Lab-test-decoder / rag_engine.py
Hanan-Alnakhal's picture
Upload 12 files
8a693e2 verified
raw
history blame
8.99 kB
"""
RAG Query Engine for Lab Report Decoder
Uses Hugging Face models for embeddings and generation
"""
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import chromadb
from chromadb.config import Settings
from typing import List, Dict
from pdf_extractor import LabResult
import torch
class LabReportRAG:
"""RAG system for explaining lab results using Hugging Face models"""
def __init__(self, db_path: str = "./chroma_db"):
"""Initialize the RAG system with Hugging Face models"""
print("🔄 Loading Hugging Face models...")
# Use smaller, faster models for embeddings
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Use a medical-focused or general LLM
# Options:
# - "microsoft/Phi-3-mini-4k-instruct" (good balance)
# - "google/flan-t5-base" (lighter)
# - "meta-llama/Llama-2-7b-chat-hf" (requires auth)
model_name = "microsoft/Phi-3-mini-4k-instruct"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.llm = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
print(f"✅ Loaded model: {model_name}")
except Exception as e:
print(f"⚠️ Could not load {model_name}, falling back to simpler model")
# Fallback to lighter model
self.text_generator = pipeline(
"text-generation",
model="google/flan-t5-base",
max_length=512
)
self.llm = None
# Load vector store
try:
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_collection("lab_reports")
print("✅ Vector database loaded")
except Exception as e:
print(f"⚠️ No vector database found. Please run build_vector_db.py first.")
self.collection = None
def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str:
"""Generate text using Phi-3 model"""
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
outputs = self.llm.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from response
response = response.replace(prompt, "").strip()
return response
def _generate_with_fallback(self, prompt: str) -> str:
"""Generate text using fallback pipeline"""
result = self.text_generator(prompt, max_length=512, num_return_sequences=1)
return result[0]['generated_text']
def _generate_text(self, prompt: str) -> str:
"""Generate text using available model"""
try:
if self.llm is not None:
return self._generate_with_phi(prompt)
else:
return self._generate_with_fallback(prompt)
except Exception as e:
print(f"Generation error: {e}")
return "Sorry, I encountered an error generating the explanation."
def _retrieve_context(self, query: str, k: int = 3) -> str:
"""Retrieve relevant context from vector database"""
if self.collection is None:
return "No medical reference data available."
try:
# Create query embedding
query_embedding = self.embedding_model.encode(query).tolist()
# Query the collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=k
)
# Combine documents
if results and results['documents']:
context = "\n\n".join(results['documents'][0])
return context
else:
return "No relevant information found."
except Exception as e:
print(f"Retrieval error: {e}")
return "Error retrieving medical information."
def explain_result(self, result: LabResult) -> str:
"""Generate explanation for a single lab result"""
# Retrieve relevant context
query = f"{result.test_name} {result.status} meaning causes treatment"
context = self._retrieve_context(query, k=3)
# Create prompt
prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms.
Medical Information:
{context}
Lab Test: {result.test_name}
Value: {result.value} {result.unit}
Reference Range: {result.reference_range}
Status: {result.status}
Please explain:
1. What this test measures
2. What this result means
3. Possible causes if abnormal
4. Dietary recommendations if applicable
Keep it simple and clear. Answer:"""
# Generate explanation
explanation = self._generate_text(prompt)
return explanation
def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
"""Generate explanations for all lab results"""
explanations = {}
for result in results:
print(f"Explaining {result.test_name}...")
explanation = self.explain_result(result)
explanations[result.test_name] = explanation
return explanations
def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
"""Answer follow-up questions about lab results"""
# Create context from lab results
results_context = "\n".join([
f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})"
for r in lab_results
])
# Retrieve relevant medical information
medical_context = self._retrieve_context(question, k=3)
# Create prompt
prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information.
Patient's Lab Results:
{results_context}
Medical Information:
{medical_context}
Question: {question}
Provide a clear, helpful answer. Answer:"""
# Generate answer
answer = self._generate_text(prompt)
return answer
def generate_summary(self, results: List[LabResult]) -> str:
"""Generate overall summary of lab results"""
abnormal = [r for r in results if r.status in ['high', 'low']]
normal = [r for r in results if r.status == 'normal']
if not abnormal:
return "✅ Great news! All your lab results are within normal ranges. Keep up the good work with your health!"
# Get context about abnormal results
queries = [f"{r.test_name} {r.status}" for r in abnormal]
combined_query = " ".join(queries)
context = self._retrieve_context(combined_query, k=4)
# Create summary prompt
abnormal_list = "\n".join([
f"- {r.test_name}: {r.value} {r.unit} ({r.status})"
for r in abnormal
])
prompt = f"""Provide a brief summary of these lab results.
Normal Results: {len(normal)} tests
Abnormal Results: {len(abnormal)} tests
Abnormal Tests:
{abnormal_list}
Medical Context:
{context}
Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:"""
# Generate summary
summary = self._generate_text(prompt)
return summary
# Example usage
if __name__ == "__main__":
from pdf_extractor import LabResult
# Initialize RAG system
print("Initializing RAG system...")
rag = LabReportRAG()
# Example result
test_result = LabResult(
test_name="Hemoglobin",
value="10.5",
unit="g/dL",
reference_range="12.0-15.5",
status="low"
)
# Generate explanation
print("\nGenerating explanation...")
explanation = rag.explain_result(test_result)
print(f"\n{explanation}")