SogolS commited on
Commit
ae8a025
·
verified ·
1 Parent(s): f49889d

Upload train_jetformer_sogol.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_jetformer_sogol.py +898 -0
train_jetformer_sogol.py ADDED
@@ -0,0 +1,898 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ train_jetformer_ddp.py
3
+
4
+ JetFormer (Flow + AR Transformer + GMM) training script with DDP + HF streaming dataset.
5
+
6
+ Key edits in THIS version (requested):
7
+ 1) Noise curriculum is tied DIRECTLY to (step, max_iters) and goes to ~0 at max_iters.
8
+ - Image/RGB noise uses paper-style sigma in [0,255] (default 64 -> 0).
9
+ - Latent z noise uses paper-style std (default 0.3 -> 0).
10
+ 2) Removed the constant latent noise (z += N(0, 0.3)) and replaced it with a decaying schedule.
11
+ 3) Forward signature changed to: forward(x, step, max_iters)
12
+ 4) Fixed a few structural/indent issues in the original paste (HF shard indent, ViTFlow.forward indent, etc.)
13
+
14
+ Notes:
15
+ - This keeps your architecture and training logic intact, only changing noise scheduling + small code fixes.
16
+ - If you want "almost zero but not exact" at the end, set CFG.noise_floor = 1e-6.
17
+ """
18
+
19
+ import math
20
+ import os
21
+ import csv
22
+ import time
23
+ import pandas as pd
24
+ from dataclasses import dataclass
25
+ from typing import Tuple
26
+
27
+ # --- PIL Fix for Truncated Images ---
28
+ from PIL import ImageFile
29
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
30
+
31
+ import torch
32
+ import torch.nn as nn
33
+ import torch.nn.functional as F
34
+ import torch.distributed as dist
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ from torch.utils.data import DataLoader
37
+ from torchvision.utils import save_image
38
+ from torchvision import transforms
39
+ from tqdm import tqdm
40
+
41
+ # Hugging Face Datasets
42
+ from datasets import load_dataset
43
+
44
+ # This must be done BEFORE importing pyplot
45
+ import matplotlib
46
+ matplotlib.use('Agg')
47
+ import matplotlib.pyplot as plt
48
+
49
+ import numpy as np
50
+
51
+
52
+ # ======================================================================================
53
+ # Block 1: DDP Setup & Configuration
54
+ # ======================================================================================
55
+ def setup_ddp():
56
+ """Initializes the distributed process group."""
57
+ if "RANK" not in os.environ:
58
+ os.environ["RANK"] = "0"
59
+ os.environ["WORLD_SIZE"] = "1"
60
+ os.environ["LOCAL_RANK"] = "0"
61
+ os.environ["MASTER_ADDR"] = "localhost"
62
+ os.environ["MASTER_PORT"] = "12355"
63
+
64
+ dist.init_process_group(backend="nccl")
65
+
66
+ rank = int(os.environ["RANK"])
67
+ local_rank = int(os.environ["LOCAL_RANK"])
68
+ world_size = int(os.environ["WORLD_SIZE"])
69
+ torch.cuda.set_device(local_rank)
70
+ return rank, local_rank, world_size
71
+
72
+
73
+ def cleanup_ddp():
74
+ dist.destroy_process_group()
75
+
76
+
77
+ @dataclass
78
+ class CFG:
79
+ # --- Model Config (Scaled for RTX 4090 24GB) ---
80
+ d_model: int = 768
81
+ n_heads: int = 12
82
+ n_layers: int = 12
83
+
84
+ # --- AstroPT Specific Configs ---
85
+ block_size: int = 1024
86
+ dropout: float = 0.0
87
+ bias: bool = False
88
+ is_causal: bool = True
89
+
90
+ # --- Flow Specification ---
91
+ flow_steps: int = 16
92
+
93
+ # --- Training Config ---
94
+ max_iters: int = 80_000
95
+ save_interval: int = 5000
96
+ batch_size: int = 8
97
+ val_check_interval: int = 5000
98
+
99
+ # --- Optimizer Config ---
100
+ lr: float = 1e-4
101
+ wd: float = 1e-4
102
+ beta2: float = 0.95
103
+ warmup_steps: int = 10000
104
+
105
+ # --- Data Params ---
106
+ img_size: int = 256
107
+ patch: int = 8
108
+ in_ch: int = 3
109
+
110
+ # Derived Dimensions
111
+ n_tokens: int = (img_size // patch) ** 2
112
+ d_token: int = in_ch * patch * patch
113
+
114
+ # --- GMM Head ---
115
+ gmm_K: int = 256
116
+
117
+ # --- Noise curriculum (paper-style, tied to max_iters) ---
118
+ # JetFormer paper uses σ0 = 64 in pixel space [0,255] (≈ 0.251 in [0,1]).
119
+ rgb_sigma0_255: float = 64.0 # start noise in [0,255]
120
+ rgb_sigmaT_255: float = 0.0 # final noise at max_iters (0 => sharpest end)
121
+
122
+ # Latent noise in flow token space (paper mentions std=0.3)
123
+ z_sigma0: float = 0.3 # start latent noise
124
+ z_sigmaT: float = 0.0 # final latent noise at max_iters
125
+
126
+ # If 1.0: reaches final exactly at max_iters.
127
+ # If <1.0: reaches final earlier and stays there.
128
+ noise_decay_frac: float = 1.0
129
+
130
+ # Optional: set to 1e-6 to avoid EXACT zero (sometimes smoother numerically)
131
+ noise_floor: float = 0.0
132
+
133
+ # --- System ---
134
+ grad_clip_val: float = 0.5
135
+
136
+ # Paths
137
+ dataset_name: str = "final_sogol_image_patch_8"
138
+ checkpoint_path: str = ""
139
+ samples_dir: str = ""
140
+ loss_csv_path: str = ""
141
+ loss_plot_path: str = ""
142
+
143
+ # --- Data Sources (Hugging Face) ---
144
+ hf_repo: str = "Smith42/galaxies"
145
+ val_steps: int = 100
146
+
147
+ # DDP Placeholders
148
+ rank: int = 0
149
+ world_size: int = 1
150
+ device: str = "cuda"
151
+
152
+
153
+ # ======================================================================================
154
+ # Block 2: Logging Utilities (Rank 0 Only)
155
+ # ======================================================================================
156
+ def append_losses_to_csv(step, train_loss, val_loss, filename):
157
+ file_exists = os.path.isfile(filename)
158
+ with open(filename, 'a', newline='') as csvfile:
159
+ writer = csv.writer(csvfile)
160
+ if not file_exists:
161
+ writer.writerow(['step', 'train_loss', 'val_loss'])
162
+ writer.writerow([step, train_loss, val_loss])
163
+
164
+
165
+ def plot_loss_from_csv(csv_path, output_path):
166
+ if not os.path.isfile(csv_path):
167
+ return
168
+ df = pd.read_csv(csv_path)
169
+ fig, ax = plt.subplots(figsize=(10, 6))
170
+ ax.plot(df['step'], df['train_loss'], label='Train Loss', color='blue')
171
+
172
+ df_val = df.dropna(subset=['val_loss'])
173
+ if not df_val.empty:
174
+ ax.plot(
175
+ df_val['step'], df_val['val_loss'],
176
+ label='Validation Loss', color='orange',
177
+ linestyle='--', marker='o'
178
+ )
179
+
180
+ ax.set_title('Training and Validation Loss per Step')
181
+ ax.set_xlabel('Step')
182
+ ax.set_ylabel('Average Loss')
183
+ ax.legend()
184
+ ax.grid(True)
185
+ fig.savefig(output_path)
186
+ plt.close(fig)
187
+
188
+
189
+ # ======================================================================================
190
+ # Block 3: Data Loading
191
+ # ======================================================================================
192
+ def process_hf_item(item):
193
+ img = item['image_crop']
194
+ to_tensor = transforms.ToTensor()
195
+ img_t = to_tensor(img)
196
+ if img_t.shape[0] == 1:
197
+ img_t = img_t.repeat(3, 1, 1)
198
+ return {"img": img_t}
199
+
200
+
201
+ def get_train_dataloader(cfg: CFG):
202
+ if cfg.rank == 0:
203
+ print(f"Loading streaming dataset: {cfg.hf_repo} (Split: train)")
204
+ ds = load_dataset(cfg.hf_repo, split="train", streaming=True)
205
+
206
+ # shard across ranks so each GPU sees a different stream
207
+ if cfg.world_size > 1:
208
+ ds = ds.shard(num_shards=cfg.world_size, index=cfg.rank)
209
+
210
+ ds = ds.map(process_hf_item, remove_columns=["image", "image_crop", "survey", "ra", "dec"])
211
+
212
+ nw = min(6, max(2, (os.cpu_count() // max(cfg.world_size, 1)) - 1))
213
+ return DataLoader(
214
+ ds,
215
+ batch_size=cfg.batch_size,
216
+ num_workers=nw,
217
+ pin_memory=True,
218
+ )
219
+
220
+
221
+ def get_val_dataloader(cfg: CFG):
222
+ if cfg.rank == 0:
223
+ print(f"Loading streaming dataset: {cfg.hf_repo} (Split: test)")
224
+ ds = load_dataset(cfg.hf_repo, split="test", streaming=True)
225
+
226
+ # No validation sharding to prevent empty shards on small val sets
227
+ ds = ds.map(process_hf_item, remove_columns=["image", "image_crop", "survey", "ra", "dec"])
228
+
229
+ nw = min(4, max(2, (os.cpu_count() // max(cfg.world_size, 1)) - 1))
230
+ return DataLoader(
231
+ ds,
232
+ batch_size=cfg.batch_size,
233
+ num_workers=nw,
234
+ pin_memory=True,
235
+ )
236
+
237
+
238
+ # ======================================================================================
239
+ # Block 4: Checkpointing (Rank 0 Only)
240
+ # ======================================================================================
241
+ def save_checkpoint(step, model, optimizer, cfg, is_latest=True):
242
+ """
243
+ Saves the checkpoint.
244
+ 1) Always overwrites 'checkpoint_latest.pt' for easy resuming.
245
+ 2) If is_latest=False, saves a numbered file like 'checkpoint_step_005000.pt'.
246
+ """
247
+ if cfg.rank != 0:
248
+ return
249
+
250
+ model_state = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
251
+
252
+ checkpoint = {
253
+ 'step': step,
254
+ 'model_state_dict': model_state,
255
+ 'optimizer_state_dict': optimizer.state_dict()
256
+ }
257
+
258
+ latest_path = os.path.join(cfg.samples_dir, "checkpoint_latest.pt")
259
+ torch.save(checkpoint, latest_path)
260
+
261
+ if not is_latest:
262
+ history_path = os.path.join(cfg.samples_dir, f"checkpoint_step_{step:07d}.pt")
263
+ torch.save(checkpoint, history_path)
264
+ print(f"Saved historical checkpoint: {history_path}")
265
+ else:
266
+ print(f"Updated latest checkpoint: {latest_path}")
267
+
268
+
269
+ def load_checkpoint(model, optimizer, cfg):
270
+ latest_path = os.path.join(cfg.samples_dir, "checkpoint_latest.pt")
271
+
272
+ if not os.path.exists(latest_path):
273
+ if cfg.rank == 0:
274
+ print(f"No checkpoint found at {latest_path}. Starting from scratch.")
275
+ return 0
276
+
277
+ map_location = {'cuda:%d' % 0: 'cuda:%d' % cfg.rank}
278
+ checkpoint = torch.load(latest_path, map_location=map_location)
279
+
280
+ model_unwrap = model.module if isinstance(model, DDP) else model
281
+ model_unwrap.load_state_dict(checkpoint['model_state_dict'])
282
+
283
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
284
+ step = checkpoint['step']
285
+
286
+ if cfg.rank == 0:
287
+ print(f"Checkpoint loaded from {latest_path}. Resuming from step {step}")
288
+ return step
289
+
290
+
291
+ # ======================================================================================
292
+ # Block 5: Model Definitions
293
+ # ======================================================================================
294
+ def uniform_dequantize(x: torch.Tensor) -> torch.Tensor:
295
+ # Standard dequantization for 8-bit images
296
+ return (x + torch.rand_like(x) / 256.0).clamp(0.0, 1.0)
297
+
298
+
299
+ def patchify(x: torch.Tensor, patch_size: int = 16) -> torch.Tensor:
300
+ B, C, H, W = x.shape
301
+ x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
302
+ x = x.contiguous().permute(0, 2, 3, 1, 4, 5).reshape(B, -1, C * patch_size * patch_size)
303
+ return x
304
+
305
+
306
+ def depatchify(tokens: torch.Tensor, C: int = 3, H: int = 256, W: int = 256, patch_size: int = 16) -> torch.Tensor:
307
+ B, N, D = tokens.shape
308
+ hp, wp = H // patch_size, W // patch_size
309
+ x = tokens.reshape(B, hp, wp, C, patch_size, patch_size)
310
+ x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
311
+ return x
312
+
313
+
314
+ def cosine_decay(step: int, T: int, start: float, end: float) -> float:
315
+ """
316
+ Cosine decay from start -> end over steps [0, T].
317
+ Returns exactly end for step >= T.
318
+ """
319
+ if T <= 0:
320
+ return end
321
+ if step >= T:
322
+ return end
323
+ x = step / T # in [0,1)
324
+ return end + 0.5 * (start - end) * (1.0 + math.cos(math.pi * x))
325
+
326
+
327
+ class ViTCouplingBlock(nn.Module):
328
+ def __init__(self, in_channels: int, n_tokens: int, width: int = 512, depth: int = 4, heads: int = 8):
329
+ super().__init__()
330
+ self.in_proj = nn.Linear(in_channels, width)
331
+ self.pos_emb = nn.Parameter(torch.randn(1, n_tokens, width) * 0.02)
332
+
333
+ layer = nn.TransformerEncoderLayer(
334
+ d_model=width, nhead=heads, dim_feedforward=2048, dropout=0.0,
335
+ activation="gelu", batch_first=True, norm_first=True
336
+ )
337
+ self.transformer = nn.TransformerEncoder(layer, num_layers=depth)
338
+ self.out_proj = nn.Linear(width, in_channels * 2)
339
+
340
+ nn.init.zeros_(self.out_proj.weight)
341
+ nn.init.zeros_(self.out_proj.bias)
342
+
343
+ def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
344
+ h = self.in_proj(x) + self.pos_emb
345
+ h = self.transformer(h)
346
+ st = self.out_proj(h)
347
+ s, t = st.chunk(2, dim=-1)
348
+ s = torch.tanh(s)
349
+ return s, t
350
+
351
+
352
+ class ViTAffineCoupling(nn.Module):
353
+ def __init__(self, d_token: int, n_tokens: int):
354
+ super().__init__()
355
+ self.half_d = d_token // 2
356
+ self.register_buffer('perm', torch.randperm(d_token))
357
+ self.register_buffer('inv_perm', torch.argsort(self.perm))
358
+ self.net = ViTCouplingBlock(self.half_d, n_tokens)
359
+
360
+ def forward(self, x: torch.Tensor, reverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
361
+ if not reverse:
362
+ x = x[..., self.perm]
363
+ x_a, x_b = x[..., :self.half_d], x[..., self.half_d:]
364
+ s, t = self.net(x_a)
365
+ y_b = x_b * torch.exp(s) + t
366
+ y = torch.cat([x_a, y_b], dim=-1)
367
+ logdet = s.sum(dim=(1, 2))
368
+ return y, logdet
369
+ else:
370
+ x_a, x_b = x[..., :self.half_d], x[..., self.half_d:]
371
+ s, t = self.net(x_a)
372
+ y_b = (x_b - t) * torch.exp(-s)
373
+ y = torch.cat([x_a, y_b], dim=-1)
374
+ y = y[..., self.inv_perm]
375
+ logdet = -s.sum(dim=(1, 2))
376
+ return y, logdet
377
+
378
+
379
+ class ViTFlow(nn.Module):
380
+ def __init__(self, d_token: int, n_tokens: int, steps: int = 32):
381
+ super().__init__()
382
+ self.blocks = nn.ModuleList([ViTAffineCoupling(d_token, n_tokens) for _ in range(steps)])
383
+
384
+ def forward(self, x: torch.Tensor, reverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
385
+ logdet = x.new_zeros(x.size(0))
386
+ z = x
387
+ if not reverse:
388
+ for b in self.blocks:
389
+ z, ld = b(z, reverse=False)
390
+ logdet += ld
391
+ else:
392
+ for b in reversed(self.blocks):
393
+ z, ld = b(z, reverse=True)
394
+ logdet += ld
395
+ return z, logdet
396
+
397
+
398
+ def compute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
399
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
400
+ t = torch.arange(end, device=freqs.device, dtype=torch.float32)
401
+ freqs = torch.outer(t, freqs)
402
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
403
+ return freqs_cis
404
+
405
+
406
+ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
407
+ x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
408
+ if freqs_cis.dtype not in (torch.complex64, torch.complex128):
409
+ if freqs_cis.dim() == 2:
410
+ freqs_cis = freqs_cis.view(*freqs_cis.shape[:-1], -1, 2)
411
+ freqs_cis = torch.view_as_complex(freqs_cis)
412
+ freqs_cis = freqs_cis.view(1, x.size(1), 1, x_c.size(-1))
413
+ x_out = torch.view_as_real(x_c * freqs_cis).flatten(3)
414
+ return x_out.type_as(x)
415
+
416
+
417
+ class RMSNorm(nn.Module):
418
+ def __init__(self, dim: int, eps: float = 1e-6):
419
+ super().__init__()
420
+ self.eps = eps
421
+ self.weight = nn.Parameter(torch.ones(dim))
422
+
423
+ def _norm(self, x):
424
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
425
+
426
+ def forward(self, x):
427
+ output = self._norm(x.float()).type_as(x)
428
+ return output * self.weight
429
+
430
+
431
+ class GemmaMLP(nn.Module):
432
+ def __init__(self, cfg: CFG):
433
+ super().__init__()
434
+ self.hidden_dim = 4 * cfg.d_model
435
+ self.gate_proj = nn.Linear(cfg.d_model, self.hidden_dim, bias=cfg.bias)
436
+ self.up_proj = nn.Linear(cfg.d_model, self.hidden_dim, bias=cfg.bias)
437
+ self.down_proj = nn.Linear(self.hidden_dim, cfg.d_model, bias=cfg.bias)
438
+ self.dropout = nn.Dropout(cfg.dropout)
439
+
440
+ def forward(self, x):
441
+ gate = self.gate_proj(x)
442
+ gate = F.gelu(gate, approximate="tanh")
443
+ up = self.up_proj(x)
444
+ x = gate * up
445
+ x = self.down_proj(x)
446
+ x = self.dropout(x)
447
+ return x
448
+
449
+
450
+ class GemmaAttention(nn.Module):
451
+ def __init__(self, cfg: CFG):
452
+ super().__init__()
453
+ self.head_dim = cfg.d_model // cfg.n_heads
454
+
455
+ self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
456
+ self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
457
+ self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
458
+ self.o_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
459
+
460
+ self.resid_dropout = nn.Dropout(cfg.dropout)
461
+ self.n_head = cfg.n_heads
462
+ self.dropout = cfg.dropout
463
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
464
+
465
+ self.register_buffer("freqs_cis", compute_freqs_cis(self.head_dim, cfg.block_size), persistent=False)
466
+
467
+ def forward(self, x):
468
+ B, T, C = x.size()
469
+ q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
470
+ k = self.k_proj(x).view(B, T, self.n_head, self.head_dim)
471
+ v = self.v_proj(x).view(B, T, self.n_head, self.head_dim)
472
+
473
+ freqs_cis = self.freqs_cis[:T]
474
+ q = apply_rotary_emb(q, freqs_cis)
475
+ k = apply_rotary_emb(k, freqs_cis)
476
+
477
+ q = q.transpose(1, 2)
478
+ k = k.transpose(1, 2)
479
+ v = v.transpose(1, 2)
480
+
481
+ if self.flash:
482
+ y = torch.nn.functional.scaled_dot_product_attention(
483
+ q, k, v,
484
+ attn_mask=None,
485
+ dropout_p=self.dropout if self.training else 0.0,
486
+ is_causal=True
487
+ )
488
+ else:
489
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
490
+ mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
491
+ att = att.masked_fill(mask == 0, float('-inf'))
492
+ att = F.softmax(att, dim=-1)
493
+ y = att @ v
494
+
495
+ y = y.transpose(1, 2).contiguous().view(B, T, C)
496
+ y = self.resid_dropout(self.o_proj(y))
497
+ return y
498
+
499
+
500
+ class GemmaBlock(nn.Module):
501
+ def __init__(self, cfg: CFG):
502
+ super().__init__()
503
+ self.ln_1 = RMSNorm(cfg.d_model)
504
+ self.attn = GemmaAttention(cfg)
505
+ self.ln_2 = RMSNorm(cfg.d_model)
506
+ self.mlp = GemmaMLP(cfg)
507
+
508
+ def forward(self, x):
509
+ x = x + self.attn(self.ln_1(x))
510
+ x = x + self.mlp(self.ln_2(x))
511
+ return x
512
+
513
+
514
+ class AstroPTBackbone(nn.Module):
515
+ def __init__(self, cfg: CFG):
516
+ super().__init__()
517
+ self.drop = nn.Dropout(cfg.dropout)
518
+ self.h = nn.ModuleList([GemmaBlock(cfg) for _ in range(cfg.n_layers)])
519
+ self.ln_f = RMSNorm(cfg.d_model)
520
+ self.apply(self._init_weights)
521
+
522
+ def _init_weights(self, module):
523
+ if isinstance(module, nn.Linear):
524
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
525
+ if module.bias is not None:
526
+ torch.nn.init.zeros_(module.bias)
527
+ elif isinstance(module, nn.Embedding):
528
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
529
+
530
+ def forward(self, x):
531
+ x = self.drop(x)
532
+ for block in self.h:
533
+ x = block(x)
534
+ x = self.ln_f(x)
535
+ return x
536
+
537
+
538
+ class GMMHead(nn.Module):
539
+ def __init__(self, d_model: int, d_token: int, K: int):
540
+ super().__init__()
541
+ self.K, self.D = K, d_token
542
+ self.proj = nn.Linear(d_model, K * (1 + 2 * d_token))
543
+
544
+ def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
545
+ B, N, _ = h.shape
546
+ out = self.proj(h).view(B, N, self.K, 1 + 2 * self.D)
547
+
548
+ logits_pi = out[..., 0]
549
+ mu = out[..., 1:1 + self.D]
550
+ log_sigma = out[..., 1 + self.D:]
551
+
552
+ log_sigma = torch.clamp(log_sigma, -7, 2)
553
+ return logits_pi, mu, log_sigma
554
+
555
+
556
+ def gmm_nll(y: torch.Tensor, logits_pi: torch.Tensor, mu: torch.Tensor, log_sigma: torch.Tensor) -> torch.Tensor:
557
+ B, N, D = y.shape
558
+ K = logits_pi.size(-1)
559
+
560
+ y = y.unsqueeze(2)
561
+ inv_var = torch.exp(-2 * log_sigma)
562
+ logp = -0.5 * ((y - mu) ** 2 * inv_var).sum(-1) - log_sigma.sum(-1) - 0.5 * D * math.log(2 * math.pi)
563
+ logmix = F.log_softmax(logits_pi, dim=-1) + logp
564
+ return -torch.logsumexp(logmix, dim=-1).sum(dim=1)
565
+
566
+
567
+ class JetFormer(nn.Module):
568
+ def __init__(self, cfg: CFG):
569
+ super().__init__()
570
+ self.cfg = cfg
571
+ self.flow = ViTFlow(cfg.d_token, cfg.n_tokens, cfg.flow_steps)
572
+ self.in_proj = nn.Linear(cfg.d_token, cfg.d_model)
573
+ self.pos = nn.Parameter(torch.randn(1, cfg.n_tokens, cfg.d_model) * 0.02)
574
+ self.gpt = AstroPTBackbone(cfg)
575
+ self.head = GMMHead(cfg.d_model, cfg.d_token, cfg.gmm_K)
576
+
577
+ def forward(self, x: torch.Tensor, step: int, max_iters: int) -> torch.Tensor:
578
+ """
579
+ Noise curriculum is tied to (step, max_iters) and decays to final values at max_iters.
580
+ - RGB noise: sigma in [0,255] (paper-style)
581
+ - z-noise: token-space Gaussian std (paper-style)
582
+ """
583
+ x = uniform_dequantize(x)
584
+
585
+ # Curriculum length T (end exactly at max_iters if noise_decay_frac=1.0)
586
+ T = int(max_iters * self.cfg.noise_decay_frac)
587
+ T = max(T, 1)
588
+
589
+ # ---- RGB noise schedule ----
590
+ rgb_sigma_255 = cosine_decay(step, T, self.cfg.rgb_sigma0_255, self.cfg.rgb_sigmaT_255)
591
+ rgb_sigma = rgb_sigma_255 / 255.0
592
+ if self.cfg.noise_floor > 0:
593
+ rgb_sigma = max(rgb_sigma, self.cfg.noise_floor)
594
+
595
+ if self.training and rgb_sigma > 0.0:
596
+ x = (x + torch.randn_like(x) * rgb_sigma).clamp(0.0, 1.0)
597
+
598
+ # ---- Flow encode ----
599
+ tokens_in = patchify(x, self.cfg.patch)
600
+ z, logdet = self.flow(tokens_in, reverse=False)
601
+
602
+ # ---- Latent z noise schedule (decays to ~0 by max_iters) ----
603
+ z_sigma = cosine_decay(step, T, self.cfg.z_sigma0, self.cfg.z_sigmaT)
604
+ if self.cfg.noise_floor > 0:
605
+ z_sigma = max(z_sigma, self.cfg.noise_floor)
606
+
607
+ if self.training and z_sigma > 0.0:
608
+ z = z + torch.randn_like(z) * z_sigma
609
+
610
+ # ---- AR transformer + GMM ----
611
+ h = self.in_proj(z) + self.pos
612
+ h = self.gpt(h)
613
+
614
+ logits_pi, mu, log_sigma = self.head(h[:, :-1])
615
+ target = z[:, 1:]
616
+
617
+ nll_gmm = gmm_nll(target, logits_pi, mu, log_sigma)
618
+ loss = (nll_gmm - logdet).mean()
619
+ return loss
620
+
621
+ @torch.no_grad()
622
+ def sample(self, n: int = 16, x_real_batch: torch.Tensor = None):
623
+ self.eval()
624
+ B = n
625
+ N = self.cfg.n_tokens
626
+ device = next(self.parameters()).device
627
+
628
+ if x_real_batch is None:
629
+ z_seq = torch.zeros(B, N, self.cfg.d_token, device=device)
630
+ for t in range(N - 1):
631
+ h_in = self.in_proj(z_seq) + self.pos
632
+ h_out = self.gpt(h_in)
633
+ logits_pi, mu, log_sigma = self.head(h_out[:, t:t + 1])
634
+
635
+ pi = F.softmax(logits_pi.squeeze(1), dim=-1)
636
+ comp_idx = torch.multinomial(pi, 1)
637
+ gather_idx = comp_idx[..., None].expand(-1, -1, self.cfg.d_token)
638
+
639
+ sel_mu = mu.squeeze(1).gather(1, gather_idx).squeeze(1)
640
+ sel_sigma = log_sigma.squeeze(1).gather(1, gather_idx).squeeze(1).exp()
641
+
642
+ z_next = sel_mu + torch.randn_like(sel_mu) * sel_sigma
643
+ z_seq[:, t + 1] = z_next
644
+
645
+ x_rec_tokens, _ = self.flow(z_seq, reverse=True)
646
+ x_rec = depatchify(x_rec_tokens, self.cfg.in_ch, self.cfg.img_size, self.cfg.img_size, self.cfg.patch)
647
+ return x_rec.clamp(0, 1)
648
+
649
+ else:
650
+ x_real = x_real_batch.to(device)
651
+ x_real_proc = uniform_dequantize(x_real)
652
+
653
+ z_real, _ = self.flow(patchify(x_real_proc, self.cfg.patch), reverse=False)
654
+ h_in = self.in_proj(z_real) + self.pos
655
+ h_out = self.gpt(h_in)
656
+
657
+ logits_pi, mu, log_sigma = self.head(h_out)
658
+ best_comp_idx = torch.argmax(logits_pi, dim=-1, keepdim=True)
659
+ gather_idx = best_comp_idx.unsqueeze(-1).expand(-1, -1, -1, self.cfg.d_token)
660
+
661
+ z_pred_next = torch.gather(mu, 2, gather_idx).squeeze(2)
662
+
663
+ z_rec = torch.zeros_like(z_real)
664
+ z_rec[:, 0] = z_real[:, 0]
665
+ z_rec[:, 1:] = z_pred_next[:, :-1]
666
+
667
+ x_rec_tokens, _ = self.flow(z_rec, reverse=True)
668
+ x_rec = depatchify(x_rec_tokens, self.cfg.in_ch, self.cfg.img_size, self.cfg.img_size, self.cfg.patch)
669
+
670
+ combined = torch.stack([x_real, x_rec.clamp(0, 1)], dim=1).view(
671
+ -1, self.cfg.in_ch, self.cfg.img_size, self.cfg.img_size
672
+ )
673
+ return combined
674
+
675
+
676
+ # ======================================================================================
677
+ # Block 6: Main Training Loop (DDP Aware)
678
+ # ======================================================================================
679
+ def train():
680
+ # --- 1. DDP Setup ---
681
+ rank, local_rank, world_size = setup_ddp()
682
+
683
+ cfg = CFG()
684
+ cfg.rank = rank
685
+ cfg.world_size = world_size
686
+ cfg.device = f"cuda:{local_rank}"
687
+
688
+ # ### PATH SETUP ###
689
+ cfg.samples_dir = f"samples_{cfg.dataset_name}_256"
690
+ cfg.loss_csv_path = f"loss_log_{cfg.dataset_name}_256.csv"
691
+ cfg.loss_plot_path = f"loss_plot_{cfg.dataset_name}_256.png"
692
+
693
+ if rank == 0:
694
+ os.makedirs(cfg.samples_dir, exist_ok=True)
695
+ print(f"--- DDP CONFIGURATION ---")
696
+ print(f" World Size: {world_size}")
697
+ print(f" Per-GPU Batch: {cfg.batch_size}")
698
+ print(f" Global Batch: {cfg.batch_size * world_size}")
699
+ print(f" Dataset: {cfg.hf_repo} (Streaming + Sharded)")
700
+ print(f" Saving Checkpoints to: {cfg.samples_dir}")
701
+ print(f"-------------------------")
702
+
703
+ print("Noise curriculum:")
704
+ print(f" RGB sigma: {cfg.rgb_sigma0_255} -> {cfg.rgb_sigmaT_255} (in [0,255])")
705
+ print(f" z sigma: {cfg.z_sigma0} -> {cfg.z_sigmaT} (token space)")
706
+ print(f" decay_frac: {cfg.noise_decay_frac} (T = {int(cfg.max_iters*cfg.noise_decay_frac)})")
707
+ print(f" floor: {cfg.noise_floor}")
708
+
709
+ # --- 2. Model Setup ---
710
+ model = JetFormer(cfg).to(cfg.device)
711
+ model = DDP(model, device_ids=[local_rank], output_device=local_rank)
712
+
713
+ # --- 3. Optimizer Setup ---
714
+ opt = torch.optim.AdamW(
715
+ model.parameters(),
716
+ lr=cfg.lr,
717
+ weight_decay=cfg.wd,
718
+ betas=(0.9, cfg.beta2)
719
+ )
720
+
721
+ # --- 4. Checkpoint Loading ---
722
+ start_step = load_checkpoint(model, opt, cfg)
723
+
724
+ # --- 5. Data Loading ---
725
+ train_loader = get_train_dataloader(cfg)
726
+ val_loader = get_val_dataloader(cfg)
727
+
728
+ # Pre-load fixed batch for viz
729
+ viz_batch = None
730
+ if rank == 0:
731
+ print("Fetching visualization batch...")
732
+ try:
733
+ viz_batch = next(iter(val_loader))['img'][:16].to(cfg.device)
734
+ except Exception as e:
735
+ print(f"Warning: Could not load viz batch: {e}")
736
+
737
+ # --- 6. Scheduler ---
738
+ def get_lr_schedule(step):
739
+ if step < cfg.warmup_steps:
740
+ return step / cfg.warmup_steps
741
+ else:
742
+ progress = (step - cfg.warmup_steps) / (cfg.max_iters - cfg.warmup_steps)
743
+ progress = max(0.0, min(1.0, progress))
744
+ return 0.5 * (1.0 + math.cos(math.pi * progress))
745
+
746
+ if rank == 0:
747
+ print(f"Starting training loop from step {start_step}...")
748
+
749
+ # --- 7. Main Loop ---
750
+ model.train()
751
+ train_iter = iter(train_loader)
752
+
753
+ if rank == 0:
754
+ pbar = tqdm(range(start_step, cfg.max_iters), initial=start_step, total=cfg.max_iters)
755
+ else:
756
+ pbar = range(start_step, cfg.max_iters)
757
+
758
+ train_loss_accum = 0.0
759
+ accum_steps = 0
760
+
761
+ for step in pbar:
762
+ try:
763
+ batch = next(train_iter)
764
+ except StopIteration:
765
+ train_iter = iter(train_loader)
766
+ batch = next(train_iter)
767
+
768
+ img = batch["img"].to(cfg.device)
769
+
770
+ # LR Update
771
+ lr_scale = get_lr_schedule(step)
772
+ for param_group in opt.param_groups:
773
+ param_group['lr'] = cfg.lr * lr_scale
774
+
775
+ # Forward Pass (BFloat16 for H100 / Ampere+)
776
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
777
+ loss = model(img, step=step, max_iters=cfg.max_iters)
778
+
779
+ # Backward
780
+ opt.zero_grad(set_to_none=True)
781
+ loss.backward()
782
+ torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_val)
783
+ opt.step()
784
+
785
+ current_loss = float(loss.item())
786
+
787
+ if rank == 0:
788
+ train_loss_accum += current_loss
789
+ accum_steps += 1
790
+ if isinstance(pbar, tqdm):
791
+ pbar.set_postfix(loss=f"{current_loss:.3f}", lr=f"{opt.param_groups[0]['lr']:.2e}")
792
+
793
+ if step > 0:
794
+ # 1. Validation and Image Sampling
795
+ if step % cfg.val_check_interval == 0:
796
+ if rank == 0:
797
+ avg_train_loss = train_loss_accum / max(accum_steps, 1)
798
+ train_loss_accum = 0.0
799
+ accum_steps = 0
800
+
801
+ # Generate Samples
802
+ if viz_batch is not None:
803
+ model.eval()
804
+ try:
805
+ with torch.no_grad():
806
+ fake_images = model.module.sample(n=16, x_real_batch=viz_batch)
807
+ sample_path = os.path.join(cfg.samples_dir, f"step_{step:07d}.png")
808
+ save_image(fake_images, sample_path, nrow=2)
809
+ except Exception as e:
810
+ print(f"Interval Sampling Error: {e}")
811
+ model.train()
812
+
813
+ # Run Validation
814
+ model.eval()
815
+ val_iter = iter(val_loader)
816
+ local_val_loss = 0.0
817
+
818
+ with torch.no_grad():
819
+ for _ in range(cfg.val_steps):
820
+ try:
821
+ vbatch = next(val_iter)
822
+ vimg = vbatch["img"].to(cfg.device)
823
+ vloss = model(vimg, step=step, max_iters=cfg.max_iters)
824
+ local_val_loss += float(vloss.item())
825
+ except StopIteration:
826
+ break
827
+
828
+ avg_local_val = local_val_loss / max(cfg.val_steps, 1)
829
+ val_tensor = torch.tensor([avg_local_val], device=cfg.device)
830
+ dist.all_reduce(val_tensor, op=dist.ReduceOp.SUM)
831
+ avg_val_loss = val_tensor.item() / world_size
832
+
833
+ if rank == 0:
834
+ save_checkpoint(step, model, opt, cfg, is_latest=True)
835
+ append_losses_to_csv(step, avg_train_loss, avg_val_loss, cfg.loss_csv_path)
836
+ plot_loss_from_csv(cfg.loss_csv_path, cfg.loss_plot_path)
837
+ print(f"\nStep {step}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
838
+
839
+ model.train()
840
+
841
+ # 2. Historical checkpoint saving
842
+ if step % cfg.save_interval == 0:
843
+ if rank == 0:
844
+ save_checkpoint(step, model, opt, cfg, is_latest=False)
845
+
846
+ # Final checkpoint and logging at max_iters (after training loop completes)
847
+ final_step = cfg.max_iters - 1
848
+
849
+ # Calculate final average train loss (rank 0 only)
850
+ if rank == 0:
851
+ avg_train_loss = train_loss_accum / max(accum_steps, 1) if accum_steps > 0 else 0.0
852
+
853
+ # Final validation (all ranks)
854
+ model.eval()
855
+ val_iter = iter(val_loader)
856
+ local_val_loss = 0.0
857
+
858
+ with torch.no_grad():
859
+ for _ in range(cfg.val_steps):
860
+ try:
861
+ vbatch = next(val_iter)
862
+ vimg = vbatch["img"].to(cfg.device)
863
+ vloss = model(vimg, step=final_step, max_iters=cfg.max_iters)
864
+ local_val_loss += float(vloss.item())
865
+ except StopIteration:
866
+ break
867
+
868
+ avg_local_val = local_val_loss / max(cfg.val_steps, 1)
869
+ val_tensor = torch.tensor([avg_local_val], device=cfg.device)
870
+ dist.all_reduce(val_tensor, op=dist.ReduceOp.SUM)
871
+ avg_val_loss = val_tensor.item() / world_size
872
+
873
+ # Final sampling, checkpoint and logging (rank 0 only)
874
+ if rank == 0:
875
+ # Final sampling
876
+ if viz_batch is not None:
877
+ try:
878
+ with torch.no_grad():
879
+ fake_images = model.module.sample(n=16, x_real_batch=viz_batch)
880
+ sample_path = os.path.join(cfg.samples_dir, f"step_{cfg.max_iters:07d}.png")
881
+ save_image(fake_images, sample_path, nrow=2)
882
+ except Exception as e:
883
+ print(f"Final Sampling Error: {e}")
884
+
885
+ # Final checkpoint and logging
886
+ save_checkpoint(final_step, model, opt, cfg, is_latest=True)
887
+ save_checkpoint(final_step, model, opt, cfg, is_latest=False) # Also save as historical
888
+ append_losses_to_csv(final_step, avg_train_loss, avg_val_loss, cfg.loss_csv_path)
889
+ plot_loss_from_csv(cfg.loss_csv_path, cfg.loss_plot_path)
890
+ print(f"\nFinal Step {final_step}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
891
+
892
+ cleanup_ddp()
893
+ if rank == 0:
894
+ print("Training finished.")
895
+
896
+
897
+ if __name__ == "__main__":
898
+ train()