Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,44 +9,32 @@ from PIL import Image
|
|
| 9 |
import time
|
| 10 |
from collections import defaultdict
|
| 11 |
|
| 12 |
-
# Load
|
| 13 |
model = load_model('Densenet.h5')
|
| 14 |
model.load_weights("pretrained_model.h5")
|
| 15 |
layer_name = 'conv5_block16_concat'
|
| 16 |
|
| 17 |
-
# Define
|
| 18 |
class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
|
| 19 |
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia',
|
| 20 |
'Fibrosis', 'Edema', 'Consolidation', 'No Finding']
|
| 21 |
|
| 22 |
-
# Symptom
|
| 23 |
symptom_data = {
|
| 24 |
"Shortness of breath": {
|
| 25 |
-
"questions": [
|
| 26 |
-
"Do you also have chest pain?",
|
| 27 |
-
"Do you feel fatigued often?",
|
| 28 |
-
"Have you noticed swelling in your legs?"
|
| 29 |
-
],
|
| 30 |
"diseases": ["Atelectasis", "Emphysema", "Edema"],
|
| 31 |
"weights_yes": [30, 30, 40],
|
| 32 |
"weights_no": [10, 20, 30]
|
| 33 |
},
|
| 34 |
"Persistent cough": {
|
| 35 |
-
"questions": [
|
| 36 |
-
"Is your cough dry or with mucus?",
|
| 37 |
-
"Do you experience fever?",
|
| 38 |
-
"Do you have difficulty breathing?"
|
| 39 |
-
],
|
| 40 |
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
|
| 41 |
"weights_yes": [35, 30, 35],
|
| 42 |
"weights_no": [10, 15, 20]
|
| 43 |
},
|
| 44 |
"Sharp chest pain": {
|
| 45 |
-
"questions": [
|
| 46 |
-
"Does it worsen with deep breaths?",
|
| 47 |
-
"Do you feel lightheaded?",
|
| 48 |
-
"Have you had recent trauma or surgery?"
|
| 49 |
-
],
|
| 50 |
"diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
|
| 51 |
"weights_yes": [40, 30, 30],
|
| 52 |
"weights_no": [15, 20, 25]
|
|
@@ -57,34 +45,38 @@ symptom_data = {
|
|
| 57 |
user_state = {}
|
| 58 |
|
| 59 |
# Chatbot function
|
| 60 |
-
def chatbot(user_input, history):
|
| 61 |
if "state" not in user_state:
|
| 62 |
user_state["state"] = "greet"
|
| 63 |
-
|
|
|
|
|
|
|
| 64 |
if user_state["state"] == "greet":
|
| 65 |
user_state["state"] = "ask_symptom"
|
| 66 |
-
return history + [
|
| 67 |
-
|
| 68 |
-
|
| 69 |
if user_input not in symptom_data:
|
| 70 |
-
return history + [
|
|
|
|
| 71 |
user_state["symptom"] = user_input
|
| 72 |
user_state["state"] = "ask_duration"
|
| 73 |
-
return history + [
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
| 76 |
if user_input.lower() == "less than a week":
|
| 77 |
user_state.clear()
|
| 78 |
-
return history + [
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
elif user_state["state"] == "follow_up":
|
| 88 |
symptom = user_state["symptom"]
|
| 89 |
question_index = user_state["current_question"]
|
| 90 |
|
|
@@ -97,18 +89,15 @@ def chatbot(user_input, history):
|
|
| 97 |
|
| 98 |
user_state["current_question"] += 1
|
| 99 |
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
|
| 100 |
-
return history + [
|
| 101 |
|
| 102 |
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
|
| 103 |
user_state.clear()
|
| 104 |
-
return history + [
|
| 105 |
|
| 106 |
# Grad-CAM function
|
| 107 |
def get_gradcam(model, img, layer_name):
|
| 108 |
-
img_array = img_to_array(img)
|
| 109 |
-
img_array = np.expand_dims(img_array, axis=0)
|
| 110 |
-
img_array = preprocess_input(img_array)
|
| 111 |
-
|
| 112 |
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
|
| 113 |
|
| 114 |
with tf.GradientTape() as tape:
|
|
@@ -118,50 +107,37 @@ def get_gradcam(model, img, layer_name):
|
|
| 118 |
output = conv_outputs[0]
|
| 119 |
grads = tape.gradient(predictions, conv_outputs)[0]
|
| 120 |
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
|
| 121 |
-
|
| 122 |
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
|
| 123 |
-
cam = tf.reduce_sum(tf.multiply(weights, output), axis=-1)
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
heatmap /= tf.reduce_max(heatmap)
|
| 127 |
-
heatmap = np.uint8(255 * heatmap)
|
| 128 |
-
heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
|
| 129 |
-
|
| 130 |
-
return Image.fromarray(heatmap)
|
| 131 |
|
| 132 |
# X-ray classification function
|
| 133 |
def classify_image(img):
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
img_array = img_to_array(img)
|
| 137 |
-
img_array = np.expand_dims(img_array, axis=0)
|
| 138 |
-
img_array = preprocess_input(img_array)
|
| 139 |
-
|
| 140 |
-
predictions = model.predict(img_array)
|
| 141 |
top_indices = predictions[0].argsort()[-4:][::-1]
|
| 142 |
decoded_predictions = [(class_names[i], float(predictions[0][i])) for i in top_indices]
|
| 143 |
-
|
| 144 |
-
heatmap = get_gradcam(model, img, layer_name)
|
| 145 |
-
return decoded_predictions, heatmap
|
| 146 |
|
| 147 |
# Gradio UI
|
| 148 |
with gr.Blocks() as demo:
|
| 149 |
gr.Markdown("# Medical AI Assistant")
|
| 150 |
|
| 151 |
with gr.Tab("Symptom Chatbot"):
|
| 152 |
-
chatbot_ui = gr.Chatbot()
|
| 153 |
-
user_input = gr.Textbox(
|
| 154 |
submit = gr.Button("Send")
|
| 155 |
clear_chat = gr.Button("Clear Chat")
|
| 156 |
|
| 157 |
-
submit.click(chatbot, [user_input, chatbot_ui], [chatbot_ui
|
| 158 |
-
clear_chat.click(lambda:
|
| 159 |
|
| 160 |
with gr.Tab("X-ray Classification"):
|
| 161 |
-
image_input = gr.Image(type="numpy")
|
| 162 |
-
classify_button = gr.Button("Classify")
|
| 163 |
-
output_text = gr.Text()
|
| 164 |
-
output_image = gr.Image()
|
| 165 |
|
| 166 |
classify_button.click(classify_image, [image_input], [output_text, output_image])
|
| 167 |
|
|
|
|
| 9 |
import time
|
| 10 |
from collections import defaultdict
|
| 11 |
|
| 12 |
+
# Load model
|
| 13 |
model = load_model('Densenet.h5')
|
| 14 |
model.load_weights("pretrained_model.h5")
|
| 15 |
layer_name = 'conv5_block16_concat'
|
| 16 |
|
| 17 |
+
# Define classes
|
| 18 |
class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
|
| 19 |
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia',
|
| 20 |
'Fibrosis', 'Edema', 'Consolidation', 'No Finding']
|
| 21 |
|
| 22 |
+
# Symptom mapping
|
| 23 |
symptom_data = {
|
| 24 |
"Shortness of breath": {
|
| 25 |
+
"questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
"diseases": ["Atelectasis", "Emphysema", "Edema"],
|
| 27 |
"weights_yes": [30, 30, 40],
|
| 28 |
"weights_no": [10, 20, 30]
|
| 29 |
},
|
| 30 |
"Persistent cough": {
|
| 31 |
+
"questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
|
| 33 |
"weights_yes": [35, 30, 35],
|
| 34 |
"weights_no": [10, 15, 20]
|
| 35 |
},
|
| 36 |
"Sharp chest pain": {
|
| 37 |
+
"questions": ["Does it worsen with deep breaths?", "Do you feel lightheaded?", "Have you had recent trauma or surgery?"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
"diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
|
| 39 |
"weights_yes": [40, 30, 30],
|
| 40 |
"weights_no": [15, 20, 25]
|
|
|
|
| 45 |
user_state = {}
|
| 46 |
|
| 47 |
# Chatbot function
|
| 48 |
+
def chatbot(user_input, history=[]):
|
| 49 |
if "state" not in user_state:
|
| 50 |
user_state["state"] = "greet"
|
| 51 |
+
history.append(("User", user_input))
|
| 52 |
+
return history, "Hello! I'm a medical AI assistant. Please describe your primary symptom."
|
| 53 |
+
|
| 54 |
if user_state["state"] == "greet":
|
| 55 |
user_state["state"] = "ask_symptom"
|
| 56 |
+
return history + [("User", user_input), ("AI", "Please describe your primary symptom.")]
|
| 57 |
+
|
| 58 |
+
if user_state["state"] == "ask_symptom":
|
| 59 |
if user_input not in symptom_data:
|
| 60 |
+
return history + [("User", user_input), ("AI", "Please enter a valid symptom: " + ", ".join(symptom_data.keys()))]
|
| 61 |
+
|
| 62 |
user_state["symptom"] = user_input
|
| 63 |
user_state["state"] = "ask_duration"
|
| 64 |
+
return history + [("User", user_input), ("AI", "How long have you had this symptom? (Less than a week / More than a week)")]
|
| 65 |
+
|
| 66 |
+
if user_state["state"] == "ask_duration":
|
| 67 |
+
if user_input.lower() not in ["less than a week", "more than a week"]:
|
| 68 |
+
return history + [("User", user_input), ("AI", "Please respond with 'Less than a week' or 'More than a week'.")]
|
| 69 |
+
|
| 70 |
if user_input.lower() == "less than a week":
|
| 71 |
user_state.clear()
|
| 72 |
+
return history + [("User", user_input), ("AI", "It might be temporary. Monitor symptoms and see a doctor if needed.")]
|
| 73 |
+
|
| 74 |
+
user_state["state"] = "follow_up"
|
| 75 |
+
user_state["current_question"] = 0
|
| 76 |
+
user_state["disease_scores"] = defaultdict(int)
|
| 77 |
+
return history + [("User", user_input), ("AI", symptom_data[user_state['symptom']]['questions'][0])]
|
| 78 |
+
|
| 79 |
+
if user_state["state"] == "follow_up":
|
|
|
|
|
|
|
| 80 |
symptom = user_state["symptom"]
|
| 81 |
question_index = user_state["current_question"]
|
| 82 |
|
|
|
|
| 89 |
|
| 90 |
user_state["current_question"] += 1
|
| 91 |
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
|
| 92 |
+
return history + [("User", user_input), ("AI", symptom_data[symptom]["questions"][user_state["current_question"]])]
|
| 93 |
|
| 94 |
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
|
| 95 |
user_state.clear()
|
| 96 |
+
return history + [("User", user_input), (f"AI", f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
|
| 97 |
|
| 98 |
# Grad-CAM function
|
| 99 |
def get_gradcam(model, img, layer_name):
|
| 100 |
+
img_array = preprocess_input(np.expand_dims(img_to_array(img), axis=0))
|
|
|
|
|
|
|
|
|
|
| 101 |
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
|
| 102 |
|
| 103 |
with tf.GradientTape() as tape:
|
|
|
|
| 107 |
output = conv_outputs[0]
|
| 108 |
grads = tape.gradient(predictions, conv_outputs)[0]
|
| 109 |
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
|
|
|
|
| 110 |
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
|
| 111 |
+
cam = np.maximum(tf.reduce_sum(tf.multiply(weights, output), axis=-1), 0)
|
| 112 |
+
heatmap = np.uint8(255 * cam / tf.reduce_max(cam))
|
| 113 |
+
return Image.fromarray(cv2.applyColorMap(heatmap, cv2.COLORMAP_JET))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 114 |
|
| 115 |
# X-ray classification function
|
| 116 |
def classify_image(img):
|
| 117 |
+
img = cv2.resize(img, (540, 540))
|
| 118 |
+
predictions = model.predict(np.expand_dims(preprocess_input(img_to_array(img)), axis=0))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
top_indices = predictions[0].argsort()[-4:][::-1]
|
| 120 |
decoded_predictions = [(class_names[i], float(predictions[0][i])) for i in top_indices]
|
| 121 |
+
return decoded_predictions, get_gradcam(model, img, layer_name)
|
|
|
|
|
|
|
| 122 |
|
| 123 |
# Gradio UI
|
| 124 |
with gr.Blocks() as demo:
|
| 125 |
gr.Markdown("# Medical AI Assistant")
|
| 126 |
|
| 127 |
with gr.Tab("Symptom Chatbot"):
|
| 128 |
+
chatbot_ui = gr.Chatbot(label="Chatbot")
|
| 129 |
+
user_input = gr.Textbox(label="Your Message", interactive=True)
|
| 130 |
submit = gr.Button("Send")
|
| 131 |
clear_chat = gr.Button("Clear Chat")
|
| 132 |
|
| 133 |
+
submit.click(chatbot, [user_input, chatbot_ui], [chatbot_ui])
|
| 134 |
+
clear_chat.click(lambda: [], [], chatbot_ui)
|
| 135 |
|
| 136 |
with gr.Tab("X-ray Classification"):
|
| 137 |
+
image_input = gr.Image(type="numpy", label="Upload Chest X-ray", height=250)
|
| 138 |
+
classify_button = gr.Button("Classify X-ray")
|
| 139 |
+
output_text = gr.Text(label="Prediction Results")
|
| 140 |
+
output_image = gr.Image(label="Grad-CAM Heatmap", height=250)
|
| 141 |
|
| 142 |
classify_button.click(classify_image, [image_input], [output_text, output_image])
|
| 143 |
|