gemeo-twin-stack / src /gemeo /consult.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Multi-specialist consultation — orchestrate the existing swarm-py
agent swarm to produce a multidisciplinary opinion on the digital twin.
For rare diseases, a multidisciplinary team (MDT) is the gold standard.
We simulate one by routing the patient subgraph + summary through the
swarm-py agents tagged with each clinical specialty (cardio, neuro,
genetics, pediatrics, pharma, immuno, hepato, etc.) and synthesizing
their opinions into a structured `ConsultSpec`.
Bootstrap path: re-uses `engine_v2`/`agent_coordinator` if available.
Phase-2: dedicated specialist agents fine-tuned per discipline.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional
from .types import ConsultSpec, SpecialistOpinion
logger = logging.getLogger("gemeo.consult")
# Default panel for rare-disease cases; overridable per-call
DEFAULT_PANEL = [
"geneticist",
"neurologist",
"pediatrician",
"immunologist",
"cardiologist",
"pharmacologist",
]
async def _ask_specialist(specialty: str, prompt: str) -> Optional[dict]:
"""Route the prompt to a specialist agent.
First tries the swarm's native specialist routing; falls back to a
structured LLM call if the swarm is unavailable.
"""
try:
from agent_coordinator import ask_specialist
return await ask_specialist(specialty=specialty, prompt=prompt)
except Exception:
pass
# Fallback: structured LLM call
try:
from llm_router import call_llm
sys = (
f"You are a senior {specialty}. Review the rare-disease case "
f"and produce a JSON object with keys: opinion (str), "
f"confidence (float 0..1), key_concerns (list[str]), "
f"recommended_next_steps (list[str]), "
f"red_flags_for_my_specialty (list[str])."
)
out = await call_llm(system=sys, user=prompt, json_mode=True, max_tokens=600)
if isinstance(out, dict):
return out
import json
return json.loads(out)
except Exception as e:
logger.warning(f"specialist {specialty} failed: {e}")
return None
def _serialize_twin_for_prompt(twin) -> str:
"""Compact human-readable summary for specialist consumption."""
if twin is None:
return ""
parts = []
if twin.diagnoses:
top = twin.diagnoses[:3]
parts.append("Top diagnoses: " + "; ".join(
f"{d.get('name','?')} (ORPHA:{d.get('orpha','?')}, p={d.get('probability', 0):.2f})"
for d in top
))
if twin.subgraph and twin.subgraph.paths:
parts.append("Reasoning paths: " + "; ".join(
f"{p['target_name']}: {' → '.join(s.get('label','?') for s in p['steps'])}"
for p in twin.subgraph.paths[:2]
))
if twin.risk:
parts.append(
f"Risk: severity={twin.risk.overall_severity:.2f}, "
f"progression={twin.risk.progression_risk:.2f}, "
f"urgency={twin.risk.treatment_urgency:.2f}"
)
if twin.trajectory:
for h in twin.trajectory.horizons[:2]:
parts.append(f"T+{h.months}m: {h.state[:120]}")
if twin.drugs and twin.drugs.candidates:
parts.append("Drug candidates: " + ", ".join(
d.get("name", "?") for d in twin.drugs.candidates[:5]
))
if twin.next_questions:
parts.append("Next questions: " + "; ".join(
f"{q.name} ({q.hpo_id})" for q in twin.next_questions[:3]
))
return "\n".join(parts)
async def consult(
twin,
*,
panel: list[str] = None,
question: str = "Synthesise your opinion on this case.",
) -> ConsultSpec:
"""Run a multi-specialist consult on the twin."""
panel = panel or DEFAULT_PANEL
summary = _serialize_twin_for_prompt(twin)
if not summary:
return ConsultSpec(opinions=[], synthesis="No twin context available.", panel=panel)
prompt = f"{summary}\n\nQuestion: {question}\n\nReturn JSON."
opinions = await asyncio.gather(*[
_ask_specialist(sp, prompt) for sp in panel
], return_exceptions=True)
parsed: list = []
for sp, opin in zip(panel, opinions):
if isinstance(opin, Exception) or not opin:
parsed.append(SpecialistOpinion(
specialty=sp, opinion="(unavailable)", confidence=0.0,
key_concerns=[], recommended_next_steps=[], red_flags=[],
))
continue
parsed.append(SpecialistOpinion(
specialty=sp,
opinion=str(opin.get("opinion", ""))[:600],
confidence=float(opin.get("confidence", 0.6) or 0.6),
key_concerns=list(opin.get("key_concerns", []))[:5],
recommended_next_steps=list(opin.get("recommended_next_steps", []))[:5],
red_flags=list(opin.get("red_flags_for_my_specialty", opin.get("red_flags", [])))[:5],
))
# synthesis: simple aggregation; LLM can rewrite this if available
all_concerns = [c for o in parsed for c in (o.key_concerns or [])]
from collections import Counter
top_concerns = [c for c, _ in Counter(all_concerns).most_common(5)]
all_steps = [s for o in parsed for s in (o.recommended_next_steps or [])]
top_steps = [s for s, _ in Counter(all_steps).most_common(5)]
synthesis_lines = []
if top_concerns:
synthesis_lines.append("Shared concerns: " + "; ".join(top_concerns))
if top_steps:
synthesis_lines.append("Converged next steps: " + "; ".join(top_steps))
flags = [f for o in parsed for f in (o.red_flags or [])]
if flags:
synthesis_lines.append("Red flags raised: " + "; ".join(flags[:5]))
synthesis = "\n".join(synthesis_lines) or "No consensus signal."
return ConsultSpec(
opinions=parsed,
synthesis=synthesis,
panel=panel,
)