AlexWortega commited on
Commit
591832c
·
verified ·
1 Parent(s): c98627b

Upload README.md with huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +48 -2
README.md CHANGED
@@ -88,7 +88,7 @@ Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-a
88
  | GPU | NVIDIA RTX A6000 (48GB) |
89
  | Training time | ~7 hours |
90
 
91
- ### Training Results
92
 
93
  | Metric | Value |
94
  |---|---|
@@ -98,6 +98,33 @@ Trained on **turbulent_radiative_layer_2D** from [The Well](https://polymathic-a
98
 
99
  Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  ## Usage
102
 
103
  ### Installation
@@ -127,6 +154,24 @@ x_cond = ... # your input frame
127
  x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
128
  ```
129
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  ### Autoregressive rollout
131
 
132
  ```python
@@ -174,7 +219,8 @@ python train_diffusion.py --streaming --batch_size 4
174
  | `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
175
  | `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
176
  | `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
177
- | `diffusion_ep0099.pt` | Final checkpoint (epoch 99, 748MB) |
 
178
 
179
  ## Evaluation Details
180
 
 
88
  | GPU | NVIDIA RTX A6000 (48GB) |
89
  | Training time | ~7 hours |
90
 
91
+ ### Diffusion Training Results
92
 
93
  | Metric | Value |
94
  |---|---|
 
98
 
99
  Training loss curve, validation metrics, comparison images (Condition | Ground Truth | Prediction), and rollout videos (GT vs Prediction side-by-side) are all available on the [WandB run](https://wandb.ai/alexwortega/the-well-diffusion/runs/ilnm4eh9).
100
 
101
+ ### JEPA Training Config
102
+
103
+ | Parameter | Value |
104
+ |---|---|
105
+ | Optimizer | AdamW (lr=3e-4, wd=0.05) |
106
+ | LR schedule | Cosine with 500-step warmup |
107
+ | Batch size | 16 |
108
+ | Mixed precision | bfloat16 |
109
+ | Gradient clipping | max_norm=1.0 |
110
+ | EMA schedule | Cosine 0.996 → 1.0 |
111
+ | Epochs | 100 |
112
+ | GPU | NVIDIA RTX A6000 (48GB) |
113
+ | Training time | ~1.5 hours |
114
+
115
+ ### JEPA Training Results
116
+
117
+ | Metric | Value |
118
+ |---|---|
119
+ | Final train loss | 4.07 |
120
+ | Similarity (sim) | 0.079 |
121
+ | Variance (VICReg) | 1.476 |
122
+ | Covariance (VICReg) | 0.578 |
123
+
124
+ Loss progression: 4.55 (epoch 0) → 3.79 (epoch 2) → 4.07 (epoch 99, converged ~epoch 50). The VICReg regularization keeps representations from collapsing while the similarity loss learns dynamics prediction.
125
+
126
+ Full JEPA training metrics available on the [WandB run](https://wandb.ai/alexwortega/the-well-jepa/runs/obwyebcv).
127
+
128
  ## Usage
129
 
130
  ### Installation
 
154
  x_pred = model.sample_ddim(x_cond, steps=50) # fast DDIM sampling
155
  ```
156
 
157
+ ### JEPA inference (extract dynamics embeddings)
158
+
159
+ ```python
160
+ import torch
161
+ from jepa import JEPA
162
+
163
+ device = "cuda"
164
+ model = JEPA(in_channels=4, latent_channels=128, base_ch=32, pred_hidden=256).to(device)
165
+
166
+ ckpt = torch.load("jepa_ep0099.pt", map_location=device)
167
+ model.load_state_dict(ckpt["model"])
168
+ model.eval()
169
+
170
+ # Given a frame [1, 4, 128, 384]:
171
+ x = ... # your input frame
172
+ z = model.online_encoder(x) # [1, 128, 16, 48] spatial latent map
173
+ ```
174
+
175
  ### Autoregressive rollout
176
 
177
  ```python
 
219
  | `train_jepa.py` | JEPA training with EMA schedule, VICReg metrics |
220
  | `eval_utils.py` | Evaluation: single-step MSE, rollout videos, WandB media logging |
221
  | `test_pipeline.py` | End-to-end verification script (data → forward → backward) |
222
+ | `diffusion_ep0099.pt` | Diffusion final checkpoint (epoch 99, 748MB) |
223
+ | `jepa_ep0099.pt` | JEPA final checkpoint (epoch 99, 23MB) |
224
 
225
  ## Evaluation Details
226