| | """ |
| | RAG Query Engine for Lab Report Decoder |
| | Uses Hugging Face models for embeddings and generation |
| | """ |
| |
|
| | from sentence_transformers import SentenceTransformer |
| | from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM |
| | import chromadb |
| | from chromadb.config import Settings |
| | from typing import List, Dict |
| | from pdf_extractor import LabResult |
| | import torch |
| |
|
| | class LabReportRAG: |
| | """RAG system for explaining lab results using Hugging Face models""" |
| | |
| | def __init__(self, db_path: str = "./chroma_db"): |
| | """Initialize the RAG system with Hugging Face models""" |
| | |
| | print("🔄 Loading Hugging Face models...") |
| | |
| | |
| | self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | model_name = "google/flan-t5-base" |
| | |
| | try: |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| | self.llm = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | trust_remote_code=True, |
| | torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| | device_map="auto" if torch.cuda.is_available() else None |
| | ) |
| | print(f"✅ Loaded model: {model_name}") |
| | except Exception as e: |
| | print(f"⚠️ Could not load {model_name}, falling back to simpler model") |
| | |
| | self.text_generator = pipeline( |
| | "text-generation", |
| | model="google/flan-t5-base", |
| | max_length=512 |
| | ) |
| | self.llm = None |
| | |
| | |
| | try: |
| | self.client = chromadb.PersistentClient(path=db_path) |
| | self.collection = self.client.get_collection("lab_reports") |
| | print("✅ Vector database loaded") |
| | except Exception as e: |
| | print(f"⚠️ No vector database found. Please run build_vector_db.py first.") |
| | self.collection = None |
| | |
| | def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str: |
| | """Generate text using Phi-3 model""" |
| | inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) |
| | |
| | if torch.cuda.is_available(): |
| | inputs = {k: v.to('cuda') for k, v in inputs.items()} |
| | |
| | outputs = self.llm.generate( |
| | **inputs, |
| | max_new_tokens=max_tokens, |
| | temperature=0.7, |
| | do_sample=True, |
| | top_p=0.9 |
| | ) |
| | |
| | response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| | |
| | response = response.replace(prompt, "").strip() |
| | return response |
| | |
| | def _generate_with_fallback(self, prompt: str) -> str: |
| | """Generate text using fallback pipeline""" |
| | result = self.text_generator(prompt, max_length=512, num_return_sequences=1) |
| | return result[0]['generated_text'] |
| | |
| | def _generate_text(self, prompt: str) -> str: |
| | try: |
| | result = self.text_generator(prompt)[0]["generated_text"] |
| | return result |
| | except Exception as e: |
| | print(f"Generation error: {e}") |
| | return "Sorry, I encountered an error generating the explanation." |
| |
|
| | |
| | def _retrieve_context(self, query: str, k: int = 3) -> str: |
| | """Retrieve relevant context from vector database""" |
| | if self.collection is None: |
| | return "No medical reference data available." |
| | |
| | try: |
| | |
| | query_embedding = self.embedding_model.encode(query).tolist() |
| | |
| | |
| | results = self.collection.query( |
| | query_embeddings=[query_embedding], |
| | n_results=k |
| | ) |
| | |
| | |
| | if results and results['documents']: |
| | context = "\n\n".join(results['documents'][0]) |
| | return context |
| | else: |
| | return "No relevant information found." |
| | except Exception as e: |
| | print(f"Retrieval error: {e}") |
| | return "Error retrieving medical information." |
| | |
| | def explain_result(self, result: LabResult) -> str: |
| | """Generate explanation for a single lab result""" |
| | |
| | |
| | query = f"{result.test_name} {result.status} meaning causes treatment" |
| | context = self._retrieve_context(query, k=3) |
| | |
| | |
| | prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms. |
| | |
| | Medical Information: |
| | {context} |
| | |
| | Lab Test: {result.test_name} |
| | Value: {result.value} {result.unit} |
| | Reference Range: {result.reference_range} |
| | Status: {result.status} |
| | |
| | Please explain: |
| | 1. What this test measures |
| | 2. What this result means |
| | 3. Possible causes if abnormal |
| | 4. Dietary recommendations if applicable |
| | |
| | Keep it simple and clear. Answer:""" |
| |
|
| | |
| | explanation = self._generate_text(prompt) |
| | |
| | return explanation |
| | |
| | def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]: |
| | """Generate explanations for all lab results""" |
| | explanations = {} |
| | |
| | for result in results: |
| | print(f"Explaining {result.test_name}...") |
| | explanation = self.explain_result(result) |
| | explanations[result.test_name] = explanation |
| | |
| | return explanations |
| | |
| | def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str: |
| | """Answer follow-up questions about lab results""" |
| | |
| | |
| | results_context = "\n".join([ |
| | f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})" |
| | for r in lab_results |
| | ]) |
| | |
| | |
| | medical_context = self._retrieve_context(question, k=3) |
| | |
| | |
| | prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information. |
| | |
| | Patient's Lab Results: |
| | {results_context} |
| | |
| | Medical Information: |
| | {medical_context} |
| | |
| | Question: {question} |
| | |
| | Provide a clear, helpful answer. Answer:""" |
| | |
| | |
| | answer = self._generate_text(prompt) |
| | |
| | return answer |
| | |
| | def generate_summary(self, results: List[LabResult]) -> str: |
| | """Generate overall summary of lab results""" |
| | |
| | abnormal = [r for r in results if r.status in ['high', 'low']] |
| | normal = [r for r in results if r.status == 'normal'] |
| | |
| | if not abnormal: |
| | return "✅ Great news! All your lab results are within normal ranges. Keep up the good work with your health!" |
| | |
| | |
| | queries = [f"{r.test_name} {r.status}" for r in abnormal] |
| | combined_query = " ".join(queries) |
| | context = self._retrieve_context(combined_query, k=4) |
| | |
| | |
| | abnormal_list = "\n".join([ |
| | f"- {r.test_name}: {r.value} {r.unit} ({r.status})" |
| | for r in abnormal |
| | ]) |
| | |
| | prompt = f"""Provide a brief summary of these lab results. |
| | |
| | Normal Results: {len(normal)} tests |
| | Abnormal Results: {len(abnormal)} tests |
| | |
| | Abnormal Tests: |
| | {abnormal_list} |
| | |
| | Medical Context: |
| | {context} |
| | |
| | Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:""" |
| | |
| | |
| | summary = self._generate_text(prompt) |
| | |
| | return summary |
| |
|
| |
|
| | |
| | if __name__ == "__main__": |
| | from pdf_extractor import LabResult |
| | |
| | |
| | print("Initializing RAG system...") |
| | rag = LabReportRAG() |
| | |
| | |
| | test_result = LabResult( |
| | test_name="Hemoglobin", |
| | value="10.5", |
| | unit="g/dL", |
| | reference_range="12.0-15.5", |
| | status="low" |
| | ) |
| | |
| | |
| | print("\nGenerating explanation...") |
| | explanation = rag.explain_result(test_result) |
| | print(f"\n{explanation}") |