PhysioJEPA / docs /EXPERIMENT_TRACKING.md
guychuk's picture
Upload folder using huggingface_hub
31e2456 verified
# PhysioJEPA β€” Minimal Experiment Matrix
*Oz Labs β€” April 2026*
*Revision 2: post-reviewer critique. All "CausalCardio-JEPA" references replaced.*
---
## The single question this matrix answers
> Does predicting PPG at Ξ”t from ECG produce better cardiovascular representations
> than aligning ECG and PPG at t=0?
Every experiment below either answers this question or gates the next one.
Nothing else runs until K2 is resolved.
---
## Experiment map overview
```
Day 1–2 E0: Data audit β†’ Go/No-go on dataset
β”‚
β–Ό
Day 3 E1: Morphology vs raw β†’ Choose PPG encoding, once, forever
β”‚
β–Ό
Day 4–5 E2: Baselines A+B+C β†’ Establish floor and ceiling
β”‚
β–Ό
Day 6–8 E3: Ξ”t-JEPA v1 β†’ Core claim test (K1, K2, K3)
β”‚
β”œβ”€β”€ FAIL β†’ exit
β”‚
β–Ό
Day 9–10 E4: Rollout coherence β†’ World model validation
β”‚
β–Ό
Day 11–12 E5: PTT probe β†’ Downstream validation
β”‚
β–Ό
Day 13–14 E6: Ablation Ξ”t=0 vs Ξ”t>0 β†’ Isolate the single variable
β”‚
β–Ό
Day 15 Decision: paper or pivot
```
---
## E0 β€” Data audit
**Days 1–2 | Prerequisite for everything**
### What to run
```python
import datasets
ds = datasets.load_dataset("lucky9-cyou/mimic-iv-aligned-ppg-ecg")
# For each record, compute:
# 1. ECG-PPG alignment tolerance
alignment_error_ms = []
for record in ds:
r_peak_ts = detect_r_peaks(record['ecg'])
ppg_peak_ts = detect_ppg_peaks(record['ppg'])
ptt = align_peaks(r_peak_ts, ppg_peak_ts)
alignment_error_ms.append(ptt_variability(ptt))
# 2. Coverage
n_patients = len(set(record['subject_id'] for record in ds))
total_hours = sum(record['duration'] for record in ds) / 3600
missing_pct = mean_missing_rate(ds)
```
### Pass criteria β€” ALL must be true
| Metric | Pass | Fail action |
|--------|------|-------------|
| Median alignment ≀ 50ms | βœ“ proceed | Pivot to PhysioNet BIDMC |
| PTT within-patient std ≀ 80ms | βœ“ proceed | Same pivot |
| Patients β‰₯ 500 | βœ“ proceed | Supplement with PhysioNet MIMIC-III waveforms |
| Missing rate ≀ 20% after windowing | βœ“ proceed | Tighten quality filter |
| PTT range [50ms, 500ms] physiologically plausible | βœ“ proceed | Check synchronisation method |
### Output
- `data_card.md`: patients, hours, alignment stats, missing rates
- `ptt_histogram.png`: histogram of measured PTT per patient
- Go/no-go decision logged in `experiments/e0_decision.md`
**If E0 fails**: PhysioNet BIDMC (ECG + PPG, documented 0.1ms alignment, 53 subjects β€” smaller but clean). All downstream experiments are identical; only scale changes.
---
## E1 β€” Morphology vs raw PPG patches
**Day 3 | One-time architectural decision**
### What to run
Two target encoders, same ViT-S backbone, 10% of data, 20 epochs each:
**E1a β€” Raw patch encoder**
- PPG windowed into 200ms patches (25 samples at 125Hz)
- Linear projection β†’ d=256 tokens
- Standard I-JEPA spatial masking within window
**E1b β€” Morphological encoder**
- Per-beat features: systolic peak height, diastolic notch depth, pulse width, upstroke slope, augmentation index
- Extracted via Bishop & Ercole peak detection + `scipy.signal`
- Linear projection β†’ d=256 tokens per beat
### Metrics to compare
| Metric | What it tests |
|--------|--------------|
| % beats with valid morphology extraction | Is E1b viable on this dataset? |
| Target encoder latent variance | Stability (collapse check) |
| Linear probe AUROC on AF (frozen, 100 AF / 100 normal) | Representation quality |
| MAE of PTT regression from frozen encoder | Vascular information content |
### Decision rule (made once, frozen)
```
if morphology_extraction_rate < 0.70:
USE raw patches (E1a)
elif E1b linear_probe_AUROC > E1a + 0.02:
USE morphological (E1b)
else:
USE raw patches (E1a) β€” simpler, fewer failure modes
```
### Output
- `e1_decision.md`: which encoder, exact threshold used, quality stats
- `ppg_encoder.py`: the chosen implementation, committed to repo
---
## E2 β€” Baseline suite
**Days 4–5 | Floor and ceiling**
Run all three in parallel. Same data split, same 20 epochs, same evaluation harness.
These are reference points for E3, not ablations.
### AF label source β€” decide before running E2
**Decision required by**: Day 3 (before baselines start training)
**Owner**: Zack
**Option 1 β€” MIMIC-IV ECG module (preferred)**
Join `mimic-iv-ecg` rhythm annotations to the aligned waveform dataset by `subject_id` + `hadm_id`.
- Pros: in-distribution, same patient population as training data
- Cons: requires verifying the join yields enough AF-positive patients (need β‰₯100 AF, β‰₯100 normal for the linear probe to be meaningful)
- Check: `SELECT count(*) FROM mimic-iv-ecg WHERE rhythm = 'atrial fibrillation'` on the HF mirror
**Option 2 β€” PTB-XL (fallback)**
Use PTB-XL rhythm labels as the AF evaluation benchmark.
- Pros: clean, well-labelled, already used by Weimann & Conrad (enables direct comparison)
- Cons: different population (German outpatient vs MIMIC ICU) β€” becomes a generalisation test, not in-distribution
- Note: framing in paper changes slightly to "transfer to PTB-XL" rather than "in-distribution evaluation"
**Option 3 β€” PhysioNet AFDB**
MIT-BIH AF Database: 25 long-term ECG recordings with AF annotations.
- Only if Options 1 and 2 both fail
- Very small; only useful for AUROC, not for sample efficiency curves
**Decision log**:
```
AF_LABEL_SOURCE = "" # fill in before Day 4
DECISION_DATE = ""
DECISION_BY = ""
N_AF_POSITIVE = 0 # verify after join/filter
N_AF_NEGATIVE = 0
```
### Baseline A β€” ECG-JEPA (Weimann & Conrad exact replication)
```python
# Fork: github.com/kweimann/ECG-JEPA
# Config: ViT-S/8, multi-block masking, EMA Ο„=0.996
# Input: ECG only (no PPG at all)
# Loss: standard I-JEPA L1 latent prediction (within ECG)
```
This is the unimodal ceiling. If our model can't match this on ECG-only tasks, something is wrong with the cross-modal architecture.
### Baseline B β€” Symmetric cross-modal JEPA (Ξ”t = 0)
```python
# Architecture: identical to E3 in every detail
# EXCEPT: Ξ”t is hardcoded to 0
# - context: ECG window at time t
# - target: PPG window at the SAME time t (no lag)
# - predictor: cross-attention ECG β†’ PPG
# Loss: L1 latent prediction
```
This isolates the Ξ”t variable. If E3 beats B on the same tasks, Ξ”t matters. If not, the core claim fails.
### Baseline C β€” InfoNCE contrastive (AnyPPG-style)
```python
# Architecture: same dual encoder
# Loss: symmetric InfoNCE
# z_ecg = ecg_encoder(ECG_t)
# z_ppg = ppg_encoder(PPG_t)
# L = InfoNCE(z_ecg, z_ppg, temperature=0.07)
# No Ξ”t, no prediction β€” pure alignment
```
This is the comparison against the dominant paradigm in the field.
### Metrics for all three
```
After 20 epochs on 10% data, for each model:
1. Pretraining loss convergence curve
2. Linear probe AUROC β€” AF detection (frozen encoder)
3. Linear probe RΒ² β€” HR estimation (frozen encoder)
4. Latent variance + eigenspectrum rank (collapse check)
5. UMAP: coloured by patient ID, AF status, HR decile
```
### What to learn from E2 before running E3
| Observation | Implication |
|-------------|-------------|
| Baseline A AUROC > 0.80 | ECG alone is strong; cross-modal has a high bar |
| Baseline B collapses | Symmetric cross-modal JEPA is unstable; add SIGReg to E3 from the start |
| Baseline C > Baseline A | Cross-modal information helps; our model has something to beat |
| All three collapse | Data quality problem β€” revisit E0 |
---
## E3 β€” Ξ”t-JEPA v1
**Days 6–8 | The paper test**
Minimal version of the actual contribution.
PPG encoding from E1 decision. No SIGReg. No cardiac phase encoding.
Just: ECG context predicts PPG target at t+Ξ”t.
### Architecture
```python
# ECG encoder: ViT-S/8, 2D patches (leads Γ— time), EMA target
# PPG encoder: ViT-S/8, encoding chosen in E1, EMA target
# Predictor: 4-layer cross-attention transformer
# query = positional tokens for target PPG beats
# key/val = ECG context latents + Ξ”t embedding
# Ξ”t embed: sinusoidal over [50ms, 500ms] β†’ R^256
# Loss:
# L_cross = L1(predicted_ppg_latent, ema_ppg_encoder_output)
# L_self = L1(masked_ecg_pred, ema_ecg_target) [auxiliary, Ξ±=0.3]
# L_total = L_cross + Ξ± * L_self
# Ξ”t sampling per batch:
# 60% log-uniform in [50ms, 500ms]
# 40% ground-truth PTT from dataset
```
### Training config
```yaml
epochs: 100
batch_size: 64
optimizer: AdamW, lr=1e-4, weight_decay=0.04
scheduler: cosine with 10-epoch warmup
ema_tau: 0.996 β†’ 0.9999 over first 30% of training
window: 10s ECG + matched PPG
stride: 5s
data: 100% of passing-E0 records
```
### Collapse monitoring (every 100 steps)
```python
# Log these β€” stop if cross_modal_cosim > 0.99 for 500 consecutive steps
metrics = {
'ecg_latent_variance': var(z_ecg).mean(),
'ppg_latent_variance': var(z_ppg).mean(),
'cross_modal_cosim': cosine_sim(z_ecg_pooled, z_ppg_pred).mean(),
'ecg_eigenspectrum_rank': effective_rank(cov(z_ecg)),
}
```
### Kill criteria β€” evaluated at epoch 25
**K1 β€” Is the model learning anything?**
```python
mean_baseline_loss = L1(z_ppg_target, z_ppg_mean_over_dataset)
# PASS: model_loss < 0.85 * mean_baseline_loss
```
**K2 β€” Does Ξ”t matter? (the core claim)**
```python
# Run identical linear probe on frozen E3 and Baseline B encoders
# PASS: E3_AUROC > Baseline_B_AUROC + 0.02 (AF detection)
# OR E3_RΒ² > Baseline_B_RΒ² + 0.05 (HR estimation)
# At least one metric must pass
```
**K3 β€” Does cross-modal not hurt relative to unimodal?**
```python
# PASS: E3_AUROC >= Baseline_A_AUROC (within 0.01)
```
### Decision tree at epoch 25
```
K1 FAIL β†’ Stop entirely.
Data is unusable or encoder collapsed.
Check alignment, quality filtering, EMA schedule.
If clean: the architecture is wrong. Move to Architecture A (temporal ECG-JEPA only).
K2 FAIL β†’ Stop. The paper does not exist.
Ξ”t-aware prediction β‰ˆ t-aligned prediction.
Pivot options:
(a) Architecture A β€” temporal unimodal ECG-JEPA
(b) Study 4 β€” anomaly detection reusing this codebase
(c) Rerun with cleaner BIDMC data before final decision.
K2 PASS + K3 FAIL β†’ Cross-modal hurts.
Run 10 more epochs. If still failing:
Reduce PPG encoder capacity, check EMA instability.
If persistent: use lighter PPG encoder (ViT-T instead of ViT-S).
K1 βœ“, K2 βœ“, K3 βœ“ β†’ Continue to epoch 100. Proceed to E4.
```
---
## E4 β€” Rollout coherence test
**Days 9–10 | World model validation**
This is the experiment that separates "JEPA with a lag" from "a cardiovascular world model." Without it, the paper cannot make the world model claim.
### Protocol
```python
# Frozen encoder + trained predictor. N=200 held-out patients.
for patient in held_out_patients:
z_ecg = ecg_encoder(ecg_window_t)
# Predict at a grid of Ξ”t values
delta_t_grid = [50, 100, 150, 200, 250, 300, 350, 400, 450, 500] # ms
errors = []
for dt in delta_t_grid:
z_ppg_pred = predictor(z_ecg, delta_t=dt)
z_ppg_true = ppg_encoder(ppg_window_at_t_plus_dt)
errors.append(L1(z_ppg_pred, z_ppg_true))
# Find optimal Ξ”t (prediction error minimum)
optimal_delta_t[patient] = delta_t_grid[argmin(errors)]
```
### Physiological consistency checks
```python
# Check 1: Does optimal_Ξ”t correlate with measured PTT?
correlation = spearman(optimal_delta_t, measured_ptt_per_patient)
# PASS: correlation > 0.30
# Check 2: HR-PTT inverse relationship
# High HR β†’ shorter PTT β†’ shorter optimal Ξ”t
high_hr = windows_where(hr > 90 bpm)
low_hr = windows_where(hr < 60 bpm)
# PASS: mean(optimal_Ξ”t[high_hr]) < mean(optimal_Ξ”t[low_hr]), p < 0.05
# Check 3: U-shaped error curve (predictor has a real minimum, not flat)
for patient in sample_50_patients:
assert has_clear_minimum(errors) # not monotone, not flat
# PASS: β‰₯ 60% of patients have clear minimum
```
### Pass criteria
| Check | Pass | Implication if pass |
|-------|------|---------------------|
| Spearman > 0.30 | Model learned PTT implicitly | Core world-model claim supported |
| HR-PTT ordering | Physiologically consistent | Not a lookup table |
| U-curve β‰₯ 60% | Predictor has a real minimum | Latent space is smooth |
### If E4 passes but E5 PTT probe fails
The representation has the information but a linear probe can't extract it. Try a 3-layer MLP probe. If that also fails, the PTT information is encoded nonlinearly β€” mention this as a limitation but don't remove the E4 claim from the paper.
---
## E5 β€” Downstream probes
**Days 11–12 | Validation signals**
These run on frozen encoders from E3 best checkpoint. They are probes, not contributions.
### E5a β€” PTT regression probe
```python
mlp_ptt = MLP(in=256, hidden=128, out=1)
train(mlp_ptt,
X = pool(ecg_latent),
y = measured_ptt_per_beat,
split = patient_level_80_20)
# Report:
# MAE (ms) vs naive mean-PTT baseline
# Pearson(predicted_ptt, measured_ptt)
# Within-patient: does the probe track PTT changes over time?
```
### E5b β€” AF detection sample efficiency
```python
# Same linear probe as used in E2/E3 β€” enables direct comparison
# Label fractions: 1%, 5%, 10%, 50%, 100%
# Models: E3 vs Baseline_A vs Baseline_C
# Goal: sample efficiency curve (not just full-data comparison)
```
### E5c β€” HR estimation
```python
# Linear regression on frozen latent β†’ HR
# Baseline: RR-interval to HR (trivial β€” sets floor)
```
### What must be true for the paper
| Result | Why it matters |
|--------|----------------|
| E5a MAE < naive by β‰₯ 20% | PTT is in the latent β€” confirms E4 |
| E5b: E3 β‰₯ Baseline_A at all label fractions | Cross-modal doesn't hurt |
| E5b: E3 > Baseline_C at 1% labels | JEPA more sample-efficient than InfoNCE |
---
## E6 β€” The decisive ablation
**Days 13–14 | The main result**
One variable changed. Everything else identical.
| Model | Ξ”t | Architecture |
|-------|-----|-------------|
| E3 (PhysioJEPA) | log-uniform [50, 500ms] | Identical |
| Baseline B (t-aligned) | Fixed 0ms | Identical |
Both trained to 100 epochs, full data. Evaluated identically.
### The comparison table (this becomes Table 1 of the paper)
```
Model | AF AUROC | HR RΒ² | PTT RΒ² | ECG-PPG R@1
────────────────────────────────────────────────────────────────
Baseline A (ECG) | | | N/A | N/A
Baseline B (Ξ”t=0) | | | |
Baseline C (InfoNCE)| | | |
E3 (Ξ”t>0, ours) | | | |
```
### Paper-level claim, if E6 supports it
> Predicting PPG at variable time offset Ξ”t from ECG produces latent representations
> that implicitly encode vascular timing structure (PTT).
> Contrastive alignment at t=0 and predictive alignment at t=0 both destroy this structure.
> This is demonstrated by improved PTT regression, superior sample efficiency on AF detection,
> and physiologically consistent rollout behaviour under varying heart rate.
One paragraph. Defensible. Not overclaiming causality or blood pressure.
---
## Day 15 β€” Decision
```
GREEN β€” all of K1, K2, K3, E4 coherence, E6 Ξ”t > Ξ”t=0
β†’ Write the paper.
β†’ Weeks 3–4: run ablations A1–A5 (morphology, phase encoding,
SIGReg, PTT head, curriculum Ξ”t).
β†’ Target venues (with actual 2026 deadlines):
NeurIPS 2026 workshops (TS4H, BrainBodyFM): ~August 2026
ML4H 2026 symposium (archival proceedings track): ~September 2026
ICLR 2027: ~October 2026 (needs strong E4 + clean ablations)
YELLOW β€” K2 passes weakly, E4 marginal
β†’ Extend E3 to 200 epochs before deciding.
β†’ If still weak: reframe as temporal ECG-JEPA (Architecture A).
Smaller claim but still publishable as an extension of Weimann & Conrad.
Target: NeurIPS 2026 workshop TS4H.
RED β€” K2 fails
β†’ The core idea does not work on this dataset at this scale.
β†’ Immediate pivot options:
(a) Architecture A (temporal ECG-JEPA, unimodal) β€” reuses everything
(b) Study 4 (anomaly detection via prediction error) β€” same codebase
(c) Re-run E0 on PhysioNet BIDMC before final call.
Note: CHIL 2026 deadline (Apr 17) has passed. MLHC 2026 (Apr 17) has passed.
Next realistic archival venue: ML4H 2026 (~Sep 2026 estimated).
```
---
## Post-hoc (2026-04-15): K2 failed, K3 passed, Ο„ mechanism falsified
Actual results from the E2/E3 run (subset_frac=0.10, 25 epochs, seed=42):
| Model | Config | ep5 | ep10 | ep25 |
|-------|--------|-----|------|------|
| F (Ξ”t>0) | PhysioJEPA v1 | 0.652 | 0.859 | 0.835 |
| B (Ξ”t=0) | symmetric cross-modal | 0.660 | 0.844 | **0.847** |
| A (unimodal) | ECG-JEPA | 0.783 | 0.736 | 0.703 |
| C (InfoNCE) | symmetric | β€” | β€” | under-tuned; not usable |
**K2: FAIL.** Fβˆ’B at ep25 = βˆ’0.012 (target was +0.02). Ξ”t doesn't matter.
**K3: PASS BIG.** Fβˆ’A at ep25 = +0.133. Cross-modal beats unimodal by
~0.13 AUROC.
**Ο„-saturation mechanism (slow-Ο„ A ablation): FALSIFIED.**
Slow-Ο„ A (ema_end=0.999, warmup_frac=0.60) had L_self rising *more* than
original A through steps 2000-5000, not less. Ο„ is not the lever.
Working hypothesis for A's degradation: predictor+query-embedding overfits
to a narrow target distribution in unimodal training. Cross-modal training
provides target diversity the predictor can't overfit to, which is why
F/B stay stable. Needs a different ablation (e.g. shrink predictor, shrink
query embedding, vary masking ratio) to confirm.
## Summary
| Day | Experiment | Key output | Decision gated |
|-----|-----------|-----------|----------------|
| 1–2 | E0: data audit | data_card.md, PTT histogram | Dataset go/no-go |
| 3 | E1: PPG encoding | e1_decision.md, ppg_encoder.py | Architecture lock |
| 4–5 | E2: baselines | Floor + ceiling numbers | Calibrates E3 expectations |
| 6–8 | E3: Ξ”t-JEPA v1 | K1/K2/K3 at epoch 25 | Paper exists or doesn't |
| 9–10 | E4: rollout coherence | World model evidence | World model claim |
| 11–12 | E5: probes | PTT, AF, HR numbers | Downstream story |
| 13–14 | E6: decisive ablation | Table 1 | Paper's main result |
| 15 | Decision | Green / yellow / red | What gets written |
**Compute to day 15 decision point: ~50–70 GPU-hours. Cost: ~$125–175.**
K2 is answered by day 8. Everything after that is filling in the paper.
---
## Division of work
| Task | Owner |
|------|-------|
| E0: data pipeline, quality metrics, PTT computation | Zack |
| E1: morphology extractor, two-encoder comparison | Zack |
| E2: ECG-JEPA fork (Baseline A), training | Guy |
| E2: InfoNCE baseline (Baseline C) | Zack |
| E2: Symmetric JEPA (Baseline B) | Guy |
| E3: Ξ”t-JEPA architecture + training loop | Guy |
| E3: collapse monitoring, checkpoint saving | Both |
| E4: rollout coherence test, physiological checks | Guy |
| E5: probe training harness, sample efficiency curves | Zack |
| E6: final comparison, Table 1 | Both |
| Day 15 decision | Both |
---
*Designed so the most important question β€” does Ξ”t matter? β€” is answered by day 8, not day 28.*
*Total time to go/no-go: 8 days. Total compute: ~50–70 GPU-hours.*