|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- self-supervised-learning |
|
|
- world-models |
|
|
- equivariance |
|
|
- vision |
|
|
- pytorch |
|
|
datasets: |
|
|
- 3DIEBench |
|
|
- STL10 |
|
|
--- |
|
|
|
|
|
# seq-JEPA: Autoregressive Predictive Learning of Invariant-Equivariant World Models |
|
|
|
|
|
<p align="center"> |
|
|
<a href="https://openreview.net/forum?id=GKt3VRaCU1"><img src="https://img.shields.io/badge/NeurIPS%202025-Paper-blue" alt="Paper"></a> |
|
|
<a href="https://hafezgh.github.io/seq-jepa/"><img src="https://img.shields.io/badge/Project-Page-green" alt="Project Page"></a> |
|
|
<a href="https://github.com/hafezgh/seq-jepa"><img src="https://img.shields.io/badge/GitHub-Code-black" alt="Code"></a> |
|
|
</p> |
|
|
|
|
|
## Model Description |
|
|
|
|
|
By processing views sequentially with action conditioning, seq-JEPA naturally segregates representations for equivariance- and invariance-demanding tasks. |
|
|
|
|
|
## Available Checkpoints |
|
|
|
|
|
| Checkpoint | Dataset | Training | Download | |
|
|
|------------|---------|----------|----------| |
|
|
| `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) | |
|
|
| `3diebench_rotcol_seqlen4.pth` | 3DIEBench | seq-len=4, rotation and color conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rotcol_seqlen4.pth) | |
|
|
| `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) | |
|
|
|
|
|
## Usage |
|
|
|
|
|
First, clone the repository to access model definitions: |
|
|
|
|
|
git clone https://github.com/hafezgh/seq-jepa.git |
|
|
cd seq-jepaThen load the checkpoints: |
|
|
|
|
|
import torch |
|
|
from models import SeqJEPA_Transforms, SeqJEPA_PLS |
|
|
|
|
|
# 3DIEBench checkpoints |
|
|
kwargs = { |
|
|
"num_heads": 4, "n_channels": 3, "num_enc_layers": 3, |
|
|
"num_classes": 55, "act_cond": True, "pred_hidden": 1024, |
|
|
"act_projdim": 128, "cifar_resnet": False, |
|
|
"learn_act_emb": True |
|
|
} |
|
|
|
|
|
### for ckpt with rotation and color conditioning |
|
|
kwargs["act_latentdim"]=6 |
|
|
### for ckpt with rotation conditioning |
|
|
kwargs["act_latentdim"]=4 |
|
|
|
|
|
model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs) |
|
|
ckpt = torch.load('3diebench_rot_seqlen3.pth') ## or ckpt = torch.load('3diebench_rotcol_seqlen4.pth') for ckpt w/ rotcolor conditioning |
|
|
model.load_state_dict(ckpt['model_state_dict']) |
|
|
|
|
|
# STL10 PLS checkpoint |
|
|
kwargs = { |
|
|
"num_heads": 4, "n_channels": 3, "num_enc_layers": 3, |
|
|
"num_classes": 10, "act_cond": True, "pred_hidden": 1024, |
|
|
"act_projdim": 128, "act_latentdim": 2, "cifar_resnet": True, |
|
|
"learn_act_emb": True, "pos_dim": 2 |
|
|
} |
|
|
model = SeqJEPA_PLS(fovea_size=32, img_size=96, ema=True, ema_decay=0.996, **kwargs) |
|
|
ckpt = torch.load('stl10_pls.pth') |
|
|
model.load_state_dict(ckpt['model_state_dict'])## Citation |
|
|
|
|
|
## Citation |
|
|
|
|
|
@inproceedings{ |
|
|
ghaemi2025seqjepa, |
|
|
title={seq-{JEPA}: Autoregressive Predictive Learning of Invariant-Equivariant World Models}, |
|
|
author={Hafez Ghaemi and Eilif Benjamin Muller and Shahab Bakhtiari}, |
|
|
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems}, |
|
|
year={2025}, |
|
|
url={https://openreview.net/forum?id=GKt3VRaCU1} |
|
|
} |