AlexWortega commited on
Commit
8292899
·
verified ·
1 Parent(s): 009369a

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +212 -0
README.md ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - physics
5
+ - diffusion
6
+ - jepa
7
+ - pde
8
+ - simulation
9
+ - the-well
10
+ - ddpm
11
+ - ddim
12
+ - spatiotemporal
13
+ datasets:
14
+ - polymathic-ai/turbulent_radiative_layer_2D
15
+ language:
16
+ - en
17
+ pipeline_tag: image-to-image
18
+ ---
19
+
20
+ # The Well: Diffusion & JEPA for PDE Dynamics
21
+
22
+ Conditional diffusion model (DDPM/DDIM) and Spatial JEPA trained to predict the evolution of 2D physics simulations from [The Well](https://polymathic-ai.org/the_well/) dataset collection by Polymathic AI.
23
+
24
+ Given the current state of a physical system (e.g. turbulent radiative layer), the model predicts the next time step. Can be run autoregressively to generate multi-step rollout trajectories.
25
+
26
+ ## Architecture
27
+
28
+ ### Conditional DDPM (62M parameters)
29
+
30
+ | Component | Details |
31
+ |---|---|
32
+ | **Backbone** | U-Net with 4 resolution levels (64→128→256→512 channels) |
33
+ | **Conditioning** | Previous frame concatenated to noisy target along channel dim |
34
+ | **Time encoding** | Sinusoidal positional embedding → MLP (256-d) |
35
+ | **Residual blocks** | GroupNorm → SiLU → Conv3x3 → +time_emb → GroupNorm → SiLU → Dropout → Conv3x3 |
36
+ | **Attention** | Multi-head self-attention at bottleneck (16x48 spatial, 768 tokens) |
37
+ | **Noise schedule** | Linear beta: 1e-4 → 0.02, 1000 timesteps |
38
+ | **Parameterization** | Epsilon-prediction (predict noise) |
39
+ | **Sampling** | DDPM (1000 steps) or DDIM (50 steps, deterministic) |
40
+
41
+ ```
42
+ Input: [B, 8, 128, 384] ← 4ch noisy target + 4ch condition
43
+ ↓ Conv3x3 → 64ch
44
+ ↓ Level 0: 2×ResBlock(64), Downsample → 64×64×192
45
+ ↓ Level 1: 2×ResBlock(128), Downsample → 128×32×96
46
+ ↓ Level 2: 2×ResBlock(256), Downsample → 256×16×48
47
+ ↓ Level 3: 2×ResBlock(512) + SelfAttention
48
+ ↓ Middle: ResBlock + Attention + ResBlock (512ch)
49
+ ↑ Level 3: 3×ResBlock(512) + Attention, Upsample
50
+ ↑ Level 2: 3×ResBlock(256), Upsample
51
+ ↑ Level 1: 3×ResBlock(128), Upsample
52
+ ↑ Level 0: 3×ResBlock(64)
53
+ ↓ GroupNorm → SiLU → Conv3x3
54
+ Output: [B, 4, 128, 384] ← predicted noise
55
+ ```
56
+
57
+ ### Spatial JEPA (1.8M trainable parameters)
58
+
59
+ | Component | Details |
60
+ |---|---|
61
+ | **Online encoder** | ResNet-style CNN (3 stages, stride-2), outputs spatial latent maps [B, 128, H/8, W/8] |
62
+ | **Target encoder** | EMA copy of online encoder (decay 0.996 → 1.0 cosine schedule) |
63
+ | **Predictor** | 3-layer CNN on spatial feature maps (128 → 256 → 128 channels) |
64
+ | **Loss** | Spatial MSE + VICReg regularization (variance + covariance on channel-averaged features) |
65
+
66
+ The JEPA learns compressed dynamics representations without generating pixels, useful for downstream tasks and transfer learning.
67
+
68
+ ## Training
69
+
70
+ ### Dataset
71
+
72
+ Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-ai.org/the_well/) (Polymathic AI, NeurIPS 2024 Datasets & Benchmarks):
73
+ - 2D turbulent radiative layer simulation
74
+ - Resolution: 128 × 384 spatial, 4 physical field channels
75
+ - 90 trajectories × 101 timesteps = 7,200 training samples
76
+ - 6.9 GB total (HDF5 format)
77
+
78
+ ### Diffusion Training Config
79
+
80
+ | Parameter | Value |
81
+ |---|---|
82
+ | Optimizer | AdamW (lr=1e-4, wd=0.01) |
83
+ | LR schedule | Cosine with 500-step warmup |
84
+ | Batch size | 8 |
85
+ | Mixed precision | bfloat16 |
86
+ | Gradient clipping | max_norm=1.0 |
87
+ | Epochs | 100 |
88
+ | GPU | NVIDIA RTX A6000 (48GB) |
89
+ | Training time | ~7 hours |
90
+
91
+ ### Training Results
92
+
93
+ | Metric | Value |
94
+ |---|---|
95
+ | Final train loss | 0.028 |
96
+ | Val MSE (single-step) | 743.3 |
97
+ | Rollout MSE (10-step mean) | 805.1 |
98
+
99
+ Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).
100
+
101
+ ## Usage
102
+
103
+ ### Installation
104
+
105
+ ```bash
106
+ pip install the_well torch einops wandb tqdm h5py matplotlib "wandb[media]"
107
+ ```
108
+
109
+ ### Inference (generate next frame)
110
+
111
+ ```python
112
+ import torch
113
+ from unet import UNet
114
+ from diffusion import GaussianDiffusion
115
+
116
+ # Load model
117
+ device = "cuda"
118
+ unet = UNet(in_channels=8, out_channels=4, base_ch=64, ch_mults=(1, 2, 4, 8))
119
+ model = GaussianDiffusion(unet, timesteps=1000).to(device)
120
+
121
+ ckpt = torch.load("diffusion_ep0099.pt", map_location=device)
122
+ model.load_state_dict(ckpt["model"])
123
+ model.eval()
124
+
125
+ # Given a condition frame [1, 4, 128, 384]:
126
+ x_cond = ... # your input frame
127
+ x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
128
+ ```
129
+
130
+ ### Autoregressive rollout
131
+
132
+ ```python
133
+ # Generate 20-step trajectory
134
+ trajectory = [x_cond]
135
+ cond = x_cond
136
+ for step in range(20):
137
+ pred = model.sample_ddim(cond, steps=50, eta=0.0)
138
+ trajectory.append(pred)
139
+ cond = pred # feed prediction back as next condition
140
+ ```
141
+
142
+ ### Training from scratch
143
+
144
+ ```bash
145
+ # Download data locally (6.9 GB)
146
+ the-well-download --base-path ./data --dataset turbulent_radiative_layer_2D
147
+
148
+ # Train diffusion with WandB logging + eval videos
149
+ python train_diffusion.py \
150
+ --no-streaming --local_path ./data/datasets \
151
+ --batch_size 8 --epochs 100 --wandb
152
+
153
+ # Train JEPA
154
+ python train_jepa.py \
155
+ --no-streaming --local_path ./data/datasets \
156
+ --batch_size 16 --epochs 100 --wandb
157
+ ```
158
+
159
+ ### Streaming from HuggingFace (no download needed)
160
+
161
+ ```bash
162
+ python train_diffusion.py --streaming --batch_size 4
163
+ ```
164
+
165
+ ## Project Structure
166
+
167
+ | File | Description |
168
+ |---|---|
169
+ | `unet.py` | U-Net with time conditioning, skip connections, self-attention |
170
+ | `diffusion.py` | DDPM/DDIM framework: noise schedule, training loss, sampling |
171
+ | `jepa.py` | Spatial JEPA: CNN encoder, conv predictor, EMA target, VICReg loss |
172
+ | `data_pipeline.py` | Data loading from The Well (streaming HF or local HDF5) |
173
+ | `train_diffusion.py` | Diffusion training with eval, video logging, checkpointing |
174
+ | `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
175
+ | `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
176
+ | `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
177
+ | `diffusion_ep0099.pt` | Final checkpoint (epoch 99, 748MB) |
178
+
179
+ ## Evaluation Details
180
+
181
+ Every 5 epochs, the training script runs:
182
+
183
+ 1. **Single-step evaluation**: DDIM-50 sampling on 4 validation batches, MSE against ground truth
184
+ 2. **Multi-step rollout**: 10-step autoregressive prediction from a validation sample
185
+ 3. **Video logging**: Side-by-side GT vs Prediction video logged to WandB as mp4
186
+ 4. **Comparison images**: Condition | Ground Truth | Prediction for each field channel (RdBu_r colormap)
187
+ 5. **Rollout MSE curve**: Per-step MSE showing prediction degradation over horizon
188
+
189
+ ## The Well Dataset
190
+
191
+ [The Well](https://polymathic-ai.org/the_well/) is a 15TB collection of 16 physics simulation datasets (NeurIPS 2024). This project works with any 2D dataset from The Well — just change `--dataset`:
192
+
193
+ ```bash
194
+ python train_diffusion.py --dataset active_matter # 51 GB, 256×256
195
+ python train_diffusion.py --dataset shear_flow # 115 GB, 128×256
196
+ python train_diffusion.py --dataset gray_scott_reaction_diffusion # 154 GB
197
+ ```
198
+
199
+ ## Citation
200
+
201
+ ```bibtex
202
+ @inproceedings{thewell2024,
203
+ title={The Well: a Large-Scale Collection of Diverse Physics Simulations for Machine Learning},
204
+ author={Polymathic AI},
205
+ booktitle={NeurIPS 2024 Datasets and Benchmarks},
206
+ year={2024}
207
+ }
208
+ ```
209
+
210
+ ## License
211
+
212
+ Apache 2.0