Update README.md
Browse files
README.md
CHANGED
|
@@ -28,6 +28,7 @@ By processing views sequentially with action conditioning, seq-JEPA naturally se
|
|
| 28 |
| Checkpoint | Dataset | Training | Download |
|
| 29 |
|------------|---------|----------|----------|
|
| 30 |
| `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) |
|
|
|
|
| 31 |
| `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) |
|
| 32 |
|
| 33 |
## Usage
|
|
@@ -40,15 +41,21 @@ cd seq-jepaThen load the checkpoints:
|
|
| 40 |
import torch
|
| 41 |
from models import SeqJEPA_Transforms, SeqJEPA_PLS
|
| 42 |
|
| 43 |
-
# 3DIEBench
|
| 44 |
kwargs = {
|
| 45 |
"num_heads": 4, "n_channels": 3, "num_enc_layers": 3,
|
| 46 |
"num_classes": 55, "act_cond": True, "pred_hidden": 1024,
|
| 47 |
-
"act_projdim": 128, "
|
| 48 |
"learn_act_emb": True
|
| 49 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs)
|
| 51 |
-
ckpt = torch.load('3diebench_rot_seqlen3.pth')
|
| 52 |
model.load_state_dict(ckpt['model_state_dict'])
|
| 53 |
|
| 54 |
# STL10 PLS checkpoint
|
|
|
|
| 28 |
| Checkpoint | Dataset | Training | Download |
|
| 29 |
|------------|---------|----------|----------|
|
| 30 |
| `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) |
|
| 31 |
+
| `3diebench_rotcol_seqlen4.pth` | 3DIEBench | seq-len=4, rotation and color conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rotcol_seqlen4.pth) |
|
| 32 |
| `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) |
|
| 33 |
|
| 34 |
## Usage
|
|
|
|
| 41 |
import torch
|
| 42 |
from models import SeqJEPA_Transforms, SeqJEPA_PLS
|
| 43 |
|
| 44 |
+
# 3DIEBench checkpoints
|
| 45 |
kwargs = {
|
| 46 |
"num_heads": 4, "n_channels": 3, "num_enc_layers": 3,
|
| 47 |
"num_classes": 55, "act_cond": True, "pred_hidden": 1024,
|
| 48 |
+
"act_projdim": 128, "cifar_resnet": False,
|
| 49 |
"learn_act_emb": True
|
| 50 |
}
|
| 51 |
+
|
| 52 |
+
### for ckpt with rotation and color conditioning
|
| 53 |
+
kwargs["act_latentdim"]=6
|
| 54 |
+
### for ckpt with rotation conditioning
|
| 55 |
+
kwargs["act_latentdim"]=4
|
| 56 |
+
|
| 57 |
model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs)
|
| 58 |
+
ckpt = torch.load('3diebench_rot_seqlen3.pth') ## or ckpt = torch.load('3diebench_rotcol_seqlen4.pth') for ckpt w/ rotcolor conditioning
|
| 59 |
model.load_state_dict(ckpt['model_state_dict'])
|
| 60 |
|
| 61 |
# STL10 PLS checkpoint
|