Mr-HASSAN commited on
Commit
c7d192d
·
verified ·
1 Parent(s): fad6fa5

Upload agent.py

Browse files
Files changed (1) hide show
  1. agent.py +171 -0
agent.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # agent.py
2
+ import os
3
+ from typing import Dict, List, Any, Optional
4
+ from langgraph.graph import Graph
5
+ from langchain.schema import BaseMessage, HumanMessage, AIMessage
6
+ from langchain_community.llms import HuggingFacePipeline
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+ import torch
9
+ from sentence_transformers import SentenceTransformer
10
+ import faiss
11
+ import numpy as np
12
+ import json
13
+
14
+
15
+ class MedicalAgent:
16
+ def __init__(self):
17
+ self.llm = self._load_huatuogpt()
18
+ self.rag = MedicalRAG()
19
+ self.conversation_history = []
20
+ self.question_count = 0
21
+ self.max_questions = 3
22
+ self.max_words_per_question = 5
23
+
24
+ def _load_huatuogpt(self):
25
+ """Load HuatuoGPT model from HuggingFace"""
26
+ model_name = "HuatuoGPT/HuatuoGPT-7B"
27
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
28
+ model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float16,
31
+ device_map="auto"
32
+ )
33
+
34
+ pipe = pipeline(
35
+ "text-generation",
36
+ model=model,
37
+ tokenizer=tokenizer,
38
+ max_new_tokens=100,
39
+ temperature=0.7,
40
+ do_sample=True
41
+ )
42
+
43
+ return HuggingFacePipeline(pipeline=pipe)
44
+
45
+ def process_patient_input(self, patient_text: str) -> Dict[str, Any]:
46
+ """Process patient input and generate response"""
47
+ self.conversation_history.append(f"Patient: {patient_text}")
48
+
49
+ # Check if we've reached question limit
50
+ if self.question_count >= self.max_questions:
51
+ return self._generate_summary()
52
+
53
+ # Analyze symptoms and decide next action
54
+ analysis = self._analyze_symptoms(patient_text)
55
+
56
+ if analysis["needs_follow_up"]:
57
+ follow_up_question = self._generate_follow_up_question(analysis)
58
+ self.question_count += 1
59
+ return {
60
+ "type": "question",
61
+ "content": follow_up_question,
62
+ "question_count": self.question_count
63
+ }
64
+ else:
65
+ return self._generate_summary()
66
+
67
+ def _analyze_symptoms(self, patient_text: str) -> Dict[str, Any]:
68
+ """Analyze symptoms using RAG and LLM"""
69
+ # Search medical knowledge base
70
+ relevant_info = self.rag.search(patient_text, k=3)
71
+
72
+ prompt = f"""
73
+ Patient complaint: {patient_text}
74
+ Relevant medical information: {relevant_info}
75
+
76
+ Analyze the symptoms and determine:
77
+ 1. If we need follow-up questions (True/False)
78
+ 2. What key information is missing
79
+ 3. Suggested follow-up questions (max 5 words each)
80
+
81
+ Respond in JSON format:
82
+ {{
83
+ "needs_follow_up": boolean,
84
+ "missing_info": list,
85
+ "possible_questions": list
86
+ }}
87
+ """
88
+
89
+ response = self.llm(prompt)
90
+ try:
91
+ analysis = json.loads(response)
92
+ except:
93
+ analysis = {
94
+ "needs_follow_up": True,
95
+ "missing_info": ["symptom details"],
96
+ "possible_questions": ["How long have headache?", "Any other symptoms?"]
97
+ }
98
+
99
+ return analysis
100
+
101
+ def _generate_follow_up_question(self, analysis: Dict) -> str:
102
+ """Generate concise follow-up question"""
103
+ possible_questions = analysis.get("possible_questions", [])
104
+ if possible_questions:
105
+ question = possible_questions[0]
106
+ # Ensure question is within word limit
107
+ words = question.split()[:self.max_words_per_question]
108
+ return " ".join(words)
109
+ else:
110
+ return "Any other symptoms?"
111
+
112
+ def _generate_summary(self) -> Dict[str, Any]:
113
+ """Generate summary for doctor"""
114
+ conversation_text = "\n".join(self.conversation_history)
115
+
116
+ prompt = f"""
117
+ Patient conversation:
118
+ {conversation_text}
119
+
120
+ Generate a concise medical summary for the doctor including:
121
+ - Main symptoms
122
+ - Key findings
123
+ - Suggested preliminary diagnosis
124
+ - Recommended tests if any
125
+
126
+ Keep it under 150 words.
127
+ """
128
+
129
+ summary = self.llm(prompt)
130
+ return {
131
+ "type": "summary",
132
+ "content": summary,
133
+ "question_count": self.question_count
134
+ }
135
+
136
+ def process_doctor_question(self, doctor_text: str) -> str:
137
+ """Process doctor's follow-up questions"""
138
+ prompt = f"""
139
+ Doctor's question: {doctor_text}
140
+
141
+ Rephrase this question to be clear and simple for the patient.
142
+ Keep it under 5 words and make it easy to understand.
143
+ """
144
+
145
+ simplified_question = self.llm(prompt)
146
+ return simplified_question
147
+
148
+
149
+ class MedicalRAG:
150
+ def __init__(self):
151
+ self.encoder = SentenceTransformer('all-MiniLM-L6-v2')
152
+ self.index = faiss.IndexFlatL2(384)
153
+ self.knowledge_base = []
154
+
155
+ def add_medical_knowledge(self, documents: List[str]):
156
+ """Add medical knowledge documents to RAG"""
157
+ self.knowledge_base.extend(documents)
158
+ embeddings = self.encoder.encode(documents)
159
+ self.index.add(np.array(embeddings))
160
+
161
+ def search(self, query: str, k: int = 3) -> List[str]:
162
+ """Search medical knowledge base"""
163
+ query_embedding = self.encoder.encode([query])
164
+ distances, indices = self.index.search(query_embedding, k)
165
+
166
+ results = []
167
+ for idx in indices[0]:
168
+ if idx < len(self.knowledge_base):
169
+ results.append(self.knowledge_base[idx])
170
+
171
+ return results