Hanan-Alnakhal commited on
Commit
cad89d4
Β·
verified Β·
1 Parent(s): 7890291

Update rag_engine.py

Browse files
Files changed (1) hide show
  1. rag_engine.py +207 -166
rag_engine.py CHANGED
@@ -1,53 +1,41 @@
1
  """
2
  RAG Query Engine for Lab Report Decoder
3
- Uses Hugging Face models for embeddings and generation
4
  """
5
 
6
  from sentence_transformers import SentenceTransformer
7
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
8
  import chromadb
9
- from chromadb.config import Settings
10
  from typing import List, Dict
11
  from pdf_extractor import LabResult
12
- import torch
13
 
14
  class LabReportRAG:
15
- """RAG system for explaining lab results using Hugging Face models"""
16
 
17
  def __init__(self, db_path: str = "./chroma_db"):
18
- """Initialize the RAG system with Hugging Face models"""
19
 
20
- print("πŸ”„ Loading Hugging Face models...")
21
 
22
- # Use smaller, faster models for embeddings
23
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
24
 
25
- # Use a medical-focused or general LLM
26
- # Options:
27
- # - "microsoft/Phi-3-mini-4k-instruct" (good balance)
28
- # - "google/flan-t5-base" (lighter)
29
- # - "meta-llama/Llama-2-7b-chat-hf" (requires auth)
30
-
31
- model_name = "microsoft/Phi-3-mini-4k-instruct"
32
-
33
  try:
34
- self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
35
- self.llm = AutoModelForCausalLM.from_pretrained(
36
- model_name,
37
- trust_remote_code=True,
38
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
39
- device_map="auto" if torch.cuda.is_available() else None
40
- )
41
- print(f"βœ… Loaded model: {model_name}")
42
- except Exception as e:
43
- print(f"⚠️ Could not load {model_name}, falling back to simpler model")
44
- # Fallback to lighter model
45
  self.text_generator = pipeline(
46
- "text-generation",
47
- model="google/flan-t5-base",
48
- max_length=512
 
49
  )
50
- self.llm = None
 
 
 
51
 
52
  # Load vector store
53
  try:
@@ -55,49 +43,13 @@ class LabReportRAG:
55
  self.collection = self.client.get_collection("lab_reports")
56
  print("βœ… Vector database loaded")
57
  except Exception as e:
58
- print(f"⚠️ No vector database found. Please run build_vector_db.py first.")
59
  self.collection = None
60
 
61
- def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str:
62
- """Generate text using Phi-3 model"""
63
- inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
64
-
65
- if torch.cuda.is_available():
66
- inputs = {k: v.to('cuda') for k, v in inputs.items()}
67
-
68
- outputs = self.llm.generate(
69
- **inputs,
70
- max_new_tokens=max_tokens,
71
- temperature=0.7,
72
- do_sample=True,
73
- top_p=0.9
74
- )
75
-
76
- response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
77
- # Remove the prompt from response
78
- response = response.replace(prompt, "").strip()
79
- return response
80
-
81
- def _generate_with_fallback(self, prompt: str) -> str:
82
- """Generate text using fallback pipeline"""
83
- result = self.text_generator(prompt, max_length=512, num_return_sequences=1)
84
- return result[0]['generated_text']
85
-
86
- def _generate_text(self, prompt: str) -> str:
87
- """Generate text using available model"""
88
- try:
89
- if self.llm is not None:
90
- return self._generate_with_phi(prompt)
91
- else:
92
- return self._generate_with_fallback(prompt)
93
- except Exception as e:
94
- print(f"Generation error: {e}")
95
- return "Sorry, I encountered an error generating the explanation."
96
-
97
- def _retrieve_context(self, query: str, k: int = 3) -> str:
98
  """Retrieve relevant context from vector database"""
99
  if self.collection is None:
100
- return "No medical reference data available."
101
 
102
  try:
103
  # Create query embedding
@@ -110,144 +62,233 @@ class LabReportRAG:
110
  )
