D-JEPA · baseline+vit+ema+vicreg_lam001
A frozen ViT-small encoder trained from scratch with a JEPA objective on The Well's active_matter simulation. The encoder produces patch-token embeddings from (11, 16, 256, 256) (C, T, H, W) snippets of the active-matter fields; pooled features are linearly informative about the two underlying physical parameters (active dipole strength α and steric alignment ζ) without any labeled supervision during pretraining.
This is the project champion out of 34 trained-and-evaluated cells across a 4-axis design space (routing × backbone × target × loss). Two checkpoints are included: the online encoder and its BYOL-style EMA twin.
Headline frozen-encoder probe scores
Test MSE on z-scored targets — a constant-mean predictor scores 1.0 by construction; lower is better.
| Metric | Test MSE | Selection on val |
|---|---|---|
| α linear probe (closed-form ridge) | 0.0063 | best ridge α = 1.0 |
| α kNN regression | 0.0147 | k=3, cosine |
| ζ linear probe (closed-form ridge) | 0.0680 | best ridge α = 1.0 |
| ζ kNN regression | 0.1017 | k=3, cosine |
Train / val / test split sizes: 700 / 96 / 104 trajectories. Per-target z-score statistics are fit on the train split only and applied unchanged to val/test.
Architecture
Standard ViT-small with V-JEPA-style 3D tubelet patch embedding:
| backbone | ViT-small (depth 12, 6 heads, MLP ratio 4.0, no dropout) |
| input shape | (B, 11, 16, 256, 256) (B, C, T, H, W) |
| patch size (spatial) | 16 × 16 |
| tubelet size (temporal) | 2 |
| token grid | (8 frames after tubelet) × (16 × 16 spatial patches) = 2048 tokens |
| embed dim | 384 |
| encoder params | 23,457,408 (~23.5 M) |
| total params (with predictor) | 34,400,640 — well under any 100 M cap |
Channel layout (fixed, all 11 channels are inputs in baseline routing):
| idx | channel | meaning |
|---|---|---|
| 0 | phi | concentration |
| 1, 2 | u_1, u_2 | velocity |
| 3–6 | D_11, D_12, D_21, D_22 | orientation tensor |
| 7–10 | E_11, E_12, E_21, E_22 | strain-rate tensor |
Training recipe
| dataset | The Well — active_matter (45 train / 16 valid / 21 test HDF5 files, ~49 GB) |
| objective | JEPA: ‖predictor(f(ctx)) − sg(f_tgt(tgt))‖² + λ · VICReg(f(ctx)) |
| target encoder | BYOL-style EMA of the online encoder, decay 0.996 |
| regularizer | VICReg (variance hinge weight 25, off-diagonal cov weight 1.0), outer scale λ = 0.01 |
| optimizer | AdamW (β = (0.9, 0.999)), cosine LR, weight-decay schedule cosine 0.05 → 0.4, bias/norm excluded |
| batch size | 2 |
| learning rate | 3e-4, warmup 2 epochs, decay to 1e-6 over 30 epochs (10,500 steps) |
| precision | AMP fp16 on a single NVIDIA RTX 4070 SUPER |
| seed | 0 (recorded in config.json) |
| wall time | 46.8 min |
The full merged config is shipped in config.json; the headline eval numbers are in eval_results.json.
Files in this repo
| File | Contents |
|---|---|
encoder.pt |
online encoder state_dict (149 tensors, ~97 MB fp32) |
target_encoder_ema.pt |
EMA target encoder state_dict (same architecture, EMA-averaged weights — usable as an alternate frozen feature extractor) |
model.py |
self-contained MIT-licensed encoder definition. Standard 3D-tubelet ViT + sincos position embedding; matches the state_dict layout of the shipped weights bit-for-bit. No dependencies beyond torch and numpy. |
config.json |
full merged training config |
eval_results.json |
linear-probe and kNN test MSE on α and ζ + the train-fit z-score statistics needed to invert predictions |
final.json |
wall time, parameter counts, final losses |
PROVENANCE.txt |
which checkpoint was packaged + step/epoch |
How to load
import torch, json
from model import build_encoder
cfg = json.load(open("config.json"))
m = cfg["model"]["encoder"]
encoder = build_encoder(
in_chans=11,
size=cfg["model"]["encoder_size"], # "small"
img_size=m["img_size"], # 256
patch_size=m["patch_size"], # 16
num_frames=m["num_frames"], # 16
tubelet_size=m["tubelet_size"], # 2
mlp_ratio=m["mlp_ratio"], # 4.0
)
encoder.load_state_dict(torch.load("encoder.pt", map_location="cpu"))
encoder.eval()
# x: (B, 11, 16, 256, 256) — 16-frame windows of the 11-channel active_matter fields
with torch.no_grad():
z = encoder(x) # (B, 2048, 384) — per-token features
pooled = z.mean(dim=1) # (B, 384) — mean-pooled, ready for a linear probe
To swap in the EMA target encoder, replace the load line with torch.load("target_encoder_ema.pt", ...) — same architecture, EMA-averaged weights.
To recover real-units predictions from the linear probe, undo the z-score using the linear_probe.stats block in eval_results.json:
import json, numpy as np
stats = json.load(open("eval_results.json"))["linear_probe"]["stats"]
means = np.array(stats["means"]) # [alpha_mean, zeta_mean]
stds = np.array(stats["stds"]) # [alpha_std, zeta_std]
# y_real = y_zscored * stds + means
Intended use
- Frozen-encoder feature extraction on
active_matter-distributed inputs (i.e. the same channel layout, frame count, and 256×256 spatial size). The encoder is shape-strict — the patch embed is 3D(C=11, T=16, H=256, W=256). - Linear / kNN probing for downstream regression on physical parameters of the same simulation family.
- Probe-development baseline for SSL methods on physics simulations.
Out of scope
- Finetuning: not the project setup; results below were obtained with a frozen encoder. The optimizer / scheduler state is intentionally not shipped.
- Other channel counts or input sizes: the patch embed is hardcoded to 11 input channels and a 256×256 / 16-frame window. Different inputs need a different stem.
- Other physics datasets in The Well: trained only on
active_matter; transfer is untested. - Image-classification-style use: this is a 3D spatio-temporal encoder for scalar/tensor field stacks, not RGB images.
Limitations and reproducibility caveats
- Non-bit-deterministic: the source repo seeds Python/numpy/torch CPU and CUDA RNG and records them per-run, but does not set
torch.backends.cudnn.deterministic = Trueortorch.use_deterministic_algorithms(True). Combined with AMP fp16, two re-runs land within a small numerical window but are not bit-identical across machines. - Single-seed: project champion is one seed (
seed: 0). No multi-seed variance band. - Probe-only evaluation: α and ζ are recovered via a single
Linearlayer (closed-form ridge with regularization swept on val) and a non-parametric kNN regressor (k and metric swept on val). No MLP head, no finetuning. - 104-sample test set: probe MSEs sit on a small evaluation split. Treat sub-percent differences between top cells as within-noise.
Source code, training pipeline, and full project context
This checkpoint comes from a wider ablation study: 34 cells across routing × backbone × target × loss. The source repo contains the trainer, encoder/predictor/loss code, eval pipeline, all configs, the per-run inventory, and the physics-validation scripts for the derived-field kernels.
See the project repo for the full REFACTORED_CODEBASE/, including README.md (with a reproducibility audit section), results/RUN_INVENTORY.md (every cell's metrics and hyperparameters), and results/PARETO.md (per-metric top-5s and the Pareto frontier).
Acknowledgments
The shipped model.py is a clean-room minimal implementation of standard architecture components — it is licensed under MIT alongside the weights and has no upstream code dependencies. The training pipeline that produced these weights (in the source repo) does vendor several upstream codebases:
- V-JEPA (facebookresearch/jepa) — ViT encoder, predictor, 3D tubelet patch embed, optimizer/schedule pattern. (Subject to its own upstream license; see the source repo's
ENV.md.) - physical-representation-learning (helenqu/physical-representation-learning) —
WellDatasetForJEPAHDF5 reader, run/config layout. - VICReg (Bardes, Ponce, LeCun. ICLR 2022. arXiv:2105.04906) — variance-invariance-covariance loss, re-implemented in this project.
- BYOL (Grill et al. NeurIPS 2020. arXiv:2006.07733) — EMA target-encoder pattern.
- The Well — the
active_mattersimulation dataset.
Architecture references for the model.py reimplementation:
- ViT trunk: Dosovitskiy et al. (2020). An Image is Worth 16x16 Words. arXiv:2010.11929
- 3D tubelet patch embedding: Arnab et al. (2021). ViViT: A Video Vision Transformer. arXiv:2103.15691
- Sincos position embedding: Vaswani et al. (2017). Attention is All You Need. arXiv:1706.03762
- Downloads last month
- 19