"""Multimodal clinical extraction — single image → structured entities. Lives in the Python backend (rather than the Next.js serverless function) because: 1. We can re-use `gemeo.extractor._kg_normalize_phenotypes` to map free-text labels to HP:xxxxxxx IDs via the raras-app KG phenotype_search fulltext (PT-BR coverage). The Vercel function can't reach Neo4j directly without extra latency. 2. AI logic stays centralized — same auth/audit/redaction pipeline. 3. Render gives fixed-cost compute; Vercel function execution costs scale per-call with vision payloads (~600KB each). Provider chain — same model as the Next.js path (Groq Llama 4 Scout) plus an optional Gemini fallback if the GEMINI_API_KEY is healthy. Both share the same response shape so the front-end is agnostic to which provider answered. Public API: await extract_image(image_bytes, mime, source_url=None) -> dict """ from __future__ import annotations import json import logging import os from base64 import b64encode from typing import Any import httpx logger = logging.getLogger("gemeo.multimodal_extract") SYSTEM_PROMPT = """Você é um extrator clínico para um sistema de doenças raras. Recebe uma imagem (screenshot de prontuário eletrônico, PDF, laudo, planilha) e extrai entidades estruturadas em JSON. Regras: - Idioma fonte é provavelmente PT-BR — preserve termos clínicos no original em "label" e adicione tradução EN só se óbvia. - Mapeie fenótipos a HPO IDs (HP:xxxxxxx) quando confiante. Se não conseguir mapear com >70% de confiança, omita o "id". - Mapeie diagnósticos a ICD-10-BR e/ou ORPHA quando confiante. - "confidence" 0..1 conforme certeza da extração (visibilidade do texto, ambiguidade clínica, qualidade da imagem). - "evidence" copia a frase fonte que sustenta o achado — máximo 80 chars. - Se a imagem não contém dado clínico, retorne arrays vazios e free_text descrevendo o que viu. - NÃO INVENTE dados. Se um campo não está visível, omita-o. Retorne APENAS um objeto JSON com este schema (campos vazios viram array/objeto vazio, jamais null): { "hpo": [{"id": "HP:xxxxxxx", "label": "...", "confidence": 0..1, "evidence": "..."}], "medications": [{"name": "...", "dose": "...", "route": "...", "confidence": 0..1}], "diagnoses": [{"name": "...", "icd10": "...", "orpha": "...", "confidence": 0..1}], "labs": [{"name": "...", "value": "...", "unit": "...", "date": "...", "confidence": 0..1}], "patient": {"age": "...", "sex": "...", "weight": "..."}, "free_text": "...", "language": "pt-BR" }""" GROQ_URL = "https://api.groq.com/openai/v1/chat/completions" GROQ_MODEL_DEFAULT = "meta-llama/llama-4-scout-17b-16e-instruct" GEMINI_URL_TPL = ( "https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}" ) GEMINI_MODEL_DEFAULT = "gemini-2.5-flash" def _normalize_entities(raw: dict[str, Any]) -> dict[str, Any]: """Cope with the model emitting Portuguese keys when the schema hint loses fidelity. Map alternates back to canonical keys + ensure arrays exist so the front-end never crashes on a missing key.""" def pick_list(*keys: str) -> list: for k in keys: v = raw.get(k) if isinstance(v, list): return v return [] return { "hpo": pick_list("hpo", "achados", "phenotypes", "fenotipos"), "medications": pick_list("medications", "medicamentos", "meds", "drugs"), "diagnoses": pick_list("diagnoses", "diagnosticos", "hipoteses", "differentials"), "labs": pick_list("labs", "exames", "laboratory", "tests"), "patient": raw.get("patient") or raw.get("paciente") or {}, "free_text": ( raw.get("free_text") if isinstance(raw.get("free_text"), str) else (raw.get("observacao") if isinstance(raw.get("observacao"), str) else "") ), "language": raw.get("language") if isinstance(raw.get("language"), str) else "pt-BR", } def _parse_json_safely(text: str) -> dict[str, Any]: cleaned = text.strip() # Strip optional markdown code fences. if cleaned.startswith("```"): cleaned = cleaned.split("```", 2)[1] if "```" in cleaned[3:] else cleaned[3:] if cleaned.startswith("json"): cleaned = cleaned[4:] cleaned = cleaned.strip().rstrip("`").strip() return json.loads(cleaned) async def _call_groq(image_b64: str, mime: str, user_prompt: str) -> dict[str, Any]: """Groq Llama 4 Scout vision — primary provider. ~600ms p50.""" api_key = os.getenv("GROQ_API_KEY") if not api_key: raise RuntimeError("GROQ_API_KEY not set") model = os.getenv("GROQ_EXTRACT_MODEL", GROQ_MODEL_DEFAULT) data_url = f"data:{mime};base64,{image_b64}" payload = { "model": model, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": [ {"type": "text", "text": user_prompt}, {"type": "image_url", "image_url": {"url": data_url}}, ], }, ], "response_format": {"type": "json_object"}, "temperature": 0.1, "max_tokens": 4096, } async with httpx.AsyncClient(timeout=60.0) as client: r = await client.post( GROQ_URL, headers={ "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", }, json=payload, ) r.raise_for_status() j = r.json() text = j.get("choices", [{}])[0].get("message", {}).get("content", "") entities = _normalize_entities(_parse_json_safely(text)) usage = j.get("usage") or {} return { "entities": entities, "model": f"groq:{model}", "tokens": { "input": usage.get("prompt_tokens"), "output": usage.get("completion_tokens"), }, } async def _call_gemini(image_b64: str, mime: str, user_prompt: str) -> dict[str, Any]: """Gemini 2.5 Flash vision — fallback. Uses responseSchema for strict JSON. Native PT-BR; handles tabular reports a touch better than Llama 4 Scout but ~3× slower.""" api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise RuntimeError("GEMINI_API_KEY not set") model = os.getenv("GEMINI_EXTRACT_MODEL", GEMINI_MODEL_DEFAULT) url = GEMINI_URL_TPL.format(model=model, key=api_key) payload = { "systemInstruction": {"parts": [{"text": SYSTEM_PROMPT}]}, "contents": [ { "role": "user", "parts": [ {"text": user_prompt}, {"inlineData": {"mimeType": mime, "data": image_b64}}, ], } ], "generationConfig": { "responseMimeType": "application/json", "temperature": 0.1, "maxOutputTokens": 4096, }, } async with httpx.AsyncClient(timeout=60.0) as client: r = await client.post(url, json=payload) r.raise_for_status() j = r.json() text = j.get("candidates", [{}])[0].get("content", {}).get("parts", [{}])[0].get("text", "") entities = _normalize_entities(_parse_json_safely(text)) usage = j.get("usageMetadata") or {} return { "entities": entities, "model": f"gemini:{model}", "tokens": { "input": usage.get("promptTokenCount"), "output": usage.get("candidatesTokenCount"), }, } async def extract_image( image_bytes: bytes, mime: str = "image/png", source_url: str | None = None, ) -> dict[str, Any]: """Run a screenshot / clinical image through a multimodal model and return canonical structured entities. Provider order: Groq Llama 4 Scout → Gemini Flash. Returns `{entities, model, tokens, fallback?, elapsed_ms}` so the caller can log which provider answered. """ import time started = time.time() image_b64 = b64encode(image_bytes).decode("ascii") user_prompt = ( f"Captura de tela vinda de: {source_url}. Extraia as entidades clínicas." if source_url else "Extraia as entidades clínicas desta imagem." ) errors: list[str] = [] for fn, label in [(_call_groq, "groq"), (_call_gemini, "gemini")]: try: result = await fn(image_b64, mime, user_prompt) result["elapsed_ms"] = int((time.time() - started) * 1000) if label != "groq": result["fallback"] = True return result except Exception as e: msg = str(e)[:240] logger.warning("[multimodal_extract] %s failed: %s", label, msg) errors.append(f"{label}: {msg}") raise RuntimeError(f"all providers failed — {' | '.join(errors)}")