Sat-JEPA-Diff

Bridging Self-Supervised Learning and Generative Diffusion for Satellite Image Forecasting

Accepted at ICLR 2026 — Machine Learning for Remote Sensing (ML4RS) Workshop

Paper GitHub HF Dataset Dataset

Model Description

Sat-JEPA-Diff predicts future satellite images (t → t+1) by combining self-supervised semantic prediction with generative diffusion:

  1. IJEPA module encodes the current image and predicts future semantic embeddings (what will be there)
  2. Conditioning adapter transforms these embeddings into cross-attention signals
  3. Frozen Stable Diffusion 3.5 Medium + LoRA generates high-fidelity RGB imagery (how it looks)

This approach produces sharp, realistic predictions that preserve roads, buildings, and vegetation boundaries — details that traditional deterministic methods blur away.

Key Results

Model L1 ↓ MSE ↓ PSNR ↑ SSIM ↑ GSSIM ↑ LPIPS ↓ FID ↓
Deterministic Baselines
Default 0.0131 0.0008 37.52 0.9361 0.7858 0.0708 0.6959
PredRNN 0.0117 0.0005 38.38 0.9476 0.7836 0.0726 9.9720
SimVP v2 0.0131 0.0006 37.63 0.9391 0.7719 0.0928 18.7208
Generative Models
Stable Diff. 3.5 0.0175 0.0005 32.98 0.8398 0.8711 0.4528 0.1533
MCVD 0.0314 0.0031 31.28 0.8637 0.7665 0.1890 0.1956
Ours 0.0158 0.0004 33.81 0.8672 0.8984 0.4449 0.1475

Our model achieves +11% GSSIM over the best baseline, confirming superior preservation of geospatial boundaries and structural gradients.

Architecture Details

Component Specification
IJEPA Encoder ViT-Base, patch size 8, input 128×128
IJEPA Predictor 6-layer transformer, embed dim 384
Conditioning Adapter ~25M params, multi-stream fusion
Diffusion Backbone SD 3.5 Medium (frozen) + LoRA (rank 8, alpha 16)
VAE 8× spatial compression, 16 latent channels

How to Use

Note: This is a custom PyTorch checkpoint, not a standard transformers model. You need the source code from the GitHub repository to load and run inference.

1. Setup

git clone https://github.com/VU-AIML/SAT-JEPA-DIFF.git
cd SAT-JEPA-DIFF

conda create -n satjepa python=3.12
conda activate satjepa

pip install torch torchvision --index-url https://download.pytorch.org/whl/cu121
pip install diffusers transformers peft accelerate
pip install rasterio matplotlib pyyaml lpips

You need a Hugging Face token with access to Stable Diffusion 3.5 Medium:

export HF_TOKEN=hf_your_token_here

2. Download the Checkpoint

from huggingface_hub import hf_hub_download

checkpoint_path = hf_hub_download(
    repo_id="kursatkomurcu/SAT-JEPA-DIFF",
    filename="s2_future_jepa-best.pth.tar",
)

Or via CLI:

huggingface-cli download kursatkomurcu/SAT-JEPA-DIFF s2_future_jepa-best.pth.tar --local-dir ./checkpoints

3. Run Inference

cd SAT-JEPA-DIFF/src

python inference.py \
    --checkpoint /path/to/s2_future_jepa-best.pth.tar \
    --output_dir ./results \
    --diffusion_steps 20 \
    --noise_strength 0.35

4. Programmatic Usage

import torch
from helper import init_model
from sd_models import load_sd_model
from sd_joint_loss import diffusion_sample
from inference import load_model, predict_next_frame, load_and_resize_tif

device = torch.device("cuda")

# Load all components from checkpoint
encoder, predictor, sd_state, embed_dim = load_model(
    checkpoint_path="path/to/s2_future_jepa-best.pth.tar",
    device=device,
)

# Load a Sentinel-2 GeoTIFF (128x128, [0,1] range)
rgb_t = load_and_resize_tif("path/to/sentinel2_image.tif", target_size=128)

# Predict next time step
rgb_t1_pred = predict_next_frame(
    rgb_t=rgb_t,
    encoder=encoder,
    predictor=predictor,
    sd_state=sd_state,
    device=device,
    num_diffusion_steps=20,
    noise_strength=0.35,
)
# rgb_t1_pred: (3, 128, 128) tensor in [0, 1]

Checkpoint Contents

The .pth.tar file contains the following keys:

Key Description
encoder IJEPA ViT encoder state dict
predictor IJEPA predictor state dict
target_encoder EMA target encoder state dict
proj_head Projection head (768 → 64) state dict
cond_adapter Multi-caption conditioning adapter state dict
lora_state_dict LoRA weights for SD 3.5 UNet
prompt_embeds Pre-encoded text prompt embeddings
pooled_prompt_embeds Pooled prompt embeddings
config Training configuration dict
epoch Training epoch
optimizer Optimizer state

Training

See the GitHub repository for full training instructions. Training takes approximately 5 days on a single NVIDIA RTX 5090 (24GB).

Dataset

Sentinel-2 RGB imagery (10m GSD) paired with Alpha Earth Foundation Model embeddings across 100 global Regions of Interest (2017–2024). Available on Zenodo.

Requirements

Citation

@inproceedings{
komurcu2026satjepadiff,
title={Sat-{JEPA}-Diff: Bridging Self-Supervised Learning and Generative Diffusion for Remote Sensing},
author={Kursat Komurcu and Linas Petkevicius},
booktitle={4th ICLR Workshop on Machine Learning for Remote Sensing (Main Track)},
year={2026},
url={https://openreview.net/forum?id=WBHfQLbgZR}
}

References

Acknowledgments

This project was funded by the European Union (project No S-MIP-23-45) under the agreement with the Research Council of Lithuania (LMTLT).

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Paper for kursatkomurcu/SAT-JEPA-DIFF