--- license: apache-2.0 tags: - physics - diffusion - jepa - pde - simulation - the-well - ddpm - ddim - spatiotemporal datasets: - polymathic-ai/turbulent_radiative_layer_2D language: - en pipeline_tag: image-to-image --- # The Well: Diffusion & JEPA for PDE Dynamics Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from [The Well](https://polymathic-ai.org/the_well/) dataset collection by Polymathic AI. Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories. ## Architecture ### Conditional DDPM (62M parameters) | Component | Details | |---|---| | **Backbone** | U-Net with 4 resolution levels (64→128→256→512 channels) | | **Conditioning** | Previous frame concatenated to noisy target along channel dim | | **Time encoding** | Sinusoidal positional embedding → MLP (256-d) | | **Residual blocks** | GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3 | | **Attention** | Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens) | | **Noise schedule** | Linear beta: 1e-4 → 0.02, 1000 timesteps | | **Parameterization** | Epsilon-prediction (predict noise) | | **Sampling** | DDPM (1000 steps) or DDIM (50 steps, deterministic) | ``` Input: [B, 8, 128, 384] ← 4ch noisy target + 4ch condition ↓ Conv3x3 → 64ch ↓ Level 0: 2×ResBlock(64), Downsample → 64×64×192 ↓ Level 1: 2×ResBlock(128), Downsample → 128×32×96 ↓ Level 2: 2×ResBlock(256), Downsample → 256×16×48 ↓ Level 3: 2×ResBlock(512) + SelfAttention ↓ Middle: ResBlock + Attention + ResBlock (512ch) ↑ Level 3: 3×ResBlock(512) + Attention, Upsample ↑ Level 2: 3×ResBlock(256), Upsample ↑ Level 1: 3×ResBlock(128), Upsample ↑ Level 0: 3×ResBlock(64) ↓ GroupNorm → SiLU → Conv3x3 Output: [B, 4, 128, 384] ← predicted noise ``` ### Spatial JEPA (1.8M trainable parameters) | Component | Details | |---|---| | **Online encoder** | ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8] | | **Target encoder** | EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule) | | **Predictor** | 3-layer CNN on spatial feature maps (128 → 256 → 128 channels) | | **Loss** | Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features) | The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning. ## Training ### Dataset Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-ai.org/the_well/) (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks): - 2D turbulent radiative layer simulation - Resolution: 128 × 384 spatial, 4 physical field channels - 90 trajectories × 101 timesteps = 7,200 training samples - 6.9 GB total (HDF5 format) ### Diffusion Training Config | Parameter | Value | |---|---| | Optimizer | AdamW (lr=1e-4, wd=0.01) | | LR schedule | Cosine with 500-step warmup | | Batch size | 8 | | Mixed precision | bfloat16 | | Gradient clipping | max_norm=1.0 | | Epochs | 100 | | GPU | NVIDIA RTX A6000 (48GB) | | Training time | ~7 hours | ### Diffusion Training Results | Metric | Value | |---|---| | Final train loss | 0.028 | | Val MSE (single-step) | 743.3 | | Rollout MSE (10-step mean) | 805.1 | Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9). ### JEPA Training Config | Parameter | Value | |---|---| | Optimizer | AdamW (lr=3e-4, wd=0.05) | | LR schedule | Cosine with 500-step warmup | | Batch size | 16 | | Mixed precision | bfloat16 | | Gradient clipping | max_norm=1.0 | | EMA schedule | Cosine 0.996 → 1.0 | | Epochs | 100 | | GPU | NVIDIA RTX A6000 (48GB) | | Training time | ~1.5 hours | ### JEPA Training Results | Metric | Value | |---|---| | Final train loss | 4.07 | | Similarity (sim) | 0.079 | | Variance (VICReg) | 1.476 | | Covariance (VICReg) | 0.578 | Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction. Full JEPA training metrics available on the [WandB run](https://wandb.ai/alexwortega/the-well-jepa/runs/obwyebcv). ## Usage ### Installation ```bash pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]" ``` ### Inference (generate next frame) ```python import torch from unet import UNet from diffusion import GaussianDiffusion # Load model device = "cuda" unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8)) model = GaussianDiffusion(unet, timesteps=1000).to(device) ckpt = torch.load("diffusion_ep0099.pt", map_location=device) model.load_state_dict(ckpt["model"]) model.eval() # Given a condition frame [1, 4, 128, 384]: x_cond = ... # your input frame x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling ``` ### JEPA inference (extract dynamics embeddings) ```python import torch from jepa import JEPA device = "cuda" model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device) ckpt = torch.load("jepa_ep0099.pt", map_location=device) model.load_state_dict(ckpt["model"]) model.eval() # Given a frame [1, 4, 128, 384]: x = ... # your input frame z = model.online_encoder(x) # [1, 128, 16, 48] spatial latent map ``` ### Autoregressive rollout ```python # Generate 20-step trajectory trajectory = [x_cond] cond = x_cond for step in range(20): pred = model.sample_ddim(cond, steps=50, eta=0.0) trajectory.append(pred) cond = pred # feed prediction back as next condition ``` ### Training from scratch ```bash # Download data locally (6.9 GB) the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D # Train diffusion with WandB logging + eval videos python train_diffusion.py \ --no-streaming --local_path ./data/datasets \ --batch_size 8 --epochs 100 --wandb # Train JEPA python train_jepa.py \ --no-streaming --local_path ./data/datasets \ --batch_size 16 --epochs 100 --wandb ``` ### Streaming from HuggingFace (no download needed) ```bash python train_diffusion.py --streaming --batch_size 4 ``` ## Project Structure | File | Description | |---|---| | `unet.py` | U-Net with time conditioning, skip connections, self-attention | | `diffusion.py` | DDPM/DDIM framework: noise schedule, training loss, sampling | | `jepa.py` | Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss | | `data_pipeline.py` | Data loading from The Well (streaming HF or local HDF5) | | `train_diffusion.py` | Diffusion training with eval, video logging, checkpointing | | `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics | | `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging | | `test_pipeline.py` | End-to-end verification script (data → forward → backward) | | `diffusion_ep0099.pt` | Diffusion final checkpoint (epoch 99, 748MB) | | `jepa_ep0099.pt` | JEPA final checkpoint (epoch 99, 23MB) | ## Evaluation Details Every 5 epochs, the training script runs: 1. **Single-step evaluation**: DDIM-50 sampling on 4 validation batches, MSE against ground truth 2. **Multi-step rollout**: 10-step autoregressive prediction from a validation sample 3. **Video logging**: Side-by-side GT vs Prediction video logged to WandB as mp4 4. **Comparison images**: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap) 5. **Rollout MSE curve**: Per-step MSE showing prediction degradation over horizon ## The Well Dataset [The Well](https://polymathic-ai.org/the_well/) is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change `--dataset`: ```bash python train_diffusion.py --dataset active_matter # 51 GB, 256×256 python train_diffusion.py --dataset shear_flow # 115 GB, 128×256 python train_diffusion.py --dataset gray_scott_reaction_diffusion # 154 GB ``` ## Citation ```bibtex @inproceedings{thewell2024, title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning}, author={Polymathic AI}, booktitle={NeurIPS 2024 Datasets and Benchmarks}, year={2024} } ``` ## License Apache 2.0