File size: 5,647 Bytes
089d665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
"""TwinWeaver-style event-stream serialization.

Following TwinWeaver (arXiv 2601.20906, Genie Digital Twin, 93k cancer
patients, MASE 0.87 vs 0.97 baseline), we serialize the patient's
longitudinal history as a chronological event tape that LLMs can
extend (next-event prediction).

Format:
    [t=0d]   case_opened — Menino 5a M SP
    [t=2y]   onset hpo:0001251 ataxia
    [t=4y]   sign hpo:0001009 telangiectasia
    [t=4y]   lab AFP=280 ng/mL [abnormal]
    [t=4y]   lab IgA=18 mg/dL [abnormal]
    [t=4y]   imaging RM cerebellar atrophy
    [t=4y]   dx_suspected ORPHA:100 p=0.85
    [t=?]    ?

The trailing `?` is the prediction prompt — the LLM completes it with
the most likely next event, age at occurrence, and confidence.
"""
from __future__ import annotations
from datetime import datetime
from typing import Optional


def _t(months_ago: float) -> str:
    if months_ago is None:
        return "[t=?]"
    if abs(months_ago) >= 12:
        return f"[t={months_ago/12:.0f}y]" if months_ago != int(months_ago) else f"[t={int(months_ago/12)}y]"
    if abs(months_ago) >= 1:
        return f"[t={int(months_ago)}m]"
    return f"[t={int(months_ago*30)}d]"


def serialize_twin_as_event_stream(twin, *, presentation_year: int = 2020,
                                    horizons_months: list[int] = None) -> str:
    """Render the twin as a chronological event tape ending in a `?` prompt."""
    if twin is None:
        return ""
    horizons_months = horizons_months or [12, 36, 72]
    events: list[tuple[float, str]] = []  # (months_from_t0, line)

    # Demographics anchor
    age = (twin.extra or {}).get("age")
    sex = (twin.extra or {}).get("sex", "?")
    uf = (twin.extra or {}).get("sus_region", "?")
    events.append((0, f"case_opened — paciente {sex} idade {age or '?'}a UF {uf}"))

    # Iterate snapshot phenotypes/genes/labs (we don't have real timestamps;
    # use snapshot version × month as a coarse approximation).
    snapshots = (twin.snapshot_versions or [])
    if hasattr(twin, "_space_ref") and twin._space_ref:
        space = twin._space_ref
    else:
        space = None
    # Best effort: use the GemeoTwin's explicit fields (they only carry counts,
    # so we synthesize one event per known phenotype).
    # The actual phenotype list lives in the PatientSpace, so we fetch via core.
    try:
        from . import core as gcore
        cached = gcore.get_gemeo(twin.case_id) if twin.case_id else None
        if cached is None:
            from patient_space import get_space as _gs
            sp = _gs(twin.case_id) if twin.case_id else None
            snap = sp.get_current_snapshot() if sp and hasattr(sp, "get_current_snapshot") else None
        else:
            from patient_space import get_space as _gs
            sp = _gs(cached.case_id)
            snap = sp.get_current_snapshot() if sp and hasattr(sp, "get_current_snapshot") else None
    except Exception:
        snap = None

    if snap is not None:
        # Phenotypes: assume onset 1-3 years before presentation (heuristic)
        onset_offset = -3 * 12  # 3y before t=0
        for i, p in enumerate(snap.phenotypes[:30]):
            t = onset_offset + i * 2  # spread over ~2-month bursts
            events.append((t, f"sign {p.get('hpo_id', '?')} {p.get('name', '')[:60]}"))
        for g in snap.genes[:10]:
            sym = g.get("symbol", "?")
            var = g.get("variant", "")
            events.append((0, f"genetic_test {sym}{(' ' + var) if var else ''} ({g.get('pathogenicity', '?')})"))
        for l in snap.labs[:15]:
            test = l.get("test", "?")
            val = l.get("value", "?")
            unit = l.get("unit", "")
            flag = " [abnormal]" if l.get("abnormal") else ""
            events.append((0, f"lab {test}={val} {unit}{flag}"))
        for img in snap.imaging[:5]:
            mod = img.get("modality", "?")
            find = (img.get("finding", "") or "")[:80]
            events.append((0, f"imaging {mod}: {find}"))

    # Hypothesised diagnoses
    for d in twin.diagnoses[:5]:
        name = d.get("name", d.get("disease", "?"))
        orpha = d.get("orpha", "?")
        prob = d.get("probability", 0)
        events.append((0, f"dx_suspected ORPHA:{orpha} {name} p={prob:.2f}"))

    # Sort chronologically
    events.sort(key=lambda x: x[0])

    lines = []
    for t, line in events:
        lines.append(f"  {_t(t):<8} {line}")

    # Prediction prompt — for each horizon, ask the model to complete
    horizon_prompts = "\n".join(
        f"  [t=+{h}m] ? <event_type> <details> conf=<0..1>"
        for h in horizons_months
    )

    return (
        "## EVENT TAPE (TwinWeaver-style serialization)\n"
        "Each line: [time_from_now] event_type details. Predict next events at the prompts below.\n\n"
        + "\n".join(lines) +
        "\n\n## PREDICT (most likely events at each horizon):\n" + horizon_prompts
    )


def parse_predicted_events(llm_text: str) -> list[dict]:
    """Parse LLM completion of the event prompt format into structured events."""
    out = []
    if not llm_text:
        return out
    import re
    pattern = re.compile(r"\[t=\+(\d+)m\]\s*([a-z_]+)\s+(.+?)(?:\s+conf=([\d.]+))?$",
                         re.IGNORECASE | re.MULTILINE)
    for m in pattern.finditer(llm_text):
        try:
            out.append({
                "horizon_months": int(m.group(1)),
                "event_type": m.group(2).lower(),
                "details": m.group(3).strip(),
                "confidence": float(m.group(4)) if m.group(4) else 0.5,
            })
        except Exception:
            continue
    return out