Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,141 +1,153 @@
|
|
| 1 |
import tensorflow as tf
|
| 2 |
from tensorflow.keras.models import Model, load_model
|
| 3 |
import gradio as gr
|
| 4 |
-
from tensorflow.keras.preprocessing.image import img_to_array
|
| 5 |
-
from tensorflow.keras.applications.densenet import preprocess_input
|
| 6 |
import numpy as np
|
| 7 |
import cv2
|
|
|
|
| 8 |
from collections import defaultdict
|
| 9 |
-
from
|
|
|
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
-
import
|
| 12 |
|
| 13 |
-
# Load
|
| 14 |
model = load_model('Densenet.h5')
|
| 15 |
model.load_weights("pretrained_model.h5")
|
| 16 |
layer_name = 'conv5_block16_concat'
|
| 17 |
-
class_names = [
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
# Symptom-to-
|
| 20 |
symptom_data = {
|
| 21 |
"Shortness of breath": {
|
| 22 |
-
"questions": [
|
| 23 |
-
"Do you also have chest pain?",
|
| 24 |
-
"Do you feel fatigued often?",
|
| 25 |
-
"Have you noticed swelling in your legs?"
|
| 26 |
-
],
|
| 27 |
"diseases": ["Atelectasis", "Emphysema", "Edema"],
|
| 28 |
"weights_yes": [30, 30, 40],
|
| 29 |
"weights_no": [10, 20, 30]
|
| 30 |
},
|
| 31 |
"Persistent cough": {
|
| 32 |
-
"questions": [
|
| 33 |
-
"Is your cough dry or with mucus?",
|
| 34 |
-
"Do you experience fever?",
|
| 35 |
-
"Do you have difficulty breathing?"
|
| 36 |
-
],
|
| 37 |
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
|
| 38 |
"weights_yes": [35, 30, 35],
|
| 39 |
"weights_no": [10, 15, 20]
|
| 40 |
-
}
|
| 41 |
}
|
|
|
|
|
|
|
| 42 |
user_state = {}
|
| 43 |
|
| 44 |
-
def chatbot(user_input, history):
|
|
|
|
| 45 |
if "state" not in user_state:
|
| 46 |
user_state["state"] = "greet"
|
| 47 |
-
|
| 48 |
if user_state["state"] == "greet":
|
| 49 |
user_state["state"] = "ask_symptom"
|
| 50 |
-
return history + [
|
| 51 |
-
|
| 52 |
elif user_state["state"] == "ask_symptom":
|
| 53 |
if user_input not in symptom_data:
|
| 54 |
-
return history + [
|
|
|
|
| 55 |
user_state["symptom"] = user_input
|
| 56 |
user_state["state"] = "ask_duration"
|
| 57 |
-
return history + [
|
| 58 |
-
|
| 59 |
elif user_state["state"] == "ask_duration":
|
| 60 |
if user_input.lower() == "less than a week":
|
| 61 |
user_state.clear()
|
| 62 |
-
return history + [
|
| 63 |
elif user_input.lower() == "more than a week":
|
| 64 |
user_state["state"] = "follow_up"
|
| 65 |
user_state["current_question"] = 0
|
| 66 |
user_state["disease_scores"] = defaultdict(int)
|
| 67 |
-
return history + [
|
| 68 |
else:
|
| 69 |
-
return history + [
|
| 70 |
-
|
| 71 |
elif user_state["state"] == "follow_up":
|
| 72 |
symptom = user_state["symptom"]
|
| 73 |
question_index = user_state["current_question"]
|
| 74 |
-
|
|
|
|
| 75 |
if user_input.lower() == "yes":
|
| 76 |
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
|
| 77 |
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
|
| 78 |
else:
|
| 79 |
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
|
| 80 |
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
|
| 81 |
-
|
|
|
|
| 82 |
user_state["current_question"] += 1
|
| 83 |
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
|
| 84 |
-
return history + [
|
| 85 |
-
|
| 86 |
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
|
| 87 |
user_state.clear()
|
| 88 |
-
return history + [
|
| 89 |
|
| 90 |
def get_gradcam(img, model, layer_name):
|
|
|
|
| 91 |
img_array = img_to_array(img)
|
| 92 |
img_array = np.expand_dims(img_array, axis=0)
|
| 93 |
img_array = preprocess_input(img_array)
|
| 94 |
-
|
| 95 |
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
|
|
|
|
| 96 |
with tf.GradientTape() as tape:
|
| 97 |
conv_outputs, predictions = grad_model(img_array)
|
| 98 |
class_idx = tf.argmax(predictions[0])
|
| 99 |
-
|
| 100 |
-
output = conv_outputs[0]
|
| 101 |
grads = tape.gradient(predictions, conv_outputs)[0]
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
cam =
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
heatmap =
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
return overlay
|
| 111 |
|
| 112 |
def classify_image(img):
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
img_array = np.expand_dims(
|
| 116 |
img_array = preprocess_input(img_array)
|
| 117 |
-
|
| 118 |
predictions = model.predict(img_array)
|
| 119 |
-
top_pred = class_names[np.argmax(predictions)]
|
| 120 |
overlay_img = get_gradcam(img, model, layer_name)
|
| 121 |
-
|
|
|
|
| 122 |
return top_pred, overlay_img
|
| 123 |
|
|
|
|
| 124 |
with gr.Blocks() as demo:
|
| 125 |
gr.Markdown("# Medical AI Assistant")
|
|
|
|
| 126 |
with gr.Tab("Chatbot"):
|
| 127 |
chatbot_ui = gr.Chatbot()
|
| 128 |
-
user_input = gr.Textbox(label="Your Message")
|
| 129 |
-
submit = gr.Button("Send")
|
| 130 |
clear_chat = gr.Button("Clear Chat")
|
|
|
|
| 131 |
submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
|
|
|
|
| 132 |
clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
|
| 133 |
-
|
| 134 |
with gr.Tab("X-ray Classification"):
|
| 135 |
image_input = gr.Image()
|
| 136 |
classify_button = gr.Button("Classify")
|
| 137 |
output_text = gr.Text()
|
| 138 |
output_image = gr.Image()
|
|
|
|
| 139 |
classify_button.click(classify_image, [image_input], [output_text, output_image])
|
| 140 |
|
| 141 |
demo.launch()
|
|
|
|
| 1 |
import tensorflow as tf
|
| 2 |
from tensorflow.keras.models import Model, load_model
|
| 3 |
import gradio as gr
|
|
|
|
|
|
|
| 4 |
import numpy as np
|
| 5 |
import cv2
|
| 6 |
+
import time
|
| 7 |
from collections import defaultdict
|
| 8 |
+
from tensorflow.keras.preprocessing.image import img_to_array
|
| 9 |
+
from tensorflow.keras.applications.densenet import preprocess_input
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
+
from PIL import Image
|
| 12 |
|
| 13 |
+
# Load Model
|
| 14 |
model = load_model('Densenet.h5')
|
| 15 |
model.load_weights("pretrained_model.h5")
|
| 16 |
layer_name = 'conv5_block16_concat'
|
| 17 |
+
class_names = [
|
| 18 |
+
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
|
| 19 |
+
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural Thickening', 'Pneumonia',
|
| 20 |
+
'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
|
| 21 |
+
]
|
| 22 |
|
| 23 |
+
# Symptom-to-Disease Mapping
|
| 24 |
symptom_data = {
|
| 25 |
"Shortness of breath": {
|
| 26 |
+
"questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
"diseases": ["Atelectasis", "Emphysema", "Edema"],
|
| 28 |
"weights_yes": [30, 30, 40],
|
| 29 |
"weights_no": [10, 20, 30]
|
| 30 |
},
|
| 31 |
"Persistent cough": {
|
| 32 |
+
"questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
|
| 34 |
"weights_yes": [35, 30, 35],
|
| 35 |
"weights_no": [10, 15, 20]
|
| 36 |
+
},
|
| 37 |
}
|
| 38 |
+
|
| 39 |
+
# User State
|
| 40 |
user_state = {}
|
| 41 |
|
| 42 |
+
def chatbot(user_input, history=[]):
|
| 43 |
+
"""Chatbot for symptom-based diagnosis."""
|
| 44 |
if "state" not in user_state:
|
| 45 |
user_state["state"] = "greet"
|
| 46 |
+
|
| 47 |
if user_state["state"] == "greet":
|
| 48 |
user_state["state"] = "ask_symptom"
|
| 49 |
+
return history + [(user_input, "Hello! Please describe your primary symptom.")]
|
| 50 |
+
|
| 51 |
elif user_state["state"] == "ask_symptom":
|
| 52 |
if user_input not in symptom_data:
|
| 53 |
+
return history + [(user_input, f"I don't recognize that symptom. Please enter one of these: {', '.join(symptom_data.keys())}")]
|
| 54 |
+
|
| 55 |
user_state["symptom"] = user_input
|
| 56 |
user_state["state"] = "ask_duration"
|
| 57 |
+
return history + [(user_input, "How long have you had this symptom? (Less than a week / More than a week)")]
|
| 58 |
+
|
| 59 |
elif user_state["state"] == "ask_duration":
|
| 60 |
if user_input.lower() == "less than a week":
|
| 61 |
user_state.clear()
|
| 62 |
+
return history + [(user_input, "It might be a temporary issue. Monitor symptoms and consult a doctor if they persist.")]
|
| 63 |
elif user_input.lower() == "more than a week":
|
| 64 |
user_state["state"] = "follow_up"
|
| 65 |
user_state["current_question"] = 0
|
| 66 |
user_state["disease_scores"] = defaultdict(int)
|
| 67 |
+
return history + [(user_input, symptom_data[user_state["symptom"]]["questions"][0])]
|
| 68 |
else:
|
| 69 |
+
return history + [(user_input, "Please respond with 'Less than a week' or 'More than a week'.")]
|
| 70 |
+
|
| 71 |
elif user_state["state"] == "follow_up":
|
| 72 |
symptom = user_state["symptom"]
|
| 73 |
question_index = user_state["current_question"]
|
| 74 |
+
|
| 75 |
+
# Update disease probability scores
|
| 76 |
if user_input.lower() == "yes":
|
| 77 |
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
|
| 78 |
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
|
| 79 |
else:
|
| 80 |
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
|
| 81 |
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
|
| 82 |
+
|
| 83 |
+
# Next question or final diagnosis
|
| 84 |
user_state["current_question"] += 1
|
| 85 |
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
|
| 86 |
+
return history + [(user_input, symptom_data[symptom]["questions"][user_state["current_question"]])]
|
| 87 |
+
|
| 88 |
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
|
| 89 |
user_state.clear()
|
| 90 |
+
return history + [(user_input, f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
|
| 91 |
|
| 92 |
def get_gradcam(img, model, layer_name):
|
| 93 |
+
"""Generate Grad-CAM heatmap for X-ray image."""
|
| 94 |
img_array = img_to_array(img)
|
| 95 |
img_array = np.expand_dims(img_array, axis=0)
|
| 96 |
img_array = preprocess_input(img_array)
|
| 97 |
+
|
| 98 |
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
|
| 99 |
+
|
| 100 |
with tf.GradientTape() as tape:
|
| 101 |
conv_outputs, predictions = grad_model(img_array)
|
| 102 |
class_idx = tf.argmax(predictions[0])
|
| 103 |
+
|
|
|
|
| 104 |
grads = tape.gradient(predictions, conv_outputs)[0]
|
| 105 |
+
guided_grads = tf.cast(conv_outputs > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
|
| 106 |
+
|
| 107 |
+
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
|
| 108 |
+
cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
|
| 109 |
+
|
| 110 |
+
heatmap = np.maximum(cam, 0)
|
| 111 |
+
heatmap /= tf.reduce_max(heatmap)
|
| 112 |
+
heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
|
| 113 |
+
|
| 114 |
+
colormap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
|
| 115 |
+
overlay = cv2.addWeighted(img, 0.5, colormap, 0.5, 0)
|
| 116 |
+
|
| 117 |
return overlay
|
| 118 |
|
| 119 |
def classify_image(img):
|
| 120 |
+
"""Classify X-ray image and return Grad-CAM visualization."""
|
| 121 |
+
img = cv2.resize(np.array(img), (540, 540))
|
| 122 |
+
img_array = np.expand_dims(img, axis=0)
|
| 123 |
img_array = preprocess_input(img_array)
|
| 124 |
+
|
| 125 |
predictions = model.predict(img_array)
|
|
|
|
| 126 |
overlay_img = get_gradcam(img, model, layer_name)
|
| 127 |
+
|
| 128 |
+
top_pred = class_names[np.argmax(predictions)]
|
| 129 |
return top_pred, overlay_img
|
| 130 |
|
| 131 |
+
# Gradio UI
|
| 132 |
with gr.Blocks() as demo:
|
| 133 |
gr.Markdown("# Medical AI Assistant")
|
| 134 |
+
|
| 135 |
with gr.Tab("Chatbot"):
|
| 136 |
chatbot_ui = gr.Chatbot()
|
| 137 |
+
user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
|
| 138 |
+
submit = gr.Button("Send", variant="primary", interactive=True)
|
| 139 |
clear_chat = gr.Button("Clear Chat")
|
| 140 |
+
|
| 141 |
submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
|
| 142 |
+
user_input.submit(chatbot, [user_input, chatbot_ui], chatbot_ui)
|
| 143 |
clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
|
| 144 |
+
|
| 145 |
with gr.Tab("X-ray Classification"):
|
| 146 |
image_input = gr.Image()
|
| 147 |
classify_button = gr.Button("Classify")
|
| 148 |
output_text = gr.Text()
|
| 149 |
output_image = gr.Image()
|
| 150 |
+
|
| 151 |
classify_button.click(classify_image, [image_input], [output_text, output_image])
|
| 152 |
|
| 153 |
demo.launch()
|