Upload README.md with huggingface_hub
Browse files
README.md
CHANGED
|
@@ -88,7 +88,7 @@ Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-a
|
|
| 88 |
| GPU | NVIDIA RTX A6000 (48GB) |
|
| 89 |
| Training time | ~7 hours |
|
| 90 |
|
| 91 |
-
### Training Results
|
| 92 |
|
| 93 |
| Metric | Value |
|
| 94 |
|---|---|
|
|
@@ -98,6 +98,33 @@ Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-a
|
|
| 98 |
|
| 99 |
Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
## Usage
|
| 102 |
|
| 103 |
### Installation
|
|
@@ -127,6 +154,24 @@ x_cond = ... # your input frame
|
|
| 127 |
x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
|
| 128 |
```
|
| 129 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 130 |
### Autoregressive rollout
|
| 131 |
|
| 132 |
```python
|
|
@@ -174,7 +219,8 @@ python train_diffusion.py --streaming --batch_size 4
|
|
| 174 |
| `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
|
| 175 |
| `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
|
| 176 |
| `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
|
| 177 |
-
| `diffusion_ep0099.pt` |
|
|
|
|
| 178 |
|
| 179 |
## Evaluation Details
|
| 180 |
|
|
|
|
| 88 |
| GPU | NVIDIA RTX A6000 (48GB) |
|
| 89 |
| Training time | ~7 hours |
|
| 90 |
|
| 91 |
+
### Diffusion Training Results
|
| 92 |
|
| 93 |
| Metric | Value |
|
| 94 |
|---|---|
|
|
|
|
| 98 |
|
| 99 |
Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).
|
| 100 |
|
| 101 |
+
### JEPA Training Config
|
| 102 |
+
|
| 103 |
+
| Parameter | Value |
|
| 104 |
+
|---|---|
|
| 105 |
+
| Optimizer | AdamW (lr=3e-4, wd=0.05) |
|
| 106 |
+
| LR schedule | Cosine with 500-step warmup |
|
| 107 |
+
| Batch size | 16 |
|
| 108 |
+
| Mixed precision | bfloat16 |
|
| 109 |
+
| Gradient clipping | max_norm=1.0 |
|
| 110 |
+
| EMA schedule | Cosine 0.996 → 1.0 |
|
| 111 |
+
| Epochs | 100 |
|
| 112 |
+
| GPU | NVIDIA RTX A6000 (48GB) |
|
| 113 |
+
| Training time | ~1.5 hours |
|
| 114 |
+
|
| 115 |
+
### JEPA Training Results
|
| 116 |
+
|
| 117 |
+
| Metric | Value |
|
| 118 |
+
|---|---|
|
| 119 |
+
| Final train loss | 4.07 |
|
| 120 |
+
| Similarity (sim) | 0.079 |
|
| 121 |
+
| Variance (VICReg) | 1.476 |
|
| 122 |
+
| Covariance (VICReg) | 0.578 |
|
| 123 |
+
|
| 124 |
+
Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction.
|
| 125 |
+
|
| 126 |
+
Full JEPA training metrics available on the [WandB run](https://wandb.ai/alexwortega/the-well-jepa/runs/obwyebcv).
|
| 127 |
+
|
| 128 |
## Usage
|
| 129 |
|
| 130 |
### Installation
|
|
|
|
| 154 |
x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
|
| 155 |
```
|
| 156 |
|
| 157 |
+
### JEPA inference (extract dynamics embeddings)
|
| 158 |
+
|
| 159 |
+
```python
|
| 160 |
+
import torch
|
| 161 |
+
from jepa import JEPA
|
| 162 |
+
|
| 163 |
+
device = "cuda"
|
| 164 |
+
model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device)
|
| 165 |
+
|
| 166 |
+
ckpt = torch.load("jepa_ep0099.pt", map_location=device)
|
| 167 |
+
model.load_state_dict(ckpt["model"])
|
| 168 |
+
model.eval()
|
| 169 |
+
|
| 170 |
+
# Given a frame [1, 4, 128, 384]:
|
| 171 |
+
x = ... # your input frame
|
| 172 |
+
z = model.online_encoder(x) # [1, 128, 16, 48] spatial latent map
|
| 173 |
+
```
|
| 174 |
+
|
| 175 |
### Autoregressive rollout
|
| 176 |
|
| 177 |
```python
|
|
|
|
| 219 |
| `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
|
| 220 |
| `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
|
| 221 |
| `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
|
| 222 |
+
| `diffusion_ep0099.pt` | Diffusion final checkpoint (epoch 99, 748MB) |
|
| 223 |
+
| `jepa_ep0099.pt` | JEPA final checkpoint (epoch 99, 23MB) |
|
| 224 |
|
| 225 |
## Evaluation Details
|
| 226 |
|