Sat-JEPA-Diff
Bridging Self-Supervised Learning and Generative Diffusion for Satellite Image Forecasting
Accepted at ICLR 2026 — Machine Learning for Remote Sensing (ML4RS) Workshop
Model Description
Sat-JEPA-Diff predicts future satellite images (t → t+1) by combining self-supervised semantic prediction with generative diffusion:
- IJEPA module encodes the current image and predicts future semantic embeddings (what will be there)
- Conditioning adapter transforms these embeddings into cross-attention signals
- Frozen Stable Diffusion 3.5 Medium + LoRA generates high-fidelity RGB imagery (how it looks)
This approach produces sharp, realistic predictions that preserve roads, buildings, and vegetation boundaries — details that traditional deterministic methods blur away.
Key Results
| Model | L1 ↓ | MSE ↓ | PSNR ↑ | SSIM ↑ | GSSIM ↑ | LPIPS ↓ | FID ↓ |
|---|---|---|---|---|---|---|---|
| Deterministic Baselines | |||||||
| Default | 0.0131 | 0.0008 | 37.52 | 0.9361 | 0.7858 | 0.0708 | 0.6959 |
| PredRNN | 0.0117 | 0.0005 | 38.38 | 0.9476 | 0.7836 | 0.0726 | 9.9720 |
| SimVP v2 | 0.0131 | 0.0006 | 37.63 | 0.9391 | 0.7719 | 0.0928 | 18.7208 |
| Generative Models | |||||||
| Stable Diff. 3.5 | 0.0175 | 0.0005 | 32.98 | 0.8398 | 0.8711 | 0.4528 | 0.1533 |
| MCVD | 0.0314 | 0.0031 | 31.28 | 0.8637 | 0.7665 | 0.1890 | 0.1956 |
| Ours | 0.0158 | 0.0004 | 33.81 | 0.8672 | 0.8984 | 0.4449 | 0.1475 |
Our model achieves +11% GSSIM over the best baseline, confirming superior preservation of geospatial boundaries and structural gradients.
Architecture Details
| Component | Specification |
|---|---|
| IJEPA Encoder | ViT-Base, patch size 8, input 128×128 |
| IJEPA Predictor | 6-layer transformer, embed dim 384 |
| Conditioning Adapter | ~25M params, multi-stream fusion |
| Diffusion Backbone | SD 3.5 Medium (frozen) + LoRA (rank 8, alpha 16) |
| VAE | 8× spatial compression, 16 latent channels |
How to Use
Note: This is a custom PyTorch checkpoint, not a standard
transformersmodel. You need the source code from the GitHub repository to load and run inference.
1. Setup
git clone https://github.com/VU-AIML/SAT-JEPA-DIFF.git
cd SAT-JEPA-DIFF
conda create -n satjepa python=3.12
conda activate satjepa
pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install diffusers transformers peft accelerate
pip install rasterio matplotlib pyyaml lpips
You need a Hugging Face token with access to Stable Diffusion 3.5 Medium:
export HF_TOKEN=hf_your_token_here
2. Download the Checkpoint
from huggingface_hub import hf_hub_download
checkpoint_path = hf_hub_download(
repo_id="kursatkomurcu/SAT-JEPA-DIFF",
filename="s2_future_jepa-best.pth.tar",
)
Or via CLI:
huggingface-cli download kursatkomurcu/SAT-JEPA-DIFF s2_future_jepa-best.pth.tar --local-dir ./checkpoints
3. Run Inference
cd SAT-JEPA-DIFF/src
python inference.py \
--checkpoint /path/to/s2_future_jepa-best.pth.tar \
--output_dir ./results \
--diffusion_steps 20 \
--noise_strength 0.35
4. Programmatic Usage
import torch
from helper import init_model
from sd_models import load_sd_model
from sd_joint_loss import diffusion_sample
from inference import load_model, predict_next_frame, load_and_resize_tif
device = torch.device("cuda")
# Load all components from checkpoint
encoder, predictor, sd_state, embed_dim = load_model(
checkpoint_path="path/to/s2_future_jepa-best.pth.tar",
device=device,
)
# Load a Sentinel-2 GeoTIFF (128x128, [0,1] range)
rgb_t = load_and_resize_tif("path/to/sentinel2_image.tif", target_size=128)
# Predict next time step
rgb_t1_pred = predict_next_frame(
rgb_t=rgb_t,
encoder=encoder,
predictor=predictor,
sd_state=sd_state,
device=device,
num_diffusion_steps=20,
noise_strength=0.35,
)
# rgb_t1_pred: (3, 128, 128) tensor in [0, 1]
Checkpoint Contents
The .pth.tar file contains the following keys:
| Key | Description |
|---|---|
encoder |
IJEPA ViT encoder state dict |
predictor |
IJEPA predictor state dict |
target_encoder |
EMA target encoder state dict |
proj_head |
Projection head (768 → 64) state dict |
cond_adapter |
Multi-caption conditioning adapter state dict |
lora_state_dict |
LoRA weights for SD 3.5 UNet |
prompt_embeds |
Pre-encoded text prompt embeddings |
pooled_prompt_embeds |
Pooled prompt embeddings |
config |
Training configuration dict |
epoch |
Training epoch |
optimizer |
Optimizer state |
Training
See the GitHub repository for full training instructions. Training takes approximately 5 days on a single NVIDIA RTX 5090 (24GB).
Dataset
Sentinel-2 RGB imagery (10m GSD) paired with Alpha Earth Foundation Model embeddings across 100 global Regions of Interest (2017–2024). Available on Zenodo.
Requirements
- Python 3.12+
- PyTorch 2.0+ with CUDA
- ~24GB GPU VRAM (RTX 3090/4090/5090 or A100)
- Stable Diffusion 3.5 Medium access via HF token
Citation
@inproceedings{
komurcu2026satjepadiff,
title={Sat-{JEPA}-Diff: Bridging Self-Supervised Learning and Generative Diffusion for Remote Sensing},
author={Kursat Komurcu and Linas Petkevicius},
booktitle={4th ICLR Workshop on Machine Learning for Remote Sensing (Main Track)},
year={2026},
url={https://openreview.net/forum?id=WBHfQLbgZR}
}
References
Acknowledgments
This project was funded by the European Union (project No S-MIP-23-45) under the agreement with the Research Council of Lithuania (LMTLT).