File size: 5,868 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""Multi-specialist consultation — orchestrate the existing swarm-py
agent swarm to produce a multidisciplinary opinion on the digital twin.

For rare diseases, a multidisciplinary team (MDT) is the gold standard.
We simulate one by routing the patient subgraph + summary through the
swarm-py agents tagged with each clinical specialty (cardio, neuro,
genetics, pediatrics, pharma, immuno, hepato, etc.) and synthesizing
their opinions into a structured `ConsultSpec`.

Bootstrap path: re-uses `engine_v2`/`agent_coordinator` if available.
Phase-2: dedicated specialist agents fine-tuned per discipline.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Optional

from .types import ConsultSpec, SpecialistOpinion

logger = logging.getLogger("gemeo.consult")


# Default panel for rare-disease cases; overridable per-call
DEFAULT_PANEL = [
    "geneticist",
    "neurologist",
    "pediatrician",
    "immunologist",
    "cardiologist",
    "pharmacologist",
]


async def _ask_specialist(specialty: str, prompt: str) -> Optional[dict]:
    """Route the prompt to a specialist agent.

    First tries the swarm's native specialist routing; falls back to a
    structured LLM call if the swarm is unavailable.
    """
    try:
        from agent_coordinator import ask_specialist
        return await ask_specialist(specialty=specialty, prompt=prompt)
    except Exception:
        pass

    # Fallback: structured LLM call
    try:
        from llm_router import call_llm
        sys = (
            f"You are a senior {specialty}. Review the rare-disease case "
            f"and produce a JSON object with keys: opinion (str), "
            f"confidence (float 0..1), key_concerns (list[str]), "
            f"recommended_next_steps (list[str]), "
            f"red_flags_for_my_specialty (list[str])."
        )
        out = await call_llm(system=sys, user=prompt, json_mode=True, max_tokens=600)
        if isinstance(out, dict):
            return out
        import json
        return json.loads(out)
    except Exception as e:
        logger.warning(f"specialist {specialty} failed: {e}")
        return None


def _serialize_twin_for_prompt(twin) -> str:
    """Compact human-readable summary for specialist consumption."""
    if twin is None:
        return ""
    parts = []
    if twin.diagnoses:
        top = twin.diagnoses[:3]
        parts.append("Top diagnoses: " + "; ".join(
            f"{d.get('name','?')} (ORPHA:{d.get('orpha','?')}, p={d.get('probability', 0):.2f})"
            for d in top
        ))
    if twin.subgraph and twin.subgraph.paths:
        parts.append("Reasoning paths: " + "; ".join(
            f"{p['target_name']}: {' → '.join(s.get('label','?') for s in p['steps'])}"
            for p in twin.subgraph.paths[:2]
        ))
    if twin.risk:
        parts.append(
            f"Risk: severity={twin.risk.overall_severity:.2f}, "
            f"progression={twin.risk.progression_risk:.2f}, "
            f"urgency={twin.risk.treatment_urgency:.2f}"
        )
    if twin.trajectory:
        for h in twin.trajectory.horizons[:2]:
            parts.append(f"T+{h.months}m: {h.state[:120]}")
    if twin.drugs and twin.drugs.candidates:
        parts.append("Drug candidates: " + ", ".join(
            d.get("name", "?") for d in twin.drugs.candidates[:5]
        ))
    if twin.next_questions:
        parts.append("Next questions: " + "; ".join(
            f"{q.name} ({q.hpo_id})" for q in twin.next_questions[:3]
        ))
    return "\n".join(parts)


async def consult(
    twin,
    *,
    panel: list[str] = None,
    question: str = "Synthesise your opinion on this case.",
) -> ConsultSpec:
    """Run a multi-specialist consult on the twin."""
    panel = panel or DEFAULT_PANEL
    summary = _serialize_twin_for_prompt(twin)
    if not summary:
        return ConsultSpec(opinions=[], synthesis="No twin context available.", panel=panel)

    prompt = f"{summary}\n\nQuestion: {question}\n\nReturn JSON."

    opinions = await asyncio.gather(*[
        _ask_specialist(sp, prompt) for sp in panel
    ], return_exceptions=True)

    parsed: list = []
    for sp, opin in zip(panel, opinions):
        if isinstance(opin, Exception) or not opin:
            parsed.append(SpecialistOpinion(
                specialty=sp, opinion="(unavailable)", confidence=0.0,
                key_concerns=[], recommended_next_steps=[], red_flags=[],
            ))
            continue
        parsed.append(SpecialistOpinion(
            specialty=sp,
            opinion=str(opin.get("opinion", ""))[:600],
            confidence=float(opin.get("confidence", 0.6) or 0.6),
            key_concerns=list(opin.get("key_concerns", []))[:5],
            recommended_next_steps=list(opin.get("recommended_next_steps", []))[:5],
            red_flags=list(opin.get("red_flags_for_my_specialty", opin.get("red_flags", [])))[:5],
        ))

    # synthesis: simple aggregation; LLM can rewrite this if available
    all_concerns = [c for o in parsed for c in (o.key_concerns or [])]
    from collections import Counter
    top_concerns = [c for c, _ in Counter(all_concerns).most_common(5)]
    all_steps = [s for o in parsed for s in (o.recommended_next_steps or [])]
    top_steps = [s for s, _ in Counter(all_steps).most_common(5)]

    synthesis_lines = []
    if top_concerns:
        synthesis_lines.append("Shared concerns: " + "; ".join(top_concerns))
    if top_steps:
        synthesis_lines.append("Converged next steps: " + "; ".join(top_steps))
    flags = [f for o in parsed for f in (o.red_flags or [])]
    if flags:
        synthesis_lines.append("Red flags raised: " + "; ".join(flags[:5]))

    synthesis = "\n".join(synthesis_lines) or "No consensus signal."

    return ConsultSpec(
        opinions=parsed,
        synthesis=synthesis,
        panel=panel,
    )