111
 
112
  # Combine documents
113
- if results and results['documents']:
114
- context = "\n\n".join(results['documents'][0])
115
- return context
 
116
  else:
117
- return "No relevant information found."
118
  except Exception as e:
119
  print(f"Retrieval error: {e}")
120
- return "Error retrieving medical information."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def explain_result(self, result: LabResult) -> str:
123
  """Generate explanation for a single lab result"""
124
 
125
- # Retrieve relevant context
126
- query = f"{result.test_name} {result.status} meaning causes treatment"
127
- context = self._retrieve_context(query, k=3)
128
 
129
- # Create prompt
130
- prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms.
131
-
132
- Medical Information:
133
- {context}
134
-
135
- Lab Test: {result.test_name}
136
- Value: {result.value} {result.unit}
137
- Reference Range: {result.reference_range}
138
- Status: {result.status}
 
 
 
 
 
139
 
140
- Please explain:
141
- 1. What this test measures
142
- 2. What this result means
143
- 3. Possible causes if abnormal
144
- 4. Dietary recommendations if applicable
 
 
 
 
 
 
 
 
145
 
146
- Keep it simple and clear. Answer:"""
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- # Generate explanation
149
- explanation = self._generate_text(prompt)
 
 
 
 
150
 
151
  return explanation
152
 
 
 
 
 
 
 
 
 
153
  def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
154
- """Generate explanations for all lab results"""
155
  explanations = {}
156
 
157
- for result in results:
158
- print(f"Explaining {result.test_name}...")
159
- explanation = self.explain_result(result)
160
- explanations[result.test_name] = explanation
 
 
 
 
 
 
161
 
 
162
  return explanations
163
 
164
  def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
165
- """Answer follow-up questions about lab results"""
 
 
166
 
167
  # Create context from lab results
168
- results_context = "\n".join([
169
- f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})"
170
- for r in lab_results
171
- ])
 
 
172
 
173
- # Retrieve relevant medical information
174
- medical_context = self._retrieve_context(question, k=3)
175
 
176
- # Create prompt
177
- prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information.
178
-
179
- Patient's Lab Results:
180
- {results_context}
181
-
182
- Medical Information:
183
- {medical_context}
184
-
185
- Question: {question}
186
-
187
- Provide a clear, helpful answer. Answer:"""
188
 
189
- # Generate answer
190
- answer = self._generate_text(prompt)
 
 
 
 
 
 
 
 
 
 
 
191
 
 
192
  return answer
193
 
194
  def generate_summary(self, results: List[LabResult]) -> str:
195
- """Generate overall summary of lab results"""
 
 
196
 
197
  abnormal = [r for r in results if r.status in ['high', 'low']]
198
  normal = [r for r in results if r.status == 'normal']
199
 
200
  if not abnormal:
201
- return "βœ… Great news! All your lab results are within normal ranges. Keep up the good work with your health!"
202
-
203
- # Get context about abnormal results
204
- queries = [f"{r.test_name} {r.status}" for r in abnormal]
205
- combined_query = " ".join(queries)
206
- context = self._retrieve_context(combined_query, k=4)
207
-
208
- # Create summary prompt
209
- abnormal_list = "\n".join([
210
- f"- {r.test_name}: {r.value} {r.unit} ({r.status})"
211
- for r in abnormal
212
- ])
213
-
214
- prompt = f"""Provide a brief summary of these lab results.
215
 
216
- Normal Results: {len(normal)} tests
217
- Abnormal Results: {len(abnormal)} tests
218
-
219
- Abnormal Tests:
220
- {abnormal_list}
221
 
222
- Medical Context:
223
- {context}
 
224
 
225
- Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:"""
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- # Generate summary
228
- summary = self._generate_text(prompt)
 
 
 
 
 
 
 
 
 
 
 
 
229
 
 
230
  return summary
231
 
232
 
233
- # Example usage
234
  if __name__ == "__main__":
