OnSensorWorldModel

JEPA-based latent world models designed for on-sensor deployment on the Sony IMX500 vision sensor. These models learn stable latent representations from raw pixels using a two-term loss (prediction + SIGReg) and support goal-conditioned planning via CEM in latent space.

Model Variants

This repo contains 5 encoder variants across 4 robot control tasks (19 checkpoints total).

Encoders

Encoder Predictor Params Quantization Target Hardware
MobileNetV3-Large GRU (2L, 256h) 7.6M Fake Quantization (INT8-ready) IMX500 sensor
EfficientNet-B0 GRU (2L, 256h) 9.2M Fake Quantization (INT8-ready) IMX500 sensor
MobileNetV2 GRU (2L, 256h) 7.5M Fake Quantization (INT8-ready) IMX500 sensor
MobileViT-XS GRU (2L, 256h) 9.4M Fake Quantization (INT8-ready) IMX500 sensor
ResNet-18 Transformer (6L/16H, AR) 24.4M None (FP32 baseline) GPU baseline

Results

Planning success rates (%, mean +/- std over 10 seeds, 50 episodes each):

Encoder Predictor Params Quantization Task Success Rate Checkpoint
MobileNetV3-Large GRU (2-layer, 256 hidden) 7.6M Fake Quantization (INT8-ready) pusht 78.0 +/- 6.8 checkpoints/mobilenetv3/pusht.ckpt
MobileNetV3-Large GRU (2-layer, 256 hidden) 7.6M Fake Quantization (INT8-ready) tworoom 94.2 +/- 3.8 checkpoints/mobilenetv3/tworoom.ckpt
MobileNetV3-Large GRU (2-layer, 256 hidden) 7.6M Fake Quantization (INT8-ready) ogb 70.2 +/- 7.9 checkpoints/mobilenetv3/ogb.ckpt
MobileNetV3-Large GRU (2-layer, 256 hidden) 7.6M Fake Quantization (INT8-ready) dmc 82.6 +/- 5.5 checkpoints/mobilenetv3/dmc.ckpt
EfficientNet-B0 GRU (2-layer, 256 hidden) 9.2M Fake Quantization (INT8-ready) pusht 78.6 +/- 6.8 checkpoints/efficientnet-b0/pusht.ckpt
EfficientNet-B0 GRU (2-layer, 256 hidden) 9.2M Fake Quantization (INT8-ready) tworoom 93.8 +/- 3.7 checkpoints/efficientnet-b0/tworoom.ckpt
EfficientNet-B0 GRU (2-layer, 256 hidden) 9.2M Fake Quantization (INT8-ready) ogb 71.8 +/- 6.3 checkpoints/efficientnet-b0/ogb.ckpt
EfficientNet-B0 GRU (2-layer, 256 hidden) 9.2M Fake Quantization (INT8-ready) dmc 80.2 +/- 6.4 checkpoints/efficientnet-b0/dmc.ckpt
MobileNetV2 GRU (2-layer, 256 hidden) 7.5M Fake Quantization (INT8-ready) pusht 43.4 +/- 4.3 checkpoints/mobilenetv2/pusht.ckpt
MobileNetV2 GRU (2-layer, 256 hidden) 7.5M Fake Quantization (INT8-ready) tworoom 93.8 +/- 4.0 checkpoints/mobilenetv2/tworoom.ckpt
MobileNetV2 GRU (2-layer, 256 hidden) 7.5M Fake Quantization (INT8-ready) ogb 69.2 +/- 7.3 checkpoints/mobilenetv2/ogb.ckpt
MobileNetV2 GRU (2-layer, 256 hidden) 7.5M Fake Quantization (INT8-ready) dmc 86.2 +/- 4.3 checkpoints/mobilenetv2/dmc.ckpt
MobileViT-XS GRU (2-layer, 256 hidden) 9.4M Fake Quantization (INT8-ready) pusht 82.6 +/- 4.2 checkpoints/mobilevit-xs/pusht.ckpt
MobileViT-XS GRU (2-layer, 256 hidden) 9.4M Fake Quantization (INT8-ready) tworoom 94.2 +/- 4.2 checkpoints/mobilevit-xs/tworoom.ckpt
MobileViT-XS GRU (2-layer, 256 hidden) 9.4M Fake Quantization (INT8-ready) ogb 68.8 +/- 5.2 checkpoints/mobilevit-xs/ogb.ckpt
MobileViT-XS GRU (2-layer, 256 hidden) 9.4M Fake Quantization (INT8-ready) dmc 76.2 +/- 6.6 checkpoints/mobilevit-xs/dmc.ckpt
ResNet-18 Transformer (6L/16H, AR) 24.4M None (FP32 baseline) pusht 87.6 +/- 4.3 checkpoints/resnet18/pusht.ckpt
ResNet-18 Transformer (6L/16H, AR) 24.4M None (FP32 baseline) tworoom 88.8 +/- 3.6 checkpoints/resnet18/tworoom.ckpt
ResNet-18 Transformer (6L/16H, AR) 24.4M None (FP32 baseline) ogb 66.0 +/- 6.5 checkpoints/resnet18/ogb.ckpt
ResNet-18 Transformer (6L/16H, AR) 24.4M None (FP32 baseline) dmc 84.8 +/- 5.8 checkpoints/resnet18/dmc.ckpt

Architecture

Encoder (CNN/ViT) -> Projector (MLP+BN, hidden=2048) -> Predictor (GRU/Transformer) -> PredProj (MLP+BN)
  • Input: 224x224 RGB images
  • Embedding dimension: 192
  • History size: 3 frames
  • Projector: MLP with BatchNorm, hidden_dim=2048
  • Training: 10 epochs per task, SIGReg regularization (lambda=0.09)
  • Planning: CEM (300 samples, 30 steps, topk=30, horizon=5)

Tasks

Task Description Action Dim
PushT 2D push manipulation 2
TwoRoom Grid navigation 2
OGB (Cube) 3D robotic arm manipulation 5
DMC (Reacher) DeepMind Control reacher 2

Repo Structure

checkpoints/
  mobilenetv3/       # MobileNetV3-Large + GRU + FakeQuant
    pusht.ckpt
    tworoom.ckpt
    ogb.ckpt
    dmc.ckpt
  efficientnet-b0/   # EfficientNet-B0 + GRU + FakeQuant
    pusht.ckpt
    tworoom.ckpt
    ogb.ckpt
    dmc.ckpt
  mobilenetv2/       # MobileNetV2 + GRU + FakeQuant
    pusht.ckpt
    tworoom.ckpt
    ogb.ckpt
    dmc.ckpt
  mobilevit-xs/      # MobileViT-XS + GRU + FakeQuant
    tworoom.ckpt
    ogb.ckpt
    dmc.ckpt
  resnet18/          # ResNet-18 + Transformer (FP32 baseline)
    pusht.ckpt
    tworoom.ckpt
    ogb.ckpt
    dmc.ckpt
config.json          # Model configurations for all variants

Usage

import torch

model = torch.load("checkpoints/mobilenetv3/pusht.ckpt", map_location="cpu")
model.eval()

# Encode an observation
info = {"pixels": img_tensor}  # [B, T, 3, 224, 224]
info = model.encode(info)

# Predict future states
info["action"] = action_tensor  # [B, T, frameskip * action_dim]
preds = model.predict(info)

Citation

Based on LeWorldModel (Maes et al., 2026). On-sensor deployment work for CoRL 2026.

License

Apache 2.0

Downloads last month
169
Video Preview
loading