Upload train_jetformer_sogol.py with huggingface_hub
Browse files- train_jetformer_sogol.py +898 -0
train_jetformer_sogol.py
ADDED
|
@@ -0,0 +1,898 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
train_jetformer_ddp.py
|
| 3 |
+
|
| 4 |
+
JetFormer (Flow + AR Transformer + GMM) training script with DDP + HF streaming dataset.
|
| 5 |
+
|
| 6 |
+
Key edits in THIS version (requested):
|
| 7 |
+
1) Noise curriculum is tied DIRECTLY to (step, max_iters) and goes to ~0 at max_iters.
|
| 8 |
+
- Image/RGB noise uses paper-style sigma in [0,255] (default 64 -> 0).
|
| 9 |
+
- Latent z noise uses paper-style std (default 0.3 -> 0).
|
| 10 |
+
2) Removed the constant latent noise (z += N(0, 0.3)) and replaced it with a decaying schedule.
|
| 11 |
+
3) Forward signature changed to: forward(x, step, max_iters)
|
| 12 |
+
4) Fixed a few structural/indent issues in the original paste (HF shard indent, ViTFlow.forward indent, etc.)
|
| 13 |
+
|
| 14 |
+
Notes:
|
| 15 |
+
- This keeps your architecture and training logic intact, only changing noise scheduling + small code fixes.
|
| 16 |
+
- If you want "almost zero but not exact" at the end, set CFG.noise_floor = 1e-6.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
import os
|
| 21 |
+
import csv
|
| 22 |
+
import time
|
| 23 |
+
import pandas as pd
|
| 24 |
+
from dataclasses import dataclass
|
| 25 |
+
from typing import Tuple
|
| 26 |
+
|
| 27 |
+
# --- PIL Fix for Truncated Images ---
|
| 28 |
+
from PIL import ImageFile
|
| 29 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
| 30 |
+
|
| 31 |
+
import torch
|
| 32 |
+
import torch.nn as nn
|
| 33 |
+
import torch.nn.functional as F
|
| 34 |
+
import torch.distributed as dist
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 36 |
+
from torch.utils.data import DataLoader
|
| 37 |
+
from torchvision.utils import save_image
|
| 38 |
+
from torchvision import transforms
|
| 39 |
+
from tqdm import tqdm
|
| 40 |
+
|
| 41 |
+
# Hugging Face Datasets
|
| 42 |
+
from datasets import load_dataset
|
| 43 |
+
|
| 44 |
+
# This must be done BEFORE importing pyplot
|
| 45 |
+
import matplotlib
|
| 46 |
+
matplotlib.use('Agg')
|
| 47 |
+
import matplotlib.pyplot as plt
|
| 48 |
+
|
| 49 |
+
import numpy as np
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
# ======================================================================================
|
| 53 |
+
# Block 1: DDP Setup & Configuration
|
| 54 |
+
# ======================================================================================
|
| 55 |
+
def setup_ddp():
|
| 56 |
+
"""Initializes the distributed process group."""
|
| 57 |
+
if "RANK" not in os.environ:
|
| 58 |
+
os.environ["RANK"] = "0"
|
| 59 |
+
os.environ["WORLD_SIZE"] = "1"
|
| 60 |
+
os.environ["LOCAL_RANK"] = "0"
|
| 61 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
| 62 |
+
os.environ["MASTER_PORT"] = "12355"
|
| 63 |
+
|
| 64 |
+
dist.init_process_group(backend="nccl")
|
| 65 |
+
|
| 66 |
+
rank = int(os.environ["RANK"])
|
| 67 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
| 68 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
| 69 |
+
torch.cuda.set_device(local_rank)
|
| 70 |
+
return rank, local_rank, world_size
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def cleanup_ddp():
|
| 74 |
+
dist.destroy_process_group()
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@dataclass
|
| 78 |
+
class CFG:
|
| 79 |
+
# --- Model Config (Scaled for RTX 4090 24GB) ---
|
| 80 |
+
d_model: int = 768
|
| 81 |
+
n_heads: int = 12
|
| 82 |
+
n_layers: int = 12
|
| 83 |
+
|
| 84 |
+
# --- AstroPT Specific Configs ---
|
| 85 |
+
block_size: int = 1024
|
| 86 |
+
dropout: float = 0.0
|
| 87 |
+
bias: bool = False
|
| 88 |
+
is_causal: bool = True
|
| 89 |
+
|
| 90 |
+
# --- Flow Specification ---
|
| 91 |
+
flow_steps: int = 16
|
| 92 |
+
|
| 93 |
+
# --- Training Config ---
|
| 94 |
+
max_iters: int = 80_000
|
| 95 |
+
save_interval: int = 5000
|
| 96 |
+
batch_size: int = 8
|
| 97 |
+
val_check_interval: int = 5000
|
| 98 |
+
|
| 99 |
+
# --- Optimizer Config ---
|
| 100 |
+
lr: float = 1e-4
|
| 101 |
+
wd: float = 1e-4
|
| 102 |
+
beta2: float = 0.95
|
| 103 |
+
warmup_steps: int = 10000
|
| 104 |
+
|
| 105 |
+
# --- Data Params ---
|
| 106 |
+
img_size: int = 256
|
| 107 |
+
patch: int = 8
|
| 108 |
+
in_ch: int = 3
|
| 109 |
+
|
| 110 |
+
# Derived Dimensions
|
| 111 |
+
n_tokens: int = (img_size // patch) ** 2
|
| 112 |
+
d_token: int = in_ch * patch * patch
|
| 113 |
+
|
| 114 |
+
# --- GMM Head ---
|
| 115 |
+
gmm_K: int = 256
|
| 116 |
+
|
| 117 |
+
# --- Noise curriculum (paper-style, tied to max_iters) ---
|
| 118 |
+
# JetFormer paper uses σ0 = 64 in pixel space [0,255] (≈ 0.251 in [0,1]).
|
| 119 |
+
rgb_sigma0_255: float = 64.0 # start noise in [0,255]
|
| 120 |
+
rgb_sigmaT_255: float = 0.0 # final noise at max_iters (0 => sharpest end)
|
| 121 |
+
|
| 122 |
+
# Latent noise in flow token space (paper mentions std=0.3)
|
| 123 |
+
z_sigma0: float = 0.3 # start latent noise
|
| 124 |
+
z_sigmaT: float = 0.0 # final latent noise at max_iters
|
| 125 |
+
|
| 126 |
+
# If 1.0: reaches final exactly at max_iters.
|
| 127 |
+
# If <1.0: reaches final earlier and stays there.
|
| 128 |
+
noise_decay_frac: float = 1.0
|
| 129 |
+
|
| 130 |
+
# Optional: set to 1e-6 to avoid EXACT zero (sometimes smoother numerically)
|
| 131 |
+
noise_floor: float = 0.0
|
| 132 |
+
|
| 133 |
+
# --- System ---
|
| 134 |
+
grad_clip_val: float = 0.5
|
| 135 |
+
|
| 136 |
+
# Paths
|
| 137 |
+
dataset_name: str = "final_sogol_image_patch_8"
|
| 138 |
+
checkpoint_path: str = ""
|
| 139 |
+
samples_dir: str = ""
|
| 140 |
+
loss_csv_path: str = ""
|
| 141 |
+
loss_plot_path: str = ""
|
| 142 |
+
|
| 143 |
+
# --- Data Sources (Hugging Face) ---
|
| 144 |
+
hf_repo: str = "Smith42/galaxies"
|
| 145 |
+
val_steps: int = 100
|
| 146 |
+
|
| 147 |
+
# DDP Placeholders
|
| 148 |
+
rank: int = 0
|
| 149 |
+
world_size: int = 1
|
| 150 |
+
device: str = "cuda"
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
# ======================================================================================
|
| 154 |
+
# Block 2: Logging Utilities (Rank 0 Only)
|
| 155 |
+
# ======================================================================================
|
| 156 |
+
def append_losses_to_csv(step, train_loss, val_loss, filename):
|
| 157 |
+
file_exists = os.path.isfile(filename)
|
| 158 |
+
with open(filename, 'a', newline='') as csvfile:
|
| 159 |
+
writer = csv.writer(csvfile)
|
| 160 |
+
if not file_exists:
|
| 161 |
+
writer.writerow(['step', 'train_loss', 'val_loss'])
|
| 162 |
+
writer.writerow([step, train_loss, val_loss])
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def plot_loss_from_csv(csv_path, output_path):
|
| 166 |
+
if not os.path.isfile(csv_path):
|
| 167 |
+
return
|
| 168 |
+
df = pd.read_csv(csv_path)
|
| 169 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 170 |
+
ax.plot(df['step'], df['train_loss'], label='Train Loss', color='blue')
|
| 171 |
+
|
| 172 |
+
df_val = df.dropna(subset=['val_loss'])
|
| 173 |
+
if not df_val.empty:
|
| 174 |
+
ax.plot(
|
| 175 |
+
df_val['step'], df_val['val_loss'],
|
| 176 |
+
label='Validation Loss', color='orange',
|
| 177 |
+
linestyle='--', marker='o'
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
ax.set_title('Training and Validation Loss per Step')
|
| 181 |
+
ax.set_xlabel('Step')
|
| 182 |
+
ax.set_ylabel('Average Loss')
|
| 183 |
+
ax.legend()
|
| 184 |
+
ax.grid(True)
|
| 185 |
+
fig.savefig(output_path)
|
| 186 |
+
plt.close(fig)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
# ======================================================================================
|
| 190 |
+
# Block 3: Data Loading
|
| 191 |
+
# ======================================================================================
|
| 192 |
+
def process_hf_item(item):
|
| 193 |
+
img = item['image_crop']
|
| 194 |
+
to_tensor = transforms.ToTensor()
|
| 195 |
+
img_t = to_tensor(img)
|
| 196 |
+
if img_t.shape[0] == 1:
|
| 197 |
+
img_t = img_t.repeat(3, 1, 1)
|
| 198 |
+
return {"img": img_t}
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def get_train_dataloader(cfg: CFG):
|
| 202 |
+
if cfg.rank == 0:
|
| 203 |
+
print(f"Loading streaming dataset: {cfg.hf_repo} (Split: train)")
|
| 204 |
+
ds = load_dataset(cfg.hf_repo, split="train", streaming=True)
|
| 205 |
+
|
| 206 |
+
# shard across ranks so each GPU sees a different stream
|
| 207 |
+
if cfg.world_size > 1:
|
| 208 |
+
ds = ds.shard(num_shards=cfg.world_size, index=cfg.rank)
|
| 209 |
+
|
| 210 |
+
ds = ds.map(process_hf_item, remove_columns=["image", "image_crop", "survey", "ra", "dec"])
|
| 211 |
+
|
| 212 |
+
nw = min(6, max(2, (os.cpu_count() // max(cfg.world_size, 1)) - 1))
|
| 213 |
+
return DataLoader(
|
| 214 |
+
ds,
|
| 215 |
+
batch_size=cfg.batch_size,
|
| 216 |
+
num_workers=nw,
|
| 217 |
+
pin_memory=True,
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def get_val_dataloader(cfg: CFG):
|
| 222 |
+
if cfg.rank == 0:
|
| 223 |
+
print(f"Loading streaming dataset: {cfg.hf_repo} (Split: test)")
|
| 224 |
+
ds = load_dataset(cfg.hf_repo, split="test", streaming=True)
|
| 225 |
+
|
| 226 |
+
# No validation sharding to prevent empty shards on small val sets
|
| 227 |
+
ds = ds.map(process_hf_item, remove_columns=["image", "image_crop", "survey", "ra", "dec"])
|
| 228 |
+
|
| 229 |
+
nw = min(4, max(2, (os.cpu_count() // max(cfg.world_size, 1)) - 1))
|
| 230 |
+
return DataLoader(
|
| 231 |
+
ds,
|
| 232 |
+
batch_size=cfg.batch_size,
|
| 233 |
+
num_workers=nw,
|
| 234 |
+
pin_memory=True,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# ======================================================================================
|
| 239 |
+
# Block 4: Checkpointing (Rank 0 Only)
|
| 240 |
+
# ======================================================================================
|
| 241 |
+
def save_checkpoint(step, model, optimizer, cfg, is_latest=True):
|
| 242 |
+
"""
|
| 243 |
+
Saves the checkpoint.
|
| 244 |
+
1) Always overwrites 'checkpoint_latest.pt' for easy resuming.
|
| 245 |
+
2) If is_latest=False, saves a numbered file like 'checkpoint_step_005000.pt'.
|
| 246 |
+
"""
|
| 247 |
+
if cfg.rank != 0:
|
| 248 |
+
return
|
| 249 |
+
|
| 250 |
+
model_state = model.module.state_dict() if isinstance(model, DDP) else model.state_dict()
|
| 251 |
+
|
| 252 |
+
checkpoint = {
|
| 253 |
+
'step': step,
|
| 254 |
+
'model_state_dict': model_state,
|
| 255 |
+
'optimizer_state_dict': optimizer.state_dict()
|
| 256 |
+
}
|
| 257 |
+
|
| 258 |
+
latest_path = os.path.join(cfg.samples_dir, "checkpoint_latest.pt")
|
| 259 |
+
torch.save(checkpoint, latest_path)
|
| 260 |
+
|
| 261 |
+
if not is_latest:
|
| 262 |
+
history_path = os.path.join(cfg.samples_dir, f"checkpoint_step_{step:07d}.pt")
|
| 263 |
+
torch.save(checkpoint, history_path)
|
| 264 |
+
print(f"Saved historical checkpoint: {history_path}")
|
| 265 |
+
else:
|
| 266 |
+
print(f"Updated latest checkpoint: {latest_path}")
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def load_checkpoint(model, optimizer, cfg):
|
| 270 |
+
latest_path = os.path.join(cfg.samples_dir, "checkpoint_latest.pt")
|
| 271 |
+
|
| 272 |
+
if not os.path.exists(latest_path):
|
| 273 |
+
if cfg.rank == 0:
|
| 274 |
+
print(f"No checkpoint found at {latest_path}. Starting from scratch.")
|
| 275 |
+
return 0
|
| 276 |
+
|
| 277 |
+
map_location = {'cuda:%d' % 0: 'cuda:%d' % cfg.rank}
|
| 278 |
+
checkpoint = torch.load(latest_path, map_location=map_location)
|
| 279 |
+
|
| 280 |
+
model_unwrap = model.module if isinstance(model, DDP) else model
|
| 281 |
+
model_unwrap.load_state_dict(checkpoint['model_state_dict'])
|
| 282 |
+
|
| 283 |
+
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 284 |
+
step = checkpoint['step']
|
| 285 |
+
|
| 286 |
+
if cfg.rank == 0:
|
| 287 |
+
print(f"Checkpoint loaded from {latest_path}. Resuming from step {step}")
|
| 288 |
+
return step
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
# ======================================================================================
|
| 292 |
+
# Block 5: Model Definitions
|
| 293 |
+
# ======================================================================================
|
| 294 |
+
def uniform_dequantize(x: torch.Tensor) -> torch.Tensor:
|
| 295 |
+
# Standard dequantization for 8-bit images
|
| 296 |
+
return (x + torch.rand_like(x) / 256.0).clamp(0.0, 1.0)
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
def patchify(x: torch.Tensor, patch_size: int = 16) -> torch.Tensor:
|
| 300 |
+
B, C, H, W = x.shape
|
| 301 |
+
x = x.unfold(2, patch_size, patch_size).unfold(3, patch_size, patch_size)
|
| 302 |
+
x = x.contiguous().permute(0, 2, 3, 1, 4, 5).reshape(B, -1, C * patch_size * patch_size)
|
| 303 |
+
return x
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def depatchify(tokens: torch.Tensor, C: int = 3, H: int = 256, W: int = 256, patch_size: int = 16) -> torch.Tensor:
|
| 307 |
+
B, N, D = tokens.shape
|
| 308 |
+
hp, wp = H // patch_size, W // patch_size
|
| 309 |
+
x = tokens.reshape(B, hp, wp, C, patch_size, patch_size)
|
| 310 |
+
x = x.permute(0, 3, 1, 4, 2, 5).reshape(B, C, H, W)
|
| 311 |
+
return x
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def cosine_decay(step: int, T: int, start: float, end: float) -> float:
|
| 315 |
+
"""
|
| 316 |
+
Cosine decay from start -> end over steps [0, T].
|
| 317 |
+
Returns exactly end for step >= T.
|
| 318 |
+
"""
|
| 319 |
+
if T <= 0:
|
| 320 |
+
return end
|
| 321 |
+
if step >= T:
|
| 322 |
+
return end
|
| 323 |
+
x = step / T # in [0,1)
|
| 324 |
+
return end + 0.5 * (start - end) * (1.0 + math.cos(math.pi * x))
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
class ViTCouplingBlock(nn.Module):
|
| 328 |
+
def __init__(self, in_channels: int, n_tokens: int, width: int = 512, depth: int = 4, heads: int = 8):
|
| 329 |
+
super().__init__()
|
| 330 |
+
self.in_proj = nn.Linear(in_channels, width)
|
| 331 |
+
self.pos_emb = nn.Parameter(torch.randn(1, n_tokens, width) * 0.02)
|
| 332 |
+
|
| 333 |
+
layer = nn.TransformerEncoderLayer(
|
| 334 |
+
d_model=width, nhead=heads, dim_feedforward=2048, dropout=0.0,
|
| 335 |
+
activation="gelu", batch_first=True, norm_first=True
|
| 336 |
+
)
|
| 337 |
+
self.transformer = nn.TransformerEncoder(layer, num_layers=depth)
|
| 338 |
+
self.out_proj = nn.Linear(width, in_channels * 2)
|
| 339 |
+
|
| 340 |
+
nn.init.zeros_(self.out_proj.weight)
|
| 341 |
+
nn.init.zeros_(self.out_proj.bias)
|
| 342 |
+
|
| 343 |
+
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 344 |
+
h = self.in_proj(x) + self.pos_emb
|
| 345 |
+
h = self.transformer(h)
|
| 346 |
+
st = self.out_proj(h)
|
| 347 |
+
s, t = st.chunk(2, dim=-1)
|
| 348 |
+
s = torch.tanh(s)
|
| 349 |
+
return s, t
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
class ViTAffineCoupling(nn.Module):
|
| 353 |
+
def __init__(self, d_token: int, n_tokens: int):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.half_d = d_token // 2
|
| 356 |
+
self.register_buffer('perm', torch.randperm(d_token))
|
| 357 |
+
self.register_buffer('inv_perm', torch.argsort(self.perm))
|
| 358 |
+
self.net = ViTCouplingBlock(self.half_d, n_tokens)
|
| 359 |
+
|
| 360 |
+
def forward(self, x: torch.Tensor, reverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 361 |
+
if not reverse:
|
| 362 |
+
x = x[..., self.perm]
|
| 363 |
+
x_a, x_b = x[..., :self.half_d], x[..., self.half_d:]
|
| 364 |
+
s, t = self.net(x_a)
|
| 365 |
+
y_b = x_b * torch.exp(s) + t
|
| 366 |
+
y = torch.cat([x_a, y_b], dim=-1)
|
| 367 |
+
logdet = s.sum(dim=(1, 2))
|
| 368 |
+
return y, logdet
|
| 369 |
+
else:
|
| 370 |
+
x_a, x_b = x[..., :self.half_d], x[..., self.half_d:]
|
| 371 |
+
s, t = self.net(x_a)
|
| 372 |
+
y_b = (x_b - t) * torch.exp(-s)
|
| 373 |
+
y = torch.cat([x_a, y_b], dim=-1)
|
| 374 |
+
y = y[..., self.inv_perm]
|
| 375 |
+
logdet = -s.sum(dim=(1, 2))
|
| 376 |
+
return y, logdet
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
class ViTFlow(nn.Module):
|
| 380 |
+
def __init__(self, d_token: int, n_tokens: int, steps: int = 32):
|
| 381 |
+
super().__init__()
|
| 382 |
+
self.blocks = nn.ModuleList([ViTAffineCoupling(d_token, n_tokens) for _ in range(steps)])
|
| 383 |
+
|
| 384 |
+
def forward(self, x: torch.Tensor, reverse: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 385 |
+
logdet = x.new_zeros(x.size(0))
|
| 386 |
+
z = x
|
| 387 |
+
if not reverse:
|
| 388 |
+
for b in self.blocks:
|
| 389 |
+
z, ld = b(z, reverse=False)
|
| 390 |
+
logdet += ld
|
| 391 |
+
else:
|
| 392 |
+
for b in reversed(self.blocks):
|
| 393 |
+
z, ld = b(z, reverse=True)
|
| 394 |
+
logdet += ld
|
| 395 |
+
return z, logdet
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
def compute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
| 399 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
| 400 |
+
t = torch.arange(end, device=freqs.device, dtype=torch.float32)
|
| 401 |
+
freqs = torch.outer(t, freqs)
|
| 402 |
+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
|
| 403 |
+
return freqs_cis
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor):
|
| 407 |
+
x_c = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
| 408 |
+
if freqs_cis.dtype not in (torch.complex64, torch.complex128):
|
| 409 |
+
if freqs_cis.dim() == 2:
|
| 410 |
+
freqs_cis = freqs_cis.view(*freqs_cis.shape[:-1], -1, 2)
|
| 411 |
+
freqs_cis = torch.view_as_complex(freqs_cis)
|
| 412 |
+
freqs_cis = freqs_cis.view(1, x.size(1), 1, x_c.size(-1))
|
| 413 |
+
x_out = torch.view_as_real(x_c * freqs_cis).flatten(3)
|
| 414 |
+
return x_out.type_as(x)
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
class RMSNorm(nn.Module):
|
| 418 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 419 |
+
super().__init__()
|
| 420 |
+
self.eps = eps
|
| 421 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
| 422 |
+
|
| 423 |
+
def _norm(self, x):
|
| 424 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
| 425 |
+
|
| 426 |
+
def forward(self, x):
|
| 427 |
+
output = self._norm(x.float()).type_as(x)
|
| 428 |
+
return output * self.weight
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
class GemmaMLP(nn.Module):
|
| 432 |
+
def __init__(self, cfg: CFG):
|
| 433 |
+
super().__init__()
|
| 434 |
+
self.hidden_dim = 4 * cfg.d_model
|
| 435 |
+
self.gate_proj = nn.Linear(cfg.d_model, self.hidden_dim, bias=cfg.bias)
|
| 436 |
+
self.up_proj = nn.Linear(cfg.d_model, self.hidden_dim, bias=cfg.bias)
|
| 437 |
+
self.down_proj = nn.Linear(self.hidden_dim, cfg.d_model, bias=cfg.bias)
|
| 438 |
+
self.dropout = nn.Dropout(cfg.dropout)
|
| 439 |
+
|
| 440 |
+
def forward(self, x):
|
| 441 |
+
gate = self.gate_proj(x)
|
| 442 |
+
gate = F.gelu(gate, approximate="tanh")
|
| 443 |
+
up = self.up_proj(x)
|
| 444 |
+
x = gate * up
|
| 445 |
+
x = self.down_proj(x)
|
| 446 |
+
x = self.dropout(x)
|
| 447 |
+
return x
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
class GemmaAttention(nn.Module):
|
| 451 |
+
def __init__(self, cfg: CFG):
|
| 452 |
+
super().__init__()
|
| 453 |
+
self.head_dim = cfg.d_model // cfg.n_heads
|
| 454 |
+
|
| 455 |
+
self.q_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
|
| 456 |
+
self.k_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
|
| 457 |
+
self.v_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
|
| 458 |
+
self.o_proj = nn.Linear(cfg.d_model, cfg.d_model, bias=cfg.bias)
|
| 459 |
+
|
| 460 |
+
self.resid_dropout = nn.Dropout(cfg.dropout)
|
| 461 |
+
self.n_head = cfg.n_heads
|
| 462 |
+
self.dropout = cfg.dropout
|
| 463 |
+
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
|
| 464 |
+
|
| 465 |
+
self.register_buffer("freqs_cis", compute_freqs_cis(self.head_dim, cfg.block_size), persistent=False)
|
| 466 |
+
|
| 467 |
+
def forward(self, x):
|
| 468 |
+
B, T, C = x.size()
|
| 469 |
+
q = self.q_proj(x).view(B, T, self.n_head, self.head_dim)
|
| 470 |
+
k = self.k_proj(x).view(B, T, self.n_head, self.head_dim)
|
| 471 |
+
v = self.v_proj(x).view(B, T, self.n_head, self.head_dim)
|
| 472 |
+
|
| 473 |
+
freqs_cis = self.freqs_cis[:T]
|
| 474 |
+
q = apply_rotary_emb(q, freqs_cis)
|
| 475 |
+
k = apply_rotary_emb(k, freqs_cis)
|
| 476 |
+
|
| 477 |
+
q = q.transpose(1, 2)
|
| 478 |
+
k = k.transpose(1, 2)
|
| 479 |
+
v = v.transpose(1, 2)
|
| 480 |
+
|
| 481 |
+
if self.flash:
|
| 482 |
+
y = torch.nn.functional.scaled_dot_product_attention(
|
| 483 |
+
q, k, v,
|
| 484 |
+
attn_mask=None,
|
| 485 |
+
dropout_p=self.dropout if self.training else 0.0,
|
| 486 |
+
is_causal=True
|
| 487 |
+
)
|
| 488 |
+
else:
|
| 489 |
+
att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
|
| 490 |
+
mask = torch.tril(torch.ones(T, T, device=x.device)).view(1, 1, T, T)
|
| 491 |
+
att = att.masked_fill(mask == 0, float('-inf'))
|
| 492 |
+
att = F.softmax(att, dim=-1)
|
| 493 |
+
y = att @ v
|
| 494 |
+
|
| 495 |
+
y = y.transpose(1, 2).contiguous().view(B, T, C)
|
| 496 |
+
y = self.resid_dropout(self.o_proj(y))
|
| 497 |
+
return y
|
| 498 |
+
|
| 499 |
+
|
| 500 |
+
class GemmaBlock(nn.Module):
|
| 501 |
+
def __init__(self, cfg: CFG):
|
| 502 |
+
super().__init__()
|
| 503 |
+
self.ln_1 = RMSNorm(cfg.d_model)
|
| 504 |
+
self.attn = GemmaAttention(cfg)
|
| 505 |
+
self.ln_2 = RMSNorm(cfg.d_model)
|
| 506 |
+
self.mlp = GemmaMLP(cfg)
|
| 507 |
+
|
| 508 |
+
def forward(self, x):
|
| 509 |
+
x = x + self.attn(self.ln_1(x))
|
| 510 |
+
x = x + self.mlp(self.ln_2(x))
|
| 511 |
+
return x
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
class AstroPTBackbone(nn.Module):
|
| 515 |
+
def __init__(self, cfg: CFG):
|
| 516 |
+
super().__init__()
|
| 517 |
+
self.drop = nn.Dropout(cfg.dropout)
|
| 518 |
+
self.h = nn.ModuleList([GemmaBlock(cfg) for _ in range(cfg.n_layers)])
|
| 519 |
+
self.ln_f = RMSNorm(cfg.d_model)
|
| 520 |
+
self.apply(self._init_weights)
|
| 521 |
+
|
| 522 |
+
def _init_weights(self, module):
|
| 523 |
+
if isinstance(module, nn.Linear):
|
| 524 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 525 |
+
if module.bias is not None:
|
| 526 |
+
torch.nn.init.zeros_(module.bias)
|
| 527 |
+
elif isinstance(module, nn.Embedding):
|
| 528 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
| 529 |
+
|
| 530 |
+
def forward(self, x):
|
| 531 |
+
x = self.drop(x)
|
| 532 |
+
for block in self.h:
|
| 533 |
+
x = block(x)
|
| 534 |
+
x = self.ln_f(x)
|
| 535 |
+
return x
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
class GMMHead(nn.Module):
|
| 539 |
+
def __init__(self, d_model: int, d_token: int, K: int):
|
| 540 |
+
super().__init__()
|
| 541 |
+
self.K, self.D = K, d_token
|
| 542 |
+
self.proj = nn.Linear(d_model, K * (1 + 2 * d_token))
|
| 543 |
+
|
| 544 |
+
def forward(self, h: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 545 |
+
B, N, _ = h.shape
|
| 546 |
+
out = self.proj(h).view(B, N, self.K, 1 + 2 * self.D)
|
| 547 |
+
|
| 548 |
+
logits_pi = out[..., 0]
|
| 549 |
+
mu = out[..., 1:1 + self.D]
|
| 550 |
+
log_sigma = out[..., 1 + self.D:]
|
| 551 |
+
|
| 552 |
+
log_sigma = torch.clamp(log_sigma, -7, 2)
|
| 553 |
+
return logits_pi, mu, log_sigma
|
| 554 |
+
|
| 555 |
+
|
| 556 |
+
def gmm_nll(y: torch.Tensor, logits_pi: torch.Tensor, mu: torch.Tensor, log_sigma: torch.Tensor) -> torch.Tensor:
|
| 557 |
+
B, N, D = y.shape
|
| 558 |
+
K = logits_pi.size(-1)
|
| 559 |
+
|
| 560 |
+
y = y.unsqueeze(2)
|
| 561 |
+
inv_var = torch.exp(-2 * log_sigma)
|
| 562 |
+
logp = -0.5 * ((y - mu) ** 2 * inv_var).sum(-1) - log_sigma.sum(-1) - 0.5 * D * math.log(2 * math.pi)
|
| 563 |
+
logmix = F.log_softmax(logits_pi, dim=-1) + logp
|
| 564 |
+
return -torch.logsumexp(logmix, dim=-1).sum(dim=1)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class JetFormer(nn.Module):
|
| 568 |
+
def __init__(self, cfg: CFG):
|
| 569 |
+
super().__init__()
|
| 570 |
+
self.cfg = cfg
|
| 571 |
+
self.flow = ViTFlow(cfg.d_token, cfg.n_tokens, cfg.flow_steps)
|
| 572 |
+
self.in_proj = nn.Linear(cfg.d_token, cfg.d_model)
|
| 573 |
+
self.pos = nn.Parameter(torch.randn(1, cfg.n_tokens, cfg.d_model) * 0.02)
|
| 574 |
+
self.gpt = AstroPTBackbone(cfg)
|
| 575 |
+
self.head = GMMHead(cfg.d_model, cfg.d_token, cfg.gmm_K)
|
| 576 |
+
|
| 577 |
+
def forward(self, x: torch.Tensor, step: int, max_iters: int) -> torch.Tensor:
|
| 578 |
+
"""
|
| 579 |
+
Noise curriculum is tied to (step, max_iters) and decays to final values at max_iters.
|
| 580 |
+
- RGB noise: sigma in [0,255] (paper-style)
|
| 581 |
+
- z-noise: token-space Gaussian std (paper-style)
|
| 582 |
+
"""
|
| 583 |
+
x = uniform_dequantize(x)
|
| 584 |
+
|
| 585 |
+
# Curriculum length T (end exactly at max_iters if noise_decay_frac=1.0)
|
| 586 |
+
T = int(max_iters * self.cfg.noise_decay_frac)
|
| 587 |
+
T = max(T, 1)
|
| 588 |
+
|
| 589 |
+
# ---- RGB noise schedule ----
|
| 590 |
+
rgb_sigma_255 = cosine_decay(step, T, self.cfg.rgb_sigma0_255, self.cfg.rgb_sigmaT_255)
|
| 591 |
+
rgb_sigma = rgb_sigma_255 / 255.0
|
| 592 |
+
if self.cfg.noise_floor > 0:
|
| 593 |
+
rgb_sigma = max(rgb_sigma, self.cfg.noise_floor)
|
| 594 |
+
|
| 595 |
+
if self.training and rgb_sigma > 0.0:
|
| 596 |
+
x = (x + torch.randn_like(x) * rgb_sigma).clamp(0.0, 1.0)
|
| 597 |
+
|
| 598 |
+
# ---- Flow encode ----
|
| 599 |
+
tokens_in = patchify(x, self.cfg.patch)
|
| 600 |
+
z, logdet = self.flow(tokens_in, reverse=False)
|
| 601 |
+
|
| 602 |
+
# ---- Latent z noise schedule (decays to ~0 by max_iters) ----
|
| 603 |
+
z_sigma = cosine_decay(step, T, self.cfg.z_sigma0, self.cfg.z_sigmaT)
|
| 604 |
+
if self.cfg.noise_floor > 0:
|
| 605 |
+
z_sigma = max(z_sigma, self.cfg.noise_floor)
|
| 606 |
+
|
| 607 |
+
if self.training and z_sigma > 0.0:
|
| 608 |
+
z = z + torch.randn_like(z) * z_sigma
|
| 609 |
+
|
| 610 |
+
# ---- AR transformer + GMM ----
|
| 611 |
+
h = self.in_proj(z) + self.pos
|
| 612 |
+
h = self.gpt(h)
|
| 613 |
+
|
| 614 |
+
logits_pi, mu, log_sigma = self.head(h[:, :-1])
|
| 615 |
+
target = z[:, 1:]
|
| 616 |
+
|
| 617 |
+
nll_gmm = gmm_nll(target, logits_pi, mu, log_sigma)
|
| 618 |
+
loss = (nll_gmm - logdet).mean()
|
| 619 |
+
return loss
|
| 620 |
+
|
| 621 |
+
@torch.no_grad()
|
| 622 |
+
def sample(self, n: int = 16, x_real_batch: torch.Tensor = None):
|
| 623 |
+
self.eval()
|
| 624 |
+
B = n
|
| 625 |
+
N = self.cfg.n_tokens
|
| 626 |
+
device = next(self.parameters()).device
|
| 627 |
+
|
| 628 |
+
if x_real_batch is None:
|
| 629 |
+
z_seq = torch.zeros(B, N, self.cfg.d_token, device=device)
|
| 630 |
+
for t in range(N - 1):
|
| 631 |
+
h_in = self.in_proj(z_seq) + self.pos
|
| 632 |
+
h_out = self.gpt(h_in)
|
| 633 |
+
logits_pi, mu, log_sigma = self.head(h_out[:, t:t + 1])
|
| 634 |
+
|
| 635 |
+
pi = F.softmax(logits_pi.squeeze(1), dim=-1)
|
| 636 |
+
comp_idx = torch.multinomial(pi, 1)
|
| 637 |
+
gather_idx = comp_idx[..., None].expand(-1, -1, self.cfg.d_token)
|
| 638 |
+
|
| 639 |
+
sel_mu = mu.squeeze(1).gather(1, gather_idx).squeeze(1)
|
| 640 |
+
sel_sigma = log_sigma.squeeze(1).gather(1, gather_idx).squeeze(1).exp()
|
| 641 |
+
|
| 642 |
+
z_next = sel_mu + torch.randn_like(sel_mu) * sel_sigma
|
| 643 |
+
z_seq[:, t + 1] = z_next
|
| 644 |
+
|
| 645 |
+
x_rec_tokens, _ = self.flow(z_seq, reverse=True)
|
| 646 |
+
x_rec = depatchify(x_rec_tokens, self.cfg.in_ch, self.cfg.img_size, self.cfg.img_size, self.cfg.patch)
|
| 647 |
+
return x_rec.clamp(0, 1)
|
| 648 |
+
|
| 649 |
+
else:
|
| 650 |
+
x_real = x_real_batch.to(device)
|
| 651 |
+
x_real_proc = uniform_dequantize(x_real)
|
| 652 |
+
|
| 653 |
+
z_real, _ = self.flow(patchify(x_real_proc, self.cfg.patch), reverse=False)
|
| 654 |
+
h_in = self.in_proj(z_real) + self.pos
|
| 655 |
+
h_out = self.gpt(h_in)
|
| 656 |
+
|
| 657 |
+
logits_pi, mu, log_sigma = self.head(h_out)
|
| 658 |
+
best_comp_idx = torch.argmax(logits_pi, dim=-1, keepdim=True)
|
| 659 |
+
gather_idx = best_comp_idx.unsqueeze(-1).expand(-1, -1, -1, self.cfg.d_token)
|
| 660 |
+
|
| 661 |
+
z_pred_next = torch.gather(mu, 2, gather_idx).squeeze(2)
|
| 662 |
+
|
| 663 |
+
z_rec = torch.zeros_like(z_real)
|
| 664 |
+
z_rec[:, 0] = z_real[:, 0]
|
| 665 |
+
z_rec[:, 1:] = z_pred_next[:, :-1]
|
| 666 |
+
|
| 667 |
+
x_rec_tokens, _ = self.flow(z_rec, reverse=True)
|
| 668 |
+
x_rec = depatchify(x_rec_tokens, self.cfg.in_ch, self.cfg.img_size, self.cfg.img_size, self.cfg.patch)
|
| 669 |
+
|
| 670 |
+
combined = torch.stack([x_real, x_rec.clamp(0, 1)], dim=1).view(
|
| 671 |
+
-1, self.cfg.in_ch, self.cfg.img_size, self.cfg.img_size
|
| 672 |
+
)
|
| 673 |
+
return combined
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
# ======================================================================================
|
| 677 |
+
# Block 6: Main Training Loop (DDP Aware)
|
| 678 |
+
# ======================================================================================
|
| 679 |
+
def train():
|
| 680 |
+
# --- 1. DDP Setup ---
|
| 681 |
+
rank, local_rank, world_size = setup_ddp()
|
| 682 |
+
|
| 683 |
+
cfg = CFG()
|
| 684 |
+
cfg.rank = rank
|
| 685 |
+
cfg.world_size = world_size
|
| 686 |
+
cfg.device = f"cuda:{local_rank}"
|
| 687 |
+
|
| 688 |
+
# ### PATH SETUP ###
|
| 689 |
+
cfg.samples_dir = f"samples_{cfg.dataset_name}_256"
|
| 690 |
+
cfg.loss_csv_path = f"loss_log_{cfg.dataset_name}_256.csv"
|
| 691 |
+
cfg.loss_plot_path = f"loss_plot_{cfg.dataset_name}_256.png"
|
| 692 |
+
|
| 693 |
+
if rank == 0:
|
| 694 |
+
os.makedirs(cfg.samples_dir, exist_ok=True)
|
| 695 |
+
print(f"--- DDP CONFIGURATION ---")
|
| 696 |
+
print(f" World Size: {world_size}")
|
| 697 |
+
print(f" Per-GPU Batch: {cfg.batch_size}")
|
| 698 |
+
print(f" Global Batch: {cfg.batch_size * world_size}")
|
| 699 |
+
print(f" Dataset: {cfg.hf_repo} (Streaming + Sharded)")
|
| 700 |
+
print(f" Saving Checkpoints to: {cfg.samples_dir}")
|
| 701 |
+
print(f"-------------------------")
|
| 702 |
+
|
| 703 |
+
print("Noise curriculum:")
|
| 704 |
+
print(f" RGB sigma: {cfg.rgb_sigma0_255} -> {cfg.rgb_sigmaT_255} (in [0,255])")
|
| 705 |
+
print(f" z sigma: {cfg.z_sigma0} -> {cfg.z_sigmaT} (token space)")
|
| 706 |
+
print(f" decay_frac: {cfg.noise_decay_frac} (T = {int(cfg.max_iters*cfg.noise_decay_frac)})")
|
| 707 |
+
print(f" floor: {cfg.noise_floor}")
|
| 708 |
+
|
| 709 |
+
# --- 2. Model Setup ---
|
| 710 |
+
model = JetFormer(cfg).to(cfg.device)
|
| 711 |
+
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
|
| 712 |
+
|
| 713 |
+
# --- 3. Optimizer Setup ---
|
| 714 |
+
opt = torch.optim.AdamW(
|
| 715 |
+
model.parameters(),
|
| 716 |
+
lr=cfg.lr,
|
| 717 |
+
weight_decay=cfg.wd,
|
| 718 |
+
betas=(0.9, cfg.beta2)
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
# --- 4. Checkpoint Loading ---
|
| 722 |
+
start_step = load_checkpoint(model, opt, cfg)
|
| 723 |
+
|
| 724 |
+
# --- 5. Data Loading ---
|
| 725 |
+
train_loader = get_train_dataloader(cfg)
|
| 726 |
+
val_loader = get_val_dataloader(cfg)
|
| 727 |
+
|
| 728 |
+
# Pre-load fixed batch for viz
|
| 729 |
+
viz_batch = None
|
| 730 |
+
if rank == 0:
|
| 731 |
+
print("Fetching visualization batch...")
|
| 732 |
+
try:
|
| 733 |
+
viz_batch = next(iter(val_loader))['img'][:16].to(cfg.device)
|
| 734 |
+
except Exception as e:
|
| 735 |
+
print(f"Warning: Could not load viz batch: {e}")
|
| 736 |
+
|
| 737 |
+
# --- 6. Scheduler ---
|
| 738 |
+
def get_lr_schedule(step):
|
| 739 |
+
if step < cfg.warmup_steps:
|
| 740 |
+
return step / cfg.warmup_steps
|
| 741 |
+
else:
|
| 742 |
+
progress = (step - cfg.warmup_steps) / (cfg.max_iters - cfg.warmup_steps)
|
| 743 |
+
progress = max(0.0, min(1.0, progress))
|
| 744 |
+
return 0.5 * (1.0 + math.cos(math.pi * progress))
|
| 745 |
+
|
| 746 |
+
if rank == 0:
|
| 747 |
+
print(f"Starting training loop from step {start_step}...")
|
| 748 |
+
|
| 749 |
+
# --- 7. Main Loop ---
|
| 750 |
+
model.train()
|
| 751 |
+
train_iter = iter(train_loader)
|
| 752 |
+
|
| 753 |
+
if rank == 0:
|
| 754 |
+
pbar = tqdm(range(start_step, cfg.max_iters), initial=start_step, total=cfg.max_iters)
|
| 755 |
+
else:
|
| 756 |
+
pbar = range(start_step, cfg.max_iters)
|
| 757 |
+
|
| 758 |
+
train_loss_accum = 0.0
|
| 759 |
+
accum_steps = 0
|
| 760 |
+
|
| 761 |
+
for step in pbar:
|
| 762 |
+
try:
|
| 763 |
+
batch = next(train_iter)
|
| 764 |
+
except StopIteration:
|
| 765 |
+
train_iter = iter(train_loader)
|
| 766 |
+
batch = next(train_iter)
|
| 767 |
+
|
| 768 |
+
img = batch["img"].to(cfg.device)
|
| 769 |
+
|
| 770 |
+
# LR Update
|
| 771 |
+
lr_scale = get_lr_schedule(step)
|
| 772 |
+
for param_group in opt.param_groups:
|
| 773 |
+
param_group['lr'] = cfg.lr * lr_scale
|
| 774 |
+
|
| 775 |
+
# Forward Pass (BFloat16 for H100 / Ampere+)
|
| 776 |
+
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 777 |
+
loss = model(img, step=step, max_iters=cfg.max_iters)
|
| 778 |
+
|
| 779 |
+
# Backward
|
| 780 |
+
opt.zero_grad(set_to_none=True)
|
| 781 |
+
loss.backward()
|
| 782 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.grad_clip_val)
|
| 783 |
+
opt.step()
|
| 784 |
+
|
| 785 |
+
current_loss = float(loss.item())
|
| 786 |
+
|
| 787 |
+
if rank == 0:
|
| 788 |
+
train_loss_accum += current_loss
|
| 789 |
+
accum_steps += 1
|
| 790 |
+
if isinstance(pbar, tqdm):
|
| 791 |
+
pbar.set_postfix(loss=f"{current_loss:.3f}", lr=f"{opt.param_groups[0]['lr']:.2e}")
|
| 792 |
+
|
| 793 |
+
if step > 0:
|
| 794 |
+
# 1. Validation and Image Sampling
|
| 795 |
+
if step % cfg.val_check_interval == 0:
|
| 796 |
+
if rank == 0:
|
| 797 |
+
avg_train_loss = train_loss_accum / max(accum_steps, 1)
|
| 798 |
+
train_loss_accum = 0.0
|
| 799 |
+
accum_steps = 0
|
| 800 |
+
|
| 801 |
+
# Generate Samples
|
| 802 |
+
if viz_batch is not None:
|
| 803 |
+
model.eval()
|
| 804 |
+
try:
|
| 805 |
+
with torch.no_grad():
|
| 806 |
+
fake_images = model.module.sample(n=16, x_real_batch=viz_batch)
|
| 807 |
+
sample_path = os.path.join(cfg.samples_dir, f"step_{step:07d}.png")
|
| 808 |
+
save_image(fake_images, sample_path, nrow=2)
|
| 809 |
+
except Exception as e:
|
| 810 |
+
print(f"Interval Sampling Error: {e}")
|
| 811 |
+
model.train()
|
| 812 |
+
|
| 813 |
+
# Run Validation
|
| 814 |
+
model.eval()
|
| 815 |
+
val_iter = iter(val_loader)
|
| 816 |
+
local_val_loss = 0.0
|
| 817 |
+
|
| 818 |
+
with torch.no_grad():
|
| 819 |
+
for _ in range(cfg.val_steps):
|
| 820 |
+
try:
|
| 821 |
+
vbatch = next(val_iter)
|
| 822 |
+
vimg = vbatch["img"].to(cfg.device)
|
| 823 |
+
vloss = model(vimg, step=step, max_iters=cfg.max_iters)
|
| 824 |
+
local_val_loss += float(vloss.item())
|
| 825 |
+
except StopIteration:
|
| 826 |
+
break
|
| 827 |
+
|
| 828 |
+
avg_local_val = local_val_loss / max(cfg.val_steps, 1)
|
| 829 |
+
val_tensor = torch.tensor([avg_local_val], device=cfg.device)
|
| 830 |
+
dist.all_reduce(val_tensor, op=dist.ReduceOp.SUM)
|
| 831 |
+
avg_val_loss = val_tensor.item() / world_size
|
| 832 |
+
|
| 833 |
+
if rank == 0:
|
| 834 |
+
save_checkpoint(step, model, opt, cfg, is_latest=True)
|
| 835 |
+
append_losses_to_csv(step, avg_train_loss, avg_val_loss, cfg.loss_csv_path)
|
| 836 |
+
plot_loss_from_csv(cfg.loss_csv_path, cfg.loss_plot_path)
|
| 837 |
+
print(f"\nStep {step}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
|
| 838 |
+
|
| 839 |
+
model.train()
|
| 840 |
+
|
| 841 |
+
# 2. Historical checkpoint saving
|
| 842 |
+
if step % cfg.save_interval == 0:
|
| 843 |
+
if rank == 0:
|
| 844 |
+
save_checkpoint(step, model, opt, cfg, is_latest=False)
|
| 845 |
+
|
| 846 |
+
# Final checkpoint and logging at max_iters (after training loop completes)
|
| 847 |
+
final_step = cfg.max_iters - 1
|
| 848 |
+
|
| 849 |
+
# Calculate final average train loss (rank 0 only)
|
| 850 |
+
if rank == 0:
|
| 851 |
+
avg_train_loss = train_loss_accum / max(accum_steps, 1) if accum_steps > 0 else 0.0
|
| 852 |
+
|
| 853 |
+
# Final validation (all ranks)
|
| 854 |
+
model.eval()
|
| 855 |
+
val_iter = iter(val_loader)
|
| 856 |
+
local_val_loss = 0.0
|
| 857 |
+
|
| 858 |
+
with torch.no_grad():
|
| 859 |
+
for _ in range(cfg.val_steps):
|
| 860 |
+
try:
|
| 861 |
+
vbatch = next(val_iter)
|
| 862 |
+
vimg = vbatch["img"].to(cfg.device)
|
| 863 |
+
vloss = model(vimg, step=final_step, max_iters=cfg.max_iters)
|
| 864 |
+
local_val_loss += float(vloss.item())
|
| 865 |
+
except StopIteration:
|
| 866 |
+
break
|
| 867 |
+
|
| 868 |
+
avg_local_val = local_val_loss / max(cfg.val_steps, 1)
|
| 869 |
+
val_tensor = torch.tensor([avg_local_val], device=cfg.device)
|
| 870 |
+
dist.all_reduce(val_tensor, op=dist.ReduceOp.SUM)
|
| 871 |
+
avg_val_loss = val_tensor.item() / world_size
|
| 872 |
+
|
| 873 |
+
# Final sampling, checkpoint and logging (rank 0 only)
|
| 874 |
+
if rank == 0:
|
| 875 |
+
# Final sampling
|
| 876 |
+
if viz_batch is not None:
|
| 877 |
+
try:
|
| 878 |
+
with torch.no_grad():
|
| 879 |
+
fake_images = model.module.sample(n=16, x_real_batch=viz_batch)
|
| 880 |
+
sample_path = os.path.join(cfg.samples_dir, f"step_{cfg.max_iters:07d}.png")
|
| 881 |
+
save_image(fake_images, sample_path, nrow=2)
|
| 882 |
+
except Exception as e:
|
| 883 |
+
print(f"Final Sampling Error: {e}")
|
| 884 |
+
|
| 885 |
+
# Final checkpoint and logging
|
| 886 |
+
save_checkpoint(final_step, model, opt, cfg, is_latest=True)
|
| 887 |
+
save_checkpoint(final_step, model, opt, cfg, is_latest=False) # Also save as historical
|
| 888 |
+
append_losses_to_csv(final_step, avg_train_loss, avg_val_loss, cfg.loss_csv_path)
|
| 889 |
+
plot_loss_from_csv(cfg.loss_csv_path, cfg.loss_plot_path)
|
| 890 |
+
print(f"\nFinal Step {final_step}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}")
|
| 891 |
+
|
| 892 |
+
cleanup_ddp()
|
| 893 |
+
if rank == 0:
|
| 894 |
+
print("Training finished.")
|
| 895 |
+
|
| 896 |
+
|
| 897 |
+
if __name__ == "__main__":
|
| 898 |
+
train()
|