gemeo-twin-stack / src /gemeo /event_stream.py
timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""TwinWeaver-style event-stream serialization.
Following TwinWeaver (arXiv 2601.20906, Genie Digital Twin, 93k cancer
patients, MASE 0.87 vs 0.97 baseline), we serialize the patient's
longitudinal history as a chronological event tape that LLMs can
extend (next-event prediction).
Format:
[t=0d] case_opened — Menino 5a M SP
[t=2y] onset hpo:0001251 ataxia
[t=4y] sign hpo:0001009 telangiectasia
[t=4y] lab AFP=280 ng/mL [abnormal]
[t=4y] lab IgA=18 mg/dL [abnormal]
[t=4y] imaging RM cerebellar atrophy
[t=4y] dx_suspected ORPHA:100 p=0.85
[t=?] ?
The trailing `?` is the prediction prompt — the LLM completes it with
the most likely next event, age at occurrence, and confidence.
"""
from __future__ import annotations
from datetime import datetime
from typing import Optional
def _t(months_ago: float) -> str:
if months_ago is None:
return "[t=?]"
if abs(months_ago) >= 12:
return f"[t={months_ago/12:.0f}y]" if months_ago != int(months_ago) else f"[t={int(months_ago/12)}y]"
if abs(months_ago) >= 1:
return f"[t={int(months_ago)}m]"
return f"[t={int(months_ago*30)}d]"
def serialize_twin_as_event_stream(twin, *, presentation_year: int = 2020,
horizons_months: list[int] = None) -> str:
"""Render the twin as a chronological event tape ending in a `?` prompt."""
if twin is None:
return ""
horizons_months = horizons_months or [12, 36, 72]
events: list[tuple[float, str]] = [] # (months_from_t0, line)
# Demographics anchor
age = (twin.extra or {}).get("age")
sex = (twin.extra or {}).get("sex", "?")
uf = (twin.extra or {}).get("sus_region", "?")
events.append((0, f"case_opened — paciente {sex} idade {age or '?'}a UF {uf}"))
# Iterate snapshot phenotypes/genes/labs (we don't have real timestamps;
# use snapshot version × month as a coarse approximation).
snapshots = (twin.snapshot_versions or [])
if hasattr(twin, "_space_ref") and twin._space_ref:
space = twin._space_ref
else:
space = None
# Best effort: use the GemeoTwin's explicit fields (they only carry counts,
# so we synthesize one event per known phenotype).
# The actual phenotype list lives in the PatientSpace, so we fetch via core.
try:
from . import core as gcore
cached = gcore.get_gemeo(twin.case_id) if twin.case_id else None
if cached is None:
from patient_space import get_space as _gs
sp = _gs(twin.case_id) if twin.case_id else None
snap = sp.get_current_snapshot() if sp and hasattr(sp, "get_current_snapshot") else None
else:
from patient_space import get_space as _gs
sp = _gs(cached.case_id)
snap = sp.get_current_snapshot() if sp and hasattr(sp, "get_current_snapshot") else None
except Exception:
snap = None
if snap is not None:
# Phenotypes: assume onset 1-3 years before presentation (heuristic)
onset_offset = -3 * 12 # 3y before t=0
for i, p in enumerate(snap.phenotypes[:30]):
t = onset_offset + i * 2 # spread over ~2-month bursts
events.append((t, f"sign {p.get('hpo_id', '?')} {p.get('name', '')[:60]}"))
for g in snap.genes[:10]:
sym = g.get("symbol", "?")
var = g.get("variant", "")
events.append((0, f"genetic_test {sym}{(' ' + var) if var else ''} ({g.get('pathogenicity', '?')})"))
for l in snap.labs[:15]:
test = l.get("test", "?")
val = l.get("value", "?")
unit = l.get("unit", "")
flag = " [abnormal]" if l.get("abnormal") else ""
events.append((0, f"lab {test}={val} {unit}{flag}"))
for img in snap.imaging[:5]:
mod = img.get("modality", "?")
find = (img.get("finding", "") or "")[:80]
events.append((0, f"imaging {mod}: {find}"))
# Hypothesised diagnoses
for d in twin.diagnoses[:5]:
name = d.get("name", d.get("disease", "?"))
orpha = d.get("orpha", "?")
prob = d.get("probability", 0)
events.append((0, f"dx_suspected ORPHA:{orpha} {name} p={prob:.2f}"))
# Sort chronologically
events.sort(key=lambda x: x[0])
lines = []
for t, line in events:
lines.append(f" {_t(t):<8} {line}")
# Prediction prompt — for each horizon, ask the model to complete
horizon_prompts = "\n".join(
f" [t=+{h}m] ? <event_type> <details> conf=<0..1>"
for h in horizons_months
)
return (
"## EVENT TAPE (TwinWeaver-style serialization)\n"
"Each line: [time_from_now] event_type details. Predict next events at the prompts below.\n\n"
+ "\n".join(lines) +
"\n\n## PREDICT (most likely events at each horizon):\n" + horizon_prompts
)
def parse_predicted_events(llm_text: str) -> list[dict]:
"""Parse LLM completion of the event prompt format into structured events."""
out = []
if not llm_text:
return out
import re
pattern = re.compile(r"\[t=\+(\d+)m\]\s*([a-z_]+)\s+(.+?)(?:\s+conf=([\d.]+))?$",
re.IGNORECASE | re.MULTILINE)
for m in pattern.finditer(llm_text):
try:
out.append({
"horizon_months": int(m.group(1)),
"event_type": m.group(2).lower(),
"details": m.group(3).strip(),
"confidence": float(m.group(4)) if m.group(4) else 0.5,
})
except Exception:
continue
return out