--- license: mit tags: - self-supervised-learning - world-models - equivariance - vision - pytorch datasets: - 3DIEBench - STL10 --- # seq-JEPA: Autoregressive Predictive Learning of Invariant-Equivariant World Models
## Model Description By processing views sequentially with action conditioning, seq-JEPA naturally segregates representations for equivariance- and invariance-demanding tasks. ## Available Checkpoints | Checkpoint | Dataset | Training | Download | |------------|---------|----------|----------| | `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) | | `3diebench_rotcol_seqlen4.pth` | 3DIEBench | seq-len=4, rotation and color conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rotcol_seqlen4.pth) | | `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) | ## Usage First, clone the repository to access model definitions: git clone https://github.com/hafezgh/seq-jepa.git cd seq-jepaThen load the checkpoints: import torch from models import SeqJEPA_Transforms, SeqJEPA_PLS # 3DIEBench checkpoints kwargs = { "num_heads": 4, "n_channels": 3, "num_enc_layers": 3, "num_classes": 55, "act_cond": True, "pred_hidden": 1024, "act_projdim": 128, "cifar_resnet": False, "learn_act_emb": True } ### for ckpt with rotation and color conditioning kwargs["act_latentdim"]=6 ### for ckpt with rotation conditioning kwargs["act_latentdim"]=4 model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs) ckpt = torch.load('3diebench_rot_seqlen3.pth') ## or ckpt = torch.load('3diebench_rotcol_seqlen4.pth') for ckpt w/ rotcolor conditioning model.load_state_dict(ckpt['model_state_dict']) # STL10 PLS checkpoint kwargs = { "num_heads": 4, "n_channels": 3, "num_enc_layers": 3, "num_classes": 10, "act_cond": True, "pred_hidden": 1024, "act_projdim": 128, "act_latentdim": 2, "cifar_resnet": True, "learn_act_emb": True, "pos_dim": 2 } model = SeqJEPA_PLS(fovea_size=32, img_size=96, ema=True, ema_decay=0.996, **kwargs) ckpt = torch.load('stl10_pls.pth') model.load_state_dict(ckpt['model_state_dict'])## Citation ## Citation @inproceedings{ ghaemi2025seqjepa, title={seq-{JEPA}: Autoregressive Predictive Learning of Invariant-Equivariant World Models}, author={Hafez Ghaemi and Eilif Benjamin Muller and Shahab Bakhtiari}, booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems}, year={2025}, url={https://openreview.net/forum?id=GKt3VRaCU1} }