File size: 6,323 Bytes
d33786f
48aaa68
4db7613
d33786f
 
b83ec23
decc13a
b83ec23
 
d33786f
b83ec23
75166d9
b83ec23
d33786f
 
 
b83ec23
 
 
 
 
75166d9
b83ec23
e95e76d
 
b83ec23
536d560
 
 
 
 
b83ec23
536d560
 
 
b83ec23
e95e76d
b83ec23
 
decc13a
fe48a9c
b83ec23
 
e95e76d
 
b83ec23
e95e76d
 
b83ec23
 
48aaa68
 
b83ec23
 
48aaa68
e95e76d
b83ec23
 
48aaa68
 
 
b83ec23
48aaa68
 
 
 
b83ec23
48aaa68
b83ec23
 
48aaa68
e95e76d
 
b83ec23
 
e95e76d
 
 
 
 
 
b83ec23
 
e95e76d
 
b83ec23
 
e95e76d
 
b83ec23
48aaa68
 
b83ec23
48aaa68
 
 
b83ec23
48aaa68
b83ec23
48aaa68
 
 
b83ec23
48aaa68
b83ec23
 
 
 
 
 
 
 
 
 
 
 
48aaa68
d33786f
 
b83ec23
 
 
d33786f
b83ec23
48aaa68
 
b83ec23
 
48aaa68
d33786f
b83ec23
d33786f
 
b83ec23
d33786f
 
b83ec23
 
d33786f
b83ec23
48aaa68
b83ec23
d33786f
b83ec23
d33786f
 
 
 
48aaa68
b83ec23
48aaa68
4db7613
48aaa68
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
import gradio as gr
import numpy as np
import cv2
import time
from collections import defaultdict
from tensorflow.keras.preprocessing.image import img_to_array
from tensorflow.keras.applications.densenet import preprocess_input
import matplotlib.pyplot as plt
from PIL import Image

# Load Model
model = load_model('Densenet.h5')
model.load_weights("pretrained_model.h5")
layer_name = 'conv5_block16_concat'
class_names = [
    'Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
    'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural Thickening', 'Pneumonia',
    'Fibrosis', 'Edema', 'Consolidation', 'No Finding'
]

# Symptom-to-Disease Mapping
symptom_data = {
    "Shortness of breath": {
        "questions": ["Do you also have chest pain?", "Do you feel fatigued often?", "Have you noticed swelling in your legs?"],
        "diseases": ["Atelectasis", "Emphysema", "Edema"],
        "weights_yes": [30, 30, 40],
        "weights_no": [10, 20, 30]
    },
    "Persistent cough": {
        "questions": ["Is your cough dry or with mucus?", "Do you experience fever?", "Do you have difficulty breathing?"],
        "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
        "weights_yes": [35, 30, 35],
        "weights_no": [10, 15, 20]
    },
}

# User State
user_state = {}

def chatbot(user_input, history=[]):
    """Chatbot for symptom-based diagnosis."""
    if "state" not in user_state:
        user_state["state"] = "greet"

    if user_state["state"] == "greet":
        user_state["state"] = "ask_symptom"
        return history + [(user_input, "Hello! Please describe your primary symptom.")]

    elif user_state["state"] == "ask_symptom":
        if user_input not in symptom_data:
            return history + [(user_input, f"I don't recognize that symptom. Please enter one of these: {', '.join(symptom_data.keys())}")]
        
        user_state["symptom"] = user_input
        user_state["state"] = "ask_duration"
        return history + [(user_input, "How long have you had this symptom? (Less than a week / More than a week)")]

    elif user_state["state"] == "ask_duration":
        if user_input.lower() == "less than a week":
            user_state.clear()
            return history + [(user_input, "It might be a temporary issue. Monitor symptoms and consult a doctor if they persist.")]
        elif user_input.lower() == "more than a week":
            user_state["state"] = "follow_up"
            user_state["current_question"] = 0
            user_state["disease_scores"] = defaultdict(int)
            return history + [(user_input, symptom_data[user_state["symptom"]]["questions"][0])]
        else:
            return history + [(user_input, "Please respond with 'Less than a week' or 'More than a week'.")]

    elif user_state["state"] == "follow_up":
        symptom = user_state["symptom"]
        question_index = user_state["current_question"]

        # Update disease probability scores
        if user_input.lower() == "yes":
            for i, disease in enumerate(symptom_data[symptom]["diseases"]):
                user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
        else:
            for i, disease in enumerate(symptom_data[symptom]["diseases"]):
                user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]

        # Next question or final diagnosis
        user_state["current_question"] += 1
        if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
            return history + [(user_input, symptom_data[symptom]["questions"][user_state["current_question"]])]

        probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
        user_state.clear()
        return history + [(user_input, f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor.")]

def get_gradcam(img, model, layer_name):
    """Generate Grad-CAM heatmap for X-ray image."""
    img_array = img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array = preprocess_input(img_array)

    grad_model = Model(inputs=model.inputs, outputs=[model.get_layer(layer_name).output, model.output])
    
    with tf.GradientTape() as tape:
        conv_outputs, predictions = grad_model(img_array)
        class_idx = tf.argmax(predictions[0])

    grads = tape.gradient(predictions, conv_outputs)[0]
    guided_grads = tf.cast(conv_outputs > 0, 'float32') * tf.cast(grads > 0, 'float32') * grads

    weights = tf.reduce_mean(guided_grads, axis=(0, 1))
    cam = tf.reduce_sum(tf.multiply(weights, conv_outputs), axis=-1)

    heatmap = np.maximum(cam, 0)
    heatmap /= tf.reduce_max(heatmap)
    heatmap = cv2.resize(heatmap.numpy(), (img.shape[1], img.shape[0]))

    colormap = cv2.applyColorMap(np.uint8(255 * heatmap), cv2.COLORMAP_JET)
    overlay = cv2.addWeighted(img, 0.5, colormap, 0.5, 0)

    return overlay

def classify_image(img):
    """Classify X-ray image and return Grad-CAM visualization."""
    img = cv2.resize(np.array(img), (540, 540))
    img_array = np.expand_dims(img, axis=0)
    img_array = preprocess_input(img_array)

    predictions = model.predict(img_array)
    overlay_img = get_gradcam(img, model, layer_name)

    top_pred = class_names[np.argmax(predictions)]
    return top_pred, overlay_img

# Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# Medical AI Assistant")

    with gr.Tab("Chatbot"):
        chatbot_ui = gr.Chatbot()
        user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
        submit = gr.Button("Send", variant="primary", interactive=True)
        clear_chat = gr.Button("Clear Chat")

        submit.click(chatbot, [user_input, chatbot_ui], chatbot_ui)
        user_input.submit(chatbot, [user_input, chatbot_ui], chatbot_ui)
        clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])

    with gr.Tab("X-ray Classification"):
        image_input = gr.Image()
        classify_button = gr.Button("Classify")
        output_text = gr.Text()
        output_image = gr.Image()

        classify_button.click(classify_image, [image_input], [output_text, output_image])

demo.launch()