| """Event-stream loader for GEMEO-CWM. |
| |
| Reads the same DATASUS SIH/APAC/SIM JSONs that DT-FM-Joint trains on, but |
| emits cohort sequences with explicit (cohort_key, condition_id, time_zero) |
| metadata so the diffusion model can be CFG-conditioned on treatment status. |
| |
| The condition_id is the treatment-assignment code used for CFG: |
| 0 = null (CFG dropout / unconditional) |
| 1 = no orphan-drug observed in trajectory |
| 2-N = specific drug procedure code (one id per APAC drug subgroup) |
| |
| A trajectory's condition_id is set to the FIRST orphan drug observed, |
| mimicking real-world treatment assignment at time of first dispensation. |
| This is the Target-Trial-Emulation "treatment intention" arm. |
| """ |
| from __future__ import annotations |
| import json |
| import logging |
| import os |
| from collections import Counter, defaultdict |
| from dataclasses import dataclass |
|
|
| import torch |
|
|
| log = logging.getLogger("gemeo.cwm.data") |
|
|
| SPECIAL = ["<PAD>", "<BOS>", "<EOS>", "<SEP>", "<UNK>", "<YEAR_BREAK>"] |
|
|
|
|
| def age_bucket(a): |
| if a is None: return "age_unk" |
| if a < 1: return "age_0_1" |
| if a < 2: return "age_1_2" |
| if a < 5: return "age_2_5" |
| if a < 12: return "age_5_12" |
| if a < 18: return "age_12_18" |
| if a < 30: return "age_18_30" |
| if a < 50: return "age_30_50" |
| if a < 70: return "age_50_70" |
| return "age_70plus" |
|
|
|
|
| def los_bucket(l): |
| if l is None: return None |
| if l <= 1: return "los_short" |
| if l <= 7: return "los_week" |
| if l <= 30: return "los_month" |
| return "los_long" |
|
|
|
|
| def event_ym(r): |
| if r.get("type") == "death": |
| d = r.get("date_of_death") |
| if d and "-" in str(d): |
| p = str(d).split("-") |
| return (int(p[0]), int(p[1]) if len(p) > 1 else 0) |
| return (r.get("year", 0), r.get("month", 0)) |
|
|
|
|
| def cohort_key(r): |
| age = (r.get("age_at_admission_years") |
| or r.get("age_at_authorization_years") |
| or r.get("age_at_death_years")) |
| if age is None or r.get("orpha") is None: |
| return None |
| yr = r.get("year") or 2020 |
| birth = ((yr - int(age)) // 5) * 5 |
| return (r["orpha"], r.get("uf_code", "??"), birth, r.get("sex", "?")) |
|
|
|
|
| def event_to_tokens(r): |
| out = [] |
| if r["type"] == "admission": |
| out.append(age_bucket(r.get("age_at_admission_years"))) |
| out.append("EV_ADM") |
| cid = r.get("cid_princ", "") |
| if cid: out.append(f"cid_{cid}") |
| lb = los_bucket(r.get("los_days")) |
| if lb: out.append(lb) |
| proc = r.get("primary_procedure") |
| if proc: out.append(f"proc_{proc[:7]}") |
| out.append("outcome_death" if r.get("death_during_stay") else "outcome_discharge") |
| elif r["type"] == "treatment": |
| out.append(age_bucket(r.get("age_at_authorization_years"))) |
| out.append("EV_TX") |
| cid = r.get("cid", "") |
| if cid: out.append(f"cid_{cid}") |
| proc = r.get("procedure_code") |
| if proc: out.append(f"drug_{proc[:7]}") |
| if r.get("is_orphan_drug"): out.append("ORPHAN_DRUG") |
| elif r["type"] == "death": |
| out.append(age_bucket(r.get("age_at_death_years"))) |
| out.append("EV_DEATH") |
| cid = (r.get("cause_cid") or r.get("cid_princ") or r.get("cid", "")) |
| if cid: out.append(f"cid_{cid}") |
| out.append("<SEP>") |
| return out |
|
|
|
|
| def first_drug_token(events_for_cohort): |
| """Find the FIRST orphan-drug token in the cohort's event stream.""" |
| for r in events_for_cohort: |
| if r.get("type") == "treatment" and r.get("procedure_code"): |
| return f"drug_{r['procedure_code'][:7]}" |
| return None |
|
|
|
|
| @dataclass |
| class CWMDataset: |
| sequences: list |
| conditions: torch.Tensor |
| cohort_keys: list |
| tok2id: dict |
| vocab: list |
| cond2id: dict |
| cond_vocab: list |
| max_seq_len: int |
|
|
| def __len__(self): |
| return len(self.sequences) |
|
|
| def to(self, device): |
| seqs = torch.tensor( |
| [s + [self.tok2id["<PAD>"]] * (self.max_seq_len - len(s)) for s in self.sequences], |
| dtype=torch.long, device=device, |
| ) |
| return seqs, self.conditions.to(device) |
|
|
|
|
| def load_events(sih_path=None, apac_path=None, sim_path=None): |
| events = [] |
| for path, t in [(sih_path, "admission"), (apac_path, "treatment"), (sim_path, "death")]: |
| if path and os.path.exists(path): |
| recs = json.load(open(path)) |
| for r in recs: |
| r["type"] = t |
| if t == "death" and r.get("age_at_death_years") is None: |
| r["age_at_death_years"] = r.get("age") |
| events.extend(recs) |
| log.info(f"loaded {len(recs)} {t} events from {path}") |
| return events |
|
|
|
|
| def build_cwm_dataset(events, max_seq_len=384, min_events=3, |
| year_filter=None) -> CWMDataset: |
| """Build cohort-sequence dataset with treatment-condition labels for CFG.""" |
| if year_filter is not None: |
| events = [e for e in events if event_ym(e)[0] in year_filter] |
|
|
| by_cohort = defaultdict(list) |
| for r in events: |
| ck = cohort_key(r) |
| if ck is not None: |
| by_cohort[ck].append(r) |
|
|
| |
| seqs, conds, keys = [], [], [] |
| drug_counter = Counter() |
| for ck, recs in by_cohort.items(): |
| if len(recs) < min_events: |
| continue |
| recs.sort(key=event_ym) |
| orpha, uf, birth, sex = ck |
| seq = ["<BOS>", f"orpha_{orpha}", f"uf_{uf}", f"sex_{sex}", f"birth_{birth}"] |
| last_y = None |
| for r in recs: |
| y = event_ym(r)[0] |
| if last_y is not None and y != last_y: |
| seq.append("<YEAR_BREAK>") |
| last_y = y |
| seq.extend(event_to_tokens(r)) |
| seq.append("<EOS>") |
| seq = seq[:max_seq_len] |
| first_drug = first_drug_token(recs) |
| seqs.append(seq) |
| conds.append(first_drug) |
| keys.append(ck) |
| if first_drug: |
| drug_counter[first_drug] += 1 |
|
|
| |
| vocab_set = set(SPECIAL) |
| for s in seqs: |
| vocab_set.update(s) |
| vocab = sorted(vocab_set) |
| tok2id = {t: i for i, t in enumerate(vocab)} |
|
|
| |
| cond_vocab = ["<NULL>", "<NO_TX>"] + [d for d, _ in drug_counter.most_common(30)] |
| cond2id = {c: i for i, c in enumerate(cond_vocab)} |
| cond_ids = torch.tensor( |
| [cond2id.get(c, cond2id["<NO_TX>"]) if c else cond2id["<NO_TX>"] |
| for c in conds], |
| dtype=torch.long, |
| ) |
|
|
| |
| encoded = [] |
| for s in seqs: |
| ids = [tok2id.get(t, tok2id["<UNK>"]) for t in s] |
| encoded.append(ids) |
|
|
| log.info(f"built {len(encoded)} cohort sequences over {len(vocab)} tokens, " |
| f"{len(cond_vocab)} conditions ({drug_counter.most_common(5)})") |
| return CWMDataset( |
| sequences=encoded, conditions=cond_ids, cohort_keys=keys, |
| tok2id=tok2id, vocab=vocab, cond2id=cond2id, cond_vocab=cond_vocab, |
| max_seq_len=max_seq_len, |
| ) |
|
|
|
|
| def temporal_split(events, train_years, test_years): |
| train = [e for e in events if event_ym(e)[0] in set(train_years)] |
| test = [e for e in events if event_ym(e)[0] in set(test_years)] |
| return train, test |
|
|