File size: 4,095 Bytes
b9765dc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
---
license: mit
tags:
  - ecg
  - ppg
  - jepa
  - self-supervised
  - cardiac
  - representation-learning
datasets:
  - lucky9-cyou/mimic-iv-aligned-ppg-ecg
  - physionet/ptb-xl
---

# PhysioJEPA

**Self-supervised ECG-PPG representation learning via Joint Embedding Predictive Architecture.**

## Key finding: mask ratio is the hidden lever

We discovered that unimodal ECG-JEPA (Weimann & Conrad, 2024) has a
**predictor shortcut vulnerability** at the standard 50% mask ratio: the
predictor learns local-interpolation shortcuts that degrade downstream
performance as training progresses.

**Raising mask ratio from 50% to 75% eliminates the shortcut** and recovers
full downstream performance, matching cross-modal JEPA:

| Model | Mask | AF AUROC (ep25) |
|-------|------|-----------------|
| Unimodal ECG-JEPA | 0.50 | 0.703 |
| **Unimodal ECG-JEPA** | **0.75** | **0.848** |
| Cross-modal ECG-PPG JEPA | -- | 0.847 |
| PhysioJEPA (cross-modal + Δt) | -- | 0.835 |

The mechanism: at 50% masking with contiguous blocks, the predictor has 25
visible context patches and 25 target patches. It discovers a short-range
interpolation shortcut early in training (L_self dips at step ~1500). As
the encoder refines and patches become less linearly interpolatable, the
shortcut fails (L_self spikes at step ~4675). The encoder locks into a
self-consistent but downstream-uninformative optimum.

At 75% masking (12 visible, 37 target), no interpolation path exists. The
predictor learns long-range structure from the start.

Cross-modal prediction works by the same mechanism: 0% of PPG is visible
as context, so no interpolation shortcut can form.

## Confirmed by 5 ablation arms

1. **Slow tau** (ema_end=0.999, warmup=60%): spike persists -> tau is NOT the cause
2. **Smaller predictor** (depth 4->1): spike persists -> capacity is NOT the cause
3. **Sinusoidal queries** (no learned embeddings): spike WORSENS
4. **Mask ratio 0.75**: spike ELIMINATED, AUROC recovers to 0.848
5. **Full data** (10x): spike delayed but present -> architectural, not data-scale

## Architecture

- ECG encoder: ViT-S (12 layers, d=256, 8 heads) on single-lead II @ 250 Hz
- PPG encoder: ViT-T (6 layers, d=256) on Pleth @ 125 Hz
- Predictor: 4-layer cross-attention transformer
- EMA target encoder (tau 0.996 -> 0.9999 cosine over 30% of training)
- Loss: L1 latent prediction (cross-modal) + 0.3 * L1 ECG self-prediction

## Dataset

Training: [lucky9-cyou/mimic-iv-aligned-ppg-ecg](https://huggingface.co/datasets/lucky9-cyou/mimic-iv-aligned-ppg-ecg)
(MIMIC-IV ICU waveforms, ~814 hours, ~381 patients, sample-accurate ECG-PPG alignment)

Evaluation: PTB-XL (PhysioNet, 21.8k 12-lead ECGs, lead II resampled to 250 Hz)

## Usage

```bash
# Install
git clone https://huggingface.co/guychuk/PhysioJEPA
cd PhysioJEPA
uv sync

# Smoke test (CPU, random data)
PYTHONPATH=src uv run python scripts/smoke_test.py

# Train (requires GPU + MIMIC data)
PYTHONPATH=src uv run python scripts/train.py --config configs/base.yaml --model A --mask_ratio 0.75
```

## Repository structure

```
src/physiojepa/
  models.py      # 4 model variants (A=unimodal, B=cross-modal, C=InfoNCE, F=PhysioJEPA)
  vit.py         # ViT-1D encoder + cross-attention predictor
  data.py        # MIMIC dataset with sliding windows
  data_fast.py   # mmap-backed fast dataset for full-scale runs
  trainer.py     # shared training loop with WandB + collapse monitoring
  ema.py         # EMA with cosine tau schedule
  masking.py     # I-JEPA multi-block 1D masking
  probe.py       # linear probe evaluators
configs/
  base.yaml      # shared hyperparameters
docs/
  RESEARCH_LOG.md          # complete research narrative
  e2_e3_results.md         # K-gate results + ablation findings
  EXPERIMENT_TRACKING.md   # experiment matrix + post-hoc results
  RESEARCH_DEVELOPMENT.md  # full research development document
```

## Citation

```
@misc{physiojepa2026,
  title={PhysioJEPA: Mask Ratio as the Hidden Lever in Cardiac JEPA},
  author={Oz Labs},
  year={2026},
  url={https://huggingface.co/guychuk/PhysioJEPA}
}
```

## License

MIT