File size: 8,677 Bytes
8292899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591832c
8292899
 
 
 
 
 
 
 
 
591832c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8292899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591832c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8292899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
591832c
 
8292899
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
---
license: apache-2.0
tags:
  - physics
  - diffusion
  - jepa
  - pde
  - simulation
  - the-well
  - ddpm
  - ddim
  - spatiotemporal
datasets:
  - polymathic-ai/turbulent_radiative_layer_2D
language:
  - en
pipeline_tag: image-to-image
---

# The Well: Diffusion & JEPA for PDE Dynamics

Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from [The Well](https://polymathic-ai.org/the_well/) dataset collection by Polymathic AI.

Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories.

## Architecture

### Conditional DDPM (62M parameters)

| Component | Details |
|---|---|
| **Backbone** | U-Net with 4 resolution levels (64→128→256→512 channels) |
| **Conditioning** | Previous frame concatenated to noisy target along channel dim |
| **Time encoding** | Sinusoidal positional embedding → MLP (256-d) |
| **Residual blocks** | GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3 |
| **Attention** | Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens) |
| **Noise schedule** | Linear beta: 1e-4 → 0.02, 1000 timesteps |
| **Parameterization** | Epsilon-prediction (predict noise) |
| **Sampling** | DDPM (1000 steps) or DDIM (50 steps, deterministic) |

```
Input: [B, 8, 128, 384]  ← 4ch noisy target + 4ch condition
  ↓ Conv3x3 → 64ch
  ↓ Level 0: 2×ResBlock(64),   Downsample → 64×64×192
  ↓ Level 1: 2×ResBlock(128),  Downsample → 128×32×96
  ↓ Level 2: 2×ResBlock(256),  Downsample → 256×16×48
  ↓ Level 3: 2×ResBlock(512) + SelfAttention
  ↓ Middle:  ResBlock + Attention + ResBlock (512ch)
  ↑ Level 3: 3×ResBlock(512) + Attention, Upsample
  ↑ Level 2: 3×ResBlock(256), Upsample
  ↑ Level 1: 3×ResBlock(128), Upsample
  ↑ Level 0: 3×ResBlock(64)
  ↓ GroupNorm → SiLU → Conv3x3
Output: [B, 4, 128, 384]  ← predicted noise
```

### Spatial JEPA (1.8M trainable parameters)

| Component | Details |
|---|---|
| **Online encoder** | ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8] |
| **Target encoder** | EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule) |
| **Predictor** | 3-layer CNN on spatial feature maps (128 → 256 → 128 channels) |
| **Loss** | Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features) |

The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning.

## Training

### Dataset

Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-ai.org/the_well/) (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks):
- 2D turbulent radiative layer simulation
- Resolution: 128 × 384 spatial, 4 physical field channels
- 90 trajectories × 101 timesteps = 7,200 training samples
- 6.9 GB total (HDF5 format)

### Diffusion Training Config

| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=1e-4, wd=0.01) |
| LR schedule | Cosine with 500-step warmup |
| Batch size | 8 |
| Mixed precision | bfloat16 |
| Gradient clipping | max_norm=1.0 |
| Epochs | 100 |
| GPU | NVIDIA RTX A6000 (48GB) |
| Training time | ~7 hours |

### Diffusion Training Results

| Metric | Value |
|---|---|
| Final train loss | 0.028 |
| Val MSE (single-step) | 743.3 |
| Rollout MSE (10-step mean) | 805.1 |

Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).

### JEPA Training Config

| Parameter | Value |
|---|---|
| Optimizer | AdamW (lr=3e-4, wd=0.05) |
| LR schedule | Cosine with 500-step warmup |
| Batch size | 16 |
| Mixed precision | bfloat16 |
| Gradient clipping | max_norm=1.0 |
| EMA schedule | Cosine 0.996 → 1.0 |
| Epochs | 100 |
| GPU | NVIDIA RTX A6000 (48GB) |
| Training time | ~1.5 hours |

