Lab-test-decoder / rag_engine.py
Hanan-Alnakhal's picture
Update rag_engine.py
59e581a verified
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import chromadb
from typing import List, Dict
from pdf_extractor import LabResult
import os
class LabReportRAG:
"""RAG system for explaining lab results - Fast and efficient"""
def __init__(self, db_path: str = "./chroma_db"):
"""Initialize the RAG system with fast models"""
print("πŸ”„ Loading models (optimized for speed)...")
# Fast embedding model
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
print("βœ… Embeddings loaded")
# Use FAST text generation model
print("πŸ”„ Loading text generation model...")
try:
# Use Flan-T5 - efficient instruction tuned model
self.text_generator = pipeline(
"text2text-generation",
model="google/flan-t5-base",
max_length=512, # Increased slightly for better answers
device=-1 # Force CPU
)
print("βœ… Text generation model loaded (Flan-T5-base)")
except Exception as e:
print(f"⚠️ Model loading error: {e}")
self.text_generator = None
# Load vector store
try:
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_collection("lab_reports")
print("βœ… Vector database loaded")
except Exception as e:
print(f"⚠️ Vector database not found: {e}")
self.collection = None
def _retrieve_context(self, query: str, k: int = 2) -> str:
"""Retrieve relevant context from vector database"""
if self.collection is None:
return None
try:
# Create query embedding
query_embedding = self.embedding_model.encode(query).tolist()
# Query the collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=k
)
# Combine documents
if results and results['documents'] and len(results['documents'][0]) > 0:
# Calculate simple relevance check (if distances available)
# For now, we assume if Chroma returns it, it's the best it has.
context = "\n".join(results['documents'][0])
return context[:1500] # Increased context window for the LLM
else:
return None
except Exception as e:
print(f"Retrieval error: {e}")
return None
def _generate_text(self, prompt: str) -> str:
"""Generate text using Flan-T5"""
if self.text_generator is None:
return "AI model not available."
try:
result = self.text_generator(
prompt,
max_length=256,
do_sample=True,
temperature=0.5, # Lower temperature for more factual answers
num_return_sequences=1
)
return result[0]['generated_text'].strip()
except Exception as e:
print(f"Generation error: {e}")
return "Unable to generate explanation."
def explain_result(self, result: LabResult) -> str:
"""Generate explanation for a single lab result using LLM"""
print(f" Explaining: {result.test_name} ({result.status})...")
# 1. Base Message (Deterministic)
base_msg = f"Your {result.test_name} is {result.value} {result.unit} ({result.status}). "
# 2. Retrieve Context
query = f"What does {result.status} {result.test_name} mean? causes and simple explanation"
context = self._retrieve_context(query, k=1)
if not context:
return base_msg + "Please consult your doctor for interpretation."
# 3. Generate Simple Explanation via LLM
# We instruct Flan-T5 to summarize the context simply
prompt = f"""
Explain simply in 2 sentences what {result.status} {result.test_name} means based on this medical info.
Medical Info: {context}
Explanation:"""
ai_explanation = self._generate_text(prompt)
# 4. Construct Final Output
final_output = f"""{base_msg}
πŸ’‘ Analysis: {ai_explanation}
(Reference Range: {result.reference_range})"""
return final_output
def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
"""
Answer questions using RAG + LLM.
Handles out-of-context questions strictly.
"""
print(f"πŸ’¬ Processing question: {question[:50]}...")
# 1. Retrieve Medical Context
medical_context = self._retrieve_context(question, k=2)
print(medical_context)
# 2. Check for "Out of Context"
# If no documents found in DB, or query seems totally unrelated
if not medical_context:
return "Sorry I can't answer this Question"
# 3. Prepare Patient Context (Current Results)
# We give the model a snapshot of the patient's actual data
relevant_results = [f"{r.test_name}: {r.value} ({r.status})" for r in lab_results]
patient_data = ", ".join(relevant_results[:5]) # Top 5 results to save space
# 4. Construct Strict Prompt
# Flan-T5 instruction to enforce the fallback phrase
prompt = f"""
Answer the question based strictly on the Medical Context and Patient Results provided below.
If the answer cannot be found in the context, or if the question is not about health/labs, reply exactly "Sorry I can't answer this Question".
Medical Context: {medical_context}
Patient Results: {patient_data}
Question: {question}
Answer:"""
# 5. Generate
answer = self._generate_text(prompt)
# Double check: sometimes models hallucinate.
# If the context was extremely short/weak, we might want to override,
# but relying on the prompt instructions is standard for T5.
return answer
def generate_summary(self, results: List[LabResult]) -> str:
"""Generate a summary using the LLM"""
print("πŸ“Š Generating summary...")
abnormal = [r for r in results if r.status in ['high', 'low']]
if not abnormal:
return "βœ… All results are normal. Great job maintaining your health!"
# Create a prompt for the summary
abnormal_text = ", ".join([f"{r.test_name} is {r.status}" for r in abnormal])
# Get general context about these specific abnormal tests
context = self._retrieve_context(f"health implications of {abnormal_text}", k=1)
prompt = f"""
The patient has these abnormal lab results: {abnormal_text}.
Based on this medical info: {context}
Write a short, encouraging 2-sentence summary advising them to see a doctor.
Summary:"""
ai_summary = self._generate_text(prompt)
return f"⚠️ **Abnormal Results Detected**\n\n{ai_summary}\n\nDetailed changes:\n" + "\n".join([f"- {r.test_name}: {r.value} {r.unit}" for r in abnormal])
# Keep the other helper methods if needed or rely on the new logic
# The explain_all_results wrapper is still useful
def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
explanations = {}
for result in results:
explanations[result.test_name] = self.explain_result(result)
return explanations
# Testing block
if __name__ == "__main__":
print("Testing RAG system...")
try:
rag = LabReportRAG()
# Mock Data
from pdf_extractor import LabResult
results = [
LabResult("Hemoglobin", "10.5", "g/dL", "12.0-15.5", "low"),
LabResult("Glucose", "95", "mg/dL", "70-100", "normal")
]
# Test 1: Explanation using LLM
print("\n--- Test Explanation ---")
print(rag.explain_result(results[0]))
# Test 2: Follow up (Valid)
print("\n--- Test Valid Question ---")
q1 = "What foods should I eat for low hemoglobin?"
print(f"Q: {q1}")
print(f"A: {rag.answer_followup_question(q1, results)}")
# Test 3: Follow up (Out of Context)
print("\n--- Test Invalid Question ---")
q2 = "Who is the president of the USA?"
print(f"Q: {q2}")
print(f"A: {rag.answer_followup_question(q2, results)}")
except Exception as e:
print(f"\n❌ Error: {e}")