OnSensorWorldModel
JEPA-based latent world models designed for on-sensor deployment on the Sony IMX500 vision sensor. These models learn stable latent representations from raw pixels using a two-term loss (prediction + SIGReg) and support goal-conditioned planning via CEM in latent space.
Model Variants
This repo contains 5 encoder variants across 4 robot control tasks (19 checkpoints total).
Encoders
| Encoder | Predictor | Params | Quantization | Target Hardware |
|---|---|---|---|---|
| MobileNetV3-Large | GRU (2L, 256h) | 7.6M | Fake Quantization (INT8-ready) | IMX500 sensor |
| EfficientNet-B0 | GRU (2L, 256h) | 9.2M | Fake Quantization (INT8-ready) | IMX500 sensor |
| MobileNetV2 | GRU (2L, 256h) | 7.5M | Fake Quantization (INT8-ready) | IMX500 sensor |
| MobileViT-XS | GRU (2L, 256h) | 9.4M | Fake Quantization (INT8-ready) | IMX500 sensor |
| ResNet-18 | Transformer (6L/16H, AR) | 24.4M | None (FP32 baseline) | GPU baseline |
Results
Planning success rates (%, mean +/- std over 10 seeds, 50 episodes each):
| Encoder | Predictor | Params | Quantization | Task | Success Rate | Checkpoint |
|---|---|---|---|---|---|---|
| MobileNetV3-Large | GRU (2-layer, 256 hidden) | 7.6M | Fake Quantization (INT8-ready) | pusht | 78.0 +/- 6.8 | checkpoints/mobilenetv3/pusht.ckpt |
| MobileNetV3-Large | GRU (2-layer, 256 hidden) | 7.6M | Fake Quantization (INT8-ready) | tworoom | 94.2 +/- 3.8 | checkpoints/mobilenetv3/tworoom.ckpt |
| MobileNetV3-Large | GRU (2-layer, 256 hidden) | 7.6M | Fake Quantization (INT8-ready) | ogb | 70.2 +/- 7.9 | checkpoints/mobilenetv3/ogb.ckpt |
| MobileNetV3-Large | GRU (2-layer, 256 hidden) | 7.6M | Fake Quantization (INT8-ready) | dmc | 82.6 +/- 5.5 | checkpoints/mobilenetv3/dmc.ckpt |
| EfficientNet-B0 | GRU (2-layer, 256 hidden) | 9.2M | Fake Quantization (INT8-ready) | pusht | 78.6 +/- 6.8 | checkpoints/efficientnet-b0/pusht.ckpt |
| EfficientNet-B0 | GRU (2-layer, 256 hidden) | 9.2M | Fake Quantization (INT8-ready) | tworoom | 93.8 +/- 3.7 | checkpoints/efficientnet-b0/tworoom.ckpt |
| EfficientNet-B0 | GRU (2-layer, 256 hidden) | 9.2M | Fake Quantization (INT8-ready) | ogb | 71.8 +/- 6.3 | checkpoints/efficientnet-b0/ogb.ckpt |
| EfficientNet-B0 | GRU (2-layer, 256 hidden) | 9.2M | Fake Quantization (INT8-ready) | dmc | 80.2 +/- 6.4 | checkpoints/efficientnet-b0/dmc.ckpt |
| MobileNetV2 | GRU (2-layer, 256 hidden) | 7.5M | Fake Quantization (INT8-ready) | pusht | 43.4 +/- 4.3 | checkpoints/mobilenetv2/pusht.ckpt |
| MobileNetV2 | GRU (2-layer, 256 hidden) | 7.5M | Fake Quantization (INT8-ready) | tworoom | 93.8 +/- 4.0 | checkpoints/mobilenetv2/tworoom.ckpt |
| MobileNetV2 | GRU (2-layer, 256 hidden) | 7.5M | Fake Quantization (INT8-ready) | ogb | 69.2 +/- 7.3 | checkpoints/mobilenetv2/ogb.ckpt |
| MobileNetV2 | GRU (2-layer, 256 hidden) | 7.5M | Fake Quantization (INT8-ready) | dmc | 86.2 +/- 4.3 | checkpoints/mobilenetv2/dmc.ckpt |
| MobileViT-XS | GRU (2-layer, 256 hidden) | 9.4M | Fake Quantization (INT8-ready) | pusht | 82.6 +/- 4.2 | checkpoints/mobilevit-xs/pusht.ckpt |
| MobileViT-XS | GRU (2-layer, 256 hidden) | 9.4M | Fake Quantization (INT8-ready) | tworoom | 94.2 +/- 4.2 | checkpoints/mobilevit-xs/tworoom.ckpt |
| MobileViT-XS | GRU (2-layer, 256 hidden) | 9.4M | Fake Quantization (INT8-ready) | ogb | 68.8 +/- 5.2 | checkpoints/mobilevit-xs/ogb.ckpt |
| MobileViT-XS | GRU (2-layer, 256 hidden) | 9.4M | Fake Quantization (INT8-ready) | dmc | 76.2 +/- 6.6 | checkpoints/mobilevit-xs/dmc.ckpt |
| ResNet-18 | Transformer (6L/16H, AR) | 24.4M | None (FP32 baseline) | pusht | 87.6 +/- 4.3 | checkpoints/resnet18/pusht.ckpt |
| ResNet-18 | Transformer (6L/16H, AR) | 24.4M | None (FP32 baseline) | tworoom | 88.8 +/- 3.6 | checkpoints/resnet18/tworoom.ckpt |
| ResNet-18 | Transformer (6L/16H, AR) | 24.4M | None (FP32 baseline) | ogb | 66.0 +/- 6.5 | checkpoints/resnet18/ogb.ckpt |
| ResNet-18 | Transformer (6L/16H, AR) | 24.4M | None (FP32 baseline) | dmc | 84.8 +/- 5.8 | checkpoints/resnet18/dmc.ckpt |
Architecture
Encoder (CNN/ViT) -> Projector (MLP+BN, hidden=2048) -> Predictor (GRU/Transformer) -> PredProj (MLP+BN)
- Input: 224x224 RGB images
- Embedding dimension: 192
- History size: 3 frames
- Projector: MLP with BatchNorm, hidden_dim=2048
- Training: 10 epochs per task, SIGReg regularization (lambda=0.09)
- Planning: CEM (300 samples, 30 steps, topk=30, horizon=5)
Tasks
| Task | Description | Action Dim |
|---|---|---|
| PushT | 2D push manipulation | 2 |
| TwoRoom | Grid navigation | 2 |
| OGB (Cube) | 3D robotic arm manipulation | 5 |
| DMC (Reacher) | DeepMind Control reacher | 2 |
Repo Structure
checkpoints/
mobilenetv3/ # MobileNetV3-Large + GRU + FakeQuant
pusht.ckpt
tworoom.ckpt
ogb.ckpt
dmc.ckpt
efficientnet-b0/ # EfficientNet-B0 + GRU + FakeQuant
pusht.ckpt
tworoom.ckpt
ogb.ckpt
dmc.ckpt
mobilenetv2/ # MobileNetV2 + GRU + FakeQuant
pusht.ckpt
tworoom.ckpt
ogb.ckpt
dmc.ckpt
mobilevit-xs/ # MobileViT-XS + GRU + FakeQuant
tworoom.ckpt
ogb.ckpt
dmc.ckpt
resnet18/ # ResNet-18 + Transformer (FP32 baseline)
pusht.ckpt
tworoom.ckpt
ogb.ckpt
dmc.ckpt
config.json # Model configurations for all variants
Usage
import torch
model = torch.load("checkpoints/mobilenetv3/pusht.ckpt", map_location="cpu")
model.eval()
# Encode an observation
info = {"pixels": img_tensor} # [B, T, 3, 224, 224]
info = model.encode(info)
# Predict future states
info["action"] = action_tensor # [B, T, frameskip * action_dim]
preds = model.predict(info)
Citation
Based on LeWorldModel (Maes et al., 2026). On-sensor deployment work for CoRL 2026.
License
Apache 2.0
- Downloads last month
- 169