Spaces:
Sleeping
Sleeping
| from duckduckgo_search import DDGS | |
| import datetime | |
| import requests | |
| import pytz | |
| import yaml | |
| import gradio as gr | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| class MedicalAgent: | |
| def __init__(self): | |
| self.patient_data = {} | |
| self.ddgs = DDGS() | |
| def search_web(self, query: str) -> str: | |
| try: | |
| results = list(self.ddgs.text(query, max_results=2)) | |
| return "\n".join([result['body'] for result in results]) | |
| except Exception as e: | |
| return f"Error searching: {str(e)}" | |
| def collect_symptoms(self, patient_id: str, symptoms: str) -> str: | |
| """ | |
| Nurse tool to collect symptoms and ask follow-up questions. | |
| Args: | |
| patient_id: Unique identifier for the patient. | |
| symptoms: Initial symptoms provided by the patient. | |
| Returns: | |
| Follow-up questions based on the symptoms. | |
| """ | |
| try: | |
| if not patient_id or not symptoms: | |
| raise ValueError("Patient ID and symptoms must be provided.") | |
| self.patient_data[patient_id] = {"symptoms": symptoms, "additional_info": {}} | |
| questions = [ | |
| "How long have you had these symptoms?", | |
| "Do you have any allergies?", | |
| "Are you taking any medications?", | |
| "Have you experienced these symptoms before?", | |
| "Have you had any recent illnesses?", | |
| "Have you noticed any other unusual changes?", | |
| "What is your medical history related to these symptoms?" | |
| ] | |
| return f"Nurse: I have noted the symptoms ({symptoms}). Here are follow-up questions:\n" + "\n".join(questions) | |
| except ValueError as e: | |
| return f"Error: {str(e)}" | |
| except Exception as e: | |
| return f"Unexpected Error: {str(e)}" | |
| def diagnose_patient(self, patient_id: str) -> str: | |
| """ | |
| Doctor tool to diagnose the patient based on symptoms. | |
| Args: | |
| patient_id: Unique identifier for the patient. | |
| Returns: | |
| Diagnosis and recommendations. | |
| """ | |
| try: | |
| if patient_id not in self.patient_data: | |
| raise ValueError("No symptoms found. Nurse must collect symptoms first.") | |
| symptoms = self.patient_data[patient_id]["symptoms"] | |
| diagnosis = self.fetch_diagnosis(symptoms) | |
| medication = self.fetch_medication(symptoms) | |
| advice = self.fetch_treatment_advice(symptoms) | |
| return (f"Doctor: Based on the symptoms: {symptoms},\n" | |
| f"Diagnosis: {diagnosis}\n" | |
| f"Medication: {medication}\n" | |
| f"Advice: {advice}") | |
| except ValueError as e: | |
| return f"Error: {str(e)}" | |
| except Exception as e: | |
| return f"Unexpected Error: {str(e)}" | |
| def fetch_diagnosis(self, symptoms: str) -> str: | |
| """ | |
| AI tool to retrieve a diagnosis based on symptoms. | |
| Args: | |
| symptoms: The symptoms provided by the patient. | |
| Returns: | |
| Detailed Diagnosis information. | |
| """ | |
| try: | |
| if not symptoms: | |
| raise ValueError("Symptoms must be provided.") | |
| search_query = f"A detailed medical diagnosis for these symptoms: {symptoms}" | |
| return self.search_web(search_query) | |
| except Exception as e: | |
| return f"Error fetching diagnosis: {str(e)}" | |
| def fetch_medication(self, symptoms: str) -> str: | |
| """ | |
| AI tool to suggest medications based on symptoms. | |
| Args: | |
| symptoms: The symptoms provided by the patient. | |
| Returns: | |
| Suggested medication. | |
| """ | |
| try: | |
| if not symptoms: | |
| raise ValueError("Symptoms must be provided.") | |
| search_query = f"Recommended medications for these symptoms: {symptoms}" | |
| return self.search_web(search_query) | |
| except Exception as e: | |
| return f"Error fetching medication: {str(e)}" | |
| def fetch_treatment_advice(self, symptoms: str) -> str: | |
| """ | |
| AI tool to provide treatment recommendations. | |
| Args: | |
| symptoms: The symptoms provided by the patient. | |
| Returns: | |
| Recommended treatment and advice. | |
| """ | |
| try: | |
| if not symptoms: | |
| raise ValueError("Symptoms must be provided.") | |
| search_query = f"Treatment advice and recommendations for these symptoms: {symptoms}" | |
| return self.search_web(search_query) | |
| except Exception as e: | |
| return f"Error fetching treatment advice: {str(e)}" | |
| def get_current_time_in_timezone(timezone: str) -> str: | |
| """A tool that fetches the current local time in a specified timezone. | |
| Args: | |
| timezone: A string representing a valid timezone (e.g., 'America/New_York'). | |
| """ | |
| try: | |
| tz = pytz.timezone(timezone) | |
| local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") | |
| return f"The current local time in {timezone} is: {local_time}" | |
| except Exception as e: | |
| return f"Error fetching time for timezone '{timezone}': {str(e)}" | |
| class HealthAssistant: | |
| def __init__(self): | |
| self.medical_agent = MedicalAgent() | |
| self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") | |
| # Set padding token | |
| self.tokenizer.pad_token = self.tokenizer.eos_token | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| "microsoft/phi-2", | |
| torch_dtype=torch.float16, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Set model's padding token | |
| self.model.config.pad_token_id = self.tokenizer.pad_token_id | |
| def process_message(self, message: str) -> str: | |
| # Tokenize with attention mask | |
| inputs = self.tokenizer( | |
| f"User: {message}\nAssistant:", | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| return_attention_mask=True | |
| ) | |
| # Generate response | |
| outputs = self.model.generate( | |
| input_ids=inputs["input_ids"], | |
| attention_mask=inputs["attention_mask"], | |
| max_length=512, | |
| pad_token_id=self.tokenizer.eos_token_id | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract only the assistant's response | |
| response = response.split("Assistant:")[-1].strip() | |
| # Check if we need to collect symptoms or diagnose | |
| if "symptoms" in message.lower(): | |
| response += "\n\n" + self.medical_agent.collect_symptoms("user_1", message) | |
| elif "diagnose" in message.lower(): | |
| response += "\n\n" + self.medical_agent.diagnose_patient("user_1") | |
| return response | |
| def chat_interface(message, history): | |
| assistant = HealthAssistant() | |
| response = assistant.process_message(message) | |
| return response | |
| if __name__ == "__main__": | |
| iface = gr.ChatInterface( | |
| fn=chat_interface, | |
| title="Health Assistant", | |
| description="Chat with our AI health assistant to discuss your symptoms and get medical advice. Note: This is for educational purposes only and should not replace professional medical advice." | |
| ) | |
| iface.launch() |