235
- from pdf_extractor import LabResult
236
-
237
- # Initialize RAG system
238
- print("Initializing RAG system...")
239
- rag = LabReportRAG()
240
-
241
- # Example result
242
- test_result = LabResult(
243
- test_name="Hemoglobin",
244
- value="10.5",
245
- unit="g/dL",
246
- reference_range="12.0-15.5",
247
- status="low"
248
- )
249
 
250
- # Generate explanation
251
- print("\nGenerating explanation...")
252
- explanation = rag.explain_result(test_result)
253
- print(f"\n{explanation}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
  RAG Query Engine for Lab Report Decoder
3
+ Uses Hugging Face models - OPTIMIZED for speed
4
  """
5
 
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import pipeline
8
  import chromadb
 
9
  from typing import List, Dict
10
  from pdf_extractor import LabResult
11
+ import os
12
 
13
  class LabReportRAG:
14
+ """RAG system for explaining lab results - Fast and efficient"""
15
 
16
  def __init__(self, db_path: str = "./chroma_db"):
17
+ """Initialize the RAG system with fast models"""
18
 
19
+ print("πŸ”„ Loading models (optimized for speed)...")
20
 
21
+ # Fast embedding model
22
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
23
+ print("βœ… Embeddings loaded")
24
 
25
+ # Use FAST text generation model
26
+ print("πŸ”„ Loading text generation model...")
 
 
 
 
 
 
27
  try:
28
+ # Use Flan-T5 - much faster than Phi-3
 
 
 
 
 
 
 
 
 
 
29
  self.text_generator = pipeline(
30
+ "text2text-generation",
31
+ model="google/flan-t5-small", # Even smaller/faster
32
+ max_length=256,
33
+ device=-1 # Force CPU (HF Spaces default)
34
  )
35
+ print("βœ… Text generation model loaded (Flan-T5-small)")
36
+ except Exception as e:
37
+ print(f"⚠️ Model loading error: {e}")
38
+ self.text_generator = None
39
 
40
  # Load vector store
41
  try:
 
43
  self.collection = self.client.get_collection("lab_reports")
44
  print("βœ… Vector database loaded")
45
  except Exception as e:
46
+ print(f"⚠️ Vector database not found: {e}")
47
  self.collection = None
48
 
49
+ def _retrieve_context(self, query: str, k: int = 2) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  """Retrieve relevant context from vector database"""
51
  if self.collection is None:
52
+ return "Limited medical information available."
53
 
54
  try:
55
  # Create query embedding
 
62
  )
63
 
64
  # Combine documents
65
+ if results and results['documents'] and len(results['documents'][0]) > 0:
66
+ context = "\n".join(results['documents'][0])
67
+ # Limit context length for speed
68
+ return context[:1000]
69
  else:
70
+ return "No specific information found."
71
  except Exception as e:
72
  print(f"Retrieval error: {e}")
73
+ return "Error retrieving information."
74
+
75
+ def _generate_text(self, prompt: str) -> str:
76
+ """Generate text - with fallback to template-based"""
77
+ if self.text_generator is None:
78
+ return "AI model not available. Using basic explanation."
79
+
80
+ try:
81
+ # Generate with timeout protection
82
+ result = self.text_generator(
83
+ prompt,
84
+ max_length=256,
85
+ do_sample=True,
86
+ temperature=0.7,
87
+ num_return_sequences=1
88
+ )
89
+ return result[0]['generated_text'].strip()
90
+ except Exception as e:
91
+ print(f"Generation error: {e}")
92
+ return "Unable to generate detailed explanation."
93
 
94
  def explain_result(self, result: LabResult) -> str:
95
  """Generate explanation for a single lab result"""
96
 
97
+ print(f" Explaining: {result.test_name} ({result.status})...")
 
 
98
 
99
+ # Quick template-based explanation for speed
100
+ if result.status == 'normal':
101
+ return self._explain_normal(result)
102
+ elif result.status == 'high':
103
+ return self._explain_high(result)
104
+ elif result.status == 'low':
105
+ return self._explain_low(result)
106
+ else:
107
+ return self._explain_unknown(result)
108
+
109
+ def _explain_normal(self, result: LabResult) -> str:
110
+ """Fast template for normal results"""
111
+ context = self._retrieve_context(f"{result.test_name} normal meaning", k=1)
112
+
113
+ explanation = f"""βœ… Your {result.test_name} level of {result.value} {result.unit} is within the normal range ({result.reference_range}).
114
 
115
+ This indicates healthy levels. """
116
+
117
+ if context and len(context) > 20:
118
+ # Add context if available
119
+ explanation += f"\n\n{context[:300]}"
120
+
121
+ return explanation
122
+
123
+ def _explain_high(self, result: LabResult) -> str:
124
+ """Fast template for high results"""
125
+ context = self._retrieve_context(f"{result.test_name} high causes treatment", k=2)
126
+
127
+ explanation = f"""⚠️ Your {result.test_name} level of {result.value} {result.unit} is ABOVE the normal range ({result.reference_range}).
128
 
129
+ """
130
+
131
+ if context and len(context) > 20:
132
+ explanation += f"{context[:400]}\n\n"
133
+
134
+ explanation += "πŸ’‘ Recommendation: Discuss these results with your healthcare provider for personalized advice."
135
+
136
+ return explanation
137
+
138
+ def _explain_low(self, result: LabResult) -> str:
139
+ """Fast template for low results"""
140
+ context = self._retrieve_context(f"{result.test_name} low causes treatment", k=2)
141
+
142
+ explanation = f"""⚠️ Your {result.test_name} level of {result.value} {result.unit} is BELOW the normal range ({result.reference_range}).
143
 
144
+ """
145
+
146
+ if context and len(context) > 20:
147
+ explanation += f"{context[:400]}\n\n"
148
+
149
+ explanation += "πŸ’‘ Recommendation: Consult with your healthcare provider about these results."
150
 
151
  return explanation
152
 
153
+ def _explain_unknown(self, result: LabResult) -> str:
154
+ """Template for unknown status"""
155
+ return f"""Your {result.test_name} result is {result.value} {result.unit}.
156
+
157
+ Reference range: {result.reference_range}
158
+
159
+ We couldn't automatically determine if this is within normal range. Please consult your healthcare provider to interpret this result."""
160
+
161
  def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
162
+ """Generate explanations for all lab results - FAST"""
163
  explanations = {}
164
 
165
+ print(f"🧠 Generating explanations for {len(results)} results...")
166
+
167
+ for i, result in enumerate(results, 1):
168
+ print(f" [{i}/{len(results)}] {result.test_name}...")
169
+ try:
170
+ explanation = self.explain_result(result)
171
+ explanations[result.test_name] = explanation
172
+ except Exception as e:
173
+ print(f" Error: {e}")
174
+ explanations[result.test_name] = f"Unable to generate explanation for {result.test_name}."
175
 
176
+ print("βœ… All explanations generated")
177
  return explanations
178
 
179
  def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
180
+ """Answer follow-up questions - FAST"""
181
+
182
+ print(f"πŸ’¬ Processing question: {question[:50]}...")
183
 
184
  # Create context from lab results
185
+ results_summary = []
186
+ for r in lab_results[:10]: # Limit to first 10 for speed
187
+ results_summary.append(
188
+ f"{r.test_name}: {r.value} {r.unit} ({r.status})"
189
+ )
190
+ results_context = "\n".join(results_summary)
191
 
192
+ # Get relevant medical info
193
+ medical_context = self._retrieve_context(question, k=2)
194
 
195
+ # Simple template-based response for speed
196
+ if "food" in question.lower() or "eat" in question.lower() or "diet" in question.lower():
197
+ answer = f"""Based on your lab results:\n\n{results_context}\n\n"""
198
+ if medical_context and len(medical_context) > 20:
199
+ answer += f"{medical_context[:500]}"
200
+ else:
201
+ answer += "For dietary recommendations specific to your results, please consult with a healthcare provider or nutritionist."
 
 
 
 
 
202
 
203
+ elif "why" in question.lower() or "cause" in question.lower():
204
+ answer = f"""Regarding your question about your results:\n\n"""
205
+ if medical_context and len(medical_context) > 20:
206
+ answer += f"{medical_context[:500]}"
207
+ else:
208
+ answer += "There can be various causes for abnormal lab results. Your healthcare provider can help identify the specific cause in your case."
209
+
210
+ else:
211
+ # General question
212
+ if medical_context and len(medical_context) > 20:
213
+ answer = medical_context[:500]
214
+ else:
215
+ answer = f"""Based on your results:\n{results_context}\n\nFor specific medical advice about your results, please consult with your healthcare provider."""
216
 
217
+ print("βœ… Answer generated")
218
  return answer
219
 
220
  def generate_summary(self, results: List[LabResult]) -> str:
221
+ """Generate overall summary - FAST"""
222
+
223
+ print("πŸ“Š Generating summary...")
224
 
225
  abnormal = [r for r in results if r.status in ['high', 'low']]
226
  normal = [r for r in results if r.status == 'normal']
227
 
228
  if not abnormal:
229
+ return """βœ… Excellent news! All your lab results are within normal ranges.
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
+ This suggests that the tested parameters are functioning well. Continue maintaining your current health habits, and follow your healthcare provider's recommendations for routine monitoring."""
232
+
233
+ # Build summary
234
+ summary = f"""πŸ“Š Lab Results Summary
 
235
 
236
+ Total Tests: {len(results)}
237
+ βœ… Normal: {len(normal)}
238
+ ⚠️ Abnormal: {len(abnormal)}
239
 
240
+ """
241
+
242
+ if abnormal:
243
+ summary += "**Tests Outside Normal Range:**\n"
244
+ for r in abnormal[:5]: # Limit to first 5
245
+ status_emoji = "↑" if r.status == "high" else "↓"
246
+ summary += f"{status_emoji} {r.test_name}: {r.value} {r.unit} ({r.status})\n"
247
+
248
+ if len(abnormal) > 5:
249
+ summary += f"... and {len(abnormal) - 5} more\n"
250
+
251
+ summary += "\n"
252
 
253
+ # Get context for abnormal results
254
+ if abnormal:
255
+ abnormal_names = ", ".join([r.test_name for r in abnormal[:3]])
256
+ context = self._retrieve_context(f"{abnormal_names} interpretation", k=2)
257
+
258
+ if context and len(context) > 20:
259
+ summary += f"**Key Information:**\n{context[:400]}\n\n"
260
+
261
+ summary += """**Next Steps:**
262
+ 1. Review these results with your healthcare provider
263
+ 2. Discuss any concerns or symptoms you're experiencing
264
+ 3. Follow recommended treatment or monitoring plans
265
+
266
+ Remember: These results are for educational purposes. Always consult your doctor for medical advice."""
267
 
268
+ print("βœ… Summary generated")
269
  return summary
270
 
271
 
272
+ # Test if ran directly
273
  if __name__ == "__main__":
274
+ print("Testing RAG system...")
 
 
 
 
 
 
 
 
 
 
 
 
 
275
 
276
+ try:
277
+ rag = LabReportRAG()
278
+ print("\nβœ… RAG system initialized successfully!")
279
+
280
+ # Test with example
281
+ from pdf_extractor import LabResult
282
+ test_result = LabResult(
283
+ test_name="Hemoglobin",
284
+ value="10.5",
285
+ unit="g/dL",
286
+ reference_range="12.0-15.5",
287
+ status="low"
288
+ )
289
+
290
+ explanation = rag.explain_result(test_result)
291
+ print(f"\nTest Explanation:\n{explanation}")
292
+
293
+ except Exception as e:
294
+ print(f"\n❌ Error: {e}")