Santhosh1705kumar commited on
Commit
b83ec23
·
verified ·
1 Parent(s): 48aaa68

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -54
app.py CHANGED
@@ -1,141 +1,153 @@
1
  import tensorflow as tf
2
  from tensorflow.keras.models import Model, load_model
3
  import gradio as gr
4
- from tensorflow.keras.preprocessing.image import img_to_array
5
- from tensorflow.keras.applications.densenet import preprocess_input
6
  import numpy as np
7
  import cv2
 
8
  from collections import defaultdict
9
- from PIL import Image
 
10
  import matplotlib.pyplot as plt
11
- import time
12
 
13
- # Load the model
14
  model = load_model('Densenet.h5')
15
  model.load_weights("pretrained_model.h5")
16
  layer_name = 'conv5_block16_concat'
17
- class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass', 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation', 'No Finding']
 
 
 
 
18
 
19
- # Symptom-to-disease mapping
20
  symptom_data = {
21
  "Shortness of breath": {
22
- "questions": [
23
- "Do you also have chest pain?",
24
- "Do you feel fatigued often?",
25
- "Have you noticed swelling in your legs?"
26
- ],
27
  "diseases": ["Atelectasis", "Emphysema", "Edema"],
28
  "weights_yes": [30, 30, 40],
29
  "weights_no": [10, 20, 30]
30
  },
31
  "Persistent cough": {
32
- "questions": [
33
- "Is your cough dry or with mucus?",
34
- "Do you experience fever?",
35
- "Do you have difficulty breathing?"
36
- ],
37
  "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
38
  "weights_yes": [35, 30, 35],
39
  "weights_no": [10, 15, 20]
40
- }
41
  }
 
 
42
  user_state = {}
43
 
44
- def chatbot(user_input, history):
 
45
  if "state" not in user_state:
46
  user_state["state"] = "greet"
47
-
48
  if user_state["state"] == "greet":
49
  user_state["state"] = "ask_symptom"
50
- return history + [[user_input, "Hello! Please describe your primary symptom."]]
51
-
52
  elif user_state["state"] == "ask_symptom":
53
  if user_input not in symptom_data:
54
- return history + [[user_input, "I don't recognize that symptom. Try: " + ", ".join(symptom_data.keys())]]
 
55
  user_state["symptom"] = user_input
56
  user_state["state"] = "ask_duration"
57
- return history + [[user_input, "How long have you had this symptom? (Less than a week / More than a week)"]]
58
-
59
  elif user_state["state"] == "ask_duration":
60
  if user_input.lower() == "less than a week":
61
  user_state.clear()
62
- return history + [[user_input, "It might be temporary. Monitor your symptoms and consult a doctor if they persist."]]
63
  elif user_input.lower() == "more than a week":
64
  user_state["state"] = "follow_up"
65
  user_state["current_question"] = 0
66
  user_state["disease_scores"] = defaultdict(int)
67
- return history + [[user_input, symptom_data[user_state["symptom"]]["questions"][0]]]
68
  else:
69
- return history + [[user_input, "Please answer 'Less than a week' or 'More than a week'."]]
70
-
71
  elif user_state["state"] == "follow_up":
72
  symptom = user_state["symptom"]
73
  question_index = user_state["current_question"]
74
-
 
75
  if user_input.lower() == "yes":
76
  for i, disease in enumerate(symptom_data[symptom]["diseases"]):
77
  user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
78
  else:
79
  for i, disease in enumerate(symptom_data[symptom]["diseases"]):
80
  user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
81
-
 
82
  user_state["current_question"] += 1
83
  if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
84
- return history + [[user_input, symptom_data[symptom]["questions"][user_state["current_question"]]]]
85
-
86
  probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
87
  user_state.clear()
