from duckduckgo_search import DDGS import datetime import requests import pytz import yaml import gradio as gr from transformers import AutoModelForCausalLM, AutoTokenizer import torch class MedicalAgent: def __init__(self): self.patient_data = {} self.ddgs = DDGS() def search_web(self, query: str) -> str: try: results = list(self.ddgs.text(query, max_results=2)) return "\n".join([result['body'] for result in results]) except Exception as e: return f"Error searching: {str(e)}" def collect_symptoms(self, patient_id: str, symptoms: str) -> str: """ Nurse tool to collect symptoms and ask follow-up questions. Args: patient_id: Unique identifier for the patient. symptoms: Initial symptoms provided by the patient. Returns: Follow-up questions based on the symptoms. """ try: if not patient_id or not symptoms: raise ValueError("Patient ID and symptoms must be provided.") self.patient_data[patient_id] = {"symptoms": symptoms, "additional_info": {}} questions = [ "How long have you had these symptoms?", "Do you have any allergies?", "Are you taking any medications?", "Have you experienced these symptoms before?", "Have you had any recent illnesses?", "Have you noticed any other unusual changes?", "What is your medical history related to these symptoms?" ] return f"Nurse: I have noted the symptoms ({symptoms}). Here are follow-up questions:\n" + "\n".join(questions) except ValueError as e: return f"Error: {str(e)}" except Exception as e: return f"Unexpected Error: {str(e)}" def diagnose_patient(self, patient_id: str) -> str: """ Doctor tool to diagnose the patient based on symptoms. Args: patient_id: Unique identifier for the patient. Returns: Diagnosis and recommendations. """ try: if patient_id not in self.patient_data: raise ValueError("No symptoms found. Nurse must collect symptoms first.") symptoms = self.patient_data[patient_id]["symptoms"] diagnosis = self.fetch_diagnosis(symptoms) medication = self.fetch_medication(symptoms) advice = self.fetch_treatment_advice(symptoms) return (f"Doctor: Based on the symptoms: {symptoms},\n" f"Diagnosis: {diagnosis}\n" f"Medication: {medication}\n" f"Advice: {advice}") except ValueError as e: return f"Error: {str(e)}" except Exception as e: return f"Unexpected Error: {str(e)}" def fetch_diagnosis(self, symptoms: str) -> str: """ AI tool to retrieve a diagnosis based on symptoms. Args: symptoms: The symptoms provided by the patient. Returns: Detailed Diagnosis information. """ try: if not symptoms: raise ValueError("Symptoms must be provided.") search_query = f"A detailed medical diagnosis for these symptoms: {symptoms}" return self.search_web(search_query) except Exception as e: return f"Error fetching diagnosis: {str(e)}" def fetch_medication(self, symptoms: str) -> str: """ AI tool to suggest medications based on symptoms. Args: symptoms: The symptoms provided by the patient. Returns: Suggested medication. """ try: if not symptoms: raise ValueError("Symptoms must be provided.") search_query = f"Recommended medications for these symptoms: {symptoms}" return self.search_web(search_query) except Exception as e: return f"Error fetching medication: {str(e)}" def fetch_treatment_advice(self, symptoms: str) -> str: """ AI tool to provide treatment recommendations. Args: symptoms: The symptoms provided by the patient. Returns: Recommended treatment and advice. """ try: if not symptoms: raise ValueError("Symptoms must be provided.") search_query = f"Treatment advice and recommendations for these symptoms: {symptoms}" return self.search_web(search_query) except Exception as e: return f"Error fetching treatment advice: {str(e)}" def get_current_time_in_timezone(timezone: str) -> str: """A tool that fetches the current local time in a specified timezone. Args: timezone: A string representing a valid timezone (e.g., 'America/New_York'). """ try: tz = pytz.timezone(timezone) local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S") return f"The current local time in {timezone} is: {local_time}" except Exception as e: return f"Error fetching time for timezone '{timezone}': {str(e)}" class HealthAssistant: def __init__(self): self.medical_agent = MedicalAgent() self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2") # Set padding token self.tokenizer.pad_token = self.tokenizer.eos_token self.model = AutoModelForCausalLM.from_pretrained( "microsoft/phi-2", torch_dtype=torch.float16, device_map="auto", trust_remote_code=True ) # Set model's padding token self.model.config.pad_token_id = self.tokenizer.pad_token_id def process_message(self, message: str) -> str: # Tokenize with attention mask inputs = self.tokenizer( f"User: {message}\nAssistant:", return_tensors="pt", padding=True, truncation=True, return_attention_mask=True ) # Generate response outputs = self.model.generate( input_ids=inputs["input_ids"], attention_mask=inputs["attention_mask"], max_length=512, pad_token_id=self.tokenizer.eos_token_id ) response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract only the assistant's response response = response.split("Assistant:")[-1].strip() # Check if we need to collect symptoms or diagnose if "symptoms" in message.lower(): response += "\n\n" + self.medical_agent.collect_symptoms("user_1", message) elif "diagnose" in message.lower(): response += "\n\n" + self.medical_agent.diagnose_patient("user_1") return response def chat_interface(message, history): assistant = HealthAssistant() response = assistant.process_message(message) return response if __name__ == "__main__": iface = gr.ChatInterface( fn=chat_interface, title="Health Assistant", description="Chat with our AI health assistant to discuss your symptoms and get medical advice. Note: This is for educational purposes only and should not replace professional medical advice." ) iface.launch()