Spaces:
Build error
Build error
| import gradio as gr | |
| import tensorflow as tf | |
| from tensorflow.keras.models import Model, load_model | |
| from tensorflow.keras.preprocessing.image import img_to_array | |
| from tensorflow.keras.applications.densenet import preprocess_input | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| import time | |
| from collections import defaultdict | |
| # Load model | |
| model = load_model('Densenet.h5') | |
| model.load_weights("pretrained_model.h5") | |
| layer_name = 'conv5_block16_concat' | |
| # Define classes | |
| class_names = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass', | |
| 'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening', 'Pneumonia', | |
| 'Fibrosis', 'Edema', 'Consolidation', 'No Finding'] | |
| # Symptom mapping | |
| symptom_data = { | |
| "Shortness of breath": { | |
| "questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"], | |
| "diseases": ["Atelectasis", "Emphysema", "Edema"], | |
| "weights_yes": [30, 30, 40], | |
| "weights_no": [10, 20, 30] | |
| }, | |
| "Persistent cough": { | |
| "questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"], | |
| "diseases": ["Pneumonia", "Fibrosis", "Infiltration"], | |
| "weights_yes": [35, 30, 35], | |
| "weights_no": [10, 15, 20] | |
| }, | |
| "Sharp chest pain": { | |
| "questions": ["Does it worsen with deep breaths?", "Do you feel lightheaded?", "Have you had recent trauma or surgery?"], | |
| "diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"], | |
| "weights_yes": [40, 30, 30], | |
| "weights_no": [15, 20, 25] | |
| } | |
| } | |
| # User state tracking | |
| user_state = {} | |
| # Chatbot function | |
| def chatbot(user_input, history=[]): | |
| if "state" not in user_state: | |
| user_state["state"] = "greet" | |
| history.append(("User", user_input)) | |
| return history, "Hello! I'm a medical AI assistant. Please describe your primary symptom." | |
| if user_state["state"] == "greet": | |
| user_state["state"] = "ask_symptom" | |
| return history + [("User", user_input), ("AI", "Please describe your primary symptom.")] | |
| if user_state["state"] == "ask_symptom": | |
| if user_input not in symptom_data: | |
| return history + [("User", user_input), ("AI", "Please enter a valid symptom: " + ", ".join(symptom_data.keys()))] | |
| user_state["symptom"] = user_input | |
| user_state["state"] = "ask_duration" | |
| return history + [("User", user_input), ("AI", "How long have you had this symptom? (Less than a week / More than a week)")] | |
| if user_state["state"] == "ask_duration": | |
| if user_input.lower() not in ["less than a week", "more than a week"]: | |
| return history + [("User", user_input), ("AI", "Please respond with 'Less than a week' or 'More than a week'.")] | |
| if user_input.lower() == "less than a week": | |
| user_state.clear() | |
| return history + [("User", user_input), ("AI", "It might be temporary. Monitor symptoms and see a doctor if needed.")] | |
| user_state["state"] = "follow_up" | |
| user_state["current_question"] = 0 | |
| user_state["disease_scores"] = defaultdict(int) | |
| return history + [("User", user_input), ("AI", symptom_data[user_state['symptom']]['questions'][0])] | |
| if user_state["state"] == "follow_up": | |
| symptom = user_state["symptom"] | |
| question_index = user_state["current_question"] | |
| if user_input.lower() == "yes": | |
| for i, disease in enumerate(symptom_data[symptom]["diseases"]): | |
| user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i] | |
| else: | |
| for i, disease in enumerate(symptom_data[symptom]["diseases"]): | |
| user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i] | |
| user_state["current_question"] += 1 | |
| if user_state["current_question"] < len(symptom_data[symptom]["questions"]): | |
| return history + [("User", user_input), ("AI", symptom_data[symptom]["questions"][user_state["current_question"]])] | |
| probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get) | |
| user_state.clear() | |
| return history + [("User", user_input), (f"AI", f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")] | |
| # Grad-CAM function | |
| def get_gradcam(model, img, layer_name): | |
| img_array = preprocess_input(np.expand_dims(img_to_array(img), axis=0)) | |
| grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output]) | |
| with tf.GradientTape() as tape: | |
| conv_outputs, predictions = grad_model(img_array) | |
| class_idx = tf.argmax(predictions[0]) | |
| output = conv_outputs[0] | |
| grads = tape.gradient(predictions, conv_outputs)[0] | |
| guided_grads = tf.cast(output > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads | |
| weights = tf.reduce_mean(guided_grads, axis=(0, 1)) | |
| cam = np.maximum(tf.reduce_sum(tf.multiply(weights, output), axis=-1), 0) | |
| heatmap = np.uint8(255 * cam / tf.reduce_max(cam)) | |
| return Image.fromarray(cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)) | |
| # X-ray classification function | |
| def classify_image(img): | |
| img = cv2.resize(img, (540, 540)) | |
| predictions = model.predict(np.expand_dims(preprocess_input(img_to_array(img)), axis=0)) | |
| top_indices = predictions[0].argsort()[-4:][::-1] | |
| decoded_predictions = [(class_names[i], float(predictions[0][i])) for i in top_indices] | |
| return decoded_predictions, get_gradcam(model, img, layer_name) | |
| # Gradio UI | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Medical AI Assistant") | |
| with gr.Tab("Symptom Chatbot"): | |
| chatbot_ui = gr.Chatbot(label="Chatbot") | |
| user_input = gr.Textbox(label="Your Message", interactive=True) | |
| submit = gr.Button("Send") | |
| clear_chat = gr.Button("Clear Chat") | |
| submit.click(chatbot, [user_input, chatbot_ui], [chatbot_ui]) | |
| clear_chat.click(lambda: [], [], chatbot_ui) | |
| with gr.Tab("X-ray Classification"): | |
| image_input = gr.Image(type="numpy", label="Upload Chest X-ray", height=250) | |
| classify_button = gr.Button("Classify X-ray") | |
| output_text = gr.Text(label="Prediction Results") | |
| output_image = gr.Image(label="Grad-CAM Heatmap", height=250) | |
| classify_button.click(classify_image, [image_input], [output_text, output_image]) | |
| demo.launch() | |