Spaces:
Sleeping
Sleeping
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline | |
| import chromadb | |
| from typing import List, Dict | |
| from pdf_extractor import LabResult | |
| import os | |
| class LabReportRAG: | |
| """RAG system for explaining lab results - Fast and efficient""" | |
| def __init__(self, db_path: str = "./chroma_db"): | |
| """Initialize the RAG system with fast models""" | |
| print("π Loading models (optimized for speed)...") | |
| # Fast embedding model | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| print("β Embeddings loaded") | |
| # Use FAST text generation model | |
| print("π Loading text generation model...") | |
| try: | |
| # Use Flan-T5 - efficient instruction tuned model | |
| self.text_generator = pipeline( | |
| "text2text-generation", | |
| model="google/flan-t5-base", | |
| max_length=512, # Increased slightly for better answers | |
| device=-1 # Force CPU | |
| ) | |
| print("β Text generation model loaded (Flan-T5-base)") | |
| except Exception as e: | |
| print(f"β οΈ Model loading error: {e}") | |
| self.text_generator = None | |
| # Load vector store | |
| try: | |
| self.client = chromadb.PersistentClient(path=db_path) | |
| self.collection = self.client.get_collection("lab_reports") | |
| print("β Vector database loaded") | |
| except Exception as e: | |
| print(f"β οΈ Vector database not found: {e}") | |
| self.collection = None | |
| def _retrieve_context(self, query: str, k: int = 2) -> str: | |
| """Retrieve relevant context from vector database""" | |
| if self.collection is None: | |
| return None | |
| try: | |
| # Create query embedding | |
| query_embedding = self.embedding_model.encode(query).tolist() | |
| # Query the collection | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=k | |
| ) | |
| # Combine documents | |
| if results and results['documents'] and len(results['documents'][0]) > 0: | |
| # Calculate simple relevance check (if distances available) | |
| # For now, we assume if Chroma returns it, it's the best it has. | |
| context = "\n".join(results['documents'][0]) | |
| return context[:1500] # Increased context window for the LLM | |
| else: | |
| return None | |
| except Exception as e: | |
| print(f"Retrieval error: {e}") | |
| return None | |
| def _generate_text(self, prompt: str) -> str: | |
| """Generate text using Flan-T5""" | |
| if self.text_generator is None: | |
| return "AI model not available." | |
| try: | |
| result = self.text_generator( | |
| prompt, | |
| max_length=256, | |
| do_sample=True, | |
| temperature=0.5, # Lower temperature for more factual answers | |
| num_return_sequences=1 | |
| ) | |
| return result[0]['generated_text'].strip() | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| return "Unable to generate explanation." | |
| def explain_result(self, result: LabResult) -> str: | |
| """Generate explanation for a single lab result using LLM""" | |
| print(f" Explaining: {result.test_name} ({result.status})...") | |
| # 1. Base Message (Deterministic) | |
| base_msg = f"Your {result.test_name} is {result.value} {result.unit} ({result.status}). " | |
| # 2. Retrieve Context | |
| query = f"What does {result.status} {result.test_name} mean? causes and simple explanation" | |
| context = self._retrieve_context(query, k=1) | |
| if not context: | |
| return base_msg + "Please consult your doctor for interpretation." | |
| # 3. Generate Simple Explanation via LLM | |
| # We instruct Flan-T5 to summarize the context simply | |
| prompt = f""" | |
| Explain simply in 2 sentences what {result.status} {result.test_name} means based on this medical info. | |
| Medical Info: {context} | |
| Explanation:""" | |
| ai_explanation = self._generate_text(prompt) | |
| # 4. Construct Final Output | |
| final_output = f"""{base_msg} | |
| π‘ Analysis: {ai_explanation} | |
| (Reference Range: {result.reference_range})""" | |
| return final_output | |
| def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str: | |
| """ | |
| Answer questions using RAG + LLM. | |
| Handles out-of-context questions strictly. | |
| """ | |
| print(f"π¬ Processing question: {question[:50]}...") | |
| # 1. Retrieve Medical Context | |
| medical_context = self._retrieve_context(question, k=2) | |
| print(medical_context) | |
| # 2. Check for "Out of Context" | |
| # If no documents found in DB, or query seems totally unrelated | |
| if not medical_context: | |
| return "Sorry I can't answer this Question" | |
| # 3. Prepare Patient Context (Current Results) | |
| # We give the model a snapshot of the patient's actual data | |
| relevant_results = [f"{r.test_name}: {r.value} ({r.status})" for r in lab_results] | |
| patient_data = ", ".join(relevant_results[:5]) # Top 5 results to save space | |
| # 4. Construct Strict Prompt | |
| # Flan-T5 instruction to enforce the fallback phrase | |
| prompt = f""" | |
| Answer the question based strictly on the Medical Context and Patient Results provided below. | |
| If the answer cannot be found in the context, or if the question is not about health/labs, reply exactly "Sorry I can't answer this Question". | |
| Medical Context: {medical_context} | |
| Patient Results: {patient_data} | |
| Question: {question} | |
| Answer:""" | |
| # 5. Generate | |
| answer = self._generate_text(prompt) | |
| # Double check: sometimes models hallucinate. | |
| # If the context was extremely short/weak, we might want to override, | |
| # but relying on the prompt instructions is standard for T5. | |
| return answer | |
| def generate_summary(self, results: List[LabResult]) -> str: | |
| """Generate a summary using the LLM""" | |
| print("π Generating summary...") | |
| abnormal = [r for r in results if r.status in ['high', 'low']] | |
| if not abnormal: | |
| return "β All results are normal. Great job maintaining your health!" | |
| # Create a prompt for the summary | |
| abnormal_text = ", ".join([f"{r.test_name} is {r.status}" for r in abnormal]) | |
| # Get general context about these specific abnormal tests | |
| context = self._retrieve_context(f"health implications of {abnormal_text}", k=1) | |
| prompt = f""" | |
| The patient has these abnormal lab results: {abnormal_text}. | |
| Based on this medical info: {context} | |
| Write a short, encouraging 2-sentence summary advising them to see a doctor. | |
| Summary:""" | |
| ai_summary = self._generate_text(prompt) | |
| return f"β οΈ **Abnormal Results Detected**\n\n{ai_summary}\n\nDetailed changes:\n" + "\n".join([f"- {r.test_name}: {r.value} {r.unit}" for r in abnormal]) | |
| # Keep the other helper methods if needed or rely on the new logic | |
| # The explain_all_results wrapper is still useful | |
| def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]: | |
| explanations = {} | |
| for result in results: | |
| explanations[result.test_name] = self.explain_result(result) | |
| return explanations | |
| # Testing block | |
| if __name__ == "__main__": | |
| print("Testing RAG system...") | |
| try: | |
| rag = LabReportRAG() | |
| # Mock Data | |
| from pdf_extractor import LabResult | |
| results = [ | |
| LabResult("Hemoglobin", "10.5", "g/dL", "12.0-15.5", "low"), | |
| LabResult("Glucose", "95", "mg/dL", "70-100", "normal") | |
| ] | |
| # Test 1: Explanation using LLM | |
| print("\n--- Test Explanation ---") | |
| print(rag.explain_result(results[0])) | |
| # Test 2: Follow up (Valid) | |
| print("\n--- Test Valid Question ---") | |
| q1 = "What foods should I eat for low hemoglobin?" | |
| print(f"Q: {q1}") | |
| print(f"A: {rag.answer_followup_question(q1, results)}") | |
| # Test 3: Follow up (Out of Context) | |
| print("\n--- Test Invalid Question ---") | |
| q2 = "Who is the president of the USA?" | |
| print(f"Q: {q2}") | |
| print(f"A: {rag.answer_followup_question(q2, results)}") | |
| except Exception as e: | |
| print(f"\nβ Error: {e}") |