Santhosh1705kumar commited on
Commit
74d281c
·
verified ·
1 Parent(s): fe48a9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -76
app.py CHANGED
@@ -1,88 +1,168 @@
1
  import gradio as gr
2
- import torch
3
- from transformers import AutoModelForCausalLM, AutoTokenizer
4
- import pandas as pd
5
 
6
- # Load the symptom dataset once
7
- df = pd.read_csv("enhanced_symptom_tree_with_measures.csv")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- # Load model and tokenizer once to optimize performance
10
- def load_model():
11
- model_name = "microsoft/phi-2" # Lightweight model
12
- tokenizer = AutoTokenizer.from_pretrained(model_name)
13
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
14
- model.eval() # Set model to evaluation mode
15
- return tokenizer, model
16
 
17
- tokenizer, model = load_model()
18
-
19
- # Global variable to track last detected symptom
20
- last_detected_symptom = None
21
-
22
- # Function to find symptom match
23
- def find_symptom_match(user_input):
24
- match = df[df["Primary Symptom"].str.lower().str.contains(user_input.lower(), na=False)]
25
- if not match.empty:
26
- symptom = match.iloc[0]
27
- response = f"It seems like you're experiencing {symptom['Primary Symptom']}. "
28
- if pd.notna(symptom["Follow-up Question"]):
29
- response += f"{symptom['Follow-up Question']} "
30
- response += f"\nPossible conditions: {symptom['Possible Diseases']} \n"
31
- response += f"Recommended measures: {symptom['Recommended Measures']}"
32
- return symptom['Primary Symptom'], response # Return symptom name and response
33
- return None, None
34
-
35
- # Main chatbot response function
36
- def chatbot_response(user_input, history):
37
- global last_detected_symptom # Maintain previous symptom context
38
 
39
- if not user_input.strip():
40
- return history, ""
41
 
42
- # Step 1: If it's a follow-up response, continue from the last known symptom
43
- if last_detected_symptom:
44
- match = df[df["Primary Symptom"].str.lower() == last_detected_symptom.lower()]
45
- if not match.empty:
46
- follow_up_options = match.iloc[0]["Follow-up Question"]
47
- if pd.notna(follow_up_options) and user_input.lower() in follow_up_options.lower():
48
- response = f"Got it. Based on that, possible conditions: {match.iloc[0]['Possible Diseases']} \nRecommended: {match.iloc[0]['Recommended Measures']}"
49
- last_detected_symptom = None # Reset symptom tracking
50
- history.append((user_input, response))
51
- return history, ""
52
-
53
- # Step 2: Otherwise, check for a new symptom
54
- detected_symptom, symptom_info = find_symptom_match(user_input)
55
- if symptom_info:
56
- last_detected_symptom = detected_symptom # Store new symptom for next response
57
- response = symptom_info
58
- else:
59
- # Step 3: Use LLM as a fallback if no match is found
60
- prompt = f"User: {user_input}\nChatbot:"
61
- inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
62
- with torch.no_grad():
63
- outputs = model.generate(**inputs, max_length=150, pad_token_id=tokenizer.eos_token_id)
64
- response = tokenizer.decode(outputs[0], skip_special_tokens=True).split("Chatbot:")[-1].strip()
65
 
66
- history.append((user_input, response))
67
- return history, ""
68
 
69
- # Function to clear chat
70
- def clear_chat():
71
- global last_detected_symptom
72
- last_detected_symptom = None # Reset symptom tracking
73
- return [], ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- # Gradio UI
76
- with gr.Blocks(theme="compact") as demo:
77
- gr.Markdown("### Symptom Chatbot 🏥")
78
- chatbot = gr.Chatbot()
79
- user_input = gr.Textbox(placeholder="Type your symptom and get advice...", interactive=True)
80
  submit = gr.Button("Send")
81
- clear = gr.Button("Clear Chat")
 
 
 
 
 
 
 
 
 
82
 
