CHATBOT2 / app.py
Santhosh1705kumar's picture
Update app.py
b83ec23 verified
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
import gradio as gr
import numpy as np
import cv2
import time
from collections import defaultdict
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import preprocess_input
import matplotlib.pyplot as plt
from PIL import Image
# Load Model
model = load_model('Densenet.h5')
model.load_weights("pretrained_model.h5")
layer_name = 'conv5_block16_concat'
class_names = [
'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural Thickening', 'Pneumonia',
'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
]
# Symptom-to-Disease Mapping
symptom_data = {
"Shortness of breath": {
"questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
"diseases": ["Atelectasis", "Emphysema", "Edema"],
"weights_yes": [30, 30, 40],
"weights_no": [10, 20, 30]
},
"Persistent cough": {
"questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
"diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
"weights_yes": [35, 30, 35],
"weights_no": [10, 15, 20]
},
}
# User State
user_state = {}
def chatbot(user_input, history=[]):
"""Chatbot for symptom-based diagnosis."""
if "state" not in user_state:
user_state["state"] = "greet"
if user_state["state"] == "greet":
user_state["state"] = "ask_symptom"
return history + [(user_input, "Hello! Please describe your primary symptom.")]
elif user_state["state"] == "ask_symptom":
if user_input not in symptom_data:
return history + [(user_input, f"I don't recognize that symptom. Please enter one of these: {', '.join(symptom_data.keys())}")]
user_state["symptom"] = user_input
user_state["state"] = "ask_duration"
return history + [(user_input, "How long have you had this symptom? (Less than a week / More than a week)")]
elif user_state["state"] == "ask_duration":
if user_input.lower() == "less than a week":
user_state.clear()
return history + [(user_input, "It might be a temporary issue. Monitor symptoms and consult a doctor if they persist.")]
elif user_input.lower() == "more than a week":
user_state["state"] = "follow_up"
user_state["current_question"] = 0
user_state["disease_scores"] = defaultdict(int)
return history + [(user_input, symptom_data[user_state["symptom"]]["questions"][0])]
else:
return history + [(user_input, "Please respond with 'Less than a week' or 'More than a week'.")]
elif user_state["state"] == "follow_up":
symptom = user_state["symptom"]
question_index = user_state["current_question"]
# Update disease probability scores
if user_input.lower() == "yes":
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
else:
for i, disease in enumerate(symptom_data[symptom]["diseases"]):
user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
# Next question or final diagnosis
user_state["current_question"] += 1
if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
return history + [(user_input, symptom_data[symptom]["questions"][user_state["current_question"]])]
probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
user_state.clear()
return history + [(user_input, f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]
def get_gradcam(img, model, layer_name):
"""Generate Grad-CAM heatmap for X-ray image."""
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
img_array = preprocess_input(img_array)
grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
class_idx = tf.argmax(predictions[0])
grads = tape.gradient(predictions, conv_outputs)[0]
guided_grads = tf.cast(conv_outputs > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)
heatmap = np.maximum(cam, 0)
heatmap /= tf.reduce_max(heatmap)
heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))
colormap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
overlay = cv2.addWeighted(img, 0.5, colormap, 0.5, 0)
return overlay
def classify_image(img):
"""Classify X-ray image and return Grad-CAM visualization."""
img = cv2.resize(np.array(img), (540, 540))
img_array = np.expand_dims(img, axis=0)
img_array = preprocess_input(img_array)
predictions = model.predict(img_array)
overlay_img = get_gradcam(img, model, layer_name)
top_pred = class_names[np.argmax(predictions)]
return top_pred, overlay_img
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# Medical AI Assistant")
with gr.Tab("Chatbot"):
chatbot_ui = gr.Chatbot()
user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
submit = gr.Button("Send", variant="primary", interactive=True)
clear_chat = gr.Button("Clear Chat")
submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
user_input.submit(chatbot, [user_input, chatbot_ui], chatbot_ui)
clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
with gr.Tab("X-ray Classification"):
image_input = gr.Image()
classify_button = gr.Button("Classify")
output_text = gr.Text()
output_image = gr.Image()
classify_button.click(classify_image, [image_input], [output_text, output_image])
demo.launch()