timmers's picture
GEMEO world-model — initial release (module + NeuralSurv ckpt + RareBench v49 + KG embeddings)
089d665 verified
"""Event-stream loader for GEMEO-CWM.
Reads the same DATASUS SIH/APAC/SIM JSONs that DT-FM-Joint trains on, but
emits cohort sequences with explicit (cohort_key, condition_id, time_zero)
metadata so the diffusion model can be CFG-conditioned on treatment status.
The condition_id is the treatment-assignment code used for CFG:
0 = null (CFG dropout / unconditional)
1 = no orphan-drug observed in trajectory
2-N = specific drug procedure code (one id per APAC drug subgroup)
A trajectory's condition_id is set to the FIRST orphan drug observed,
mimicking real-world treatment assignment at time of first dispensation.
This is the Target-Trial-Emulation "treatment intention" arm.
"""
from __future__ import annotations
import json
import logging
import os
from collections import Counter, defaultdict
from dataclasses import dataclass
import torch
log = logging.getLogger("gemeo.cwm.data")
SPECIAL = ["<PAD>", "<BOS>", "<EOS>", "<SEP>", "<UNK>", "<YEAR_BREAK>"]
def age_bucket(a):
if a is None: return "age_unk"
if a < 1: return "age_0_1"
if a < 2: return "age_1_2"
if a < 5: return "age_2_5"
if a < 12: return "age_5_12"
if a < 18: return "age_12_18"
if a < 30: return "age_18_30"
if a < 50: return "age_30_50"
if a < 70: return "age_50_70"
return "age_70plus"
def los_bucket(l):
if l is None: return None
if l <= 1: return "los_short"
if l <= 7: return "los_week"
if l <= 30: return "los_month"
return "los_long"
def event_ym(r):
if r.get("type") == "death":
d = r.get("date_of_death")
if d and "-" in str(d):
p = str(d).split("-")
return (int(p[0]), int(p[1]) if len(p) > 1 else 0)
return (r.get("year", 0), r.get("month", 0))
def cohort_key(r):
age = (r.get("age_at_admission_years")
or r.get("age_at_authorization_years")
or r.get("age_at_death_years"))
if age is None or r.get("orpha") is None:
return None
yr = r.get("year") or 2020
birth = ((yr - int(age)) // 5) * 5
return (r["orpha"], r.get("uf_code", "??"), birth, r.get("sex", "?"))
def event_to_tokens(r):
out = []
if r["type"] == "admission":
out.append(age_bucket(r.get("age_at_admission_years")))
out.append("EV_ADM")
cid = r.get("cid_princ", "")
if cid: out.append(f"cid_{cid}")
lb = los_bucket(r.get("los_days"))
if lb: out.append(lb)
proc = r.get("primary_procedure")
if proc: out.append(f"proc_{proc[:7]}")
out.append("outcome_death" if r.get("death_during_stay") else "outcome_discharge")
elif r["type"] == "treatment":
out.append(age_bucket(r.get("age_at_authorization_years")))
out.append("EV_TX")
cid = r.get("cid", "")
if cid: out.append(f"cid_{cid}")
proc = r.get("procedure_code")
if proc: out.append(f"drug_{proc[:7]}")
if r.get("is_orphan_drug"): out.append("ORPHAN_DRUG")
elif r["type"] == "death":
out.append(age_bucket(r.get("age_at_death_years")))
out.append("EV_DEATH")
cid = (r.get("cause_cid") or r.get("cid_princ") or r.get("cid", ""))
if cid: out.append(f"cid_{cid}")
out.append("<SEP>")
return out
def first_drug_token(events_for_cohort):
"""Find the FIRST orphan-drug token in the cohort's event stream."""
for r in events_for_cohort:
if r.get("type") == "treatment" and r.get("procedure_code"):
return f"drug_{r['procedure_code'][:7]}"
return None
@dataclass
class CWMDataset:
sequences: list # list of token-id sequences (each len <= max_seq_len)
conditions: torch.Tensor # (N,) condition id per sequence
cohort_keys: list # list of (orpha, uf, birth, sex)
tok2id: dict
vocab: list
cond2id: dict
cond_vocab: list
max_seq_len: int
def __len__(self):
return len(self.sequences)
def to(self, device):
seqs = torch.tensor(
[s + [self.tok2id["<PAD>"]] * (self.max_seq_len - len(s)) for s in self.sequences],
dtype=torch.long, device=device,
)
return seqs, self.conditions.to(device)
def load_events(sih_path=None, apac_path=None, sim_path=None):
events = []
for path, t in [(sih_path, "admission"), (apac_path, "treatment"), (sim_path, "death")]:
if path and os.path.exists(path):
recs = json.load(open(path))
for r in recs:
r["type"] = t
if t == "death" and r.get("age_at_death_years") is None:
r["age_at_death_years"] = r.get("age")
events.extend(recs)
log.info(f"loaded {len(recs)} {t} events from {path}")
return events
def build_cwm_dataset(events, max_seq_len=384, min_events=3,
year_filter=None) -> CWMDataset:
"""Build cohort-sequence dataset with treatment-condition labels for CFG."""
if year_filter is not None:
events = [e for e in events if event_ym(e)[0] in year_filter]
by_cohort = defaultdict(list)
for r in events:
ck = cohort_key(r)
if ck is not None:
by_cohort[ck].append(r)
# Build sequences
seqs, conds, keys = [], [], []
drug_counter = Counter()
for ck, recs in by_cohort.items():
if len(recs) < min_events:
continue
recs.sort(key=event_ym)
orpha, uf, birth, sex = ck
seq = ["<BOS>", f"orpha_{orpha}", f"uf_{uf}", f"sex_{sex}", f"birth_{birth}"]
last_y = None
for r in recs:
y = event_ym(r)[0]
if last_y is not None and y != last_y:
seq.append("<YEAR_BREAK>")
last_y = y
seq.extend(event_to_tokens(r))
seq.append("<EOS>")
seq = seq[:max_seq_len]
first_drug = first_drug_token(recs)
seqs.append(seq)
conds.append(first_drug)
keys.append(ck)
if first_drug:
drug_counter[first_drug] += 1
# Build token vocab
vocab_set = set(SPECIAL)
for s in seqs:
vocab_set.update(s)
vocab = sorted(vocab_set)
tok2id = {t: i for i, t in enumerate(vocab)}
# Build condition vocab: "<NULL>" + "<NO_TX>" + top-30 drugs (others -> <NO_TX>)
cond_vocab = ["<NULL>", "<NO_TX>"] + [d for d, _ in drug_counter.most_common(30)]
cond2id = {c: i for i, c in enumerate(cond_vocab)}
cond_ids = torch.tensor(
[cond2id.get(c, cond2id["<NO_TX>"]) if c else cond2id["<NO_TX>"]
for c in conds],
dtype=torch.long,
)
# Encode sequences
encoded = []
for s in seqs:
ids = [tok2id.get(t, tok2id["<UNK>"]) for t in s]
encoded.append(ids)
log.info(f"built {len(encoded)} cohort sequences over {len(vocab)} tokens, "
f"{len(cond_vocab)} conditions ({drug_counter.most_common(5)})")
return CWMDataset(
sequences=encoded, conditions=cond_ids, cohort_keys=keys,
tok2id=tok2id, vocab=vocab, cond2id=cond2id, cond_vocab=cond_vocab,
max_seq_len=max_seq_len,
)
def temporal_split(events, train_years, test_years):
train = [e for e in events if event_ym(e)[0] in set(train_years)]
test = [e for e in events if event_ym(e)[0] in set(test_years)]
return train, test