Spaces:
Sleeping
Sleeping
| """ | |
| RAG Query Engine for Lab Report Decoder | |
| Uses Hugging Face models for embeddings and generation | |
| """ | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| import chromadb | |
| from chromadb.config import Settings | |
| from typing import List, Dict | |
| from pdf_extractor import LabResult | |
| import torch | |
| class LabReportRAG: | |
| """RAG system for explaining lab results using Hugging Face models""" | |
| def __init__(self, db_path: str = "./chroma_db"): | |
| """Initialize the RAG system with Hugging Face models""" | |
| print("🔄 Loading Hugging Face models...") | |
| # Use smaller, faster models for embeddings | |
| self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2') | |
| # Use a medical-focused or general LLM | |
| # Options: | |
| # - "microsoft/Phi-3-mini-4k-instruct" (good balance) | |
| # - "google/flan-t5-base" (lighter) | |
| # - "meta-llama/Llama-2-7b-chat-hf" (requires auth) | |
| model_name = "microsoft/Phi-3-mini-4k-instruct" | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
| self.llm = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| device_map="auto" if torch.cuda.is_available() else None | |
| ) | |
| print(f"✅ Loaded model: {model_name}") | |
| except Exception as e: | |
| print(f"⚠️ Could not load {model_name}, falling back to simpler model") | |
| # Fallback to lighter model | |
| self.text_generator = pipeline( | |
| "text-generation", | |
| model="google/flan-t5-base", | |
| max_length=512 | |
| ) | |
| self.llm = None | |
| # Load vector store | |
| try: | |
| self.client = chromadb.PersistentClient(path=db_path) | |
| self.collection = self.client.get_collection("lab_reports") | |
| print("✅ Vector database loaded") | |
| except Exception as e: | |
| print(f"⚠️ No vector database found. Please run build_vector_db.py first.") | |
| self.collection = None | |
| def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str: | |
| """Generate text using Phi-3 model""" | |
| inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| if torch.cuda.is_available(): | |
| inputs = {k: v.to('cuda') for k, v in inputs.items()} | |
| outputs = self.llm.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=0.7, | |
| do_sample=True, | |
| top_p=0.9 | |
| ) | |
| response = self.tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Remove the prompt from response | |
| response = response.replace(prompt, "").strip() | |
| return response | |
| def _generate_with_fallback(self, prompt: str) -> str: | |
| """Generate text using fallback pipeline""" | |
| result = self.text_generator(prompt, max_length=512, num_return_sequences=1) | |
| return result[0]['generated_text'] | |
| def _generate_text(self, prompt: str) -> str: | |
| """Generate text using available model""" | |
| try: | |
| if self.llm is not None: | |
| return self._generate_with_phi(prompt) | |
| else: | |
| return self._generate_with_fallback(prompt) | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| return "Sorry, I encountered an error generating the explanation." | |
| def _retrieve_context(self, query: str, k: int = 3) -> str: | |
| """Retrieve relevant context from vector database""" | |
| if self.collection is None: | |
| return "No medical reference data available." | |
| try: | |
| # Create query embedding | |
| query_embedding = self.embedding_model.encode(query).tolist() | |
| # Query the collection | |
| results = self.collection.query( | |
| query_embeddings=[query_embedding], | |
| n_results=k | |
| ) | |
| # Combine documents | |
| if results and results['documents']: | |
| context = "\n\n".join(results['documents'][0]) | |
| return context | |
| else: | |
| return "No relevant information found." | |
| except Exception as e: | |
| print(f"Retrieval error: {e}") | |
| return "Error retrieving medical information." | |
| def explain_result(self, result: LabResult) -> str: | |
| """Generate explanation for a single lab result""" | |
| # Retrieve relevant context | |
| query = f"{result.test_name} {result.status} meaning causes treatment" | |
| context = self._retrieve_context(query, k=3) | |
| # Create prompt | |
| prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms. | |
| Medical Information: | |
| {context} | |
| Lab Test: {result.test_name} | |
| Value: {result.value} {result.unit} | |
| Reference Range: {result.reference_range} | |
| Status: {result.status} | |
| Please explain: | |
| 1. What this test measures | |
| 2. What this result means | |
| 3. Possible causes if abnormal | |
| 4. Dietary recommendations if applicable | |
| Keep it simple and clear. Answer:""" | |
| # Generate explanation | |
| explanation = self._generate_text(prompt) | |
| return explanation | |
| def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]: | |
| """Generate explanations for all lab results""" | |
| explanations = {} | |
| for result in results: | |
| print(f"Explaining {result.test_name}...") | |
| explanation = self.explain_result(result) | |
| explanations[result.test_name] = explanation | |
| return explanations | |
| def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str: | |
| """Answer follow-up questions about lab results""" | |
| # Create context from lab results | |
| results_context = "\n".join([ | |
| f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})" | |
| for r in lab_results | |
| ]) | |
| # Retrieve relevant medical information | |
| medical_context = self._retrieve_context(question, k=3) | |
| # Create prompt | |
| prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information. | |
| Patient's Lab Results: | |
| {results_context} | |
| Medical Information: | |
| {medical_context} | |
| Question: {question} | |
| Provide a clear, helpful answer. Answer:""" | |
| # Generate answer | |
| answer = self._generate_text(prompt) | |
| return answer | |
| def generate_summary(self, results: List[LabResult]) -> str: | |
| """Generate overall summary of lab results""" | |
| abnormal = [r for r in results if r.status in ['high', 'low']] | |
| normal = [r for r in results if r.status == 'normal'] | |
| if not abnormal: | |
| return "✅ Great news! All your lab results are within normal ranges. Keep up the good work with your health!" | |
| # Get context about abnormal results | |
| queries = [f"{r.test_name} {r.status}" for r in abnormal] | |
| combined_query = " ".join(queries) | |
| context = self._retrieve_context(combined_query, k=4) | |
| # Create summary prompt | |
| abnormal_list = "\n".join([ | |
| f"- {r.test_name}: {r.value} {r.unit} ({r.status})" | |
| for r in abnormal | |
| ]) | |
| prompt = f"""Provide a brief summary of these lab results. | |
| Normal Results: {len(normal)} tests | |
| Abnormal Results: {len(abnormal)} tests | |
| Abnormal Tests: | |
| {abnormal_list} | |
| Medical Context: | |
| {context} | |
| Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:""" | |
| # Generate summary | |
| summary = self._generate_text(prompt) | |
| return summary | |
| # Example usage | |
| if __name__ == "__main__": | |
| from pdf_extractor import LabResult | |
| # Initialize RAG system | |
| print("Initializing RAG system...") | |
| rag = LabReportRAG() | |
| # Example result | |
| test_result = LabResult( | |
| test_name="Hemoglobin", | |
| value="10.5", | |
| unit="g/dL", | |
| reference_range="12.0-15.5", | |
| status="low" | |
| ) | |
| # Generate explanation | |
| print("\nGenerating explanation...") | |
| explanation = rag.explain_result(test_result) | |
| print(f"\n{explanation}") |