|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- physics |
|
|
- diffusion |
|
|
- jepa |
|
|
- pde |
|
|
- simulation |
|
|
- the-well |
|
|
- ddpm |
|
|
- ddim |
|
|
- spatiotemporal |
|
|
datasets: |
|
|
- polymathic-ai/turbulent_radiative_layer_2D |
|
|
language: |
|
|
- en |
|
|
pipeline_tag: image-to-image |
|
|
--- |
|
|
|
|
|
# The Well: Diffusion & JEPA for PDE Dynamics |
|
|
|
|
|
Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from [The Well](https://polymathic-ai.org/the_well/) dataset collection by Polymathic AI. |
|
|
|
|
|
Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories. |
|
|
|
|
|
## Architecture |
|
|
|
|
|
### Conditional DDPM (62M parameters) |
|
|
|
|
|
| Component | Details | |
|
|
|---|---| |
|
|
| **Backbone** | U-Net with 4 resolution levels (64→128→256→512 channels) | |
|
|
| **Conditioning** | Previous frame concatenated to noisy target along channel dim | |
|
|
| **Time encoding** | Sinusoidal positional embedding → MLP (256-d) | |
|
|
| **Residual blocks** | GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3 | |
|
|
| **Attention** | Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens) | |
|
|
| **Noise schedule** | Linear beta: 1e-4 → 0.02, 1000 timesteps | |
|
|
| **Parameterization** | Epsilon-prediction (predict noise) | |
|
|
| **Sampling** | DDPM (1000 steps) or DDIM (50 steps, deterministic) | |
|
|
|
|
|
``` |
|
|
Input: [B, 8, 128, 384] ← 4ch noisy target + 4ch condition |
|
|
↓ Conv3x3 → 64ch |
|
|
↓ Level 0: 2×ResBlock(64), Downsample → 64×64×192 |
|
|
↓ Level 1: 2×ResBlock(128), Downsample → 128×32×96 |
|
|
↓ Level 2: 2×ResBlock(256), Downsample → 256×16×48 |
|
|
↓ Level 3: 2×ResBlock(512) + SelfAttention |
|
|
↓ Middle: ResBlock + Attention + ResBlock (512ch) |
|
|
↑ Level 3: 3×ResBlock(512) + Attention, Upsample |
|
|
↑ Level 2: 3×ResBlock(256), Upsample |
|
|
↑ Level 1: 3×ResBlock(128), Upsample |
|
|
↑ Level 0: 3×ResBlock(64) |
|
|
↓ GroupNorm → SiLU → Conv3x3 |
|
|
Output: [B, 4, 128, 384] ← predicted noise |
|
|
``` |
|
|
|
|
|
### Spatial JEPA (1.8M trainable parameters) |
|
|
|
|
|
| Component | Details | |
|
|
|---|---| |
|
|
| **Online encoder** | ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8] | |
|
|
| **Target encoder** | EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule) | |
|
|
| **Predictor** | 3-layer CNN on spatial feature maps (128 → 256 → 128 channels) | |
|
|
| **Loss** | Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features) | |
|
|
|
|
|
The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning. |
|
|
|
|
|
## Training |
|
|
|
|
|
### Dataset |
|
|
|
|
|
Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-ai.org/the_well/) (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks): |
|
|
- 2D turbulent radiative layer simulation |
|
|
- Resolution: 128 × 384 spatial, 4 physical field channels |
|
|
- 90 trajectories × 101 timesteps = 7,200 training samples |
|
|
- 6.9 GB total (HDF5 format) |
|
|
|
|
|
### Diffusion Training Config |
|
|
|
|
|
| Parameter | Value | |
|
|
|---|---| |
|
|
| Optimizer | AdamW (lr=1e-4, wd=0.01) | |
|
|
| LR schedule | Cosine with 500-step warmup | |
|
|
| Batch size | 8 | |
|
|
| Mixed precision | bfloat16 | |
|
|
| Gradient clipping | max_norm=1.0 | |
|
|
| Epochs | 100 | |
|
|
| GPU | NVIDIA RTX A6000 (48GB) | |
|
|
| Training time | ~7 hours | |
|
|
|
|
|
### Diffusion Training Results |
|
|
|
|
|
| Metric | Value | |
|
|
|---|---| |
|
|
| Final train loss | 0.028 | |
|
|
| Val MSE (single-step) | 743.3 | |
|
|
| Rollout MSE (10-step mean) | 805.1 | |
|
|
|
|
|
Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9). |
|
|
|
|
|
### JEPA Training Config |
|
|
|
|
|
| Parameter | Value | |
|
|
|---|---| |
|
|
| Optimizer | AdamW (lr=3e-4, wd=0.05) | |
|
|
| LR schedule | Cosine with 500-step warmup | |
|
|
| Batch size | 16 | |
|
|
| Mixed precision | bfloat16 | |
|
|
| Gradient clipping | max_norm=1.0 | |
|
|
| EMA schedule | Cosine 0.996 → 1.0 | |
|
|
| Epochs | 100 | |
|
|
| GPU | NVIDIA RTX A6000 (48GB) | |
|
|
| Training time | ~1.5 hours | |
|
|
|
|
|
### JEPA Training Results |
|
|
|
|
|
| Metric | Value | |
|
|
|---|---| |
|
|
| Final train loss | 4.07 | |
|
|
| Similarity (sim) | 0.079 | |
|
|
| Variance (VICReg) | 1.476 | |
|
|
| Covariance (VICReg) | 0.578 | |
|
|
|
|
|
Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction. |
|
|
|
|
|
Full JEPA training metrics available on the [WandB run](https://wandb.ai/alexwortega/the-well-jepa/runs/obwyebcv). |
|
|
|
|
|
## Usage |
|
|
|
|
|
### Installation |
|
|
|
|
|
```bash |
|
|
pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]" |
|
|
``` |
|
|
|
|
|
### Inference (generate next frame) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from unet import UNet |
|
|
from diffusion import GaussianDiffusion |
|
|
|
|
|
# Load model |
|
|
device = "cuda" |
|
|
unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8)) |
|
|
model = GaussianDiffusion(unet, timesteps=1000).to(device) |
|
|
|
|
|
ckpt = torch.load("diffusion_ep0099.pt", map_location=device) |
|
|
model.load_state_dict(ckpt["model"]) |
|
|
model.eval() |
|
|
|
|
|
# Given a condition frame [1, 4, 128, 384]: |
|
|
x_cond = ... # your input frame |
|
|
x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling |
|
|
``` |
|
|
|
|
|
### JEPA inference (extract dynamics embeddings) |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from jepa import JEPA |
|
|
|
|
|
device = "cuda" |
|
|
model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device) |
|
|
|
|
|
ckpt = torch.load("jepa_ep0099.pt", map_location=device) |
|
|
model.load_state_dict(ckpt["model"]) |
|
|
model.eval() |
|
|
|
|
|
# Given a frame [1, 4, 128, 384]: |
|
|
x = ... # your input frame |
|
|
z = model.online_encoder(x) # [1, 128, 16, 48] spatial latent map |
|
|
``` |
|
|
|
|
|
### Autoregressive rollout |
|
|
|
|
|
```python |
|
|
# Generate 20-step trajectory |
|
|
trajectory = [x_cond] |
|
|
cond = x_cond |
|
|
for step in range(20): |
|
|
pred = model.sample_ddim(cond, steps=50, eta=0.0) |
|
|
trajectory.append(pred) |
|
|
cond = pred # feed prediction back as next condition |
|
|
``` |
|
|
|
|
|
### Training from scratch |
|
|
|
|
|
```bash |
|
|
# Download data locally (6.9 GB) |
|
|
the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D |
|
|
|
|
|
# Train diffusion with WandB logging + eval videos |
|
|
python train_diffusion.py \ |
|
|
--no-streaming --local_path ./data/datasets \ |
|
|
--batch_size 8 --epochs 100 --wandb |
|
|
|
|
|
# Train JEPA |
|
|
python train_jepa.py \ |
|
|
--no-streaming --local_path ./data/datasets \ |
|
|
--batch_size 16 --epochs 100 --wandb |
|
|
``` |
|
|
|
|
|
### Streaming from HuggingFace (no download needed) |
|
|
|
|
|
```bash |
|
|
python train_diffusion.py --streaming --batch_size 4 |
|
|
``` |
|
|
|
|
|
## Project Structure |
|
|
|
|
|
| File | Description | |
|
|
|---|---| |
|
|
| `unet.py` | U-Net with time conditioning, skip connections, self-attention | |
|
|
| `diffusion.py` | DDPM/DDIM framework: noise schedule, training loss, sampling | |
|
|
| `jepa.py` | Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss | |
|
|
| `data_pipeline.py` | Data loading from The Well (streaming HF or local HDF5) | |
|
|
| `train_diffusion.py` | Diffusion training with eval, video logging, checkpointing | |
|
|
| `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics | |
|
|
| `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging | |
|
|
| `test_pipeline.py` | End-to-end verification script (data → forward → backward) | |
|
|
| `diffusion_ep0099.pt` | Diffusion final checkpoint (epoch 99, 748MB) | |
|
|
| `jepa_ep0099.pt` | JEPA final checkpoint (epoch 99, 23MB) | |
|
|
|
|
|
## Evaluation Details |
|
|
|
|
|
Every 5 epochs, the training script runs: |
|
|
|
|
|
1. **Single-step evaluation**: DDIM-50 sampling on 4 validation batches, MSE against ground truth |
|
|
2. **Multi-step rollout**: 10-step autoregressive prediction from a validation sample |
|
|
3. **Video logging**: Side-by-side GT vs Prediction video logged to WandB as mp4 |
|
|
4. **Comparison images**: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap) |
|
|
5. **Rollout MSE curve**: Per-step MSE showing prediction degradation over horizon |
|
|
|
|
|
## The Well Dataset |
|
|
|
|
|
[The Well](https://polymathic-ai.org/the_well/) is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change `--dataset`: |
|
|
|
|
|
```bash |
|
|
python train_diffusion.py --dataset active_matter # 51 GB, 256×256 |
|
|
python train_diffusion.py --dataset shear_flow # 115 GB, 128×256 |
|
|
python train_diffusion.py --dataset gray_scott_reaction_diffusion # 154 GB |
|
|
``` |
|
|
|
|
|
## Citation |
|
|
|
|
|
```bibtex |
|
|
@inproceedings{thewell2024, |
|
|
title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning}, |
|
|
author={Polymathic AI}, |
|
|
booktitle={NeurIPS 2024 Datasets and Benchmarks}, |
|
|
year={2024} |
|
|
} |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
Apache 2.0 |
|
|
|