from sentence_transformers import SentenceTransformer from transformers import pipeline import chromadb from typing import List, Dict from pdf_extractor import LabResult import os class LabReportRAG: """RAG system for explaining lab results - Fast and efficient""" def __init__(self, db_path: str = "./chroma_db"): """Initialize the RAG system with fast models""" print("šŸ”„ Loading models (optimized for speed)...") # Fast embedding model self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') print("āœ… Embeddings loaded") # Use FAST text generation model print("šŸ”„ Loading text generation model...") try: # Use Flan-T5 - efficient instruction tuned model self.text_generator = pipeline( "text2text-generation", model="google/flan-t5-base", max_length=512, # Increased slightly for better answers device=-1 # Force CPU ) print("āœ… Text generation model loaded (Flan-T5-base)") except Exception as e: print(f"āš ļø Model loading error: {e}") self.text_generator = None # Load vector store try: self.client = chromadb.PersistentClient(path=db_path) self.collection = self.client.get_collection("lab_reports") print("āœ… Vector database loaded") except Exception as e: print(f"āš ļø Vector database not found: {e}") self.collection = None def _retrieve_context(self, query: str, k: int = 2) -> str: """Retrieve relevant context from vector database""" if self.collection is None: return None try: # Create query embedding query_embedding = self.embedding_model.encode(query).tolist() # Query the collection results = self.collection.query( query_embeddings=[query_embedding], n_results=k ) # Combine documents if results and results['documents'] and len(results['documents'][0]) > 0: # Calculate simple relevance check (if distances available) # For now, we assume if Chroma returns it, it's the best it has. context = "\n".join(results['documents'][0]) return context[:1500] # Increased context window for the LLM else: return None except Exception as e: print(f"Retrieval error: {e}") return None def _generate_text(self, prompt: str) -> str: """Generate text using Flan-T5""" if self.text_generator is None: return "AI model not available." try: result = self.text_generator( prompt, max_length=256, do_sample=True, temperature=0.5, # Lower temperature for more factual answers num_return_sequences=1 ) return result[0]['generated_text'].strip() except Exception as e: print(f"Generation error: {e}") return "Unable to generate explanation." def explain_result(self, result: LabResult) -> str: """Generate explanation for a single lab result using LLM""" print(f" Explaining: {result.test_name} ({result.status})...") # 1. Base Message (Deterministic) base_msg = f"Your {result.test_name} is {result.value} {result.unit} ({result.status}). " # 2. Retrieve Context query = f"What does {result.status} {result.test_name} mean? causes and simple explanation" context = self._retrieve_context(query, k=1) if not context: return base_msg + "Please consult your doctor for interpretation." # 3. Generate Simple Explanation via LLM # We instruct Flan-T5 to summarize the context simply prompt = f""" Explain simply in 2 sentences what {result.status} {result.test_name} means based on this medical info. Medical Info: {context} Explanation:""" ai_explanation = self._generate_text(prompt) # 4. Construct Final Output final_output = f"""{base_msg} šŸ’” Analysis: {ai_explanation} (Reference Range: {result.reference_range})""" return final_output def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str: """ Answer questions using RAG + LLM. Handles out-of-context questions strictly. """ print(f"šŸ’¬ Processing question: {question[:50]}...") # 1. Retrieve Medical Context medical_context = self._retrieve_context(question, k=2) print(medical_context) # 2. Check for "Out of Context" # If no documents found in DB, or query seems totally unrelated if not medical_context: return "Sorry I can't answer this Question" # 3. Prepare Patient Context (Current Results) # We give the model a snapshot of the patient's actual data relevant_results = [f"{r.test_name}: {r.value} ({r.status})" for r in lab_results] patient_data = ", ".join(relevant_results[:5]) # Top 5 results to save space # 4. Construct Strict Prompt # Flan-T5 instruction to enforce the fallback phrase prompt = f""" Answer the question based strictly on the Medical Context and Patient Results provided below. If the answer cannot be found in the context, or if the question is not about health/labs, reply exactly "Sorry I can't answer this Question". Medical Context: {medical_context} Patient Results: {patient_data} Question: {question} Answer:""" # 5. Generate answer = self._generate_text(prompt) # Double check: sometimes models hallucinate. # If the context was extremely short/weak, we might want to override, # but relying on the prompt instructions is standard for T5. return answer def generate_summary(self, results: List[LabResult]) -> str: """Generate a summary using the LLM""" print("šŸ“Š Generating summary...") abnormal = [r for r in results if r.status in ['high', 'low']] if not abnormal: return "āœ… All results are normal. Great job maintaining your health!" # Create a prompt for the summary abnormal_text = ", ".join([f"{r.test_name} is {r.status}" for r in abnormal]) # Get general context about these specific abnormal tests context = self._retrieve_context(f"health implications of {abnormal_text}", k=1) prompt = f""" The patient has these abnormal lab results: {abnormal_text}. Based on this medical info: {context} Write a short, encouraging 2-sentence summary advising them to see a doctor. Summary:""" ai_summary = self._generate_text(prompt) return f"āš ļø **Abnormal Results Detected**\n\n{ai_summary}\n\nDetailed changes:\n" + "\n".join([f"- {r.test_name}: {r.value} {r.unit}" for r in abnormal]) # Keep the other helper methods if needed or rely on the new logic # The explain_all_results wrapper is still useful def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]: explanations = {} for result in results: explanations[result.test_name] = self.explain_result(result) return explanations # Testing block if __name__ == "__main__": print("Testing RAG system...") try: rag = LabReportRAG() # Mock Data from pdf_extractor import LabResult results = [ LabResult("Hemoglobin", "10.5", "g/dL", "12.0-15.5", "low"), LabResult("Glucose", "95", "mg/dL", "70-100", "normal") ] # Test 1: Explanation using LLM print("\n--- Test Explanation ---") print(rag.explain_result(results[0])) # Test 2: Follow up (Valid) print("\n--- Test Valid Question ---") q1 = "What foods should I eat for low hemoglobin?" print(f"Q: {q1}") print(f"A: {rag.answer_followup_question(q1, results)}") # Test 3: Follow up (Out of Context) print("\n--- Test Invalid Question ---") q2 = "Who is the president of the USA?" print(f"Q: {q2}") print(f"A: {rag.answer_followup_question(q2, results)}") except Exception as e: print(f"\nāŒ Error: {e}")