83
- submit.click(chatbot_response, [user_input, chatbot], [chatbot, user_input])
84
- clear.click(clear_chat, [], [chatbot, user_input])
 
85
 
86
- # Launch app in Hugging Face Spaces environment
87
- if __name__ == "__main__":
88
- demo.launch()
 
1
  import gradio as gr
2
+ import time
3
+ from collections import defaultdict
4
+ import spacy
5
 
6
+ # Load the symptom-to-disease mapping (same as before)
7
+ symptom_data = {
8
+ "Shortness of breath": {
9
+ "questions": [
10
+ "Do you also have chest pain?",
11
+ "Do you feel fatigued often?",
12
+ "Have you noticed swelling in your legs?"
13
+ ],
14
+ "diseases": ["Atelectasis", "Emphysema", "Edema"],
15
+ "weights_yes": [30, 30, 40],
16
+ "weights_no": [10, 20, 30]
17
+ },
18
+ "Persistent cough": {
19
+ "questions": [
20
+ "Is your cough dry or with mucus?",
21
+ "Do you experience fever?",
22
+ "Do you have difficulty breathing?"
23
+ ],
24
+ "diseases": ["Pneumonia", "Fibrosis", "Infiltration"],
25
+ "weights_yes": [35, 30, 35],
26
+ "weights_no": [10, 15, 20]
27
+ },
28
+ "Sharp chest pain": {
29
+ "questions": [
30
+ "Does it worsen with deep breaths?",
31
+ "Do you feel lightheaded?",
32
+ "Have you had recent trauma or surgery?"
33
+ ],
34
+ "diseases": ["Pneumothorax", "Effusion", "Cardiomegaly"],
35
+ "weights_yes": [40, 30, 30],
36
+ "weights_no": [15, 20, 25]
37
+ },
38
+ "Fatigue & swelling": {
39
+ "questions": [
40
+ "Do you feel breathless when lying down?",
41
+ "Have you gained weight suddenly?",
42
+ "Do you experience irregular heartbeat?"
43
+ ],
44
+ "diseases": ["Edema", "Cardiomegaly"],
45
+ "weights_yes": [50, 30, 20],
46
+ "weights_no": [20, 15, 15]
47
+ },
48
+ "Chronic wheezing": {
49
+ "questions": [
50
+ "Do you have a history of smoking?",
51
+ "Do you feel tightness in your chest?",
52
+ "Do you have frequent lung infections?"
53
+ ],
54
+ "diseases": ["Emphysema", "Fibrosis"],
55
+ "weights_yes": [40, 30, 30],
56
+ "weights_no": [15, 25, 20]
57
+ }
58
+ }
59
 
60
+ # Load spaCy model for NLP
61
+ nlp = spacy.load("en_core_web_lg")
 
 
 
 
 
62
 
63
+ # Function to extract key symptom from user input
64
+ def extract_symptom(user_input):
65
+ # Define the symptoms that the chatbot recognizes
66
+ known_symptoms = list(symptom_data.keys())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ # Process the input with spaCy NLP model
69
+ user_doc = nlp(user_input.lower())
70
 
71
+ # Check if any of the known symptoms are in the user input
72
+ for symptom in known_symptoms:
73
+ if symptom.lower() in user_input.lower():
74
+ return symptom
75
+
76
+ # If no direct match, use similarity to find the closest symptom
77
+ similarities = {}
78
+ for symptom in known_symptoms:
79
+ symptom_doc = nlp(symptom.lower())
80
+ similarity = user_doc.similarity(symptom_doc)
81
+ similarities[symptom] = similarity
82
+
83
+ # Return the symptom with the highest similarity
84
+ return max(similarities, key=similarities.get)
 
 
 
 
 
 
 
 
 
85
 
86
+ # Global variables to track user state
87
+ user_state = {}
88
 
