D-JEPA · baseline+vit+ema+vicreg_lam001

A frozen ViT-small encoder trained from scratch with a JEPA objective on The Well's active_matter simulation. The encoder produces patch-token embeddings from (11, 16, 256, 256) (C, T, H, W) snippets of the active-matter fields; pooled features are linearly informative about the two underlying physical parameters (active dipole strength α and steric alignment ζ) without any labeled supervision during pretraining.

This is the project champion out of 34 trained-and-evaluated cells across a 4-axis design space (routing × backbone × target × loss). Two checkpoints are included: the online encoder and its BYOL-style EMA twin.

Headline frozen-encoder probe scores

Test MSE on z-scored targets — a constant-mean predictor scores 1.0 by construction; lower is better.

Metric Test MSE Selection on val
α linear probe (closed-form ridge) 0.0063 best ridge α = 1.0
α kNN regression 0.0147 k=3, cosine
ζ linear probe (closed-form ridge) 0.0680 best ridge α = 1.0
ζ kNN regression 0.1017 k=3, cosine

Train / val / test split sizes: 700 / 96 / 104 trajectories. Per-target z-score statistics are fit on the train split only and applied unchanged to val/test.

Architecture

Standard ViT-small with V-JEPA-style 3D tubelet patch embedding:

backbone ViT-small (depth 12, 6 heads, MLP ratio 4.0, no dropout)
input shape (B, 11, 16, 256, 256) (B, C, T, H, W)
patch size (spatial) 16 × 16
tubelet size (temporal) 2
token grid (8 frames after tubelet) × (16 × 16 spatial patches) = 2048 tokens
embed dim 384
encoder params 23,457,408 (~23.5 M)
total params (with predictor) 34,400,640 — well under any 100 M cap

Channel layout (fixed, all 11 channels are inputs in baseline routing):

idx channel meaning
0 phi concentration
1, 2 u_1, u_2 velocity
3–6 D_11, D_12, D_21, D_22 orientation tensor
7–10 E_11, E_12, E_21, E_22 strain-rate tensor

Training recipe

dataset The Well — active_matter (45 train / 16 valid / 21 test HDF5 files, ~49 GB)
objective JEPA: ‖predictor(f(ctx)) − sg(f_tgt(tgt))‖² + λ · VICReg(f(ctx))
target encoder BYOL-style EMA of the online encoder, decay 0.996
regularizer VICReg (variance hinge weight 25, off-diagonal cov weight 1.0), outer scale λ = 0.01
optimizer AdamW (β = (0.9, 0.999)), cosine LR, weight-decay schedule cosine 0.05 → 0.4, bias/norm excluded
batch size 2
learning rate 3e-4, warmup 2 epochs, decay to 1e-6 over 30 epochs (10,500 steps)
precision AMP fp16 on a single NVIDIA RTX 4070 SUPER
seed 0 (recorded in config.json)
wall time 46.8 min

The full merged config is shipped in config.json; the headline eval numbers are in eval_results.json.

Files in this repo

File Contents
encoder.pt online encoder state_dict (149 tensors, ~97 MB fp32)
target_encoder_ema.pt EMA target encoder state_dict (same architecture, EMA-averaged weights — usable as an alternate frozen feature extractor)
model.py self-contained MIT-licensed encoder definition. Standard 3D-tubelet ViT + sincos position embedding; matches the state_dict layout of the shipped weights bit-for-bit. No dependencies beyond torch and numpy.
config.json full merged training config
eval_results.json linear-probe and kNN test MSE on α and ζ + the train-fit z-score statistics needed to invert predictions
final.json wall time, parameter counts, final losses
PROVENANCE.txt which checkpoint was packaged + step/epoch

How to load

import torch, json
from model import build_encoder

cfg = json.load(open("config.json"))
m = cfg["model"]["encoder"]

encoder = build_encoder(
    in_chans=11,
    size=cfg["model"]["encoder_size"],   # "small"
    img_size=m["img_size"],              # 256
    patch_size=m["patch_size"],          # 16
    num_frames=m["num_frames"],          # 16
    tubelet_size=m["tubelet_size"],      # 2
    mlp_ratio=m["mlp_ratio"],            # 4.0
)

encoder.load_state_dict(torch.load("encoder.pt", map_location="cpu"))
encoder.eval()

