# agent.py import os from typing import Dict, List, Any, Optional from langgraph.graph import Graph from langchain.schema import BaseMessage, HumanMessage, AIMessage from langchain_community.llms import HuggingFacePipeline from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch from sentence_transformers import SentenceTransformer import faiss import numpy as np import json class MedicalAgent: def __init__(self): self.llm = self._load_huatuogpt() self.rag = MedicalRAG() self.conversation_history = [] self.question_count = 0 self.max_questions = 3 self.max_words_per_question = 5 def _load_huatuogpt(self): """Load HuatuoGPT model from HuggingFace""" model_name = "HuatuoGPT/HuatuoGPT-7B" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16, device_map="auto" ) pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, max_new_tokens=100, temperature=0.7, do_sample=True ) return HuggingFacePipeline(pipeline=pipe) def process_patient_input(self, patient_text: str) -> Dict[str, Any]: """Process patient input and generate response""" self.conversation_history.append(f"Patient: {patient_text}") # Check if we've reached question limit if self.question_count >= self.max_questions: return self._generate_summary() # Analyze symptoms and decide next action analysis = self._analyze_symptoms(patient_text) if analysis["needs_follow_up"]: follow_up_question = self._generate_follow_up_question(analysis) self.question_count += 1 return { "type": "question", "content": follow_up_question, "question_count": self.question_count } else: return self._generate_summary() def _analyze_symptoms(self, patient_text: str) -> Dict[str, Any]: """Analyze symptoms using RAG and LLM""" # Search medical knowledge base relevant_info = self.rag.search(patient_text, k=3) prompt = f""" Patient complaint: {patient_text} Relevant medical information: {relevant_info} Analyze the symptoms and determine: 1. If we need follow-up questions (True/False) 2. What key information is missing 3. Suggested follow-up questions (max 5 words each) Respond in JSON format: {{ "needs_follow_up": boolean, "missing_info": list, "possible_questions": list }} """ response = self.llm(prompt) try: analysis = json.loads(response) except: analysis = { "needs_follow_up": True, "missing_info": ["symptom details"], "possible_questions": ["How long have headache?", "Any other symptoms?"] } return analysis def _generate_follow_up_question(self, analysis: Dict) -> str: """Generate concise follow-up question""" possible_questions = analysis.get("possible_questions", []) if possible_questions: question = possible_questions[0] # Ensure question is within word limit words = question.split()[:self.max_words_per_question] return " ".join(words) else: return "Any other symptoms?" def _generate_summary(self) -> Dict[str, Any]: """Generate summary for doctor""" conversation_text = "\n".join(self.conversation_history) prompt = f""" Patient conversation: {conversation_text} Generate a concise medical summary for the doctor including: - Main symptoms - Key findings - Suggested preliminary diagnosis - Recommended tests if any Keep it under 150 words. """ summary = self.llm(prompt) return { "type": "summary", "content": summary, "question_count": self.question_count } def process_doctor_question(self, doctor_text: str) -> str: """Process doctor's follow-up questions""" prompt = f""" Doctor's question: {doctor_text} Rephrase this question to be clear and simple for the patient. Keep it under 5 words and make it easy to understand. """ simplified_question = self.llm(prompt) return simplified_question class MedicalRAG: def __init__(self): self.encoder = SentenceTransformer('all-MiniLM-L6-v2') self.index = faiss.IndexFlatL2(384) self.knowledge_base = [] def add_medical_knowledge(self, documents: List[str]): """Add medical knowledge documents to RAG""" self.knowledge_base.extend(documents) embeddings = self.encoder.encode(documents) self.index.add(np.array(embeddings)) def search(self, query: str, k: int = 3) -> List[str]: """Search medical knowledge base""" query_embedding = self.encoder.encode([query]) distances, indices = self.index.search(query_embedding, k) results = [] for idx in indices[0]: if idx < len(self.knowledge_base): results.append(self.knowledge_base[idx]) return results