File size: 3,124 Bytes
beeddc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9e535d
beeddc5
 
 
 
 
 
 
 
 
 
 
 
a9e535d
beeddc5
 
 
a9e535d
beeddc5
 
a9e535d
 
 
 
 
 
beeddc5
a9e535d
beeddc5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
---
license: mit
tags:
  - self-supervised-learning
  - world-models
  - equivariance
  - vision
  - pytorch
datasets:
  - 3DIEBench
  - STL10
---

# seq-JEPA: Autoregressive Predictive Learning of Invariant-Equivariant World Models

<p align="center">
  <a href="https://openreview.net/forum?id=GKt3VRaCU1"><img src="https://img.shields.io/badge/NeurIPS%202025-Paper-blue" alt="Paper"></a>
  <a href="https://hafezgh.github.io/seq-jepa/"><img src="https://img.shields.io/badge/Project-Page-green" alt="Project Page"></a>
  <a href="https://github.com/hafezgh/seq-jepa"><img src="https://img.shields.io/badge/GitHub-Code-black" alt="Code"></a>
</p>

## Model Description

By processing views sequentially with action conditioning, seq-JEPA naturally segregates representations for equivariance- and invariance-demanding tasks.

## Available Checkpoints

| Checkpoint | Dataset | Training | Download |
|------------|---------|----------|----------|
| `3diebench_rot_seqlen3.pth` | 3DIEBench | seq-len=3, rotation conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rot_seqlen3.pth) |
| `3diebench_rotcol_seqlen4.pth` | 3DIEBench | seq-len=4, rotation and color conditioning | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/3diebench_rotcol_seqlen4.pth) |
| `stl10_pls.pth` | STL10 | PLS (predictive learning across saccades) | [Download](https://huggingface.co/Hafez/seq-JEPA/resolve/main/stl10_pls.pth) |

## Usage

First, clone the repository to access model definitions:

git clone https://github.com/hafezgh/seq-jepa.git
cd seq-jepaThen load the checkpoints:

import torch
from models import SeqJEPA_Transforms, SeqJEPA_PLS

# 3DIEBench checkpoints
kwargs = {
    "num_heads": 4, "n_channels": 3, "num_enc_layers": 3,
    "num_classes": 55, "act_cond": True, "pred_hidden": 1024,
    "act_projdim": 128, "cifar_resnet": False,
    "learn_act_emb": True
}

### for ckpt with rotation and color conditioning
kwargs["act_latentdim"]=6
### for ckpt with rotation conditioning
kwargs["act_latentdim"]=4

model = SeqJEPA_Transforms(img_size=128, ema=True, ema_decay=0.996, **kwargs)
ckpt = torch.load('3diebench_rot_seqlen3.pth') ## or ckpt = torch.load('3diebench_rotcol_seqlen4.pth') for ckpt w/ rotcolor conditioning
model.load_state_dict(ckpt['model_state_dict'])

# STL10 PLS checkpoint  
kwargs = {
    "num_heads": 4, "n_channels": 3, "num_enc_layers": 3,
    "num_classes": 10, "act_cond": True, "pred_hidden": 1024,
    "act_projdim": 128, "act_latentdim": 2, "cifar_resnet": True,
    "learn_act_emb": True, "pos_dim": 2
}
model = SeqJEPA_PLS(fovea_size=32, img_size=96, ema=True, ema_decay=0.996, **kwargs)
ckpt = torch.load('stl10_pls.pth')
model.load_state_dict(ckpt['model_state_dict'])## Citation

## Citation

@inproceedings{
ghaemi2025seqjepa,
title={seq-{JEPA}: Autoregressive Predictive Learning of Invariant-Equivariant World Models},
author={Hafez Ghaemi and Eilif Benjamin Muller and Shahab Bakhtiari},
booktitle={The Thirty-ninth Annual Conference on Neural Information Processing Systems},
year={2025},
url={https://openreview.net/forum?id=GKt3VRaCU1}
}