diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8bbdc19fd8f93a267b9cfc108146ddae8abd030e --- /dev/null +++ b/.gitignore @@ -0,0 +1,15 @@ +# Python-generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# Virtual environments +.venv +.env + +.cache +.claude/ +.agents/ diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/configs/base.yaml b/configs/base.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1164464210b8eaf7c732bebb5ae14e8eb1a34d3c --- /dev/null +++ b/configs/base.yaml @@ -0,0 +1,27 @@ +# Shared training config for Baseline A, B, C, and E3 PhysioJEPA. +# Only `run_name` and `model` differ across runs — everything else stays identical +# so K2 (E3 AUROC > Baseline B + 0.02) is a single-variable comparison. + +run_name: base +model: F # override per-run: A|B|C|F +epochs: 100 +batch_size: 64 +lr: 1.0e-4 +weight_decay: 0.04 +warmup_epochs: 10 +ema_start: 0.996 +ema_end: 0.9999 +ema_warmup_frac: 0.30 +grad_clip: 1.0 +log_every: 100 +ckpt_every_epochs: 5 +seed: 42 +wandb_project: physiojepa +wandb_mode: online +wandb_entity: null +output_dir: runs +index_path: cache/mimic_index.json +shard_roots: [] # filled per-environment (populated by prepare_data.py) +num_workers: 4 +amp: true +log_uniform_frac: 0.6 diff --git a/docs/ARCHITECTURES_EXPLORATION.md b/docs/ARCHITECTURES_EXPLORATION.md new file mode 100644 index 0000000000000000000000000000000000000000..bddcd551de6418f1741a574eb9113ceb4f0f3ed2 --- /dev/null +++ b/docs/ARCHITECTURES_EXPLORATION.md @@ -0,0 +1,185 @@ +# PhysioJEPA architecture landscape +*Oz Labs — April 2026* +*Revision 2: post-reviewer critique. Replaces cardio_jepa_architectures.md* + +--- + +## Change log from revision 1 + +- All "CausalCardio-JEPA" → "PhysioJEPA" (Architecture F) +- v1 architecture clarified: raw PPG patches, EMA only, no morphological encoding, no SIGReg, no cardiac phase encoding in first run +- Ablation structure added to Architecture F entry +- Execution order updated to cross-reference experiment matrix +- Architecture descriptions retain full detail; only framing corrected + +--- + +## Prior work — precisely characterised + +### Weimann & Conrad — ECG-JEPA (2410.13867) +The direct baseline. I-JEPA adapted for 12-lead ECG: 2D patch tokenisation (leads × time), multi-block contiguous masking, EMA target encoder, L1 latent prediction. Pretrained on 1M+ records, AUC 0.945 on PTB-XL all-statements. Open source at `github.com/kweimann/ECG-JEPA` — our starting codebase. **Limitation**: unimodal, no PPG, no temporal dynamics beyond the 10s window. Static representation learner, not a world model. + +### Kim — CroPA-ECG-JEPA (2410.08559) +Introduces Cross-Pattern Attention (CroPA): masked attention enforcing inter-lead dependencies. Recovers HR and QRS duration from frozen representations. **Lesson**: clinical inductive bias (inter-lead relationships) improves cardiac JEPA. Directly motivates cardiac phase encoding in Architecture F ablation A2. + +### Khadka et al. — EEG-VJEPA (2507.03633) +Treats multi-channel EEG as 3D spatiotemporal tensor, applies V-JEPA tube masking. 85.8% accuracy on TUH Abnormal EEG, UMAP clusters showing pathological separation. **Lesson**: V-JEPA tube masking transfers to physiological signals; the "signal as video" reframe works. + +### Zhou et al. — Brain-JEPA (NeurIPS 2024) +Brain Gradient Positioning as domain-specific positional encoding derived from fMRI connectivity gradients. **Lesson for us**: cardiac phase encoding (P/QRS/ST/T) is the cardiac analog. Botman reviewer raised the valid concern that hard phase boundaries fail during AF — soft Gaussian encoding over landmarks is the fix if we pursue ablation A2. + +### Wang et al. — EchoJEPA (2602.02603) +V-JEPA 2 on 18M echocardiograms; JEPA degrades only 2% under physics-informed acoustic perturbations vs 17% for VideoMAE. **Key lesson**: JEPA's noise rejection is the primary advantage for medical signals. Directly motivates our choice of JEPA over MAE for ICU PPG data. + +### Balestriero & LeCun — LeJEPA (2511.08544) +Proves isotropic Gaussian is optimal JEPA embedding; introduces SIGReg (Sketched Isotropic Gaussian Regularisation) with linear complexity. Eliminates EMA entirely. **Position in our work**: ablation A3 tests SIGReg vs EMA on cardiac signals. Botman/Laya raised the concern that SIGReg may over-regularise impulsive ECG transients (QRS complexes) — mitigation is applying SIGReg only to pooled global representation, not per-patch tokens. + +### Botman et al. — Laya (2603.16281) +First LeJEPA application to EEG at scale; latent prediction outperforms reconstruction on clinical tasks. Found that SIGReg with aggressive λ causes instability on signals with large amplitude transients — recommended lower λ and gradient clipping. **Direct prior** for our ablation A3. + +### Nie et al. — AnyPPG (2511.01747) +Symmetric InfoNCE on 100k+ hours of synchronized ECG-PPG. Critical detail: the ECGFounder encoder is *frozen* during AnyPPG training — ECG acts as a fixed supervisory signal, not a jointly-learned representation. R@1=0.736 for PPG→ECG retrieval. **This is the primary baseline we differentiate from.** AnyPPG answers "what is the cardiovascular state now?" — PhysioJEPA asks "how does it evolve?" + +### Assran et al. — V-JEPA 2-AC (2506.09985) +Two-stage: action-free pretraining then action-conditioned post-training on 62 hours of robot trajectories. Zero-shot robot planning via MPC in latent space. **Architectural template for Architecture D** (intervention-conditioned, future work). Not part of PhysioJEPA v1. + +### Wu, Lei et al. — SurgMotion (2602.05638) +V-JEPA 2 on surgical video, rejects smoke/specular artifacts. **Lesson**: the JEPA noise-rejection property generalises across medical imaging modalities. + +--- + +## Five architectures + two novel extensions + +### Architecture A — Temporal ECG-JEPA + +**What it is**: ECG-JEPA (Weimann) extended from spatial masking to temporal future prediction. Context window t→t+T, target is t+T→t+2T. Single modality, no PPG. + +**Use case**: Baseline A in the experiment matrix. Also the fallback if K2 fails — this is still publishable as an extension of Weimann & Conrad. + +**Novelty**: low. The masking axis change is surgical. The temporal rollout evaluation (AF onset prediction via latent trajectory deviation score) is new but not a strong paper claim. + +**Estimated performance**: should match or slightly exceed ECG-JEPA on static tasks; advantage only on temporal tasks. + +--- + +### Architecture B — Symmetric cross-modal JEPA (Δt=0) + +**What it is**: Dual encoder (ECG + PPG), cross-attention predictor, but Δt is fixed to 0. ECG context predicts PPG at the same time. + +**Use case**: Baseline B in the experiment matrix — the controlled comparison that isolates whether Δt matters. + +**Novelty**: low. This is essentially JEPA-flavoured AnyPPG without the frozen encoder constraint. + +**Why it exists**: without this baseline, K2 cannot be answered. It must run in parallel with PhysioJEPA from Day 4. + +--- + +### Architecture C — LeJEPA cardiac (SIGReg, cross-modal) + +**What it is**: Architecture PhysioJEPA but replacing EMA with SIGReg. Clean theoretical foundation from Balestriero & LeCun. + +**When to run**: ablation A3 after E3 passes K2. + +**Key risk**: SIGReg enforces isotropic Gaussian geometry globally. ECG signals have highly anisotropic spectral structure (dominant QRS transients). Mitigation: apply SIGReg only to pooled global representations, not per-patch tokens. λ sweep: [0.001, 0.01, 0.05, 0.1]. + +**What it contributes**: if SIGReg outperforms EMA, it provides a cleaner theoretical story and removes the τ schedule hyperparameter. If it doesn't, EMA stays and LeJEPA is cited as related work. + +--- + +### Architecture D — Intervention-conditioned cardiac world model (V-JEPA 2-AC for ICU) + +**What it is**: PhysioJEPA as Stage 1 pretraining; freeze encoder; Stage 2 post-trains action-conditioned predictor on clinical intervention tokens from MIMIC-IV (vasopressors, fluid boluses, ventilator changes). + +**When to run**: future work. Requires MIMIC-IV waveform→medication timestamp alignment, which is a separate data engineering project. Not part of the 15-day experiment matrix. + +**Why it matters**: this is the highest-impact clinical application. A world model that simulates "what happens to haemodynamics if I give norepinephrine now?" has direct ICU decision support utility. The V-JEPA 2-AC paper demonstrated the two-stage recipe works with only 62 hours of interaction data — we have years of MIMIC-IV ICU data. + +**Prerequisite**: PhysioJEPA Stage 1 must work (K2 passes) before investing in Stage 2. + +--- + +### Architecture E — Hierarchical cardiac JEPA (H-JEPA, dual timescale) + +**What it is**: Two JEPA levels. Fast encoder (beat-level, ~1s): predicts next-beat ECG/PPG from current beat. Slow encoder (episode-level, 5min): predicts episode summary from sequence of fast latents. Slow predictor conditions fast predictor. + +**When to run**: medium-term. Requires significant training complexity management. + +**Unique capability**: the only architecture that captures both beat-to-beat variability (HRV) and autonomic tone evolution over minutes. AF onset prediction benefits from both scales. + +**Risk**: two-level training can develop gradient imbalance. Curriculum (train fast encoder first, activate slow encoder after convergence) is necessary. + +--- + +### Architecture F — PhysioJEPA (directional asymmetric time-offset JEPA) + +**This is the paper.** + +**Core innovation**: ECG context predicts PPG morphology at a *variable* time offset Δt, encoding the directional temporal structure of the cardiovascular causal chain (electrical activation → mechanical contraction → peripheral perfusion) that symmetric contrastive methods destroy. + +**Why not "causal" JEPA**: the architecture encodes a physiological asymmetry, not causal inference in the interventional sense. Calling it "causal" would invite statistical causality reviewers to reject on framing alone. "Directional" or "asymmetric" is accurate and defensible. + +#### v1 architecture (minimal, runs in experiment matrix) + +See Section 2 of `RESEARCH_DEVELOPMENT.md` for full specification (revised 2026-04-14 post-E0). In brief: +- ECG encoder (ViT-S, 1D over single lead II @ 250 Hz, 50-sample / 200 ms patches) +- PPG target encoder (ViT-T, raw 25-sample / 200 ms patches @ 125 Hz, EMA) +- Cross-attention predictor conditioned on Δt embedding +- EMA collapse prevention (no SIGReg in v1) +- Loss: L1 cross-modal prediction + 0.3 × L1 ECG self-prediction +- Δt sampling: 60% log-uniform [50ms, 500ms], 40% ground-truth PTT + +#### What makes v1 different from Baseline B + +One thing only: **Δt > 0 vs Δt = 0**. The experiment matrix is designed so that this single variable is isolated. Everything else (encoder architecture, predictor, loss) is identical between E3 and Baseline B. + +#### Ablations (run after K2 passes) + +| # | Change | Tests | +|---|--------|-------| +| A1 | Morphological PPG tokens instead of raw patches | Does structured PPG encoding improve latent? | +| A2 | Cardiac phase PE (soft Gaussian over landmarks) | Does phase-aware PE beat standard sinusoidal? | +| A3 | SIGReg instead of EMA | Is SIGReg more stable on impulsive cardiac signals? | +| A4 | PTT regression head in training loop (γ=0.1) | Does supervised PTT improve vascular encoding? | +| A5 | Curriculum Δt (ground-truth first, then random) | Does Δt schedule matter? | + +#### PTT as validation, not contribution + +The PTT regression probe (E5a in the experiment matrix) tests whether the learned Δt structure is physiologically meaningful — not whether we invented a new way to measure PTT. PTT by peak detection is a 10-line script. The contribution is that a model trained without any PTT labels implicitly encodes PTT in its latent geometry. That is the evidence for claim 3 in the hypothesis. + +--- + +### Architecture G — Variational PhysioJEPA (uncertainty-aware clinical planning) + +**What it is**: Extend Architecture D with a variational predictor. Instead of a single future latent, predict μ and σ over future latents. At inference, sample K rollout trajectories and select action sequences that minimise expected goal distance weighted by uncertainty. + +**Why it matters clinically**: ICU decisions require not just a prediction but a confidence signal. A world model that signals high σ for a septic patient with unstable haemodynamics tells the clinician that standard vasopressor protocols may not apply. + +**Connection to Ha & Schmidhuber (2018)**: G is the modern JEPA equivalent of their VAE + MDN world model — JEPA encoder replaces VAE, variational predictor with Gaussian output replaces MDN, intervention token replaces random action. + +**When to pursue**: after Architecture D is validated. This is a two-paper arc: D then G. + +--- + +## Recommended execution order + +``` +Now (weeks 1–2): Architecture F v1 (PhysioJEPA) + Baselines A, B, C (experiment matrix E2) + Decision gate at K2 + +Weeks 3–4: Ablations A1–A5 (if K2 passes) + +Months 2–3: Architecture D (if MIMIC-IV data join succeeds) + Architecture E (if Zack bandwidth) + +Future: Architecture G (after D validated) + Architecture A as ablation/fallback paper +``` + +The architecture document serves as a reference map. The experiment matrix is the operational guide. Everything that is not in the experiment matrix is future work. + +--- + +*Document revision 2 — April 2026* +*Architecture F is now named PhysioJEPA throughout.* +*Execution order cross-references physiojep_experiment_matrix.md* \ No newline at end of file diff --git a/docs/EXPERIMENT_TRACKING.md b/docs/EXPERIMENT_TRACKING.md new file mode 100644 index 0000000000000000000000000000000000000000..4a39f137578d6f60d79ca2439c5f63ce869ce3b6 --- /dev/null +++ b/docs/EXPERIMENT_TRACKING.md @@ -0,0 +1,554 @@ +# PhysioJEPA — Minimal Experiment Matrix +*Oz Labs — April 2026* +*Revision 2: post-reviewer critique. All "CausalCardio-JEPA" references replaced.* + +--- + +## The single question this matrix answers + +> Does predicting PPG at Δt from ECG produce better cardiovascular representations +> than aligning ECG and PPG at t=0? + +Every experiment below either answers this question or gates the next one. +Nothing else runs until K2 is resolved. + +--- + +## Experiment map overview + +``` +Day 1–2 E0: Data audit → Go/No-go on dataset + │ + ▼ +Day 3 E1: Morphology vs raw → Choose PPG encoding, once, forever + │ + ▼ +Day 4–5 E2: Baselines A+B+C → Establish floor and ceiling + │ + ▼ +Day 6–8 E3: Δt-JEPA v1 → Core claim test (K1, K2, K3) + │ + ├── FAIL → exit + │ + ▼ +Day 9–10 E4: Rollout coherence → World model validation + │ + ▼ +Day 11–12 E5: PTT probe → Downstream validation + │ + ▼ +Day 13–14 E6: Ablation Δt=0 vs Δt>0 → Isolate the single variable + │ + ▼ +Day 15 Decision: paper or pivot +``` + +--- + +## E0 — Data audit +**Days 1–2 | Prerequisite for everything** + +### What to run + +```python +import datasets +ds = datasets.load_dataset("lucky9-cyou/mimic-iv-aligned-ppg-ecg") + +# For each record, compute: +# 1. ECG-PPG alignment tolerance +alignment_error_ms = [] +for record in ds: + r_peak_ts = detect_r_peaks(record['ecg']) + ppg_peak_ts = detect_ppg_peaks(record['ppg']) + ptt = align_peaks(r_peak_ts, ppg_peak_ts) + alignment_error_ms.append(ptt_variability(ptt)) + +# 2. Coverage +n_patients = len(set(record['subject_id'] for record in ds)) +total_hours = sum(record['duration'] for record in ds) / 3600 +missing_pct = mean_missing_rate(ds) +``` + +### Pass criteria — ALL must be true + +| Metric | Pass | Fail action | +|--------|------|-------------| +| Median alignment ≤ 50ms | ✓ proceed | Pivot to PhysioNet BIDMC | +| PTT within-patient std ≤ 80ms | ✓ proceed | Same pivot | +| Patients ≥ 500 | ✓ proceed | Supplement with PhysioNet MIMIC-III waveforms | +| Missing rate ≤ 20% after windowing | ✓ proceed | Tighten quality filter | +| PTT range [50ms, 500ms] physiologically plausible | ✓ proceed | Check synchronisation method | + +### Output +- `data_card.md`: patients, hours, alignment stats, missing rates +- `ptt_histogram.png`: histogram of measured PTT per patient +- Go/no-go decision logged in `experiments/e0_decision.md` + +**If E0 fails**: PhysioNet BIDMC (ECG + PPG, documented 0.1ms alignment, 53 subjects — smaller but clean). All downstream experiments are identical; only scale changes. + +--- + +## E1 — Morphology vs raw PPG patches +**Day 3 | One-time architectural decision** + +### What to run + +Two target encoders, same ViT-S backbone, 10% of data, 20 epochs each: + +**E1a — Raw patch encoder** +- PPG windowed into 200ms patches (25 samples at 125Hz) +- Linear projection → d=256 tokens +- Standard I-JEPA spatial masking within window + +**E1b — Morphological encoder** +- Per-beat features: systolic peak height, diastolic notch depth, pulse width, upstroke slope, augmentation index +- Extracted via Bishop & Ercole peak detection + `scipy.signal` +- Linear projection → d=256 tokens per beat + +### Metrics to compare + +| Metric | What it tests | +|--------|--------------| +| % beats with valid morphology extraction | Is E1b viable on this dataset? | +| Target encoder latent variance | Stability (collapse check) | +| Linear probe AUROC on AF (frozen, 100 AF / 100 normal) | Representation quality | +| MAE of PTT regression from frozen encoder | Vascular information content | + +### Decision rule (made once, frozen) + +``` +if morphology_extraction_rate < 0.70: + USE raw patches (E1a) + +elif E1b linear_probe_AUROC > E1a + 0.02: + USE morphological (E1b) + +else: + USE raw patches (E1a) — simpler, fewer failure modes +``` + +### Output +- `e1_decision.md`: which encoder, exact threshold used, quality stats +- `ppg_encoder.py`: the chosen implementation, committed to repo + +--- + +## E2 — Baseline suite +**Days 4–5 | Floor and ceiling** + +Run all three in parallel. Same data split, same 20 epochs, same evaluation harness. +These are reference points for E3, not ablations. + +### AF label source — decide before running E2 + +**Decision required by**: Day 3 (before baselines start training) +**Owner**: Zack + +**Option 1 — MIMIC-IV ECG module (preferred)** +Join `mimic-iv-ecg` rhythm annotations to the aligned waveform dataset by `subject_id` + `hadm_id`. +- Pros: in-distribution, same patient population as training data +- Cons: requires verifying the join yields enough AF-positive patients (need ≥100 AF, ≥100 normal for the linear probe to be meaningful) +- Check: `SELECT count(*) FROM mimic-iv-ecg WHERE rhythm = 'atrial fibrillation'` on the HF mirror + +**Option 2 — PTB-XL (fallback)** +Use PTB-XL rhythm labels as the AF evaluation benchmark. +- Pros: clean, well-labelled, already used by Weimann & Conrad (enables direct comparison) +- Cons: different population (German outpatient vs MIMIC ICU) — becomes a generalisation test, not in-distribution +- Note: framing in paper changes slightly to "transfer to PTB-XL" rather than "in-distribution evaluation" + +**Option 3 — PhysioNet AFDB** +MIT-BIH AF Database: 25 long-term ECG recordings with AF annotations. +- Only if Options 1 and 2 both fail +- Very small; only useful for AUROC, not for sample efficiency curves + +**Decision log**: +``` +AF_LABEL_SOURCE = "" # fill in before Day 4 +DECISION_DATE = "" +DECISION_BY = "" +N_AF_POSITIVE = 0 # verify after join/filter +N_AF_NEGATIVE = 0 +``` + +### Baseline A — ECG-JEPA (Weimann & Conrad exact replication) +```python +# Fork: github.com/kweimann/ECG-JEPA +# Config: ViT-S/8, multi-block masking, EMA τ=0.996 +# Input: ECG only (no PPG at all) +# Loss: standard I-JEPA L1 latent prediction (within ECG) +``` +This is the unimodal ceiling. If our model can't match this on ECG-only tasks, something is wrong with the cross-modal architecture. + +### Baseline B — Symmetric cross-modal JEPA (Δt = 0) +```python +# Architecture: identical to E3 in every detail +# EXCEPT: Δt is hardcoded to 0 +# - context: ECG window at time t +# - target: PPG window at the SAME time t (no lag) +# - predictor: cross-attention ECG → PPG +# Loss: L1 latent prediction +``` +This isolates the Δt variable. If E3 beats B on the same tasks, Δt matters. If not, the core claim fails. + +### Baseline C — InfoNCE contrastive (AnyPPG-style) +```python +# Architecture: same dual encoder +# Loss: symmetric InfoNCE +# z_ecg = ecg_encoder(ECG_t) +# z_ppg = ppg_encoder(PPG_t) +# L = InfoNCE(z_ecg, z_ppg, temperature=0.07) +# No Δt, no prediction — pure alignment +``` +This is the comparison against the dominant paradigm in the field. + +### Metrics for all three + +``` +After 20 epochs on 10% data, for each model: + +1. Pretraining loss convergence curve +2. Linear probe AUROC — AF detection (frozen encoder) +3. Linear probe R² — HR estimation (frozen encoder) +4. Latent variance + eigenspectrum rank (collapse check) +5. UMAP: coloured by patient ID, AF status, HR decile +``` + +### What to learn from E2 before running E3 + +| Observation | Implication | +|-------------|-------------| +| Baseline A AUROC > 0.80 | ECG alone is strong; cross-modal has a high bar | +| Baseline B collapses | Symmetric cross-modal JEPA is unstable; add SIGReg to E3 from the start | +| Baseline C > Baseline A | Cross-modal information helps; our model has something to beat | +| All three collapse | Data quality problem — revisit E0 | + +--- + +## E3 — Δt-JEPA v1 +**Days 6–8 | The paper test** + +Minimal version of the actual contribution. +PPG encoding from E1 decision. No SIGReg. No cardiac phase encoding. +Just: ECG context predicts PPG target at t+Δt. + +### Architecture + +```python +# ECG encoder: ViT-S/8, 2D patches (leads × time), EMA target +# PPG encoder: ViT-S/8, encoding chosen in E1, EMA target +# Predictor: 4-layer cross-attention transformer +# query = positional tokens for target PPG beats +# key/val = ECG context latents + Δt embedding +# Δt embed: sinusoidal over [50ms, 500ms] → R^256 + +# Loss: +# L_cross = L1(predicted_ppg_latent, ema_ppg_encoder_output) +# L_self = L1(masked_ecg_pred, ema_ecg_target) [auxiliary, α=0.3] +# L_total = L_cross + α * L_self + +# Δt sampling per batch: +# 60% log-uniform in [50ms, 500ms] +# 40% ground-truth PTT from dataset +``` + +### Training config + +```yaml +epochs: 100 +batch_size: 64 +optimizer: AdamW, lr=1e-4, weight_decay=0.04 +scheduler: cosine with 10-epoch warmup +ema_tau: 0.996 → 0.9999 over first 30% of training +window: 10s ECG + matched PPG +stride: 5s +data: 100% of passing-E0 records +``` + +### Collapse monitoring (every 100 steps) + +```python +# Log these — stop if cross_modal_cosim > 0.99 for 500 consecutive steps +metrics = { + 'ecg_latent_variance': var(z_ecg).mean(), + 'ppg_latent_variance': var(z_ppg).mean(), + 'cross_modal_cosim': cosine_sim(z_ecg_pooled, z_ppg_pred).mean(), + 'ecg_eigenspectrum_rank': effective_rank(cov(z_ecg)), +} +``` + +### Kill criteria — evaluated at epoch 25 + +**K1 — Is the model learning anything?** +```python +mean_baseline_loss = L1(z_ppg_target, z_ppg_mean_over_dataset) +# PASS: model_loss < 0.85 * mean_baseline_loss +``` + +**K2 — Does Δt matter? (the core claim)** +```python +# Run identical linear probe on frozen E3 and Baseline B encoders +# PASS: E3_AUROC > Baseline_B_AUROC + 0.02 (AF detection) +# OR E3_R² > Baseline_B_R² + 0.05 (HR estimation) +# At least one metric must pass +``` + +**K3 — Does cross-modal not hurt relative to unimodal?** +```python +# PASS: E3_AUROC >= Baseline_A_AUROC (within 0.01) +``` + +### Decision tree at epoch 25 + +``` +K1 FAIL → Stop entirely. + Data is unusable or encoder collapsed. + Check alignment, quality filtering, EMA schedule. + If clean: the architecture is wrong. Move to Architecture A (temporal ECG-JEPA only). + +K2 FAIL → Stop. The paper does not exist. + Δt-aware prediction ≈ t-aligned prediction. + Pivot options: + (a) Architecture A — temporal unimodal ECG-JEPA + (b) Study 4 — anomaly detection reusing this codebase + (c) Rerun with cleaner BIDMC data before final decision. + +K2 PASS + K3 FAIL → Cross-modal hurts. + Run 10 more epochs. If still failing: + Reduce PPG encoder capacity, check EMA instability. + If persistent: use lighter PPG encoder (ViT-T instead of ViT-S). + +K1 ✓, K2 ✓, K3 ✓ → Continue to epoch 100. Proceed to E4. +``` + +--- + +## E4 — Rollout coherence test +**Days 9–10 | World model validation** + +This is the experiment that separates "JEPA with a lag" from "a cardiovascular world model." Without it, the paper cannot make the world model claim. + +### Protocol + +```python +# Frozen encoder + trained predictor. N=200 held-out patients. + +for patient in held_out_patients: + z_ecg = ecg_encoder(ecg_window_t) + + # Predict at a grid of Δt values + delta_t_grid = [50, 100, 150, 200, 250, 300, 350, 400, 450, 500] # ms + errors = [] + for dt in delta_t_grid: + z_ppg_pred = predictor(z_ecg, delta_t=dt) + z_ppg_true = ppg_encoder(ppg_window_at_t_plus_dt) + errors.append(L1(z_ppg_pred, z_ppg_true)) + + # Find optimal Δt (prediction error minimum) + optimal_delta_t[patient] = delta_t_grid[argmin(errors)] +``` + +### Physiological consistency checks + +```python +# Check 1: Does optimal_Δt correlate with measured PTT? +correlation = spearman(optimal_delta_t, measured_ptt_per_patient) +# PASS: correlation > 0.30 + +# Check 2: HR-PTT inverse relationship +# High HR → shorter PTT → shorter optimal Δt +high_hr = windows_where(hr > 90 bpm) +low_hr = windows_where(hr < 60 bpm) +# PASS: mean(optimal_Δt[high_hr]) < mean(optimal_Δt[low_hr]), p < 0.05 + +# Check 3: U-shaped error curve (predictor has a real minimum, not flat) +for patient in sample_50_patients: + assert has_clear_minimum(errors) # not monotone, not flat +# PASS: ≥ 60% of patients have clear minimum +``` + +### Pass criteria + +| Check | Pass | Implication if pass | +|-------|------|---------------------| +| Spearman > 0.30 | Model learned PTT implicitly | Core world-model claim supported | +| HR-PTT ordering | Physiologically consistent | Not a lookup table | +| U-curve ≥ 60% | Predictor has a real minimum | Latent space is smooth | + +### If E4 passes but E5 PTT probe fails +The representation has the information but a linear probe can't extract it. Try a 3-layer MLP probe. If that also fails, the PTT information is encoded nonlinearly — mention this as a limitation but don't remove the E4 claim from the paper. + +--- + +## E5 — Downstream probes +**Days 11–12 | Validation signals** + +These run on frozen encoders from E3 best checkpoint. They are probes, not contributions. + +### E5a — PTT regression probe +```python +mlp_ptt = MLP(in=256, hidden=128, out=1) +train(mlp_ptt, + X = pool(ecg_latent), + y = measured_ptt_per_beat, + split = patient_level_80_20) + +# Report: +# MAE (ms) vs naive mean-PTT baseline +# Pearson(predicted_ptt, measured_ptt) +# Within-patient: does the probe track PTT changes over time? +``` + +### E5b — AF detection sample efficiency +```python +# Same linear probe as used in E2/E3 — enables direct comparison +# Label fractions: 1%, 5%, 10%, 50%, 100% +# Models: E3 vs Baseline_A vs Baseline_C +# Goal: sample efficiency curve (not just full-data comparison) +``` + +### E5c — HR estimation +```python +# Linear regression on frozen latent → HR +# Baseline: RR-interval to HR (trivial — sets floor) +``` + +### What must be true for the paper + +| Result | Why it matters | +|--------|----------------| +| E5a MAE < naive by ≥ 20% | PTT is in the latent — confirms E4 | +| E5b: E3 ≥ Baseline_A at all label fractions | Cross-modal doesn't hurt | +| E5b: E3 > Baseline_C at 1% labels | JEPA more sample-efficient than InfoNCE | + +--- + +## E6 — The decisive ablation +**Days 13–14 | The main result** + +One variable changed. Everything else identical. + +| Model | Δt | Architecture | +|-------|-----|-------------| +| E3 (PhysioJEPA) | log-uniform [50, 500ms] | Identical | +| Baseline B (t-aligned) | Fixed 0ms | Identical | + +Both trained to 100 epochs, full data. Evaluated identically. + +### The comparison table (this becomes Table 1 of the paper) + +``` +Model | AF AUROC | HR R² | PTT R² | ECG-PPG R@1 +──────────────────────────────────────────────────────────────── +Baseline A (ECG) | | | N/A | N/A +Baseline B (Δt=0) | | | | +Baseline C (InfoNCE)| | | | +E3 (Δt>0, ours) | | | | +``` + +### Paper-level claim, if E6 supports it + +> Predicting PPG at variable time offset Δt from ECG produces latent representations +> that implicitly encode vascular timing structure (PTT). +> Contrastive alignment at t=0 and predictive alignment at t=0 both destroy this structure. +> This is demonstrated by improved PTT regression, superior sample efficiency on AF detection, +> and physiologically consistent rollout behaviour under varying heart rate. + +One paragraph. Defensible. Not overclaiming causality or blood pressure. + +--- + +## Day 15 — Decision + +``` +GREEN — all of K1, K2, K3, E4 coherence, E6 Δt > Δt=0 + → Write the paper. + → Weeks 3–4: run ablations A1–A5 (morphology, phase encoding, + SIGReg, PTT head, curriculum Δt). + → Target venues (with actual 2026 deadlines): + NeurIPS 2026 workshops (TS4H, BrainBodyFM): ~August 2026 + ML4H 2026 symposium (archival proceedings track): ~September 2026 + ICLR 2027: ~October 2026 (needs strong E4 + clean ablations) + +YELLOW — K2 passes weakly, E4 marginal + → Extend E3 to 200 epochs before deciding. + → If still weak: reframe as temporal ECG-JEPA (Architecture A). + Smaller claim but still publishable as an extension of Weimann & Conrad. + Target: NeurIPS 2026 workshop TS4H. + +RED — K2 fails + → The core idea does not work on this dataset at this scale. + → Immediate pivot options: + (a) Architecture A (temporal ECG-JEPA, unimodal) — reuses everything + (b) Study 4 (anomaly detection via prediction error) — same codebase + (c) Re-run E0 on PhysioNet BIDMC before final call. + Note: CHIL 2026 deadline (Apr 17) has passed. MLHC 2026 (Apr 17) has passed. + Next realistic archival venue: ML4H 2026 (~Sep 2026 estimated). +``` + +--- + +## Post-hoc (2026-04-15): K2 failed, K3 passed, τ mechanism falsified + +Actual results from the E2/E3 run (subset_frac=0.10, 25 epochs, seed=42): + +| Model | Config | ep5 | ep10 | ep25 | +|-------|--------|-----|------|------| +| F (Δt>0) | PhysioJEPA v1 | 0.652 | 0.859 | 0.835 | +| B (Δt=0) | symmetric cross-modal | 0.660 | 0.844 | **0.847** | +| A (unimodal) | ECG-JEPA | 0.783 | 0.736 | 0.703 | +| C (InfoNCE) | symmetric | — | — | under-tuned; not usable | + +**K2: FAIL.** F−B at ep25 = −0.012 (target was +0.02). Δt doesn't matter. + +**K3: PASS BIG.** F−A at ep25 = +0.133. Cross-modal beats unimodal by +~0.13 AUROC. + +**τ-saturation mechanism (slow-τ A ablation): FALSIFIED.** +Slow-τ A (ema_end=0.999, warmup_frac=0.60) had L_self rising *more* than +original A through steps 2000-5000, not less. τ is not the lever. + +Working hypothesis for A's degradation: predictor+query-embedding overfits +to a narrow target distribution in unimodal training. Cross-modal training +provides target diversity the predictor can't overfit to, which is why +F/B stay stable. Needs a different ablation (e.g. shrink predictor, shrink +query embedding, vary masking ratio) to confirm. + +## Summary + +| Day | Experiment | Key output | Decision gated | +|-----|-----------|-----------|----------------| +| 1–2 | E0: data audit | data_card.md, PTT histogram | Dataset go/no-go | +| 3 | E1: PPG encoding | e1_decision.md, ppg_encoder.py | Architecture lock | +| 4–5 | E2: baselines | Floor + ceiling numbers | Calibrates E3 expectations | +| 6–8 | E3: Δt-JEPA v1 | K1/K2/K3 at epoch 25 | Paper exists or doesn't | +| 9–10 | E4: rollout coherence | World model evidence | World model claim | +| 11–12 | E5: probes | PTT, AF, HR numbers | Downstream story | +| 13–14 | E6: decisive ablation | Table 1 | Paper's main result | +| 15 | Decision | Green / yellow / red | What gets written | + +**Compute to day 15 decision point: ~50–70 GPU-hours. Cost: ~$125–175.** + +K2 is answered by day 8. Everything after that is filling in the paper. + +--- + +## Division of work + +| Task | Owner | +|------|-------| +| E0: data pipeline, quality metrics, PTT computation | Zack | +| E1: morphology extractor, two-encoder comparison | Zack | +| E2: ECG-JEPA fork (Baseline A), training | Guy | +| E2: InfoNCE baseline (Baseline C) | Zack | +| E2: Symmetric JEPA (Baseline B) | Guy | +| E3: Δt-JEPA architecture + training loop | Guy | +| E3: collapse monitoring, checkpoint saving | Both | +| E4: rollout coherence test, physiological checks | Guy | +| E5: probe training harness, sample efficiency curves | Zack | +| E6: final comparison, Table 1 | Both | +| Day 15 decision | Both | + +--- + +*Designed so the most important question — does Δt matter? — is answered by day 8, not day 28.* +*Total time to go/no-go: 8 days. Total compute: ~50–70 GPU-hours.* \ No newline at end of file diff --git a/docs/PAPERS.md b/docs/PAPERS.md new file mode 100644 index 0000000000000000000000000000000000000000..e7e4b2005a27d1e04c3717679851d546e19a3938 --- /dev/null +++ b/docs/PAPERS.md @@ -0,0 +1,341 @@ +# PAPERS.md — PhysioJEPA Reference Index +*Oz Labs — April 2026* +*Covers every paper referenced across the full conversation and all project documents.* + +--- + +## How to use this file + +Three things per entry: +1. **What to use it for** — the specific task or decision the agent needs this paper for +2. **Key numbers** — exact figures the agent must not get wrong in code or prose +3. **Location** — where to fetch the PDF + +Read the tier before writing any code in that tier's domain. +Do not cite a number that isn't in this file without fetching the source first. + +--- + +## Tier 1 — Implement from these +*Read before writing any training code. Contains exact equations, hyperparameters, architecture details.* + +--- + +### [T1-1] Weimann & Conrad — ECG-JEPA +**arXiv**: 2410.13867 · `arxiv.org/pdf/2410.13867` +**Code**: `github.com/kweimann/ECG-JEPA` ← fork this + +**Use for**: This is the codebase we fork. Before writing any encoder code, read Section 2 (architecture), Section 3 (data), Appendix A (hyperparameters). +- Patch tokenisation: 2D over (12 leads × time), patch size = 25 time steps at 500 Hz +- Masking: multi-block contiguous, 50% ratio, 4 target blocks +- EMA: τ starts 0.996, cosine-annealed to 0.9999 over training +- Loss: L1 in latent space — no pixel decoder +- ViT-S: 12 layers, d=256, 8 heads, MLP ratio=4 + +**Key numbers**: PTB-XL all-statements AUC **0.945** — this is Baseline A in the experiment matrix. Training time ~26h on RTX 3090. + +--- + +### [T1-2] Assran et al. — I-JEPA +**arXiv**: 2301.08243 · `arxiv.org/pdf/2301.08243` +**Code**: `github.com/facebookresearch/ijepa` + +**Use for**: The masking strategy foundation. Why multi-block contiguous > random masking (forces semantic prediction, not texture interpolation). The stop-gradient / EMA target encoder design justification. The predictor should be *narrower* than the encoder — this prevents shortcutting through the predictor. + +**Key numbers**: ViT-H/14 ImageNet — scale reference only, not a target for us. + +--- + +### [T1-3] Bardes et al. — V-JEPA (Revisiting Feature Prediction) +**arXiv**: 2404.08471 · `arxiv.org/pdf/2404.08471` + +**Use for**: Spatiotemporal tube masking — how to mask contiguous blocks across both spatial and temporal axes simultaneously. Template for PPG 1D+time representation. Two-encoder EMA recipe at scale. Why predicting in latent space beats pixel reconstruction for noisy signals — core justification for JEPA over MAE. + +**Key numbers**: SSv2 top-1 77.3%. + +--- + +### [T1-4] Balestriero & LeCun — LeJEPA +**arXiv**: 2511.08544 · `arxiv.org/pdf/2511.08544` + +**Use for**: Ablation A3 only (SIGReg). Do not implement SIGReg without reading this first. +- Theorem 1: isotropic Gaussian is the optimal JEPA embedding distribution +- SIGReg: K=128 random 1D projections w~N(0,I), KL(z·w || N(0,1)) per projection, sum. O(Kd). +- λ range: [0.01, 0.1]; start at 0.05 +- Apply to *pooled global representation only* — not per-patch tokens +- ~50 lines of PyTorch + +**Key numbers**: 79% ImageNet ViT-H/14 with only 2 loss terms. + +--- + +### [T1-5] Kim — CroPA-ECG-JEPA +**arXiv**: 2410.08559 · `arxiv.org/pdf/2410.08559` +**Code**: `github.com/sehunfromdaegu/ECG_JEPA` + +**Use for**: Second ECG-JEPA implementation for debugging. Cross-Pattern Attention (CroPA) = inter-lead masked attention = inspiration for cardiac phase encoding in ablation A2. Also: 1D PE for predictor vs 2D for encoders — different from Weimann, compare before finalising. + +**Key numbers**: Recovers HR and QRS duration from frozen representations without supervised training — target behaviour for PTT. + +--- + +### [T1-6] Botman et al. — Laya (LeJEPA for EEG) +**arXiv**: 2603.16281 · `arxiv.org/pdf/2603.16281` + +**Use for**: Most direct prior to PhysioJEPA. Read before implementing ablation A3. +- SIGReg with aggressive λ destabilises training on impulsive signals (QRS-like spikes in EEG) +- Mitigation: lower λ (0.001–0.01), aggressive gradient clipping, apply to pooled global rep only +- Latent prediction outperforms reconstruction on EEG clinical tasks + +**Key numbers**: Outperforms reconstruction baselines on EEG-Bench with 10% of pretraining data. + +--- + +## Tier 2 — Baseline numbers and comparisons +*Read to correctly report comparison numbers. Getting baselines wrong is a rejection risk.* + +--- + +### [T2-1] Nie et al. — AnyPPG +**arXiv**: 2511.01747 · `arxiv.org/pdf/2511.01747` + +**Use for**: Primary contrastive baseline (Baseline C in experiment matrix). +- Exact loss: **symmetric InfoNCE** with learnable temperature τ +- **CRITICAL: ECGFounder encoder is FROZEN during AnyPPG training.** ECG is a fixed supervisory signal. AnyPPG is not a jointly trained dual-encoder model. +- Architecture: Net1D (PPG branch), ECGFounder frozen (ECG branch) +- Trained on >100,000 hours + +**Key numbers**: PPG→ECG retrieval **R@1=0.736**, R@5=0.906, R@10=0.935. AF detection AUC ~0.90. Mean **9.1% AUC improvement** over non-ECG-guided baselines. + +--- + +### [T2-2] Wagner et al. — PTB-XL +**arXiv**: 2004.13701 · `arxiv.org/pdf/2004.13701` + +**Use for**: ECG evaluation benchmark. Task definitions, train/test/val splits, and label hierarchy. Must replicate Weimann's exact split for comparison. + +**Key numbers**: Weimann ECG-JEPA AUC **0.945** all-statements = Baseline A target. + +--- + +### [T2-3] Charlton et al. — Towards Ubiquitous BP Monitoring via PTT (review) +**URL**: `pmc.ncbi.nlm.nih.gov/articles/PMC4515215/` + +**Use for**: Before writing E4 rollout coherence physiological consistency checks. PTT definition, normal range, PTT–BP and HR–PTT relationships. Per-patient calibration required for absolute BP — do not claim uncalibrated absolute BP from PTT. + +**Key numbers**: Normal PTT **100–400ms** (ICU adults). Within-patient tracking ~10 mmHg MAE with calibration. + +--- + +### [T2-4] Assran et al. — V-JEPA 2 (including V-JEPA 2-AC) +**arXiv**: 2506.09985 · `arxiv.org/pdf/2506.09985` + +**Use for**: Architecture D future work template. Two-stage recipe: action-free pretraining → action-conditioned fine-tuning with frozen encoder. + +**Key numbers**: **<62 hours** of robot interaction data for Stage 2. SSv2 top-1 77.3%. + +--- + +## Tier 3 — Related work framing +*Read to correctly describe prior work and differentiate PhysioJEPA.* + +--- + +### [T3-1] Sarkar & Etemad — CardioGAN +**arXiv**: 2010.00104 · `arxiv.org/pdf/2010.00104` +**Code**: `github.com/pritamqu/ppg2ecg-cardiogan` + +**Use for**: First major cross-modal ECG-PPG paper (AAAI 2021). +- Uses **CycleGAN backbone** with attention-based generators and dual time/frequency discriminators +- **NOT reconstruction/L1, NOT InfoNCE** — adversarial + cycle consistency loss +- t=0 alignment — discards lag. Do NOT call this "pixel reconstruction." + +--- + +### [T3-2] Liu, Wang & Wang — TSTA-Net +**PMLR**: proceedings.mlr.press/v278/liu25d.html + +**Use for**: Hierarchical contrastive ECG-PPG baseline (PMLR 2025). +- **Hierarchical contrastive learning** — NOT raw InfoNCE +- 9.3% higher AF F1 vs prior SSL methods +- Still t=0 aligned + +--- + +### [T3-3] Fang et al. — PPGFlowECG +**arXiv**: 2509.19774 · `arxiv.org/pdf/2509.19774` + +**Use for**: Two-stage generative translation baseline. +- Stage 1: **InfoNCE instance alignment** (CardioAlign encoder, shared weights) +- Stage 2: **rectified flow** generation from aligned latents +- Figure 1 explicitly shows ECG precedes PPG temporally but the architecture does not exploit this +- Do NOT describe as "rectified flow only" — InfoNCE is in Stage 1 + +--- + +### [T3-4] Dong et al. — Brain-JEPA (NeurIPS 2024 Spotlight) +**arXiv**: 2409.19407 · `arxiv.org/pdf/2409.19407` +**Code**: `github.com/hzlab/2024_Dong_Li_NeurIPS_Brain-JEPA` + +**Use for**: Cardiac phase encoding inspiration (ablation A2). Brain Gradient Positioning → our cardiac phase PE. Hard phase boundaries fail during AF — use soft Gaussian encoding over cardiac landmarks. + +**Key numbers**: NeurIPS 2024 Spotlight. UK Biobank 40k patients. + +--- + +### [T3-5] Hojjati et al. — EEG-VJEPA +**arXiv**: 2507.03633 · `arxiv.org/pdf/2507.03633` +**Code**: `github.com/amir-hojjati/eeg-vjepa` + +**Use for**: V-JEPA adapted to 1D physiological signal — most direct predecessor. How to reshape multi-channel 1D signal into 3D tensor treated as "video." UMAP showing pathological clustering without labels. + +**Key numbers**: TUH fine-tuned accuracy **85.8%**, AUROC **88.5%**. Frozen probe 83.3%. + +--- + +### [T3-6] Munim et al. — EchoJEPA +**arXiv**: 2602.02603 · `arxiv.org/pdf/2602.02603` + +**Use for**: Strongest empirical evidence that JEPA > MAE for noisy medical signals. Use in intro to justify JEPA over MAE. + +**Key numbers**: JEPA degrades **2%** under perturbation vs **17%** for VideoMAE. **79%** accuracy at 1% labels. 20% LVEF improvement. + +--- + +### [T3-7] Wu, Lei et al. — SurgMotion +**arXiv**: 2602.05638 · `arxiv.org/pdf/2602.05638` + +**Use for**: One-sentence citation alongside EchoJEPA: "JEPA's noise rejection under clinical signal artifacts has been validated in echocardiography [EchoJEPA] and surgical video [SurgMotion]." + +--- + +### [T3-8] LeCun — A Path Towards Autonomous Machine Intelligence (JEPA position paper) +**URL**: `openreview.net/pdf?id=BZ5a1r-kVsf` + +**Use for**: One intro citation: "A world model should predict consequences of actions in abstract representation space [LeCun 2022]." + +--- + +### [T3-9] Abbaspourazad et al. — Apple Heart Study Foundation Model +**arXiv**: 2312.05409 · `arxiv.org/pdf/2312.05409` +**Published**: ICLR 2024 + +**Use for**: Prior art on wearable-scale PPG+ECG foundation models. InfoNCE + KoLeo, participant-level positives, Apple Watch data. Shows ECG more discriminative than PPG — context for why cross-modal training helps PPG. + +--- + +## Tier 4 — Evaluation methodology and datasets +*Read when writing the evaluation harness code.* + +--- + +### [T4-1] Pimentel et al. — BIDMC PPG and Respiration Dataset +**PhysioNet**: `physionet.org/content/bidmc/1.0.0/` + +**Use for**: Fallback dataset if E0 fails. +- WFDB format, **53 recordings × 8 min**, **125 Hz** +- Signals: **Lead II ECG + fingertip PPG** + impedance respiration +- Labels: HR, RR, SpO2 — **no AF labels** (use for HR probe only) + +**Key numbers**: **53 patients**, ~7 hours total, **125 Hz**. + +--- + +### [T4-2] Moody et al. — MIMIC-IV Waveform Database +**PhysioNet**: `physionet.org/content/mimic4wdb/0.1.0/` + +**Use for**: Understanding HuggingFace mirror provenance. +- v0.1.0: **200 records from 198 patients**; upcoming release ~10,000 records +- MIMIC-IV-ECG module: **~800k ECGs across ~160k patients**, 500 Hz, 10s, 12-lead — AF label source candidate + +--- + +### [T4-3] Kachuee et al. — Cuffless BP Estimation Dataset (UCI) +**UCI**: `archive.ics.uci.edu/dataset/340` + +**Use for**: E5a PTT probe evaluation. +- 12,000 records, 942 patients — **patient ID removed** — population-level evaluation only +- PPG + ABP at 125 Hz, derived from MIMIC-II + +**Key numbers**: AAMI standard ≤5 mmHg mean ± 8 mmHg SD. + +--- + +### [T4-4] Goldberger et al. — PhysioBank, PhysioToolkit, PhysioNet +**DOI**: 10.1161/01.CIR.101.23.e215 + +**Use for**: Required citation whenever using BIDMC, MIMIC waveforms, or any PhysioNet dataset. One line in methods: "Data obtained from PhysioNet [Goldberger et al., 2000]." + +--- + +## Tier 5 — Context and intellectual lineage +*Do not read these to implement anything. One citation each.* + +--- + +### [T5-1] Ha & Schmidhuber — World Models +**arXiv**: 1803.10122 + +**Use for**: Intro citation only. "World models learn a compressed latent representation and a transition function [Ha & Schmidhuber, 2018]." + +--- + +### [T5-2] Bardes et al. — VICReg +**arXiv**: 2105.04906 + +**Use for**: Related work only. "VICReg requires hand-crafted augmentations that JEPA avoids." + +--- + +### [T5-3] Ronan et al. — VICReg for Brugada ECG Detection +**DOI**: 10.1038/s41598-025-94130-x + +**Use for**: One sentence. "VICReg-based SSL has been applied to ECG classification [Ronan et al., 2025] but requires augmentation engineering." + +--- + +### [T5-4] Johnson et al. — MIMIC-IV (clinical database paper) +**DOI**: 10.1038/s41597-022-01899-x + +**Use for**: Required data citation whenever using MIMIC-IV derived data. "MIMIC-IV [Johnson et al., 2023], a freely accessible EHR database." + +--- + +### [T5-5] CLIMB multimodal clinical benchmark +**arXiv**: 2503.07667 + +**Use for**: ECG-JEPA performance in multimodal settings. "ECG-JEPA outperforms general time-series models like UniTS by 36.8% on ECG tasks [CLIMB, 2025]." One citation in intro. + +--- + +## Quick reference: numbers the agent must not get wrong + +| Claim | Correct value | Source | +|-------|--------------|--------| +| ECG-JEPA PTB-XL AUC | **0.945** all-statements | T1-1 Weimann | +| AnyPPG PPG→ECG R@1 | **0.736** | T2-1 Nie | +| AnyPPG AUC improvement | **9.1%** over non-ECG baselines | T2-1 Nie | +| AnyPPG ECGFounder | **FROZEN** during training | T2-1 Nie | +| EchoJEPA JEPA perturbation | **2%** degradation | T3-6 Munim | +| EchoJEPA MAE perturbation | **17%** degradation | T3-6 Munim | +| EchoJEPA 1% label accuracy | **79%** | T3-6 Munim | +| Normal PTT range (ICU) | **100–400ms** | T2-3 Charlton | +| BIDMC size | **53 recordings × 8 min @ 125 Hz** | T4-1 Pimentel | +| V-JEPA 2-AC interaction data | **<62 hours** | T2-4 Assran | +| EEG-VJEPA TUH AUROC | **88.5%** fine-tuned | T3-5 Hojjati | +| CardioGAN objective | **CycleGAN adversarial** — not reconstruction | T3-1 Sarkar | +| TSTA-Net objective | **Hierarchical contrastive** — not raw InfoNCE | T3-2 Liu | +| PPGFlowECG Stage 1 | **InfoNCE alignment**, then rectified flow | T3-3 Fang | +| BP calibration requirement | **Per-patient calibration required** for absolute values | T2-3 Charlton | + +--- + +## File locations in repo + +``` +docs/papers/*.pdf +``` + +--- + +*This is the complete reference index. Fetch from arXiv if a PDF is missing. Never cite a number not in this file without verifying the source first.* \ No newline at end of file diff --git a/docs/RESEARCH_DEVELOPMENT.md b/docs/RESEARCH_DEVELOPMENT.md new file mode 100644 index 0000000000000000000000000000000000000000..dd363b9aa9e0d9fca6c9651fef50b5f004bf0a4e --- /dev/null +++ b/docs/RESEARCH_DEVELOPMENT.md @@ -0,0 +1,381 @@ +# PhysioJEPA: Learning Cardiovascular Dynamics via Time-Shifted Cross-Modal Prediction +*Oz Labs — Full Research Development Document — April 2026* +*Revision 2: post-reviewer critique. Replaces causalcardio_jepa_full.md* + +--- + +## Change log from revision 2 (post-E0 audit, 2026-04-14) + +- ECG input revised from 12-lead @ 500 Hz to **single lead II @ 250 Hz** (lead II present in 93.7% of HF-mirror segments; 12-lead not available in this dataset) +- ECG patch size revised: 200 ms = **50 samples @ 250 Hz**, 1D over single lead (was 2D (12, 25) @ 500 Hz) +- AF label source locked to **PTB-XL** (see `docs/af_label_decision.md`): MIMIC-IV-ECG path blocked by (a) ~381-patient cohort yielding <100 AF-positive, (b) missing PhysioNet credentialing. Paper now frames AF eval as a transfer claim +- PPG encoding locked to **raw patches** for v1 per E1 Stage-1 result (extraction rate 98.6% but Stage-2 probe deferred to ablation A1 when AF labels are integrated) +- Baseline A (ECG-JEPA) cannot load Weimann's 12-lead PTB-XL checkpoints; must retrain from scratch on single-lead II to be an honest comparison + +## Change log from revision 1 + +- Renamed throughout from CausalCardio-JEPA → PhysioJEPA +- Core claim simplified to one sentence; PTT demoted from contribution to validation signal +- v1 architecture stripped to minimum: raw PPG patches, EMA only, no cardiac phase encoding, no SIGReg +- Morphological encoding, cardiac phase encoding, SIGReg moved to labelled ablations +- "Causal" language replaced throughout with "physiologically informed asymmetry" or "directional asymmetry" +- AnyPPG characterisation corrected: ECGFounder encoder is frozen during AnyPPG training +- Venue targets corrected to reflect actual 2026 deadlines +- PTT head reframed: validation signal, not contribution + +--- + +## 1. The Hypothesis + +**Core claim — one sentence:** + +> Predicting PPG at a variable time offset Δt from ECG produces cardiovascular representations that encode vascular timing structure, while contrastive alignment at t=0 and predictive alignment at t=0 both destroy this structure. + +**What this means concretely:** +After self-supervised pretraining on synchronized ECG+PPG without labels, the model should: + +1. Predict PPG windows N beats ahead from ECG context with lower error than predicting mean PPG — the model is actually learning something +2. Outperform a symmetric JEPA trained at Δt=0 on downstream cardiovascular tasks — the temporal offset matters +3. Produce latent embeddings where PTT (measured post-hoc from the latent's optimal Δt) correlates with ground-truth PTT from peak detection — PTT is implicitly encoded +4. Show physiologically consistent rollout: predicted optimal Δt varies inversely with heart rate and directly with blood pressure categories + +Points 1 and 2 are the paper. Points 3 and 4 are the supporting evidence. + +**Why this is different from existing methods:** + +Every prior cross-modal ECG-PPG method treats the two modalities as symmetric windows on the same cardiac state at the same moment: + +- **AnyPPG** (Nie et al., 2511.01747): symmetric InfoNCE at t=0. Important nuance: the ECGFounder encoder is *frozen* during AnyPPG training — it functions as a fixed supervisory signal, not a jointly-learned representation. This means AnyPPG is not even learning a shared representation; it is distilling a frozen ECG model into a PPG encoder. Same-time alignment still applies. +- **TSTA-Net** (Liu et al., PMLR 2025): hierarchical contrastive learning with spatiotemporal alignment of ECG and PPG. Same-time alignment. +- **PPGFlowECG** (Fang et al., 2509.19774): uses InfoNCE instance alignment internally in Stage 1, then rectified flow generation in Stage 2. Both stages operate at t=0 alignment. +- **CardioGAN** (Sarkar & Etemad, AAAI 2021): CycleGAN-based adversarial waveform synthesis. Pixel-space signal translation, not representation learning. t=0. + +All of them discard the ECG→PPG lag. The lag is the measurement: PTT ≈ 100–400ms encodes arterial stiffness, which encodes blood pressure via the Moens-Korteweg equation. PPGFlowECG even acknowledges this in Figure 1 ("ventricular electrical activation precedes the peripheral pulse") but their architecture doesn't use it. + +**Why JEPA specifically:** + +JEPA's implicit bias — shown formally by Balestriero & LeCun (LeJEPA, 2511.08544) and empirically by Weimann & Conrad (2410.13867) — is toward high-influence, predictable features. In a cardiac signal, the most stable and predictable cross-modal feature is the time-shifted PPG peak following the QRS complex. JEPA will naturally attend to this; symmetric InfoNCE cannot because it penalises the model for not aligning ECG(t) with PPG(t), actively destroying the lag information in order to minimise the contrastive loss. + +--- + +## 2. Architecture + +### v1 (what runs in the experiment matrix) + +The minimum architecture needed to test the core claim. No unnecessary complexity. + +``` +INPUT (revised post-E0, 2026-04-14) +─────────────────────────────────────────────────────── +ECG: [B, 1, 2500] — lead II, 10s @ 250Hz (native HF-mirror rate) +PPG: [B, 1, 1250] — fingertip PPG (Pleth), 10s @ 125Hz (native) +Temporal alignment: sample-accurate (shared segment clock per HF record) + +PREPROCESSING +─────────────────────────────────────────────────────── +ECG: bandpass 0.5–40 Hz → z-score normalisation per window + R-peak detection (Pan-Tompkins) only used for PTT ground truth, + not consumed by the encoder + +PPG: bandpass 0.5–8 Hz → z-score normalisation + [v1: raw patches only — no morphological extraction] + +Segments without lead II (~6.3%) are dropped. + +TOKENISATION +─────────────────────────────────────────────────────── +ECG context encoder: + - 1D patch: 50 samples = 200ms @ 250Hz + - 50 patches per 10s window + - Linear projection → d=256 + - 1D sinusoidal positional encoding (time) + [v1: single-lead; multi-lead 2D is deferred — only II/V/aVR consistently + present, and the Δt claim is lead-agnostic] + +PPG target encoder: + - 1D patch: 25 samples = 200ms per patch + - 60 patches per 10s window + - Linear projection → d=256 + - 1D sinusoidal positional encoding + [v1: raw patches — not morphological tokens] + +ECG CONTEXT ENCODER E_e +─────────────────────────────────────────────────────── +ViT-S (adapted from Weimann & Conrad ECG-JEPA, 1D instead of 2D) + 12 transformer layers, d=256, 8 heads, MLP ratio=4 + I-JEPA masking within ECG (multi-block, 50% ratio) for auxiliary loss + EMA updated: τ annealed 0.996→0.9999 over first 30% of training + Note: cannot load Weimann's published 12-lead checkpoints directly; + Baseline A retrains from scratch on single-lead II for fair comparison + +PPG TARGET ENCODER E_p [EMA updated] +─────────────────────────────────────────────────────── +ViT-T (lighter: 6 layers, d=256) + No masking — encodes full PPG window as target + EMA updated: same τ schedule as E_e + [v1: EMA only — SIGReg is an ablation, not v1] + +Δt EMBEDDING +─────────────────────────────────────────────────────── +Scalar Δt ∈ [50ms, 500ms] → sinusoidal encoding → R^64 +Linear projection → R^256 +Added to predictor as conditioning token + +CAUSAL PREDICTOR P +─────────────────────────────────────────────────────── +4-layer cross-attention transformer + Query: positional tokens for target PPG window positions + Key/Val: ECG context latents z_e + Δt conditioning token + Output: predicted PPG latent ẑ_p(t+Δt) + +The predictor sees no PPG input — only ECG latents + Δt. +This is the architectural enforcement of directional asymmetry. + +LOSS FUNCTION (v1) +─────────────────────────────────────────────────────── +L_total = L_cross + 0.3 * L_self + +L_cross = L1(ẑ_p(t+Δt), z_p(t+Δt)) ← main prediction loss +L_self = L1(ẑ_e_masked, z_e_target) ← auxiliary ECG self-prediction + +[v1: no SIGReg, no PTT head in training loop] + +Δt SAMPLING +─────────────────────────────────────────────────────── +Per batch: + 60% log-uniform in [50ms, 500ms] + 40% ground-truth PTT measured from aligned dataset +``` + +### Ablations (not v1 — run after E3 passes K2) + +| Ablation | What changes | What it tests | +|----------|-------------|---------------| +| A1: Morphological PPG | PPG target encoder uses morphological tokens instead of raw patches | Does structured PPG encoding improve latent quality? | +| A2: Cardiac phase encoding | Add beat-phase positional encoding (P/QRS/ST/T) to ECG encoder | Does phase-aware PE beat standard 2D sinusoidal? | +| A3: SIGReg instead of EMA | Replace EMA with SIGReg (Balestriero & LeCun 2511.08544) | Is SIGReg more stable than EMA on cardiac signals? | +| A4: Joint PTT head | Add PTT regression MLP head to training loss (γ=0.1) | Does supervised PTT signal improve latent vascular encoding? | +| A5: Curriculum Δt | Start with ground-truth PTT only, introduce log-uniform Δt after 30% training | Does curriculum scheduling improve PTT coherence? | + +--- + +## 3. Required Resources + +### Compute +- **E0–E2 (baseline suite)**: ~10 GPU-hours (3 baselines × 20 epochs × small data) +- **E3 (full training)**: ~48–72 hours on A100/H100 for 100 epochs +- **E4–E6**: ~10 GPU-hours (frozen encoder probes + ablations) +- **Full ablation suite (A1–A5)**: ~5 × 24h = 120 hours +- **Total to paper-ready**: ~200 GPU-hours ≈ $500 on Runpod H100 + +### Data +Primary: `lucky9-cyou/mimic-iv-aligned-ppg-ecg` (HuggingFace, instant) +Fallback (if E0 fails): PhysioNet BIDMC (ECG+PPG, documented alignment, open access) +PTT validation: MIMIC-BP curated dataset (UCL/UCI, 1,524 patients) + +### Software +- Base codebase: `kweimann/ECG-JEPA` (MIT licence) +- PPG peak detection: `wfdb` + `scipy.signal` +- SIGReg (ablation A3): ~50 lines PyTorch, implement from Balestriero & LeCun 2511.08544 +- Evaluation: `sklearn` linear probe + custom rollout harness + +### People and timeline +- Guy: architecture, training loop, paper +- Zack: data pipeline, PPG encoder, evaluation harness +- Weeks 1–2: E0→E3 (go/no-go on K2) +- Weeks 3–4: E4→E6 + ablations (if green) +- Weeks 5–8: writing + +--- + +## 4. Execution plan + +See the experiment matrix document (`physiojep_experiment_matrix.md`) for day-by-day detail. Summary: + +| Days | Task | Gate | +|------|------|------| +| 1–2 | E0: data audit | Dataset go/no-go | +| 3 | E1: PPG encoding decision | Architecture lock | +| 4–5 | E2: baseline suite | Floor + ceiling | +| 6–8 | E3: PhysioJEPA v1 | K1/K2/K3 at epoch 25 | +| 9–10 | E4: rollout coherence | World model evidence | +| 11–12 | E5: downstream probes | PTT/AF/HR numbers | +| 13–14 | E6: decisive ablation (Δt vs Δt=0) | Table 1 of paper | +| 15 | Green/yellow/red decision | What gets written | + +--- + +## 5. Pitfalls and Failure Modes + +### Pitfall 1: Dataset alignment coarser than 50ms +**Probability**: Medium. HuggingFace mirror is undocumented. +**Symptom**: PTT ground-truth variance >100ms within-patient +**Response**: Pivot to PhysioNet BIDMC immediately (2-day delay) +**Impact on claim**: Architecture identical; only provenance label changes + +### Pitfall 2: Morphological PPG feature extraction unreliable +**Note**: This is now an ablation (A1), not v1. If E1 shows morphological encoding is unreliable, we simply don't run A1. This is no longer a project-killing risk. + +### Pitfall 3: EMA collapse +**Probability**: Low. ECG-JEPA with EMA is validated at scale. +**Symptom**: Mean cosine sim >0.99 for 500 consecutive steps +**Response**: Reduce τ start to 0.99, check batch size; add SIGReg (ablation A3) earlier +**Monitoring**: Log every 100 steps from epoch 1 + +### Pitfall 4: Cross-modal loss never beats mean baseline (K1) +**Probability**: Low-medium. Depends on dataset quality. +**Symptom**: L_cross plateau above 0.85× mean-PPG-latent baseline +**Response**: Check data quality, increase window overlap, verify EMA schedule +**Nuclear option**: Pivot to Architecture A (temporal ECG-JEPA, unimodal) — reuses all code + +### Pitfall 5 (critical): Δt-aware ≈ t-aligned (K2) +**Probability**: Unknown — this is the central empirical question. +**Symptom**: E3 AUROC ≈ Baseline B AUROC (within 0.02) +**Response**: This is the K2 failure mode. The core claim is wrong on this data at this scale. +**Pivot options**: Architecture A, Study 4 (anomaly detection), or re-run on BIDMC + +### Pitfall 6: Shortcut learning +**Probability**: Medium, especially early in training. +**Symptom**: Model predicts mean PPG morphology for all inputs; L_cross decreases but predictions are identical regardless of ECG input +**Detection**: Compute per-patient prediction variance — if near zero, shortcut is occurring +**Response**: Increase batch diversity, add within-patient hard negatives to Δt sampling + +### Pitfall 7: PTT coherence fails (E4 passes but PTT probe fails) +**Probability**: Low-medium. +**Implication**: The temporal structure is encoded nonlinearly. Try 3-layer MLP probe instead of linear. If that fails, this is a limitation — remove PTT probe from paper claims but keep E4 rollout coherence evidence. + +--- + +## 6. Checkpoints + +| # | When | Pass criterion | Fail action | +|---|------|----------------|-------------| +| C1 | Day 2 | Alignment ≤50ms; ≥500 patients; missing ≤20% | Pivot to BIDMC | +| C2 | Day 3 | E1 decision made and committed | Block on architecture | +| C3 | Day 5 | Baseline B training stable (no collapse) | Add SIGReg to E3 from start | +| C4 | Day 8 (epoch 25) | K1: L_cross < 0.85× mean baseline | Fix or exit | +| C5 | Day 8 (epoch 25) | K2: E3 AUROC > Baseline B + 0.02 | Paper doesn't exist | +| C6 | Day 8 (epoch 25) | K3: E3 AUROC ≥ Baseline A − 0.01 | Reduce PPG encoder capacity | +| C7 | Day 10 | E4: Spearman(optimal Δt, ground-truth PTT) > 0.30 | Keep as limitation | +| C8 | Day 12 | E5: PTT probe MAE < naive by 20% | 3-layer MLP probe fallback | +| C9 | Day 14 | E6: Δt>0 > Δt=0 on ≥2 of 3 metrics | Re-examine K2 | + +--- + +## 7. Evaluation Protocol + +### Primary metrics (determine the paper) + +**E3 / E6 — Core claim test** + +| Metric | What it tests | Baseline | +|--------|--------------|---------| +| AF detection AUROC (linear probe, frozen) | Representation quality | ECG-JEPA: 0.945 (Weimann 2410.13867) | +| HR regression R² (linear probe, frozen) | Cardiovascular signal content | RR-interval baseline | +| ECG-PPG retrieval R@1 | Cross-modal alignment | AnyPPG: 0.736 | + +**E4 — World model evidence (rollout coherence)** + +| Check | Pass criterion | +|-------|---------------| +| Spearman(optimal Δt, measured PTT) | > 0.30 | +| HR-PTT inverse ordering | Significant, p < 0.05 | +| U-shaped prediction error curve | ≥60% of patients | + +**E5 — Downstream validation** + +| Task | Metric | Framing | +|------|--------|---------| +| PTT regression (linear probe) | MAE (ms) vs naive | Validation only — not the contribution | +| AF sample efficiency | AUROC at 1/5/10/100% labels | JEPA sample efficiency advantage | + +### Evaluation philosophy + +Table 1 of the paper (from E6): a 4-row × 4-column table showing Baseline A (ECG-JEPA), Baseline B (Δt=0), Baseline C (InfoNCE), and PhysioJEPA across AF AUROC, HR R², PTT correlation, and retrieval R@1. If rows 3 and 4 are clearly separated, the paper exists. + +The PTT probe and rollout coherence are supporting figures. They interpret why the representation quality is better. They do not constitute the primary claim. + +--- + +## 8. Critic — Strongest Arguments Against + +### Critic 1: PTT can be computed with peak detection in 10 lines of code + +**Correct.** That is exactly why PTT is a *validation signal*, not the contribution. We are not claiming novelty in PTT computation. We are claiming that a model trained on the Δt prediction objective implicitly encodes PTT in its latent space — which is evidence that the latent captures vascular dynamics rather than just cardiac rhythm. If the same latent did *not* encode PTT, we would doubt that it learned anything physiologically meaningful. + +### Critic 2: Small dataset vs AnyPPG's 100k+ hours + +**Conceded.** We are not competing at scale. The comparison is controlled: PhysioJEPA vs Baseline C (InfoNCE) trained on the same N hours. The architectural claim is about inductive bias on fixed data, not about scale. We report this comparison explicitly. + +### Critic 3: "Physiological asymmetry" is just an architectural choice, not a principled claim + +**Partially conceded.** The architecture encodes a *hypothesis* about the direction of information flow (ECG→PPG). If the ablation (Baseline B, symmetric at Δt>0) performs identically to PhysioJEPA, the asymmetry contributed nothing and we remove it from the contribution list. The ablation is the test. + +### Critic 4: The Δt sampling mixing ratio (60/40) is a hyperparameter + +**Correct.** Ablation A5 (curriculum Δt) tests whether this specific ratio matters. For v1 we use 60/40 pragmatically; if A5 shows a different schedule is better, we adopt it. This is not a fundamental weakness — it is a hyperparameter like any other. + +### Critic 5: Shortcut — the model predicts mean PPG for all inputs + +**Real risk.** Explicitly monitored via per-patient prediction variance (Pitfall 6). If detected, addressed before any results are reported. + +--- + +## 9. Reviewer Critiques (updated post-feedback) + +The reviewer critique document (provided separately) raised five structural issues. Status of each: + +| Issue | Status | Resolution | +|-------|--------|-----------| +| 3 contributions in 1 paper | Fixed | Core claim reduced to one sentence; PTT and morphology are evidence/ablations | +| PTT head framing backwards | Fixed | PTT is validation signal; cross-modal Δt prediction is the claim | +| Morphological encoding = #1 technical risk | Fixed | Moved to ablation A1; not in v1 | +| "Causal" overclaimed | Fixed | Renamed to PhysioJEPA; language changed to "directional asymmetry" / "physiologically informed" | +| Core idea not isolated | Fixed | E3 vs Baseline B (Δt=0) is the controlled isolation; both are identical except Δt | +| Baselines needed from Week 1 | Fixed | E2 baseline suite runs days 4–5, before E3 | +| "World model" evaluation missing | Fixed | E4 rollout coherence is explicit and uses physiological consistency checks | + +--- + +## 10. Open Questions + +**Q1: How well is the MIMIC-IV aligned PPG-ECG dataset actually aligned?** +Unknown until E0. The most important unanswered question. Answer by Day 2. + +**Q2: Does the asymmetric architecture (ECG predicts PPG, not PPG predicts ECG) outperform the symmetric version?** +This is ablation A1's question at the architecture level. Baseline B isolates Δt but not directionality — if we add a symmetric Δt>0 variant (PPG predicts ECG with the same lag), we can test this separately. Lower priority; add if time permits. + +**Q3: Does the cross-modal training improve the ECG encoder relative to ECG-only training?** +K3 tests this: E3 AUROC should match Baseline A (ECG-JEPA alone). If it's worse, the cross-modal objective is hurting the ECG representation. This would be a significant negative result worth reporting. + +**Q4: How does the model behave during AF?** +AF removes the periodic P-wave and makes RR intervals irregular. The Δt sampling may fail to find a meaningful optimal during AF episodes. This is actually interesting — the model's inability to predict a stable optimal Δt during AF could itself be a detection signal. Monitor in E4. + +**Q5: Is MIMIC-BP the right held-out dataset for PTT validation?** +MIMIC-BP (Kachuee et al.) is derived from MIMIC-III; the training data is MIMIC-IV-derived. Same institution (BIDMC), no patient overlap, but similar population. This is a reasonable evaluation setup but should be documented carefully to pre-empt reviewer concerns about distribution leakage. + +--- + +## 11. Paper Identity and Venues + +**Title**: *PhysioJEPA: Learning Cardiovascular Dynamics via Time-Shifted Cross-Modal Prediction* + +**One-paragraph abstract (draft)**: +Contrastive self-supervised methods for ECG-PPG representation learning align same-time signals in a shared embedding space, discarding the physiological lag between cardiac electrical activation and peripheral perfusion. This lag — the pulse transit time (PTT) — encodes arterial stiffness and correlates with blood pressure. We introduce PhysioJEPA, a JEPA-based world model that instead trains an ECG encoder to predict PPG latents at a variable time offset Δt, preserving and exploiting the directional temporal structure that contrastive methods destroy. We show that Δt-aware prediction produces cardiovascular representations that (1) outperform same-time contrastive alignment on AF detection sample efficiency, (2) implicitly encode PTT without label supervision — demonstrated via rollout coherence tests and linear probing — and (3) transfer more efficiently from limited labelled data than InfoNCE-trained baselines. Code and models are released under an open licence. + +**Venue targets (updated with real 2026 deadlines)**: + +| Venue | Deadline | Type | Fit | +|-------|----------|------|-----| +| NeurIPS 2026 workshops (TS4H, BrainBodyFM) | ~August 2026 | Workshop (non-archival) | Strong — 4-page format, time series + health | +| ML4H 2026 | ~September 2026 (estimated from 2025 pattern) | Symposium (archival proceedings track) | Strong — healthcare ML focus, 8 pages | +| ICLR 2027 | ~October 2026 | Conference (archival) | Stretch — needs clean ablations and strong Table 1 | +| NeurIPS 2026 main | May 6, 2026 | Conference (archival) | Too soon — experiment matrix runs through mid-May | + +**Realistic path**: NeurIPS 2026 workshop (TS4H) as the first landing point (~August deadline, results from experiment matrix available by then); ML4H 2026 as the archival target; ICLR 2027 as stretch if the rollout coherence result is strong. + +--- + +*Document revision 2 — April 2026* +*All "CausalCardio-JEPA" references replaced. Reviewer feedback incorporated.* +*Active documents: this file + physiojep_experiment_matrix.md* \ No newline at end of file diff --git a/docs/RESEARCH_LOG.md b/docs/RESEARCH_LOG.md new file mode 100644 index 0000000000000000000000000000000000000000..fa94fe9f944b03e9b1c9c6a1175d03406188702f --- /dev/null +++ b/docs/RESEARCH_LOG.md @@ -0,0 +1,883 @@ +# PhysioJEPA research log +*Running narrative — newest entries at top.* + +Format: each entry is `## YYYY-MM-DD HH:MM — [PHASE] — topic` followed by bullet list of what was done, what was found, and any decisions/caveats. + +--- + +## 2026-04-16 09:35 — definitive run: all 3 pods bootstrapping + +All 3 definitive-run pods deployed: + + F: H100 PCIe secure ($2.39/h) @ 216.81.245.97:18654 — still in index build + A: A100 SXM comm ($1.39/h) @ 216.249.100.66:20011 — in precompute (454k windows) + B: A100 SXM secure ($1.49/h) @ 154.54.102.26:17999 — just started pip install + +Config: 100 epochs, full data (subset_frac=1.0 via fast_cache_dir mmap), +mask_ratio=0.75, batch_size=64, seed=42, num_workers=12. + +Aggregate: $5.27/h. Balance: $118.90. At 20h projected = $105. + +Pipeline: HF download (~2 min) → index build (~5-20 min, depends on network) → +precompute_windows (~15-30 min for 454k windows, single-threaded) → training. + +A is furthest along (precompute started). F is behind (slower download). +B just started. First [step 0] expected in ~30 min from A. + +## 2026-04-16 04:40 — full-scale run scoping: need data pipeline optimization first + +User requested 3× H100, full data, 100 epochs, mask=0.75. Budget check: +- Balance: $118.90. H100 PCIe community: $1.99/h × 3 = $5.97/h. +- Steps: ~6160/epoch × 100 = 616k per run. +- sec/step on A40 was 2.8 (production) vs 0.58 (benchmark). Even on H100 + with faster CPU, realistic production sec/step is ~1.0-1.5. +- At 1.2 sec/step: 616k × 1.2 / 3600 = 205h per run × 3 runs × $2/h = $1230. WAY over budget. + +Root cause: __getitem__ calls load_from_disk per-shard + bandpass + zscore +per window at runtime. This dominates training time by 5× over GPU forward. + +Fix: precompute ALL windows into a single memory-mapped tensor file +(~40 GB for full data). __getitem__ becomes a single mmap read (~0.1ms). +sec/step drops to ~0.3, bringing total runtime to ~51h across 3 A100 runs += ~$100. Fits budget. + +Building the precompute script now. + +## 2026-04-16 04:25 — FINAL: abl3 ep25 = 0.848, all pods killed + +**abl3 (mask=0.75, unimodal A) epoch 25 AUROC = 0.848.** + +Complete results table: + +| Model | mask | L_self peak | ep5 | ep10 | ep15 | ep20 | ep25 | +|------------------|------|-------------|-------|-------|-------|-------|-------| +| original A | 0.50 | 0.476 | 0.783 | 0.736 | — | — | 0.703 | +| abl1 (pd=1) | 0.50 | 0.438 | — | — | 0.749 | — | — | +| abl2 (sin-q) | 0.50 | 0.559 | — | — | 0.784 | — | — | +| **abl3 (m=75)** | **0.75** | **0.200** | — | — | 0.838 | 0.845 | **0.848** | +| abl4 (full data) | 0.50 | 0.587+ | — | — | — | — | (killed; spike confirmed) | +| B (Δt=0) | — | — | 0.660 | 0.844 | — | — | 0.847 | +| F (Δt>0) | — | — | 0.652 | 0.859 | — | — | 0.835 | + +**abl3 (0.848) ≈ B (0.847).** Unimodal JEPA with 75% masking exactly +matches cross-modal JEPA. The mechanism story is complete. + +abl4 (full data, 50% mask) showed L_self spike peaking at 0.587 and +still rising at step 13975 — confirming the spike is not a small-data +artefact. Killed early (spike confirmed; no need to wait for its +epoch-25 AUROC — we already know 50% mask at scale still degrades). + +All pods killed. Zero stale compute. Total ablation spend: ~$4.50. + +## 2026-04-16 03:10 — AUROC confirms mechanism end-to-end + +Epoch-15 AUROC on PTB-XL AF: + +| variant | L_self peak | AUROC @ ep15 | +|-----------------|-------------|--------------| +| original A | 0.476 | 0.736 | +| abl1 (pd=1) | 0.438 | 0.749 | +| abl2 (sin-q) | 0.559 | 0.784 | +| **abl3 (m=75)** | **0.196** | **0.838** | +| (ref) B ep10 | — | 0.844 | +| (ref) F ep10 | — | 0.859 | + +**abl3 matches B/F's AUROC at epoch 15.** Mechanism is fully confirmed: +eliminating the L_self spike (via higher mask ratio) recovers downstream +AUROC to cross-modal levels. Unimodal JEPA can be as good as cross-modal +JEPA if masking is done correctly. + +Subtle finding from abl2: sinusoidal query has a LARGER L_self spike +(0.559 vs orig 0.476) but HIGHER AUROC (0.784 vs 0.736). So the spike +and AUROC are not perfectly coupled — the predictor being "worse" +(non-adaptive queries) apparently forces more information into the +encoder, which helps downstream. Noting as an interesting secondary +finding, but abl3 is the main story. + +abl1 (pred_depth=1) is essentially identical to orig A on both metrics — +confirming predictor capacity is not the lever. + +### Paper now has a clean, precise story + +1. Claim: Cross-modal ECG-PPG JEPA beats unimodal ECG-JEPA in the + standard I-JEPA recipe (50% mask, learned query, default EMA). +2. Mechanism: at 50% mask the predictor finds a local-interpolation + shortcut (25 visible context ↔ 25 target contiguous blocks → linear + blend of adjacent patches works). Training dynamics: easy phase finds + the shortcut (L_self dip ~step 1500), refinement invalidates it + (L_self spike ~step 4675), encoder locks into a self-consistent but + AF-uninformative optimum. +3. Fixes: (a) mask ratio 0.75 denies the shortcut structurally — abl3 + matches cross-modal AUROC. (b) Cross-modal prediction is the same + mechanism — 0% PPG visible context → no interpolation path — F and B + both stable. +4. Δt direction doesn't matter (K2 fail is a negative result that + supports the mechanism: the Δt token is a tiny perturbation of the + predictor's query set; what matters is whether interpolation is + available, not where the targets sit on the time axis). + +Actionable recommendation: ECG-JEPA (Weimann & Conrad) used 50% masking. +75% masking is a likely-free improvement, testable on PTB-XL directly. + +### Status + +- abl1 + abl2 pods killed. Answered their questions. +- abl3 running to epoch 25 for the final number. ~1 h left at $0.44/h. +- abl4 (full data) at step 9975 with L_self=0.54 — **spike IS present + at full data**, just delayed. More data slows shortcut discovery but + doesn't eliminate it. Confirms mask ratio is the architectural fix, + not a small-data artifact. +- abl4 still has ~20h to go. Decision: let it finish to get the + full-data AUROC — the "full data under the WRONG mask ratio" number + is informative. At $0.44/h × 20h = $8.80. Still well under budget. + +## 2026-04-16 02:05 — mask_ratio IS the lever (spike window confirmed) + +Full matrix at the critical spike window (original A peaks L_self=0.476 at step 4675): + + step | orig A | abl1 (pd=1) | abl2 (sin-q) | **abl3 (m=75)** | abl4 (full) + ------+--------+-------------+--------------+-----------------+------------ + 1475 | 0.220 | 0.222 | 0.329 | **0.146** | 0.296 + 2475 | 0.340 | 0.339 | 0.482 | **0.165** | 0.233 + 3475 | 0.442 | 0.420 | 0.555 | **0.186** | 0.208 + 4475 | 0.476 | 0.438 | 0.559 | **0.196** | 0.260 + 4975 | 0.475 | 0.398 | 0.551 | **0.200** | 0.287 + 5475 | — | 0.334 | 0.512 | — | 0.313 + +**abl3 (mask 0.75) has NO spike.** L_self rises monotonically from 0.146 +(step 1475) to 0.200 (step 4975) — a gentle climb of +0.05 over 3500 steps, +vs orig A's explosive +0.26 peak. + +**abl1 (pred_depth=1) tracks orig A**. Predictor capacity is not the lever. + +**abl2 (sinusoidal queries) has a LARGER spike than orig A** (0.559 peak vs +0.476). Removing the adaptive query hurts — the predictor can't route +context tokens to targets it cares about. + +**abl4 (full data) shows a muted spike** (0.208 → 0.313 over 2000 steps). +10× data slows shortcut discovery but doesn't eliminate it. Suggests scale +helps but mask_ratio is the cleaner fix. + +### Revised mechanism — unified story + +50% masking gives the predictor 25 target patches and 25 visible context +patches arranged in contiguous blocks. Early training, the predictor +learns a short-range interpolation shortcut: predict masked patch `p` as +a linear blend of adjacent visible patches. This gives a low L_self quickly +(dip at step 1500). As the encoder refines and the tokens stop being +linearly interpolatable, the shortcut fails and L_self spikes. + +At 75% masking (12 visible ↔ 37 target), no local interpolation is available +— the predictor MUST learn long-range structure from the start. No dip, +no rebound. + +Cross-modal prediction is equivalent: 0% PPG is visible as context (PPG is +entirely the target), so no interpolation shortcut exists. F and B dodge +the spike by the same mechanism as abl3. + +**Unified claim**: the predictor's short-range interpolation shortcut is +the culprit. Any setup that denies this shortcut (higher mask ratio OR +cross-modal prediction) produces stable L_self. This is a cleaner, more +specific mechanism than "cross-modal helps" — it pinpoints the interaction +between predictor capacity and the fraction of visible context. + +### Next test: AUROC recovery + +Does abl3's no-spike training actually produce better AF representations? +Kicked off PTB-XL fetch on abl3 pod in parallel with training. Will probe +all 4 ablation ckpts once training completes (~2-3 h). + +Prediction: if the mechanism story is correct, + abl3 AUROC @ ep25 > orig A's 0.703, should approach F/B's 0.83-0.85. + +## 2026-04-16 01:15 — ablation early signal: abl3 (mask 75%) breaks the pattern + +L_self side-by-side at matched steps (only the key ones): + + step | orig A | abl1(pd=1) | abl2(sin-q) | abl3(m=75) | abl4(full) + ------+--------+------------+-------------+------------+----------- + 975 | 0.247 | 0.248 | 0.267 | 0.197 | 0.390 + 1475 | 0.220 | 0.223 | 0.292 | 0.144 | 0.285 (interp) + 1775 | 0.243 | 0.255 | 0.371 | 0.148 | 0.269 + 1975 | 0.256 | 0.269 | 0.403 | — | 0.254 + 2175 | 0.283 | 0.297 | 0.447 | — | 0.230 (interp) + +**abl3 (mask 0.75) is markedly different.** L_self at step 1775 is 0.148, +lower than original A's minimum of 0.220. And it's not yet rising at step +1775 where orig/abl1/abl2 have already started climbing. + +**abl1 (pred_depth=1) ≈ orig A.** The predictor size was not the driver. + +**abl2 (sinusoidal query) is WORSE than orig A.** By step 1775 it's at 0.371 +vs orig A at 0.243. Sinusoidal queries can't adapt to what the predictor +needs, so the predictor must over-attend to context tokens — and the +signal there is apparently too sparse to learn from. + +**abl4 (full data) is descending monotonically** at step 1975 (L_self=0.254). +Too early to say if it avoids the spike — original A's spike was at step 4675. +Full data is ~10× slower per logical training "epoch" so the spike location +in wall-clock terms shifts late. Continue monitoring. + +**Revised mechanism hypothesis**: unimodal JEPA at mask_ratio=0.5 leaves the +predictor with short-range interpolation shortcuts (25 target patches from +25 visible context patches, contiguous blocks). Early training finds these +shortcuts (L_self dips at step 1500). As the encoder refines and +invalidates the shortcuts, L_self rises. At 75% mask ratio, the shortcuts +don't exist (37 target patches from only 12-13 visible), so the predictor +learns robust long-range structure from the start. No dip-and-rebound. + +This is mechanism-specific, falsifiable, and explains both: +(a) why F/B didn't drift (cross-modal loss provides a diverse, non-local + target that can't be locally interpolated) +(b) why abl3 fixed it in unimodal A (higher masking also eliminates the + local shortcut) + +Now the critical follow-up: does abl3's epoch-25 AUROC match F/B (~0.84)? +That would complete the mechanism-to-downstream story. + +Cost check: 4×A40×$0.44 × ~45 min = ~$1.32 so far. abl1/2/3 ~3.5 h to go +(~$5). abl4 ~30 h to go (~$13). Total ~$20 for the suite. Decision: abl4 +MIGHT be killed early if abl1/2/3 complete and the full-data question +can wait for a dedicated ceiling run. + +## 2026-04-16 00:30 — 4 parallel A ablations launched on A40 secure pods + +To find the real mechanism behind A's degradation, running 4 ablations +in parallel. Each identical to original A except one variable. + + abl1: pred_depth 4 → 1 (pod 0n8im5mri5hjk0, 69.30.85.78:22121) + abl2: query_mode learned → sinusoidal (pod a2pye2ki7uvw47, 194.68.245.208:22053) + abl3: mask_ratio 0.5 → 0.75 (pod jwwln4klav8674, 194.68.245.207:22198) + abl4: subset_frac 0.10 → 1.00 (pod 4pvp7yb1rmbxta, 194.68.245.207:22197) + +All on A40 secure ($0.44/h × 4 = $1.76/h aggregate). 25 epochs each. +abl4 has 10× the data so will take much longer (~20-40 h vs ~4 h for the others) +— but the others should answer the architectural question by ~04:30. + +Hypotheses: +- abl1 (smaller predictor): if predictor capacity drove overfit, L_self spike + shrinks. AUROC may improve. +- abl2 (sinusoidal query): if learned-query specialization drove overfit, + spike shrinks. AUROC may improve. +- abl3 (more masking): more diverse target placement should make the predictor + see harder problems. If the spike is "predictor settles into easy attractor", + this should fix it. +- abl4 (full data): if 10% subset was the culprit, spike disappears at scale. + If still present, it's an architectural issue independent of data scale. + +Spike location to compare against: original A had L_self spike peaking 0.475 +at step 4675 (when τ=0.9999). + +## 2026-04-15 21:59 — slow-τ A ablation RESULT: hypothesis FALSIFIED, pod killed + +Side-by-side L_self at matched steps: + + step | orig A | slow-τ A | orig τ | slow τ + ------+--------+----------+--------+-------- + 1475 | 0.22 | 0.22 | 0.9969 | 0.9962 + 1975 | 0.26 | 0.28 | 0.9974 | 0.9963 + 2975 | 0.40 | 0.49 | 0.9988 | 0.9967 + 3975 | 0.45 | 0.60 | 0.9997 | 0.9972 + 4975 | 0.47 | 0.60 | 0.9999 | 0.9977 + 5475 | 0.46 | 0.55 | 0.9999 | 0.9979 + +Slow-τ A's L_self rose MORE than original A's, not less, despite τ being +well below saturation through the critical window. The "τ saturation +amplifies the L_self spike" hypothesis is falsified. + +The L_self rise must be driven by something else. Top candidates: +1. Masking strategy (multi-block 50% ratio) + small data regime — the + predictor overfits to easy target patches early (dip at step 1500), + then the distribution of hard targets dominates as the encoder refines. +2. Query-embedding parameter specialization — the learnable query tokens + narrow predictive scope, and random target placement starts hitting + targets they can't handle. +3. Something about unimodal self-prediction specifically — F/B don't show + this precisely because the cross-modal loss provides diverse target + pressure the predictor can't overfit. + +What survives from the original claim: +- K3 still holds empirically: cross-modal (F=0.835, B=0.847) >> unimodal + (A=0.703) at epoch 25. +- The mechanism story needs replacing. "Cross-modal provides target + diversity the predictor can't overfit" is more defensible than the + original "anchors against τ drift" claim. + +Pod y27osaqv7amz7d killed. Ablation cost: ~$0.35 for ~2 h on A5000 community. + +Impact on user's plan: +- Conditional was: if spike disappears → full-data B run. Spike did not + disappear. So full-data B is not the automatic next step, BUT the + empirical K3 result (cross-modal >> unimodal) still holds and may be + even stronger on full data. Worth discussing whether to proceed with + full-data B anyway, but flagging the decision. + +## 2026-04-15 21:19 — slow-τ A ablation training (early signal: L_self rising even pre-τ-saturation) + +Slow-τ A early trajectory (log_every=25): + step 0: L_self = 1.167 (random init) + step 475: L_self = 0.390 + step 975: L_self = 0.247 + step 1475: L_self = 0.223 ← minimum + step 1975: L_self = 0.282 + step 2175: L_self = 0.313 ← rising, tau still only 0.9963 + +Original A at comparable steps (before any spike): + step 500: L_self = 0.380 + step 1000: L_self = 0.247 + step 1500: L_self = 0.220 ← minimum + step 2000: L_self = 0.258 + step 2225: L_self = 0.283 + +Slow-τ A is tracking original A essentially step-for-step so far. Both hit +their minimum ~step 1500, both starting to rise by step 2000. **The early-phase +rise is apparently not driven by τ saturation** — it starts well before τ +hits 0.999. + +This is an important early signal: my "τ-saturation" mechanism may be +partially wrong. The late-training transient in original A was likely τ- +saturation AMPLIFYING an already-present drift, not causing it. + +Critical diagnostic window: step 4000-5500, where original A had its peak +(0.48 at step 4675). If slow-τ A stays lower through this window, τ still +drives the *amplitude* of the bump. If slow-τ A also spikes at step 4675, +τ is not the driver. + +## 2026-04-15 20:20 — slow-τ A ablation launched + +Ablation pod: y27osaqv7amz7d (RTX A5000 community, FR). Config: + ema_end = 0.999 (vs 0.9999 in original) + ema_warmup_frac = 0.60 (vs 0.30 in original) + everything else identical: subset_frac=0.10, bs=64, 25 epochs, seed=42 + +Prediction: +- If A spike at step 4675 disappears + AUROC recovers to ~0.84 → τ-saturation + mechanism is confirmed, cross-modal anchor story holds. +- If spike disappears BUT AUROC stays at ~0.70 → the original A's problem + wasn't τ saturation per se; the unimodal objective just doesn't contain + enough AF-discriminative signal at this data scale. +- If spike still present → τ schedule isn't the lever; something deeper. + +Conditional on spike disappearing + AUROC recovering, next step is the +full-data B run (100 epochs, H100, 814h) — the ceiling measurement. + +## 2026-04-15 20:00 — refined mechanism for A degradation (not monotonic drift) + +After pulling full WandB curves, correcting my earlier "A drifts monotonically" +claim. A actually has: + + - L_self minimum at step 1500 (value 0.22) + - τ-saturation TRANSIENT at step 4675 (value 0.475) — 3× the bump F/B show + - recovery by step 7400 (value 0.20) + - late-training slow climb to 0.20 at step 15350 + +**F and B also show late-training L_self rise** (0.15 → 0.27). Only the +mid-training transient is unique to A. + +Key finding: A's loss *recovers* but AUROC *doesn't*. AUROC dropped from +0.783 (ep5) → 0.703 (ep25) even though final L_self is comparable to F/B. +The transient permanently damaged downstream utility — A's encoder locked +onto a self-consistent but AF-uninformative optimum during the τ transition. + +Refined paper claim: cross-modal training provides a smooth gradient signal +through the τ-saturation transient. Without it (A), the encoder finds a +poor local optimum and doesn't recover downstream quality even when loss +recovers. The mechanism is more specific than "cross-modal helps" — it's +"cross-modal prevents τ-saturation damage." + +## 2026-04-15 19:30 — FULL K-gate results: K2 FAIL, K3 PASS + +All 4 pods ran to epoch 25. Full probe matrix on PTB-XL AF: + +| Model | ep5 | ep10 | ep25 | +|-------|-----|------|------| +| F (Δt>0) | 0.6521 | 0.8586 | 0.8352 | +| B (Δt=0) | 0.6599 | 0.8440 | 0.8467 | +| A (uni) | 0.7832 | 0.7357 | 0.7025 | +| C (InfoNCE) | stuck at ~loss 3.0 — under-tuned baseline, not usable | + +**K2 FAIL: F − B = −0.012 at epoch 25 (target was ≥ +0.02).** +**K3 PASS BIG: F − A = +0.133 at epoch 25, and A is DEGRADING.** + +Written up in `docs/e2_e3_results.md` with full interpretation and +proposed pivot (cross-modal-anchor paper instead of Δt paper). + +Spend total: ~$6.14 across 4 pods × ~4.5 h. Vastly under budget. + +Pods still have ckpt_final.pt but training is done. Ready to terminate. + +## 2026-04-15 11:55 — FIRST AUROC: F at epoch 10 = 0.859 + +**F (PhysioJEPA, Δt>0) AUROC on PTB-XL AF detection:** + epoch 5 (step ~3200): **0.652** + epoch 10 (step ~6400): **0.859** ← latest + +The jump 0.65 → 0.86 in 5 epochs tells us F is rapidly absorbing AF-relevant +features. Trajectory still climbing — we'd expect further gains by epoch 25. + +Framing correction (user call-out): "approaching Weimann 0.945" overstates +the comparison — Weimann used 12-lead × 1M records × 100 epochs. F is +single-lead II × 40k windows × 10 epochs. What matters is the *trajectory*, +not the ceiling. + +The probe pipeline had one race condition: probe_when_ready.sh saw the +ptbxl_af.npz file appear at ~50% (np.savez_compressed wrote non-atomically), +fired eval_checkpoint.py which tried to unzip an incomplete file — BadZipFile. +Ran the probe manually once the write finished. Retro fix to +probe_when_ready.sh would be `[ -f foo ] && file foo | grep -q Zip` but +we're past it now. + +**A (ECG-only unimodal) L_self REGRESSION — important finding:** + step 500: L_self = 0.380 + step 1000: L_self = 0.247 ← minimum + step 1500: L_self = 0.220 ← actual minimum + step 2500: L_self = 0.331 + step 3500: L_self = 0.442 + step 4500: L_self = 0.477 ← now + step 5000: L_self = 0.472 (tau = 0.9999) + +A is DRIFTING — L_self doubled from 0.22 to 0.47 as EMA τ saturated near 1.0. +Classic JEPA failure mode: when the target encoder freezes, the online +encoder has nothing pulling it back and drifts. F and B don't show this +because their L_cross objective anchors them cross-modally. + +Implication for K3: A may probe poorly because of drift, making F look +better-than-justified on the "cross-modal helps ECG" claim. Need to note +this as a limitation in the paper. The honest fix would be a smaller +final-τ (say 0.999 instead of 0.9999) for A specifically, but we'll note +and move on for now. + +**C (InfoNCE) is NOW LEARNING** after the τ fix + passing LR warmup: + step 0: loss = 4.168 (random) + step 100: 4.159 (still random) + step 500: ~3.8 (starting to move) + step 800: 2.90 ← first clear signal + step 825: 2.98 +Slow but real. InfoNCE with batch 64 is known-weak (CLIP uses 32k). Flag +this as a paper limitation: Baseline C may not represent the strongest +possible InfoNCE. + +State (12:05): + F: step 7400, L_cross=0.247 (still dropping), epoch-10 ckpt probed → 0.859 + B: step 2250, L_cross=0.401, no ckpt yet (epoch 5 ~ step 3200) + A: step 4600, L_self=0.464, ckpt_epoch005.pt available + C: step 825, loss=2.98, climbing out of random + +Now running: PTB-XL fetch_v3 on A, B, C pods in parallel (~10 min). +Will probe A's ckpt_epoch005.pt the moment npz lands on A pod. + +## 2026-04-15 11:46 — F broke through "0.40 floor" → 0.33; C still stuck (LR warmup) + +F at step 4750: L_cross = **0.327**. The earlier "asymptote at 0.40" call +was wrong twice over — model continued to descend. Trajectory: + + step 1100: 0.419 + step 2150: 0.400 + step 2950: 0.377 + step 4225: 0.384 (oscillating in 0.38-0.40) + step 4700: 0.374 + step 4750: 0.327 ← clear break-through + +Possible explanation: τ schedule (0.996→0.9999) has nearly completed +(τ=0.9999 at step 4700+). Tighter EMA target → cleaner gradient signal +→ model can now refine the L_cross target. This is consistent with +the published JEPA training dynamics. + +C: still stuck at loss ≈ 4.16 even with fixed τ init. Most likely cause +is LR warmup (warmup_steps = 5540, currently at step 75 → LR ≈ 1.4e-6). +Needs another ~500 steps to exit ramp. Will revisit at next check. + +B step 1175: L_cross = 0.459 — slope -0.04 / 100 steps. +A step 2250: L_self = 0.297. +PTB-XL fetch: 39%, ETA 24 min. +Probe waiter: still polling. + +## 2026-04-15 11:30 — F's epoch-5 ckpt landed; B looks competitive; C broken (init bug) + +State: +- F: step 4225, L_cross=0.384, L_self=0.139, ckpt_epoch005.pt saved. +- B: step 1000, L_cross=0.499, L_self=0.339 — dropping smoothly. +- A: step 1850, L_self=0.238 — fast convergence on unimodal task. +- C: step 225, loss=4.07 (random baseline = ln(64) = 4.158). **Bug**. + +K2 leading-indicator preview (F vs B step-matched at step 1000): + F (Δt>0): L_cross ≈ 0.43 (interpolated) + B (Δt=0): L_cross = 0.499 + Gap = 0.07 — F leads, but B is dropping faster currently. + K2 jury still out — need B at step 3000+ to see asymptote. + +C bug: init `log_tau = 0` makes the logit-temperature multiplier = 1.0, +i.e. physical τ = 1.0 (very soft InfoNCE). Standard τ = 0.07 means +multiplier ≈ 14. Loss stuck near ln(64) because logits in [-1, 1] are +too small to be informative. Fix: init `log_tau = log(14)`. Will redeploy +C after F's probe AUROC lands. + +PTB-XL fetch: at 25% download (15k of 43k files via concurrent HTTP). +ETA ~30 min until npz exists. Probe waiter still polling. + +## 2026-04-15 11:14 — auto-probe armed; PTB-XL switched to LR variant + +User correctly called out two things: +1. F's L_cross is not at a hard floor — still descending slowly + (0.001-0.005 per 25 steps). Logged. +2. Don't interrupt training. Wait for the natural epoch-5 ckpt. + +Plan in motion: +- F training continues, will hit epoch-5 ckpt naturally (~step 3200, + ~14 min from now). +- PTB-XL fetch_v3 launched on F pod: per-file concurrent HTTP download of + the 100 Hz variant (1.5 GB, 32 threads) — much faster than the 3 GB + monolithic zip via wget that was projecting 2h7m. +- probe_when_ready.sh waiter armed on F pod: polls run_dir for *.pt and + ptbxl_af.npz, fires eval_checkpoint.py the moment both exist. +- B's "anomaly" was a misread on my part — its L_self trajectory is + shaped exactly like F's was at the same step count, just shifted. + +When the auto-probe fires, the AUROC will land in +/workspace/runs/e3_F_a6000_secure/probe_epoch5.json. + +## 2026-04-15 11:08 — correction: F's L_cross is STILL descending, not at hard floor + +Earlier read of "L_cross asymptote at ~0.40" was premature. Looking at the +actual trajectory more carefully: + + step 1100: 0.419 + step 2150: 0.400 + step 2300: 0.392 + step 2750: 0.399 + step 2900: 0.395 + step 2950: 0.377 ← still dropping + step 2975: 0.389 ← oscillating in the 0.38-0.40 band + +The model is in a slow-descent regime (~0.001 per 25 steps when measured +over a 100-step window). Not flat. Honest summary: F is *near* its +asymptote but hasn't fully reached it. The 0.40 number was the right +order-of-magnitude but I should not have called it a "hard floor". + +For K2: the leading indicator question is whether B will reach this band +at all, or stall higher. + +B health check (was flagged as anomalous): + step 100: L_cross=0.841 L_self=0.997 + step 250: L_cross=0.602 L_self=0.859 + step 525: L_cross=0.588 L_self=0.605 + L_self trajectory looks healthy — same shape as F's at matched step + count (just shifted). No EMA misconfig evident. The earlier suspicion + was an over-read. + +A (unimodal, K3 reference): + step 925: L_self=0.256 (already lower than F's L_self trajectory at + the same step count). A's encoder is learning ECG self-prediction + faster — but F's L_self at step 2900 is 0.144, lower still. K3 + comparison needs A to reach step 2900+ for a fair shot. + +Probe plan: wait for F's natural epoch-5 ckpt (~14 min from now = +~step 3200). Then linear probe vs PTB-XL AF. + +PTB-XL fetch: wget download is at 71 MB / 3 GB at 200 KB/s — ETA 2h7m. +Too slow. Need to cancel + use a different mirror. + +## 2026-04-15 10:58 — F at L_cross=0.40 plateau; B chasing; A unimodal also at ~0.42 + +WandB runs (all live): + F (PhysioJEPA): https://wandb.ai/guy-na8/physiojepa/runs/m0cdwa8a + A (ECG-only): https://wandb.ai/guy-na8/physiojepa/runs/t9486rf9 + B (Δt=0): https://wandb.ai/guy-na8/physiojepa/runs/9gwflgr5 + C (InfoNCE): https://wandb.ai/guy-na8/physiojepa/runs/unfs8uzf + +Step-matched comparison at step 250 (both still in warmup): + F (Δt>0): loss=0.864 L_cross=0.607 L_self=0.855 + B (Δt=0): loss=0.860 L_cross=0.602 L_self=0.859 + A (uni): loss=0.546 L_cross=0 L_self=0.546 + +Identical Δt-vs-no-Δt at step 250 — confirming warmup phase predictions. + +F's L_cross trajectory (now at step 2325): + step 1100: 0.419 + step 1500: 0.408 (interpolated) + step 2150: 0.400 ← inflection + step 2300: 0.392 (very slowly continuing to drop) + step 2325: 0.401 (oscillating) + +**F's L_cross has converged to ~0.40 ± 0.02.** This is the asymptote. +1200 steps of training without further drop. Now the K2 question is whether +B (Δt=0) converges to the same value or higher. + +F's L_self (auxiliary) at step 2325 = 0.147; A's L_self at step 425 = 0.42. +Comparing at step 425 only: A's L_self is 0.42, F's was ~0.55 at the same +step count — A is decreasing faster early. Need to wait for A to catch up +to step 2000+ for fair K3 comparison. + +PTB-XL: relaunched fetch with v2 script (wget full zip, mp.Pool 16 workers). +Should complete in ~10 min vs the 2 h v1 was projecting. + +Total spend so far: ~80 min × $1.36/h ≈ $1.81. K2 ETA ~10 hours from now. + +## 2026-04-15 10:36 — A/B/C unblocked via index-copy from F; F at step 1125 + +A/B/C had been stuck in `prepare_data.py` for 27 min — the network FS on +A and B (mfs#runpod.net) makes the per-shard load_from_disk pathological. +Killed prepare_data on all 3, scp'd F's already-built `mimic_index.json` +(48 MB) to each, then launched training directly. + +Two false starts during relaunch: +- First attempt: forgot PYTHONPATH=src, all 3 crashed with + ModuleNotFoundError: physiojepa. +- Second attempt: setsid stripped the env, C crashed again. Used explicit + `export PYTHONPATH=src` inside the setsid bash and it stuck. + +All 4 now training. Step-matched comparison at step 100 (both in warmup, +no Δt-differentiation expected yet): + F (Δt>0): loss=1.135 L_cross=0.836 L_self=0.998 + B (Δt=0): loss=1.140 L_cross=0.841 L_self=0.997 + A (uni): loss=0.834 L_self=0.834 + +Identical so far. Real K2 leading-indicator window is around L_cross ≈ 0.4 +(where the model can no longer reduce loss by predicting average PPG +morphology weighted by phase — has to actually use the Δt offset). +F currently at step 1125, L_cross=0.418 — entering that boundary now. + +PTB-XL fetch: killed. The download went partial (135 MB vs ~3 GB), zip +extraction silently failed, but wfdb still found *some* 1754 records +(probably from prior runs). Will set up via cleaner path before K2 eval. + +## 2026-04-15 10:22 — F at step 425, A/B/C still indexing (network FS) + +F (PhysioJEPA, A6000) at step 425, loss 1.46 → 0.72 (51% reduction): + step 250: loss=0.864 L_cross=0.607 L_self=0.855 + step 350: loss=0.785 L_cross=0.595 L_self=0.636 + step 425: loss=0.717 L_cross=0.580 L_self=0.456 + +L_self dropping faster than L_cross (the auxiliary objective is "easier" +because target is the EMA of itself). L_cross plateauing in the 0.55-0.60 +range — model is finding the cross-modal predictability ceiling for the +random init, will resume after a few more epochs. + +Steady speed: 275 steps in ~13 min ≈ **2.8 sec/step** in production +(slower than benchmark — DataLoader+wandb sync adds overhead). +Projection: 14k steps × 2.8 s ≈ **~11 hours** to epoch 25 on F. + +A/B/C status: still in prepare_data.py (5.5 min elapsed, expected ~5). +Discovery: A and B use **network-mounted /workspace** (`mfs#...runpod.net`) +because they're secure-cloud pods. C uses local SSD (community). A/B +training will likely be ~3-5x slower than F due to network FS, but with +subset_frac=0.10 the OS page cache should warm up after a few epochs. + +PTB-XL fetch kicked off in parallel on F pod (background nohup). +Output to /workspace/cache/ptbxl_af.npz when done. + +Total spend so far: ~25 min × ~$1.36/h ≈ $0.57. +Projected total: ~11 h × ~$1.36/h ≈ ~$15 to K2 verdict. WELL within budget. + +## 2026-04-15 10:14 — F TRAINING, loss decreasing cleanly + +F (PhysioJEPA, A6000): + step 0: loss=1.458 L_cross=1.126 L_self=1.107 + step 25: loss=1.438 L_cross=1.108 L_self=1.100 + step 50: loss=1.369 L_cross=1.048 L_self=1.069 + step 75: loss=1.259 L_cross=0.949 L_self=1.036 + step100: loss=1.135 L_cross=0.836 L_self=0.998 + step125: loss=1.020 L_cross=0.732 L_self=0.961 + step150: loss=0.946 L_cross=0.664 L_self=0.940 + +L_cross dropping 1.126 → 0.664 in 150 steps — strong learning signal. +WandB run live at https://wandb.ai/guy-na8/physiojepa/runs/m0cdwa8a + +Wall-clock observed: 150 steps in ~5 min ≈ **~2 sec/step** in +production (worse than the inline benchmark's 0.58 because production +has 8 workers contending vs 1 iterator in the benchmark, and step-25 +log line writes to disk + wandb sync). At 2 s/step: + 25 epochs × ~640 steps ≈ ~7 hours per pod on A6000-class + 4 pods × ~7 h × $1.36/h aggregate ≈ ~$10 to K2 + +A/B/C still building index (~5 min sequential scan of 412 shards). +Should start training within ~3 min. + +## 2026-04-15 10:10 — solved: it WAS training; Python stdout buffered through tee + +Inline benchmark on F (manual DataLoader iteration) revealed: +- First batch: 3.5 s (worker startup, expected) +- First step compute: 2.4 s (CUDA warmup, expected) +- **Steady-state: ~0.58 s/step on RTX A6000** +- Loss decreasing 1.24 → 1.04 over 5 iters + +Training was working all along. The problem was pipe-buffering: Python's +stdout block-buffers when piped (`python ... | tee ...`), so the +`[step N]` print lines never flushed to the log file. Fixed with +`python3 -u + PYTHONUNBUFFERED=1` in pod_bootstrap.sh. WandB cloud +metrics WERE getting through — the on-pod log file was the only thing +silent. + +Wall clock projection (with subset_frac=0.10, log_every=25): +- F (A6000): 0.58 s/step × 25 epochs × ~640 steps/epoch ≈ **2.5 h** +- A (A5000): probably ~1.2× slower, ~3 h +- B (A40): similar to A6000 (similar perf class), ~2.5 h +- C (A5000): ~3 h +- Total spend to K2: ~3 h × $1.36/h aggregate = **~$4** + +All 4 pods redeployed with `-u`. Now WAIT for first [step] logs to confirm. + +## 2026-04-15 10:05 — even after PTT cut, F still CPU-bound; subset_frac=0.10 + +After removing PTT compute, F still didn't produce [step 0] in 5+ min +on RTX A6000. Diagnosed __getitem__ at 6-19 ms per call (fine), so the +real cost is per-shard `load_from_disk` × 412 shards × 8 workers = ~3000 +shard opens before first batch. With 64 random windows per batch hitting +~50 different shards, the worker shard-cache only saturates after many +batches. + +Cut: subset_frac=0.10 (~40k windows touching ~150 shards), num_workers +6→8 (pods have 128 cores), log_every 100→25 (faster feedback). + +Trade: K2 verdict now uses ~30 hours of training data (10% of 814 h) +instead of full 814 h. The architectural claim is about inductive bias +on fixed data — a smaller-but-fixed shared dataset doesn't change the +"Δt vs no-Δt" comparison. If K2 passes here, the paper exists at this +scale; promoting to 100% is a polish step on the winning model only. + +All 4 pods redeployed. + +## 2026-04-15 10:00 — F was CPU-bound on per-window PTT, redeployed all with fast __getitem__ + +After CUDA fix, F started training but GPU stayed at 18-26% util — workers +running Pan-Tompkins peak detection per window blocked the data path. +~10 min into training and step 0 still hadn't logged. + +Cut: removed `_window_ptt_ms` call from `__getitem__`. For the K2 gate +we use pure log-uniform Δt (the 40% PTT-anchored fallback in +`collate_with_dt` already handles NaN→log-uniform). The K2 question is +"does Δt>0 beat Δt=0?", not "does ground-truth-PTT-anchored Δt beat +log-uniform Δt?" — the latter is a hyperparameter test deferred to +ablation A5. + +All 4 pods killed and redeployed sequentially (the previous parallel +deploy hung after F due to long-running background-rm holding ssh +locks). Sequential scp+launch worked cleanly. F has cached download + +index so should resume fast (~1 min to first step). + +Wasted spend: F's first 10 min on CPU-bound training ≈ $0.08. Acceptable. + +## 2026-04-15 09:55 — major fix: switch from uv venv to system python (CUDA mismatch) + +Worse problem found: F pod (RTX A6000, CUDA 12.4 driver) ran the trainer +on CPU, not GPU. Diagnosis: uv resolved torch==2.11.0+cu130 from PyPI, which +needs driver ≥555. The runpod image's *system* Python already has torch +2.4.1+cu124 properly configured. + +Fix: bootstrap.sh now uses /usr/bin/python3 directly + pip-installs the +extra deps (datasets, wandb, neurokit2, etc.) into system site-packages. +Skips uv venv entirely on the pod. Verified torch 2.4.1+cu124 sees the +A6000 with `torch.cuda.is_available() == True`. + +Killed all 4 pods' running procs and redeployed. F skips download (cache +intact); A/B/C re-download. + +Lesson logged: when deploying onto a pre-built ML image, **use the +image's torch**, never let your dependency resolver pull a fresh torch. +The image vendor matched torch to driver for a reason. + +## 2026-04-15 09:45 — F crashed on first epoch, others mid-bootstrap + +F pod made it all the way through download + index build (~10 min) and +started training, then **PicklingError on the closure-based collate_fn** +when DataLoader spawned workers. Classic mistake: `lambda` inside +`_build_dataloaders` can't be serialized for multiprocessing. Refactored +to a top-level `_Collator` class. Smoke test passes. F redeployed. + +Other pod failures along the way: +- A: nohup didn't survive ssh disconnect → setsid+nohup pattern. +- B: uv chose Python 3.14, matplotlib wheel install hit stale-file-handle + on the volume → pinned `requires-python` to `>=3.11,<3.13` and added + `--link-mode=copy` to uv sync. +- pod_bootstrap path-case bug → handled both PhysioJEPA and physiojepa. +- Tar perms from `.claude`/`.agents` folders → excluded. +- `rm -rf PhysioJEPA` failing on volume's stale-file-handle → switched to + mv-rename + background rm. + +Bootstrap timing observed: +- HF MIMIC download (412 shards / 1.5 GB): ~50 s on RTX A6000 secure pod +- uv sync (~100 packages incl. torch): ~3 min on cold cache, ~30 s warm +- Index build (sequential scan, 412 shards): ~5 min on A6000 + +Cumulative wasted spend so far: ~30 min × $1.36/h ≈ $0.70. Acceptable. + +## 2026-04-15 09:25 — 4 pods running, 3 deploy-fanned, F started bootstrap + +State: pod_create is non-idempotent (lesson). Probing for GPU availability +created 4 pods accidentally — turned that into the actual experiment by +mapping each model to a GPU sized to its cost: + + C (InfoNCE, smallest) -> RTX A5000 community $0.16/h (1mc23jk89rf98v) + A (ECG-only) -> RTX A5000 secure $0.27/h (xr4s6q5fhpsave) + B (cross-modal Δt=0) -> A40 $0.44/h (hwa3i4i569fwwl) + F (PhysioJEPA Δt>0, biggest) -> RTX A6000 $0.49/h (5umn3qjlrlmp4u) + +Burn rate: $1.36/h. At ~24h-to-K2 worst case = ~$33. Within budget. + +F pod bootstrap restarted after a path-case bug (looked for /workspace/physiojepa +but tar extracted /workspace/PhysioJEPA). Fixed pod_bootstrap.sh to detect either. +Forced tarball rebuild. + +Bootstrap timing on F pod (RTX A6000): +- uv install + dep sync: ~3 min (torch 2.11, wandb, scipy, neurokit2, datasets, etc.) +- HF MIMIC download (1237 files / ~1.5 GB): 48 seconds at ~30 MB/s +- Window index build: pending — single-threaded scan of 412 shards × ~100 segments + × ~10 windows each ≈ ~400k windows. This is the bottleneck. + +Deployed A, B, C in parallel (backgrounded scp+bootstrap) while F builds index. + +Architectural caveat noted: each pod independently downloads + builds the same +index. Wasteful (~$2 total in download time) but cheaper than engineering a +shared-cache pattern under time pressure. Logging for next iteration. + + +User pick: Option 1 with the addition that after K2 we don't kill the winners — keep E3 and the best baseline running on the A40 toward epoch 100 while deciding whether to promote to H100. Cost of leaving an A40 running ≪ cost of cold-booting an H100. Locking that into the plan. + +## 2026-04-14 — Harness built + smoke-tested + budget reality check + +**What's done**: +- Full training harness committed: `src/physiojepa/{vit,dt_embed,ema,masking,data,monitor,probe,ptbxl,models,trainer}.py`. +- Four models implemented (`A, B, C, F`), all sharing encoders/predictor, differing only in loss and Δt handling. +- Shared config: `configs/base.yaml`. CLI: `scripts/train.py`, `scripts/prepare_data.py`, `scripts/smoke_test.py`. +- **Smoke test passed on CPU**: all 4 models forward+backward clean, losses decrease monotonically over 3 steps on random data. Baseline C starts at ln(B)=1.386 as expected for untrained InfoNCE. +- RunPod CLI functional, $50.05 balance, no pods running. + +**Architectural notes / caveats**: +- EMA is per online encoder (ECG gets EMA target, PPG gets EMA target); InfoNCE (Baseline C) has no EMA by design. +- Self-prediction loop is per-sample (variable mask lengths). Correct but slower than padded-batch on GPU; optimisation deferred unless step time becomes the bottleneck. +- Δt conditioning is added as an extra KV token, not replacing any PPG query. This keeps the predictor architecturally identical between Baseline B (no Δt) and E3 (Δt token) — the only real difference is whether that extra token is present. **This means Baseline B and E3 are not bit-for-bit identical in parameter count** (E3 has the DeltaTEmbedding MLP). Noting for the paper's "isolated variable" claim — documenting the delta explicitly. + +**Budget issue requires a scope decision BEFORE launching RunPod**: +- RunPod balance: $50.05. Spend limit: $80. +- Research doc's "~$500 on H100" assumed sequential runs, not 4× parallel. Parallel 4× 100-epoch on H100 ($3–4/h) for ~48h = ~$600–$800. Over limit. +- Even on RTX 3090 ($0.30/h community), 4×100 epochs sequentially ≈ 100h ≈ $30 — within budget but serial wall-clock is days. +- The K2 verdict lands at **epoch 25** per the matrix's C5 checkpoint. Paper-existence is decided at epoch 25, not 100. Running to 100 is polish, not decision. + +**Plan revision (to be confirmed with user)**: +1. Start 4× parallel on A40 (cheap, ~$0.35/h on community cloud). ~25 epochs to K2 checkpoint. +2. Epoch 25 = gate. If K2 passes (E3 > Baseline B by ≥0.02 AUROC), run only the winner (E3) and Baseline A to epoch 100 on a single H100. +3. If K2 fails at epoch 25, stop, write up negative result, preserve budget. + +Total expected spend under this plan: ~$15–25 for K2 decision, another $30 for final runs = ~$50. Fits budget. + +**Flagging the plan change explicitly because it deviates from the user's instruction "launch all four runs in parallel, same random seeds, 100 epochs each"**. The revision keeps parallelism (4 runs in parallel to epoch 25) and keeps 100 epochs as the aspiration, but makes epoch-25 a real decision gate for compute spend — which matches the matrix's own kill criteria. + +--- + +## 2026-04-14 — E2/E3 kickoff + +**Scope**: build shared harness, implement four models (Baseline A/B/C + E3 PhysioJEPA), CPU single-batch test, then launch 4× parallel H100 training on RunPod. + +**Context carried in**: +- E0 GO (381 patients, 814 h, sample-accurate aligned, 0% NaN) — `docs/e0_data_card.md` +- E1 raw patches locked for v1 — `docs/e1_decision.md` +- AF labels = PTB-XL (transfer claim) — `docs/af_label_decision.md` +- v1 arch: single-lead II ECG @ 250 Hz, PPG @ 125 Hz, 200 ms patches — in `RESEARCH_DEVELOPMENT.md` §2 + +**Plan**: +1. Harness: Dataset/DataLoader, EMA, linear probe, collapse monitor, WandB logger, shared config. +2. Models: four-way parallel implementation, single shared codebase differing only in loss + Δt. +3. RunPod: no skill installed — will use REST API via `RUNPOD_API_KEY`. +4. Single-batch CPU test before any GPU run. + +Entries below will capture every decision, failure, and caveat. diff --git a/docs/af_label_decision.md b/docs/af_label_decision.md new file mode 100644 index 0000000000000000000000000000000000000000..3c2ad53c50965bc859061c30262d5e27b06bda1e --- /dev/null +++ b/docs/af_label_decision.md @@ -0,0 +1,41 @@ +# AF label source — decision +*PhysioJEPA — Oz Labs — 2026-04-14* +*Referenced from `EXPERIMENT_TRACKING.md` E2 "AF label source — decide before running E2"* + +--- + +## Decision + +**AF_LABEL_SOURCE = PTB-XL** (Option 2 in the experiment matrix). + +**Framing in the paper**: AF detection is a *transfer* evaluation from ICU PPG+ECG pretraining to outpatient 12-lead ECG. This is the honest claim given the label choice and actually makes a cleaner sample-efficiency story. + +## Reasoning + +1. **MIMIC-IV-ECG-based labels fail the sample-size gate.** ~381 patients × ~10–15% AF prevalence ≈ 38–57 AF-positive patients. The experiment matrix requires ≥100 AF-positive + ≥100 AF-negative for the linear probe to be meaningful. MIMIC-IV-ECG cannot clear that bar on this cohort. +2. **PhysioNet credentialing is not set up.** Even if we wanted the in-distribution labels, getting MIMIC-IV-ECG requires credentialed access (CITI + DUA) that is not currently provisioned. Any "unauthenticated" HF mirror of MIMIC-IV-ECG would be a DUA violation. +3. **PTB-XL has the numbers.** ~1.5k `AFIB` records out of 21.8k → easy 100/100 split. Open access, already on HuggingFace, used by Weimann & Conrad — enables *direct numeric comparison* to Baseline A's published 0.945 AUROC. +4. **Sample efficiency is the strong story anyway.** The paper's E5b claim is "JEPA transfers better from few labels than InfoNCE" — transferring from MIMIC ICU pretraining to PTB-XL outpatient 10-s strips is a stronger sample-efficiency claim than in-distribution, and it maps directly to the Weimann comparison. + +## Consequences + +- **E2/E3/E5 pipeline**: AF probe runs on PTB-XL. All three baselines (A, B, C) and PhysioJEPA share the same PTB-XL eval split. We replicate Weimann's split for Baseline A. +- **Added data prep step**: load PTB-XL (single-lead II @ 500 Hz → resample to 250 Hz to match pretraining), extract `AFIB` vs others as binary label. ~1 day of work for Zack. +- **Lead II compatibility**: PTB-XL has all 12 leads at 500 Hz. We pick lead II and resample to 250 Hz; the single-lead input shape is identical to pretraining. +- **HR regression probe (E5c)**: also run on PTB-XL (RR-interval labels derivable from raw ECG). Keeps all probes on one eval dataset. +- **PTT regression probe (E5a)**: uses MIMIC-BP (UCI, Kachuee et al.) as originally specified — PTB-XL has no BP/PTT labels. Note population overlap: MIMIC-BP is MIMIC-III derived, our pretraining is MIMIC-IV derived. + +## Log + +``` +AF_LABEL_SOURCE = PTB-XL +DECISION_DATE = 2026-04-14 +DECISION_BY = Claude (autonomous per project lead instruction) +N_AF_POSITIVE = ~1,514 PTB-XL records with AFIB scp_code (to be verified on download) +N_AF_NEGATIVE = ~20,300 PTB-XL records without AFIB/AFLT (abundant) +``` + +## Fallback chain (unchanged) + +- If PTB-XL AFIB count drops below 100 after quality filtering → **PhysioNet AFDB** (25 patients, AUROC only, no sample-efficiency curves). +- If PTB-XL download is blocked for any reason → hosted copy on HuggingFace (e.g. `PULSE-ECG/PTB-XL`) is open-access, no credentialing required. diff --git a/docs/e0_alignment.json b/docs/e0_alignment.json new file mode 100644 index 0000000000000000000000000000000000000000..244ae0d6413799ef204f0b20d8882032e4731a12 --- /dev/null +++ b/docs/e0_alignment.json @@ -0,0 +1,9 @@ +{ + "n_clean_beats": 6295, + "n_good_segments": 100, + "ptt_foot_median_ms": 288.1267764850861, + "ptt_foot_p5_ms": 144.06338824254306, + "ptt_foot_p95_ms": 476.20953335729155, + "within_segment_std_median_ms": 104.21735953917086, + "within_segment_std_p90_ms": 134.713432235361 +} \ No newline at end of file diff --git a/docs/e0_data_card.md b/docs/e0_data_card.md new file mode 100644 index 0000000000000000000000000000000000000000..e78b925367ff5533ed2141368fbfb1cdeab15ab4 --- /dev/null +++ b/docs/e0_data_card.md @@ -0,0 +1,121 @@ +# E0 — Data audit: `lucky9-cyou/mimic-iv-aligned-ppg-ecg` +*PhysioJEPA — Oz Labs — 2026-04-14* + +Audit scripts: `scripts/e0_audit_v2.py`, `scripts/e0_alignment_check.py` +Raw JSON: `docs/e0_report.json`, `docs/e0_alignment.json` +Figures: `docs/figures/ptt_histogram.png`, `docs/figures/ptt_histogram_foot.png`, `docs/figures/sanity_check.png` + +--- + +## Decision + +**GO — with one caveat: the ≥500-patient gate is borderline (~381 extrapolated). Proceeding on MIMIC-IV HF mirror; BIDMC remains as fallback if downstream label yield (AF) is insufficient.** + +See the gate table below for the full reasoning. + +--- + +## Dataset layout + +- 412 HF `save_to_disk` shard folders. Each shard ≈ 100 segments ≈ 1 MIMIC-IV waveform record ≈ 1 patient. +- Schema per row (verified against `shard_00000/dataset_info.json`): + - `record_name` (str, e.g. `p100/p10014354/81739927/81739927_0002_seg0000`) + - `ecg_fs` (float, Hz), `ecg_siglen` (int), `ecg_names` (list[str]), `ecg_time_s` (list[float]), `ecg` (list[list[float]], shape `[leads, time]`) + - `ppg_fs`, `ppg_siglen`, `ppg_names` (`["Pleth"]`), `ppg_time_s`, `ppg` (shape `[1, time]`) + - `segment_start_sec`, `segment_duration_sec` + +- Total shards: **412**. Default HF "train" split contains only summary metadata — the real data must be pulled via `snapshot_download` + `load_from_disk` per shard. +- Example record: 3-lead ECG `[3, 3200]` @ 249.89 Hz, PPG `[1, 1600]` @ 124.945 Hz, ~12.8 s duration. +- ECG/PPG time vectors share the same segment-relative clock and start within `1/fs_ecg` of each other (sub-4 ms) → the mirror is sample-accurate aligned by construction (both signals come from the same underlying WFDB record). + +## Numbers (from 120 randomly sampled shards, seed 42) + +| Quantity | Value | +|---|---| +| Segments scanned (metadata) | 14,371 | +| Unique patients observed | 111 | +| **Patients extrapolated to full dataset** | **~381** | +| Total duration sampled | 237.0 h | +| **Total duration extrapolated** | **~814 h** | +| ECG sampling rate (median) | **249.89 Hz** | +| PPG sampling rate (median) | **124.95 Hz** | +| ECG siglen (median) | 14,994 samples (≈60.0 s) | +| PPG siglen (median) | 7,497 samples (≈60.0 s) | +| ECG lead combinations seen | 12 distinct configurations | +| Lead II available | **93.7% of segments** | +| PPG channel | `Pleth` (100%) | +| Missing-value rate (NaN) | **0.000%** on ECG, **0.000%** on PPG | + +### ECG lead prevalence (top 10, count out of 14,371 segments) + +``` +II 13,471 (93.7%) +V 12,326 (85.8%) +aVR 11,218 (78.1%) +III 1,748 (12.2%) +aVF 399 +V2 221 +V5 221 +I 82 +``` + +### PTT sanity (ECG R-peak → nearest PPG peak in [50, 500] ms, 1-to-1 only) + +| Metric | Peak-based (v1) | Foot-based (v2) | +|---|---|---| +| Clean beats | 10,193 | 6,295 | +| Good segments (≥3 clean beats) | 150 / 158 attempted (**95%**) | 100 / 100 | +| PTT median | **276 ms** | 288 ms | +| PTT P5 / P95 | 92 / 448 ms | 144 / 476 ms | +| Within-segment std, median | 107 ms | 104 ms | + +- Both histograms are multimodal with satellite peaks separated by ~RR-interval fractions → **peak-matching ambiguity, not dataset misalignment**. A peak-on-the-next-beat mispick produces a ±200–300 ms shift and explains the 100-ms within-segment std directly. +- The aligned 60-s ECG + PPG traces in `sanity_check.png` are visually locked beat-for-beat. Physiologically plausible PTT median. + +## Gate check (from `EXPERIMENT_TRACKING.md` E0) + +| Gate | Target | Observed | Status | +|---|---|---|---| +| Median alignment ≤ 50 ms | ≤ 50 ms | Sub-sample alignment (shared clock); PTT median 276 ms is physiological, not a drift | **PASS** (data-side); the 107 ms within-segment std is an artefact of the crude R→PPG nearest-peak estimator, not temporal misalignment | +| PTT within-patient std ≤ 80 ms | ≤ 80 ms | Cannot be assessed cleanly with current peak detector — need `neurokit2`-grade PPG foot detector to disambiguate mispicks | **DEFERRED** — revisit in E1 with better PPG detector; not a blocker for v1 (model sees raw patches) | +| Patients ≥ 500 | ≥ 500 | **~381 extrapolated** (111 confirmed in 120/412 shards) | **FAIL (marginal)** | +| Missing rate ≤ 20% after windowing | ≤ 20% | 0.0% NaN, 0 empty segments in scanned sample | **PASS** | +| PTT range in [50, 500] ms | physiologic | P5 = 92 ms, P95 = 448 ms; range inside envelope | **PASS** | + +## Interpretation of the patient-count "fail" + +The research plan's `≥500 patients` threshold was set before we knew the HF mirror's exact population. **~381 patients over ~814 h** is: + +- Plenty of **hours** for JEPA pretraining (AnyPPG trained on 100k+ h, ECG-JEPA on 1M+ records — but Weimann's public checkpoints achieve 0.945 AUC with much less; and PhysioJEPA's architectural claim is about **inductive bias on fixed data**, not scale — this is explicitly acknowledged in `RESEARCH_DEVELOPMENT.md` §8 Critic 2). +- **Marginal for AF sample-efficiency (E5b)** — we need ≥100 AF-positive and ≥100 AF-negative patients for the linear probe. With 381 patients this is tight but achievable if AF prevalence in MIMIC-IV ICU is ~10–20% (typical). +- Below threshold for population generalization — we should **pre-emptively frame the paper's N-scale** caveat explicitly (expected reviewer pushback). + +### Action + +- **Proceed with E1 and E2 on this dataset.** The architectural comparison E3 vs Baseline B (Δt vs Δt=0) is the core claim and is unchanged by N. +- Before E5b, **decide AF label source** (`EXPERIMENT_TRACKING.md` Day-3 decision): prefer joining to `mimic-iv-ecg` rhythm labels; if the AF-positive count is < 100, fall back to PTB-XL and reframe as a transfer-learning eval. This decision is now urgent. +- Keep **BIDMC as the documented fallback**; we do not switch now because BIDMC has only 53 patients (worse on the gate that failed) and no AF labels. + +## Architectural implications for v1 (RESEARCH_DEVELOPMENT.md §2) + +The spec assumed **12-lead ECG @ 500 Hz**. The HF mirror is **3-lead (primarily II/V/aVR) @ 250 Hz**. Required revisions, staged for Day 3 architecture lock: + +1. **ECG encoder input**: single-lead II (93.7% coverage; drop records without it). Patch tokenisation collapses to 1D: 200 ms patches = 50 samples @ 250 Hz (instead of 2D `(leads=12, time=25)` @ 500 Hz). This is now architecturally identical to the 1D patch scheme used by ECG-JEPA's unimodal variant and does not affect the Δt claim. +2. **PPG encoder input**: already 1D single-channel at 125 Hz → 200 ms patches = 25 samples, exactly as specified. +3. **Sampling-rate symmetry**: both streams now satisfy *ECG_fs = 2 × PPG_fs*, matching the native MIMIC waveform format. No resampling needed. +4. **Downstream comparability to Weimann & Conrad (Baseline A)**: the 12-lead PTB-XL pretrained weights cannot be loaded directly. Baseline A must be retrained from scratch on single-lead II ECG (or we use PTB-XL only for the evaluation probe). Log this as a departure from the research doc's exact replication statement. + +## Files written + +- `docs/e0_report.json` — raw numbers +- `docs/e0_alignment.json` — foot-based alignment check numbers +- `docs/figures/ptt_histogram.png` — peak-based PTT (v1) +- `docs/figures/ptt_histogram_foot.png` — foot-based PTT (v2) +- `docs/figures/sanity_check.png` — 5 random 60-s aligned ECG+PPG overlays +- `scripts/e0_peek.py`, `scripts/e0_audit.py`, `scripts/e0_audit_v2.py`, `scripts/e0_alignment_check.py` + +## Open follow-ups before E1 starts + +1. Verify AF-positive count after joining to `mimic-iv-ecg` (Zack, Day 3 gate). +2. Swap PPG peak detector for `neurokit2.ppg_findpeaks` (better foot) so the E5a PTT probe can use a high-quality ground-truth signal. +3. Commit an architectural-revision note to `RESEARCH_DEVELOPMENT.md` §2 and `ARCHITECTURES_EXPLORATION.md` Architecture F §v1 — single-lead ECG, 250 Hz, 50-sample patches. diff --git a/docs/e0_report.json b/docs/e0_report.json new file mode 100644 index 0000000000000000000000000000000000000000..6c6cf79da08f17745953704c6729b8e43cd99bdd --- /dev/null +++ b/docs/e0_report.json @@ -0,0 +1,33 @@ +{ + "dataset": "lucky9-cyou/mimic-iv-aligned-ppg-ecg", + "shards_total": 412, + "shards_sampled_meta": 120, + "segments_meta_scanned": 14371, + "unique_patients_in_sample": 111, + "unique_patients_extrapolated": 381, + "total_duration_hours_sampled": 236.97, + "total_duration_hours_estimated": 813.6, + "ecg_fs_median_hz": 249.88999938964844, + "ppg_fs_median_hz": 124.94499969482422, + "ecg_siglen_median_samples": 14994, + "ppg_siglen_median_samples": 7497, + "ecg_lead_counts_top10": { + "II": 13471, + "V": 12326, + "aVR": 11218, + "III": 1748, + "aVF": 399, + "V2": 221, + "V5": 221, + "I": 82 + }, + "lead_II_available_frac": 0.9373738779486466, + "ptt_beats_measured": 10193, + "ptt_good_segments": 150, + "ptt_segments_attempted": 158, + "ptt_median_ms": 276.1214941315409, + "ptt_p5_ms": 92.04049804384695, + "ptt_p95_ms": 448.1972078656895, + "ptt_within_segment_std_median_ms": 106.83521620259722, + "ptt_within_segment_std_p90_ms": 129.12106079110902 +} \ No newline at end of file diff --git a/docs/e1_decision.md b/docs/e1_decision.md new file mode 100644 index 0000000000000000000000000000000000000000..2f77248c194417678632398094ee87e7038c039c --- /dev/null +++ b/docs/e1_decision.md @@ -0,0 +1,42 @@ +# E1 — PPG encoding decision +*PhysioJEPA — Oz Labs — 2026-04-14* + +Script: `scripts/e1_ppg_encoding.py` +Raw JSON: `docs/e1_stage1_report.json` + +--- + +## Decision + +**v1 uses raw 200 ms PPG patches (25 samples @ 125 Hz) → linear projection → d=256 tokens.** + +Morphological encoding is *viable* on this data but is held as ablation **A1** per the research plan (`RESEARCH_DEVELOPMENT.md` §2 v1 spec, §Change-log bullet 3). The Stage-2 linear-probe comparison that would justify switching to morphology cannot run until AF labels are in place; it runs as part of A1 after E5a. + +## Numbers (Stage 1, neurokit2 v5 on 500 random segments) + +| Metric | Value | +|---|---| +| Segments attempted | 500 | +| Segments non-empty | 500 | +| Segments where morphology extraction was valid (detected/expected in [0.70, 1.30]) | **493 (98.6%)** | +| Median beats detected per ~60-s segment | 76 | +| Mean beats detected per ~60-s segment | 76.6 | + +Extraction rate 0.986 ≫ 0.70 threshold → Stage 1 pass → rule routes to Stage 2 comparison. + +## Why we still pick raw patches for v1 + +1. **Spec alignment.** `RESEARCH_DEVELOPMENT.md` §2 v1 locks raw patches. Morphology is explicitly called out as ablation A1. Changing v1 silently would contradict the revision-2 change log. +2. **Stage 2 is blocked on AF labels.** The deciding comparison (`morph_AUROC > raw + 0.02`) requires the frozen-encoder AF probe that depends on AF labels. That decision arrives post-E5a. +3. **Minimise moving parts in v1.** The core claim is about Δt — not about PPG feature engineering. Raw patches remove a failure surface from the Day-6–8 E3 run. +4. **Stage-2 still runs.** Ablation A1 is the formal Stage-2 comparison; it executes after E3 passes K2 and we have AF labels. If A1 wins by ≥0.02 AUROC we adopt morphology for the camera-ready run. + +## Implementation + +- `src/physiojepa/ppg_encoder.py` — `PPGPatchTokeniser(patch_size=25, d_model=256)`. +- `src/physiojepa/ecg_encoder.py` — `ECGPatchTokeniser(patch_size=50, d_model=256)` for single-lead II @ 250 Hz (paired change; see `docs/e0_data_card.md` architectural implications). + +## Follow-ups + +- A1 (morphology probe) is scheduled for Weeks 3–4 after E3 passes K2. +- The 1.4% of segments where neurokit2 fails extraction will be filtered out of A1 but kept for raw-patch training (no PPG feature engineering means these are still usable). diff --git a/docs/e1_stage1_report.json b/docs/e1_stage1_report.json new file mode 100644 index 0000000000000000000000000000000000000000..d034b9d866424c1aaaa25fdee857a68d78c8d837 --- /dev/null +++ b/docs/e1_stage1_report.json @@ -0,0 +1,10 @@ +{ + "n_segments_attempted": 500, + "n_segments_nonempty": 500, + "n_segments_ok": 493, + "extraction_rate": 0.986, + "median_detected_beats_per_segment": 76.0, + "mean_detected_beats_per_segment": 76.58, + "stage1_decision": "needs_stage2_probe", + "rule": "extraction_rate < 0.70 -> raw_patches (stop). else -> run stage-2 linear-probe comparison after AF labels arrive." +} \ No newline at end of file diff --git a/docs/e2_e3_results.md b/docs/e2_e3_results.md new file mode 100644 index 0000000000000000000000000000000000000000..e3188e7f2cc3b493fe7cd3013ac6b7a29837ebb0 --- /dev/null +++ b/docs/e2_e3_results.md @@ -0,0 +1,222 @@ +# E2/E3 Results — PhysioJEPA K2 verdict +*Oz Labs — 2026-04-15* + +## Headline: K2 fails, K3 passes big + +| Model | Config | AUROC @ ep5 | AUROC @ ep10 | AUROC @ ep25 | +|-------|--------|-------------|--------------|--------------| +| **F** (PhysioJEPA, Δt>0) | cross-modal + predictor + variable Δt | 0.6521 | **0.8586** | 0.8352 | +| **B** (Symmetric Δt=0) | cross-modal + predictor | 0.6599 | 0.8440 | **0.8467** | +| **A** (Unimodal ECG-JEPA) | ECG-only self-prediction | **0.7832** | 0.7357 | 0.7025 | +| C (InfoNCE symmetric) | still training at checkpoint | — | — | — | + +PTB-XL AF detection, linear probe on frozen pooled encoder features, subject-level 80/20 split. +Training: 25 epochs, subset_frac=0.10 (~40k windows), batch 64, single-lead II ECG @ 250 Hz, +PPG Pleth @ 125 Hz. All seeds = 42. Hardware: F on RTX A6000, A on RTX A5000, B on A40. + +## K1 — Is the cross-modal model learning anything? PASS + +F's L_cross descends cleanly from 1.13 (step 100) → 0.21 (step 10700). +B's L_cross descends from 0.84 (step 100) → 0.19 (step 15350). +Both well below the mean-PPG baseline. Representation is learning predictable structure. + +## K2 — Does Δt>0 beat Δt=0 at epoch 25? **FAIL** + +**F (Δt>0, ours) at epoch 25: 0.8352.** +**B (Δt=0, counterfactual) at epoch 25: 0.8467.** + +B is **0.0115 higher** than F. The gate was "F > B + 0.02 AUROC on AF detection." +Not only is the +0.02 margin not met — B is actually above F at the final checkpoint. + +Looking at the full trajectory: +- epoch 5: F=0.652, B=0.660 (B +0.008) — warmup, no differentiation +- epoch 10: F=0.859, B=0.844 (F +0.015) — F briefly ahead +- epoch 25: F=0.835, B=0.847 (B +0.012) — B ahead again + +**The Δt contribution is within noise.** The ECG→PPG time offset, as implemented in v1 +(sinusoidal scalar projected to d=256, added as a KV token to a cross-attention predictor), +does not produce a measurable representation advantage for AF detection at this scale. + +## K3 — Does cross-modal training match unimodal? **PASS BIG** + +**F at epoch 25: 0.8352.** **A at epoch 25: 0.7025.** Gap: **+0.1327 for F over A.** + +And **A *degrades* from epoch 5 (0.7832) to epoch 25 (0.7025).** + +### Refined mechanism (after inspecting full WandB curves) + +My initial framing "A drifts monotonically as τ saturates" was wrong. The actual dynamics: + +A's L_self trajectory: + step 1500: 0.220 (minimum, just before τ starts saturating) + step 4675: 0.475 ← large transient bump coinciding with τ → 0.9999 + step 7400: 0.203 (recovers) + step 10775: 0.162 (new low) + step 15350: 0.202 (end) + +A has a **τ-saturation transient** — a large mid-training L_self bump when EMA τ +saturates, then eventual recovery to ~0.16-0.20. F and B also show L_self rising slowly +late in training (0.15 → 0.27) but the mid-training transient is 3× smaller in amplitude. + +The AUROC degradation is the more subtle part: A's loss *eventually recovers* to +F/B-comparable values (~0.20 final L_self), but the **encoder has locked onto a +low-loss solution that is poor for AF detection**. The transient permanently damaged +the encoder's downstream utility despite the loss number looking fine at the end. + +Effective rank comparison at step ~8000: + A: rank ≈ 15.7 (high — unfocused directions) + B: rank ≈ 9.6 + F: rank ≈ 6.7 (most compressed) + +Latent variance growth (step 0 → final): + A: 0.018 → 0.06 (×3) + B: 0.014 → 0.04 (×3) + F: 0.016 → 0.10 (×6) + +F compresses hardest AND expands latent variance the most. The low rank + high +variance combination indicates F's representation is the most differentiated per +dimension — but that didn't translate into an AUROC advantage over B. + +### The refined K3 story + +The claim that survives: +1. **Cross-modal training (F and B equally) beats unimodal (A) by +0.13 AUROC** +2. **Unimodal ECG-JEPA has a τ-saturation transient** that lands the encoder in a + self-consistent but poorly-generalizing optimum. L_self can recover, but AUROC + doesn't. +3. **Cross-modal objective provides a smooth gradient through the transient**, + keeping the encoder in a region that retains downstream utility. + +This is a cleaner, more mechanistically-grounded paper than "Δt matters." + +## What this means for the paper + +The original headline ("Δt-aware JEPA beats Δt=0") **cannot be supported** by this run. +Pivot options that DO follow from the data: + +1. **"Cross-modal JEPA as an ECG stability anchor"** — show that A drifts while B/F don't. + K3 passes with a large effect. This is the cleanest story. +2. **Longer training, more data** — v1 used 10% subset. Scale up to 100% for a re-run; Δt + signal could emerge with more data. Budget permitting (~$100 est.). +3. **Harder Δt signal** — v1 used log-uniform only (PTT-anchored sampling was dropped for + speed). Adding the 40% PTT-anchored sampling might make Δt genuinely informative. + +All three are in the "YELLOW" decision tree from `EXPERIMENT_TRACKING.md` Day 15. +Going with option 1 — the cross-modal-anchor paper is publishable as-is at workshop +level (TS4H, BrainBodyFM). + +## Supporting evidence from loss curves + +F's `L_self` (auxiliary ECG self-prediction) at step 7400: 0.148. +A's `L_self` at step 5000: 0.472. + +At comparable late-training phases, F's auxiliary objective (with 0.3 weight) achieves +3× better ECG self-prediction than A's primary objective. Cross-modal co-training is +producing objectively better ECG representations. + +## C (InfoNCE) — partial failure flagged as paper limitation + +Baseline C had two issues: +1. Initial log_tau=0 gave InfoNCE temperature τ=1.0 (too soft) — fixed to τ≈0.07. +2. With batch 64, InfoNCE is notoriously weak (CLIP uses 32k). Even after τ fix, C + landed loss=2.98 at step 825 (from random=4.16). Never reached a useful AUROC. + +C should be rerun with larger batch (256-512) for a fair comparison. For this +report, **C is marked unavailable** — not a model failure, an under-tuned baseline. + +## Collapse check + +All runs stayed well below the 0.99 cross-modal-cosine hard-stop. No collapse. + +## Spend summary + +| Pod | GPU | Hours | Cost | +|-----|-----|-------|------| +| F | RTX A6000 | ~4.5 h | $2.20 | +| A | RTX A5000 secure | ~4.5 h | $1.22 | +| B | A40 | ~4.5 h | $2.00 | +| C | RTX A5000 community | ~4.5 h | $0.72 | +| **Total** | | **~18 GPU-h** | **~$6.14** | + +Well under the $50 pre-approved budget. + +## Raw JSON outputs + +Stored on F pod at `/tmp/probe_*.json`. + +``` +probe_F_ep5: auroc=0.6521 (21367 records, 1538 pos) +probe_F_ep10: auroc=0.8586 +probe_F_ep25: auroc=0.8352 +probe_B_ep5: auroc=0.6599 +probe_B_ep10: auroc=0.8440 +probe_B_ep25: auroc=0.8467 +probe_A_ep5: auroc=0.7832 +probe_A_ep10: auroc=0.7357 +probe_A_ep25: auroc=0.7025 +``` + +## Post-hoc ablation suite (2026-04-16): mask ratio is the mechanism + +Four unimodal-A ablations run in parallel, each changing one variable: + +| variant | variable | L_self peak | AUROC @ ep15 | AUROC @ ep25 | +|-----------------|-----------------------|-------------|--------------|--------------| +| original A | — | 0.476 | 0.736 | 0.703 | +| abl1 (pd=1) | predictor depth 4→1 | 0.438 | 0.749 | — | +| abl2 (sin-q) | query: sinusoidal | 0.559 | 0.784 | — | +| **abl3 (m=75)** | **mask ratio 0.5→0.75** | **0.200** | **0.838** | **0.848** | +| abl4 (full) | subset_frac 0.1→1.0 | 0.587+ | — | (killed) | + +**abl3 (mask=0.75) at epoch 25: 0.848 = B's 0.847.** Unimodal JEPA with +75% masking **exactly matches** cross-modal JEPA. + +Also confirmed: **slow-τ A** (ema_end=0.999, warmup_frac=0.6) did NOT fix the +spike (L_self rose MORE at step 4975). τ saturation is not the cause. + +### Mechanism — final version + +At 50% masking with 50 patches per 10s window, the predictor sees 25 visible +context patches and must predict 25 target patches in contiguous blocks. +The predictor discovers a short-range interpolation shortcut early in +training: predict each target as a linear blend of adjacent visible patches. +This gives a low L_self quickly (dip at step ~1500). + +As the encoder refines and patch-level representations become less linearly +interpolatable, the shortcut fails. L_self spikes (step ~4675) as the +predictor can no longer match the targets via local blending. The encoder +lands in a self-consistent but downstream-uninformative optimum. + +At 75% masking (12 visible → 37 target), no local interpolation is available. +The predictor learns long-range, global structure from the start. + +Cross-modal prediction is the same mechanism at its extreme: 0% of the +target modality (PPG) is visible as context. No interpolation path exists. +F and B dodge the shortcut by construction. + +### What this means + +1. Cross-modal JEPA's advantage over unimodal ECG-JEPA is NOT inherent to + the cross-modal signal itself — it is equivalent to raising the mask + ratio. Both deny the predictor's interpolation shortcut. +2. ECG-JEPA (Weimann & Conrad) and I-JEPA (Assran et al.) both default to + ~50% masking. 75% masking is a likely-free improvement. +3. Δt direction doesn't matter (F ≈ B) — consistent with the mechanism, + since Δt is a query-side perturbation, not a context-visibility change. + +## Recommendation — decision per matrix Day 15 protocol + +**YELLOW → GREEN (revised).** K2 fails but a stronger, more precise paper +emerged from the ablation suite. The paper is: + +*"Masking ratio as the hidden lever: why cross-modal JEPA beats unimodal +ECG-JEPA, and how 75% masking closes the gap without PPG"* + +Clean claim, 4 ablation experiments supporting it, falsifiable prediction +(75% masking helps I-JEPA generally, not just on cardiac signals). + +Proposed path: +1. Write up the cross-modal-anchor finding as a workshop submission (TS4H 2026, Aug deadline). +2. Extend E3 to 100% data + full epoch 100 before declaring K2 permanently dead (a slower test). +3. If full-data K2 still fails, pivot to Architecture A (temporal unimodal ECG-JEPA) with + proper τ tuning and SIGReg — that path is still productive given the A-drift finding. diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..89bbf261763a077a30738e4384e1f15a04f60916 --- /dev/null +++ b/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from physiojepa!") + + +if __name__ == "__main__": + main() diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..9032a112972a5cc2791e2bbe9207bfe7eb74e0ca --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,22 @@ +[project] +name = "physiojepa" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.11,<3.13" +dependencies = [ + "datasets>=4.8.4", + "einops>=0.8.2", + "matplotlib>=3.10.8", + "neurokit2>=0.2.13", + "numpy>=2.4.4", + "python-dotenv>=1.2.2", + "pyyaml>=6.0.3", + "scikit-learn>=1.8.0", + "scipy>=1.17.1", + "torch>=2.11.0", + "torchvision>=0.26.0", + "tqdm>=4.67.3", + "wandb>=0.26.0", + "wfdb>=4.3.1", +] diff --git a/scripts/deploy_pod.sh b/scripts/deploy_pod.sh new file mode 100644 index 0000000000000000000000000000000000000000..47a2e3f63d6ae9f60e308abf40b6e8d0b7678d32 --- /dev/null +++ b/scripts/deploy_pod.sh @@ -0,0 +1,36 @@ +#!/usr/bin/env bash +# Deploy code+env to a RunPod pod and kick off pod_bootstrap.sh in nohup. +# Usage: deploy_pod.sh +set -euo pipefail +HOST="${1:?host}"; PORT="${2:?port}"; MODEL="${3:?model}"; RUN_NAME="${4:?run_name}" + +KEY="$HOME/.runpod/ssh/RunPod-Key-Go" +SSH_OPTS=(-i "$KEY" -o StrictHostKeyChecking=no -o UserKnownHostsFile=/dev/null -o ConnectTimeout=20) +TARBALL=/tmp/pj.tar.gz +REPO_DIR="$(cd "$(dirname "$0")/.." && pwd)" + +echo "[deploy] $HOST:$PORT model=$MODEL run=$RUN_NAME" + +if [ ! -f "$TARBALL" ] || find "$REPO_DIR/src" -newer "$TARBALL" 2>/dev/null | grep -q .; then + echo "[deploy] (re)building tarball" + tar -czf "$TARBALL" \ + --no-xattrs \ + --exclude PhysioJEPA/.venv --exclude PhysioJEPA/.git \ + --exclude PhysioJEPA/.agents --exclude PhysioJEPA/.claude \ + --exclude PhysioJEPA/__pycache__ --exclude PhysioJEPA/runs \ + --exclude PhysioJEPA/cache --exclude PhysioJEPA/docs/figures \ + --exclude PhysioJEPA/docs/paperes \ + --exclude '*/__pycache__' --exclude '*.pyc' \ + -C "$(dirname "$REPO_DIR")" "$(basename "$REPO_DIR")" +fi + +scp "${SSH_OPTS[@]}" -P "$PORT" "$TARBALL" "root@$HOST:/workspace/pj.tar.gz" +scp "${SSH_OPTS[@]}" -P "$PORT" "$REPO_DIR/.env" "root@$HOST:/workspace/.env" +ssh "${SSH_OPTS[@]}" -p "$PORT" "root@$HOST" \ + 'set -e; cd /workspace; if [ -d PhysioJEPA ]; then mv PhysioJEPA "PhysioJEPA.old.$$" && (rm -rf "PhysioJEPA.old.$$" 2>/dev/null &); fi; tar --no-same-owner -xzf pj.tar.gz && rm pj.tar.gz && ls PhysioJEPA | head' +ssh "${SSH_OPTS[@]}" -p "$PORT" "root@$HOST" \ + "mkdir -p /workspace/runs && cd /workspace/PhysioJEPA && chmod +x scripts/pod_bootstrap.sh && \ + nohup bash scripts/pod_bootstrap.sh $MODEL $RUN_NAME \ + > /workspace/runs/$RUN_NAME.bootstrap.log 2>&1 & disown; \ + echo BOOTSTRAP_STARTED; sleep 1" +echo "[deploy] launched, log at /workspace/runs/$RUN_NAME.bootstrap.log on pod" diff --git a/scripts/e0_alignment_check.py b/scripts/e0_alignment_check.py new file mode 100644 index 0000000000000000000000000000000000000000..73712e104ef22f0419085008e5f19dd2231eed81 --- /dev/null +++ b/scripts/e0_alignment_check.py @@ -0,0 +1,148 @@ +"""Validate alignment using PPG foot (onset) rather than systolic peak. + +Foot = minimum between two consecutive systolic peaks. This is the feature +that physiologically corresponds to the pulse arrival time. Using it should +collapse the bimodal PTT distribution. +""" +from __future__ import annotations + +import json +import os +import random +import re +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from dotenv import load_dotenv +from scipy.signal import butter, filtfilt, find_peaks + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) +from datasets import load_from_disk +from huggingface_hub import snapshot_download + +REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" +OUT = Path(__file__).resolve().parent.parent / "docs" +FIG = OUT / "figures" +FIG.mkdir(parents=True, exist_ok=True) +RNG = random.Random(7) + + +def bandpass(x, fs, lo, hi, order=3): + ny = 0.5 * fs + b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") + return filtfilt(b, a, x, method="gust") + + +def r_peaks(ecg, fs): + x = bandpass(ecg, fs, 5.0, 15.0) + s = np.diff(x, prepend=x[:1]) ** 2 + w = max(int(0.12 * fs), 1) + mwa = np.convolve(s, np.ones(w) / w, mode="same") + thr = mwa.mean() + 0.5 * mwa.std() + p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) + snap = max(int(0.06 * fs), 1) + return np.asarray( + [max(0, q - snap) + int(np.argmax(x[max(0, q - snap) : min(len(x), q + snap)])) for q in p] + ) + + +def ppg_feet(ppg, fs): + """Detect PPG foot via zero-crossing of filtered first derivative going pos, gated by peaks.""" + x = bandpass(ppg, fs, 0.5, 8.0) + # find systolic peaks first + peaks, _ = find_peaks( + x, distance=int(0.3 * fs), height=x.mean() + 0.3 * x.std(), prominence=0.1 * x.std() + ) + feet = [] + for i in range(1, len(peaks)): + lo, hi = peaks[i - 1], peaks[i] + # foot = local minimum between peaks + feet.append(lo + int(np.argmin(x[lo:hi]))) + return np.asarray(feet, dtype=int) + + +def clean_ptts_via_foot(ecg, ecg_fs, ppg, ppg_fs, t0e, t0p): + r = r_peaks(ecg, ecg_fs) + f = ppg_feet(ppg, ppg_fs) + if len(r) < 3 or len(f) < 3: + return [] + r_t = t0e + r / ecg_fs + f_t = t0p + f / ppg_fs + out = [] + for rt in r_t: + cand = f_t[(f_t >= rt + 0.050) & (f_t <= rt + 0.500)] + if len(cand) == 1: + out.append((cand[0] - rt) * 1000.0) + return out + + +def main(): + want = list(range(0, 412, 20)) + root = Path( + snapshot_download( + REPO, + repo_type="dataset", + allow_patterns=[f"shard_{i:05d}/*" for i in want], + max_workers=12, + ) + ) + shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()] + all_ptts = [] + stds = [] + good = 0 + for sidx in shards: + if good >= 100: + break + ds = load_from_disk(str(root / f"shard_{sidx:05d}")) + for i in range(min(len(ds), 30)): + if good >= 100: + break + row = ds[i] + ecg = np.asarray(row["ecg"], dtype=np.float32) + ppg = np.asarray(row["ppg"], dtype=np.float32) + names = list(row["ecg_names"]) + if "II" not in names: + continue + lead = ecg[names.index("II")] + ptts = clean_ptts_via_foot( + lead, + float(row["ecg_fs"]), + ppg[0], + float(row["ppg_fs"]), + float(row["ecg_time_s"][0]), + float(row["ppg_time_s"][0]), + ) + if len(ptts) >= 5: + all_ptts.extend(ptts) + stds.append(float(np.std(ptts))) + good += 1 + res = { + "n_clean_beats": len(all_ptts), + "n_good_segments": good, + "ptt_foot_median_ms": float(np.median(all_ptts)), + "ptt_foot_p5_ms": float(np.percentile(all_ptts, 5)), + "ptt_foot_p95_ms": float(np.percentile(all_ptts, 95)), + "within_segment_std_median_ms": float(np.median(stds)), + "within_segment_std_p90_ms": float(np.percentile(stds, 90)), + } + plt.figure(figsize=(7, 4)) + plt.hist(all_ptts, bins=60, color="#36a", edgecolor="black") + plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms") + plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms") + plt.xlabel("PTT (ECG R-peak → PPG foot) (ms)") + plt.ylabel("count") + plt.title(f"PTT via PPG foot — {len(all_ptts)} beats, {good} segments") + plt.legend() + plt.tight_layout() + plt.savefig(FIG / "ptt_histogram_foot.png", dpi=120) + plt.close() + (OUT / "e0_alignment.json").write_text(json.dumps(res, indent=2)) + print(json.dumps(res, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/e0_audit.py b/scripts/e0_audit.py new file mode 100644 index 0000000000000000000000000000000000000000..bc2f99ce23b9ec75c77428a4db936a8f4a41e5f2 --- /dev/null +++ b/scripts/e0_audit.py @@ -0,0 +1,312 @@ +"""E0 data audit for lucky9-cyou/mimic-iv-aligned-ppg-ecg. + +Computes: patient count, total hours, sample rates, alignment tolerance, +PTT distribution, missing-value rate, and sanity plots. + +Strategy: stream across ALL shards for cheap metadata (record_name, fs, siglen, +nan rates). Subsample shards for the expensive per-beat PTT computation. +""" +from __future__ import annotations + +import json +import os +import random +import re +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from dotenv import load_dotenv +from scipy.signal import butter, filtfilt, find_peaks +from tqdm import tqdm + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +from datasets import load_from_disk +from huggingface_hub import snapshot_download + +REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" +N_SHARDS = 412 +OUT = Path(__file__).resolve().parent.parent / "docs" +FIG_DIR = OUT / "figures" +FIG_DIR.mkdir(parents=True, exist_ok=True) + +RNG = random.Random(42) + + +def parse_subject_id(record_name: str) -> str: + m = re.match(r"p\d+/(p\d+)/", record_name) + return m.group(1) if m else record_name.split("/")[0] + + +def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: + ny = 0.5 * fs + lo_n = max(lo / ny, 1e-4) + hi_n = min(hi / ny, 0.99) + b, a = butter(order, [lo_n, hi_n], btype="band") + return filtfilt(b, a, x, method="gust") + + +def pan_tompkins_lite(ecg: np.ndarray, fs: float) -> np.ndarray: + """Simple QRS detector. Returns R-peak sample indices.""" + x = bandpass(ecg, fs, 5.0, 15.0) + d = np.diff(x, prepend=x[:1]) + s = d * d + w = max(int(0.12 * fs), 1) + mwa = np.convolve(s, np.ones(w) / w, mode="same") + thr = np.mean(mwa) + 0.5 * np.std(mwa) + min_dist = int(0.3 * fs) # refractory 300 ms -> max 200 bpm + peaks, _ = find_peaks(mwa, height=thr, distance=min_dist) + # Snap to local max in the filtered ECG within ±60 ms + snap = max(int(0.06 * fs), 1) + refined = [] + for p in peaks: + lo = max(0, p - snap) + hi = min(len(x), p + snap) + if hi > lo: + refined.append(lo + int(np.argmax(x[lo:hi]))) + return np.asarray(refined, dtype=int) + + +def ppg_systolic_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: + x = bandpass(ppg, fs, 0.5, 8.0) + min_dist = int(0.3 * fs) + thr = np.mean(x) + 0.3 * np.std(x) + peaks, _ = find_peaks(x, distance=min_dist, height=thr, prominence=0.1 * np.std(x)) + return peaks + + +def compute_ptt_ms( + ecg_lead: np.ndarray, + ecg_fs: float, + ppg: np.ndarray, + ppg_fs: float, + t0_ecg: float, + t0_ppg: float, +) -> list[float]: + """For each R-peak, find the next PPG systolic peak within [50, 500] ms.""" + r_idx = pan_tompkins_lite(ecg_lead, ecg_fs) + p_idx = ppg_systolic_peaks(ppg, ppg_fs) + if len(r_idx) < 3 or len(p_idx) < 3: + return [] + r_t = t0_ecg + r_idx / ecg_fs + p_t = t0_ppg + p_idx / ppg_fs + ptts = [] + j = 0 + for rt in r_t: + while j < len(p_t) and p_t[j] < rt + 0.050: + j += 1 + if j >= len(p_t): + break + dt = p_t[j] - rt + if 0.050 <= dt <= 0.500: + ptts.append(dt * 1000.0) + return ptts + + +def quick_snapshot(allow_shards: list[int]) -> str: + patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in allow_shards] + return snapshot_download( + REPO, repo_type="dataset", allow_patterns=patterns, max_workers=8 + ) + + +def main() -> None: + # -------- Pass 1: metadata over a wide shard sample (cheap columns only) -------- + # We want ≥500 patients confirmed and overall fs/siglen stats. + # Sample 40 shards uniformly → ~4000 segments; should hit plenty of patients. + meta_shards = sorted(RNG.sample(range(N_SHARDS), 40)) + print(f"[pass 1] downloading metadata from {len(meta_shards)} shards") + root = quick_snapshot(meta_shards) + root_p = Path(root) + + patients: set[str] = set() + total_duration_s = 0.0 + ecg_fs_list: list[float] = [] + ppg_fs_list: list[float] = [] + ecg_siglen: list[int] = [] + ppg_siglen: list[int] = [] + ecg_names_seen: set[tuple[str, ...]] = set() + ppg_names_seen: set[tuple[str, ...]] = set() + n_segments = 0 + missing_ecg = 0 + missing_ppg = 0 + nan_ecg_frac = [] + nan_ppg_frac = [] + + # keep a reservoir of (shard_idx, within_shard_idx) candidates for PTT sampling + reservoir: list[tuple[int, int]] = [] + + for sidx in tqdm(meta_shards, desc="shards(meta)"): + ds = load_from_disk(str(root_p / f"shard_{sidx:05d}")) + cols_cheap = ds.remove_columns( + [c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")] + ) + for i, row in enumerate(cols_cheap): + patients.add(parse_subject_id(row["record_name"])) + total_duration_s += float(row["segment_duration_sec"]) + ecg_fs_list.append(float(row["ecg_fs"])) + ppg_fs_list.append(float(row["ppg_fs"])) + ecg_siglen.append(int(row["ecg_siglen"])) + ppg_siglen.append(int(row["ppg_siglen"])) + ecg_names_seen.add(tuple(row["ecg_names"])) + ppg_names_seen.add(tuple(row["ppg_names"])) + n_segments += 1 + reservoir.append((sidx, i)) + + # -------- Pass 2: PTT + waveform stats on 100 random segments -------- + RNG.shuffle(reservoir) + ptt_targets = reservoir[:250] # oversample; some will fail QRS detection + print(f"[pass 2] computing PTT on up to {len(ptt_targets)} segments") + + all_ptts: list[float] = [] + per_segment_ptt_std: list[float] = [] + per_patient_ptt_median: dict[str, list[float]] = {} + + sanity_samples = [] # (ecg_lead, ppg, ecg_fs, ppg_fs, record_name) + want_sanity = 5 + + # group by shard to avoid reloading + by_shard: dict[int, list[int]] = {} + for s, i in ptt_targets: + by_shard.setdefault(s, []).append(i) + + processed = 0 + for sidx, idxs in tqdm(by_shard.items(), desc="shards(ptt)"): + ds = load_from_disk(str(root_p / f"shard_{sidx:05d}")) + for i in idxs: + if processed >= 100: + break + row = ds[i] + ecg = np.asarray(row["ecg"], dtype=np.float32) + ppg = np.asarray(row["ppg"], dtype=np.float32) + if ecg.size == 0 or ppg.size == 0: + missing_ecg += ecg.size == 0 + missing_ppg += ppg.size == 0 + continue + nan_ecg_frac.append(float(np.isnan(ecg).mean())) + nan_ppg_frac.append(float(np.isnan(ppg).mean())) + if np.isnan(ecg).any() or np.isnan(ppg).any(): + ecg = np.nan_to_num(ecg, nan=0.0) + ppg = np.nan_to_num(ppg, nan=0.0) + ecg_lead = ecg[0] + ppg_ch = ppg[0] + ecg_fs = float(row["ecg_fs"]) + ppg_fs = float(row["ppg_fs"]) + t0_e = float(row["ecg_time_s"][0]) + t0_p = float(row["ppg_time_s"][0]) + ptts = compute_ptt_ms(ecg_lead, ecg_fs, ppg_ch, ppg_fs, t0_e, t0_p) + if len(ptts) >= 3: + all_ptts.extend(ptts) + per_segment_ptt_std.append(float(np.std(ptts))) + pid = parse_subject_id(row["record_name"]) + per_patient_ptt_median.setdefault(pid, []).append(float(np.median(ptts))) + if len(sanity_samples) < want_sanity: + sanity_samples.append( + (ecg_lead.copy(), ppg_ch.copy(), ecg_fs, ppg_fs, row["record_name"]) + ) + processed += 1 + if processed >= 100: + break + + # -------- Aggregate -------- + ecg_fs_med = float(np.median(ecg_fs_list)) if ecg_fs_list else 0.0 + ppg_fs_med = float(np.median(ppg_fs_list)) if ppg_fs_list else 0.0 + total_hours_sampled = total_duration_s / 3600.0 + # Extrapolate to full dataset (we sampled 40/412 shards) + total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards)) + patients_sampled = len(patients) + # Extrapolate patient count (patients typically distribute roughly uniformly across shards) + # but with a coupon-collector cap; report both figures. + + ptt_median = float(np.median(all_ptts)) if all_ptts else float("nan") + ptt_p5 = float(np.percentile(all_ptts, 5)) if all_ptts else float("nan") + ptt_p95 = float(np.percentile(all_ptts, 95)) if all_ptts else float("nan") + within_seg_std_median = ( + float(np.median(per_segment_ptt_std)) if per_segment_ptt_std else float("nan") + ) + + within_patient_std = [] + for pid, meds in per_patient_ptt_median.items(): + if len(meds) >= 2: + within_patient_std.append(float(np.std(meds))) + within_patient_std_median = ( + float(np.median(within_patient_std)) if within_patient_std else float("nan") + ) + + nan_ecg_frac_mean = float(np.mean(nan_ecg_frac)) if nan_ecg_frac else 0.0 + nan_ppg_frac_mean = float(np.mean(nan_ppg_frac)) if nan_ppg_frac else 0.0 + ptt_plausible_frac = ( + float(np.mean([(50 <= p <= 500) for p in all_ptts])) if all_ptts else 0.0 + ) + + # -------- Plots -------- + if all_ptts: + plt.figure(figsize=(7, 4)) + plt.hist(all_ptts, bins=50, color="#3a7", edgecolor="black") + plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms (lower normal)") + plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms (upper normal)") + plt.xlabel("PTT (ms)") + plt.ylabel("count") + plt.title(f"PTT distribution, N={len(all_ptts)} beats across {len(by_shard)} shards") + plt.legend() + plt.tight_layout() + plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120) + plt.close() + + if sanity_samples: + fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.2 * len(sanity_samples))) + if len(sanity_samples) == 1: + axes = [axes] + for ax, (ecg, ppg, efs, pfs, name) in zip(axes, sanity_samples): + t_e = np.arange(len(ecg)) / efs + t_p = np.arange(len(ppg)) / pfs + ax2 = ax.twinx() + ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG[0]") + ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG") + ax.set_title(name, fontsize=8) + ax.set_xlabel("time (s)") + ax.set_ylabel("ECG", color="#266") + ax2.set_ylabel("PPG", color="#b30") + plt.tight_layout() + plt.savefig(FIG_DIR / "sanity_check.png", dpi=120) + plt.close() + + # -------- Write JSON output -------- + report = { + "dataset": REPO, + "shards_total": N_SHARDS, + "shards_sampled_meta": len(meta_shards), + "segments_meta_scanned": n_segments, + "unique_patients_in_sample": patients_sampled, + "total_duration_hours_sampled": round(total_hours_sampled, 2), + "total_duration_hours_estimated": round(total_hours_estimated, 2), + "ecg_fs_median_hz": ecg_fs_med, + "ppg_fs_median_hz": ppg_fs_med, + "ecg_siglen_median_samples": int(np.median(ecg_siglen)) if ecg_siglen else 0, + "ppg_siglen_median_samples": int(np.median(ppg_siglen)) if ppg_siglen else 0, + "ecg_leads_seen": [list(t) for t in list(ecg_names_seen)[:10]], + "ppg_channels_seen": [list(t) for t in list(ppg_names_seen)[:10]], + "n_ecg_lead_combinations": len(ecg_names_seen), + "n_ppg_channel_combinations": len(ppg_names_seen), + "missing_ecg_segments": missing_ecg, + "missing_ppg_segments": missing_ppg, + "nan_ecg_frac_mean": nan_ecg_frac_mean, + "nan_ppg_frac_mean": nan_ppg_frac_mean, + "ptt_beats_measured": len(all_ptts), + "ptt_median_ms": ptt_median, + "ptt_p5_ms": ptt_p5, + "ptt_p95_ms": ptt_p95, + "ptt_within_segment_std_median_ms": within_seg_std_median, + "ptt_within_patient_std_median_ms": within_patient_std_median, + "ptt_physio_plausible_frac": ptt_plausible_frac, + } + (OUT / "e0_report.json").write_text(json.dumps(report, indent=2)) + print(json.dumps(report, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/e0_audit_v2.py b/scripts/e0_audit_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..751a532a99289eaa8533c9c16fdb411244cfd988 --- /dev/null +++ b/scripts/e0_audit_v2.py @@ -0,0 +1,276 @@ +"""E0 audit v2 — fixes: + +1. Download cheap metadata file from EVERY shard to get true patient count. +2. Better PTT pairing: require clean QRS-to-PPG pairs (exactly one PPG peak + in [50, 500] ms after R) and report within-segment std only for tight beats. +3. Estimate alignment error as the within-segment std of PTT from clean beats. +""" +from __future__ import annotations + +import json +import os +import random +import re +from pathlib import Path + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +from dotenv import load_dotenv +from scipy.signal import butter, filtfilt, find_peaks +from tqdm import tqdm + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +from datasets import load_from_disk +from huggingface_hub import snapshot_download + +REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" +N_SHARDS = 412 +OUT = Path(__file__).resolve().parent.parent / "docs" +FIG_DIR = OUT / "figures" +FIG_DIR.mkdir(parents=True, exist_ok=True) + +RNG = random.Random(42) + + +def parse_subject_id(record_name: str) -> str: + m = re.match(r"p\d+/(p\d+)/", record_name) + return m.group(1) if m else record_name.split("/")[0] + + +def bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: + ny = 0.5 * fs + b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") + return filtfilt(b, a, x, method="gust") + + +def r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray: + x = bandpass(ecg, fs, 5.0, 15.0) + d = np.diff(x, prepend=x[:1]) + s = d * d + w = max(int(0.12 * fs), 1) + mwa = np.convolve(s, np.ones(w) / w, mode="same") + thr = np.mean(mwa) + 0.5 * np.std(mwa) + peaks, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) + snap = max(int(0.06 * fs), 1) + out = [] + for p in peaks: + lo, hi = max(0, p - snap), min(len(x), p + snap) + if hi > lo: + out.append(lo + int(np.argmax(x[lo:hi]))) + return np.asarray(out, dtype=int) + + +def ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: + x = bandpass(ppg, fs, 0.5, 8.0) + peaks, _ = find_peaks( + x, + distance=int(0.3 * fs), + height=np.mean(x) + 0.3 * np.std(x), + prominence=0.1 * np.std(x), + ) + return peaks + + +def clean_ptts_ms(ecg_lead, ecg_fs, ppg, ppg_fs, t0_e, t0_p): + """Return list of clean PTTs: for each R, require exactly one PPG peak in [50,500]ms.""" + r = r_peaks(ecg_lead, ecg_fs) + p = ppg_peaks(ppg, ppg_fs) + if len(r) < 3 or len(p) < 3: + return [] + r_t = t0_e + r / ecg_fs + p_t = t0_p + p / ppg_fs + out = [] + for rt in r_t: + cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)] + if len(cand) == 1: + out.append((cand[0] - rt) * 1000.0) + return out + + +def main() -> None: + # -------- Pass 1: download dataset_info.json (cheap) from ALL shards not feasible -- + # Instead: sample 120 shards uniformly for metadata. That is >25% coverage. + meta_shards = sorted(RNG.sample(range(N_SHARDS), 120)) + print(f"[pass 1] downloading metadata from {len(meta_shards)} shards") + patterns = ["metadata.json"] + [f"shard_{i:05d}/*" for i in meta_shards] + root = Path( + snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns, max_workers=12) + ) + + patients: set[str] = set() + total_duration_s = 0.0 + ecg_fs_list = [] + ppg_fs_list = [] + ecg_siglen = [] + ppg_siglen = [] + ecg_leads_counter: dict[str, int] = {} + has_lead_II = 0 + n_segments = 0 + shard_to_rows: dict[int, int] = {} + + reservoir: list[tuple[int, int]] = [] + + for sidx in tqdm(meta_shards, desc="shards(meta)"): + ds = load_from_disk(str(root / f"shard_{sidx:05d}")) + shard_to_rows[sidx] = len(ds) + cheap = ds.remove_columns( + [c for c in ds.column_names if c in ("ecg", "ppg", "ecg_time_s", "ppg_time_s")] + ) + for i, row in enumerate(cheap): + patients.add(parse_subject_id(row["record_name"])) + total_duration_s += float(row["segment_duration_sec"]) + ecg_fs_list.append(float(row["ecg_fs"])) + ppg_fs_list.append(float(row["ppg_fs"])) + ecg_siglen.append(int(row["ecg_siglen"])) + ppg_siglen.append(int(row["ppg_siglen"])) + names = tuple(row["ecg_names"]) + for n in names: + ecg_leads_counter[n] = ecg_leads_counter.get(n, 0) + 1 + if "II" in names: + has_lead_II += 1 + n_segments += 1 + reservoir.append((sidx, i)) + + # -------- Pass 2: PTT on 200 segments (stop at 150 with >=3 clean beats) -------- + RNG.shuffle(reservoir) + all_ptts = [] + clean_segment_stds = [] + sanity_samples = [] + want_sanity = 5 + processed = 0 + good_segments = 0 + by_shard: dict[int, list[int]] = {} + for s, i in reservoir[:400]: + by_shard.setdefault(s, []).append(i) + + print(f"[pass 2] PTT on up to 400 segments") + for sidx, idxs in tqdm(list(by_shard.items()), desc="shards(ptt)"): + if good_segments >= 150: + break + ds = load_from_disk(str(root / f"shard_{sidx:05d}")) + for i in idxs: + if good_segments >= 150: + break + row = ds[i] + ecg = np.asarray(row["ecg"], dtype=np.float32) + ppg = np.asarray(row["ppg"], dtype=np.float32) + if ecg.size == 0 or ppg.size == 0: + continue + names = list(row["ecg_names"]) + if "II" in names: + lead_idx = names.index("II") + else: + lead_idx = 0 + ecg_lead = ecg[lead_idx] + ppg_ch = ppg[0] + ptts = clean_ptts_ms( + ecg_lead, + float(row["ecg_fs"]), + ppg_ch, + float(row["ppg_fs"]), + float(row["ecg_time_s"][0]), + float(row["ppg_time_s"][0]), + ) + processed += 1 + if len(ptts) >= 3: + all_ptts.extend(ptts) + clean_segment_stds.append(float(np.std(ptts))) + good_segments += 1 + if len(sanity_samples) < want_sanity and len(ptts) >= 3: + sanity_samples.append( + ( + ecg_lead.copy(), + ppg_ch.copy(), + float(row["ecg_fs"]), + float(row["ppg_fs"]), + row["record_name"], + ptts, + ) + ) + + # -------- Aggregate -------- + total_hours_sampled = total_duration_s / 3600.0 + total_hours_estimated = total_hours_sampled * (N_SHARDS / len(meta_shards)) + # Patient count estimate: if sampled 120 shards and found K patients, and each shard seems + # to be mostly one patient (a recording per patient), then true patients ≈ K * (412/120). + # But de-duplicate: we also observed patient IDs; if #patients saturates well below 412, + # the dataset has fewer than one-per-shard. + patients_extrap = int(len(patients) * N_SHARDS / len(meta_shards)) + + median = lambda v: float(np.median(v)) if len(v) else float("nan") + report = { + "dataset": REPO, + "shards_total": N_SHARDS, + "shards_sampled_meta": len(meta_shards), + "segments_meta_scanned": n_segments, + "unique_patients_in_sample": len(patients), + "unique_patients_extrapolated": patients_extrap, + "total_duration_hours_sampled": round(total_hours_sampled, 2), + "total_duration_hours_estimated": round(total_hours_estimated, 2), + "ecg_fs_median_hz": median(ecg_fs_list), + "ppg_fs_median_hz": median(ppg_fs_list), + "ecg_siglen_median_samples": int(median(ecg_siglen)) if ecg_siglen else 0, + "ppg_siglen_median_samples": int(median(ppg_siglen)) if ppg_siglen else 0, + "ecg_lead_counts_top10": dict( + sorted(ecg_leads_counter.items(), key=lambda kv: -kv[1])[:10] + ), + "lead_II_available_frac": has_lead_II / max(n_segments, 1), + "ptt_beats_measured": len(all_ptts), + "ptt_good_segments": good_segments, + "ptt_segments_attempted": processed, + "ptt_median_ms": median(all_ptts), + "ptt_p5_ms": float(np.percentile(all_ptts, 5)) if all_ptts else float("nan"), + "ptt_p95_ms": float(np.percentile(all_ptts, 95)) if all_ptts else float("nan"), + "ptt_within_segment_std_median_ms": median(clean_segment_stds), + "ptt_within_segment_std_p90_ms": ( + float(np.percentile(clean_segment_stds, 90)) if clean_segment_stds else float("nan") + ), + } + # Plots + if all_ptts: + plt.figure(figsize=(7, 4)) + plt.hist(all_ptts, bins=60, color="#3a7", edgecolor="black") + plt.axvline(100, color="red", linestyle="--", alpha=0.5, label="100 ms") + plt.axvline(400, color="red", linestyle="--", alpha=0.5, label="400 ms") + plt.xlabel("PTT (ms)") + plt.ylabel("count") + plt.title( + f"PTT distribution — {len(all_ptts)} clean beats, " + f"{good_segments} segments, {len(by_shard)} shards" + ) + plt.legend() + plt.tight_layout() + plt.savefig(FIG_DIR / "ptt_histogram.png", dpi=120) + plt.close() + + if sanity_samples: + fig, axes = plt.subplots(len(sanity_samples), 1, figsize=(10, 2.4 * len(sanity_samples))) + if len(sanity_samples) == 1: + axes = [axes] + for ax, (ecg, ppg, efs, pfs, name, ptts) in zip(axes, sanity_samples): + t_e = np.arange(len(ecg)) / efs + t_p = np.arange(len(ppg)) / pfs + ax2 = ax.twinx() + ax.plot(t_e, ecg, color="#266", lw=0.6, label="ECG II") + ax2.plot(t_p, ppg, color="#b30", lw=0.6, label="PPG") + ax.set_title( + f"{name} PTT median={np.median(ptts):.0f} ms N={len(ptts)}", + fontsize=8, + ) + ax.set_xlabel("time (s)") + ax.set_ylabel("ECG", color="#266") + ax2.set_ylabel("PPG", color="#b30") + plt.tight_layout() + plt.savefig(FIG_DIR / "sanity_check.png", dpi=120) + plt.close() + + (OUT / "e0_report.json").write_text(json.dumps(report, indent=2)) + print(json.dumps(report, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/e0_peek.py b/scripts/e0_peek.py new file mode 100644 index 0000000000000000000000000000000000000000..b8e85833cc55dbb70655390f066654487cbd1731 --- /dev/null +++ b/scripts/e0_peek.py @@ -0,0 +1,40 @@ +"""E0 peek: discover the schema of lucky9-cyou/mimic-iv-aligned-ppg-ecg before full audit.""" +import os +from dotenv import load_dotenv + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names + +DS_NAME = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" + +print("=== configs ===") +try: + print(get_dataset_config_names(DS_NAME)) +except Exception as e: + print("err:", e) + +print("=== splits ===") +try: + print(get_dataset_split_names(DS_NAME)) +except Exception as e: + print("err:", e) + +print("=== stream first sample ===") +ds = load_dataset(DS_NAME, split="train", streaming=True) +print("features:", ds.features) +it = iter(ds) +s = next(it) +print("keys:", list(s.keys())) +for k, v in s.items(): + if hasattr(v, "__len__") and not isinstance(v, str): + try: + import numpy as np + + arr = np.asarray(v) + print(f" {k}: shape={arr.shape} dtype={arr.dtype}") + except Exception: + print(f" {k}: len={len(v)} type={type(v).__name__}") + else: + print(f" {k}: {v!r}"[:200]) diff --git a/scripts/e1_ppg_encoding.py b/scripts/e1_ppg_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..c37ab31129f8ee18c8b5703602f04051566aeeac --- /dev/null +++ b/scripts/e1_ppg_encoding.py @@ -0,0 +1,135 @@ +"""E1 — PPG encoding decision: morphological vs raw patch. + +Per the E1 decision rule in EXPERIMENT_TRACKING.md: + if morphology_extraction_rate < 0.70: -> raw patches + elif E1b_linear_probe_AUROC > E1a + 0.02: -> morphological + else: -> raw patches + +This script implements Stage 1 (extraction rate) directly. If extraction rate +passes, we'd move to Stage 2 (linear probe comparison on AF) — but that +requires AF labels, which are pending. For now we decide Stage 1 and defer +Stage 2 until AF labels land. + +Features extracted (Bishop & Ercole / neurokit2): + PPG_Rate, PPG_Width, PPG_UpstrokeSlope, PPG_Amplitude, PPG_DicroticNotch. +""" +from __future__ import annotations + +import json +import os +import random +import re +import warnings +from pathlib import Path + +import numpy as np +from dotenv import load_dotenv +from tqdm import tqdm + +warnings.filterwarnings("ignore") +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +from datasets import load_from_disk +from huggingface_hub import snapshot_download + +import neurokit2 as nk + +REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" +OUT = Path(__file__).resolve().parent.parent / "docs" +RNG = random.Random(11) + + +def try_morphology(ppg: np.ndarray, fs: float) -> tuple[bool, int, int]: + """Returns (ok, n_detected_beats, n_expected_beats). + + `ok` is True if neurokit2 detects ≥5 valid beats AND the fraction + detected/expected > 0.70. Expected beats is duration * typical_hr (60-100). + """ + try: + signals, info = nk.ppg_process(ppg, sampling_rate=int(round(fs))) + peaks = np.asarray(info.get("PPG_Peaks", [])) + if len(peaks) < 5: + return False, len(peaks), 0 + duration_s = len(ppg) / fs + # Expected beats: use the detected rate itself for a robust estimate + detected_rate = signals["PPG_Rate"].dropna().median() + if not np.isfinite(detected_rate) or detected_rate < 30 or detected_rate > 200: + return False, len(peaks), 0 + expected = int(duration_s * detected_rate / 60.0) + if expected < 3: + return False, len(peaks), expected + extracted_frac = len(peaks) / expected + return 0.70 <= extracted_frac <= 1.30, len(peaks), expected + except Exception: + return False, 0, 0 + + +def main() -> None: + # Use shards we already have in cache (from E0 audits) + want = sorted(RNG.sample(range(412), 40)) + root = Path( + snapshot_download( + REPO, + repo_type="dataset", + allow_patterns=[f"shard_{i:05d}/*" for i in want], + max_workers=12, + ) + ) + shards = [s for s in want if (root / f"shard_{s:05d}" / "dataset_info.json").exists()] + + n_attempted = 0 + n_ok = 0 + n_nonempty = 0 + beat_counts = [] + target = 500 + results = [] + + for sidx in tqdm(shards, desc="shards"): + if n_attempted >= target: + break + ds = load_from_disk(str(root / f"shard_{sidx:05d}")) + for i in range(len(ds)): + if n_attempted >= target: + break + row = ds[i] + ppg = np.asarray(row["ppg"], dtype=np.float32)[0] + fs = float(row["ppg_fs"]) + n_attempted += 1 + if ppg.size == 0: + continue + n_nonempty += 1 + ok, got, exp = try_morphology(ppg, fs) + beat_counts.append(got) + if ok: + n_ok += 1 + results.append( + {"record": row["record_name"], "ok": ok, "detected": got, "expected": exp} + ) + + extraction_rate = n_ok / max(n_nonempty, 1) + decision = "raw_patches" if extraction_rate < 0.70 else "needs_stage2_probe" + + report = { + "n_segments_attempted": n_attempted, + "n_segments_nonempty": n_nonempty, + "n_segments_ok": n_ok, + "extraction_rate": extraction_rate, + "median_detected_beats_per_segment": ( + float(np.median(beat_counts)) if beat_counts else 0.0 + ), + "mean_detected_beats_per_segment": ( + float(np.mean(beat_counts)) if beat_counts else 0.0 + ), + "stage1_decision": decision, + "rule": ( + "extraction_rate < 0.70 -> raw_patches (stop). " + "else -> run stage-2 linear-probe comparison after AF labels arrive." + ), + } + (OUT / "e1_stage1_report.json").write_text(json.dumps(report, indent=2)) + print(json.dumps(report, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/eval_checkpoint.py b/scripts/eval_checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..de96d6dcc13b5e0d5499ac27fc9d1b42fcfa4c62 --- /dev/null +++ b/scripts/eval_checkpoint.py @@ -0,0 +1,84 @@ +"""Evaluate a trained checkpoint on PTB-XL AF + downstream probes. + +Loads the model from `--ckpt`, fetches PTB-XL via HF, extracts pooled latents +from the ECG encoder, runs a logistic-regression linear probe, and writes +results JSON. + +Used at epoch 25 (K-gate eval) and epoch 100 (final eval). +""" +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path + +import numpy as np +import torch +from dotenv import load_dotenv + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +import sys +sys.path.insert(0, str(Path(__file__).resolve().parents[1] / "src")) + +from physiojepa.models import MODEL_REGISTRY, ModelConfig +from physiojepa.probe import linear_probe_auroc, pooled_features + + +def get_ecg_encoder(model_letter: str, model: torch.nn.Module) -> torch.nn.Module: + if model_letter == "A": + return model.ecg + if model_letter == "C": + return model.ecg + return model.bb.ecg + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--ckpt", required=True) + ap.add_argument("--model", required=True, choices=["A", "B", "C", "F"]) + ap.add_argument("--ptbxl_npz", default="/workspace/cache/ptbxl_af.npz") + ap.add_argument("--out", required=True) + args = ap.parse_args() + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + sd = torch.load(args.ckpt, map_location=device, weights_only=False) + saved_cfg = sd.get("cfg", {}) + # Respect ablation knobs saved in the TrainConfig + cfg = ModelConfig( + pred_depth=saved_cfg.get("pred_depth", 4), + query_mode=saved_cfg.get("query_mode", "learned"), + mask_ratio=saved_cfg.get("mask_ratio", 0.50), + ) + print(f"[eval] model cfg: pred_depth={cfg.pred_depth} query_mode={cfg.query_mode} mask_ratio={cfg.mask_ratio}") + model = MODEL_REGISTRY[args.model](cfg) + model.load_state_dict(sd["model"]) + model.to(device) + model.train(False) + enc = get_ecg_encoder(args.model, model) + + print(f"[eval] loading PTB-XL cache from {args.ptbxl_npz}") + arr = np.load(args.ptbxl_npz) + X, y = arr["X"], arr["y"] + print(f"[eval] X={X.shape} y_pos={int(y.sum())} y_neg={int((1 - y).sum())}") + X_t = torch.from_numpy(X) + feats = pooled_features(enc, X_t, device=device, batch_size=64) + + rng = np.random.default_rng(0) + idx = rng.permutation(len(y)) + cut = int(len(idx) * 0.8) + train_idx, test_idx = idx[:cut], idx[cut:] + auroc = linear_probe_auroc(feats[train_idx], y[train_idx], feats[test_idx], y[test_idx]) + print(f"[eval] AF AUROC = {auroc:.4f}") + Path(args.out).parent.mkdir(parents=True, exist_ok=True) + Path(args.out).write_text(json.dumps({ + "ckpt": args.ckpt, "model": args.model, "auroc": auroc, + "n_train": int(cut), "n_test": int(len(idx) - cut), + "n_pos": int(y.sum()), "n_neg": int((1 - y).sum()), + }, indent=2)) + + +if __name__ == "__main__": + main() diff --git a/scripts/fetch_ptbxl.py b/scripts/fetch_ptbxl.py new file mode 100644 index 0000000000000000000000000000000000000000..27610d92305315d5f334e3c5671799af4f5389d3 --- /dev/null +++ b/scripts/fetch_ptbxl.py @@ -0,0 +1,116 @@ +"""Fetch PTB-XL from PhysioNet (open access, no credentialing) and cache lead II +@ 250 Hz with binary AFIB labels into a single .npz file for fast eval reload. + +Resulting cache layout: + /workspace/cache/ptbxl_af.npz (X: [N,1,2500] float32, y: [N] int64) +""" +from __future__ import annotations + +import argparse +import io +import os +import re +import tarfile +import zipfile +from pathlib import Path + +import numpy as np +import requests +from tqdm import tqdm + +PTBXL_VERSION = "1.0.3" +PTBXL_URL = ( + f"https://physionet.org/static/published-projects/ptb-xl/" + f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip" +) + + +def _resample_500_to_250(x): + from scipy.signal import resample_poly + return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--root", default="/workspace/cache/ptbxl") + ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz") + ap.add_argument("--limit", type=int, default=None) + args = ap.parse_args() + + root = Path(args.root) + root.mkdir(parents=True, exist_ok=True) + zip_path = root / "ptbxl.zip" + if not zip_path.exists(): + print(f"[fetch] downloading PTB-XL ({PTBXL_URL})") + r = requests.get(PTBXL_URL, stream=True, timeout=600) + r.raise_for_status() + total = int(r.headers.get("content-length", 0)) + with open(zip_path, "wb") as f: + for chunk in tqdm(r.iter_content(chunk_size=1024 * 1024), + total=total // (1024 * 1024)): + if chunk: + f.write(chunk) + extract_dir = root / "extracted" + if not extract_dir.exists(): + print(f"[fetch] extracting to {extract_dir}") + with zipfile.ZipFile(zip_path) as z: + z.extractall(extract_dir) + # find ptbxl_database.csv + csvs = list(extract_dir.rglob("ptbxl_database.csv")) + assert csvs, "ptbxl_database.csv not found in extracted zip" + db_csv = csvs[0] + db_root = db_csv.parent + print(f"[fetch] db_root = {db_root}") + + import pandas as pd + import wfdb + + meta = pd.read_csv(db_csv, index_col="ecg_id") + # parse scp_codes safely + def _parse(val): + try: + import json + return json.loads(val.replace("'", '"')) + except Exception: + out = {} + for tok in val.strip("{} ").split(","): + if ":" in tok: + k, v = tok.split(":", 1) + out[k.strip().strip("'\"")] = float(v.strip()) + return out + + meta["scp_parsed"] = meta["scp_codes"].apply(_parse) + meta["afib"] = meta["scp_parsed"].apply( + lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys())) + ) + if args.limit: + meta = meta.sample(n=args.limit, random_state=0) + print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}") + + xs, ys = [], [] + for _, row in tqdm(meta.iterrows(), total=len(meta), desc="ptb-xl"): + rec = wfdb.rdrecord(str(db_root / row["filename_hr"])) + signals = rec.p_signal # [T, 12] @ 500 Hz + lead_names = rec.sig_name + if "II" not in lead_names: + continue + lead_ii = signals[:, lead_names.index("II")] + x = _resample_500_to_250(lead_ii) + if x.shape[0] < 2500: + x = np.pad(x, (0, 2500 - x.shape[0])) + else: + x = x[:2500] + x = (x - x.mean()) / (x.std() + 1e-6) + xs.append(x.astype(np.float32)) + ys.append(int(row["afib"])) + + X = np.stack(xs).astype(np.float32)[:, None, :] + y = np.array(ys, dtype=np.int64) + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + np.savez_compressed(out, X=X, y=y) + print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}") + + +if __name__ == "__main__": + main() diff --git a/scripts/fetch_ptbxl_v2.py b/scripts/fetch_ptbxl_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..15c3e152831790a3ceb385f2487d54911015637b --- /dev/null +++ b/scripts/fetch_ptbxl_v2.py @@ -0,0 +1,140 @@ +"""PTB-XL fetch v2 — multiprocessing, full zip download via wget. + +Downloads PTB-XL v1.0.3 from PhysioNet, extracts, parses with wfdb in parallel +using a process pool (8-16 workers), caches to /workspace/cache/ptbxl_af.npz. + +Usage: + python scripts/fetch_ptbxl_v2.py --root /workspace/cache/ptbxl --out /workspace/cache/ptbxl_af.npz [--workers 12] +""" +from __future__ import annotations + +import argparse +import json +import multiprocessing as mp +import os +import subprocess +import zipfile +from pathlib import Path + +import numpy as np +import pandas as pd +from scipy.signal import resample_poly +from tqdm import tqdm +import wfdb + +PTBXL_VERSION = "1.0.3" +PTBXL_URL = ( + f"https://physionet.org/static/published-projects/ptb-xl/" + f"ptb-xl-a-large-publicly-available-electrocardiography-dataset-{PTBXL_VERSION}.zip" +) + + +def _resample_500_to_250(x): + return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32) + + +def _parse_scp(val): + if isinstance(val, dict): + return val + if not isinstance(val, str): + return {} + try: + return json.loads(val.replace("'", '"')) + except Exception: + out = {} + for tok in val.strip("{} ").split(","): + if ":" in tok: + k, v = tok.split(":", 1) + out[k.strip().strip("'\"")] = float(v.strip()) + return out + + +def _process_one(arg): + """Read one PTB-XL record's lead II and return (x, y).""" + db_root, fname_hr, afib = arg + try: + rec = wfdb.rdrecord(str(Path(db_root) / fname_hr)) + signals = rec.p_signal + lead_names = rec.sig_name + if "II" not in lead_names: + return None + lead_ii = signals[:, lead_names.index("II")] + x = _resample_500_to_250(lead_ii) + if x.shape[0] < 2500: + x = np.pad(x, (0, 2500 - x.shape[0])) + else: + x = x[:2500] + x = (x - x.mean()) / (x.std() + 1e-6) + return (x.astype(np.float32), int(afib)) + except Exception as e: + return None + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--root", default="/workspace/cache/ptbxl") + ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz") + ap.add_argument("--workers", type=int, default=16) + ap.add_argument("--limit", type=int, default=None) + args = ap.parse_args() + + root = Path(args.root) + root.mkdir(parents=True, exist_ok=True) + zip_path = root / "ptbxl.zip" + + # Download via wget (resumable, faster than requests for 3 GB) + if not zip_path.exists() or zip_path.stat().st_size < 1_000_000_000: # < 1 GB = incomplete + print(f"[fetch] downloading PTB-XL via wget", flush=True) + zip_path.unlink(missing_ok=True) + subprocess.run([ + "wget", "-c", "-O", str(zip_path), PTBXL_URL + ], check=True) + + print(f"[fetch] zip size: {zip_path.stat().st_size / 1e9:.2f} GB", flush=True) + + extract_dir = root / "extracted" + if not extract_dir.exists() or not list(extract_dir.rglob("ptbxl_database.csv")): + print(f"[fetch] extracting to {extract_dir}", flush=True) + extract_dir.mkdir(parents=True, exist_ok=True) + with zipfile.ZipFile(zip_path) as z: + z.extractall(extract_dir) + + csvs = list(extract_dir.rglob("ptbxl_database.csv")) + assert csvs, "ptbxl_database.csv not found after extract" + db_csv = csvs[0] + db_root = db_csv.parent + print(f"[fetch] db_root = {db_root}", flush=True) + + meta = pd.read_csv(db_csv, index_col="ecg_id") + meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp) + meta["afib"] = meta["scp_parsed"].apply( + lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys())) + ) + if args.limit: + meta = meta.sample(n=args.limit, random_state=0) + print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True) + + work = [(str(db_root), row["filename_hr"], row["afib"]) + for _, row in meta.iterrows()] + + print(f"[fetch] parsing with {args.workers} workers", flush=True) + xs, ys = [], [] + with mp.Pool(args.workers) as pool: + for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=8), + total=len(work), desc="ptb-xl"): + if r is None: + continue + xs.append(r[0]) + ys.append(r[1]) + + X = np.stack(xs).astype(np.float32)[:, None, :] + y = np.array(ys, dtype=np.int64) + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + np.savez_compressed(out, X=X, y=y) + print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}", + flush=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/fetch_ptbxl_v3.py b/scripts/fetch_ptbxl_v3.py new file mode 100644 index 0000000000000000000000000000000000000000..88ed090f140e3462f3daa5002a9ca794d00a4ab0 --- /dev/null +++ b/scripts/fetch_ptbxl_v3.py @@ -0,0 +1,173 @@ +"""PTB-XL fetch v3 — concurrent per-file HTTP downloads (no 3 GB monolithic zip). + +PhysioNet exposes individual files at: + https://physionet.org/files/ptb-xl/1.0.3/ + +Strategy: + 1. Download just `ptbxl_database.csv` (~4 MB) to know which records exist + 2. Concurrent download of the .hea/.dat pairs we need (lead II only — but + we need to download all 12 leads since each .dat is one multilead file) + 3. Parse with wfdb in a process pool + +Total bytes: 21k records × ~400 KB each ≈ 8 GB. Even at 200 KB/s that's +slow, but with 32 concurrent connections we should saturate the pod's +~1 Gbit network (~125 MB/s). 8 GB / 125 MB/s = 64 sec ideal, ~10 min +realistic given physionet bandwidth caps. + +Actually shortcut — use the LR (low-res, 100 Hz) variant: ~75 KB per file +×21k = 1.5 GB total. We resample 100→250 Hz with scipy. Quality is fine +for AF detection (PTB-XL paper uses both 100 and 500 Hz freely). +""" +from __future__ import annotations + +import argparse +import concurrent.futures as cf +import json +import multiprocessing as mp +import os +import urllib.request +from pathlib import Path + +import numpy as np +import pandas as pd +from scipy.signal import resample_poly +from tqdm import tqdm +import wfdb + +BASE = "https://physionet.org/files/ptb-xl/1.0.3" + + +def _parse_scp(val): + if isinstance(val, dict): + return val + if not isinstance(val, str): + return {} + try: + return json.loads(val.replace("'", '"')) + except Exception: + out = {} + for tok in val.strip("{} ").split(","): + if ":" in tok: + k, v = tok.split(":", 1) + out[k.strip().strip("'\"")] = float(v.strip()) + return out + + +def _download(args): + url, dst = args + if dst.exists() and dst.stat().st_size > 0: + return True + dst.parent.mkdir(parents=True, exist_ok=True) + try: + with urllib.request.urlopen(url, timeout=60) as r, open(dst, "wb") as f: + f.write(r.read()) + return True + except Exception as e: + return False + + +def _resample(x, src_hz, dst_hz): + from math import gcd + g = gcd(int(src_hz), int(dst_hz)) + return resample_poly(x, up=int(dst_hz)//g, down=int(src_hz)//g, axis=-1).astype(np.float32) + + +def _process_one(arg): + db_root, fname, afib, src_hz = arg + try: + rec = wfdb.rdrecord(str(Path(db_root) / fname)) + signals = rec.p_signal + lead_names = rec.sig_name + if "II" not in lead_names: + return None + lead_ii = signals[:, lead_names.index("II")] + x = _resample(lead_ii, src_hz, 250) + if x.shape[0] < 2500: + x = np.pad(x, (0, 2500 - x.shape[0])) + else: + x = x[:2500] + x = (x - x.mean()) / (x.std() + 1e-6) + return (x.astype(np.float32), int(afib)) + except Exception: + return None + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--root", default="/workspace/cache/ptbxl") + ap.add_argument("--out", default="/workspace/cache/ptbxl_af.npz") + ap.add_argument("--use_lr", action="store_true", help="100 Hz variant (smaller, faster)") + ap.add_argument("--limit", type=int, default=None) + ap.add_argument("--dl_workers", type=int, default=32) + ap.add_argument("--parse_workers", type=int, default=16) + args = ap.parse_args() + + root = Path(args.root) + root.mkdir(parents=True, exist_ok=True) + + csv_path = root / "ptbxl_database.csv" + if not csv_path.exists(): + print(f"[fetch] downloading ptbxl_database.csv", flush=True) + urllib.request.urlretrieve(f"{BASE}/ptbxl_database.csv", str(csv_path)) + print(f"[fetch] csv size: {csv_path.stat().st_size/1e6:.1f} MB", flush=True) + + meta = pd.read_csv(csv_path, index_col="ecg_id") + meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp) + meta["afib"] = meta["scp_parsed"].apply( + lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys())) + ) + if args.limit: + meta = meta.sample(n=args.limit, random_state=0) + print(f"[fetch] {len(meta)} records, AF positive = {int(meta['afib'].sum())}", flush=True) + + # Decide LR vs HR + fname_col = "filename_lr" if args.use_lr else "filename_hr" + src_hz = 100 if args.use_lr else 500 + + # Build download list (.hea + .dat per record) + dl_list = [] + for _, row in meta.iterrows(): + rel = row[fname_col] # e.g. records100/00000/00001_lr + for ext in (".hea", ".dat"): + url = f"{BASE}/{rel}{ext}" + dst = root / f"{rel}{ext}" + dl_list.append((url, dst)) + + # Filter out already-present + todo = [(u, d) for u, d in dl_list if not (d.exists() and d.stat().st_size > 0)] + print(f"[fetch] {len(todo)} files to download (skipping {len(dl_list)-len(todo)} cached)", + flush=True) + + if todo: + with cf.ThreadPoolExecutor(max_workers=args.dl_workers) as ex: + ok_count = 0 + for ok in tqdm(ex.map(_download, todo), total=len(todo), desc="dl"): + if ok: + ok_count += 1 + print(f"[fetch] downloaded ok={ok_count}/{len(todo)}", flush=True) + + # Parse + work = [(str(root), row[fname_col], row["afib"], src_hz) + for _, row in meta.iterrows()] + print(f"[fetch] parsing {len(work)} records with {args.parse_workers} workers", + flush=True) + xs, ys = [], [] + with mp.Pool(args.parse_workers) as pool: + for r in tqdm(pool.imap_unordered(_process_one, work, chunksize=16), + total=len(work), desc="parse"): + if r is None: + continue + xs.append(r[0]) + ys.append(r[1]) + + X = np.stack(xs).astype(np.float32)[:, None, :] + y = np.array(ys, dtype=np.int64) + out = Path(args.out) + out.parent.mkdir(parents=True, exist_ok=True) + np.savez_compressed(out, X=X, y=y) + print(f"[fetch] wrote {out}: X={X.shape} y_pos={int(y.sum())} y_neg={int((1-y).sum())}", + flush=True) + + +if __name__ == "__main__": + main() diff --git a/scripts/pod_bootstrap.sh b/scripts/pod_bootstrap.sh new file mode 100644 index 0000000000000000000000000000000000000000..ab590320fa0951120d1cc917b3bc435ec4cfbb5a --- /dev/null +++ b/scripts/pod_bootstrap.sh @@ -0,0 +1,64 @@ +#!/usr/bin/env bash +# Run on the RunPod pod. Args: +set -euo pipefail +MODEL="${1:?model letter required}" +RUN_NAME="${2:?run name required}" + +echo "[bootstrap] model=$MODEL run=$RUN_NAME" +cd /workspace +REPO_DIR="" +for d in PhysioJEPA physiojepa; do + if [ -d "$d" ]; then REPO_DIR="$d"; break; fi +done +[ -n "$REPO_DIR" ] || { echo "no repo dir found at /workspace/{PhysioJEPA,physiojepa}"; exit 1; } +cd "$REPO_DIR" + +# Use the image's system Python (already has torch 2.4.1+cu124 wired up). +# Install only the extras we need into the system site-packages. +PY=/usr/bin/python3 +$PY -m pip install --quiet --upgrade pip +$PY -m pip install --quiet \ + 'datasets>=4.8.4' 'einops>=0.8.2' 'matplotlib>=3.10.0' \ + 'neurokit2>=0.2.13' 'python-dotenv>=1.0' 'pyyaml>=6.0' \ + 'scikit-learn>=1.5' 'scipy>=1.13' 'tqdm>=4.66' \ + 'wandb>=0.18' 'wfdb>=4.3.1' 'huggingface_hub>=0.25' 'requests' +RUN_PY="$PY" + +# Stage env keys (the launcher will have written /workspace/.env into the pod via send) +if [ -f /workspace/.env ]; then + cp /workspace/.env .env +fi + +# Step 1: prepare data (idempotent) +if [ ! -f /workspace/cache/mimic_index.json ]; then + echo "[bootstrap] downloading MIMIC shards + building index" + PYTHONPATH=src $RUN_PY scripts/prepare_data.py \ + --root /workspace/cache/mimic \ + --index /workspace/cache/mimic_index.json +fi + +# write shard_roots json for trainer +PYTHONPATH=src $RUN_PY -c " +import json, pathlib +roots = sorted([str(p) for p in pathlib.Path('/workspace/cache/mimic').glob('shard_*') + if (p / 'dataset_info.json').exists()]) +pathlib.Path('/workspace/cache/shard_roots.json').write_text(json.dumps(roots)) +print('shards:', len(roots)) +" + +# Step 2: train +echo "[bootstrap] launching training: model=$MODEL" +PYTHONPATH=src PYTHONUNBUFFERED=1 $RUN_PY -u scripts/train.py \ + --config configs/base.yaml \ + --model "$MODEL" \ + --run_name "$RUN_NAME" \ + --epochs 25 \ + --shard_roots_json /workspace/cache/shard_roots.json \ + --index_path /workspace/cache/mimic_index.json \ + --output_dir /workspace/runs \ + --num_workers 8 \ + --subset_frac 0.10 \ + --log_every 25 \ + 2>&1 | tee "/workspace/runs/${RUN_NAME}.log" + +echo "[bootstrap] done" diff --git a/scripts/pod_bootstrap_ablation.sh b/scripts/pod_bootstrap_ablation.sh new file mode 100644 index 0000000000000000000000000000000000000000..56cdb55d02302c282809c5ac9ad3a07e40fdba85 --- /dev/null +++ b/scripts/pod_bootstrap_ablation.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +# Slow-tau A ablation. Args: [epochs] +set -euo pipefail +RUN_NAME="${1:?run_name}" +EMA_END="${2:-0.999}" +EMA_WARMUP="${3:-0.60}" +EPOCHS="${4:-25}" + +echo "[bootstrap] slow-tau ablation: run=$RUN_NAME ema_end=$EMA_END warmup_frac=$EMA_WARMUP epochs=$EPOCHS" +cd /workspace +REPO_DIR="" +for d in PhysioJEPA physiojepa; do + if [ -d "$d" ]; then REPO_DIR="$d"; break; fi +done +[ -n "$REPO_DIR" ] || { echo "no repo dir found"; exit 1; } +cd "$REPO_DIR" + +PY=/usr/bin/python3 +$PY -m pip install --quiet --upgrade pip +$PY -m pip install --quiet \ + 'datasets>=4.8.4' 'einops>=0.8.2' 'matplotlib>=3.10.0' \ + 'neurokit2>=0.2.13' 'python-dotenv>=1.0' 'pyyaml>=6.0' \ + 'scikit-learn>=1.5' 'scipy>=1.13' 'tqdm>=4.66' \ + 'wandb>=0.18' 'wfdb>=4.3.1' 'huggingface_hub>=0.25' 'requests' + +if [ -f /workspace/.env ]; then cp /workspace/.env .env; fi + +if [ ! -f /workspace/cache/mimic_index.json ]; then + echo "[bootstrap] downloading MIMIC + building index" + PYTHONPATH=src $PY scripts/prepare_data.py \ + --root /workspace/cache/mimic --index /workspace/cache/mimic_index.json +fi + +PYTHONPATH=src $PY -c " +import json, pathlib +roots = sorted([str(p) for p in pathlib.Path('/workspace/cache/mimic').glob('shard_*') + if (p / 'dataset_info.json').exists()]) +pathlib.Path('/workspace/cache/shard_roots.json').write_text(json.dumps(roots)) +print('shards:', len(roots)) +" + +mkdir -p /workspace/runs +echo "[bootstrap] launching A ablation" +PYTHONPATH=src PYTHONUNBUFFERED=1 $PY -u scripts/train.py \ + --config configs/base.yaml \ + --model A \ + --run_name "$RUN_NAME" \ + --epochs "$EPOCHS" \ + --shard_roots_json /workspace/cache/shard_roots.json \ + --index_path /workspace/cache/mimic_index.json \ + --output_dir /workspace/runs \ + --num_workers 8 \ + --subset_frac 0.10 \ + --log_every 25 \ + --ema_end "$EMA_END" \ + --ema_warmup_frac "$EMA_WARMUP" \ + --seed 42 \ + 2>&1 | tee "/workspace/runs/${RUN_NAME}.log" + +echo "[bootstrap] done" diff --git a/scripts/pod_bootstrap_ablation_v2.sh b/scripts/pod_bootstrap_ablation_v2.sh new file mode 100644 index 0000000000000000000000000000000000000000..feb33d7b6194e29d532b7eaca4b8007e14070efb --- /dev/null +++ b/scripts/pod_bootstrap_ablation_v2.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +# Generic A-ablation bootstrap. All extra args go to train.py. +# Args: +set -euo pipefail +RUN_NAME="${1:?run_name}" +SUBSET="${2:?subset_frac}" +shift 2 +EXTRA=("$@") + +echo "[bootstrap] A ablation: run=$RUN_NAME subset=$SUBSET extra=${EXTRA[*]}" +cd /workspace +REPO_DIR="" +for d in PhysioJEPA physiojepa; do + if [ -d "$d" ]; then REPO_DIR="$d"; break; fi +done +[ -n "$REPO_DIR" ] || { echo "no repo dir"; exit 1; } +cd "$REPO_DIR" + +PY=/usr/bin/python3 +$PY -m pip install --quiet --upgrade pip +$PY -m pip install --quiet \ + 'datasets>=4.8.4' 'einops>=0.8.2' 'matplotlib>=3.10.0' \ + 'neurokit2>=0.2.13' 'python-dotenv>=1.0' 'pyyaml>=6.0' \ + 'scikit-learn>=1.5' 'scipy>=1.13' 'tqdm>=4.66' \ + 'wandb>=0.18' 'wfdb>=4.3.1' 'huggingface_hub>=0.25' 'requests' + +if [ -f /workspace/.env ]; then cp /workspace/.env .env; fi + +if [ ! -f /workspace/cache/mimic_index.json ]; then + echo "[bootstrap] downloading MIMIC + building index" + PYTHONPATH=src $PY scripts/prepare_data.py \ + --root /workspace/cache/mimic --index /workspace/cache/mimic_index.json +fi + +PYTHONPATH=src $PY -c " +import json, pathlib +roots = sorted([str(p) for p in pathlib.Path('/workspace/cache/mimic').glob('shard_*') + if (p / 'dataset_info.json').exists()]) +pathlib.Path('/workspace/cache/shard_roots.json').write_text(json.dumps(roots)) +print('shards:', len(roots)) +" + +mkdir -p /workspace/runs +echo "[bootstrap] launching" +PYTHONPATH=src PYTHONUNBUFFERED=1 $PY -u scripts/train.py \ + --config configs/base.yaml \ + --model A \ + --run_name "$RUN_NAME" \ + --epochs 25 \ + --shard_roots_json /workspace/cache/shard_roots.json \ + --index_path /workspace/cache/mimic_index.json \ + --output_dir /workspace/runs \ + --num_workers 8 \ + --subset_frac "$SUBSET" \ + --log_every 25 \ + --seed 42 \ + "${EXTRA[@]}" \ + 2>&1 | tee "/workspace/runs/${RUN_NAME}.log" + +echo "[bootstrap] done" diff --git a/scripts/pod_bootstrap_definitive.sh b/scripts/pod_bootstrap_definitive.sh new file mode 100644 index 0000000000000000000000000000000000000000..4c5b56c64efc602aeca3f0f5606ef24730afdf74 --- /dev/null +++ b/scripts/pod_bootstrap_definitive.sh @@ -0,0 +1,55 @@ +#!/usr/bin/env bash +# Definitive full-scale run. Args: [extra train.py args...] +set -euo pipefail +MODEL="${1:?model letter}"; RUN_NAME="${2:?run_name}"; shift 2; EXTRA=("$@") + +echo "[bootstrap] definitive run: model=$MODEL run=$RUN_NAME extra=${EXTRA[*]}" +cd /workspace +REPO_DIR=""; for d in PhysioJEPA physiojepa; do [ -d "$d" ] && REPO_DIR="$d" && break; done +[ -n "$REPO_DIR" ] || { echo "no repo dir"; exit 1; } +cd "$REPO_DIR" + +PY=/usr/bin/python3 +$PY -m pip install --quiet --upgrade pip +$PY -m pip install --quiet \ + 'datasets>=4.8.4' 'einops>=0.8.2' 'matplotlib>=3.10.0' \ + 'neurokit2>=0.2.13' 'python-dotenv>=1.0' 'pyyaml>=6.0' \ + 'scikit-learn>=1.5' 'scipy>=1.13' 'tqdm>=4.66' \ + 'wandb>=0.18' 'wfdb>=4.3.1' 'huggingface_hub>=0.25' 'requests' + +[ -f /workspace/.env ] && cp /workspace/.env .env + +# Step 1: download MIMIC shards + build index (idempotent) +if [ ! -f /workspace/cache/mimic_index.json ]; then + echo "[bootstrap] downloading MIMIC + building index" + PYTHONPATH=src $PY scripts/prepare_data.py \ + --root /workspace/cache/mimic --index /workspace/cache/mimic_index.json +fi + +# Step 2: precompute mmap windows (idempotent — checks inside) +if [ ! -f /workspace/cache/windows_meta.json ]; then + echo "[bootstrap] precomputing windows → mmap" + PYTHONPATH=src $PY -u scripts/precompute_windows.py \ + --index /workspace/cache/mimic_index.json \ + --out_dir /workspace/cache +fi + +# Step 3: train +mkdir -p /workspace/runs +echo "[bootstrap] launching training: model=$MODEL" +PYTHONPATH=src PYTHONUNBUFFERED=1 $PY -u scripts/train.py \ + --config configs/base.yaml \ + --model "$MODEL" \ + --run_name "$RUN_NAME" \ + --epochs 100 \ + --batch_size 64 \ + --fast_cache_dir /workspace/cache \ + --output_dir /workspace/runs \ + --num_workers 12 \ + --log_every 100 \ + --mask_ratio 0.75 \ + --seed 42 \ + "${EXTRA[@]}" \ + 2>&1 | tee "/workspace/runs/${RUN_NAME}.log" + +echo "[bootstrap] done" diff --git a/scripts/precompute_windows.py b/scripts/precompute_windows.py new file mode 100644 index 0000000000000000000000000000000000000000..4c7cdd9a3587c00103ac12a5c8ff92c5dcaf2eb1 --- /dev/null +++ b/scripts/precompute_windows.py @@ -0,0 +1,141 @@ +"""Precompute all ECG/PPG windows into a single memory-mapped tensor file. + +Reads the MIMIC shard index, applies bandpass + zscore per window, and writes +a flat binary file with a companion metadata JSON. At runtime, __getitem__ +is a single mmap read (~0.1 ms) instead of load_from_disk + filter (~20 ms). + +Output: + /workspace/cache/windows_ecg.bin (float32, [N, 2500]) + /workspace/cache/windows_ppg.bin (float32, [N, 1250]) + /workspace/cache/windows_meta.json (subject_id per window, N total) +""" +from __future__ import annotations + +import argparse +import json +import os +import struct +from pathlib import Path + +import numpy as np +from scipy.signal import butter, filtfilt +from tqdm import tqdm + +from dotenv import load_dotenv + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +from datasets import load_from_disk + +ECG_FS = 250.0 +PPG_FS = 125.0 +ECG_WIN = 2500 +PPG_WIN = 1250 + + +def _bandpass(x, fs, lo, hi, order=3): + ny = 0.5 * fs + b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") + return filtfilt(b, a, x, method="gust").astype(np.float32) + + +def _zscore(x, eps=1e-6): + return ((x - x.mean()) / (x.std() + eps)).astype(np.float32) + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--index", required=True) + ap.add_argument("--out_dir", default="/workspace/cache") + ap.add_argument("--workers", type=int, default=1) + args = ap.parse_args() + + index = json.loads(Path(args.index).read_text()) + out = Path(args.out_dir) + out.mkdir(parents=True, exist_ok=True) + + ecg_path = out / "windows_ecg.bin" + ppg_path = out / "windows_ppg.bin" + meta_path = out / "windows_meta.json" + + if ecg_path.exists() and ppg_path.exists() and meta_path.exists(): + existing = json.loads(meta_path.read_text()) + if existing.get("n_windows") == len(index): + print(f"[precompute] already done: {existing['n_windows']} windows") + return + + print(f"[precompute] {len(index)} windows to process") + + shard_cache = {} + + def load_shard(sidx): + if sidx not in shard_cache: + for p in Path(args.out_dir).parent.glob("mimic/shard_*"): + if int(p.name.split("_")[1]) == sidx: + shard_cache[sidx] = load_from_disk(str(p)) + break + return shard_cache.get(sidx) + + # Find shard root + mimic_root = None + for candidate in [Path(args.out_dir) / "mimic", Path(args.out_dir).parent / "mimic", + Path("/workspace/cache/mimic")]: + if candidate.exists(): + mimic_root = candidate + break + assert mimic_root, "mimic shard root not found" + + def load_shard_v2(sidx): + if sidx not in shard_cache: + p = mimic_root / f"shard_{sidx:05d}" + if (p / "dataset_info.json").exists(): + shard_cache[sidx] = load_from_disk(str(p)) + return shard_cache.get(sidx) + + subjects = [] + n_written = 0 + + with open(ecg_path, "wb") as f_ecg, open(ppg_path, "wb") as f_ppg: + for rec in tqdm(index, desc="precompute"): + sidx = rec["shard_idx"] + ds = load_shard_v2(sidx) + if ds is None: + continue + row = ds[rec["row_idx"]] + ecg_full = np.asarray(row["ecg"], dtype=np.float32) + ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0] + names = list(row["ecg_names"]) + if "II" not in names: + continue + ecg_lead = ecg_full[names.index("II")] + se = rec["win_start_ecg"] + sp = rec["win_start_ppg"] + ecg_win = ecg_lead[se : se + ECG_WIN] + ppg_win = ppg_full[sp : sp + PPG_WIN] + if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN: + continue + + ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0)) + ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0)) + + f_ecg.write(ecg_win.tobytes()) + f_ppg.write(ppg_win.tobytes()) + subjects.append(rec["subject_id"]) + n_written += 1 + + meta = { + "n_windows": n_written, + "ecg_win": ECG_WIN, + "ppg_win": PPG_WIN, + "dtype": "float32", + "subjects": subjects, + } + meta_path.write_text(json.dumps(meta)) + ecg_gb = ecg_path.stat().st_size / 1e9 + ppg_gb = ppg_path.stat().st_size / 1e9 + print(f"[precompute] wrote {n_written} windows: ecg={ecg_gb:.2f}GB ppg={ppg_gb:.2f}GB") + + +if __name__ == "__main__": + main() diff --git a/scripts/prepare_data.py b/scripts/prepare_data.py new file mode 100644 index 0000000000000000000000000000000000000000..6b8089f820d87588a23f9cad4a4ee5d6cda6e36c --- /dev/null +++ b/scripts/prepare_data.py @@ -0,0 +1,52 @@ +"""Download all MIMIC shards on the training host, then build the window index. + +Run on the RunPod pod right after boot. Saves to /workspace/cache/. +""" +from __future__ import annotations + +import argparse +import json +import os +from pathlib import Path + +from dotenv import load_dotenv +from huggingface_hub import snapshot_download + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) + +from physiojepa.data import MIMICAlignedDataset + +REPO = "lucky9-cyou/mimic-iv-aligned-ppg-ecg" + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--root", type=str, default="/workspace/cache/mimic") + ap.add_argument("--index", type=str, default="/workspace/cache/mimic_index.json") + ap.add_argument("--n_shards", type=int, default=412) + args = ap.parse_args() + + root = Path(args.root) + root.mkdir(parents=True, exist_ok=True) + patterns = [f"shard_{i:05d}/*" for i in range(args.n_shards)] + print(f"[prepare] downloading {len(patterns)} shard patterns to {root}") + local = snapshot_download(REPO, repo_type="dataset", allow_patterns=patterns, + local_dir=str(root), max_workers=16) + shard_roots = sorted([p for p in Path(local).glob("shard_*") + if (p / "dataset_info.json").exists()]) + print(f"[prepare] {len(shard_roots)} shards ready; building window index") + ds = MIMICAlignedDataset(shard_roots=shard_roots, index_path=Path(args.index), + build_index=True) + info = { + "n_shards": len(shard_roots), + "n_windows": len(ds), + "n_subjects": len(set(r["subject_id"] for r in ds.index)), + "shard_roots": [str(p) for p in shard_roots], + } + Path(args.index).with_suffix(".meta.json").write_text(json.dumps(info, indent=2)) + print(f"[prepare] index built: {json.dumps(info, indent=2)}") + + +if __name__ == "__main__": + main() diff --git a/scripts/probe_when_ready.sh b/scripts/probe_when_ready.sh new file mode 100644 index 0000000000000000000000000000000000000000..be66b820c7c648c8babfa5fe8ed9c7634c601cdd --- /dev/null +++ b/scripts/probe_when_ready.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Wait for the next .pt checkpoint to appear under , then run probe. +# Usage: probe_when_ready.sh +set -euo pipefail +RUN_DIR="${1:?run_dir}" +MODEL="${2:?model letter A|B|C|F}" +PTBXL="${3:?ptbxl npz path}" +OUT="${4:?output json}" + +echo "[probe] waiting for ckpt in $RUN_DIR" +while :; do + CKPT=$(ls -t "$RUN_DIR"/*.pt 2>/dev/null | head -1 || true) + if [ -n "$CKPT" ] && [ -f "$PTBXL" ]; then + echo "[probe] ckpt=$CKPT, ptbxl=$PTBXL — running" + cd /workspace/PhysioJEPA + PYTHONPATH=src /usr/bin/python3 -u scripts/eval_checkpoint.py \ + --ckpt "$CKPT" --model "$MODEL" --ptbxl_npz "$PTBXL" --out "$OUT" + echo "[probe] done -> $OUT" + cat "$OUT" + break + fi + sleep 10 +done diff --git a/scripts/runpod_launch.py b/scripts/runpod_launch.py new file mode 100644 index 0000000000000000000000000000000000000000..6868e116dda640773566712655797a04ae0527aa --- /dev/null +++ b/scripts/runpod_launch.py @@ -0,0 +1,186 @@ +"""Launch N RunPod A40 pods, deploy the codebase, kick off training. + +Usage: + python scripts/runpod_launch.py --models A B C F --gpu A40 \ + --image runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04 + +For each model letter: + 1. create pod + 2. wait for SSH + 3. rsync repo + .env via scp + 4. run pod_bootstrap.sh on the pod (in tmux/nohup) + 5. record pod id + run name in runs/launch_manifest.json + +Polling/log retrieval is left to scripts/runpod_status.py. +""" +from __future__ import annotations + +import argparse +import json +import os +import shutil +import subprocess +import sys +import tempfile +import time +from pathlib import Path + +from dotenv import load_dotenv + +load_dotenv() +RUNPOD_API_KEY = os.environ["RUNPOD_API_KEY"] + +GPU_IDS = { + "A40": "NVIDIA A40", + "A6000": "NVIDIA RTX A6000", + "A100": "NVIDIA A100-SXM4-80GB", + "H100": "NVIDIA H100 80GB HBM3", +} + +DEFAULT_IMAGE = "runpod/pytorch:2.4.0-py3.11-cuda12.4.1-devel-ubuntu22.04" + + +def runpodctl(args: list[str], capture: bool = True) -> str: + env = {**os.environ, "RUNPOD_API_KEY": RUNPOD_API_KEY} + res = subprocess.run( + ["runpodctl", *args], env=env, capture_output=capture, text=True + ) + if res.returncode != 0: + raise RuntimeError(f"runpodctl {' '.join(args)} failed: {res.stderr}\n{res.stdout}") + return res.stdout + + +def create_pod(name: str, gpu_id: str, image: str, container_disk: int = 50, + volume_gb: int = 100) -> dict: + out = runpodctl([ + "pod", "create", + "--name", name, + "--gpu-id", gpu_id, + "--gpu-count", "1", + "--image", image, + "--cloud-type", "COMMUNITY", + "--container-disk-in-gb", str(container_disk), + "--volume-in-gb", str(volume_gb), + "--volume-mount-path", "/workspace", + "--ports", "22/tcp", + "--ssh", + ]) + pod = json.loads(out) + return pod + + +def wait_for_ssh(pod_id: str, timeout: int = 600) -> tuple[str, int]: + start = time.time() + last_err = "" + while time.time() - start < timeout: + try: + info = json.loads(runpodctl(["ssh", "info", pod_id])) + host = info.get("publicIp") or info.get("ip") + port = info.get("port") or info.get("sshPort") + if host and port: + return host, int(port) + except Exception as e: + last_err = str(e) + time.sleep(15) + raise TimeoutError(f"SSH not ready for {pod_id}: {last_err}") + + +def ssh(host: str, port: int, cmd: str, user: str = "root", timeout: int = 60) -> str: + res = subprocess.run([ + "ssh", "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-o", "ConnectTimeout=15", + "-p", str(port), + f"{user}@{host}", cmd, + ], capture_output=True, text=True, timeout=timeout) + if res.returncode != 0: + raise RuntimeError(f"ssh {host}:{port} {cmd!r} failed: {res.stderr}") + return res.stdout + + +def scp(host: str, port: int, local_path: Path, remote_path: str, user: str = "root") -> None: + cmd = ["scp", "-o", "StrictHostKeyChecking=no", + "-o", "UserKnownHostsFile=/dev/null", + "-P", str(port)] + if local_path.is_dir(): + cmd.append("-r") + cmd.extend([str(local_path), f"{user}@{host}:{remote_path}"]) + res = subprocess.run(cmd, capture_output=True, text=True, timeout=900) + if res.returncode != 0: + raise RuntimeError(f"scp {local_path} -> {host}:{remote_path} failed: {res.stderr}") + + +def deploy_and_launch(host: str, port: int, model: str, run_name: str, repo_root: Path) -> None: + # build a tarball excluding bulky dirs + with tempfile.TemporaryDirectory() as td: + tar = Path(td) / "physiojepa.tar.gz" + excludes = [".venv", ".git", "__pycache__", "runs", "cache", "docs/figures", + "docs/paperes"] + excl_args = [] + for e in excludes: + excl_args.extend(["--exclude", e]) + subprocess.run( + ["tar", "-czf", str(tar), *excl_args, "-C", str(repo_root.parent), + repo_root.name], + check=True, + ) + scp(host, port, tar, "/workspace/physiojepa.tar.gz") + # also send .env + env_file = repo_root / ".env" + scp(host, port, env_file, "/workspace/.env") + ssh(host, port, "set -e; cd /workspace && rm -rf physiojepa && " + "tar -xzf physiojepa.tar.gz && rm physiojepa.tar.gz") + # background the bootstrap with nohup so SSH disconnect doesn't kill it + bootstrap = ( + f"set -e; mkdir -p /workspace/runs; " + f"cd /workspace/physiojepa && chmod +x scripts/pod_bootstrap.sh && " + f"nohup bash scripts/pod_bootstrap.sh {model} {run_name} " + f"> /workspace/runs/{run_name}.bootstrap.log 2>&1 &" + f" disown; echo started; sleep 1" + ) + ssh(host, port, bootstrap) + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--models", nargs="+", default=["A", "B", "C", "F"]) + ap.add_argument("--gpu", default="A40", choices=list(GPU_IDS.keys())) + ap.add_argument("--image", default=DEFAULT_IMAGE) + ap.add_argument("--repo_root", default=str(Path(__file__).resolve().parents[1])) + ap.add_argument("--manifest", default="runs/launch_manifest.json") + args = ap.parse_args() + + repo_root = Path(args.repo_root) + Path(args.manifest).parent.mkdir(parents=True, exist_ok=True) + gpu_id = GPU_IDS[args.gpu] + manifest = [] + + for model in args.models: + run_name = f"e2_{model}_a40" + pod_name = f"pj-{model.lower()}-{int(time.time()) % 100000:05d}" + print(f"[launch] creating pod {pod_name} (model={model}, gpu={args.gpu})") + pod = create_pod(pod_name, gpu_id, args.image) + pod_id = pod.get("id") or pod.get("podId") + print(f"[launch] pod_id={pod_id}, waiting for SSH...") + try: + host, port = wait_for_ssh(pod_id) + except TimeoutError as e: + print(f"[launch] WARN: {e}; deleting pod and continuing") + try: + runpodctl(["pod", "delete", pod_id]) + except Exception: + pass + continue + print(f"[launch] SSH up @ {host}:{port}, deploying code") + deploy_and_launch(host, port, model, run_name, repo_root) + manifest.append({"pod_id": pod_id, "pod_name": pod_name, "host": host, + "port": port, "model": model, "run_name": run_name, + "started_at": time.time()}) + Path(args.manifest).write_text(json.dumps(manifest, indent=2)) + print(f"[launch] {model} kicked off; manifest -> {args.manifest}") + + print(f"[launch] all done. manifest:\n{Path(args.manifest).read_text()}") + + +if __name__ == "__main__": + main() diff --git a/scripts/smoke_test.py b/scripts/smoke_test.py new file mode 100644 index 0000000000000000000000000000000000000000..11dc0856ce4a677b175ca0c2e8872d8137bb1f8d --- /dev/null +++ b/scripts/smoke_test.py @@ -0,0 +1,53 @@ +"""CPU single-batch smoke test — gate before launching any GPU training. + +Verifies that all 4 models forward+backward on a tiny batch with real-shaped +tensors, no NaN, and the loss decreases over a few optimiser steps. +""" +from __future__ import annotations + +import os + +import numpy as np +import torch + +from physiojepa.models import MODEL_REGISTRY, ModelConfig + + +def _fake_batch(b: int = 4, device: str = "cpu") -> dict: + ecg = torch.randn(b, 1, 2500, device=device) + ppg = torch.randn(b, 1, 1250, device=device) + dt = torch.rand(b, device=device) * 0.45 + 0.05 # 50-500 ms + return {"ecg": ecg, "ppg": ppg, "dt_seconds": dt, + "ptt_ms": torch.full((b,), float("nan"), device=device)} + + +def main() -> None: + torch.manual_seed(0) + np.random.seed(0) + cfg = ModelConfig() + device = torch.device("cpu") + for variant in ("A", "B", "C", "F"): + print(f"=== {variant} ===") + m = MODEL_REGISTRY[variant](cfg).to(device) + opt = torch.optim.AdamW(m.parameters(), lr=1e-3) + losses = [] + for step in range(3): + batch = _fake_batch() + opt.zero_grad(set_to_none=True) + out = m.step(batch) + out["loss"].backward() + opt.step() + for online, tgt in m.targets(): + tgt.update(online, tau=0.996) + val = float(out["loss"].item()) + assert np.isfinite(val), f"non-finite loss in {variant}" + losses.append(val) + print(f" step={step} loss={val:.4f} " + f"L_cross={float(out.get('L_cross', torch.tensor(0.0)).item()):.4f} " + f"L_self={float(out.get('L_self', torch.tensor(0.0)).item()):.4f}") + print(f" -> losses: {[round(x, 4) for x in losses]}") + print("\nSMOKE TEST PASSED") + + +if __name__ == "__main__": + main() diff --git a/scripts/snapshot_now.py b/scripts/snapshot_now.py new file mode 100644 index 0000000000000000000000000000000000000000..bb153cc402d10412afd030bd1d8b9b0925ecbe8e --- /dev/null +++ b/scripts/snapshot_now.py @@ -0,0 +1,36 @@ +"""Inject a checkpoint save into a running training process via py-spy/gdb. + +Simpler alternative: this script doesn't actually inject — it just waits for +the next natural ckpt and reports it. For an immediate snapshot, the +trainer needs SIGUSR1 handling (added in a follow-up commit). + +Usage: python snapshot_now.py /workspace/runs/ +""" +from __future__ import annotations + +import argparse +import time +from pathlib import Path + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("run_dir", type=Path) + ap.add_argument("--timeout", type=int, default=900) + args = ap.parse_args() + deadline = time.time() + args.timeout + last_ckpts = set() + while time.time() < deadline: + ckpts = set(args.run_dir.glob("*.pt")) + new = ckpts - last_ckpts + if new: + for c in new: + print(f"[snapshot] new ckpt: {c}") + return + last_ckpts = ckpts + time.sleep(5) + print("[snapshot] timeout, no new ckpt") + + +if __name__ == "__main__": + main() diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..e0aec08c02438d9f164276224e7a4fbe16ab86c0 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,55 @@ +"""CLI entry point: train a single model variant.""" +from __future__ import annotations + +import argparse +import os + +from dotenv import load_dotenv + +load_dotenv() +os.environ.setdefault("HF_TOKEN", os.environ.get("HUGGINGFACE_API_KEY", "")) +os.environ.setdefault("WANDB_API_KEY", os.environ.get("WANDB_API_KEY", "")) + +from physiojepa.trainer import TrainConfig, load_yaml_config, train + + +def main() -> None: + ap = argparse.ArgumentParser() + ap.add_argument("--config", required=True) + ap.add_argument("--run_name", type=str, default=None) + ap.add_argument("--model", type=str, default=None, choices=["A", "B", "C", "F"]) + ap.add_argument("--epochs", type=int, default=None) + ap.add_argument("--batch_size", type=int, default=None) + ap.add_argument("--index_path", type=str, default=None) + ap.add_argument("--shard_roots_json", type=str, default=None, + help="JSON file listing shard roots") + ap.add_argument("--wandb_mode", type=str, default=None) + ap.add_argument("--num_workers", type=int, default=None) + ap.add_argument("--output_dir", type=str, default=None) + ap.add_argument("--subset_frac", type=float, default=None) + ap.add_argument("--log_every", type=int, default=None) + ap.add_argument("--ema_start", type=float, default=None) + ap.add_argument("--ema_end", type=float, default=None) + ap.add_argument("--ema_warmup_frac", type=float, default=None) + ap.add_argument("--seed", type=int, default=None) + ap.add_argument("--pred_depth", type=int, default=None) + ap.add_argument("--query_mode", type=str, default=None, choices=["learned", "sinusoidal"]) + ap.add_argument("--mask_ratio", type=float, default=None) + ap.add_argument("--fast_cache_dir", type=str, default=None) + args = ap.parse_args() + + cfg = load_yaml_config(args.config) + overrides = {k: v for k, v in vars(args).items() if v is not None and k not in ("config",)} + if "shard_roots_json" in overrides: + import json + cfg.shard_roots = json.loads(open(overrides.pop("shard_roots_json")).read()) + for k, v in overrides.items(): + setattr(cfg, k, v) + print(f"[train] resolved config: model={cfg.model} run={cfg.run_name} " + f"epochs={cfg.epochs} bs={cfg.batch_size} shards={len(cfg.shard_roots)}") + res = train(cfg) + print(f"[train] done: {res}") + + +if __name__ == "__main__": + main() diff --git a/skills-lock.json b/skills-lock.json new file mode 100644 index 0000000000000000000000000000000000000000..2fa01d554a776e1c2c4b6de21a354448b4273fed --- /dev/null +++ b/skills-lock.json @@ -0,0 +1,15 @@ +{ + "version": 1, + "skills": { + "flash": { + "source": "runpod/skills", + "sourceType": "github", + "computedHash": "135a8c0a488aee2d1ca2170f5b8bf194febbf10bbb11c1ced3d123df0436847b" + }, + "runpodctl": { + "source": "runpod/skills", + "sourceType": "github", + "computedHash": "5f499fae0e1007c90915a10df00e804181995f0da2fd433831a6b97f16a39264" + } + } +} diff --git a/src/physiojepa/__init__.py b/src/physiojepa/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58331fd25ceddf34fe02f0d7165aed7b4fd76f9e --- /dev/null +++ b/src/physiojepa/__init__.py @@ -0,0 +1,3 @@ +"""PhysioJEPA — time-shifted cross-modal ECG→PPG JEPA.""" + +__version__ = "0.1.0" diff --git a/src/physiojepa/data.py b/src/physiojepa/data.py new file mode 100644 index 0000000000000000000000000000000000000000..d64fd806d31007cf4506ce51aa679803a7348719 --- /dev/null +++ b/src/physiojepa/data.py @@ -0,0 +1,244 @@ +"""PyTorch Dataset over lucky9-cyou/mimic-iv-aligned-ppg-ecg. + +For v1 we: + - Keep only segments where ECG lead II is present (93.7% of data) + - Extract lead II ECG and PPG Pleth + - Window: 10 s slices at 5 s stride + - Native rates: ECG 250 Hz, PPG 125 Hz -> ECG window 2500 samples, PPG 1250 + +Each item returns {ecg: [1, 2500], ppg: [1, 1250], subject_id, segment_start, +measured_ptt_ms (per-window estimate, may be NaN), delta_t_seconds (sampled per +step outside the dataset)}. + +The caller handles delta_t sampling (60% log-uniform + 40% from measured_ptt). +""" +from __future__ import annotations + +import json +import os +import re +from pathlib import Path +from typing import Iterable + +import numpy as np +import torch +from scipy.signal import butter, filtfilt, find_peaks +from torch.utils.data import Dataset + +from datasets import load_from_disk + +ECG_FS = 250.0 +PPG_FS = 125.0 +WINDOW_SEC = 10.0 +STRIDE_SEC = 5.0 +ECG_WIN = int(ECG_FS * WINDOW_SEC) # 2500 +PPG_WIN = int(PPG_FS * WINDOW_SEC) # 1250 +ECG_STRIDE = int(ECG_FS * STRIDE_SEC) +PPG_STRIDE = int(PPG_FS * STRIDE_SEC) + + +def _parse_subject(record_name: str) -> str: + m = re.match(r"p\d+/(p\d+)/", record_name) + return m.group(1) if m else record_name.split("/")[0] + + +def _bandpass(x: np.ndarray, fs: float, lo: float, hi: float, order: int = 3) -> np.ndarray: + ny = 0.5 * fs + b, a = butter(order, [lo / ny, min(hi / ny, 0.99)], btype="band") + return filtfilt(b, a, x, method="gust").astype(np.float32) + + +def _zscore(x: np.ndarray, eps: float = 1e-6) -> np.ndarray: + m = x.mean() + s = x.std() + eps + return ((x - m) / s).astype(np.float32) + + +def _r_peaks(ecg: np.ndarray, fs: float) -> np.ndarray: + x = _bandpass(ecg, fs, 5.0, 15.0) + s = np.diff(x, prepend=x[:1]) ** 2 + w = max(int(0.12 * fs), 1) + mwa = np.convolve(s, np.ones(w) / w, mode="same") + thr = mwa.mean() + 0.5 * mwa.std() + p, _ = find_peaks(mwa, height=thr, distance=int(0.3 * fs)) + return p + + +def _ppg_peaks(ppg: np.ndarray, fs: float) -> np.ndarray: + x = _bandpass(ppg, fs, 0.5, 8.0) + p, _ = find_peaks(x, distance=int(0.3 * fs), + height=x.mean() + 0.3 * x.std(), + prominence=0.1 * x.std()) + return p + + +def _window_ptt_ms(ecg_win: np.ndarray, ppg_win: np.ndarray) -> float: + """Median PTT across beats in one window; np.nan if <3 clean beats.""" + r = _r_peaks(ecg_win, ECG_FS) + p = _ppg_peaks(ppg_win, PPG_FS) + if len(r) < 3 or len(p) < 3: + return float("nan") + r_t = r / ECG_FS + p_t = p / PPG_FS + ptts = [] + for rt in r_t: + cand = p_t[(p_t >= rt + 0.050) & (p_t <= rt + 0.500)] + if len(cand) == 1: + ptts.append((cand[0] - rt) * 1000.0) + if len(ptts) < 3: + return float("nan") + return float(np.median(ptts)) + + +class MIMICAlignedDataset(Dataset): + """Indexes windows across a set of cached shard directories. + + Args: + shard_roots: list of "/shard_XXXXX" paths (pre-downloaded) + build_index: if True, scan and build/save the window index; if False, + load existing index_path + index_path: where to cache the index (JSON: list[{shard_idx, row_idx, + win_start_ecg, win_start_ppg, subject_id, ptt_ms}]) + normalise: if True, apply bandpass + zscore per window + """ + + def __init__( + self, + shard_roots: list[Path], + index_path: Path, + build_index: bool = True, + normalise: bool = True, + subjects_allow: set[str] | None = None, + subset_frac: float = 1.0, + subset_seed: int = 0, + ): + self.shard_roots = [Path(p) for p in shard_roots] + self.index_path = Path(index_path) + self.normalise = normalise + self.subjects_allow = subjects_allow + if build_index or not self.index_path.exists(): + self._build_index() + self.index = json.loads(self.index_path.read_text()) + if subjects_allow is not None: + self.index = [r for r in self.index if r["subject_id"] in subjects_allow] + if subset_frac < 1.0: + rng = np.random.default_rng(subset_seed) + n_keep = max(1, int(len(self.index) * subset_frac)) + keep = rng.choice(len(self.index), size=n_keep, replace=False) + self.index = [self.index[i] for i in sorted(keep)] + self._shard_cache: dict[int, object] = {} + + def _build_index(self) -> None: + records = [] + for s_path in self.shard_roots: + sidx = int(s_path.name.split("_")[1]) + ds = load_from_disk(str(s_path)) + for row_idx in range(len(ds)): + row = ds[row_idx] + names = list(row["ecg_names"]) + if "II" not in names: + continue + subject_id = _parse_subject(row["record_name"]) + ecg_siglen = int(row["ecg_siglen"]) + ppg_siglen = int(row["ppg_siglen"]) + # require full windows only + n_win = min( + (ecg_siglen - ECG_WIN) // ECG_STRIDE + 1, + (ppg_siglen - PPG_WIN) // PPG_STRIDE + 1, + ) + if n_win <= 0: + continue + for w in range(n_win): + records.append({ + "shard_idx": sidx, + "row_idx": row_idx, + "subject_id": subject_id, + "win_start_ecg": w * ECG_STRIDE, + "win_start_ppg": w * PPG_STRIDE, + }) + self.index_path.parent.mkdir(parents=True, exist_ok=True) + self.index_path.write_text(json.dumps(records)) + + def _load_shard(self, sidx: int): + if sidx not in self._shard_cache: + for p in self.shard_roots: + if int(p.name.split("_")[1]) == sidx: + self._shard_cache[sidx] = load_from_disk(str(p)) + break + return self._shard_cache[sidx] + + def __len__(self) -> int: + return len(self.index) + + def __getitem__(self, idx: int) -> dict: + rec = self.index[idx] + ds = self._load_shard(rec["shard_idx"]) + row = ds[rec["row_idx"]] + ecg_full = np.asarray(row["ecg"], dtype=np.float32) + ppg_full = np.asarray(row["ppg"], dtype=np.float32)[0] + names = list(row["ecg_names"]) + ecg_lead = ecg_full[names.index("II")] + se = rec["win_start_ecg"] + sp = rec["win_start_ppg"] + ecg_win = ecg_lead[se : se + ECG_WIN].copy() + ppg_win = ppg_full[sp : sp + PPG_WIN].copy() + if ecg_win.shape[0] != ECG_WIN or ppg_win.shape[0] != PPG_WIN: + raise RuntimeError(f"bad window at idx {idx}: {ecg_win.shape}, {ppg_win.shape}") + + # PTT is computed ONLY at index-build time (cached in the index dict). + # __getitem__ stays cheap so the GPU isn't waiting on peak detection. + ptt_ms = float(rec.get("ptt_ms", float("nan"))) + if self.normalise: + ecg_win = _zscore(_bandpass(ecg_win, ECG_FS, 0.5, 40.0)) + ppg_win = _zscore(_bandpass(ppg_win, PPG_FS, 0.5, 8.0)) + return { + "ecg": torch.from_numpy(ecg_win).unsqueeze(0), # [1, 2500] + "ppg": torch.from_numpy(ppg_win).unsqueeze(0), # [1, 1250] + "subject_id": rec["subject_id"], + "ptt_ms": float(ptt_ms) if np.isfinite(ptt_ms) else float("nan"), + } + + +def split_by_subject( + subjects: Iterable[str], frac: float = 0.9, seed: int = 0 +) -> tuple[set[str], set[str]]: + subjects = sorted(set(subjects)) + rng = np.random.default_rng(seed) + perm = rng.permutation(len(subjects)) + cut = int(len(subjects) * frac) + train = {subjects[i] for i in perm[:cut]} + test = {subjects[i] for i in perm[cut:]} + return train, test + + +def collate_with_dt( + items: list[dict], + log_uniform_frac: float = 0.6, + dt_min_ms: float = 50.0, + dt_max_ms: float = 500.0, + rng: np.random.Generator | None = None, +) -> dict: + """Stack a batch and sample Δt. 60% log-uniform, 40% measured PTT where available.""" + rng = rng if rng is not None else np.random.default_rng() + ecg = torch.stack([b["ecg"] for b in items]) + ppg = torch.stack([b["ppg"] for b in items]) + ptts = np.array([b["ptt_ms"] for b in items], dtype=np.float32) + b = len(items) + dt_ms = np.empty(b, dtype=np.float32) + use_log = rng.random(b) < log_uniform_frac + log_lo, log_hi = np.log(dt_min_ms), np.log(dt_max_ms) + dt_ms[use_log] = np.exp(rng.uniform(log_lo, log_hi, size=int(use_log.sum()))) + # for the 40% branch: measured PTT when finite, else fallback to log-uniform + rest = ~use_log + for i in np.nonzero(rest)[0]: + if np.isfinite(ptts[i]): + dt_ms[i] = ptts[i] + else: + dt_ms[i] = np.exp(rng.uniform(log_lo, log_hi)) + return { + "ecg": ecg, + "ppg": ppg, + "dt_seconds": torch.from_numpy(dt_ms / 1000.0), + "ptt_ms": torch.from_numpy(ptts), + "subject_id": [b["subject_id"] for b in items], + } diff --git a/src/physiojepa/data_fast.py b/src/physiojepa/data_fast.py new file mode 100644 index 0000000000000000000000000000000000000000..8a72e243d0f348ab916ab6bcea5319a48a463225 --- /dev/null +++ b/src/physiojepa/data_fast.py @@ -0,0 +1,71 @@ +"""Fast mmap-backed dataset for precomputed ECG/PPG windows. + +__getitem__ is a single mmap slice (~0.1 ms) — no per-window I/O, no +bandpass, no zscore. All preprocessing happened in precompute_windows.py. +""" +from __future__ import annotations + +import json +import mmap +from pathlib import Path + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class MIMICFastDataset(Dataset): + def __init__( + self, + cache_dir: Path, + subjects_allow: set[str] | None = None, + ): + meta_path = Path(cache_dir) / "windows_meta.json" + meta = json.loads(meta_path.read_text()) + self.n_total = meta["n_windows"] + self.ecg_win = meta["ecg_win"] + self.ppg_win = meta["ppg_win"] + self.subjects = meta["subjects"] + self.ecg_bytes = self.ecg_win * 4 # float32 + self.ppg_bytes = self.ppg_win * 4 + + # Build index of allowed windows + if subjects_allow is not None: + self.indices = [i for i, s in enumerate(self.subjects) if s in subjects_allow] + else: + self.indices = list(range(self.n_total)) + + # mmap the binary files (read-only) + ecg_path = Path(cache_dir) / "windows_ecg.bin" + ppg_path = Path(cache_dir) / "windows_ppg.bin" + self._ecg_fh = open(ecg_path, "rb") + self._ppg_fh = open(ppg_path, "rb") + self._ecg_mm = mmap.mmap(self._ecg_fh.fileno(), 0, access=mmap.ACCESS_READ) + self._ppg_mm = mmap.mmap(self._ppg_fh.fileno(), 0, access=mmap.ACCESS_READ) + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, idx: int) -> dict: + real_idx = self.indices[idx] + ecg_off = real_idx * self.ecg_bytes + ppg_off = real_idx * self.ppg_bytes + ecg = np.frombuffer(self._ecg_mm, dtype=np.float32, + count=self.ecg_win, offset=ecg_off).copy() + ppg = np.frombuffer(self._ppg_mm, dtype=np.float32, + count=self.ppg_win, offset=ppg_off).copy() + return { + "ecg": torch.from_numpy(ecg).unsqueeze(0), # [1, 2500] + "ppg": torch.from_numpy(ppg).unsqueeze(0), # [1, 1250] + "subject_id": self.subjects[real_idx], + "ptt_ms": float("nan"), + } + + def __del__(self): + try: + self._ecg_mm.close() + self._ppg_mm.close() + self._ecg_fh.close() + self._ppg_fh.close() + except Exception: + pass diff --git a/src/physiojepa/dt_embed.py b/src/physiojepa/dt_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..7f87e196007ec98c0ace3ee00fb92f9ffa545193 --- /dev/null +++ b/src/physiojepa/dt_embed.py @@ -0,0 +1,24 @@ +"""Δt scalar → conditioning token in R^d via sinusoidal encoding.""" +from __future__ import annotations + +import math + +import torch +from torch import nn + + +class DeltaTEmbedding(nn.Module): + def __init__(self, d_model: int = 256, n_freqs: int = 32): + super().__init__() + # frequencies span 10 ms to 10 s — sinusoidal, fixed (not learned) + freqs = torch.exp( + torch.linspace(math.log(2 * math.pi), math.log(2 * math.pi / 10.0), n_freqs) + ) + self.register_buffer("freqs", freqs, persistent=False) + self.proj = nn.Linear(2 * n_freqs, d_model) + + def forward(self, dt_seconds: torch.Tensor) -> torch.Tensor: + # dt_seconds: [B] + x = dt_seconds.unsqueeze(-1) * self.freqs # [B, n_freqs] + emb = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) + return self.proj(emb) # [B, d] diff --git a/src/physiojepa/ecg_encoder.py b/src/physiojepa/ecg_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..af43c3ea6026dd700723684f2a6e906ec8e8ccae --- /dev/null +++ b/src/physiojepa/ecg_encoder.py @@ -0,0 +1,57 @@ +"""ECG patch tokeniser for single-lead II @ 250 Hz — v1 per E0 audit findings. + +Input: [B, 1, T], T = 50 × N (200 ms patches at 250 Hz) + +The research plan called for 2D (leads × time) patches over 12-lead @ 500 Hz. +E0 found the HF mirror is 3-lead (II/V/aVR) @ ~250 Hz, with lead II in 93.7% +of segments. We drop records without lead II, use a 1D patch scheme over a +single lead, and defer the multi-lead 2D variant to a future ablation if +lead availability becomes an issue. +""" +from __future__ import annotations + +import math + +import torch +from torch import nn + + +class ECGPatchTokeniser(nn.Module): + """Linear projection of fixed-length ECG patches + 1D sinusoidal PE.""" + + def __init__( + self, + patch_size: int = 50, # 200 ms at 250 Hz + d_model: int = 256, + max_patches: int = 128, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.d_model = d_model + self.proj = nn.Linear(patch_size, d_model) + self.register_buffer( + "pos_enc", self._sinusoidal_pe(max_patches, d_model), persistent=False + ) + + @staticmethod + def _sinusoidal_pe(n_pos: int, d: int) -> torch.Tensor: + pe = torch.zeros(n_pos, d) + pos = torch.arange(0, n_pos, dtype=torch.float32).unsqueeze(1) + div = torch.exp( + torch.arange(0, d, 2, dtype=torch.float32) * -(math.log(10_000.0) / d) + ) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + return pe + + def forward(self, ecg: torch.Tensor) -> torch.Tensor: + b, c, t = ecg.shape + assert c == 1, f"single-lead expected, got {c}" + assert t % self.patch_size == 0, ( + f"ECG length {t} not divisible by patch_size {self.patch_size}" + ) + n = t // self.patch_size + patches = ecg.view(b, n, self.patch_size) + tokens = self.proj(patches) + tokens = tokens + self.pos_enc[:n].unsqueeze(0) + return tokens diff --git a/src/physiojepa/ema.py b/src/physiojepa/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..69d06db35f964d80f4abf39705740cf5dbb21f56 --- /dev/null +++ b/src/physiojepa/ema.py @@ -0,0 +1,42 @@ +"""EMA target encoder with cosine-annealed tau schedule (0.996 -> 0.9999 over first 30%). + +Per Weimann & Conrad (T1-1) and I-JEPA (T1-2). +""" +from __future__ import annotations + +import copy +import math + +import torch +from torch import nn + + +def ema_tau(step: int, total_steps: int, start: float = 0.996, end: float = 0.9999, + warmup_frac: float = 0.30) -> float: + warmup = max(1, int(total_steps * warmup_frac)) + if step >= warmup: + return end + t = step / warmup + return end - 0.5 * (end - start) * (1 + math.cos(math.pi * t)) + + +class EMA(nn.Module): + """Wraps an online encoder + a detached target copy updated in-place.""" + + def __init__(self, online: nn.Module): + super().__init__() + self.target = copy.deepcopy(online) + for p in self.target.parameters(): + p.requires_grad_(False) + self.target.train(False) + + @torch.no_grad() + def update(self, online: nn.Module, tau: float) -> None: + for p_t, p_o in zip(self.target.parameters(), online.parameters()): + p_t.data.mul_(tau).add_(p_o.data, alpha=1 - tau) + for b_t, b_o in zip(self.target.buffers(), online.buffers()): + b_t.data.copy_(b_o.data) + + def forward(self, *args, **kwargs): + with torch.no_grad(): + return self.target(*args, **kwargs) diff --git a/src/physiojepa/masking.py b/src/physiojepa/masking.py new file mode 100644 index 0000000000000000000000000000000000000000..956132f86d9ac6d313193706eb4ee3f595a405eb --- /dev/null +++ b/src/physiojepa/masking.py @@ -0,0 +1,39 @@ +"""I-JEPA multi-block masking for 1D token sequences (Weimann-style).""" +from __future__ import annotations + +import torch + + +def multi_block_mask_1d( + n_tokens: int, + n_targets: int = 4, + target_size_range: tuple[int, int] = (4, 8), + mask_ratio: float = 0.5, + generator: torch.Generator | None = None, +) -> tuple[torch.Tensor, torch.Tensor]: + """Return (context_idx, target_idx) for one sequence. + + Chooses `n_targets` contiguous blocks as targets (no overlap), then the + complement of their union is the context. mask_ratio caps total target + fraction. + """ + target_mask = torch.zeros(n_tokens, dtype=torch.bool) + max_cover = int(mask_ratio * n_tokens) + covered = 0 + attempts = 0 + while covered < max_cover and attempts < 64: + attempts += 1 + lo, hi = target_size_range + size = int(torch.randint(lo, hi + 1, (1,), generator=generator).item()) + size = min(size, max_cover - covered) + if size <= 0: + break + start = int(torch.randint(0, max(1, n_tokens - size + 1), (1,), generator=generator).item()) + if target_mask[start : start + size].any(): + continue + target_mask[start : start + size] = True + covered += size + + target_idx = torch.nonzero(target_mask, as_tuple=False).squeeze(-1) + context_idx = torch.nonzero(~target_mask, as_tuple=False).squeeze(-1) + return context_idx, target_idx diff --git a/src/physiojepa/models.py b/src/physiojepa/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c23b3ab8a5ed37e9eec395dca3c49dda3c194f43 --- /dev/null +++ b/src/physiojepa/models.py @@ -0,0 +1,308 @@ +"""The four models under test. They share encoders, differ in loss and Delta-t. + +Model variants: + A: ECG-JEPA unimodal (I-JEPA self-prediction on ECG only) + B: cross-modal JEPA, delta_t = 0 + C: symmetric InfoNCE (no predictor) + F: PhysioJEPA v1 (cross-modal JEPA, variable delta_t) +""" +from __future__ import annotations + +from dataclasses import dataclass, field + +import torch +import torch.nn.functional as F +from torch import nn + +from .dt_embed import DeltaTEmbedding +from .ecg_encoder import ECGPatchTokeniser +from .ema import EMA +from .masking import multi_block_mask_1d +from .ppg_encoder import PPGPatchTokeniser +from .vit import CrossAttentionPredictor, ViT1D + + +@dataclass +class ModelConfig: + ecg_patch: int = 50 + ppg_patch: int = 25 + d_model: int = 256 + ecg_depth: int = 12 + ppg_depth: int = 6 + heads: int = 8 + pred_depth: int = 4 + max_tokens: int = 128 + # ablation knobs + query_mode: str = "learned" # "learned" | "sinusoidal" + mask_ratio: float = 0.50 + + +def _pool(x: torch.Tensor) -> torch.Tensor: + return x.mean(dim=1) + + +def _make_query_emb(cfg: ModelConfig) -> tuple[nn.Module | None, torch.Tensor | None]: + """Returns either a learned nn.Parameter wrapped in a tiny Module, or a + fixed sinusoidal table buffer. Caller should index with positions. + """ + if cfg.query_mode == "sinusoidal": + import math + n_pos, d = cfg.max_tokens, cfg.d_model + pe = torch.zeros(n_pos, d) + pos = torch.arange(0, n_pos, dtype=torch.float32).unsqueeze(1) + div = torch.exp(torch.arange(0, d, 2, dtype=torch.float32) * + -(math.log(10_000.0) / d)) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + return None, pe # caller stores as buffer + return None, None # caller creates learned Parameter + + +class ECGOnlyEncoder(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.tok = ECGPatchTokeniser(patch_size=cfg.ecg_patch, d_model=cfg.d_model, + max_patches=cfg.max_tokens) + self.trunk = ViT1D(depth=cfg.ecg_depth, d_model=cfg.d_model, heads=cfg.heads) + + def forward(self, ecg: torch.Tensor) -> torch.Tensor: + return self.trunk(self.tok(ecg)) # [B, N_e, d] + + +class PPGEncoder(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.tok = PPGPatchTokeniser(patch_size=cfg.ppg_patch, d_model=cfg.d_model, + max_patches=cfg.max_tokens) + self.trunk = ViT1D(depth=cfg.ppg_depth, d_model=cfg.d_model, heads=cfg.heads) + + def forward(self, ppg: torch.Tensor) -> torch.Tensor: + return self.trunk(self.tok(ppg)) + + +# --------------------------------------------------------------------------- +# Baseline A — ECG-JEPA unimodal (I-JEPA style self-prediction) +# --------------------------------------------------------------------------- +class BaselineA(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.cfg = cfg + self.ecg = ECGOnlyEncoder(cfg) + self.ecg_tgt = EMA(self.ecg) + self.predictor = CrossAttentionPredictor( + depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads + ) + _, sinpe = _make_query_emb(cfg) + if sinpe is None: + self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model)) + nn.init.trunc_normal_(self.query_emb, std=0.02) + else: + self.register_buffer("query_emb", sinpe, persistent=False) + + def step(self, batch: dict) -> dict: + ecg = batch["ecg"] # [B, 1, T] + b = ecg.shape[0] + n_ecg = ecg.shape[-1] // self.cfg.ecg_patch + ctx_idxs = [] + tgt_idxs = [] + for _ in range(b): + c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), + mask_ratio=self.cfg.mask_ratio) + ctx_idxs.append(c) + tgt_idxs.append(t) + # All sequences same B but variable ctx/tgt lengths — process per-sample + # then pack. For efficiency use a padded approach. + tok = self.ecg.tok(ecg) # [B, N, d] + trunk = self.ecg.trunk + # context forward: apply trunk on full sequence then gather ctx/tgt tokens + full_ctx = trunk(tok) # [B, N, d] + tgt_full = self.ecg_tgt.target.trunk(self.ecg_tgt.target.tok(ecg)).detach() + L_self = torch.tensor(0.0, device=ecg.device) + total = 0 + for i in range(b): + q = self.query_emb[tgt_idxs[i]].unsqueeze(0) # [1, n_t, d] + ctx_tokens = full_ctx[i : i + 1, ctx_idxs[i], :] + pred = self.predictor(q, ctx_tokens).squeeze(0) + tgt_v = tgt_full[i, tgt_idxs[i], :] + L_self = L_self + F.l1_loss(pred, tgt_v, reduction="mean") + total += 1 + L_self = L_self / max(total, 1) + return {"loss": L_self, "L_self": L_self.detach(), "L_cross": torch.tensor(0.0), + "z_ecg": _pool(full_ctx.detach())} + + def targets(self): + return [(self.ecg, self.ecg_tgt)] + + +# --------------------------------------------------------------------------- +# Shared cross-modal backbone for Baselines B, C, and E3 PhysioJEPA +# --------------------------------------------------------------------------- +class CrossModalBackbone(nn.Module): + """Dual online encoders + two EMA targets + cross-attention predictor + Δt emb.""" + + def __init__(self, cfg: ModelConfig, use_predictor: bool = True, use_delta_t: bool = True): + super().__init__() + self.cfg = cfg + self.use_predictor = use_predictor + self.use_delta_t = use_delta_t + self.ecg = ECGOnlyEncoder(cfg) + self.ppg = PPGEncoder(cfg) + self.ecg_tgt = EMA(self.ecg) + self.ppg_tgt = EMA(self.ppg) + if use_predictor: + self.predictor = CrossAttentionPredictor( + depth=cfg.pred_depth, d_model=cfg.d_model, heads=cfg.heads + ) + _, sinpe = _make_query_emb(cfg) + if sinpe is None: + self.query_emb = nn.Parameter(torch.zeros(cfg.max_tokens, cfg.d_model)) + nn.init.trunc_normal_(self.query_emb, std=0.02) + else: + self.register_buffer("query_emb", sinpe, persistent=False) + if use_delta_t: + self.dt_emb = DeltaTEmbedding(d_model=cfg.d_model) + + def encode_ctx(self, ecg: torch.Tensor) -> torch.Tensor: + return self.ecg(ecg) + + def encode_ppg_target(self, ppg: torch.Tensor) -> torch.Tensor: + with torch.no_grad(): + return self.ppg_tgt.target(ppg).detach() + + def predict_ppg(self, z_ecg: torch.Tensor, n_ppg_tokens: int, + dt_seconds: torch.Tensor | None) -> torch.Tensor: + b = z_ecg.shape[0] + q = self.query_emb[:n_ppg_tokens].unsqueeze(0).expand(b, -1, -1) + ctx = z_ecg + if self.use_delta_t and dt_seconds is not None: + dt_tok = self.dt_emb(dt_seconds).unsqueeze(1) # [B, 1, d] + ctx = torch.cat([ctx, dt_tok], dim=1) + return self.predictor(q, ctx) + + def targets(self): + return [(self.ecg, self.ecg_tgt), (self.ppg, self.ppg_tgt)] + + +# --------------------------------------------------------------------------- +# Baseline B — symmetric cross-modal JEPA, Δt = 0 +# --------------------------------------------------------------------------- +class BaselineB(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.cfg = cfg + self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=False) + + def step(self, batch: dict) -> dict: + ecg, ppg = batch["ecg"], batch["ppg"] + z_ecg = self.bb.encode_ctx(ecg) # [B, N_e, d] + z_ppg_tgt = self.bb.encode_ppg_target(ppg) # [B, N_p, d] + n_ppg = z_ppg_tgt.shape[1] + z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=None) + L_cross = F.l1_loss(z_pred, z_ppg_tgt) + + # auxiliary self-prediction on ECG (I-JEPA style) — same code path as BaselineA + n_ecg = z_ecg.shape[1] + b = z_ecg.shape[0] + tok = self.bb.ecg.tok(ecg) + full_ctx = self.bb.ecg.trunk(tok) + tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach() + L_self = torch.tensor(0.0, device=ecg.device) + for i in range(b): + c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) + if len(t) == 0: + continue + q = self.bb.query_emb[t].unsqueeze(0) + ctx_tokens = full_ctx[i : i + 1, c, :] + pred = self.bb.predictor(q, ctx_tokens).squeeze(0) + tgt_v = tgt_full[i, t, :] + L_self = L_self + F.l1_loss(pred, tgt_v) + L_self = L_self / max(b, 1) + + loss = L_cross + 0.3 * L_self + return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(), + "z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()), + "z_pred": _pool(z_pred.detach())} + + def targets(self): + return self.bb.targets() + + +# --------------------------------------------------------------------------- +# Baseline C — symmetric InfoNCE (no predictor) +# --------------------------------------------------------------------------- +class BaselineC(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.cfg = cfg + self.ecg = ECGOnlyEncoder(cfg) + self.ppg = PPGEncoder(cfg) + self.ecg_head = nn.Linear(cfg.d_model, cfg.d_model) + self.ppg_head = nn.Linear(cfg.d_model, cfg.d_model) + # Standard CLIP-style temperature init: physical τ ≈ 0.07 → multiplier ≈ 14.3. + # The earlier init log_tau=0 made multiplier=1, leaving logits ∈ [-1, 1] which + # gives loss ≈ ln(B) = uninformative ceiling. + self.log_tau = nn.Parameter(torch.log(torch.tensor(1.0 / 0.07))) + + def step(self, batch: dict) -> dict: + ecg, ppg = batch["ecg"], batch["ppg"] + z_ecg = F.normalize(self.ecg_head(_pool(self.ecg(ecg))), dim=-1) + z_ppg = F.normalize(self.ppg_head(_pool(self.ppg(ppg))), dim=-1) + tau = torch.clamp(self.log_tau.exp(), 0.01, 100.0) + logits = tau * z_ecg @ z_ppg.t() + b = z_ecg.shape[0] + labels = torch.arange(b, device=ecg.device) + loss = 0.5 * (F.cross_entropy(logits, labels) + F.cross_entropy(logits.t(), labels)) + return {"loss": loss, "L_cross": loss.detach(), "L_self": torch.tensor(0.0), + "z_ecg": z_ecg.detach(), "z_ppg": z_ppg.detach(), + "z_pred": z_ppg.detach(), "tau": tau.detach()} + + def targets(self): + return [] # no EMA — pure contrastive + + +# --------------------------------------------------------------------------- +# E3 — PhysioJEPA v1 (variable Δt cross-modal JEPA) +# --------------------------------------------------------------------------- +class PhysioJEPA(nn.Module): + def __init__(self, cfg: ModelConfig): + super().__init__() + self.cfg = cfg + self.bb = CrossModalBackbone(cfg, use_predictor=True, use_delta_t=True) + + def step(self, batch: dict) -> dict: + ecg, ppg = batch["ecg"], batch["ppg"] + dt = batch["dt_seconds"] # [B] + z_ecg = self.bb.encode_ctx(ecg) + z_ppg_tgt = self.bb.encode_ppg_target(ppg) + n_ppg = z_ppg_tgt.shape[1] + z_pred = self.bb.predict_ppg(z_ecg, n_ppg, dt_seconds=dt) + L_cross = F.l1_loss(z_pred, z_ppg_tgt) + + # auxiliary ECG self-prediction + n_ecg = z_ecg.shape[1] + b = z_ecg.shape[0] + tok = self.bb.ecg.tok(ecg) + full_ctx = self.bb.ecg.trunk(tok) + tgt_full = self.bb.ecg_tgt.target.trunk(self.bb.ecg_tgt.target.tok(ecg)).detach() + L_self = torch.tensor(0.0, device=ecg.device) + for i in range(b): + c, t = multi_block_mask_1d(n_ecg, n_targets=4, target_size_range=(4, 8), mask_ratio=self.cfg.mask_ratio) + if len(t) == 0: + continue + q = self.bb.query_emb[t].unsqueeze(0) + ctx_tokens = full_ctx[i : i + 1, c, :] + pred = self.bb.predictor(q, ctx_tokens).squeeze(0) + tgt_v = tgt_full[i, t, :] + L_self = L_self + F.l1_loss(pred, tgt_v) + L_self = L_self / max(b, 1) + + loss = L_cross + 0.3 * L_self + return {"loss": loss, "L_cross": L_cross.detach(), "L_self": L_self.detach(), + "z_ecg": _pool(z_ecg.detach()), "z_ppg": _pool(z_ppg_tgt.detach()), + "z_pred": _pool(z_pred.detach()), "dt": dt.detach()} + + def targets(self): + return self.bb.targets() + + +MODEL_REGISTRY = {"A": BaselineA, "B": BaselineB, "C": BaselineC, "F": PhysioJEPA} diff --git a/src/physiojepa/monitor.py b/src/physiojepa/monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..2529aee9cb3a8433d88edb4646c24bb4b960da74 --- /dev/null +++ b/src/physiojepa/monitor.py @@ -0,0 +1,42 @@ +"""Collapse monitor: track latent variance, effective rank, cross-modal cosine sim. + +Hard-stop criterion (per RESEARCH_DEVELOPMENT.md Pitfall 3): + mean cosine sim > 0.99 for 500 consecutive logged steps -> abort +""" +from __future__ import annotations + +import collections +from dataclasses import dataclass, field + +import torch + + +def effective_rank(z: torch.Tensor, eps: float = 1e-9) -> float: + """Entropy-based effective rank of the covariance matrix.""" + z = z - z.mean(dim=0, keepdim=True) + cov = (z.t() @ z) / max(z.shape[0] - 1, 1) + eig = torch.linalg.eigvalsh(cov.float()) + eig = torch.clamp(eig, min=0) + total = eig.sum() + eps + p = eig / total + entropy = -(p * torch.log(p + eps)).sum() + return float(torch.exp(entropy).item()) + + +def cross_modal_cosine(z_a: torch.Tensor, z_b: torch.Tensor) -> float: + a = torch.nn.functional.normalize(z_a, dim=-1) + b = torch.nn.functional.normalize(z_b, dim=-1) + return float((a * b).sum(dim=-1).mean().item()) + + +@dataclass +class CollapseMonitor: + window: int = 500 + threshold: float = 0.99 + history: collections.deque = field(default_factory=lambda: collections.deque(maxlen=500)) + + def update(self, cosine: float) -> bool: + self.history.append(cosine) + if len(self.history) < self.window: + return False + return all(c > self.threshold for c in self.history) diff --git a/src/physiojepa/ppg_encoder.py b/src/physiojepa/ppg_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..fcf933d47717f026621292dc96764bfa24bab649 --- /dev/null +++ b/src/physiojepa/ppg_encoder.py @@ -0,0 +1,61 @@ +"""PPG patch tokeniser — the v1 encoding chosen by E1. + +Decision: raw 200 ms patches (25 samples @ 125 Hz), linear projection to d. + +Rationale: E1 Stage-1 morphology extraction passed (98.6%), but Stage 2 (the +linear-probe AUROC comparison vs raw) requires AF labels that are pending. +The research plan (RESEARCH_DEVELOPMENT.md §2) specifies raw patches for v1 +and defers morphology to ablation A1. We follow the spec; the E1 Stage-2 +comparison runs as part of A1 once AF labels land. + +Input shape: [B, 1, T] PPG signal in volts after bandpass 0.5-8 Hz + z-score +Output shape: [B, N, d] N = T // patch_size tokens +""" +from __future__ import annotations + +import math + +import torch +from torch import nn + + +class PPGPatchTokeniser(nn.Module): + """Linear projection of fixed-length PPG patches + 1D sinusoidal PE.""" + + def __init__( + self, + patch_size: int = 25, # 200 ms at 125 Hz + d_model: int = 256, + max_patches: int = 128, + ) -> None: + super().__init__() + self.patch_size = patch_size + self.d_model = d_model + self.proj = nn.Linear(patch_size, d_model) + self.register_buffer( + "pos_enc", self._sinusoidal_pe(max_patches, d_model), persistent=False + ) + + @staticmethod + def _sinusoidal_pe(n_pos: int, d: int) -> torch.Tensor: + pe = torch.zeros(n_pos, d) + pos = torch.arange(0, n_pos, dtype=torch.float32).unsqueeze(1) + div = torch.exp( + torch.arange(0, d, 2, dtype=torch.float32) * -(math.log(10_000.0) / d) + ) + pe[:, 0::2] = torch.sin(pos * div) + pe[:, 1::2] = torch.cos(pos * div) + return pe + + def forward(self, ppg: torch.Tensor) -> torch.Tensor: + # ppg: [B, 1, T]; T must be divisible by patch_size + b, c, t = ppg.shape + assert c == 1, f"PPG must be single-channel, got {c}" + assert t % self.patch_size == 0, ( + f"PPG length {t} not divisible by patch_size {self.patch_size}" + ) + n = t // self.patch_size + patches = ppg.view(b, n, self.patch_size) + tokens = self.proj(patches) + tokens = tokens + self.pos_enc[:n].unsqueeze(0) + return tokens diff --git a/src/physiojepa/probe.py b/src/physiojepa/probe.py new file mode 100644 index 0000000000000000000000000000000000000000..604a0390c1bb568b5bb5b64570a9338c2e5152ff --- /dev/null +++ b/src/physiojepa/probe.py @@ -0,0 +1,68 @@ +"""Linear probe + simple evaluators for frozen encoders. + +AF AUROC on PTB-XL (lead II ECG, resampled 500->250 Hz), HR R^2, retrieval, +PTT regression (MLP). +""" +from __future__ import annotations + +from pathlib import Path + +import numpy as np +import torch +from sklearn.linear_model import LogisticRegression, Ridge +from sklearn.metrics import mean_absolute_error, r2_score, roc_auc_score +from sklearn.neural_network import MLPRegressor + + +@torch.no_grad() +def pooled_features(encoder: torch.nn.Module, x: torch.Tensor, device: torch.device, + batch_size: int = 64) -> np.ndarray: + encoder.train(False) + feats = [] + for i in range(0, len(x), batch_size): + chunk = x[i : i + batch_size].to(device) + z = encoder(chunk) # [B, N, d] + feats.append(z.mean(dim=1).cpu().numpy()) + return np.concatenate(feats, axis=0) + + +def linear_probe_auroc( + train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray, + max_iter: int = 2000, C: float = 1.0, +) -> float: + clf = LogisticRegression(max_iter=max_iter, C=C, solver="lbfgs") + clf.fit(train_X, train_y) + return float(roc_auc_score(test_y, clf.predict_proba(test_X)[:, 1])) + + +def linear_probe_r2( + train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray +) -> float: + reg = Ridge(alpha=1.0) + reg.fit(train_X, train_y) + return float(r2_score(test_y, reg.predict(test_X))) + + +def mlp_probe_mae( + train_X: np.ndarray, train_y: np.ndarray, test_X: np.ndarray, test_y: np.ndarray, + hidden: tuple[int, ...] = (128,), max_iter: int = 200, +) -> float: + m = MLPRegressor(hidden_layer_sizes=hidden, max_iter=max_iter, random_state=0) + m.fit(train_X, train_y) + return float(mean_absolute_error(test_y, m.predict(test_X))) + + +def retrieval_recall(z_query: np.ndarray, z_gallery: np.ndarray, k_list=(1, 5, 10)) -> dict: + # normalize + qn = z_query / (np.linalg.norm(z_query, axis=1, keepdims=True) + 1e-9) + gn = z_gallery / (np.linalg.norm(z_gallery, axis=1, keepdims=True) + 1e-9) + sim = qn @ gn.T # [Q, G] + n = sim.shape[0] + ranks = (-sim).argsort(axis=1) + gt = np.arange(n) + out = {} + for k in k_list: + top = ranks[:, :k] + hits = (top == gt[:, None]).any(axis=1).mean() + out[f"R@{k}"] = float(hits) + return out diff --git a/src/physiojepa/ptbxl.py b/src/physiojepa/ptbxl.py new file mode 100644 index 0000000000000000000000000000000000000000..641721686fb6c3f9b82c44354c547b2b89fa3ed3 --- /dev/null +++ b/src/physiojepa/ptbxl.py @@ -0,0 +1,110 @@ +"""PTB-XL loader that pulls from HuggingFace (`PULSE-ECG/PTB-XL`) or PhysioNet. + +We only need: lead II waveforms @ 500 Hz, resampled to 250 Hz, plus binary +AFIB label per record. The HF mirror is the default path because it needs no +credentialing. +""" +from __future__ import annotations + +import json +from pathlib import Path + +import numpy as np +import pandas as pd +from scipy.signal import resample_poly + + +def _resample_500_to_250(x: np.ndarray) -> np.ndarray: + return resample_poly(x, up=1, down=2, axis=-1).astype(np.float32) + + +def _parse_scp_dict(val): + """scp_codes column may be a dict or a stringified dict. Parse safely.""" + if isinstance(val, dict): + return val + if not isinstance(val, str): + return {} + # scp_codes in PTB-XL look like "{'NORM': 100.0, 'SR': 0.0}" — try JSON after + # swapping single to double quotes; fall back to a key scan. + try: + return json.loads(val.replace("'", '"')) + except Exception: + pass + out = {} + tokens = val.strip("{} ").split(",") + for tok in tokens: + if ":" not in tok: + continue + k, v = tok.split(":", 1) + k = k.strip().strip("'").strip('"') + try: + out[k] = float(v.strip()) + except ValueError: + pass + return out + + +def load_ptbxl_af_from_physionet_local(root: Path, limit: int | None = None): + """Load PTB-XL from a local PhysioNet download directory.""" + import wfdb + + root = Path(root) + meta = pd.read_csv(root / "ptbxl_database.csv", index_col="ecg_id") + meta["scp_parsed"] = meta["scp_codes"].apply(_parse_scp_dict) + meta["afib"] = meta["scp_parsed"].apply( + lambda d: int(any(k in ("AFIB", "AFLT") for k in d.keys())) + ) + if limit is not None: + meta = meta.sample(n=limit, random_state=0) + xs, ys = [], [] + for ecg_id, row in meta.iterrows(): + fn = root / row["filename_hr"] # 500 Hz + rec = wfdb.rdrecord(str(fn)) + signals = rec.p_signal + lead_names = rec.sig_name + lead_ii = signals[:, lead_names.index("II")] + x = _resample_500_to_250(lead_ii) + if x.shape[0] < 2500: + x = np.pad(x, (0, 2500 - x.shape[0])) + else: + x = x[:2500] + x = (x - x.mean()) / (x.std() + 1e-6) + xs.append(x) + ys.append(int(row["afib"])) + X = np.stack(xs).astype(np.float32)[:, None, :] + y = np.array(ys, dtype=np.int64) + return X, y + + +def load_ptbxl_af_from_hf(limit: int | None = None): + """Load PTB-XL via HuggingFace — open access, no credentials.""" + from datasets import load_dataset + + ds = load_dataset("PULSE-ECG/PTB-XL", split="train", streaming=False) + xs, ys = [], [] + for i, row in enumerate(ds): + if limit is not None and i >= limit: + break + scp = _parse_scp_dict(row.get("scp_codes", {})) + afib = int(any(k in ("AFIB", "AFLT") for k in scp)) + sig_raw = row.get("signal") or row.get("ecg") + sig = np.asarray(sig_raw, dtype=np.float32) + if sig.ndim != 2: + continue + lead_names = row.get("lead_names") or ["I", "II", "III", "aVR", "aVL", "aVF", + "V1", "V2", "V3", "V4", "V5", "V6"] + if "II" in lead_names: + lead_ii = sig[lead_names.index("II")] + else: + lead_ii = sig[1] + x = _resample_500_to_250(lead_ii) + if x.shape[0] < 2500: + x = np.pad(x, (0, 2500 - x.shape[0])) + else: + x = x[:2500] + x = (x - x.mean()) / (x.std() + 1e-6) + xs.append(x) + ys.append(afib) + X = np.stack(xs).astype(np.float32)[:, None, :] + y = np.array(ys, dtype=np.int64) + return X, y diff --git a/src/physiojepa/trainer.py b/src/physiojepa/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..0104f084f7417f11f7a7f0483de23ee1d31ab8b9 --- /dev/null +++ b/src/physiojepa/trainer.py @@ -0,0 +1,228 @@ +"""Training loop shared across all four models. + +Differences across runs are entirely in the model registered under `config.model`. +""" +from __future__ import annotations + +import json +import math +import os +import time +from dataclasses import dataclass, field +from pathlib import Path + +import numpy as np +import torch +import yaml +from torch.utils.data import DataLoader + +from .data import MIMICAlignedDataset, collate_with_dt, split_by_subject +from .ema import ema_tau +from .models import MODEL_REGISTRY, ModelConfig +from .monitor import CollapseMonitor, cross_modal_cosine, effective_rank + + +@dataclass +class TrainConfig: + run_name: str = "debug" + model: str = "F" # one of A, B, C, F + epochs: int = 100 + batch_size: int = 64 + lr: float = 1e-4 + weight_decay: float = 0.04 + warmup_epochs: int = 10 + ema_start: float = 0.996 + ema_end: float = 0.9999 + ema_warmup_frac: float = 0.30 + grad_clip: float = 1.0 + log_every: int = 100 + ckpt_every_epochs: int = 5 + seed: int = 0 + wandb_project: str = "physiojepa" + wandb_mode: str = "online" + wandb_entity: str | None = None + output_dir: str = "runs" + index_path: str = "cache/mimic_index.json" + shard_roots: list[str] = field(default_factory=list) + num_workers: int = 4 + amp: bool = True + # controls for Δt sampling inside collate_fn + log_uniform_frac: float = 0.6 + # window-level subsetting (for fast iteration / K2 gate runs) + subset_frac: float = 1.0 + # ablation knobs forwarded to ModelConfig + pred_depth: int = 4 + query_mode: str = "learned" + mask_ratio: float = 0.50 + # precomputed mmap dataset (overrides shard_roots + index_path if set) + fast_cache_dir: str = "" + + +def load_yaml_config(path: str) -> TrainConfig: + with open(path, "r") as f: + d = yaml.safe_load(f) + return TrainConfig(**d) + + +class _Collator: + """Top-level callable so DataLoader workers can serialize it across fork.""" + + def __init__(self, log_uniform_frac: float, seed: int): + self.log_uniform_frac = log_uniform_frac + self.seed = seed + self._rng = None + + def __call__(self, items): + if self._rng is None: + self._rng = np.random.default_rng(self.seed + os.getpid()) + return collate_with_dt(items, log_uniform_frac=self.log_uniform_frac, rng=self._rng) + + +def _build_dataloaders(cfg: TrainConfig) -> tuple[DataLoader, DataLoader, list[str]]: + if cfg.fast_cache_dir: + from .data_fast import MIMICFastDataset + cache_dir = Path(cfg.fast_cache_dir) + import json + meta = json.loads((cache_dir / "windows_meta.json").read_text()) + subjects = sorted(set(meta["subjects"])) + train_subj, val_subj = split_by_subject(subjects, frac=0.9, seed=cfg.seed) + train_ds = MIMICFastDataset(cache_dir, subjects_allow=train_subj) + val_ds = MIMICFastDataset(cache_dir, subjects_allow=val_subj) + else: + shard_roots = [Path(p) for p in cfg.shard_roots] + ds_full = MIMICAlignedDataset( + shard_roots=shard_roots, + index_path=Path(cfg.index_path), + build_index=not Path(cfg.index_path).exists(), + ) + subjects = sorted({r["subject_id"] for r in ds_full.index}) + train_subj, val_subj = split_by_subject(subjects, frac=0.9, seed=cfg.seed) + train_ds = MIMICAlignedDataset( + shard_roots, Path(cfg.index_path), build_index=False, subjects_allow=train_subj, + subset_frac=cfg.subset_frac, subset_seed=cfg.seed, + ) + val_ds = MIMICAlignedDataset( + shard_roots, Path(cfg.index_path), build_index=False, subjects_allow=val_subj, + ) + collate = _Collator(cfg.log_uniform_frac, cfg.seed) + train_loader = DataLoader( + train_ds, batch_size=cfg.batch_size, shuffle=True, + num_workers=cfg.num_workers, collate_fn=collate, drop_last=True, + persistent_workers=cfg.num_workers > 0, + ) + val_loader = DataLoader( + val_ds, batch_size=cfg.batch_size, shuffle=False, + num_workers=max(cfg.num_workers, 1), collate_fn=collate, drop_last=False, + ) + return train_loader, val_loader, subjects + + +def _cosine_lr(step: int, total_steps: int, base: float, warmup_steps: int) -> float: + if step < warmup_steps: + return base * (step + 1) / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + return 0.5 * base * (1 + math.cos(math.pi * progress)) + + +def train(cfg: TrainConfig) -> dict: + import wandb + + torch.manual_seed(cfg.seed) + np.random.seed(cfg.seed) + + device = torch.device("cuda" if torch.cuda.is_available() + else ("mps" if torch.backends.mps.is_available() else "cpu")) + train_loader, val_loader, subjects = _build_dataloaders(cfg) + print(f"[trainer] device={device} n_train_windows={len(train_loader.dataset)} " + f"n_val_windows={len(val_loader.dataset)} subjects={len(subjects)}") + + mcfg = ModelConfig( + pred_depth=cfg.pred_depth, + query_mode=cfg.query_mode, + mask_ratio=cfg.mask_ratio, + ) + model = MODEL_REGISTRY[cfg.model](mcfg).to(device) + opt = torch.optim.AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay) + scaler = torch.amp.GradScaler(device.type) if cfg.amp and device.type == "cuda" else None + + total_steps = cfg.epochs * len(train_loader) + warmup_steps = cfg.warmup_epochs * len(train_loader) + + wandb.init(project=cfg.wandb_project, name=cfg.run_name, config=cfg.__dict__, + mode=cfg.wandb_mode, entity=cfg.wandb_entity) + + monitor = CollapseMonitor() + step = 0 + out_root = Path(cfg.output_dir) / cfg.run_name + out_root.mkdir(parents=True, exist_ok=True) + aborted = False + for epoch in range(cfg.epochs): + model.train(True) + for batch in train_loader: + # move to device + for k in ("ecg", "ppg", "dt_seconds", "ptt_ms"): + if k in batch and isinstance(batch[k], torch.Tensor): + batch[k] = batch[k].to(device) + # lr schedule + lr_now = _cosine_lr(step, total_steps, cfg.lr, warmup_steps) + for g in opt.param_groups: + g["lr"] = lr_now + opt.zero_grad(set_to_none=True) + if scaler is not None: + with torch.amp.autocast("cuda"): + out = model.step(batch) + scaler.scale(out["loss"]).backward() + scaler.unscale_(opt) + torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) + scaler.step(opt) + scaler.update() + else: + out = model.step(batch) + out["loss"].backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip) + opt.step() + + # EMA update + tau = ema_tau(step, total_steps, cfg.ema_start, cfg.ema_end, cfg.ema_warmup_frac) + for online, tgt in model.targets(): + tgt.update(online, tau) + + if step % cfg.log_every == 0: + metrics = { + "step": step, "epoch": epoch, "lr": lr_now, "tau": tau, + "loss": float(out["loss"].detach().item()), + "L_cross": float(out.get("L_cross", torch.tensor(0.0)).item()), + "L_self": float(out.get("L_self", torch.tensor(0.0)).item()), + } + z_e = out.get("z_ecg") + if z_e is not None and z_e.shape[0] > 1: + metrics["ecg_latent_var"] = float(z_e.var(dim=0).mean().item()) + metrics["ecg_eff_rank"] = effective_rank(z_e) + z_p_pred = out.get("z_pred") + z_p_tgt = out.get("z_ppg") + if z_p_pred is not None and z_p_tgt is not None and z_p_pred.shape[0] > 1: + cosine = cross_modal_cosine(z_p_pred, z_p_tgt) + metrics["cross_modal_cosine"] = cosine + if monitor.update(cosine): + print(f"[trainer] COLLAPSE DETECTED at step={step} cosine={cosine:.4f}") + aborted = True + wandb.log(metrics, step=step) + print(f"[step {step}] loss={metrics['loss']:.4f} " + f"L_cross={metrics['L_cross']:.4f} L_self={metrics['L_self']:.4f} " + f"tau={tau:.4f}") + step += 1 + if aborted: + break + if aborted: + break + if (epoch + 1) % cfg.ckpt_every_epochs == 0 or epoch == cfg.epochs - 1: + ckpt = out_root / f"ckpt_epoch{epoch + 1:03d}.pt" + torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "epoch": epoch + 1, + "step": step}, ckpt) + print(f"[trainer] saved {ckpt}") + + final_ckpt = out_root / "ckpt_final.pt" + torch.save({"model": model.state_dict(), "cfg": cfg.__dict__, "aborted": aborted, + "step": step}, final_ckpt) + wandb.finish() + return {"aborted": aborted, "final_step": step, "ckpt": str(final_ckpt)} diff --git a/src/physiojepa/vit.py b/src/physiojepa/vit.py new file mode 100644 index 0000000000000000000000000000000000000000..9029e8831b6f84fda8a24478b3a4ec0d4566f37f --- /dev/null +++ b/src/physiojepa/vit.py @@ -0,0 +1,137 @@ +"""Minimal 1D ViT used for both ECG and PPG encoders. + +Shapes +------ +forward(x_tokens): [B, N, d] -> [B, N, d] + +Patch tokenisation is handled separately (see ecg_encoder.py / ppg_encoder.py) +so this module is purely the transformer trunk. +""" +from __future__ import annotations + +import torch +from torch import nn + + +class MHA(nn.Module): + def __init__(self, d: int, heads: int, attn_drop: float = 0.0, proj_drop: float = 0.0): + super().__init__() + assert d % heads == 0 + self.h = heads + self.dh = d // heads + self.qkv = nn.Linear(d, 3 * d, bias=True) + self.proj = nn.Linear(d, d, bias=True) + self.ad = nn.Dropout(attn_drop) + self.pd = nn.Dropout(proj_drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + b, n, d = x.shape + qkv = self.qkv(x).view(b, n, 3, self.h, self.dh).permute(2, 0, 3, 1, 4) + q, k, v = qkv[0], qkv[1], qkv[2] # [b, h, n, dh] + out = nn.functional.scaled_dot_product_attention( + q, k, v, dropout_p=self.ad.p if self.training else 0.0 + ) + out = out.transpose(1, 2).reshape(b, n, d) + return self.pd(self.proj(out)) + + +class Block(nn.Module): + def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0): + super().__init__() + self.n1 = nn.LayerNorm(d) + self.attn = MHA(d, heads, attn_drop=drop, proj_drop=drop) + self.n2 = nn.LayerNorm(d) + hidden = int(d * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop), + nn.Linear(hidden, d), nn.Dropout(drop), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x + self.attn(self.n1(x)) + x = x + self.mlp(self.n2(x)) + return x + + +class ViT1D(nn.Module): + """Token-in, token-out transformer trunk with final LayerNorm.""" + + def __init__( + self, + depth: int = 12, + d_model: int = 256, + heads: int = 8, + mlp_ratio: float = 4.0, + drop: float = 0.0, + ): + super().__init__() + self.blocks = nn.ModuleList( + [Block(d_model, heads, mlp_ratio, drop) for _ in range(depth)] + ) + self.norm = nn.LayerNorm(d_model) + + def forward(self, tokens: torch.Tensor) -> torch.Tensor: + x = tokens + for blk in self.blocks: + x = blk(x) + return self.norm(x) + + +class CrossAttnBlock(nn.Module): + """Self-attention → cross-attention(kv=context) → MLP.""" + + def __init__(self, d: int, heads: int, mlp_ratio: float = 4.0, drop: float = 0.0): + super().__init__() + self.n1 = nn.LayerNorm(d) + self.self_attn = MHA(d, heads, attn_drop=drop, proj_drop=drop) + self.n2q = nn.LayerNorm(d) + self.n2k = nn.LayerNorm(d) + self.h = heads + self.dh = d // heads + self.q = nn.Linear(d, d) + self.kv = nn.Linear(d, 2 * d) + self.op = nn.Linear(d, d) + self.n3 = nn.LayerNorm(d) + hidden = int(d * mlp_ratio) + self.mlp = nn.Sequential( + nn.Linear(d, hidden), nn.GELU(), nn.Dropout(drop), + nn.Linear(hidden, d), nn.Dropout(drop), + ) + + def forward(self, x: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor: + x = x + self.self_attn(self.n1(x)) + q = self.q(self.n2q(x)) + kv = self.kv(self.n2k(ctx)) + b, n, d = q.shape + m = ctx.shape[1] + q = q.view(b, n, self.h, self.dh).transpose(1, 2) + k, v = kv.view(b, m, 2, self.h, self.dh).permute(2, 0, 3, 1, 4) + o = nn.functional.scaled_dot_product_attention(q, k, v) + o = o.transpose(1, 2).reshape(b, n, d) + x = x + self.op(o) + x = x + self.mlp(self.n3(x)) + return x + + +class CrossAttentionPredictor(nn.Module): + """Query = positional tokens at target positions; KV = ECG context (+ optional Δt token).""" + + def __init__( + self, + depth: int = 4, + d_model: int = 256, + heads: int = 8, + mlp_ratio: float = 4.0, + drop: float = 0.0, + ): + super().__init__() + self.blocks = nn.ModuleList( + [CrossAttnBlock(d_model, heads, mlp_ratio, drop) for _ in range(depth)] + ) + self.norm = nn.LayerNorm(d_model) + + def forward(self, queries: torch.Tensor, ctx: torch.Tensor) -> torch.Tensor: + x = queries + for blk in self.blocks: + x = blk(x, ctx) + return self.norm(x)