Spaces:
Sleeping
Sleeping
File size: 8,988 Bytes
8a693e2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 | """
RAG Query Engine for Lab Report Decoder
Uses Hugging Face models for embeddings and generation
"""
from sentence_transformers import SentenceTransformer
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import chromadb
from chromadb.config import Settings
from typing import List, Dict
from pdf_extractor import LabResult
import torch
class LabReportRAG:
"""RAG system for explaining lab results using Hugging Face models"""
def __init__(self, db_path: str = "./chroma_db"):
"""Initialize the RAG system with Hugging Face models"""
print("🔄 Loading Hugging Face models...")
# Use smaller, faster models for embeddings
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
# Use a medical-focused or general LLM
# Options:
# - "microsoft/Phi-3-mini-4k-instruct" (good balance)
# - "google/flan-t5-base" (lighter)
# - "meta-llama/Llama-2-7b-chat-hf" (requires auth)
model_name = "microsoft/Phi-3-mini-4k-instruct"
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
self.llm = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto" if torch.cuda.is_available() else None
)
print(f"✅ Loaded model: {model_name}")
except Exception as e:
print(f"⚠️ Could not load {model_name}, falling back to simpler model")
# Fallback to lighter model
self.text_generator = pipeline(
"text-generation",
model="google/flan-t5-base",
max_length=512
)
self.llm = None
# Load vector store
try:
self.client = chromadb.PersistentClient(path=db_path)
self.collection = self.client.get_collection("lab_reports")
print("✅ Vector database loaded")
except Exception as e:
print(f"⚠️ No vector database found. Please run build_vector_db.py first.")
self.collection = None
def _generate_with_phi(self, prompt: str, max_tokens: int = 512) -> str:
"""Generate text using Phi-3 model"""
inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
if torch.cuda.is_available():
inputs = {k: v.to('cuda') for k, v in inputs.items()}
outputs = self.llm.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
do_sample=True,
top_p=0.9
)
response = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
# Remove the prompt from response
response = response.replace(prompt, "").strip()
return response
def _generate_with_fallback(self, prompt: str) -> str:
"""Generate text using fallback pipeline"""
result = self.text_generator(prompt, max_length=512, num_return_sequences=1)
return result[0]['generated_text']
def _generate_text(self, prompt: str) -> str:
"""Generate text using available model"""
try:
if self.llm is not None:
return self._generate_with_phi(prompt)
else:
return self._generate_with_fallback(prompt)
except Exception as e:
print(f"Generation error: {e}")
return "Sorry, I encountered an error generating the explanation."
def _retrieve_context(self, query: str, k: int = 3) -> str:
"""Retrieve relevant context from vector database"""
if self.collection is None:
return "No medical reference data available."
try:
# Create query embedding
query_embedding = self.embedding_model.encode(query).tolist()
# Query the collection
results = self.collection.query(
query_embeddings=[query_embedding],
n_results=k
)
# Combine documents
if results and results['documents']:
context = "\n\n".join(results['documents'][0])
return context
else:
return "No relevant information found."
except Exception as e:
print(f"Retrieval error: {e}")
return "Error retrieving medical information."
def explain_result(self, result: LabResult) -> str:
"""Generate explanation for a single lab result"""
# Retrieve relevant context
query = f"{result.test_name} {result.status} meaning causes treatment"
context = self._retrieve_context(query, k=3)
# Create prompt
prompt = f"""You are a helpful medical assistant. Explain this lab result in simple terms.
Medical Information:
{context}
Lab Test: {result.test_name}
Value: {result.value} {result.unit}
Reference Range: {result.reference_range}
Status: {result.status}
Please explain:
1. What this test measures
2. What this result means
3. Possible causes if abnormal
4. Dietary recommendations if applicable
Keep it simple and clear. Answer:"""
# Generate explanation
explanation = self._generate_text(prompt)
return explanation
def explain_all_results(self, results: List[LabResult]) -> Dict[str, str]:
"""Generate explanations for all lab results"""
explanations = {}
for result in results:
print(f"Explaining {result.test_name}...")
explanation = self.explain_result(result)
explanations[result.test_name] = explanation
return explanations
def answer_followup_question(self, question: str, lab_results: List[LabResult]) -> str:
"""Answer follow-up questions about lab results"""
# Create context from lab results
results_context = "\n".join([
f"{r.test_name}: {r.value} {r.unit} (Status: {r.status}, Range: {r.reference_range})"
for r in lab_results
])
# Retrieve relevant medical information
medical_context = self._retrieve_context(question, k=3)
# Create prompt
prompt = f"""You are a medical assistant. Answer this question based on the patient's lab results and medical information.
Patient's Lab Results:
{results_context}
Medical Information:
{medical_context}
Question: {question}
Provide a clear, helpful answer. Answer:"""
# Generate answer
answer = self._generate_text(prompt)
return answer
def generate_summary(self, results: List[LabResult]) -> str:
"""Generate overall summary of lab results"""
abnormal = [r for r in results if r.status in ['high', 'low']]
normal = [r for r in results if r.status == 'normal']
if not abnormal:
return "✅ Great news! All your lab results are within normal ranges. Keep up the good work with your health!"
# Get context about abnormal results
queries = [f"{r.test_name} {r.status}" for r in abnormal]
combined_query = " ".join(queries)
context = self._retrieve_context(combined_query, k=4)
# Create summary prompt
abnormal_list = "\n".join([
f"- {r.test_name}: {r.value} {r.unit} ({r.status})"
for r in abnormal
])
prompt = f"""Provide a brief summary of these lab results.
Normal Results: {len(normal)} tests
Abnormal Results: {len(abnormal)} tests
Abnormal Tests:
{abnormal_list}
Medical Context:
{context}
Write a 2-3 paragraph summary explaining what these results mean overall and general recommendations. Be reassuring but honest. Summary:"""
# Generate summary
summary = self._generate_text(prompt)
return summary
# Example usage
if __name__ == "__main__":
from pdf_extractor import LabResult
# Initialize RAG system
print("Initializing RAG system...")
rag = LabReportRAG()
# Example result
test_result = LabResult(
test_name="Hemoglobin",
value="10.5",
unit="g/dL",
reference_range="12.0-15.5",
status="low"
)
# Generate explanation
print("\nGenerating explanation...")
explanation = rag.explain_result(test_result)
print(f"\n{explanation}") |