""" Biomarker Extraction Service Extracts biomarker values from natural language text using LLM """ import json import sys from pathlib import Path from typing import Any # Ensure project root is in path for src imports _project_root = str(Path(__file__).parent.parent.parent.parent) if _project_root not in sys.path: sys.path.insert(0, _project_root) from langchain_core.prompts import ChatPromptTemplate from src.biomarker_normalization import normalize_biomarker_name from src.llm_config import get_chat_model # ============================================================================ # EXTRACTION PROMPT # ============================================================================ BIOMARKER_EXTRACTION_PROMPT = """You are a medical data extraction assistant. Extract biomarker values from the user's message. Known biomarkers (24 total): Glucose, Cholesterol, Triglycerides, HbA1c, LDL, HDL, Insulin, BMI, Hemoglobin, Platelets, WBC (White Blood Cells), RBC (Red Blood Cells), Hematocrit, MCV, MCH, MCHC, Heart Rate, Systolic BP, Diastolic BP, Troponin, C-reactive Protein, ALT, AST, Creatinine User message: {user_message} Extract all biomarker names and their values. Return ONLY valid JSON (no other text): {{ "biomarkers": {{ "Glucose": 140, "HbA1c": 7.5 }}, "patient_context": {{ "age": null, "gender": null, "bmi": null }} }} If you cannot find any biomarkers, return {{"biomarkers": {{}}, "patient_context": {{}}}}. """ # ============================================================================ # EXTRACTION HELPERS # ============================================================================ def _parse_llm_json(content: str) -> dict[str, Any]: """Parse JSON payload from LLM output with fallback recovery.""" text = content.strip() if "```json" in text: text = text.split("```json")[1].split("```")[0].strip() elif "```" in text: text = text.split("```")[1].split("```")[0].strip() try: return json.loads(text) except json.JSONDecodeError: left = text.find("{") right = text.rfind("}") if left != -1 and right != -1 and right > left: return json.loads(text[left : right + 1]) raise # ============================================================================ # EXTRACTION FUNCTION # ============================================================================ def extract_biomarkers( user_message: str, ollama_base_url: str | None = None, # Kept for backward compatibility, ignored ) -> tuple[dict[str, float], dict[str, Any], str]: """ Extract biomarker values from natural language using LLM. Args: user_message: Natural language text containing biomarker information ollama_base_url: DEPRECATED - uses cloud LLM (Groq/Gemini) instead Returns: Tuple of (biomarkers_dict, patient_context_dict, error_message) - biomarkers_dict: Normalized biomarker names -> values - patient_context_dict: Extracted patient context (age, gender, BMI) - error_message: Empty string if successful, error description if failed Example: >>> biomarkers, context, error = extract_biomarkers("My glucose is 185 and HbA1c is 8.2") >>> print(biomarkers) {'Glucose': 185.0, 'HbA1c': 8.2} """ try: # Initialize LLM (uses Groq/Gemini by default - FREE) llm = get_chat_model(temperature=0.0) prompt = ChatPromptTemplate.from_template(BIOMARKER_EXTRACTION_PROMPT) chain = prompt | llm # Invoke LLM response = chain.invoke({"user_message": user_message}) content = response.content.strip() extracted = _parse_llm_json(content) biomarkers = extracted.get("biomarkers", {}) patient_context = extracted.get("patient_context", {}) # Normalize biomarker names and convert to float normalized = {} for key, value in biomarkers.items(): try: standard_name = normalize_biomarker_name(key) normalized[standard_name] = float(value) except (ValueError, TypeError): # Skip invalid values continue # Clean up patient context (remove null values) patient_context = {k: v for k, v in patient_context.items() if v is not None} if not normalized: return {}, patient_context, "No biomarkers found in the input" return normalized, patient_context, "" except json.JSONDecodeError as e: return {}, {}, f"Failed to parse LLM response as JSON: {e!s}" except Exception as e: return {}, {}, f"Extraction failed: {e!s}" # ============================================================================ # SIMPLE DISEASE PREDICTION (Fallback) # ============================================================================ def predict_disease_simple(biomarkers: dict[str, float]) -> dict[str, Any]: """ Simple rule-based disease prediction based on key biomarkers. Used as a fallback when no ML model is available. Args: biomarkers: Dictionary of biomarker names to values Returns: Dictionary with disease, confidence, and probabilities """ scores = {"Diabetes": 0.0, "Anemia": 0.0, "Heart Disease": 0.0, "Thrombocytopenia": 0.0, "Thalassemia": 0.0} # Helper: check both abbreviated and normalized biomarker names # Returns None when biomarker is not present (avoids false triggers) def _get(name, *alt_names): val = biomarkers.get(name) if val is not None: return val for alt in alt_names: val = biomarkers.get(alt) if val is not None: return val return None # Diabetes indicators glucose = _get("Glucose") hba1c = _get("HbA1c") if glucose is not None and glucose > 126: scores["Diabetes"] += 0.4 if glucose is not None and glucose > 180: scores["Diabetes"] += 0.2 if hba1c is not None and hba1c >= 6.5: scores["Diabetes"] += 0.5 # Anemia indicators hemoglobin = _get("Hemoglobin") mcv = _get("Mean Corpuscular Volume", "MCV") if hemoglobin is not None and hemoglobin < 12.0: scores["Anemia"] += 0.6 if hemoglobin is not None and hemoglobin < 10.0: scores["Anemia"] += 0.2 if mcv is not None and mcv < 80: scores["Anemia"] += 0.2 # Heart disease indicators cholesterol = _get("Cholesterol") troponin = _get("Troponin") ldl = _get("LDL Cholesterol", "LDL") if cholesterol is not None and cholesterol > 240: scores["Heart Disease"] += 0.3 if troponin is not None and troponin > 0.04: scores["Heart Disease"] += 0.6 if ldl is not None and ldl > 190: scores["Heart Disease"] += 0.2 # Thrombocytopenia indicators platelets = _get("Platelets") if platelets is not None and platelets < 150000: scores["Thrombocytopenia"] += 0.6 if platelets is not None and platelets < 50000: scores["Thrombocytopenia"] += 0.3 # Thalassemia indicators (simplified) if mcv is not None and hemoglobin is not None and mcv < 80 and hemoglobin < 12.0: scores["Thalassemia"] += 0.4 # Find top prediction top_disease = max(scores, key=scores.get) confidence = min(scores[top_disease], 1.0) # Cap at 1.0 for Pydantic validation if confidence == 0.0: top_disease = "Undetermined" # Normalize probabilities to sum to 1.0 total = sum(scores.values()) if total > 0: probabilities = {k: v / total for k, v in scores.items()} else: probabilities = {k: 1.0 / len(scores) for k in scores} return {"disease": top_disease, "confidence": confidence, "probabilities": probabilities}