Health / app.py
Chinez-dev's picture
Update app.py
1faa320 verified
from duckduckgo_search import DDGS
import datetime
import requests
import pytz
import yaml
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
class MedicalAgent:
def __init__(self):
self.patient_data = {}
self.ddgs = DDGS()
def search_web(self, query: str) -> str:
try:
results = list(self.ddgs.text(query, max_results=2))
return "\n".join([result['body'] for result in results])
except Exception as e:
return f"Error searching: {str(e)}"
def collect_symptoms(self, patient_id: str, symptoms: str) -> str:
"""
Nurse tool to collect symptoms and ask follow-up questions.
Args:
patient_id: Unique identifier for the patient.
symptoms: Initial symptoms provided by the patient.
Returns:
Follow-up questions based on the symptoms.
"""
try:
if not patient_id or not symptoms:
raise ValueError("Patient ID and symptoms must be provided.")
self.patient_data[patient_id] = {"symptoms": symptoms, "additional_info": {}}
questions = [
"How long have you had these symptoms?",
"Do you have any allergies?",
"Are you taking any medications?",
"Have you experienced these symptoms before?",
"Have you had any recent illnesses?",
"Have you noticed any other unusual changes?",
"What is your medical history related to these symptoms?"
]
return f"Nurse: I have noted the symptoms ({symptoms}). Here are follow-up questions:\n" + "\n".join(questions)
except ValueError as e:
return f"Error: {str(e)}"
except Exception as e:
return f"Unexpected Error: {str(e)}"
def diagnose_patient(self, patient_id: str) -> str:
"""
Doctor tool to diagnose the patient based on symptoms.
Args:
patient_id: Unique identifier for the patient.
Returns:
Diagnosis and recommendations.
"""
try:
if patient_id not in self.patient_data:
raise ValueError("No symptoms found. Nurse must collect symptoms first.")
symptoms = self.patient_data[patient_id]["symptoms"]
diagnosis = self.fetch_diagnosis(symptoms)
medication = self.fetch_medication(symptoms)
advice = self.fetch_treatment_advice(symptoms)
return (f"Doctor: Based on the symptoms: {symptoms},\n"
f"Diagnosis: {diagnosis}\n"
f"Medication: {medication}\n"
f"Advice: {advice}")
except ValueError as e:
return f"Error: {str(e)}"
except Exception as e:
return f"Unexpected Error: {str(e)}"
def fetch_diagnosis(self, symptoms: str) -> str:
"""
AI tool to retrieve a diagnosis based on symptoms.
Args:
symptoms: The symptoms provided by the patient.
Returns:
Detailed Diagnosis information.
"""
try:
if not symptoms:
raise ValueError("Symptoms must be provided.")
search_query = f"A detailed medical diagnosis for these symptoms: {symptoms}"
return self.search_web(search_query)
except Exception as e:
return f"Error fetching diagnosis: {str(e)}"
def fetch_medication(self, symptoms: str) -> str:
"""
AI tool to suggest medications based on symptoms.
Args:
symptoms: The symptoms provided by the patient.
Returns:
Suggested medication.
"""
try:
if not symptoms:
raise ValueError("Symptoms must be provided.")
search_query = f"Recommended medications for these symptoms: {symptoms}"
return self.search_web(search_query)
except Exception as e:
return f"Error fetching medication: {str(e)}"
def fetch_treatment_advice(self, symptoms: str) -> str:
"""
AI tool to provide treatment recommendations.
Args:
symptoms: The symptoms provided by the patient.
Returns:
Recommended treatment and advice.
"""
try:
if not symptoms:
raise ValueError("Symptoms must be provided.")
search_query = f"Treatment advice and recommendations for these symptoms: {symptoms}"
return self.search_web(search_query)
except Exception as e:
return f"Error fetching treatment advice: {str(e)}"
def get_current_time_in_timezone(timezone: str) -> str:
"""A tool that fetches the current local time in a specified timezone.
Args:
timezone: A string representing a valid timezone (e.g., 'America/New_York').
"""
try:
tz = pytz.timezone(timezone)
local_time = datetime.datetime.now(tz).strftime("%Y-%m-%d %H:%M:%S")
return f"The current local time in {timezone} is: {local_time}"
except Exception as e:
return f"Error fetching time for timezone '{timezone}': {str(e)}"
class HealthAssistant:
def __init__(self):
self.medical_agent = MedicalAgent()
self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2")
# Set padding token
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model = AutoModelForCausalLM.from_pretrained(
"microsoft/phi-2",
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
)
# Set model's padding token
self.model.config.pad_token_id = self.tokenizer.pad_token_id
def process_message(self, message: str) -> str:
# Tokenize with attention mask
inputs = self.tokenizer(
f"User: {message}\nAssistant:",
return_tensors="pt",
padding=True,
truncation=True,
return_attention_mask=True
)
# Generate response
outputs = self.model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=512,
pad_token_id=self.tokenizer.eos_token_id
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract only the assistant's response
response = response.split("Assistant:")[-1].strip()
# Check if we need to collect symptoms or diagnose
if "symptoms" in message.lower():
response += "\n\n" + self.medical_agent.collect_symptoms("user_1", message)
elif "diagnose" in message.lower():
response += "\n\n" + self.medical_agent.diagnose_patient("user_1")
return response
def chat_interface(message, history):
assistant = HealthAssistant()
response = assistant.process_message(message)
return response
if __name__ == "__main__":
iface = gr.ChatInterface(
fn=chat_interface,
title="Health Assistant",
description="Chat with our AI health assistant to discuss your symptoms and get medical advice. Note: This is for educational purposes only and should not replace professional medical advice."
)
iface.launch()