### JEPA Training Results

| Metric | Value |
|---|---|
| Final train loss | 4.07 |
| Similarity (sim) | 0.079 |
| Variance (VICReg) | 1.476 |
| Covariance (VICReg) | 0.578 |

Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction.

Full JEPA training metrics available on the [WandB run](https://wandb.ai/alexwortega/the-well-jepa/runs/obwyebcv).

## Usage

### Installation

```bash
pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]"
```

### Inference (generate next frame)

```python
import torch
from unet import UNet
from diffusion import GaussianDiffusion

# Load model
device = "cuda"
unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8))
model = GaussianDiffusion(unet, timesteps=1000).to(device)

ckpt = torch.load("diffusion_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# Given a condition frame [1, 4, 128, 384]:
x_cond = ...  # your input frame
x_pred = model.sample_ddim(x_cond, steps=50)  # fast DDIM sampling
```

### JEPA inference (extract dynamics embeddings)

```python
import torch
from jepa import JEPA

device = "cuda"
model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device)

ckpt = torch.load("jepa_ep0099.pt", map_location=device)
model.load_state_dict(ckpt["model"])
model.eval()

# Given a frame [1, 4, 128, 384]:
x = ...  # your input frame
z = model.online_encoder(x)  # [1, 128, 16, 48] spatial latent map
```

### Autoregressive rollout

```python
# Generate 20-step trajectory
trajectory = [x_cond]
cond = x_cond
for step in range(20):
    pred = model.sample_ddim(cond, steps=50, eta=0.0)
    trajectory.append(pred)
    cond = pred  # feed prediction back as next condition
```

### Training from scratch

```bash
# Download data locally (6.9 GB)
the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D

# Train diffusion with WandB logging + eval videos
python train_diffusion.py \
  --no-streaming --local_path ./data/datasets \
  --batch_size 8 --epochs 100 --wandb

# Train JEPA
python train_jepa.py \
  --no-streaming --local_path ./data/datasets \
  --batch_size 16 --epochs 100 --wandb
```

### Streaming from HuggingFace (no download needed)

```bash
python train_diffusion.py --streaming --batch_size 4
```

## Project Structure

| File | Description |
|---|---|
| `unet.py` | U-Net with time conditioning, skip connections, self-attention |
| `diffusion.py` | DDPM/DDIM framework: noise schedule, training loss, sampling |
| `jepa.py` | Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss |
| `data_pipeline.py` | Data loading from The Well (streaming HF or local HDF5) |
| `train_diffusion.py` | Diffusion training with eval, video logging, checkpointing |
| `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
| `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
| `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
| `diffusion_ep0099.pt` | Diffusion final checkpoint (epoch 99, 748MB) |
| `jepa_ep0099.pt` | JEPA final checkpoint (epoch 99, 23MB) |

## Evaluation Details

Every 5 epochs, the training script runs:

1. **Single-step evaluation**: DDIM-50 sampling on 4 validation batches, MSE against ground truth
2. **Multi-step rollout**: 10-step autoregressive prediction from a validation sample
3. **Video logging**: Side-by-side GT vs Prediction video logged to WandB as mp4
4. **Comparison images**: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap)
5. **Rollout MSE curve**: Per-step MSE showing prediction degradation over horizon

## The Well Dataset

[The Well](https://polymathic-ai.org/the_well/) is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change `--dataset`:

```bash
python train_diffusion.py --dataset active_matter          # 51 GB, 256×256
python train_diffusion.py --dataset shear_flow             # 115 GB, 128×256
python train_diffusion.py --dataset gray_scott_reaction_diffusion  # 154 GB
```

## Citation

```bibtex
@inproceedings{thewell2024,
  title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning},
  author={Polymathic AI},
  booktitle={NeurIPS 2024 Datasets and Benchmarks},
  year={2024}
}
```

## License

Apache 2.0