| """MedGemma HF endpoint integration for patient profile extraction.""" |
|
|
| import asyncio |
| import base64 |
| import json |
| import re |
| import time |
| from pathlib import Path |
|
|
| import structlog |
| from huggingface_hub import InferenceClient |
| from huggingface_hub.utils import HfHubHTTPError |
|
|
| from trialpath.config import ( |
| HF_TOKEN, |
| MEDGEMMA_COLD_START_TIMEOUT, |
| MEDGEMMA_ENDPOINT_URL, |
| MEDGEMMA_MAX_RETRIES, |
| MEDGEMMA_MAX_WAIT, |
| MEDGEMMA_RETRY_BACKOFF, |
| ) |
|
|
| logger = structlog.get_logger("trialpath.medgemma") |
|
|
| _MAX_RETRIES = MEDGEMMA_MAX_RETRIES |
| _RETRY_BACKOFF_BASE = MEDGEMMA_RETRY_BACKOFF |
| _RETRY_MAX_WAIT = MEDGEMMA_MAX_WAIT |
| _COLD_START_TIMEOUT = MEDGEMMA_COLD_START_TIMEOUT |
|
|
|
|
| class MedGemmaExtractor: |
| """Extract patient profiles from medical documents using MedGemma. |
| |
| Uses a HuggingFace Inference Endpoint running medgemma-1-5-4b-it-hae |
| for multimodal extraction (image-text-to-text). |
| """ |
|
|
| def __init__( |
| self, |
| endpoint_url: str | None = None, |
| hf_token: str | None = None, |
| ): |
| self.endpoint_url = endpoint_url or MEDGEMMA_ENDPOINT_URL |
| self.hf_token = hf_token or HF_TOKEN |
| self._client = InferenceClient( |
| model=self.endpoint_url, |
| token=self.hf_token, |
| ) |
|
|
| def _system_prompt(self) -> str: |
| return ( |
| "You are an expert medical data extractor specializing in oncology. " |
| "Extract structured patient information from medical documents. " |
| "Always cite the source document and location for each extracted fact. " |
| "If information is unclear or missing, explicitly note it as unknown." |
| ) |
|
|
| def _build_extraction_prompt(self, metadata: dict) -> str: |
| return f""" |
| Extract a structured patient profile from the following medical documents. |
| |
| Known metadata: age={metadata.get("age", "unknown")}, sex={metadata.get("sex", "unknown")} |
| |
| Extract the following fields in JSON format: |
| - diagnosis (primary_condition, histology, stage, diagnosis_date) |
| - performance_status (scale, value, evidence) |
| - biomarkers (name, result, date, evidence for each) |
| - key_labs (name, value, unit, date, evidence for each) |
| - treatments (drug_name, start_date, end_date, line, evidence) |
| - comorbidities (name, grade, evidence) |
| - imaging_summary (modality, date, finding, interpretation, certainty, evidence) |
| - unknowns (field, reason, importance for each missing critical field) |
| |
| For each evidence reference, include: doc_id (filename), page number, span_id. |
| |
| Return ONLY valid JSON matching the PatientProfile schema. |
| """ |
|
|
| def _parse_profile(self, raw_text: str, metadata: dict) -> dict: |
| """Parse MedGemma output into PatientProfile structure.""" |
| try: |
| profile = json.loads(raw_text) |
| except json.JSONDecodeError: |
| json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", raw_text, re.DOTALL) |
| if json_match: |
| profile = json.loads(json_match.group(1)) |
| else: |
| raise ValueError(f"Could not parse MedGemma output as JSON: {raw_text[:200]}") |
|
|
| if "demographics" not in profile: |
| profile["demographics"] = {} |
| profile["demographics"].update(metadata) |
|
|
| return profile |
|
|
| @staticmethod |
| def _load_documents(document_paths: list[str]) -> list[dict]: |
| """Load documents and encode images as base64 for multimodal input.""" |
| content_parts: list[dict] = [] |
| image_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"} |
|
|
| for path_str in document_paths: |
| path = Path(path_str) |
| if not path.exists(): |
| continue |
|
|
| if path.suffix.lower() in image_extensions: |
| with open(path, "rb") as f: |
| b64 = base64.b64encode(f.read()).decode("utf-8") |
| mime = "image/png" if path.suffix.lower() == ".png" else "image/jpeg" |
| content_parts.append( |
| { |
| "type": "image_url", |
| "image_url": {"url": f"data:{mime};base64,{b64}"}, |
| } |
| ) |
| else: |
| content_parts.append( |
| { |
| "type": "text", |
| "text": f"[Document: {path.name}]", |
| } |
| ) |
|
|
| return content_parts |
|
|
| @staticmethod |
| def _format_gemma_prompt(messages: list[dict]) -> str: |
| """Convert chat messages to Gemma chat template format. |
| |
| The HF endpoint uses the default inference image (not TGI), so |
| /v1/chat/completions is unavailable. We format manually and call |
| text_generation() on the root endpoint instead. |
| |
| Template: <start_of_turn>role\ncontent<end_of_turn> |
| """ |
| parts: list[str] = [] |
| for msg in messages: |
| role = msg["role"] |
| content = msg["content"] |
| if isinstance(content, list): |
| |
| text_parts = [p["text"] for p in content if p.get("type") == "text"] |
| content = "\n".join(text_parts) |
| if role == "system": |
| |
| parts.append(f"<start_of_turn>user\n{content}") |
| continue |
| if role == "user": |
| if parts and parts[-1].startswith("<start_of_turn>user"): |
| |
| parts[-1] += f"\n\n{content}<end_of_turn>" |
| else: |
| parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
| elif role == "model" or role == "assistant": |
| parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
| |
| parts.append("<start_of_turn>model\n") |
| return "\n".join(parts) |
|
|
| async def _call_with_retry(self, messages: list[dict], max_tokens: int) -> str: |
| """Call HF endpoint with retry logic for cold-start 503 errors. |
| |
| Uses text_generation() with Gemma chat template (the endpoint |
| runs the default HF inference image, not TGI, so /v1/chat/completions |
| is unavailable — returns 404). |
| |
| - 4XX errors fail immediately (client error, no retry). |
| - 503 errors retry with exponential backoff (cold-start). |
| - Total retry budget capped at _COLD_START_TIMEOUT seconds. |
| - Individual wait capped at _RETRY_MAX_WAIT seconds. |
| """ |
| prompt = self._format_gemma_prompt(messages) |
| start = time.monotonic() |
| last_error = None |
| for attempt in range(_MAX_RETRIES): |
| elapsed = time.monotonic() - start |
| if elapsed > _COLD_START_TIMEOUT: |
| break |
| attempt_start = time.monotonic() |
| try: |
| content = await asyncio.to_thread( |
| self._client.text_generation, |
| prompt=prompt, |
| max_new_tokens=max_tokens, |
| ) |
| if not content: |
| raise ValueError("MedGemma returned empty content") |
| attempt_elapsed = time.monotonic() - attempt_start |
| logger.info( |
| "medgemma_call", |
| attempt=attempt + 1, |
| max_tokens=max_tokens, |
| duration_s=round(attempt_elapsed, 2), |
| ) |
| return content |
| except HfHubHTTPError as e: |
| last_error = e |
| status = getattr(getattr(e, "response", None), "status_code", None) |
| if status and 400 <= status < 500: |
| logger.error( |
| "medgemma_client_error", |
| status_code=status, |
| error=str(e)[:200], |
| ) |
| raise |
| if status == 503 and attempt < _MAX_RETRIES - 1: |
| wait = min(_RETRY_BACKOFF_BASE**attempt, _RETRY_MAX_WAIT) |
| logger.warning( |
| "medgemma_cold_start", |
| status_code=503, |
| attempt=attempt + 1, |
| max_retries=_MAX_RETRIES, |
| wait_s=round(wait, 0), |
| ) |
| await asyncio.sleep(wait) |
| continue |
| raise |
| raise last_error or RuntimeError("MedGemma retry budget exhausted") |
|
|
| async def extract(self, document_urls: list[str], metadata: dict) -> dict: |
| """Extract PatientProfile from documents via MedGemma HF Endpoint.""" |
| prompt_text = self._build_extraction_prompt(metadata) |
| doc_parts = self._load_documents(document_urls) |
|
|
| has_images = any(p.get("type") == "image_url" for p in doc_parts) |
|
|
| if has_images: |
| content: str | list[dict] = [{"type": "text", "text": prompt_text}] + doc_parts |
| else: |
| doc_texts = [p["text"] for p in doc_parts if p.get("type") == "text"] |
| content = prompt_text + "\n\n" + "\n".join(doc_texts) if doc_texts else prompt_text |
|
|
| messages = [ |
| {"role": "system", "content": self._system_prompt()}, |
| {"role": "user", "content": content}, |
| ] |
|
|
| raw_text = await self._call_with_retry(messages, max_tokens=2048) |
| return self._parse_profile(raw_text, metadata) |
|
|
| async def evaluate_medical_criterion( |
| self, |
| criterion_text: str, |
| patient_profile: object, |
| evidence_docs: list, |
| ) -> dict: |
| """Evaluate a single medical criterion against patient evidence.""" |
| prompt = f""" |
| Evaluate whether the patient meets this clinical trial criterion |
| based on the patient profile and any supporting evidence. |
| |
| CRITERION: {criterion_text} |
| |
| Patient Profile: |
| {json.dumps(patient_profile, indent=2, default=str)} |
| |
| Respond with JSON: |
| {{"decision": "met|not_met|unknown", "reasoning": "...", "confidence": 0.0-1.0}} |
| |
| Rules: |
| - "met" only if the profile provides clear evidence the criterion is satisfied |
| - "not_met" if the profile contradicts the criterion |
| - "unknown" if the profile lacks sufficient data to determine |
| - Cite specific profile fields in your reasoning |
| """ |
|
|
| doc_parts = self._load_documents(evidence_docs) if evidence_docs else [] |
| has_images = any(p.get("type") == "image_url" for p in doc_parts) |
|
|
| if has_images: |
| content: str | list[dict] = [{"type": "text", "text": prompt}] + doc_parts |
| else: |
| doc_texts = [p["text"] for p in doc_parts if p.get("type") == "text"] |
| content = prompt + "\n\n" + "\n".join(doc_texts) if doc_texts else prompt |
|
|
| messages = [ |
| {"role": "system", "content": self._system_prompt()}, |
| {"role": "user", "content": content}, |
| ] |
|
|
| raw_text = await self._call_with_retry(messages, max_tokens=1024) |
| return json.loads(raw_text) |
|
|