--- license: mit tags: - jepa - vicreg - vit3d - physics - self-supervised - representation-learning datasets: - polymathic-ai/active_matter library_name: pytorch --- # ViT3D-d6 / VICReg / FFT — `active_matter` (epoch 29) A 6-block 3D Vision Transformer pretrained with VICReg in a JEPA-style setup on the [`active_matter`](https://polymathic-ai.org/the_well/datasets/active_matter/) dataset from The Well. This checkpoint is the encoder weights at pretrain epoch 29 — the best-validation epoch in our sweep. The encoder produces a frozen `(B, 384, 16, 16)` feature map from a `(B, 11, 16, 256, 256)` input. Linear and k-NN probes on top of those frozen features regress the active-matter parameters $\alpha$ (alignment strength) and $\zeta$ (active stress). ## Architecture | Component | Spec | |---|---| | Input | $(B, 11, 16, 256, 256)$ | | 3D PatchEmbed | `Conv3d(11 → 384, kernel=stride=4×16×16)` | | Tokens | $T'{\times}H'{\times}W' = 4{\times}16{\times}16 = 1024$ | | Transformer blocks | 6 × pre-norm, $h{=}6$, MLP ratio 4, QKV bias | | Pos. embedding | learnable, $1024 \times 384$ | | Output | $(B, 384, 16, 16)$ — time collapsed by mean over $T'$ | | **Total params** | **≈ 15.4 M** | The companion `ConvPredictor` used during JEPA pretraining (≈ 7.1 M) is **not** included — only the encoder is published. ## Training recipe | | | |---|---| | Objective | JEPA + VICReg ($\lambda_{\text{sim}}{=}2,\ \lambda_{\text{std}}{=}40,\ \lambda_{\text{cov}}{=}2$) | | Preprocessing | band-limited FFT resize (preserves periodic BCs) | | Optimiser | AdamW, lr $5\!\times\!10^{-4}$, wd $0.05$, cosine, 3 warmup epochs | | Precision | bf16 | | Batch size | 4 | | Epochs | 30 (this checkpoint = epoch 29) | | Hardware | single A100 | ## Evaluation — frozen-encoder probes on held-out test split Lower MSE is better. | Probe | mean MSE $\downarrow$ | $\alpha$ MSE $\downarrow$ | $\zeta$ MSE $\downarrow$ | $k$ / metric | |---|---|---|---|---| | Linear | **0.107** | **0.016** | **0.197** | — | | k-NN | **0.120** | **0.009** | **0.231** | $k=20$, cosine | Every CNN variant we tried (VICReg baseline, VICReg+FFT, Conv+Attn, Conv+Attn×6) lands at linear MSE 0.22–0.27 on the same data; the 3D-patch tokeniser is the change that unlocks the ~3× improvement. ## Usage ```python import torch from huggingface_hub import hf_hub_download from physics_jepa.utils.model_utils import ViT3DEncoder ckpt_path = hf_hub_download( repo_id="szcharlesji/vit3d-d6-active-matter", filename="ViT3DEncoder_29.pth", ) encoder = ViT3DEncoder( in_chans=11, num_frames=16, img_size=(256, 256), patch_size=(4, 16, 16), embed_dim=384, depth=6, num_heads=6, mlp_ratio=4.0, ) encoder.load_state_dict(torch.load(ckpt_path, map_location="cpu")) encoder.eval() # x: (B, 11, 16, 256, 256), float with torch.no_grad(): feat = encoder(x) # (B, 384, 16, 16) ``` ## Citation ```bibtex @misc{ji2026jepa-active-matter, author = {Charles Cheng Ji, Zhanhe Shi, Richard Wang, Romina Yalovetzky}, title = {Physics-Aware Representation Learning for Physical Systems}, year = {2026}, } ```