File size: 7,462 Bytes
cefd566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1faa320
 
cefd566
 
 
 
 
 
1faa320
 
cefd566
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138ebb9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
from duckduckgo_search import DDGS
import datetime
import requests
import pytz
import yaml
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

class MedicalAgent:
    def __init__(self):
        self.patient_data = {}
        self.ddgs = DDGS()
        
    def search_web(self, query: str) -> str:
        try:
            results = list(self.ddgs.text(query, max_results=2))
            return "\n".join([result['body'] for result in results])
        except Exception as e:
            return f"Error searching: {str(e)}"

    def collect_symptoms(self, patient_id: str, symptoms: str) -> str:
        """
        Nurse tool to collect symptoms and ask follow-up questions.
        Args:
            patient_id: Unique identifier for the patient.
            symptoms: Initial symptoms provided by the patient.
        Returns:
            Follow-up questions based on the symptoms.
        """
        try:
            if not patient_id or not symptoms:
                raise ValueError("Patient ID and symptoms must be provided.")

            self.patient_data[patient_id] = {"symptoms": symptoms, "additional_info": {}}

            questions = [
                "How long have you had these symptoms?",
                "Do you have any allergies?",
                "Are you taking any medications?",
                "Have you experienced these symptoms before?",
                "Have you had any recent illnesses?",
                "Have you noticed any other unusual changes?",
                "What is your medical history related to these symptoms?"
            ]

            return f"Nurse: I have noted the symptoms ({symptoms}). Here are follow-up questions:\n" + "\n".join(questions)

        except ValueError as e:
            return f"Error: {str(e)}"
        except Exception as e:
            return f"Unexpected Error: {str(e)}"

    def diagnose_patient(self, patient_id: str) -> str:
        """
        Doctor tool to diagnose the patient based on symptoms.
        Args:
            patient_id: Unique identifier for the patient.
        Returns:
            Diagnosis and recommendations.
        """
        try:
            if patient_id not in self.patient_data:
                raise ValueError("No symptoms found. Nurse must collect symptoms first.")

            symptoms = self.patient_data[patient_id]["symptoms"]
            diagnosis = self.fetch_diagnosis(symptoms)
            medication = self.fetch_medication(symptoms)
            advice = self.fetch_treatment_advice(symptoms)

            return (f"Doctor: Based on the symptoms: {symptoms},\n"
                    f"Diagnosis: {diagnosis}\n"
                    f"Medication: {medication}\n"
                    f"Advice: {advice}")

        except ValueError as e:
            return f"Error: {str(e)}"
        except Exception as e:
            return f"Unexpected Error: {str(e)}"

    def fetch_diagnosis(self, symptoms: str) -> str:
        """
        AI tool to retrieve a diagnosis based on symptoms.
        Args:
            symptoms: The symptoms provided by the patient.
        Returns:
            Detailed Diagnosis information.
        """
        try:
            if not symptoms:
                raise ValueError("Symptoms must be provided.")
            search_query = f"A detailed medical diagnosis for these symptoms: {symptoms}"
            return self.search_web(search_query)
        except Exception as e:
            return f"Error fetching diagnosis: {str(e)}"

    def fetch_medication(self, symptoms: str) -> str:
        """
        AI tool to suggest medications based on symptoms.
        Args:
            symptoms: The symptoms provided by the patient.
        Returns:
            Suggested medication.
        """
        try:
            if not symptoms:
                raise ValueError("Symptoms must be provided.")
            search_query = f"Recommended medications for these symptoms: {symptoms}"
            return self.search_web(search_query)
        except Exception as e:
            return f"Error fetching medication: {str(e)}"

    def fetch_treatment_advice(self, symptoms: str) -> str:
        """
        AI tool to provide treatment recommendations.
        Args:
            symptoms: The symptoms provided by the patient.
        Returns:
            Recommended treatment and advice.
        """
        try:
            if not symptoms:
                raise ValueError("Symptoms must be provided.")
            search_query = f"Treatment advice and recommendations for these symptoms: {symptoms}"
            return self.search_web(search_query)
        except Exception as e:
            return f"Error fetching treatment advice: {str(e)}"

def get_current_time_in_timezone(timezone: str) -> str:
    """A tool that fetches the current local time in a specified timezone.
    Args:
        timezone: A string representing a valid timezone (e.g., 'America/New_York').
    """
    try:
        tz = pytz.timezone(timezone)
        local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
        return f"The current local time in {timezone} is: {local_time}"
    except Exception as e:
        return f"Error fetching time for timezone '{timezone}': {str(e)}"

class HealthAssistant:
    def __init__(self):
        self.medical_agent = MedicalAgent()
        self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
        # Set padding token
        self.tokenizer.pad_token = self.tokenizer.eos_token
        self.model = AutoModelForCausalLM.from_pretrained(
            "microsoft/phi-2",
            torch_dtype=torch.float16,
            device_map="auto",
            trust_remote_code=True
        )
        # Set model's padding token
        self.model.config.pad_token_id = self.tokenizer.pad_token_id
        
    def process_message(self, message: str) -> str:
        # Tokenize with attention mask
        inputs = self.tokenizer(
            f"User: {message}\nAssistant:",
            return_tensors="pt",
            padding=True,
            truncation=True,
            return_attention_mask=True
        )
        
        # Generate response
        outputs = self.model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=512,
            pad_token_id=self.tokenizer.eos_token_id
        )
        
        response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
        
        # Extract only the assistant's response
        response = response.split("Assistant:")[-1].strip()
        
        # Check if we need to collect symptoms or diagnose
        if "symptoms" in message.lower():
            response += "\n\n" + self.medical_agent.collect_symptoms("user_1", message)
        elif "diagnose" in message.lower():
            response += "\n\n" + self.medical_agent.diagnose_patient("user_1")
            
        return response

def chat_interface(message, history):
    assistant = HealthAssistant()
    response = assistant.process_message(message)
    return response

if __name__ == "__main__":
    iface = gr.ChatInterface(
        fn=chat_interface,
        title="Health Assistant",
        description="Chat with our AI health assistant to discuss your symptoms and get medical advice. Note: This is for educational purposes only and should not replace professional medical advice."
    )
    iface.launch()