ViT3D-Encoder / README.md
szcharlesji's picture
Update README.md
41bdbc2 verified
metadata
license: mit
tags:
  - jepa
  - vicreg
  - vit3d
  - physics
  - self-supervised
  - representation-learning
datasets:
  - polymathic-ai/active_matter
library_name: pytorch

ViT3D-d6 / VICReg / FFT — active_matter (epoch 29)

A 6-block 3D Vision Transformer pretrained with VICReg in a JEPA-style setup on the active_matter dataset from The Well. This checkpoint is the encoder weights at pretrain epoch 29 — the best-validation epoch in our sweep.

The encoder produces a frozen (B, 384, 16, 16) feature map from a (B, 11, 16, 256, 256) input. Linear and k-NN probes on top of those frozen features regress the active-matter parameters $\alpha$ (alignment strength) and $\zeta$ (active stress).

Architecture

Component Spec
Input $(B, 11, 16, 256, 256)$
3D PatchEmbed Conv3d(11 → 384, kernel=stride=4×16×16)
Tokens $T'{\times}H'{\times}W' = 4{\times}16{\times}16 = 1024$
Transformer blocks 6 × pre-norm, $h{=}6$, MLP ratio 4, QKV bias
Pos. embedding learnable, $1024 \times 384$
Output $(B, 384, 16, 16)$ — time collapsed by mean over $T'$
Total params ≈ 15.4 M

The companion ConvPredictor used during JEPA pretraining (≈ 7.1 M) is not included — only the encoder is published.

Training recipe

Objective JEPA + VICReg ($\lambda_{\text{sim}}{=}2,\ \lambda_{\text{std}}{=}40,\ \lambda_{\text{cov}}{=}2$)
Preprocessing band-limited FFT resize (preserves periodic BCs)
Optimiser AdamW, lr $5!\times!10^{-4}$, wd $0.05$, cosine, 3 warmup epochs
Precision bf16
Batch size 4
Epochs 30 (this checkpoint = epoch 29)
Hardware single A100

Evaluation — frozen-encoder probes on held-out test split

Lower MSE is better.

Probe mean MSE $\downarrow$ $\alpha$ MSE $\downarrow$ $\zeta$ MSE $\downarrow$ $k$ / metric
Linear 0.107 0.016 0.197
k-NN 0.120 0.009 0.231 $k=20$, cosine

Every CNN variant we tried (VICReg baseline, VICReg+FFT, Conv+Attn, Conv+Attn×6) lands at linear MSE 0.22–0.27 on the same data; the 3D-patch tokeniser is the change that unlocks the ~3× improvement.

Usage

import torch
from huggingface_hub import hf_hub_download
from physics_jepa.utils.model_utils import ViT3DEncoder

ckpt_path = hf_hub_download(
    repo_id="szcharlesji/vit3d-d6-active-matter",
    filename="ViT3DEncoder_29.pth",
)
encoder = ViT3DEncoder(
    in_chans=11, num_frames=16, img_size=(256, 256),
    patch_size=(4, 16, 16), embed_dim=384, depth=6, num_heads=6,
    mlp_ratio=4.0,
)
encoder.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
encoder.eval()

# x: (B, 11, 16, 256, 256), float
with torch.no_grad():
    feat = encoder(x)   # (B, 384, 16, 16)

Citation

@misc{ji2026jepa-active-matter,
  author = {Charles Cheng Ji, Zhanhe Shi, Richard Wang, Romina Yalovetzky},
  title  = {Physics-Aware Representation Learning for Physical Systems},
  year   = {2026},
}