TrialPath / trialpath /services /medgemma_extractor.py
yakilee's picture
feat(medgemma): switch to text_generation with Gemma chat template
cd8cf7c
"""MedGemma HF endpoint integration for patient profile extraction."""
import asyncio
import base64
import json
import re
import time
from pathlib import Path
import structlog
from huggingface_hub import InferenceClient
from huggingface_hub.utils import HfHubHTTPError
from trialpath.config import (
HF_TOKEN,
MEDGEMMA_COLD_START_TIMEOUT,
MEDGEMMA_ENDPOINT_URL,
MEDGEMMA_MAX_RETRIES,
MEDGEMMA_MAX_WAIT,
MEDGEMMA_RETRY_BACKOFF,
)
logger = structlog.get_logger("trialpath.medgemma")
_MAX_RETRIES = MEDGEMMA_MAX_RETRIES
_RETRY_BACKOFF_BASE = MEDGEMMA_RETRY_BACKOFF
_RETRY_MAX_WAIT = MEDGEMMA_MAX_WAIT
_COLD_START_TIMEOUT = MEDGEMMA_COLD_START_TIMEOUT
class MedGemmaExtractor:
"""Extract patient profiles from medical documents using MedGemma.
Uses a HuggingFace Inference Endpoint running medgemma-1-5-4b-it-hae
for multimodal extraction (image-text-to-text).
"""
def __init__(
self,
endpoint_url: str | None = None,
hf_token: str | None = None,
):
self.endpoint_url = endpoint_url or MEDGEMMA_ENDPOINT_URL
self.hf_token = hf_token or HF_TOKEN
self._client = InferenceClient(
model=self.endpoint_url,
token=self.hf_token,
)
def _system_prompt(self) -> str:
return (
"You are an expert medical data extractor specializing in oncology. "
"Extract structured patient information from medical documents. "
"Always cite the source document and location for each extracted fact. "
"If information is unclear or missing, explicitly note it as unknown."
)
def _build_extraction_prompt(self, metadata: dict) -> str:
return f"""
Extract a structured patient profile from the following medical documents.
Known metadata: age={metadata.get("age", "unknown")}, sex={metadata.get("sex", "unknown")}
Extract the following fields in JSON format:
- diagnosis (primary_condition, histology, stage, diagnosis_date)
- performance_status (scale, value, evidence)
- biomarkers (name, result, date, evidence for each)
- key_labs (name, value, unit, date, evidence for each)
- treatments (drug_name, start_date, end_date, line, evidence)
- comorbidities (name, grade, evidence)
- imaging_summary (modality, date, finding, interpretation, certainty, evidence)
- unknowns (field, reason, importance for each missing critical field)
For each evidence reference, include: doc_id (filename), page number, span_id.
Return ONLY valid JSON matching the PatientProfile schema.
"""
def _parse_profile(self, raw_text: str, metadata: dict) -> dict:
"""Parse MedGemma output into PatientProfile structure."""
try:
profile = json.loads(raw_text)
except json.JSONDecodeError:
json_match = re.search(r"```(?:json)?\s*(.*?)\s*```", raw_text, re.DOTALL)
if json_match:
profile = json.loads(json_match.group(1))
else:
raise ValueError(f"Could not parse MedGemma output as JSON: {raw_text[:200]}")
if "demographics" not in profile:
profile["demographics"] = {}
profile["demographics"].update(metadata)
return profile
@staticmethod
def _load_documents(document_paths: list[str]) -> list[dict]:
"""Load documents and encode images as base64 for multimodal input."""
content_parts: list[dict] = []
image_extensions = {".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp"}
for path_str in document_paths:
path = Path(path_str)
if not path.exists():
continue
if path.suffix.lower() in image_extensions:
with open(path, "rb") as f:
b64 = base64.b64encode(f.read()).decode("utf-8")
mime = "image/png" if path.suffix.lower() == ".png" else "image/jpeg"
content_parts.append(
{
"type": "image_url",
"image_url": {"url": f"data:{mime};base64,{b64}"},
}
)
else:
content_parts.append(
{
"type": "text",
"text": f"[Document: {path.name}]",
}
)
return content_parts
@staticmethod
def _format_gemma_prompt(messages: list[dict]) -> str:
"""Convert chat messages to Gemma chat template format.
The HF endpoint uses the default inference image (not TGI), so
/v1/chat/completions is unavailable. We format manually and call
text_generation() on the root endpoint instead.
Template: <start_of_turn>role\ncontent<end_of_turn>
"""
parts: list[str] = []
for msg in messages:
role = msg["role"]
content = msg["content"]
if isinstance(content, list):
# Multimodal content — extract text parts only for now
text_parts = [p["text"] for p in content if p.get("type") == "text"]
content = "\n".join(text_parts)
if role == "system":
# Gemma folds system into the first user turn
parts.append(f"<start_of_turn>user\n{content}")
continue
if role == "user":
if parts and parts[-1].startswith("<start_of_turn>user"):
# Append to existing user turn (system was prepended)
parts[-1] += f"\n\n{content}<end_of_turn>"
else:
parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
elif role == "model" or role == "assistant":
parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
# Open the model turn for generation
parts.append("<start_of_turn>model\n")
return "\n".join(parts)
async def _call_with_retry(self, messages: list[dict], max_tokens: int) -> str:
"""Call HF endpoint with retry logic for cold-start 503 errors.
Uses text_generation() with Gemma chat template (the endpoint
runs the default HF inference image, not TGI, so /v1/chat/completions
is unavailable — returns 404).
- 4XX errors fail immediately (client error, no retry).
- 503 errors retry with exponential backoff (cold-start).
- Total retry budget capped at _COLD_START_TIMEOUT seconds.
- Individual wait capped at _RETRY_MAX_WAIT seconds.
"""
prompt = self._format_gemma_prompt(messages)
start = time.monotonic()
last_error = None
for attempt in range(_MAX_RETRIES):
elapsed = time.monotonic() - start
if elapsed > _COLD_START_TIMEOUT:
break
attempt_start = time.monotonic()
try:
content = await asyncio.to_thread(
self._client.text_generation,
prompt=prompt,
max_new_tokens=max_tokens,
)
if not content:
raise ValueError("MedGemma returned empty content")
attempt_elapsed = time.monotonic() - attempt_start
logger.info(
"medgemma_call",
attempt=attempt + 1,
max_tokens=max_tokens,
duration_s=round(attempt_elapsed, 2),
)
return content
except HfHubHTTPError as e:
last_error = e
status = getattr(getattr(e, "response", None), "status_code", None)
if status and 400 <= status < 500:
logger.error(
"medgemma_client_error",
status_code=status,
error=str(e)[:200],
)
raise
if status == 503 and attempt < _MAX_RETRIES - 1:
wait = min(_RETRY_BACKOFF_BASE**attempt, _RETRY_MAX_WAIT)
logger.warning(
"medgemma_cold_start",
status_code=503,
attempt=attempt + 1,
max_retries=_MAX_RETRIES,
wait_s=round(wait, 0),
)
await asyncio.sleep(wait)
continue
raise
raise last_error or RuntimeError("MedGemma retry budget exhausted")
async def extract(self, document_urls: list[str], metadata: dict) -> dict:
"""Extract PatientProfile from documents via MedGemma HF Endpoint."""
prompt_text = self._build_extraction_prompt(metadata)
doc_parts = self._load_documents(document_urls)
has_images = any(p.get("type") == "image_url" for p in doc_parts)
if has_images:
content: str | list[dict] = [{"type": "text", "text": prompt_text}] + doc_parts
else:
doc_texts = [p["text"] for p in doc_parts if p.get("type") == "text"]
content = prompt_text + "\n\n" + "\n".join(doc_texts) if doc_texts else prompt_text
messages = [
{"role": "system", "content": self._system_prompt()},
{"role": "user", "content": content},
]
raw_text = await self._call_with_retry(messages, max_tokens=2048)
return self._parse_profile(raw_text, metadata)
async def evaluate_medical_criterion(
self,
criterion_text: str,
patient_profile: object,
evidence_docs: list,
) -> dict:
"""Evaluate a single medical criterion against patient evidence."""
prompt = f"""
Evaluate whether the patient meets this clinical trial criterion
based on the patient profile and any supporting evidence.
CRITERION: {criterion_text}
Patient Profile:
{json.dumps(patient_profile, indent=2, default=str)}
Respond with JSON:
{{"decision": "met|not_met|unknown", "reasoning": "...", "confidence": 0.0-1.0}}
Rules:
- "met" only if the profile provides clear evidence the criterion is satisfied
- "not_met" if the profile contradicts the criterion
- "unknown" if the profile lacks sufficient data to determine
- Cite specific profile fields in your reasoning
"""
doc_parts = self._load_documents(evidence_docs) if evidence_docs else []
has_images = any(p.get("type") == "image_url" for p in doc_parts)
if has_images:
content: str | list[dict] = [{"type": "text", "text": prompt}] + doc_parts
else:
doc_texts = [p["text"] for p in doc_parts if p.get("type") == "text"]
content = prompt + "\n\n" + "\n".join(doc_texts) if doc_texts else prompt
messages = [
{"role": "system", "content": self._system_prompt()},
{"role": "user", "content": content},
]
raw_text = await self._call_with_retry(messages, max_tokens=1024)
return json.loads(raw_text)