"""Event-stream loader for GEMEO-CWM. Reads the same DATASUS SIH/APAC/SIM JSONs that DT-FM-Joint trains on, but emits cohort sequences with explicit (cohort_key, condition_id, time_zero) metadata so the diffusion model can be CFG-conditioned on treatment status. The condition_id is the treatment-assignment code used for CFG: 0 = null (CFG dropout / unconditional) 1 = no orphan-drug observed in trajectory 2-N = specific drug procedure code (one id per APAC drug subgroup) A trajectory's condition_id is set to the FIRST orphan drug observed, mimicking real-world treatment assignment at time of first dispensation. This is the Target-Trial-Emulation "treatment intention" arm. """ from __future__ import annotations import json import logging import os from collections import Counter, defaultdict from dataclasses import dataclass import torch log = logging.getLogger("gemeo.cwm.data") SPECIAL = ["", "", "", "", "", ""] def age_bucket(a): if a is None: return "age_unk" if a < 1: return "age_0_1" if a < 2: return "age_1_2" if a < 5: return "age_2_5" if a < 12: return "age_5_12" if a < 18: return "age_12_18" if a < 30: return "age_18_30" if a < 50: return "age_30_50" if a < 70: return "age_50_70" return "age_70plus" def los_bucket(l): if l is None: return None if l <= 1: return "los_short" if l <= 7: return "los_week" if l <= 30: return "los_month" return "los_long" def event_ym(r): if r.get("type") == "death": d = r.get("date_of_death") if d and "-" in str(d): p = str(d).split("-") return (int(p[0]), int(p[1]) if len(p) > 1 else 0) return (r.get("year", 0), r.get("month", 0)) def cohort_key(r): age = (r.get("age_at_admission_years") or r.get("age_at_authorization_years") or r.get("age_at_death_years")) if age is None or r.get("orpha") is None: return None yr = r.get("year") or 2020 birth = ((yr - int(age)) // 5) * 5 return (r["orpha"], r.get("uf_code", "??"), birth, r.get("sex", "?")) def event_to_tokens(r): out = [] if r["type"] == "admission": out.append(age_bucket(r.get("age_at_admission_years"))) out.append("EV_ADM") cid = r.get("cid_princ", "") if cid: out.append(f"cid_{cid}") lb = los_bucket(r.get("los_days")) if lb: out.append(lb) proc = r.get("primary_procedure") if proc: out.append(f"proc_{proc[:7]}") out.append("outcome_death" if r.get("death_during_stay") else "outcome_discharge") elif r["type"] == "treatment": out.append(age_bucket(r.get("age_at_authorization_years"))) out.append("EV_TX") cid = r.get("cid", "") if cid: out.append(f"cid_{cid}") proc = r.get("procedure_code") if proc: out.append(f"drug_{proc[:7]}") if r.get("is_orphan_drug"): out.append("ORPHAN_DRUG") elif r["type"] == "death": out.append(age_bucket(r.get("age_at_death_years"))) out.append("EV_DEATH") cid = (r.get("cause_cid") or r.get("cid_princ") or r.get("cid", "")) if cid: out.append(f"cid_{cid}") out.append("") return out def first_drug_token(events_for_cohort): """Find the FIRST orphan-drug token in the cohort's event stream.""" for r in events_for_cohort: if r.get("type") == "treatment" and r.get("procedure_code"): return f"drug_{r['procedure_code'][:7]}" return None @dataclass class CWMDataset: sequences: list # list of token-id sequences (each len <= max_seq_len) conditions: torch.Tensor # (N,) condition id per sequence cohort_keys: list # list of (orpha, uf, birth, sex) tok2id: dict vocab: list cond2id: dict cond_vocab: list max_seq_len: int def __len__(self): return len(self.sequences) def to(self, device): seqs = torch.tensor( [s + [self.tok2id[""]] * (self.max_seq_len - len(s)) for s in self.sequences], dtype=torch.long, device=device, ) return seqs, self.conditions.to(device) def load_events(sih_path=None, apac_path=None, sim_path=None): events = [] for path, t in [(sih_path, "admission"), (apac_path, "treatment"), (sim_path, "death")]: if path and os.path.exists(path): recs = json.load(open(path)) for r in recs: r["type"] = t if t == "death" and r.get("age_at_death_years") is None: r["age_at_death_years"] = r.get("age") events.extend(recs) log.info(f"loaded {len(recs)} {t} events from {path}") return events def build_cwm_dataset(events, max_seq_len=384, min_events=3, year_filter=None) -> CWMDataset: """Build cohort-sequence dataset with treatment-condition labels for CFG.""" if year_filter is not None: events = [e for e in events if event_ym(e)[0] in year_filter] by_cohort = defaultdict(list) for r in events: ck = cohort_key(r) if ck is not None: by_cohort[ck].append(r) # Build sequences seqs, conds, keys = [], [], [] drug_counter = Counter() for ck, recs in by_cohort.items(): if len(recs) < min_events: continue recs.sort(key=event_ym) orpha, uf, birth, sex = ck seq = ["", f"orpha_{orpha}", f"uf_{uf}", f"sex_{sex}", f"birth_{birth}"] last_y = None for r in recs: y = event_ym(r)[0] if last_y is not None and y != last_y: seq.append("") last_y = y seq.extend(event_to_tokens(r)) seq.append("") seq = seq[:max_seq_len] first_drug = first_drug_token(recs) seqs.append(seq) conds.append(first_drug) keys.append(ck) if first_drug: drug_counter[first_drug] += 1 # Build token vocab vocab_set = set(SPECIAL) for s in seqs: vocab_set.update(s) vocab = sorted(vocab_set) tok2id = {t: i for i, t in enumerate(vocab)} # Build condition vocab: "" + "" + top-30 drugs (others -> ) cond_vocab = ["", ""] + [d for d, _ in drug_counter.most_common(30)] cond2id = {c: i for i, c in enumerate(cond_vocab)} cond_ids = torch.tensor( [cond2id.get(c, cond2id[""]) if c else cond2id[""] for c in conds], dtype=torch.long, ) # Encode sequences encoded = [] for s in seqs: ids = [tok2id.get(t, tok2id[""]) for t in s] encoded.append(ids) log.info(f"built {len(encoded)} cohort sequences over {len(vocab)} tokens, " f"{len(cond_vocab)} conditions ({drug_counter.most_common(5)})") return CWMDataset( sequences=encoded, conditions=cond_ids, cohort_keys=keys, tok2id=tok2id, vocab=vocab, cond2id=cond2id, cond_vocab=cond_vocab, max_seq_len=max_seq_len, ) def temporal_split(events, train_years, test_years): train = [e for e in events if event_ym(e)[0] in set(train_years)] test = [e for e in events if event_ym(e)[0] in set(test_years)] return train, test