SIAD World Model (Medium)
Satellite Imagery Anticipatory Dynamics - A transformer-based world model for predicting future satellite observations.
Model Description
This model predicts future satellite imagery based on:
- Current satellite observations (Sentinel-2, Sentinel-1, VIIRS nightlights)
- Climate action variables (rainfall and temperature anomalies)
Uses JEPA (Joint Embedding Predictive Architecture) with token-based spatial representations.
Architecture
- Size: 378,914,568 parameters (378.9M)
- Type: medium variant
- Latent Dimension: 1024
- Encoder: 8 transformer blocks, 16 heads
- Transition Model: 12 transformer blocks, 16 heads
- Decoder: ConvTranspose (16×16 → 256×256)
- Spatial Tokens: 256 tokens (16×16 grid)
- Input Channels: 8 (Sentinel-2: B2,B3,B4,B8 | Sentinel-1: VV,VH | VIIRS | mask)
- Rollout Horizon: 6 months
Training
- Best Val Loss: 0.0131
- Epochs: 85
- Decoder Val Loss: 0.0711 (MSE in pixel space)
Quick Start
from transformers import AutoModel
import torch
# Load model from HuggingFace Hub
model = AutoModel.from_pretrained("ozlabs/siad-wm-medium", trust_remote_code=True)
model.inference_mode()
# Prepare inputs
obs_context = torch.randn(1, 8, 256, 256) # Current observation
actions = torch.randn(1, 6, 2) # 6-month climate actions
# Run prediction (with decoder for pixel-space output)
with torch.no_grad():
z0 = model.encode(obs_context)
z_pred = model.rollout(z0, actions, H=6)
x_pred = model.decode(z_pred) # [1, 6, 8, 256, 256] - Decode to pixels
print(f"Predicted 6 months (pixels): {{x_pred.shape}}")
Visualization (RGB Composite)
import numpy as np
import matplotlib.pyplot as plt
def create_rgb(bands: np.ndarray) -> np.ndarray:
"""Create RGB composite from 8-band satellite image"""
rgb = bands[[2, 1, 0]].transpose(1, 2, 0) # [H, W, 3]
for i in range(3):
channel = rgb[:, :, i]
vmin, vmax = np.percentile(channel, [2, 98])
rgb[:, :, i] = np.clip((channel - vmin) / (vmax - vmin + 1e-8), 0, 1)
return rgb
# Visualize first prediction
x_first = x_pred[0, 0].cpu().numpy() # [8, 256, 256]
rgb = create_rgb(x_first)
plt.imshow(rgb)
plt.axis("off")
plt.show()
Advanced Usage
# Full forward pass with loss computation
outputs = model(
obs_context=obs_context,
actions_rollout=actions,
obs_targets=targets, # Ground truth for loss
return_dict=True
)
print(f"Loss: {outputs.loss}")
print(f"Predictions: {outputs.predictions.shape}")
print(f"Metrics: {outputs.metrics}")
Model Configuration
This is the medium configuration:
latent_dim: 1024
encoder_blocks: 8
encoder_heads: 16
encoder_mlp_dim: 4096
transition_blocks: 12
transition_heads: 16
transition_mlp_dim: 4096
dropout: 0.1
Citation
@misc{siad_world_model,
title={SIAD: Satellite Imagery Anticipatory Dynamics},
author={OzLabs.ai},
year={2025},
howpublished={\url{https://huggingface.co/ozlabs/siad-wm-medium}},
}
Links
- Downloads last month
- 36