File size: 8,677 Bytes
8292899 591832c 8292899 591832c 8292899 591832c 8292899 591832c 8292899 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 |
---
license: apache-2.0
tags:
- physics
- diffusion
- jepa
- pde
- simulation
- the-well
- ddpm
- ddim
- spatiotemporal
datasets:
- polymathic-ai/turbulent_radiative_layer_2D
language:
- en
pipeline_tag: image-to-image
---
# The Well: Diffusion & JEPA for PDE Dynamics
Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from [The Well](https://polymathic-ai.org/the_well/) dataset collection by Polymathic AI.
Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories.
## Architecture
### Conditional DDPM (62M parameters)
| Component | Details |
|---|---|
| **Backbone** | U-Net with 4 resolution levels (64→128→256→512 channels) |
| **Conditioning** | Previous frame concatenated to noisy target along channel dim |
| **Time encoding** | Sinusoidal positional embedding → MLP (256-d) |
| **Residual blocks** | GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3 |
| **Attention** | Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens) |
| **Noise schedule** | Linear beta: 1e-4 → 0.02, 1000 timesteps |
| **Parameterization** | Epsilon-prediction (predict noise) |
| **Sampling** | DDPM (1000 steps) or DDIM (50 steps, deterministic) |
```
Input: [B, 8, 128, 384] ← 4ch noisy target + 4ch condition
↓ Conv3x3 → 64ch
↓ Level 0: 2×ResBlock(64), Downsample → 64×64×192
↓ Level 1: 2×ResBlock(128), Downsample → 128×32×96
↓ Level 2: 2×ResBlock(256), Downsample → 256×16×48
↓ Level 3: 2×ResBlock(512) + SelfAttention
↓ Middle: ResBlock + Attention + ResBlock (512ch)
↑ Level 3: 3×ResBlock(512) + Attention, Upsample
↑ Level 2: 3×ResBlock(256), Upsample
↑ Level 1: 3×ResBlock(128), Upsample
↑ Level 0: 3×ResBlock(64)
↓ GroupNorm → SiLU → Conv3x3
Output: [B, 4, 128, 384] ← predicted noise
```
### Spatial JEPA (1.8M trainable parameters)
| Component | Details |
|---|---|
| **Online encoder** | ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8] |
| **Target encoder** | EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule) |
| **Predictor** | 3-layer CNN on spatial feature maps (128 → 256 → 128 channels) |
| **Loss** | Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features) |
The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning.
## Training
### Dataset
Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-ai.org/the_well/) (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks):
- 2D turbulent radiative layer simulation
- Resolution: 128 × 384 spatial, 4 physical field channels
- 90 trajectories × 101 timesteps = 7,200 training samples
- 6.9 GB total (HDF5 format)
### Diffusion Training Config
| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=1e-4, wd=0.01) |
| LR schedule | Cosine with 500-step warmup |
| Batch size | 8 |
| Mixed precision | bfloat16 |
| Gradient clipping | max_norm=1.0 |
| Epochs | 100 |
| GPU | NVIDIA RTX A6000 (48GB) |
| Training time | ~7 hours |
### Diffusion Training Results
| Metric | Value |
|---|---|
| Final train loss | 0.028 |
| Val MSE (single-step) | 743.3 |
| Rollout MSE (10-step mean) | 805.1 |
Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).
### JEPA Training Config
| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=3e-4, wd=0.05) |
| LR schedule | Cosine with 500-step warmup |
| Batch size | 16 |
| Mixed precision | bfloat16 |
| Gradient clipping | max_norm=1.0 |
| EMA schedule | Cosine 0.996 → 1.0 |
| Epochs | 100 |
| GPU | NVIDIA RTX A6000 (48GB) |
| Training time | ~1.5 hours |
### JEPA Training Results
| Metric | Value |
|---|---|
| Final train loss | 4.07 |
| Similarity (sim) | 0.079 |
| Variance (VICReg) | 1.476 |
| Covariance (VICReg) | 0.578 |
Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction.
Full JEPA training metrics available on the [WandB run](https://wandb.ai/alexwortega/the-well-jepa/runs/obwyebcv).
## Usage
### Installation
```bash
pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]"
```
### Inference (generate next frame)
```python
import torch
from unet import UNet
from diffusion import GaussianDiffusion
# Load model
device = "cuda"
unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8))
model = GaussianDiffusion(unet, timesteps=1000).to(device)
ckpt = torch.load("diffusion_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()
# Given a condition frame [1, 4, 128, 384]:
x_cond = ... # your input frame
x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
```
### JEPA inference (extract dynamics embeddings)
```python
import torch
from jepa import JEPA
device = "cuda"
model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device)
ckpt = torch.load("jepa_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()
# Given a frame [1, 4, 128, 384]:
x = ... # your input frame
z = model.online_encoder(x) # [1, 128, 16, 48] spatial latent map
```
### Autoregressive rollout
```python
# Generate 20-step trajectory
trajectory = [x_cond]
cond = x_cond
for step in range(20):
pred = model.sample_ddim(cond, steps=50, eta=0.0)
trajectory.append(pred)
cond = pred # feed prediction back as next condition
```
### Training from scratch
```bash
# Download data locally (6.9 GB)
the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D
# Train diffusion with WandB logging + eval videos
python train_diffusion.py \
--no-streaming --local_path ./data/datasets \
--batch_size 8 --epochs 100 --wandb
# Train JEPA
python train_jepa.py \
--no-streaming --local_path ./data/datasets \
--batch_size 16 --epochs 100 --wandb
```
### Streaming from HuggingFace (no download needed)
```bash
python train_diffusion.py --streaming --batch_size 4
```
## Project Structure
| File | Description |
|---|---|
| `unet.py` | U-Net with time conditioning, skip connections, self-attention |
| `diffusion.py` | DDPM/DDIM framework: noise schedule, training loss, sampling |
| `jepa.py` | Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss |
| `data_pipeline.py` | Data loading from The Well (streaming HF or local HDF5) |
| `train_diffusion.py` | Diffusion training with eval, video logging, checkpointing |
| `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
| `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
| `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
| `diffusion_ep0099.pt` | Diffusion final checkpoint (epoch 99, 748MB) |
| `jepa_ep0099.pt` | JEPA final checkpoint (epoch 99, 23MB) |
## Evaluation Details
Every 5 epochs, the training script runs:
1. **Single-step evaluation**: DDIM-50 sampling on 4 validation batches, MSE against ground truth
2. **Multi-step rollout**: 10-step autoregressive prediction from a validation sample
3. **Video logging**: Side-by-side GT vs Prediction video logged to WandB as mp4
4. **Comparison images**: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap)
5. **Rollout MSE curve**: Per-step MSE showing prediction degradation over horizon
## The Well Dataset
[The Well](https://polymathic-ai.org/the_well/) is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change `--dataset`:
```bash
python train_diffusion.py --dataset active_matter # 51 GB, 256×256
python train_diffusion.py --dataset shear_flow # 115 GB, 128×256
python train_diffusion.py --dataset gray_scott_reaction_diffusion # 154 GB
```
## Citation
```bibtex
@inproceedings{thewell2024,
title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning},
author={Polymathic AI},
booktitle={NeurIPS 2024 Datasets and Benchmarks},
year={2024}
}
```
## License
Apache 2.0
|