89
+ def chatbot(user_input):
90
+ if "state" not in user_state:
91
+ user_state["state"] = "greet"
92
+
93
+ if user_state["state"] == "greet":
94
+ user_state["state"] = "ask_symptom"
95
+ return "Hello! I'm a medical AI assistant. Please describe your primary symptom."
96
+
97
+ elif user_state["state"] == "ask_symptom":
98
+ # Extract symptom from the user input
99
+ matched_symptom = extract_symptom(user_input)
100
+
101
+ if matched_symptom not in symptom_data:
102
+ user_state["state"] = "ask_feeling"
103
+ return "I'm sorry, I don't recognize that symptom. How do you feel?"
104
+
105
+ user_state["symptom"] = matched_symptom
106
+ user_state["state"] = "ask_duration"
107
+ return "How long have you been experiencing this symptom? (Less than a week / More than a week)"
108
+
109
+ elif user_state["state"] == "ask_feeling":
110
+ # If the symptom is not recognized, ask how they feel
111
+ return "Can you describe your symptoms in more detail?"
112
+
113
+ elif user_state["state"] == "ask_duration":
114
+ if user_input.lower() == "less than a week":
115
+ user_state.clear()
116
+ return "It might be a temporary issue. Please monitor your symptoms and consult a doctor if they persist."
117
+ elif user_input.lower() == "more than a week":
118
+ user_state["state"] = "follow_up"
119
+ user_state["current_question"] = 0
120
+ user_state["disease_scores"] = defaultdict(int)
121
+ return symptom_data[user_state["symptom"]]["questions"][0]
122
+ else:
123
+ return "Please respond with 'Less than a week' or 'More than a week'."
124
+
125
+ elif user_state["state"] == "follow_up":
126
+ symptom = user_state["symptom"]
127
+ question_index = user_state["current_question"]
128
+
129
+ # Update probabilities
130
+ if user_input.lower() == "yes":
131
+ for i, disease in enumerate(symptom_data[symptom]["diseases"]):
132
+ user_state["disease_scores"][disease] += symptom_data[symptom]["weights_yes"][i]
133
+ else:
134
+ for i, disease in enumerate(symptom_data[symptom]["diseases"]):
135
+ user_state["disease_scores"][disease] += symptom_data[symptom]["weights_no"][i]
136
+
137
+ # Move to the next question or finish
138
+ user_state["current_question"] += 1
139
+ if user_state["current_question"] < len(symptom_data[symptom]["questions"]):
140
+ return symptom_data[symptom]["questions"][user_state["current_question"]]
141
+
142
+ # Final diagnosis
143
+ probable_disease = max(user_state["disease_scores"], key=user_state["disease_scores"].get)
144
+ user_state.clear()
145
+ return f"Based on your symptoms, the most likely condition is: {probable_disease}. Please consult a doctor for confirmation."
146
 
147
+ # Gradio Chatbot UI with improved features
148
+ with gr.Blocks() as demo:
149
+ gr.Markdown("# Conversational Image Recognition Assistant: AI-Powered X-ray Diagnosis for Healthcare")
150
+ chatbot_ui = gr.Chatbot()
151
+ user_input = gr.Textbox(placeholder="Enter your response...", label="Your Message")
152
  submit = gr.Button("Send")
153
+ clear_chat = gr.Button("Clear Chat")
154
+
155
+ def respond(user_message, history):
156
+ history.append((user_message, "Thinking...")) # Show thinking message
157
+ yield history, "" # Immediate update
158
+
159
+ time.sleep(1.5) # Simulate processing delay
160
+ bot_response = chatbot(user_message)
161
+ history[-1] = (user_message, bot_response) # Update with real response
162
+ yield history, ""
163
 
164
+ submit.click(respond, [user_input, chatbot_ui], [chatbot_ui, user_input])
165
+ user_input.submit(respond, [user_input, chatbot_ui], [chatbot_ui, user_input])
166
+ clear_chat.click(lambda: ([], ""), outputs=[chatbot_ui, user_input])
167
 
168
+ demo.launch()