szcharlesji commited on
Commit
41bdbc2
·
verified ·
1 Parent(s): 52f1848

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +100 -3
README.md CHANGED
@@ -1,3 +1,100 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - jepa
5
+ - vicreg
6
+ - vit3d
7
+ - physics
8
+ - self-supervised
9
+ - representation-learning
10
+ datasets:
11
+ - polymathic-ai/active_matter
12
+ library_name: pytorch
13
+ ---
14
+
15
+ # ViT3D-d6 / VICReg / FFT — `active_matter` (epoch 29)
16
+
17
+ A 6-block 3D Vision Transformer pretrained with VICReg in a JEPA-style setup
18
+ on the [`active_matter`](https://polymathic-ai.org/the_well/datasets/active_matter/)
19
+ dataset from The Well. This checkpoint is the encoder weights at pretrain
20
+ epoch 29 — the best-validation epoch in our sweep.
21
+
22
+ The encoder produces a frozen `(B, 384, 16, 16)` feature map from a
23
+ `(B, 11, 16, 256, 256)` input. Linear and k-NN probes on top of those frozen
24
+ features regress the active-matter parameters $\alpha$ (alignment strength)
25
+ and $\zeta$ (active stress).
26
+
27
+ ## Architecture
28
+
29
+ | Component | Spec |
30
+ |---|---|
31
+ | Input | $(B, 11, 16, 256, 256)$ |
32
+ | 3D PatchEmbed | `Conv3d(11 → 384, kernel=stride=4×16×16)` |
33
+ | Tokens | $T'{\times}H'{\times}W' = 4{\times}16{\times}16 = 1024$ |
34
+ | Transformer blocks | 6 × pre-norm, $h{=}6$, MLP ratio 4, QKV bias |
35
+ | Pos. embedding | learnable, $1024 \times 384$ |
36
+ | Output | $(B, 384, 16, 16)$ — time collapsed by mean over $T'$ |
37
+ | **Total params** | **≈ 15.4 M** |
38
+
39
+ The companion `ConvPredictor` used during JEPA pretraining (≈ 7.1 M) is
40
+ **not** included — only the encoder is published.
41
+
42
+ ## Training recipe
43
+
44
+ | | |
45
+ |---|---|
46
+ | Objective | JEPA + VICReg ($\lambda_{\text{sim}}{=}2,\ \lambda_{\text{std}}{=}40,\ \lambda_{\text{cov}}{=}2$) |
47
+ | Preprocessing | band-limited FFT resize (preserves periodic BCs) |
48
+ | Optimiser | AdamW, lr $5\!\times\!10^{-4}$, wd $0.05$, cosine, 3 warmup epochs |
49
+ | Precision | bf16 |
50
+ | Batch size | 4 |
51
+ | Epochs | 30 (this checkpoint = epoch 29) |
52
+ | Hardware | single A100 |
53
+
54
+ ## Evaluation — frozen-encoder probes on held-out test split
55
+
56
+ Lower MSE is better.
57
+
58
+ | Probe | mean MSE $\downarrow$ | $\alpha$ MSE $\downarrow$ | $\zeta$ MSE $\downarrow$ | $k$ / metric |
59
+ |---|---|---|---|---|
60
+ | Linear | **0.107** | **0.016** | **0.197** | — |
61
+ | k-NN | **0.120** | **0.009** | **0.231** | $k=20$, cosine |
62
+
63
+ Every CNN variant we tried (VICReg baseline, VICReg+FFT, Conv+Attn,
64
+ Conv+Attn×6) lands at linear MSE 0.22–0.27 on the same data; the 3D-patch
65
+ tokeniser is the change that unlocks the ~3× improvement.
66
+
67
+ ## Usage
68
+
69
+ ```python
70
+ import torch
71
+ from huggingface_hub import hf_hub_download
72
+ from physics_jepa.utils.model_utils import ViT3DEncoder
73
+
74
+ ckpt_path = hf_hub_download(
75
+ repo_id="szcharlesji/vit3d-d6-active-matter",
76
+ filename="ViT3DEncoder_29.pth",
77
+ )
78
+ encoder = ViT3DEncoder(
79
+ in_chans=11, num_frames=16, img_size=(256, 256),
80
+ patch_size=(4, 16, 16), embed_dim=384, depth=6, num_heads=6,
81
+ mlp_ratio=4.0,
82
+ )
83
+ encoder.load_state_dict(torch.load(ckpt_path, map_location="cpu"))
84
+ encoder.eval()
85
+
86
+ # x: (B, 11, 16, 256, 256), float
87
+ with torch.no_grad():
88
+ feat = encoder(x) # (B, 384, 16, 16)
89
+ ```
90
+
91
+
92
+ ## Citation
93
+
94
+ ```bibtex
95
+ @misc{ji2026jepa-active-matter,
96
+ author = {Charles Cheng Ji, Zhanhe Shi, Richard Wang, Romina Yalovetzky},
97
+ title = {Physics-Aware Representation Learning for Physical Systems},
98
+ year = {2026},
99
+ }
100
+ ```