Spaces:
Sleeping
Sleeping
File size: 6,323 Bytes
d33786f 48aaa68 4db7613 d33786f b83ec23 decc13a b83ec23 d33786f b83ec23 75166d9 b83ec23 d33786f b83ec23 75166d9 b83ec23 e95e76d b83ec23 536d560 b83ec23 536d560 b83ec23 e95e76d b83ec23 decc13a fe48a9c b83ec23 e95e76d b83ec23 e95e76d b83ec23 48aaa68 b83ec23 48aaa68 e95e76d b83ec23 48aaa68 b83ec23 48aaa68 b83ec23 48aaa68 b83ec23 48aaa68 e95e76d b83ec23 e95e76d b83ec23 e95e76d b83ec23 e95e76d b83ec23 48aaa68 b83ec23 48aaa68 b83ec23 48aaa68 b83ec23 48aaa68 b83ec23 48aaa68 b83ec23 48aaa68 d33786f b83ec23 d33786f b83ec23 48aaa68 b83ec23 48aaa68 d33786f b83ec23 d33786f b83ec23 d33786f b83ec23 d33786f b83ec23 48aaa68 b83ec23 d33786f b83ec23 d33786f 48aaa68 b83ec23 48aaa68 4db7613 48aaa68 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | import tensorflow as tf
from tensorflow.keras.models import Model, load_model
import gradio as gr
import numpy as np
import cv2
import time
from collections import defaultdict
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import preprocess_input
import matplotlib.pyplot as plt
from PIL import Image
# Load Model
model = load_model('Densenet.h5')
model.load_weights("pretrained_model.h5")
layer_name = 'conv5_block16_concat'
class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural Thickening', 'Pneumonia',
'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
]
# Symptom-to-Disease Mapping
symptom_data = {
"Shortness of breath": {
"questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
"diseases": ["Atelectasis", "Emphysema", "Edema"],
"weights_yes": [30, 30, 40],
"weights_no": [10, 20, 30]
},
"Persistent cough": {
"questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
"weights_yes": [35, 30, 35],
"weights_no": [10, 15, 20]
},
}
# User State
user_state = {}
def chatbot(user_input, history=[]):
"""Chatbot for symptom-based diagnosis."""
if "state" not in user_state:
user_state["state"] = "greet"
if user_state["state"] == "greet":
user_state["state"] = "ask_symptom"
return history + [(user_input, "Hello! Please describe your primary symptom.")]
elif user_state["state"] == "ask_symptom":
if user_input not in symptom_data:
return history + [(user_input, f"I don't recognize that symptom. Please enter one of these: {', '.join(symptom_data.keys())}")]
user_state["symptom"] = user_input
user_state["state"] = "ask_duration"
return history + [(user_input, "How long have you had this symptom? (Less than a week / More than a week)")]
elif user_state["state"] == "ask_duration":
if user_input.lower() == "less than a week":
user_state.clear()
return history + [(user_input, "It might be a temporary issue. Monitor symptoms and consult a doctor if they persist.")]
elif user_input.lower() == "more than a week":
user_state["state"] = "follow_up"
user_state["current_question"] = 0
user_state["disease_scores"] = defaultdict(int)
return history + [(user_input, symptom_data[user_state["symptom"]]["questions"][0])]
else:
return history + [(user_input, "Please respond with 'Less than a week' or 'More than a week'.")]
elif user_state["state"] == "follow_up":
symptom = user_state["symptom"]
question_index = user_state["current_question"]
# Update disease probability scores
if user_input.lower() == "yes":
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
else:
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
# Next question or final diagnosis
user_state["current_question"] += 1
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
return history + [(user_input, symptom_data[symptom]["questions"][user_state["current_question"]])]
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
user_state.clear()
return history + [(user_input, f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
def get_gradcam(img, model, layer_name):
"""Generate Grad-CAM heatmap for X-ray image."""
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
class_idx = tf.argmax(predictions[0])
grads = tape.gradient(predictions, conv_outputs)[0]
guided_grads = tf.cast(conv_outputs > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
heatmap = np.maximum(cam, 0)
heatmap /= tf.reduce_max(heatmap)
heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
colormap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(img, 0.5, colormap, 0.5, 0)
return overlay
def classify_image(img):
"""Classify X-ray image and return Grad-CAM visualization."""
img = cv2.resize(np.array(img), (540, 540))
img_array = np.expand_dims(img, axis=0)
img_array = preprocess_input(img_array)
predictions = model.predict(img_array)
overlay_img = get_gradcam(img, model, layer_name)
top_pred = class_names[np.argmax(predictions)]
return top_pred, overlay_img
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Medical AI Assistant")
with gr.Tab("Chatbot"):
chatbot_ui = gr.Chatbot()
user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
submit = gr.Button("Send", variant="primary", interactive=True)
clear_chat = gr.Button("Clear Chat")
submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
user_input.submit(chatbot, [user_input, chatbot_ui], chatbot_ui)
clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
with gr.Tab("X-ray Classification"):
image_input = gr.Image()
classify_button = gr.Button("Classify")
output_text = gr.Text()
output_image = gr.Image()
classify_button.click(classify_image, [image_input], [output_text, output_image])
demo.launch()
|