CHATBOT3 / app.py
Santhosh1705kumar's picture
Update app.py
d192207 verified
import gradio as gr
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import preprocess_input
import numpy as np
import cv2
from PIL import Image
import time
from collections import defaultdict
# Load model
model = load_model('Densenet.h5')
model.load_weights("pretrained_model.h5")
layer_name = 'conv5_block16_concat'
# Define classes
class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia',
'Fibrosis', 'Edema', 'Consolidation', 'No Finding']
# Symptom mapping
symptom_data = {
"Shortness of breath": {
"questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
"diseases": ["Atelectasis", "Emphysema", "Edema"],
"weights_yes": [30, 30, 40],
"weights_no": [10, 20, 30]
},
"Persistent cough": {
"questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
"weights_yes": [35, 30, 35],
"weights_no": [10, 15, 20]
},
"Sharp chest pain": {
"questions": ["Does it worsen with deep breaths?", "Do you feel lightheaded?", "Have you had recent trauma or surgery?"],
"diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
"weights_yes": [40, 30, 30],
"weights_no": [15, 20, 25]
}
}
# User state tracking
user_state = {}
# Chatbot function
def chatbot(user_input, history=[]):
if "state" not in user_state:
user_state["state"] = "greet"
history.append(("User", user_input))
return history, "Hello! I'm a medical AI assistant. Please describe your primary symptom."
if user_state["state"] == "greet":
user_state["state"] = "ask_symptom"
return history + [("User", user_input), ("AI", "Please describe your primary symptom.")]
if user_state["state"] == "ask_symptom":
if user_input not in symptom_data:
return history + [("User", user_input), ("AI", "Please enter a valid symptom: " + ", ".join(symptom_data.keys()))]
user_state["symptom"] = user_input
user_state["state"] = "ask_duration"
return history + [("User", user_input), ("AI", "How long have you had this symptom? (Less than a week / More than a week)")]
if user_state["state"] == "ask_duration":
if user_input.lower() not in ["less than a week", "more than a week"]:
return history + [("User", user_input), ("AI", "Please respond with 'Less than a week' or 'More than a week'.")]
if user_input.lower() == "less than a week":
user_state.clear()
return history + [("User", user_input), ("AI", "It might be temporary. Monitor symptoms and see a doctor if needed.")]
user_state["state"] = "follow_up"
user_state["current_question"] = 0
user_state["disease_scores"] = defaultdict(int)
return history + [("User", user_input), ("AI", symptom_data[user_state['symptom']]['questions'][0])]
if user_state["state"] == "follow_up":
symptom = user_state["symptom"]
question_index = user_state["current_question"]
if user_input.lower() == "yes":
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
else:
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
user_state["current_question"] += 1
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
return history + [("User", user_input), ("AI", symptom_data[symptom]["questions"][user_state["current_question"]])]
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
user_state.clear()
return history + [("User", user_input), (f"AI", f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
# Grad-CAM function
def get_gradcam(model, img, layer_name):
img_array = preprocess_input(np.expand_dims(img_to_array(img), axis=0))
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
class_idx = tf.argmax(predictions[0])
output = conv_outputs[0]
grads = tape.gradient(predictions, conv_outputs)[0]
guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = np.maximum(tf.reduce_sum(tf.multiply(weights, output), axis=-1), 0)
heatmap = np.uint8(255 * cam / tf.reduce_max(cam))
return Image.fromarray(cv2.applyColorMap(heatmap, cv2.COLORMAP_JET))
# X-ray classification function
def classify_image(img):
img = cv2.resize(img, (540, 540))
predictions = model.predict(np.expand_dims(preprocess_input(img_to_array(img)), axis=0))
top_indices = predictions[0].argsort()[-4:][::-1]
decoded_predictions = [(class_names[i], float(predictions[0][i])) for i in top_indices]
return decoded_predictions, get_gradcam(model, img, layer_name)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Medical AI Assistant")
with gr.Tab("Symptom Chatbot"):
chatbot_ui = gr.Chatbot(label="Chatbot")
user_input = gr.Textbox(label="Your Message", interactive=True)
submit = gr.Button("Send")
clear_chat = gr.Button("Clear Chat")
submit.click(chatbot, [user_input, chatbot_ui], [chatbot_ui])
clear_chat.click(lambda: [], [], chatbot_ui)
with gr.Tab("X-ray Classification"):
image_input = gr.Image(type="numpy", label="Upload Chest X-ray", height=250)
classify_button = gr.Button("Classify X-ray")
output_text = gr.Text(label="Prediction Results")
output_image = gr.Image(label="Grad-CAM Heatmap", height=250)
classify_button.click(classify_image, [image_input], [output_text, output_image])
demo.launch()