The Well: Diffusion & JEPA for PDE Dynamics
Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from The Well dataset collection by Polymathic AI.
Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories.
Architecture
Conditional DDPM (62M parameters)
| Component | Details |
|---|---|
| Backbone | U-Net with 4 resolution levels (64→128→256→512 channels) |
| Conditioning | Previous frame concatenated to noisy target along channel dim |
| Time encoding | Sinusoidal positional embedding → MLP (256-d) |
| Residual blocks | GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3 |
| Attention | Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens) |
| Noise schedule | Linear beta: 1e-4 → 0.02, 1000 timesteps |
| Parameterization | Epsilon-prediction (predict noise) |
| Sampling | DDPM (1000 steps) or DDIM (50 steps, deterministic) |
Input: [B, 8, 128, 384] ← 4ch noisy target + 4ch condition
↓ Conv3x3 → 64ch
↓ Level 0: 2×ResBlock(64), Downsample → 64×64×192
↓ Level 1: 2×ResBlock(128), Downsample → 128×32×96
↓ Level 2: 2×ResBlock(256), Downsample → 256×16×48
↓ Level 3: 2×ResBlock(512) + SelfAttention
↓ Middle: ResBlock + Attention + ResBlock (512ch)
↑ Level 3: 3×ResBlock(512) + Attention, Upsample
↑ Level 2: 3×ResBlock(256), Upsample
↑ Level 1: 3×ResBlock(128), Upsample
↑ Level 0: 3×ResBlock(64)
↓ GroupNorm → SiLU → Conv3x3
Output: [B, 4, 128, 384] ← predicted noise
Spatial JEPA (1.8M trainable parameters)
| Component | Details |
|---|---|
| Online encoder | ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8] |
| Target encoder | EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule) |
| Predictor | 3-layer CNN on spatial feature maps (128 → 256 → 128 channels) |
| Loss | Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features) |
The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning.
Training
Dataset
Trained on turbulent_radiative_layer_2D from The Well (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks):
- 2D turbulent radiative layer simulation
- Resolution: 128 × 384 spatial, 4 physical field channels
- 90 trajectories × 101 timesteps = 7,200 training samples
- 6.9 GB total (HDF5 format)
Diffusion Training Config
| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=1e-4, wd=0.01) |
| LR schedule | Cosine with 500-step warmup |
| Batch size | 8 |
| Mixed precision | bfloat16 |
| Gradient clipping | max_norm=1.0 |
| Epochs | 100 |
| GPU | NVIDIA RTX A6000 (48GB) |
| Training time | ~7 hours |
Diffusion Training Results
| Metric | Value |
|---|---|
| Final train loss | 0.028 |
| Val MSE (single-step) | 743.3 |
| Rollout MSE (10-step mean) | 805.1 |
Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the WandB run.
JEPA Training Config
| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=3e-4, wd=0.05) |
| LR schedule | Cosine with 500-step warmup |
| Batch size | 16 |
| Mixed precision | bfloat16 |
| Gradient clipping | max_norm=1.0 |
| EMA schedule | Cosine 0.996 → 1.0 |
| Epochs | 100 |
| GPU | NVIDIA RTX A6000 (48GB) |
| Training time | ~1.5 hours |
JEPA Training Results
| Metric | Value |
|---|---|
| Final train loss | 4.07 |
| Similarity (sim) | 0.079 |
| Variance (VICReg) | 1.476 |
| Covariance (VICReg) | 0.578 |
Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction.
Full JEPA training metrics available on the WandB run.
Usage
Installation
pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]"
Inference (generate next frame)
import torch
from unet import UNet
from diffusion import GaussianDiffusion
# Load model
device = "cuda"
unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8))
model = GaussianDiffusion(unet, timesteps=1000).to(device)
ckpt = torch.load("diffusion_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()
# Given a condition frame [1, 4, 128, 384]:
x_cond = ... # your input frame
x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
JEPA inference (extract dynamics embeddings)
import torch
from jepa import JEPA
device = "cuda"
model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device)
ckpt = torch.load("jepa_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()
# Given a frame [1, 4, 128, 384]:
x = ... # your input frame
z = model.online_encoder(x) # [1, 128, 16, 48] spatial latent map
Autoregressive rollout
# Generate 20-step trajectory
trajectory = [x_cond]
cond = x_cond
for step in range(20):
pred = model.sample_ddim(cond, steps=50, eta=0.0)
trajectory.append(pred)
cond = pred # feed prediction back as next condition
Training from scratch
# Download data locally (6.9 GB)
the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D
# Train diffusion with WandB logging + eval videos
python train_diffusion.py \
--no-streaming --local_path ./data/datasets \
--batch_size 8 --epochs 100 --wandb
# Train JEPA
python train_jepa.py \
--no-streaming --local_path ./data/datasets \
--batch_size 16 --epochs 100 --wandb
Streaming from HuggingFace (no download needed)
python train_diffusion.py --streaming --batch_size 4
Project Structure
| File | Description |
|---|---|
unet.py |
U-Net with time conditioning, skip connections, self-attention |
diffusion.py |
DDPM/DDIM framework: noise schedule, training loss, sampling |
jepa.py |
Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss |
data_pipeline.py |
Data loading from The Well (streaming HF or local HDF5) |
train_diffusion.py |
Diffusion training with eval, video logging, checkpointing |
train_jepa.py |
JEPA training with EMA schedule, VICReg metrics |
eval_utils.py |
Evaluation: single-step MSE, rollout videos, WandB media logging |
test_pipeline.py |
End-to-end verification script (data → forward → backward) |
diffusion_ep0099.pt |
Diffusion final checkpoint (epoch 99, 748MB) |
jepa_ep0099.pt |
JEPA final checkpoint (epoch 99, 23MB) |
Evaluation Details
Every 5 epochs, the training script runs:
- Single-step evaluation: DDIM-50 sampling on 4 validation batches, MSE against ground truth
- Multi-step rollout: 10-step autoregressive prediction from a validation sample
- Video logging: Side-by-side GT vs Prediction video logged to WandB as mp4
- Comparison images: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap)
- Rollout MSE curve: Per-step MSE showing prediction degradation over horizon
The Well Dataset
The Well is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change --dataset:
python train_diffusion.py --dataset active_matter # 51 GB, 256×256
python train_diffusion.py --dataset shear_flow # 115 GB, 128×256
python train_diffusion.py --dataset gray_scott_reaction_diffusion # 154 GB
Citation
@inproceedings{thewell2024,
title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning},
author={Polymathic AI},
booktitle={NeurIPS 2024 Datasets and Benchmarks},
year={2024}
}
License
Apache 2.0