| --- |
| license: mit |
| tags: |
| - jepa |
| - vicreg |
| - vit3d |
| - physics |
| - self-supervised |
| - representation-learning |
| datasets: |
| - polymathic-ai/active_matter |
| library_name: pytorch |
| --- |
| |
| # ViT3D-d6 / VICReg / FFT β `active_matter` (epoch 29) |
| |
| A 6-block 3D Vision Transformer pretrained with VICReg in a JEPA-style setup |
| on the [`active_matter`](https://polymathic-ai.org/the_well/datasets/active_matter/) |
| dataset from The Well. This checkpoint is the encoder weights at pretrain |
| epoch 29 β the best-validation epoch in our sweep. |
| |
| The encoder produces a frozen `(B, 384, 16, 16)` feature map from a |
| `(B, 11, 16, 256, 256)` input. Linear and k-NN probes on top of those frozen |
| features regress the active-matter parameters $\alpha$ (alignment strength) |
| and $\zeta$ (active stress). |
| |
| ## Architecture |
| |
| | Component | Spec | |
| |---|---| |
| | Input | $(B, 11, 16, 256, 256)$ | |
| | 3D PatchEmbed | `Conv3d(11 β 384, kernel=stride=4Γ16Γ16)` | |
| | Tokens | $T'{\times}H'{\times}W' = 4{\times}16{\times}16 = 1024$ | |
| | Transformer blocks | 6 Γ pre-norm, $h{=}6$, MLP ratio 4, QKV bias | |
| | Pos. embedding | learnable, $1024 \times 384$ | |
| | Output | $(B, 384, 16, 16)$ β time collapsed by mean over $T'$ | |
| | **Total params** | **β 15.4 M** | |
| |
| The companion `ConvPredictor` used during JEPA pretraining (β 7.1 M) is |
| **not** included β only the encoder is published. |
| |
| ## Training recipe |
| |
| | | | |
| |---|---| |
| | Objective | JEPA + VICReg ($\lambda_{\text{sim}}{=}2,\ \lambda_{\text{std}}{=}40,\ \lambda_{\text{cov}}{=}2$) | |
| | Preprocessing | band-limited FFT resize (preserves periodic BCs) | |
| | Optimiser | AdamW, lr $5\!\times\!10^{-4}$, wd $0.05$, cosine, 3 warmup epochs | |
| | Precision | bf16 | |
| | Batch size | 4 | |
| | Epochs | 30 (this checkpoint = epoch 29) | |
| | Hardware | single A100 | |
|
|
| ## Evaluation β frozen-encoder probes on held-out test split |
|
|
| Lower MSE is better. |
|
|
| | Probe | mean MSE $\downarrow$ | $\alpha$ MSE $\downarrow$ | $\zeta$ MSE $\downarrow$ | $k$ / metric | |
| |---|---|---|---|---| |
| | Linear | **0.107** | **0.016** | **0.197** | β | |
| | k-NN | **0.120** | **0.009** | **0.231** | $k=20$, cosine | |
|
|
| Every CNN variant we tried (VICReg baseline, VICReg+FFT, Conv+Attn, |
| Conv+AttnΓ6) lands at linear MSE 0.22β0.27 on the same data; the 3D-patch |
| tokeniser is the change that unlocks the ~3Γ improvement. |
|
|
| ## Usage |
|
|
| ```python |
| import torch |
| from huggingface_hub import hf_hub_download |
| from physics_jepa.utils.model_utils import ViT3DEncoder |
| |
| ckpt_path = hf_hub_download( |
| repo_id="szcharlesji/vit3d-d6-active-matter", |
| filename="ViT3DEncoder_29.pth", |
| ) |
| encoder = ViT3DEncoder( |
| in_chans=11, num_frames=16, img_size=(256, 256), |
| patch_size=(4, 16, 16), embed_dim=384, depth=6, num_heads=6, |
| mlp_ratio=4.0, |
| ) |
| encoder.load_state_dict(torch.load(ckpt_path, map_location="cpu")) |
| encoder.eval() |
| |
| # x: (B, 11, 16, 256, 256), float |
| with torch.no_grad(): |
| feat = encoder(x) # (B, 384, 16, 16) |
| ``` |
|
|
|
|
| ## Citation |
|
|
| ```bibtex |
| @misc{ji2026jepa-active-matter, |
| author = {Charles Cheng Ji, Zhanhe Shi, Richard Wang, Romina Yalovetzky}, |
| title = {Physics-Aware Representation Learning for Physical Systems}, |
| year = {2026}, |
| } |
| ``` |