SIAD World Model (Medium)

Satellite Imagery Anticipatory Dynamics - A transformer-based world model for predicting future satellite observations.

Model Description

This model predicts future satellite imagery based on:

  • Current satellite observations (Sentinel-2, Sentinel-1, VIIRS nightlights)
  • Climate action variables (rainfall and temperature anomalies)

Uses JEPA (Joint Embedding Predictive Architecture) with token-based spatial representations.

Architecture

  • Size: 378,914,568 parameters (378.9M)
  • Type: medium variant
  • Latent Dimension: 1024
  • Encoder: 8 transformer blocks, 16 heads
  • Transition Model: 12 transformer blocks, 16 heads
  • Decoder: ConvTranspose (16×16 → 256×256)
  • Spatial Tokens: 256 tokens (16×16 grid)
  • Input Channels: 8 (Sentinel-2: B2,B3,B4,B8 | Sentinel-1: VV,VH | VIIRS | mask)
  • Rollout Horizon: 6 months

Training

  • Best Val Loss: 0.0131
  • Epochs: 85
  • Decoder Val Loss: 0.0711 (MSE in pixel space)

Quick Start

from transformers import AutoModel
import torch

# Load model from HuggingFace Hub
model = AutoModel.from_pretrained("ozlabs/siad-wm-medium", trust_remote_code=True)
model.inference_mode()

# Prepare inputs
obs_context = torch.randn(1, 8, 256, 256)  # Current observation
actions = torch.randn(1, 6, 2)  # 6-month climate actions

# Run prediction (with decoder for pixel-space output)
with torch.no_grad():
    z0 = model.encode(obs_context)
    z_pred = model.rollout(z0, actions, H=6)
    x_pred = model.decode(z_pred)  # [1, 6, 8, 256, 256] - Decode to pixels

print(f"Predicted 6 months (pixels): {{x_pred.shape}}")

Visualization (RGB Composite)

import numpy as np
import matplotlib.pyplot as plt

def create_rgb(bands: np.ndarray) -> np.ndarray:
    """Create RGB composite from 8-band satellite image"""
    rgb = bands[[2, 1, 0]].transpose(1, 2, 0)  # [H, W, 3]

    for i in range(3):
        channel = rgb[:, :, i]
        vmin, vmax = np.percentile(channel, [2, 98])
        rgb[:, :, i] = np.clip((channel - vmin) / (vmax - vmin + 1e-8), 0, 1)

    return rgb

# Visualize first prediction
x_first = x_pred[0, 0].cpu().numpy()  # [8, 256, 256]
rgb = create_rgb(x_first)
plt.imshow(rgb)
plt.axis("off")
plt.show()

Advanced Usage

# Full forward pass with loss computation
outputs = model(
    obs_context=obs_context,
    actions_rollout=actions,
    obs_targets=targets,  # Ground truth for loss
    return_dict=True
)

print(f"Loss: {outputs.loss}")
print(f"Predictions: {outputs.predictions.shape}")
print(f"Metrics: {outputs.metrics}")

Model Configuration

This is the medium configuration:

latent_dim: 1024
encoder_blocks: 8
encoder_heads: 16
encoder_mlp_dim: 4096
transition_blocks: 12
transition_heads: 16
transition_mlp_dim: 4096
dropout: 0.1

Citation

@misc{siad_world_model,
    title={SIAD: Satellite Imagery Anticipatory Dynamics},
    author={OzLabs.ai},
    year={2025},
    howpublished={\url{https://huggingface.co/ozlabs/siad-wm-medium}},
}

Links

Downloads last month
36
Safetensors
Model size
0.4B params
Tensor type
F32
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support