ViT3D-d6 / VICReg / FFT β active_matter (epoch 29)
A 6-block 3D Vision Transformer pretrained with VICReg in a JEPA-style setup
on the active_matter
dataset from The Well. This checkpoint is the encoder weights at pretrain
epoch 29 β the best-validation epoch in our sweep.
The encoder produces a frozen (B, 384, 16, 16) feature map from a
(B, 11, 16, 256, 256) input. Linear and k-NN probes on top of those frozen
features regress the active-matter parameters $\alpha$ (alignment strength)
and $\zeta$ (active stress).
Architecture
| Component | Spec |
|---|---|
| Input | $(B, 11, 16, 256, 256)$ |
| 3D PatchEmbed | Conv3d(11 β 384, kernel=stride=4Γ16Γ16) |
| Tokens | $T'{\times}H'{\times}W' = 4{\times}16{\times}16 = 1024$ |
| Transformer blocks | 6 Γ pre-norm, $h{=}6$, MLP ratio 4, QKV bias |
| Pos. embedding | learnable, $1024 \times 384$ |
| Output | $(B, 384, 16, 16)$ β time collapsed by mean over $T'$ |
| Total params | β 15.4 M |
The companion ConvPredictor used during JEPA pretraining (β 7.1 M) is
not included β only the encoder is published.
Training recipe
| Objective | JEPA + VICReg ($\lambda_{\text{sim}}{=}2,\ \lambda_{\text{std}}{=}40,\ \lambda_{\text{cov}}{=}2$) |
| Preprocessing | band-limited FFT resize (preserves periodic BCs) |
| Optimiser | AdamW, lr $5!\times!10^{-4}$, wd $0.05$, cosine, 3 warmup epochs |
| Precision | bf16 |
| Batch size | 4 |
| Epochs | 30 (this checkpoint = epoch 29) |
| Hardware | single A100 |
Evaluation β frozen-encoder probes on held-out test split
Lower MSE is better.
| Probe | mean MSE $\downarrow$ | $\alpha$ MSE $\downarrow$ | $\zeta$ MSE $\downarrow$ | $k$ / metric |
|---|---|---|---|---|
| Linear | 0.107 | 0.016 | 0.197 | β |
| k-NN | 0.120 | 0.009 | 0.231 | $k=20$, cosine |
Every CNN variant we tried (VICReg baseline, VICReg+FFT, Conv+Attn, Conv+AttnΓ6) lands at linear MSE 0.22β0.27 on the same data; the 3D-patch tokeniser is the change that unlocks the ~3Γ improvement.
Usage
import torch
from huggingface_hub import hf_hub_download
from physics_jepa.utils.model_utils import ViT3DEncoder
ckpt_path = hf_hub_download(
repo_id="szcharlesji/vit3d-d6-active-matter",
filename="ViT3DEncoder_29.pth",
)
encoder = ViT3DEncoder(
in_chans=11, num_frames=16, img_size=(256, 256),
patch_size=(4, 16, 16), embed_dim=384, depth=6, num_heads=6,
mlp_ratio=4.0,
)
encoder.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
encoder.eval()
# x: (B, 11, 16, 256, 256), float
with torch.no_grad():
feat = encoder(x) # (B, 384, 16, 16)
Citation
@misc{ji2026jepa-active-matter,
author = {Charles Cheng Ji, Zhanhe Shi, Richard Wang, Romina Yalovetzky},
title = {Physics-Aware Representation Learning for Physical Systems},
year = {2026},
}