Hanan-Alnakhal commited on
Commit
10eddac
·
verified ·
1 Parent(s): 98542aa

edited model

Browse files
Files changed (1) hide show
  1. rag_engine.py +250 -252
rag_engine.py CHANGED
@@ -1,253 +1,251 @@
1
- """
2
- RAG Query Engine for Lab Report Decoder
3
- Uses Hugging Face models for embeddings and generation
4
- """
5
-
6
- from sentence_transformers import SentenceTransformer
7
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
8
- import chromadb
9
- from chromadb.config import Settings
10
- from typing import List, Dict
11
- from pdf_extractor import LabResult
12
- import torch
13
-
14
- class LabReportRAG:
15
- """RAG system for explaining lab results using Hugging Face models"""
16
-
17
- def __init__(self, db_path: str = "./chroma_db"):
18
- """Initialize the RAG system with Hugging Face models"""
19
-
20
- print("🔄 Loading Hugging Face models...")
21
-
22
- # Use smaller, faster models for embeddings
23
- self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
24
-
25
- # Use a medical-focused or general LLM
26
- # Options:
27
- # - "microsoft/Phi-3-mini-4k-instruct" (good balance)
28
- # - "google/flan-t5-base" (lighter)
29
- # - "meta-llama/Llama-2-7b-chat-hf" (requires auth)
30
-
31
- model_name = "microsoft/Phi-3-mini-4k-instruct"
32
-
33
- try:
34
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
35
- self.llm = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
- trust_remote_code=True,
38
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
- device_map="auto" if torch.cuda.is_available() else None
40
- )
41
- print(f"✅ Loaded model: {model_name}")
42
- except Exception as e:
43
- print(f"⚠️ Could not load {model_name}, falling back to simpler model")
44
- # Fallback to lighter model
45
- self.text_generator = pipeline(
46
- "text-generation",
47
- model="google/flan-t5-base",
48
- max_length=512
49
- )
50
- self.llm = None
51
-
52
- # Load vector store
53
- try:
54
- self.client = chromadb.PersistentClient(path=db_path)
55
- self.collection = self.client.get_collection("lab_reports")
56
- print("✅ Vector database loaded")
57
- except Exception as e:
58
- print(f"⚠️ No vector database found. Please run build_vector_db.py first.")
59
- self.collection = None
60
-
61
- def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str:
62
- """Generate text using Phi-3 model"""
63
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
64
-
65
- if torch.cuda.is_available():
66
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
67
-
68
- outputs = self.llm.generate(
69
- **inputs,
70
- max_new_tokens=max_tokens,
71
- temperature=0.7,
72
- do_sample=True,
73
- top_p=0.9
74
- )
75
-
76
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
- # Remove the prompt from response
78
- response = response.replace(prompt, "").strip()
79
- return response
80
-
81
- def _generate_with_fallback(self, prompt: str) -> str:
82
- """Generate text using fallback pipeline"""
83
- result = self.text_generator(prompt, max_length=512, num_return_sequences=1)
84
- return result[0]['generated_text']
85
-
86
- def _generate_text(self, prompt: str) -> str:
87
- """Generate text using available model"""
88
- try:
89
- if self.llm is not None:
90
- return self._generate_with_phi(prompt)
91
- else:
92
- return self._generate_with_fallback(prompt)
93
- except Exception as e:
94
- print(f"Generation error: {e}")
95
- return "Sorry, I encountered an error generating the explanation."
96
-
97
- def _retrieve_context(self, query: str, k: int = 3) -> str:
98
- """Retrieve relevant context from vector database"""
99
- if self.collection is None:
100
- return "No medical reference data available."
101
-
102
- try:
103
- # Create query embedding
104
- query_embedding = self.embedding_model.encode(query).tolist()
105
-
106
- # Query the collection
107
- results = self.collection.query(
108
- query_embeddings=[query_embedding],
109
- n_results=k
110
- )
111
-
112
- # Combine documents
113
- if results and results['documents']:
114
- context = "\n\n".join(results['documents'][0])
115
- return context
116
- else:
117
- return "No relevant information found."
118
- except Exception as e:
119
- print(f"Retrieval error: {e}")
120
- return "Error retrieving medical information."
121
-
122
- def explain_result(self, result: LabResult) -> str:
123
- """Generate explanation for a single lab result"""
124
-
125
- # Retrieve relevant context
126
- query = f"{result.test_name} {result.status} meaning causes treatment"
127
- context = self._retrieve_context(query, k=3)
128
-
129
- # Create prompt
130
- prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms.
131
-
132
- Medical Information:
133
- {context}
134
-
135
- Lab Test: {result.test_name}
136
- Value: {result.value} {result.unit}
137
- Reference Range: {result.reference_range}
138
- Status: {result.status}
139
-
140
- Please explain:
141
- 1. What this test measures
142
- 2. What this result means
143
- 3. Possible causes if abnormal
144
- 4. Dietary recommendations if applicable
145
-
146
- Keep it simple and clear. Answer:"""
147
-
148
- # Generate explanation
149
- explanation = self._generate_text(prompt)
150
-
151
- return explanation
152
-
153
- def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
154
- """Generate explanations for all lab results"""
155
- explanations = {}
156
-
157
- for result in results:
158
- print(f"Explaining {result.test_name}...")
159
- explanation = self.explain_result(result)
160
- explanations[result.test_name] = explanation
161
-
162
- return explanations
163
-
164
- def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
165
- """Answer follow-up questions about lab results"""
166
-
167
- # Create context from lab results
168
- results_context = "\n".join([
169
- f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})"
170
- for r in lab_results
171
- ])
172
-
173
- # Retrieve relevant medical information
174
- medical_context = self._retrieve_context(question, k=3)
175
-
176
- # Create prompt
177
- prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information.
178
-
179
- Patient's Lab Results:
180
- {results_context}
181
-
182
- Medical Information:
183
- {medical_context}
184
-
185
- Question: {question}
186
-
187
- Provide a clear, helpful answer. Answer:"""
188
-
189
- # Generate answer
190
- answer = self._generate_text(prompt)
191
-
192
- return answer
193
-
194
- def generate_summary(self, results: List[LabResult]) -> str:
195
- """Generate overall summary of lab results"""
196
-
197
- abnormal = [r for r in results if r.status in ['high', 'low']]
198
- normal = [r for r in results if r.status == 'normal']
199
-
200
- if not abnormal:
201
- return "✅ Great news! All your lab results are within normal ranges. Keep up the good work with your health!"
202
-
203
- # Get context about abnormal results
204
- queries = [f"{r.test_name} {r.status}" for r in abnormal]
205
- combined_query = " ".join(queries)
206
- context = self._retrieve_context(combined_query, k=4)
207
-
208
- # Create summary prompt
209
- abnormal_list = "\n".join([
210
- f"- {r.test_name}: {r.value} {r.unit} ({r.status})"
211
- for r in abnormal
212
- ])
213
-
214
- prompt = f"""Provide a brief summary of these lab results.
215
-
216
- Normal Results: {len(normal)} tests
217
- Abnormal Results: {len(abnormal)} tests
218
-
219
- Abnormal Tests:
220
- {abnormal_list}
221
-
222
- Medical Context:
223
- {context}
224
-
225
- Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:"""
226
-
227
- # Generate summary
228
- summary = self._generate_text(prompt)
229
-
230
- return summary
231
-
232
-
233
- # Example usage
234
- if __name__ == "__main__":
235
- from pdf_extractor import LabResult
236
-
237
- # Initialize RAG system
238
- print("Initializing RAG system...")
239
- rag = LabReportRAG()
240
-
241
- # Example result
242
- test_result = LabResult(
243
- test_name="Hemoglobin",
244
- value="10.5",
245
- unit="g/dL",
246
- reference_range="12.0-15.5",
247
- status="low"
248
- )
249
-
250
- # Generate explanation
251
- print("\nGenerating explanation...")
252
- explanation = rag.explain_result(test_result)
253
  print(f"\n{explanation}")
 