# x: (B, 11, 16, 256, 256) — 16-frame windows of the 11-channel active_matter fields
with torch.no_grad():
    z = encoder(x)            # (B, 2048, 384) — per-token features
    pooled = z.mean(dim=1)    # (B, 384)       — mean-pooled, ready for a linear probe

To swap in the EMA target encoder, replace the load line with torch.load("target_encoder_ema.pt", ...) — same architecture, EMA-averaged weights.

To recover real-units predictions from the linear probe, undo the z-score using the linear_probe.stats block in eval_results.json:

import json, numpy as np
stats = json.load(open("eval_results.json"))["linear_probe"]["stats"]
means = np.array(stats["means"])  # [alpha_mean, zeta_mean]
stds  = np.array(stats["stds"])   # [alpha_std,  zeta_std]
# y_real = y_zscored * stds + means

Intended use

  • Frozen-encoder feature extraction on active_matter-distributed inputs (i.e. the same channel layout, frame count, and 256×256 spatial size). The encoder is shape-strict — the patch embed is 3D (C=11, T=16, H=256, W=256).
  • Linear / kNN probing for downstream regression on physical parameters of the same simulation family.
  • Probe-development baseline for SSL methods on physics simulations.

Out of scope

  • Finetuning: not the project setup; results below were obtained with a frozen encoder. The optimizer / scheduler state is intentionally not shipped.
  • Other channel counts or input sizes: the patch embed is hardcoded to 11 input channels and a 256×256 / 16-frame window. Different inputs need a different stem.
  • Other physics datasets in The Well: trained only on active_matter; transfer is untested.
  • Image-classification-style use: this is a 3D spatio-temporal encoder for scalar/tensor field stacks, not RGB images.

Limitations and reproducibility caveats

  • Non-bit-deterministic: the source repo seeds Python/numpy/torch CPU and CUDA RNG and records them per-run, but does not set torch.backends.cudnn.deterministic = True or torch.use_deterministic_algorithms(True). Combined with AMP fp16, two re-runs land within a small numerical window but are not bit-identical across machines.
  • Single-seed: project champion is one seed (seed: 0). No multi-seed variance band.
  • Probe-only evaluation: α and ζ are recovered via a single Linear layer (closed-form ridge with regularization swept on val) and a non-parametric kNN regressor (k and metric swept on val). No MLP head, no finetuning.
  • 104-sample test set: probe MSEs sit on a small evaluation split. Treat sub-percent differences between top cells as within-noise.

Source code, training pipeline, and full project context

This checkpoint comes from a wider ablation study: 34 cells across routing × backbone × target × loss. The source repo contains the trainer, encoder/predictor/loss code, eval pipeline, all configs, the per-run inventory, and the physics-validation scripts for the derived-field kernels.

See the project repo for the full REFACTORED_CODEBASE/, including README.md (with a reproducibility audit section), results/RUN_INVENTORY.md (every cell's metrics and hyperparameters), and results/PARETO.md (per-metric top-5s and the Pareto frontier).

Acknowledgments

The shipped model.py is a clean-room minimal implementation of standard architecture components — it is licensed under MIT alongside the weights and has no upstream code dependencies. The training pipeline that produced these weights (in the source repo) does vendor several upstream codebases:

  • V-JEPA (facebookresearch/jepa) — ViT encoder, predictor, 3D tubelet patch embed, optimizer/schedule pattern. (Subject to its own upstream license; see the source repo's ENV.md.)
  • physical-representation-learning (helenqu/physical-representation-learning) — WellDatasetForJEPA HDF5 reader, run/config layout.
  • VICReg (Bardes, Ponce, LeCun. ICLR 2022. arXiv:2105.04906) — variance-invariance-covariance loss, re-implemented in this project.
  • BYOL (Grill et al. NeurIPS 2020. arXiv:2006.07733) — EMA target-encoder pattern.
  • The Well — the active_matter simulation dataset.

Architecture references for the model.py reimplementation:

  • ViT trunk: Dosovitskiy et al. (2020). An Image is Worth 16x16 Words. arXiv:2010.11929
  • 3D tubelet patch embedding: Arnab et al. (2021). ViViT: A Video Vision Transformer. arXiv:2103.15691
  • Sincos position embedding: Vaswani et al. (2017). Attention is All You Need. arXiv:1706.03762
Downloads last month
19
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Papers for JJP9216NYUBB/JepaPhysics