gemeo-twin-stack / src /gemeo /multimodal_extract.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Multimodal clinical extraction — single image → structured entities.
Lives in the Python backend (rather than the Next.js serverless function)
because:
1. We can re-use `gemeo.extractor._kg_normalize_phenotypes` to map
free-text labels to HP:xxxxxxx IDs via the raras-app KG
phenotype_search fulltext (PT-BR coverage). The Vercel function
can't reach Neo4j directly without extra latency.
2. AI logic stays centralized — same auth/audit/redaction pipeline.
3. Render gives fixed-cost compute; Vercel function execution costs
scale per-call with vision payloads (~600KB each).
Provider chain — same model as the Next.js path (Groq Llama 4 Scout)
plus an optional Gemini fallback if the GEMINI_API_KEY is healthy.
Both share the same response shape so the front-end is agnostic to
which provider answered.
Public API:
await extract_image(image_bytes, mime, source_url=None) -> dict
"""
from __future__ import annotations
import json
import logging
import os
from base64 import b64encode
from typing import Any
import httpx
logger = logging.getLogger("gemeo.multimodal_extract")
SYSTEM_PROMPT = """Você é um extrator clínico para um sistema de doenças raras.
Recebe uma imagem (screenshot de prontuário eletrônico, PDF, laudo, planilha)
e extrai entidades estruturadas em JSON.
Regras:
- Idioma fonte é provavelmente PT-BR — preserve termos clínicos no
original em "label" e adicione tradução EN só se óbvia.
- Mapeie fenótipos a HPO IDs (HP:xxxxxxx) quando confiante. Se não
conseguir mapear com >70% de confiança, omita o "id".
- Mapeie diagnósticos a ICD-10-BR e/ou ORPHA quando confiante.
- "confidence" 0..1 conforme certeza da extração (visibilidade do
texto, ambiguidade clínica, qualidade da imagem).
- "evidence" copia a frase fonte que sustenta o achado — máximo 80 chars.
- Se a imagem não contém dado clínico, retorne arrays vazios e
free_text descrevendo o que viu.
- NÃO INVENTE dados. Se um campo não está visível, omita-o.
Retorne APENAS um objeto JSON com este schema (campos vazios viram
array/objeto vazio, jamais null):
{
"hpo": [{"id": "HP:xxxxxxx", "label": "...", "confidence": 0..1, "evidence": "..."}],
"medications": [{"name": "...", "dose": "...", "route": "...", "confidence": 0..1}],
"diagnoses": [{"name": "...", "icd10": "...", "orpha": "...", "confidence": 0..1}],
"labs": [{"name": "...", "value": "...", "unit": "...", "date": "...", "confidence": 0..1}],
"patient": {"age": "...", "sex": "...", "weight": "..."},
"free_text": "...",
"language": "pt-BR"
}"""
GROQ_URL = "https://api.groq.com/openai/v1/chat/completions"
GROQ_MODEL_DEFAULT = "meta-llama/llama-4-scout-17b-16e-instruct"
GEMINI_URL_TPL = (
"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}"
)
GEMINI_MODEL_DEFAULT = "gemini-2.5-flash"
def _normalize_entities(raw: dict[str, Any]) -> dict[str, Any]:
"""Cope with the model emitting Portuguese keys when the schema
hint loses fidelity. Map alternates back to canonical keys + ensure
arrays exist so the front-end never crashes on a missing key."""
def pick_list(*keys: str) -> list:
for k in keys:
v = raw.get(k)
if isinstance(v, list):
return v
return []
return {
"hpo": pick_list("hpo", "achados", "phenotypes", "fenotipos"),
"medications": pick_list("medications", "medicamentos", "meds", "drugs"),
"diagnoses": pick_list("diagnoses", "diagnosticos", "hipoteses", "differentials"),
"labs": pick_list("labs", "exames", "laboratory", "tests"),
"patient": raw.get("patient") or raw.get("paciente") or {},
"free_text": (
raw.get("free_text")
if isinstance(raw.get("free_text"), str)
else (raw.get("observacao") if isinstance(raw.get("observacao"), str) else "")
),
"language": raw.get("language") if isinstance(raw.get("language"), str) else "pt-BR",
}
def _parse_json_safely(text: str) -> dict[str, Any]:
cleaned = text.strip()
# Strip optional markdown code fences.
if cleaned.startswith("```"):
cleaned = cleaned.split("```", 2)[1] if "```" in cleaned[3:] else cleaned[3:]
if cleaned.startswith("json"):
cleaned = cleaned[4:]
cleaned = cleaned.strip().rstrip("`").strip()
return json.loads(cleaned)
async def _call_groq(image_b64: str, mime: str, user_prompt: str) -> dict[str, Any]:
"""Groq Llama 4 Scout vision — primary provider. ~600ms p50."""
api_key = os.getenv("GROQ_API_KEY")
if not api_key:
raise RuntimeError("GROQ_API_KEY not set")
model = os.getenv("GROQ_EXTRACT_MODEL", GROQ_MODEL_DEFAULT)
data_url = f"data:{mime};base64,{image_b64}"
payload = {
"model": model,
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": [
{"type": "text", "text": user_prompt},
{"type": "image_url", "image_url": {"url": data_url}},
],
},
],
"response_format": {"type": "json_object"},
"temperature": 0.1,
"max_tokens": 4096,
}
async with httpx.AsyncClient(timeout=60.0) as client:
r = await client.post(
GROQ_URL,
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json",
},
json=payload,
)
r.raise_for_status()
j = r.json()
text = j.get("choices", [{}])[0].get("message", {}).get("content", "")
entities = _normalize_entities(_parse_json_safely(text))
usage = j.get("usage") or {}
return {
"entities": entities,
"model": f"groq:{model}",
"tokens": {
"input": usage.get("prompt_tokens"),
"output": usage.get("completion_tokens"),
},
}
async def _call_gemini(image_b64: str, mime: str, user_prompt: str) -> dict[str, Any]:
"""Gemini 2.5 Flash vision — fallback. Uses responseSchema for
strict JSON. Native PT-BR; handles tabular reports a touch better
than Llama 4 Scout but ~3× slower."""
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise RuntimeError("GEMINI_API_KEY not set")
model = os.getenv("GEMINI_EXTRACT_MODEL", GEMINI_MODEL_DEFAULT)
url = GEMINI_URL_TPL.format(model=model, key=api_key)
payload = {
"systemInstruction": {"parts": [{"text": SYSTEM_PROMPT}]},
"contents": [
{
"role": "user",
"parts": [
{"text": user_prompt},
{"inlineData": {"mimeType": mime, "data": image_b64}},
],
}
],
"generationConfig": {
"responseMimeType": "application/json",
"temperature": 0.1,
"maxOutputTokens": 4096,
},
}
async with httpx.AsyncClient(timeout=60.0) as client:
r = await client.post(url, json=payload)
r.raise_for_status()
j = r.json()
text = j.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "")
entities = _normalize_entities(_parse_json_safely(text))
usage = j.get("usageMetadata") or {}
return {
"entities": entities,
"model": f"gemini:{model}",
"tokens": {
"input": usage.get("promptTokenCount"),
"output": usage.get("candidatesTokenCount"),
},
}
async def extract_image(
image_bytes: bytes,
mime: str = "image/png",
source_url: str | None = None,
) -> dict[str, Any]:
"""Run a screenshot / clinical image through a multimodal model and
return canonical structured entities.
Provider order: Groq Llama 4 Scout → Gemini Flash. Returns
`{entities, model, tokens, fallback?, elapsed_ms}` so the caller can
log which provider answered.
"""
import time
started = time.time()
image_b64 = b64encode(image_bytes).decode("ascii")
user_prompt = (
f"Captura de tela vinda de: {source_url}. Extraia as entidades clínicas."
if source_url
else "Extraia as entidades clínicas desta imagem."
)
errors: list[str] = []
for fn, label in [(_call_groq, "groq"), (_call_gemini, "gemini")]:
try:
result = await fn(image_b64, mime, user_prompt)
result["elapsed_ms"] = int((time.time() - started) * 1000)
if label != "groq":
result["fallback"] = True
return result
except Exception as e:
msg = str(e)[:240]
logger.warning("[multimodal_extract] %s failed: %s", label, msg)
errors.append(f"{label}: {msg}")
raise RuntimeError(f"all providers failed — {' | '.join(errors)}")