timmers's picture
GEMEO world-model β€” initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""FastAPI router β€” /api/gemeo/*
Mount in main.py:
from gemeo.api import router as gemeo_router
app.include_router(gemeo_router)
Endpoints:
POST /api/gemeo/build β€” create twin from case text
GET /api/gemeo/{id} β€” full twin
POST /api/gemeo/{id}/evolve β€” evolve with new data
POST /api/gemeo/{id}/whatif β€” counterfactual
POST /api/gemeo/{id}/feedback β€” record user correction
GET /api/gemeo/{id}/cohort
GET /api/gemeo/{id}/subgraph
GET /api/gemeo/{id}/trajectory
GET /api/gemeo/{id}/risk
GET /api/gemeo/{id}/drugs
GET /api/gemeo/{id}/trials
GET /api/gemeo/{id}/next-questions
GET /api/gemeo/{id}/sus
GET /api/gemeo/{id}/viz
GET /api/gemeo/health
"""
from __future__ import annotations
import logging
from dataclasses import asdict
from typing import Optional
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from . import core as gcore
from . import bridge, feedback
logger = logging.getLogger("gemeo.api")
router = APIRouter(prefix="/api/gemeo", tags=["gemeo"])
# ─── request models ────────────────────────────────────────────────────────
class BuildRequest(BaseModel):
case_text: str = Field(..., min_length=10)
patient_info: dict = Field(default_factory=dict)
context: dict = Field(default_factory=dict)
run_diagnosis: bool = True
cohort_k: int = 10
horizons_months: list[int] = Field(default_factory=lambda: [6, 12, 24])
fast: bool = False
class EvolveRequest(BaseModel):
new_phenotypes: list = Field(default_factory=list)
new_genes: list = Field(default_factory=list)
new_labs: list = Field(default_factory=list)
new_treatments: list = Field(default_factory=list)
class WhatIfRequest(BaseModel):
intervention: dict = Field(...)
class FeedbackRequest(BaseModel):
kind: str
target: dict
user_correction: dict
user_id: Optional[str] = None
comment: Optional[str] = None
# ─── helpers ───────────────────────────────────────────────────────────────
def _twin_to_dict(twin) -> dict:
if twin is None:
return None
return asdict(twin)
# ─── endpoints ─────────────────────────────────────────────────────────────
@router.get("/health")
async def health():
return {
"ok": True,
"module": "gemeo",
"version": "0.1.0",
"bridge": bridge.stats(),
"feedback": feedback.stats(),
}
@router.post("/build")
async def build(req: BuildRequest):
twin = await gcore.build_gemeo(
case_text=req.case_text,
patient_info=req.patient_info,
context=req.context,
run_diagnosis=req.run_diagnosis,
cohort_k=req.cohort_k,
horizons_months=req.horizons_months,
fast=req.fast,
)
return _twin_to_dict(twin)
@router.get("/{case_id}")
async def get_twin(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None:
raise HTTPException(404, f"twin not found for case {case_id}")
return _twin_to_dict(twin)
@router.post("/{case_id}/evolve")
async def evolve(case_id: str, req: EvolveRequest):
twin = await gcore.evolve_gemeo(
case_id,
new_phenotypes=req.new_phenotypes,
new_genes=req.new_genes,
new_labs=req.new_labs,
new_treatments=req.new_treatments,
)
if twin is None:
raise HTTPException(404, f"twin not found")
return _twin_to_dict(twin)
@router.post("/{case_id}/whatif")
async def whatif(case_id: str, req: WhatIfRequest):
out = await gcore.what_if(case_id, req.intervention)
if out is None:
raise HTTPException(404, f"twin not found")
return out
@router.post("/{case_id}/feedback")
async def record_feedback(case_id: str, req: FeedbackRequest):
twin = gcore.get_gemeo(case_id)
twin_id = twin.id if twin else case_id
rec = feedback.record(
twin_id=twin_id,
case_id=case_id,
kind=req.kind,
target=req.target,
user_correction=req.user_correction,
user_id=req.user_id,
comment=req.comment,
)
return {"ok": True, "recorded": rec}
@router.get("/{case_id}/cohort")
async def get_cohort(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.cohort is None:
raise HTTPException(404, "cohort not available")
return asdict(twin.cohort)
@router.get("/{case_id}/subgraph")
async def get_subgraph(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.subgraph is None:
raise HTTPException(404, "subgraph not available")
return asdict(twin.subgraph)
@router.get("/{case_id}/trajectory")
async def get_trajectory(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.trajectory is None:
raise HTTPException(404, "trajectory not available")
return asdict(twin.trajectory)
@router.get("/{case_id}/risk")
async def get_risk(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.risk is None:
raise HTTPException(404, "risk not available")
return asdict(twin.risk)
@router.get("/{case_id}/drugs")
async def get_drugs(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.drugs is None:
raise HTTPException(404, "drugs not available")
return asdict(twin.drugs)
@router.get("/{case_id}/trials")
async def get_trials(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.trials is None:
raise HTTPException(404, "trials not available")
return asdict(twin.trials)
@router.get("/{case_id}/next-questions")
async def get_next_questions(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None:
raise HTTPException(404, "twin not found")
return [asdict(q) for q in (twin.next_questions or [])]
@router.get("/{case_id}/sus")
async def get_sus(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.sus_check is None:
raise HTTPException(404, "sus check not available")
return asdict(twin.sus_check)
@router.get("/{case_id}/viz")
async def get_viz(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.viz_data is None:
raise HTTPException(404, "viz not available")
return asdict(twin.viz_data)
# ─── Phase-2 case-driven endpoints ────────────────────────────────────────
@router.get("/{case_id}/ddi")
async def get_ddi(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.ddi is None:
raise HTTPException(404, "DDI analysis not available")
return asdict(twin.ddi)
@router.get("/{case_id}/pharmacogen")
async def get_pharmacogen(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.pharmacogen is None:
raise HTTPException(404, "pharmacogen analysis not available")
return asdict(twin.pharmacogen)
@router.get("/{case_id}/family")
async def get_family(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.family is None:
raise HTTPException(404, "family analysis not available")
return asdict(twin.family)
@router.get("/{case_id}/reverse-pheno")
async def get_reverse_pheno(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.reverse_pheno is None:
raise HTTPException(404, "reverse phenotyping not available")
return asdict(twin.reverse_pheno)
@router.get("/{case_id}/protocol-compliance")
async def get_protocol_compliance(case_id: str):
twin = await gcore.query_gemeo(case_id)
if twin is None or twin.protocol_compliance is None:
raise HTTPException(404, "protocol compliance not available")
return asdict(twin.protocol_compliance)
class ConsultRequest(BaseModel):
panel: list[str] = Field(default_factory=list)
question: Optional[str] = None
@router.post("/{case_id}/consult")
async def post_consult(case_id: str, req: ConsultRequest):
out = await gcore.consult(case_id, panel=req.panel or None, question=req.question)
if out is None:
raise HTTPException(404, "twin not found")
return out
class SimulateRequest(BaseModel):
n_runs: int = 30
intervention: Optional[dict] = None
horizons_months: list[int] = Field(default_factory=lambda: [6, 12, 24])
@router.post("/{case_id}/simulate")
async def post_simulate(case_id: str, req: SimulateRequest):
out = await gcore.simulate(
case_id,
n_runs=req.n_runs,
intervention=req.intervention,
horizons_months=req.horizons_months,
)
if out is None:
raise HTTPException(404, "twin not found")
return out
class DDIRequest(BaseModel):
medications: list = Field(default_factory=list)
add_drug: Optional[dict] = None
@router.post("/ddi/check")
async def check_ddi_standalone(req: DDIRequest):
"""Standalone DDI check β€” does not require a twin."""
from . import ddi as gddi_mod
spec = await gddi_mod.predict(medications=req.medications, add_drug=req.add_drug)
return asdict(spec)
# ─── LLM context / GraphRAG / absorb ──────────────────────────────────────
class LookupRequest(BaseModel):
query: str = Field(..., min_length=2)
mode: str = Field(default="local")
class AbsorbRequest(BaseModel):
message: str = Field(..., min_length=1)
source: str = Field(default="user")
@router.get("/{case_id}/context")
async def get_llm_context(case_id: str, max_chars: int = 4000):
"""Returns the Markdown context block injected into LLM system prompts."""
from . import core as gcore, llm_context
twin = gcore.get_gemeo(case_id) or await gcore.query_gemeo(case_id)
if twin is None:
raise HTTPException(404, "twin not found")
block = llm_context.serialize_twin_for_llm(twin, max_chars=max_chars)
return {
"case_id": case_id,
"twin_id": twin.id,
"header": llm_context.GEMEO_HEADER,
"instructions": llm_context.GEMEO_INSTRUCTIONS,
"context": block,
"context_chars": len(block),
"max_chars": max_chars,
}
@router.post("/{case_id}/lookup")
async def llm_lookup(case_id: str, req: LookupRequest):
"""GraphRAG retrieval β€” used by the LLM via the gemeo_lookup tool, or by
the front-end to preview what evidence the model would receive."""
from . import graphrag
result = await graphrag.retrieve(case_id, req.query, mode=req.mode)
return {**result, "rendered": graphrag.format_for_llm(result)}
@router.post("/{case_id}/absorb")
async def absorb_msg(case_id: str, req: AbsorbRequest):
"""Extract HPO/ORPHA/gene/lab/treatment mentions from a free-text
message via the LLM-based extractor (negation/family-history aware)
and feed back into the twin via evolve_gemeo.
Falls back to regex when the LLM router is unavailable.
"""
from . import extractor
return await extractor.absorb(case_id, req.message, source=req.source)
# ─── SOTA additions: event-stream, verifier, suggested-skills, cache ─────
@router.get("/{case_id}/event-stream")
async def get_event_stream(case_id: str):
"""TwinWeaver-style chronological event tape (LLM-friendly)."""
twin = await gcore.query_gemeo(case_id)
if twin is None:
raise HTTPException(404, "twin not found")
from . import event_stream
tape = event_stream.serialize_twin_as_event_stream(twin)
return {"case_id": case_id, "event_stream": tape, "chars": len(tape)}
class VerifyRequest(BaseModel):
text: str
mode: str = "light"
@router.post("/{case_id}/verify")
async def verify_recommendation(case_id: str, req: VerifyRequest):
"""Med-TIV-style verifier β€” extract claims, ground via KG, flag unverified."""
from . import verifier
rep = await verifier.verify(case_id, req.text, mode=req.mode)
return asdict(rep)
@router.get("/{case_id}/suggested-skills")
async def get_suggested_skills(case_id: str, top_n: int = 8):
"""Skill router β€” return top-N relevant skills for this twin."""
twin = await gcore.query_gemeo(case_id)
if twin is None:
raise HTTPException(404, "twin not found")
from . import skill_router
suggestions = skill_router.suggest(twin, top_n=top_n)
mcp_servers = skill_router.suggest_mcp_servers(twin)
return {
"case_id": case_id,
"skills": [asdict(s) for s in suggestions],
"mcp_servers": mcp_servers,
}
@router.get("/cache/stats")
async def get_cache_stats():
"""Aura LookupCache stats (AMG-RAG style)."""
from . import cache
return await cache.stats()
# ─── Multimodal extraction ─────────────────────────────────────────────────
# Receives a clinical image (screenshot / PDF page / lab printout) and
# returns structured entities {hpo, medications, diagnoses, labs, patient}.
# Lives in the backend (not the Vercel function) so we can reuse the
# raras-app KG to normalize HPO IDs and so the Groq/Gemini billing is
# centralized.
# NOTE: don't import UploadFile/File/Form at module top-level β€” that
# pulls python-multipart into FastAPI's dependency analyzer at import
# time, which crashes the orch boot if the package isn't installed.
# Instead, do the import + the actual route registration lazily below
# guarded by a try/except so the orch still serves the rest of /api/gemeo
# even if multipart isn't installed.
class _ExtractImageJSON(BaseModel):
"""JSON-only shape so we don't need python-multipart for the
primary path. Frontend can either POST application/json with
base64 OR multipart/form-data if the package is available."""
image_base64: str = Field(..., min_length=16)
mime: str = "image/png"
source_url: Optional[str] = None
@router.post("/extract-image-json")
async def extract_image_json_endpoint(req: _ExtractImageJSON):
"""JSON intake β€” same response shape as /extract-image. Used by
the Next.js proxy which always converts inbound multipart to base64
before forwarding here. Avoids the python-multipart dep entirely."""
from . import multimodal_extract
from base64 import b64decode
# Strip optional data-URL prefix in case the caller sent one.
clean = req.image_base64
if clean.startswith("data:"):
clean = clean.split(",", 1)[-1]
try:
raw = b64decode(clean)
except Exception as e:
raise HTTPException(400, f"bad base64: {e!s}")
if len(raw) > 15 * 1024 * 1024:
raise HTTPException(413, f"image too large ({len(raw)/1024/1024:.1f}MB, max 15MB)")
try:
return await multimodal_extract.extract_image(
image_bytes=raw, mime=req.mime, source_url=req.source_url,
)
except Exception as e:
raise HTTPException(502, f"extraction failed: {e!s}")
# Multipart route is registered lazily so a missing python-multipart
# install doesn't take the whole router down at import time. When the
# package IS available, the route adds parity with the Vercel intake.
try:
from fastapi import UploadFile, File, Form
@router.post("/extract-image")
async def extract_image_endpoint(
file: UploadFile = File(...),
source_url: Optional[str] = Form(None),
):
"""Multipart intake β€” same response shape as /extract-image-json."""
from . import multimodal_extract
raw = await file.read()
if len(raw) > 15 * 1024 * 1024:
raise HTTPException(413, f"image too large ({len(raw)/1024/1024:.1f}MB, max 15MB)")
mime = file.content_type or "image/png"
try:
return await multimodal_extract.extract_image(
image_bytes=raw, mime=mime, source_url=source_url,
)
except Exception as e:
raise HTTPException(502, f"extraction failed: {e!s}")
except Exception as _multipart_err: # noqa: F841
# python-multipart not installed β€” the JSON route still works.
logger.warning(
"[gemeo.api] multipart extract route disabled (python-multipart missing). "
"JSON route /extract-image-json remains available."
)