gemeo-arch / ADAPTING_TO_A_NEW_DATABASE.md
timmers's picture
Upload ADAPTING_TO_A_NEW_DATABASE.md with huggingface_hub
68cce5a verified
# Adapting GEMEO to a new database
GEMEO is a **reference architecture**, not a single model. The flagship instance
(`gemeo-sus`) is trained on Brazilian SUS (DATASUS) data, but the whole point is
that any health system can instantiate its own `gemeo-<substrate>` on its own
data — **which never has to leave its environment**. This guide walks through it
end-to-end. Budget: a few hours of data wiring + ≈ 5 minutes of GPU.
The only hard requirement: your data must be expressible as a stream of
`(subject_id, time, code, value)` events — i.e. the **MEDS v0.4.1** standard.
If you can produce that, GEMEO trains on it.
---
## 0. Prerequisites
```bash
pip install torch>=2.5 meds==0.4.1 pyarrow numpy scikit-learn
git clone https://github.com/rarasAI/gemeo && cd gemeo
```
You provide: a list of patient records. Each record is a dict with
`patient_id`, optional static fields (sex, birth year, etc.), and an `events`
list. See `reference_impl/meds_export.py` for the exact shape we consume.
---
## 1. Map your data to MEDS v0.4.1
This is the only substrate-specific step. Convert each patient to MEDS rows
`(subject_id, time, code, numeric_value, text_value)`. Use a **namespaced code
vocabulary** so codes from different sources never collide:
| Domain | Convention (example) |
|---|---|
| Static (time = None) | `GENDER//F`, `RACE//…`, your-site region codes |
| Birth / death | `MEDS_BIRTH`, `MEDS_DEATH` (reserved) |
| Diagnoses | `ICD10//E84.0` (interop with OHDSI/Athena) |
| Procedures | `<YOURNS>//<code>` (e.g. `CPT//…`, `SIGTAP//…`) |
| Drugs | `<YOURNS>//<code>` (`numeric_value` = dose/cost if useful) |
| Visits | `Visit//IP`, `Visit//OP`, `Visit//ER` (CLMBR convention) |
| Disease anchor | `ORPHA//…` (rare disease) or your cohort label |
Then export:
```python
from reference_impl.meds_export import export_to_meds
export_to_meds(my_patients, out_dir="/data/meds_mysite", dataset_name="GEMEO-MYSITE")
```
This writes parquet shards + a `metadata/codes.parquet` + `dataset_metadata.json`,
validated against the official `meds.DataSchema`. **`meds_export.py` consumes an
already-built trajectory list — building that list from your raw EHR is your
site's ETL** (we keep ours private; yours stays yours).
> **Tokenization granularity matters.** Decide early how fine-grained your
> procedure/drug codes are (e.g. 7-digit vs 10-digit). The candidate space and
> all metrics depend on it — keep it fixed across train and eval.
---
## 2. Define the conditioning vocabulary
GEMEO conditions generation on a per-sequence **condition id** (`cond`) — in the
SUS instance this is derived from the patient's disease/cohort. Build a small
`cond_vocab` mapping your cohort labels (or treatment classes, for counterfactual
rollouts) to integer ids. Reserve **id 0 as the null/`<NULL>` condition** (used
for classifier-free guidance and condition-dropout at train time).
---
## 3. (Optional but recommended) Wire a knowledge graph
The gated cross-attention anchors latent state to a biomedical KG. We use
[PrimeKG](https://huggingface.co/datasets/mims-harvard/PrimeKG). For each cohort
anchor (e.g. each ORPHA/disease), precompute an **ego-subgraph**: the disease +
its top-k phenotypes + top-k genes, each encoded with a sentence embedding.
- For phenotype/Portuguese-clinical text, the most up-to-date encoder is
**`Raras-AI/araras-hpo-brasil`** (BioLORD-2023 finetune; `…-int8` for edge).
Swap in any sentence encoder for your language/domain.
- KG conditioning is **optional**: the model is trained with KG-dropout, so
`kg_raw=None` is a valid path. Start without it, add it if it helps your loss.
---
## 4. Train (≈ 5 min on one H100)
The reference trunk is `reference_impl/diffusion_forcing_v13.py`
(`CDFv13Transformer` + `CDFv13Config`). Recipe that produced the flagship:
- Causal Diffusion Forcing: per-token σ ~ U(0,1); masked cross-entropy.
- **Recurrence-aware loss** (the headline objective): weight each token by
`w = max(λ**count, w_min)`, λ=0.25, w_min=0.02, where `count` = prior
occurrences of that token in the patient's history. First occurrences carry
full weight; repeats decay toward zero. This is what makes the model predict
*novel* events instead of echoing repeats. Reference:
`reproducers/modal_raven_v6.py` (the loop you copy into your trainer).
- 8 layers, d_model 384, ctx 512, bf16, WSD LR schedule (5/80/15), 8 epochs.
- Warm-start from `Raras-AI/gemeo-sus` weights if your vocabulary overlaps, or
train from scratch (still only minutes at this scale).
```bash
# adapt the Modal reproducer to point at /data/meds_mysite, then:
modal run reproducers/modal_raven_v6.py
```
> **Scaling to a larger / multimodal substrate?** Turn on two upgrades that the
> code already ships, behind config flags:
>
> - `use_qk_norm=True` — per-head RMSNorm on Q,K before RoPE (Gemma2/3-style),
> for attention stability at depth/scale.
> - `use_adaln=True` — AdaLN-Zero σ/condition conditioning (DiT/SD3-style),
> which gives the model far more conditioning capacity than additive embeddings.
>
> ```python
> cfg = CDFv13Config(..., use_qk_norm=True, use_adaln=True)
> ```
>
> **Use these when your substrate is bigger and richer** (clinical notes, WES,
> labs, dense timing) — that extra signal is what the added capacity is built to
> exploit, and it's where these upgrades earn their keep. The released `gemeo-sus`
> flagship runs the lean ~20M additive trunk, deliberately right-sized for the
> structured-only SUS substrate. Enabling the upgrades trains a ~27M trunk in the
> same ≈ 5 min; the path is validated end-to-end (forward/backward + full train).
---
## 5. Evaluate — and prove it with baselines
Two non-negotiables, both built in:
1. **Conformance:** `python architecture/gemeo_bench.py check <your_checkpoint>`
verifies your instance satisfies the GEMEO principles (per-token σ, gated KG,
MEDS substrate, recurrence-aware objective, …).
2. **Honest metrics:** evaluate on the autocorrelation-immune tasks and **always
report the count-based baselines on the same candidate space** (frequency,
bigram, repeat-last). A model that doesn't beat the bigram on 1-step
transitions isn't SOTA there; the world model's edge is **new-onset**
(first-occurrence) and **long-context** outcomes (discontinuation,
time-to-transition). See `benchmark/` and the RareBench-BR Trajectory tasks
as a template for building your own.
EHRSHOT-style linear probes on the frozen representation (`reference_impl/
eval_sota.py`) are the cheapest way to measure long-context value on your data.
---
## 6. Name and (optionally) share
Name your instance `gemeo-<substrate>` (e.g. `gemeo-mimic`, `gemeo-mayo`) and,
if you publish, submit a conformance report. Your **data and weights stay
yours**; only the recipe is shared. Open the recipe, keep the spice.
---
## Checklist
- [ ] Raw EHR → MEDS v0.4.1 events with a namespaced code vocabulary
- [ ] `export_to_meds(...)` runs and `gemeo_bench` validates the dataset
- [ ] `cond_vocab` defined, id 0 reserved as `<NULL>`
- [ ] (optional) KG ego-subgraphs precomputed per cohort anchor
- [ ] Trained with the recurrence-aware loss; warm-start or from scratch
- [ ] Evaluated on new-onset + long-context tasks **with count-based baselines**
- [ ] `gemeo_bench check` passes
Questions: dimas@raras.ai