1
+ """
2
+ RAG Query Engine for Lab Report Decoder
3
+ Uses Hugging Face models for embeddings and generation
4
+ """
5
+
6
+ from sentence_transformers import SentenceTransformer
7
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
8
+ import chromadb
9
+ from chromadb.config import Settings
10
+ from typing import List, Dict
11
+ from pdf_extractor import LabResult
12
+ import torch
13
+
14
+ class LabReportRAG:
15
+ """RAG system for explaining lab results using Hugging Face models"""
16
+
17
+ def __init__(self, db_path: str = "./chroma_db"):
18
+ """Initialize the RAG system with Hugging Face models"""
19
+
20
+ print("🔄 Loading Hugging Face models...")
21
+
22
+ # Use smaller, faster models for embeddings
23
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
24
+
25
+ # Use a medical-focused or general LLM
26
+ # Options:
27
+ # - "microsoft/Phi-3-mini-4k-instruct" (good balance)
28
+ # - "google/flan-t5-base" (lighter)
29
+ # - "meta-llama/Llama-2-7b-chat-hf" (requires auth)
30
+
31
+ model_name = "google/flan-t5-base"
32
+
33
+ try:
34
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
35
+ self.llm = AutoModelForCausalLM.from_pretrained(
36
+ model_name,
37
+ trust_remote_code=True,
38
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
+ device_map="auto" if torch.cuda.is_available() else None
40
+ )
41
+ print(f"✅ Loaded model: {model_name}")
42
+ except Exception as e:
43
+ print(f"⚠️ Could not load {model_name}, falling back to simpler model")
44
+ # Fallback to lighter model
45
+ self.text_generator = pipeline(
46
+ "text-generation",
47
+ model="google/flan-t5-base",
48
+ max_length=512
49
+ )
50
+ self.llm = None
51
+
52
+ # Load vector store
53
+ try:
54
+ self.client = chromadb.PersistentClient(path=db_path)
55
+ self.collection = self.client.get_collection("lab_reports")
56
+ print("✅ Vector database loaded")
57
+ except Exception as e:
58
+ print(f"⚠️ No vector database found. Please run build_vector_db.py first.")
59
+ self.collection = None
60
+
61
+ def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str:
62
+ """Generate text using Phi-3 model"""
63
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
64
+
65
+ if torch.cuda.is_available():
66
+ inputs = {k: v.to('cuda') for k, v in inputs.items()}
67
+
68
+ outputs = self.llm.generate(
69
+ **inputs,
70
+ max_new_tokens=max_tokens,
71
+ temperature=0.7,
72
+ do_sample=True,
73
+ top_p=0.9
74
+ )
75
+
76
+ response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
+ # Remove the prompt from response
78
+ response = response.replace(prompt, "").strip()
79
+ return response
80
+
81
+ def _generate_with_fallback(self, prompt: str) -> str:
82
+ """Generate text using fallback pipeline"""
83
+ result = self.text_generator(prompt, max_length=512, num_return_sequences=1)
84
+ return result[0]['generated_text']
85
+
86
+ def _generate_text(self, prompt: str) -> str:
87
+ try:
88
+ result = self.text_generator(prompt)[0]["generated_text"]
89
+ return result
90
+ except Exception as e:
91
+ print(f"Generation error: {e}")
92
+ return "Sorry, I encountered an error generating the explanation."
93
+
94
+
95
+ def _retrieve_context(self, query: str, k: int = 3) -> str:
96
+ """Retrieve relevant context from vector database"""
97
+ if self.collection is None:
98
+ return "No medical reference data available."
99
+
100
+ try:
101
+ # Create query embedding
102
+ query_embedding = self.embedding_model.encode(query).tolist()
103
+
104
+ # Query the collection
105
+ results = self.collection.query(
106
+ query_embeddings=[query_embedding],
107
+ n_results=k
108
+ )
109
+
110
+ # Combine documents
111
+ if results and results['documents']:
112
+ context = "\n\n".join(results['documents'][0])
113
+ return context
114
+ else:
115
+ return "No relevant information found."
116
+ except Exception as e:
117
+ print(f"Retrieval error: {e}")
118
+ return "Error retrieving medical information."
119
+
120
+ def explain_result(self, result: LabResult) -> str:
121
+ """Generate explanation for a single lab result"""
122
+
123
+ # Retrieve relevant context
124
+ query = f"{result.test_name} {result.status} meaning causes treatment"
125
+ context = self._retrieve_context(query, k=3)
126
+
127
+ # Create prompt
128
+ prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms.
129
+
130
+ Medical Information:
131
+ {context}
132
+
133
+ Lab Test: {result.test_name}
134
+ Value: {result.value} {result.unit}
135
+ Reference Range: {result.reference_range}
136
+ Status: {result.status}
137
+
138
+ Please explain:
139
+ 1. What this test measures
140
+ 2. What this result means
141
+ 3. Possible causes if abnormal
142
+ 4. Dietary recommendations if applicable
143
+
144
+ Keep it simple and clear. Answer:"""
145
+
146
+ # Generate explanation
147
+ explanation = self._generate_text(prompt)
148
+
149
+ return explanation
150
+
151
+ def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
152
+ """Generate explanations for all lab results"""
153
+ explanations = {}
154
+
155
+ for result in results:
156
+ print(f"Explaining {result.test_name}...")
157
+ explanation = self.explain_result(result)
158
+ explanations[result.test_name] = explanation
159
+
160
+ return explanations
161
+
162
+ def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
163
+ """Answer follow-up questions about lab results"""
164
+
165
+ # Create context from lab results
166
+ results_context = "\n".join([
167
+ f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})"
168
+ for r in lab_results
169
+ ])
170
+
171
+ # Retrieve relevant medical information
172
+ medical_context = self._retrieve_context(question, k=3)
173
+
174
+ # Create prompt
175
+ prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information.
176
+
177
+ Patient's Lab Results:
178
+ {results_context}
179
+
180
+ Medical Information:
181
+ {medical_context}
182
+
183
+ Question: {question}
184
+
185
+ Provide a clear, helpful answer. Answer:"""
186
+
187
+ # Generate answer
188
+ answer = self._generate_text(prompt)
189
+
190
+ return answer
191
+
192
+ def generate_summary(self, results: List[LabResult]) -> str:
193
+ """Generate overall summary of lab results"""
194
+
195
+ abnormal = [r for r in results if r.status in ['high', 'low']]
196
+ normal = [r for r in results if r.status == 'normal']
197
+
198
+ if not abnormal:
199
+ return "✅ Great news! All your lab results are within normal ranges. Keep up the good work with your health!"
200
+
201
+ # Get context about abnormal results
202
+ queries = [f"{r.test_name} {r.status}" for r in abnormal]
203
+ combined_query = " ".join(queries)
204
+ context = self._retrieve_context(combined_query, k=4)
205
+
206
+ # Create summary prompt
207
+ abnormal_list = "\n".join([
208
+ f"- {r.test_name}: {r.value} {r.unit} ({r.status})"
209
+ for r in abnormal
210
+ ])
211
+
212
+ prompt = f"""Provide a brief summary of these lab results.
213
+
214
+ Normal Results: {len(normal)} tests
215
+ Abnormal Results: {len(abnormal)} tests
216
+
217
+ Abnormal Tests:
218
+ {abnormal_list}
219
+
220
+ Medical Context:
221
+ {context}
222
+
223
+ Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:"""
224
+
225
+ # Generate summary
226
+ summary = self._generate_text(prompt)
227
+
228
+ return summary
229
+
230
+
231
+ # Example usage
232
+ if __name__ == "__main__":
233
+ from pdf_extractor import LabResult
234
+
235
+ # Initialize RAG system
236
+ print("Initializing RAG system...")
237
+ rag = LabReportRAG()
238
+
239
+ # Example result
240
+ test_result = LabResult(
241
+ test_name="Hemoglobin",
242
+ value="10.5",
243
+ unit="g/dL",
244
+ reference_range="12.0-15.5",
245
+ status="low"
246
+ )
247
+
248
+ # Generate explanation
249
+ print("\nGenerating explanation...")
250
+ explanation = rag.explain_result(test_result)
 
 
251
  print(f"\n{explanation}")