Hafez commited on
Commit
a9e535d
·
verified ·
1 Parent(s): eb9b83b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +10 -3
README.md CHANGED
@@ -28,6 +28,7 @@ By processing views sequentially with action conditioning, seq-JEPA naturally se
28
  | Checkpoint | Dataset | Training | Download |
29
  |------------|---------|----------|----------|
30
  | `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) |
 
31
  | `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) |
32
 
33
  ## Usage
@@ -40,15 +41,21 @@ cd seq-jepaThen load the checkpoints:
40
  import torch
41
  from models import SeqJEPA_Transforms, SeqJEPA_PLS
42
 
43
- # 3DIEBench checkpoint
44
  kwargs = {
45
  "num_heads": 4, "n_channels": 3, "num_enc_layers": 3,
46
  "num_classes": 55, "act_cond": True, "pred_hidden": 1024,
47
- "act_projdim": 128, "act_latentdim": 4, "cifar_resnet": False,
48
  "learn_act_emb": True
49
  }
 
 
 
 
 
 
50
  model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs)
51
- ckpt = torch.load('3diebench_rot_seqlen3.pth')
52
  model.load_state_dict(ckpt['model_state_dict'])
53
 
54
  # STL10 PLS checkpoint
 
28
  | Checkpoint | Dataset | Training | Download |
29
  |------------|---------|----------|----------|
30
  | `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) |
31
+ | `3diebench_rotcol_seqlen4.pth` | 3DIEBench | seq-len=4, rotation and color conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rotcol_seqlen4.pth) |
32
  | `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) |
33
 
34
  ## Usage
 
41
  import torch
42
  from models import SeqJEPA_Transforms, SeqJEPA_PLS
43
 
44
+ # 3DIEBench checkpoints
45
  kwargs = {
46
  "num_heads": 4, "n_channels": 3, "num_enc_layers": 3,
47
  "num_classes": 55, "act_cond": True, "pred_hidden": 1024,
48
+ "act_projdim": 128, "cifar_resnet": False,
49
  "learn_act_emb": True
50
  }
51
+
52
+ ### for ckpt with rotation and color conditioning
53
+ kwargs["act_latentdim"]=6
54
+ ### for ckpt with rotation conditioning
55
+ kwargs["act_latentdim"]=4
56
+
57
  model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs)
58
+ ckpt = torch.load('3diebench_rot_seqlen3.pth') ## or ckpt = torch.load('3diebench_rotcol_seqlen4.pth') for ckpt w/ rotcolor conditioning
59
  model.load_state_dict(ckpt['model_state_dict'])
60
 
61
  # STL10 PLS checkpoint