Spaces:
Runtime error
Runtime error
| # agent.py | |
| import os | |
| from typing import Dict, List, Any, Optional | |
| from langgraph.graph import Graph | |
| from langchain.schema import BaseMessage, HumanMessage, AIMessage | |
| from langchain_community.llms import HuggingFacePipeline | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| import numpy as np | |
| import json | |
| class MedicalAgent: | |
| def __init__(self): | |
| self.llm = self._load_huatuogpt() | |
| self.rag = MedicalRAG() | |
| self.conversation_history = [] | |
| self.question_count = 0 | |
| self.max_questions = 3 | |
| self.max_words_per_question = 5 | |
| def _load_huatuogpt(self): | |
| """Load HuatuoGPT model from HuggingFace""" | |
| model_name = "HuatuoGPT/HuatuoGPT-7B" | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| max_new_tokens=100, | |
| temperature=0.7, | |
| do_sample=True | |
| ) | |
| return HuggingFacePipeline(pipeline=pipe) | |
| def process_patient_input(self, patient_text: str) -> Dict[str, Any]: | |
| """Process patient input and generate response""" | |
| self.conversation_history.append(f"Patient: {patient_text}") | |
| # Check if we've reached question limit | |
| if self.question_count >= self.max_questions: | |
| return self._generate_summary() | |
| # Analyze symptoms and decide next action | |
| analysis = self._analyze_symptoms(patient_text) | |
| if analysis["needs_follow_up"]: | |
| follow_up_question = self._generate_follow_up_question(analysis) | |
| self.question_count += 1 | |
| return { | |
| "type": "question", | |
| "content": follow_up_question, | |
| "question_count": self.question_count | |
| } | |
| else: | |
| return self._generate_summary() | |
| def _analyze_symptoms(self, patient_text: str) -> Dict[str, Any]: | |
| """Analyze symptoms using RAG and LLM""" | |
| # Search medical knowledge base | |
| relevant_info = self.rag.search(patient_text, k=3) | |
| prompt = f""" | |
| Patient complaint: {patient_text} | |
| Relevant medical information: {relevant_info} | |
| Analyze the symptoms and determine: | |
| 1. If we need follow-up questions (True/False) | |
| 2. What key information is missing | |
| 3. Suggested follow-up questions (max 5 words each) | |
| Respond in JSON format: | |
| {{ | |
| "needs_follow_up": boolean, | |
| "missing_info": list, | |
| "possible_questions": list | |
| }} | |
| """ | |
| response = self.llm(prompt) | |
| try: | |
| analysis = json.loads(response) | |
| except: | |
| analysis = { | |
| "needs_follow_up": True, | |
| "missing_info": ["symptom details"], | |
| "possible_questions": ["How long have headache?", "Any other symptoms?"] | |
| } | |
| return analysis | |
| def _generate_follow_up_question(self, analysis: Dict) -> str: | |
| """Generate concise follow-up question""" | |
| possible_questions = analysis.get("possible_questions", []) | |
| if possible_questions: | |
| question = possible_questions[0] | |
| # Ensure question is within word limit | |
| words = question.split()[:self.max_words_per_question] | |
| return " ".join(words) | |
| else: | |
| return "Any other symptoms?" | |
| def _generate_summary(self) -> Dict[str, Any]: | |
| """Generate summary for doctor""" | |
| conversation_text = "\n".join(self.conversation_history) | |
| prompt = f""" | |
| Patient conversation: | |
| {conversation_text} | |
| Generate a concise medical summary for the doctor including: | |
| - Main symptoms | |
| - Key findings | |
| - Suggested preliminary diagnosis | |
| - Recommended tests if any | |
| Keep it under 150 words. | |
| """ | |
| summary = self.llm(prompt) | |
| return { | |
| "type": "summary", | |
| "content": summary, | |
| "question_count": self.question_count | |
| } | |
| def process_doctor_question(self, doctor_text: str) -> str: | |
| """Process doctor's follow-up questions""" | |
| prompt = f""" | |
| Doctor's question: {doctor_text} | |
| Rephrase this question to be clear and simple for the patient. | |
| Keep it under 5 words and make it easy to understand. | |
| """ | |
| simplified_question = self.llm(prompt) | |
| return simplified_question | |
| class MedicalRAG: | |
| def __init__(self): | |
| self.encoder = SentenceTransformer('all-MiniLM-L6-v2') | |
| self.index = faiss.IndexFlatL2(384) | |
| self.knowledge_base = [] | |
| def add_medical_knowledge(self, documents: List[str]): | |
| """Add medical knowledge documents to RAG""" | |
| self.knowledge_base.extend(documents) | |
| embeddings = self.encoder.encode(documents) | |
| self.index.add(np.array(embeddings)) | |
| def search(self, query: str, k: int = 3) -> List[str]: | |
| """Search medical knowledge base""" | |
| query_embedding = self.encoder.encode([query]) | |
| distances, indices = self.index.search(query_embedding, k) | |
| results = [] | |
| for idx in indices[0]: | |
| if idx < len(self.knowledge_base): | |
| results.append(self.knowledge_base[idx]) | |
| return results |