Chinez-dev commited on
Commit
cefd566
·
verified ·
1 Parent(s): 138ebb9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +193 -178
app.py CHANGED
@@ -1,179 +1,194 @@
1
- from duckduckgo_search import DDGS
2
- import datetime
3
- import requests
4
- import pytz
5
- import yaml
6
- import gradio as gr
7
- from transformers import AutoModelForCausalLM, AutoTokenizer
8
- import torch
9
-
10
- class MedicalAgent:
11
- def __init__(self):
12
- self.patient_data = {}
13
- self.ddgs = DDGS()
14
-
15
- def search_web(self, query: str) -> str:
16
- try:
17
- results = list(self.ddgs.text(query, max_results=2))
18
- return "\n".join([result['body'] for result in results])
19
- except Exception as e:
20
- return f"Error searching: {str(e)}"
21
-
22
- def collect_symptoms(self, patient_id: str, symptoms: str) -> str:
23
- """
24
- Nurse tool to collect symptoms and ask follow-up questions.
25
- Args:
26
- patient_id: Unique identifier for the patient.
27
- symptoms: Initial symptoms provided by the patient.
28
- Returns:
29
- Follow-up questions based on the symptoms.
30
- """
31
- try:
32
- if not patient_id or not symptoms:
33
- raise ValueError("Patient ID and symptoms must be provided.")
34
-
35
- self.patient_data[patient_id] = {"symptoms": symptoms, "additional_info": {}}
36
-
37
- questions = [
38
- "How long have you had these symptoms?",
39
- "Do you have any allergies?",
40
- "Are you taking any medications?",
41
- "Have you experienced these symptoms before?",
42
- "Have you had any recent illnesses?",
43
- "Have you noticed any other unusual changes?",
44
- "What is your medical history related to these symptoms?"
45
- ]
46
-
47
- return f"Nurse: I have noted the symptoms ({symptoms}). Here are follow-up questions:\n" + "\n".join(questions)
48
-
49
- except ValueError as e:
50
- return f"Error: {str(e)}"
51
- except Exception as e:
52
- return f"Unexpected Error: {str(e)}"
53
-
54
- def diagnose_patient(self, patient_id: str) -> str:
55
- """
56
- Doctor tool to diagnose the patient based on symptoms.
57
- Args:
58
- patient_id: Unique identifier for the patient.
59
- Returns:
60
- Diagnosis and recommendations.
61
- """
62
- try:
63
- if patient_id not in self.patient_data:
64
- raise ValueError("No symptoms found. Nurse must collect symptoms first.")
65
-
66
- symptoms = self.patient_data[patient_id]["symptoms"]
67
- diagnosis = self.fetch_diagnosis(symptoms)
68
- medication = self.fetch_medication(symptoms)
69
- advice = self.fetch_treatment_advice(symptoms)
70
-
71
- return (f"Doctor: Based on the symptoms: {symptoms},\n"
72
- f"Diagnosis: {diagnosis}\n"
73
- f"Medication: {medication}\n"
74
- f"Advice: {advice}")
75
-
76
- except ValueError as e:
77
- return f"Error: {str(e)}"
78
- except Exception as e:
79
- return f"Unexpected Error: {str(e)}"
80
-
81
- def fetch_diagnosis(self, symptoms: str) -> str:
82
- """
83
- AI tool to retrieve a diagnosis based on symptoms.
84
- Args:
85
- symptoms: The symptoms provided by the patient.
86
- Returns:
87
- Detailed Diagnosis information.
88
- """
89
- try:
90
- if not symptoms:
91
- raise ValueError("Symptoms must be provided.")
92
- search_query = f"A detailed medical diagnosis for these symptoms: {symptoms}"
93
- return self.search_web(search_query)
94
- except Exception as e:
95
- return f"Error fetching diagnosis: {str(e)}"
96
-
97
- def fetch_medication(self, symptoms: str) -> str:
98
- """
99
- AI tool to suggest medications based on symptoms.
100
- Args:
101
- symptoms: The symptoms provided by the patient.
102
- Returns:
103
- Suggested medication.
104
- """
105
- try:
106
- if not symptoms:
107
- raise ValueError("Symptoms must be provided.")
108
- search_query = f"Recommended medications for these symptoms: {symptoms}"
109
- return self.search_web(search_query)
110
- except Exception as e:
111
- return f"Error fetching medication: {str(e)}"
112
-
113
- def fetch_treatment_advice(self, symptoms: str) -> str:
114
- """
115
- AI tool to provide treatment recommendations.
116
- Args:
117
- symptoms: The symptoms provided by the patient.
118
- Returns:
119
- Recommended treatment and advice.
120
- """
121
- try:
122
- if not symptoms:
123
- raise ValueError("Symptoms must be provided.")
124
- search_query = f"Treatment advice and recommendations for these symptoms: {symptoms}"
125
- return self.search_web(search_query)
126
- except Exception as e:
127
- return f"Error fetching treatment advice: {str(e)}"
128
-
129
- def get_current_time_in_timezone(timezone: str) -> str:
130
- """A tool that fetches the current local time in a specified timezone.
131
- Args:
132
- timezone: A string representing a valid timezone (e.g., 'America/New_York').
133
- """
134
- try:
135
- tz = pytz.timezone(timezone)
136
- local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
137
- return f"The current local time in {timezone} is: {local_time}"
138
- except Exception as e:
139
- return f"Error fetching time for timezone '{timezone}': {str(e)}"
140
-
141
- class HealthAssistant:
142
- def __init__(self):
143
- self.medical_agent = MedicalAgent()
144
- self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
145
- self.model = AutoModelForCausalLM.from_pretrained(
146
- "microsoft/phi-2",
147
- torch_dtype=torch.float16,
148
- device_map="auto",
149
- trust_remote_code=True
150
- )
151
-
152
- def process_message(self, message: str) -> str:
153
- inputs = self.tokenizer(f"User: {message}\nAssistant:", return_tensors="pt", return_attention_mask=False)
154
- outputs = self.model.generate(**inputs, max_length=512, pad_token_id=self.tokenizer.eos_token_id)
155
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
156
-
157
- # Extract only the assistant's response
158
- response = response.split("Assistant:")[-1].strip()
159
-
160
- # Check if we need to collect symptoms or diagnose
161
- if "symptoms" in message.lower():
162
- response += "\n\n" + self.medical_agent.collect_symptoms("user_1", message)
163
- elif "diagnose" in message.lower():
164
- response += "\n\n" + self.medical_agent.diagnose_patient("user_1")
165
-
166
- return response
167
-
168
- def chat_interface(message, history):
169
- assistant = HealthAssistant()
170
- response = assistant.process_message(message)
171
- return response
172
-
173
- if __name__ == "__main__":
174
- iface = gr.ChatInterface(
175
- fn=chat_interface,
176
- title="Health Assistant",
177
- description="Chat with our AI health assistant to discuss your symptoms and get medical advice. Note: This is for educational purposes only and should not replace professional medical advice."
178
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  iface.launch()
 
1
+ from duckduckgo_search import DDGS
2
+ import datetime
3
+ import requests
4
+ import pytz
5
+ import yaml
6
+ import gradio as gr
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ import torch
9
+
10
+ class MedicalAgent:
11
+ def __init__(self):
12
+ self.patient_data = {}
13
+ self.ddgs = DDGS()
14
+
15
+ def search_web(self, query: str) -> str:
16
+ try:
17
+ results = list(self.ddgs.text(query, max_results=2))
18
+ return "\n".join([result['body'] for result in results])
19
+ except Exception as e:
20
+ return f"Error searching: {str(e)}"
21
+
22
+ def collect_symptoms(self, patient_id: str, symptoms: str) -> str:
23
+ """
24
+ Nurse tool to collect symptoms and ask follow-up questions.
25
+ Args:
26
+ patient_id: Unique identifier for the patient.
27
+ symptoms: Initial symptoms provided by the patient.
28
+ Returns:
29
+ Follow-up questions based on the symptoms.
30
+ """
31
+ try:
32
+ if not patient_id or not symptoms:
33
+ raise ValueError("Patient ID and symptoms must be provided.")
34
+
35
+ self.patient_data[patient_id] = {"symptoms": symptoms, "additional_info": {}}
36
+
37
+ questions = [
38
+ "How long have you had these symptoms?",
39
+ "Do you have any allergies?",
40
+ "Are you taking any medications?",
41
+ "Have you experienced these symptoms before?",
42
+ "Have you had any recent illnesses?",
43
+ "Have you noticed any other unusual changes?",
44
+ "What is your medical history related to these symptoms?"
45
+ ]
46
+
47
+ return f"Nurse: I have noted the symptoms ({symptoms}). Here are follow-up questions:\n" + "\n".join(questions)
48
+
49
+ except ValueError as e:
50
+ return f"Error: {str(e)}"
51
+ except Exception as e:
52
+ return f"Unexpected Error: {str(e)}"
53
+
54
+ def diagnose_patient(self, patient_id: str) -> str:
55
+ """
56
+ Doctor tool to diagnose the patient based on symptoms.
57
+ Args:
58
+ patient_id: Unique identifier for the patient.
59
+ Returns:
60
+ Diagnosis and recommendations.
61
+ """
62
+ try:
63
+ if patient_id not in self.patient_data:
64
+ raise ValueError("No symptoms found. Nurse must collect symptoms first.")
65
+
66
+ symptoms = self.patient_data[patient_id]["symptoms"]
67
+ diagnosis = self.fetch_diagnosis(symptoms)
68
+ medication = self.fetch_medication(symptoms)
69
+ advice = self.fetch_treatment_advice(symptoms)
70
+
71
+ return (f"Doctor: Based on the symptoms: {symptoms},\n"
72
+ f"Diagnosis: {diagnosis}\n"
73
+ f"Medication: {medication}\n"
74
+ f"Advice: {advice}")
75
+
76
+ except ValueError as e:
77
+ return f"Error: {str(e)}"
78
+ except Exception as e:
79
+ return f"Unexpected Error: {str(e)}"
80
+
81
+ def fetch_diagnosis(self, symptoms: str) -> str:
82
+ """
83
+ AI tool to retrieve a diagnosis based on symptoms.
84
+ Args:
85
+ symptoms: The symptoms provided by the patient.
86
+ Returns:
87
+ Detailed Diagnosis information.
88
+ """
89
+ try:
90
+ if not symptoms:
91
+ raise ValueError("Symptoms must be provided.")
92
+ search_query = f"A detailed medical diagnosis for these symptoms: {symptoms}"
93
+ return self.search_web(search_query)
94
+ except Exception as e:
95
+ return f"Error fetching diagnosis: {str(e)}"
96
+
97
+ def fetch_medication(self, symptoms: str) -> str:
98
+ """
99
+ AI tool to suggest medications based on symptoms.
100
+ Args:
101
+ symptoms: The symptoms provided by the patient.
102
+ Returns:
103
+ Suggested medication.
104
+ """
105
+ try:
106
+ if not symptoms:
107
+ raise ValueError("Symptoms must be provided.")
108
+ search_query = f"Recommended medications for these symptoms: {symptoms}"
109
+ return self.search_web(search_query)
110
+ except Exception as e:
111
+ return f"Error fetching medication: {str(e)}"
112
+
113
+ def fetch_treatment_advice(self, symptoms: str) -> str:
114
+ """
115
+ AI tool to provide treatment recommendations.
116
+ Args:
117
+ symptoms: The symptoms provided by the patient.
118
+ Returns:
119
+ Recommended treatment and advice.
120
+ """
121
+ try:
122
+ if not symptoms:
123
+ raise ValueError("Symptoms must be provided.")
124
+ search_query = f"Treatment advice and recommendations for these symptoms: {symptoms}"
125
+ return self.search_web(search_query)
126
+ except Exception as e:
127
+ return f"Error fetching treatment advice: {str(e)}"
128
+
129
+ def get_current_time_in_timezone(timezone: str) -> str:
130
+ """A tool that fetches the current local time in a specified timezone.
131
+ Args:
132
+ timezone: A string representing a valid timezone (e.g., 'America/New_York').
133
+ """
134
+ try:
135
+ tz = pytz.timezone(timezone)
136
+ local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
137
+ return f"The current local time in {timezone} is: {local_time}"
138
+ except Exception as e:
139
+ return f"Error fetching time for timezone '{timezone}': {str(e)}"
140
+
141
+ class HealthAssistant:
142
+ def __init__(self):
143
+ self.medical_agent = MedicalAgent()
144
+ self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
145
+ self.model = AutoModelForCausalLM.from_pretrained(
146
+ "microsoft/phi-2",
147
+ torch_dtype=torch.float16,
148
+ device_map="auto",
149
+ trust_remote_code=True
150
+ )
151
+
152
+ def process_message(self, message: str) -> str:
153
+ # Tokenize with attention mask
154
+ inputs = self.tokenizer(
155
+ f"User: {message}\nAssistant:",
156
+ return_tensors="pt",
157
+ padding=True,
158
+ truncation=True,
159
+ return_attention_mask=True
160
+ )
161
+
162
+ # Generate response
163
+ outputs = self.model.generate(
164
+ input_ids=inputs["input_ids"],
165
+ attention_mask=inputs["attention_mask"],
166
+ max_length=512,
167
+ pad_token_id=self.tokenizer.eos_token_id
168
+ )
169
+
170
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
171
+
172
+ # Extract only the assistant's response
173
+ response = response.split("Assistant:")[-1].strip()
174
+
175
+ # Check if we need to collect symptoms or diagnose
176
+ if "symptoms" in message.lower():
177
+ response += "\n\n" + self.medical_agent.collect_symptoms("user_1", message)
178
+ elif "diagnose" in message.lower():
179
+ response += "\n\n" + self.medical_agent.diagnose_patient("user_1")
180
+
181
+ return response
182
+
183
+ def chat_interface(message, history):
184
+ assistant = HealthAssistant()
185
+ response = assistant.process_message(message)
186
+ return response
187
+
188
+ if __name__ == "__main__":
189
+ iface = gr.ChatInterface(
190
+ fn=chat_interface,
191
+ title="Health Assistant",
192
+ description="Chat with our AI health assistant to discuss your symptoms and get medical advice. Note: This is for educational purposes only and should not replace professional medical advice."
193
+ )
194
  iface.launch()