The Well: Diffusion & JEPA for PDE Dynamics

Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from The Well dataset collection by Polymathic AI.

Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories.

Architecture

Conditional DDPM (62M parameters)

Component Details
Backbone U-Net with 4 resolution levels (64→128→256→512 channels)
Conditioning Previous frame concatenated to noisy target along channel dim
Time encoding Sinusoidal positional embedding → MLP (256-d)
Residual blocks GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3
Attention Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens)
Noise schedule Linear beta: 1e-4 → 0.02, 1000 timesteps
Parameterization Epsilon-prediction (predict noise)
Sampling DDPM (1000 steps) or DDIM (50 steps, deterministic)
Input: [B, 8, 128, 384]  ← 4ch noisy target + 4ch condition
  ↓ Conv3x3 → 64ch
  ↓ Level 0: 2×ResBlock(64),   Downsample → 64×64×192
  ↓ Level 1: 2×ResBlock(128),  Downsample → 128×32×96
  ↓ Level 2: 2×ResBlock(256),  Downsample → 256×16×48
  ↓ Level 3: 2×ResBlock(512) + SelfAttention
  ↓ Middle:  ResBlock + Attention + ResBlock (512ch)
  ↑ Level 3: 3×ResBlock(512) + Attention, Upsample
  ↑ Level 2: 3×ResBlock(256), Upsample
  ↑ Level 1: 3×ResBlock(128), Upsample
  ↑ Level 0: 3×ResBlock(64)
  ↓ GroupNorm → SiLU → Conv3x3
Output: [B, 4, 128, 384]  ← predicted noise

Spatial JEPA (1.8M trainable parameters)

Component Details
Online encoder ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8]
Target encoder EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule)
Predictor 3-layer CNN on spatial feature maps (128 → 256 → 128 channels)
Loss Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features)

The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning.

Training

Dataset

Trained on turbulent_radiative_layer_2D from The Well (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks):

  • 2D turbulent radiative layer simulation
  • Resolution: 128 × 384 spatial, 4 physical field channels
  • 90 trajectories × 101 timesteps = 7,200 training samples
  • 6.9 GB total (HDF5 format)

Diffusion Training Config

Parameter Value
Optimizer AdamW (lr=1e-4, wd=0.01)
LR schedule Cosine with 500-step warmup
Batch size 8
Mixed precision bfloat16
Gradient clipping max_norm=1.0
Epochs 100
GPU NVIDIA RTX A6000 (48GB)
Training time ~7 hours

Diffusion Training Results

Metric Value
Final train loss 0.028
Val MSE (single-step) 743.3
Rollout MSE (10-step mean) 805.1

Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the WandB run.

JEPA Training Config

Parameter Value
Optimizer AdamW (lr=3e-4, wd=0.05)
LR schedule Cosine with 500-step warmup
Batch size 16
Mixed precision bfloat16
Gradient clipping max_norm=1.0
EMA schedule Cosine 0.996 → 1.0
Epochs 100
GPU NVIDIA RTX A6000 (48GB)
Training time ~1.5 hours

JEPA Training Results

Metric Value
Final train loss 4.07
Similarity (sim) 0.079
Variance (VICReg) 1.476
Covariance (VICReg) 0.578

Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction.

Full JEPA training metrics available on the WandB run.

Usage

Installation

pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]"

Inference (generate next frame)

import torch
from unet import UNet
from diffusion import GaussianDiffusion

# Load model
device = "cuda"
unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8))
model = GaussianDiffusion(unet, timesteps=1000).to(device)

ckpt = torch.load("diffusion_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# Given a condition frame [1, 4, 128, 384]:
x_cond = ...  # your input frame
x_pred = model.sample_ddim(x_cond, steps=50)  # fast DDIM sampling

JEPA inference (extract dynamics embeddings)

import torch
from jepa import JEPA

device = "cuda"
model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device)

ckpt = torch.load("jepa_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# Given a frame [1, 4, 128, 384]:
x = ...  # your input frame
z = model.online_encoder(x)  # [1, 128, 16, 48] spatial latent map

Autoregressive rollout

# Generate 20-step trajectory
trajectory = [x_cond]
cond = x_cond
for step in range(20):
    pred = model.sample_ddim(cond, steps=50, eta=0.0)
    trajectory.append(pred)
    cond = pred  # feed prediction back as next condition

Training from scratch

# Download data locally (6.9 GB)
the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D

# Train diffusion with WandB logging + eval videos
python train_diffusion.py \
  --no-streaming --local_path ./data/datasets \
  --batch_size 8 --epochs 100 --wandb

# Train JEPA
python train_jepa.py \
  --no-streaming --local_path ./data/datasets \
  --batch_size 16 --epochs 100 --wandb

Streaming from HuggingFace (no download needed)

python train_diffusion.py --streaming --batch_size 4

Project Structure

File Description
unet.py U-Net with time conditioning, skip connections, self-attention
diffusion.py DDPM/DDIM framework: noise schedule, training loss, sampling
jepa.py Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss
data_pipeline.py Data loading from The Well (streaming HF or local HDF5)
train_diffusion.py Diffusion training with eval, video logging, checkpointing
train_jepa.py JEPA training with EMA schedule, VICReg metrics
eval_utils.py Evaluation: single-step MSE, rollout videos, WandB media logging
test_pipeline.py End-to-end verification script (data → forward → backward)
diffusion_ep0099.pt Diffusion final checkpoint (epoch 99, 748MB)
jepa_ep0099.pt JEPA final checkpoint (epoch 99, 23MB)

Evaluation Details

Every 5 epochs, the training script runs:

  1. Single-step evaluation: DDIM-50 sampling on 4 validation batches, MSE against ground truth
  2. Multi-step rollout: 10-step autoregressive prediction from a validation sample
  3. Video logging: Side-by-side GT vs Prediction video logged to WandB as mp4
  4. Comparison images: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap)
  5. Rollout MSE curve: Per-step MSE showing prediction degradation over horizon

The Well Dataset

The Well is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change --dataset:

python train_diffusion.py --dataset active_matter          # 51 GB, 256×256
python train_diffusion.py --dataset shear_flow             # 115 GB, 128×256
python train_diffusion.py --dataset gray_scott_reaction_diffusion  # 154 GB

Citation

@inproceedings{thewell2024,
  title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning},
  author={Polymathic AI},
  booktitle={NeurIPS 2024 Datasets and Benchmarks},
  year={2024}
}

License

Apache 2.0

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support

Dataset used to train Vikhrmodels/the-well-diffusion