Santhosh1705kumar commited on
Commit
d192207
·
verified ·
1 Parent(s): 1db26d9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -70
app.py CHANGED
@@ -9,44 +9,32 @@ from PIL import Image
9
  import time
10
  from collections import defaultdict
11
 
12
- # Load the X-ray classification model
13
  model = load_model('Densenet.h5')
14
  model.load_weights("pretrained_model.h5")
15
  layer_name = 'conv5_block16_concat'
16
 
17
- # Define class names
18
  class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
19
  'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia',
20
  'Fibrosis', 'Edema', 'Consolidation', 'No Finding']
21
 
22
- # Symptom-to-disease mapping
23
  symptom_data = {
24
  "Shortness of breath": {
25
- "questions": [
26
- "Do you also have chest pain?",
27
- "Do you feel fatigued often?",
28
- "Have you noticed swelling in your legs?"
29
- ],
30
  "diseases": ["Atelectasis", "Emphysema", "Edema"],
31
  "weights_yes": [30, 30, 40],
32
  "weights_no": [10, 20, 30]
33
  },
34
  "Persistent cough": {
35
- "questions": [
36
- "Is your cough dry or with mucus?",
37
- "Do you experience fever?",
38
- "Do you have difficulty breathing?"
39
- ],
40
  "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
41
  "weights_yes": [35, 30, 35],
42
  "weights_no": [10, 15, 20]
43
  },
44
  "Sharp chest pain": {
45
- "questions": [
46
- "Does it worsen with deep breaths?",
47
- "Do you feel lightheaded?",
48
- "Have you had recent trauma or surgery?"
49
- ],
50
  "diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
51
  "weights_yes": [40, 30, 30],
52
  "weights_no": [15, 20, 25]
@@ -57,34 +45,38 @@ symptom_data = {
57
  user_state = {}
58
 
59
  # Chatbot function
60
- def chatbot(user_input, history):
61
  if "state" not in user_state:
62
  user_state["state"] = "greet"
63
-
 
 
64
  if user_state["state"] == "greet":
65
  user_state["state"] = "ask_symptom"
66
- return history + [{"role": "assistant", "content": "Hello! I'm a medical AI assistant. Please describe your primary symptom."}]
67
-
68
- elif user_state["state"] == "ask_symptom":
69
  if user_input not in symptom_data:
70
- return history + [{"role": "assistant", "content": "I don't recognize that symptom. Please enter one of these: " + ", ".join(symptom_data.keys())}]
 
71
  user_state["symptom"] = user_input
72
  user_state["state"] = "ask_duration"
73
- return history + [{"role": "assistant", "content": "How long have you been experiencing this symptom? (Less than a week / More than a week)"}]
74
-
75
- elif user_state["state"] == "ask_duration":
 
 
 
76
  if user_input.lower() == "less than a week":
77
  user_state.clear()
78
- return history + [{"role": "assistant", "content": "It might be temporary. Monitor symptoms and see a doctor if needed."}]
79
- elif user_input.lower() == "more than a week":
80
- user_state["state"] = "follow_up"
81
- user_state["current_question"] = 0
82
- user_state["disease_scores"] = defaultdict(int)
83
- return history + [{"role": "assistant", "content": symptom_data[user_state['symptom']]['questions'][0]}]
84
- else:
85
- return history + [{"role": "assistant", "content": "Please respond with 'Less than a week' or 'More than a week'."}]
86
-
87
- elif user_state["state"] == "follow_up":
88
  symptom = user_state["symptom"]
89
  question_index = user_state["current_question"]
90
 
@@ -97,18 +89,15 @@ def chatbot(user_input, history):
97
 
98
  user_state["current_question"] += 1
99
  if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
100
- return history + [{"role": "assistant", "content": symptom_data[symptom]["questions"][user_state["current_question"]]}]
101
 
102
  probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
103
  user_state.clear()
104
- return history + [{"role": "assistant", "content": f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor."}]
105
 
106
  # Grad-CAM function
107
  def get_gradcam(model, img, layer_name):
108
- img_array = img_to_array(img)
109
- img_array = np.expand_dims(img_array, axis=0)
110
- img_array = preprocess_input(img_array)
111
-
112
  grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
113
 
114
  with tf.GradientTape() as tape:
@@ -118,50 +107,37 @@ def get_gradcam(model, img, layer_name):
118
  output = conv_outputs[0]
119
  grads = tape.gradient(predictions, conv_outputs)[0]
120
  guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
121
-
122
  weights = tf.reduce_mean(guided_grads, axis=(0, 1))
123
- cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
124
-
125
- heatmap = np.maximum(cam, 0)
126
- heatmap /= tf.reduce_max(heatmap)
127
- heatmap = np.uint8(255 * heatmap)
128
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
129
-
130
- return Image.fromarray(heatmap)
131
 
132
  # X-ray classification function
133
  def classify_image(img):
134
- target_size = (540, 540)
135
- img = cv2.resize(img, target_size, interpolation=cv2.INTER_AREA)
136
- img_array = img_to_array(img)
137
- img_array = np.expand_dims(img_array, axis=0)
138
- img_array = preprocess_input(img_array)
139
-
140
- predictions = model.predict(img_array)
141
  top_indices = predictions[0].argsort()[-4:][::-1]
142
  decoded_predictions = [(class_names[i], float(predictions[0][i])) for i in top_indices]
143
-
144
- heatmap = get_gradcam(model, img, layer_name)
145
- return decoded_predictions, heatmap
146
 
147
  # Gradio UI
148
  with gr.Blocks() as demo:
149
  gr.Markdown("# Medical AI Assistant")
150
 
151
  with gr.Tab("Symptom Chatbot"):
152
- chatbot_ui = gr.Chatbot()
153
- user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
154
  submit = gr.Button("Send")
155
  clear_chat = gr.Button("Clear Chat")
156
 
157
- submit.click(chatbot, [user_input, chatbot_ui], [chatbot_ui, user_input])
158
- clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
159
 
160
  with gr.Tab("X-ray Classification"):
161
- image_input = gr.Image(type="numpy")
162
- classify_button = gr.Button("Classify")
163
- output_text = gr.Text()
164
- output_image = gr.Image()
165
 
166
  classify_button.click(classify_image, [image_input], [output_text, output_image])
167
 
 
9
  import time
10
  from collections import defaultdict
11
 
12
+ # Load model
13
  model = load_model('Densenet.h5')
14
  model.load_weights("pretrained_model.h5")
15
  layer_name = 'conv5_block16_concat'
16
 
17
+ # Define classes
18
  class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
19
  'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia',
20
  'Fibrosis', 'Edema', 'Consolidation', 'No Finding']
21
 
22
+ # Symptom mapping
23
  symptom_data = {
24
  "Shortness of breath": {
25
+ "questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
 
 
 
 
26
  "diseases": ["Atelectasis", "Emphysema", "Edema"],
27
  "weights_yes": [30, 30, 40],
28
  "weights_no": [10, 20, 30]
29
  },
30
  "Persistent cough": {
31
+ "questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
 
 
 
 
32
  "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
33
  "weights_yes": [35, 30, 35],
34
  "weights_no": [10, 15, 20]
35
  },
36
  "Sharp chest pain": {
37
+ "questions": ["Does it worsen with deep breaths?", "Do you feel lightheaded?", "Have you had recent trauma or surgery?"],
 
 
 
 
38
  "diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
39
  "weights_yes": [40, 30, 30],
40
  "weights_no": [15, 20, 25]
 
45
  user_state = {}
46
 
47
  # Chatbot function
48
+ def chatbot(user_input, history=[]):
49
  if "state" not in user_state:
50
  user_state["state"] = "greet"
51
+ history.append(("User", user_input))
52
+ return history, "Hello! I'm a medical AI assistant. Please describe your primary symptom."
53
+
54
  if user_state["state"] == "greet":
55
  user_state["state"] = "ask_symptom"
56
+ return history + [("User", user_input), ("AI", "Please describe your primary symptom.")]
57
+
58
+ if user_state["state"] == "ask_symptom":
59
  if user_input not in symptom_data:
60
+ return history + [("User", user_input), ("AI", "Please enter a valid symptom: " + ", ".join(symptom_data.keys()))]
61
+
62
  user_state["symptom"] = user_input
63
  user_state["state"] = "ask_duration"
64
+ return history + [("User", user_input), ("AI", "How long have you had this symptom? (Less than a week / More than a week)")]
65
+
66
+ if user_state["state"] == "ask_duration":
67
+ if user_input.lower() not in ["less than a week", "more than a week"]:
68
+ return history + [("User", user_input), ("AI", "Please respond with 'Less than a week' or 'More than a week'.")]
69
+
70
  if user_input.lower() == "less than a week":
71
  user_state.clear()
72
+ return history + [("User", user_input), ("AI", "It might be temporary. Monitor symptoms and see a doctor if needed.")]
73
+
74
+ user_state["state"] = "follow_up"
75
+ user_state["current_question"] = 0
76
+ user_state["disease_scores"] = defaultdict(int)
77
+ return history + [("User", user_input), ("AI", symptom_data[user_state['symptom']]['questions'][0])]
78
+
79
+ if user_state["state"] == "follow_up":
 
 
80
  symptom = user_state["symptom"]
81
  question_index = user_state["current_question"]
82
 
 
89
 
90
  user_state["current_question"] += 1
91
  if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
92
+ return history + [("User", user_input), ("AI", symptom_data[symptom]["questions"][user_state["current_question"]])]
93
 
94
  probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
95
  user_state.clear()
96
+ return history + [("User", user_input), (f"AI", f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
97
 
98
  # Grad-CAM function
99
  def get_gradcam(model, img, layer_name):
100
+ img_array = preprocess_input(np.expand_dims(img_to_array(img), axis=0))
 
 
 
101
  grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
102
 
103
  with tf.GradientTape() as tape:
 
107
  output = conv_outputs[0]
108
  grads = tape.gradient(predictions, conv_outputs)[0]
109
  guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
 
110
  weights = tf.reduce_mean(guided_grads, axis=(0, 1))
111
+ cam = np.maximum(tf.reduce_sum(tf.multiply(weights, output), axis=-1), 0)
112
+ heatmap = np.uint8(255 * cam / tf.reduce_max(cam))
113
+ return Image.fromarray(cv2.applyColorMap(heatmap, cv2.COLORMAP_JET))
 
 
 
 
 
114
 
115
  # X-ray classification function
116
  def classify_image(img):
117
+ img = cv2.resize(img, (540, 540))
118
+ predictions = model.predict(np.expand_dims(preprocess_input(img_to_array(img)), axis=0))
 
 
 
 
 
119
  top_indices = predictions[0].argsort()[-4:][::-1]
120
  decoded_predictions = [(class_names[i], float(predictions[0][i])) for i in top_indices]
121
+ return decoded_predictions, get_gradcam(model, img, layer_name)
 
 
122
 
123
  # Gradio UI
124
  with gr.Blocks() as demo:
125
  gr.Markdown("# Medical AI Assistant")
126
 
127
  with gr.Tab("Symptom Chatbot"):
128
+ chatbot_ui = gr.Chatbot(label="Chatbot")
129
+ user_input = gr.Textbox(label="Your Message", interactive=True)
130
  submit = gr.Button("Send")
131
  clear_chat = gr.Button("Clear Chat")
132
 
133
+ submit.click(chatbot, [user_input, chatbot_ui], [chatbot_ui])
134
+ clear_chat.click(lambda: [], [], chatbot_ui)
135
 
136
  with gr.Tab("X-ray Classification"):
137
+ image_input = gr.Image(type="numpy", label="Upload Chest X-ray", height=250)
138
+ classify_button = gr.Button("Classify X-ray")
139
+ output_text = gr.Text(label="Prediction Results")
140
+ output_image = gr.Image(label="Grad-CAM Heatmap", height=250)
141
 
142
  classify_button.click(classify_image, [image_input], [output_text, output_image])
143