88
- return history + [[user_input, f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor."]]
89
 
90
  def get_gradcam(img, model, layer_name):
 
91
  img_array = img_to_array(img)
92
  img_array = np.expand_dims(img_array, axis=0)
93
  img_array = preprocess_input(img_array)
94
-
95
  grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
 
96
  with tf.GradientTape() as tape:
97
  conv_outputs, predictions = grad_model(img_array)
98
  class_idx = tf.argmax(predictions[0])
99
-
100
- output = conv_outputs[0]
101
  grads = tape.gradient(predictions, conv_outputs)[0]
102
- weights = tf.reduce_mean(grads, axis=(0, 1))
103
- cam = np.dot(output, weights)
104
- cam = np.maximum(cam, 0)
105
- cam = cam / cam.max()
106
- cam = cv2.resize(cam, (img.size[0], img.size[1]))
107
-
108
- heatmap = plt.cm.jet(cam)[..., :3]
109
- overlay = Image.blend(img.convert("RGB"), Image.fromarray((heatmap * 255).astype(np.uint8)), alpha=0.5)
 
 
 
 
110
  return overlay
111
 
112
  def classify_image(img):
113
- img = cv2.resize(img, (540, 540))
114
- img_array = img_to_array(img)
115
- img_array = np.expand_dims(img_array, axis=0)
116
  img_array = preprocess_input(img_array)
117
-
118
  predictions = model.predict(img_array)
119
- top_pred = class_names[np.argmax(predictions)]
120
  overlay_img = get_gradcam(img, model, layer_name)
121
-
 
122
  return top_pred, overlay_img
123
 
 
124
  with gr.Blocks() as demo:
125
  gr.Markdown("# Medical AI Assistant")
 
126
  with gr.Tab("Chatbot"):
127
  chatbot_ui = gr.Chatbot()
128
- user_input = gr.Textbox(label="Your Message")
129
- submit = gr.Button("Send")
130
  clear_chat = gr.Button("Clear Chat")
 
131
  submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
 
132
  clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
133
-
134
  with gr.Tab("X-ray Classification"):
135
  image_input = gr.Image()
136
  classify_button = gr.Button("Classify")
137
  output_text = gr.Text()
138
  output_image = gr.Image()
 
139
  classify_button.click(classify_image, [image_input], [output_text, output_image])
140
 
141
  demo.launch()
 
1
  import tensorflow as tf
2
  from tensorflow.keras.models import Model, load_model
3
  import gradio as gr
 
 
4
  import numpy as np
5
  import cv2
6
+ import time
7
  from collections import defaultdict
8
+ from tensorflow.keras.preprocessing.image import img_to_array
9
+ from tensorflow.keras.applications.densenet import preprocess_input
10
  import matplotlib.pyplot as plt
11
+ from PIL import Image
12
 
13
+ # Load Model
14
  model = load_model('Densenet.h5')
15
  model.load_weights("pretrained_model.h5")
16
  layer_name = 'conv5_block16_concat'
17
+ class_names = [
18
+ 'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
19
+ 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural Thickening', 'Pneumonia',
20
+ 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
21
+ ]
22
 
23
+ # Symptom-to-Disease Mapping
24
  symptom_data = {
25
  "Shortness of breath": {
26
+ "questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
 
 
 
 
27
  "diseases": ["Atelectasis", "Emphysema", "Edema"],
28
  "weights_yes": [30, 30, 40],
29
  "weights_no": [10, 20, 30]
30
  },
31
  "Persistent cough": {
32
+ "questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
 
 
 
 
33
  "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
34
  "weights_yes": [35, 30, 35],
35
  "weights_no": [10, 15, 20]
36
+ },
37
  }
38
+
39
+ # User State
40
  user_state = {}
41
 
42
+ def chatbot(user_input, history=[]):
43
+ """Chatbot for symptom-based diagnosis."""
44
  if "state" not in user_state:
45
  user_state["state"] = "greet"
46
+
47
  if user_state["state"] == "greet":
48
  user_state["state"] = "ask_symptom"
49
+ return history + [(user_input, "Hello! Please describe your primary symptom.")]
50
+
51
  elif user_state["state"] == "ask_symptom":
52
  if user_input not in symptom_data:
53
+ return history + [(user_input, f"I don't recognize that symptom. Please enter one of these: {', '.join(symptom_data.keys())}")]
54
+
55
  user_state["symptom"] = user_input
56
  user_state["state"] = "ask_duration"
57
+ return history + [(user_input, "How long have you had this symptom? (Less than a week / More than a week)")]
58
+
59
  elif user_state["state"] == "ask_duration":
60
  if user_input.lower() == "less than a week":
61
  user_state.clear()
62
+ return history + [(user_input, "It might be a temporary issue. Monitor symptoms and consult a doctor if they persist.")]
63
  elif user_input.lower() == "more than a week":
64
  user_state["state"] = "follow_up"
65
  user_state["current_question"] = 0
66
  user_state["disease_scores"] = defaultdict(int)
67
+ return history + [(user_input, symptom_data[user_state["symptom"]]["questions"][0])]
68
  else:
69
+ return history + [(user_input, "Please respond with 'Less than a week' or 'More than a week'.")]
70
+
71
  elif user_state["state"] == "follow_up":
72
  symptom = user_state["symptom"]
73
  question_index = user_state["current_question"]
74
+
75
+ # Update disease probability scores
76
  if user_input.lower() == "yes":
77
  for i, disease in enumerate(symptom_data[symptom]["diseases"]):
78
  user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
79
  else:
80
  for i, disease in enumerate(symptom_data[symptom]["diseases"]):
81
  user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
82
+
83
+ # Next question or final diagnosis
84
  user_state["current_question"] += 1
85
  if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
86
+ return history + [(user_input, symptom_data[symptom]["questions"][user_state["current_question"]])]
87
+
88
  probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
89
  user_state.clear()
90
+ return history + [(user_input, f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
91
 
92
  def get_gradcam(img, model, layer_name):
93
+ """Generate Grad-CAM heatmap for X-ray image."""
94
  img_array = img_to_array(img)
95
  img_array = np.expand_dims(img_array, axis=0)
96
  img_array = preprocess_input(img_array)
97
+
98
  grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
99
+
100
  with tf.GradientTape() as tape:
101
  conv_outputs, predictions = grad_model(img_array)
102
  class_idx = tf.argmax(predictions[0])
103
+
 
104
  grads = tape.gradient(predictions, conv_outputs)[0]
105
+ guided_grads = tf.cast(conv_outputs > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
106
+
107
+ weights = tf.reduce_mean(guided_grads, axis=(0, 1))
108
+ cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
109
+
110
+ heatmap = np.maximum(cam, 0)
111
+ heatmap /= tf.reduce_max(heatmap)
112
+ heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
113
+
114
+ colormap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
115
+ overlay = cv2.addWeighted(img, 0.5, colormap, 0.5, 0)
116
+
117
  return overlay
118
 
119
  def classify_image(img):
120
+ """Classify X-ray image and return Grad-CAM visualization."""
121
+ img = cv2.resize(np.array(img), (540, 540))
122
+ img_array = np.expand_dims(img, axis=0)
123
  img_array = preprocess_input(img_array)
124
+
125
  predictions = model.predict(img_array)
 
126
  overlay_img = get_gradcam(img, model, layer_name)
127
+
128
+ top_pred = class_names[np.argmax(predictions)]
129
  return top_pred, overlay_img
130
 
131
+ # Gradio UI
132
  with gr.Blocks() as demo:
133
  gr.Markdown("# Medical AI Assistant")
134
+
135
  with gr.Tab("Chatbot"):
136
  chatbot_ui = gr.Chatbot()
137
+ user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
138
+ submit = gr.Button("Send", variant="primary", interactive=True)
139
  clear_chat = gr.Button("Clear Chat")
140
+
141
  submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
142
+ user_input.submit(chatbot, [user_input, chatbot_ui], chatbot_ui)
143
  clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
144
+
145
  with gr.Tab("X-ray Classification"):
146
  image_input = gr.Image()
147
  classify_button = gr.Button("Classify")
148
  output_text = gr.Text()
149
  output_image = gr.Image()
150
+
151
  classify_button.click(classify_image, [image_input], [output_text, output_image])
152
 
153
  demo.launch()