File size: 9,784 Bytes
eb8440b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 | """
ArtiGen Training Script — Flow Matching + Modular Curriculum + Spectral Smoothness.
Optimized for Colab Free Tier / small GPU.
"""
import os
import math
import random
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
try:
from artigen.model import ArtiGen
except ImportError:
from model import ArtiGen
def sample_timesteps(batch_size, device, min_t=0.0, max_t=1.0):
return torch.rand(batch_size, device=device) * (max_t - min_t) + min_t
def rectified_flow_step(z0, z1, t):
B = z0.shape[0]
t_broadcast = t.view(B, 1, 1, 1)
z_t = (1.0 - t_broadcast) * z0 + t_broadcast * z1
v_target = z1 - z0
return z_t, v_target
def spectral_smoothness_loss(v_pred, z_t):
laplacian_h = v_pred[:, :, 2:, :] - 2 * v_pred[:, :, 1:-1, :] + v_pred[:, :, :-2, :]
laplacian_w = v_pred[:, :, :, 2:] - 2 * v_pred[:, :, :, 1:-1] + v_pred[:, :, :, :-2]
lap_h = F.pad(laplacian_h, (0, 0, 1, 1), mode='reflect')
lap_w = F.pad(laplacian_w, (1, 1, 0, 0), mode='reflect')
smooth = (lap_h.abs().mean() + lap_w.abs().mean()) * 0.01
return smooth
class DummyLatentDataset(Dataset):
def __init__(
self,
num_samples=1024,
latent_ch=4,
latent_h=32,
latent_w=32,
text_dim=768,
num_style_classes=128,
num_content_classes=512,
num_mood_classes=64,
):
self.num_samples = num_samples
self.latent_shape = (latent_ch, latent_h, latent_w)
self.text_dim = text_dim
self.num_style_classes = num_style_classes
self.num_content_classes = num_content_classes
self.num_mood_classes = num_mood_classes
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
z0 = torch.randn(self.latent_shape)
text_emb = torch.randn(self.text_dim)
style_label = torch.tensor(random.randint(0, self.num_style_classes - 1), dtype=torch.long)
content_label = torch.tensor(random.randint(0, self.num_content_classes - 1), dtype=torch.long)
mood_label = torch.tensor(random.randint(0, self.num_mood_classes - 1), dtype=torch.long)
return z0, text_emb, style_label, content_label, mood_label
def train_one_epoch(
model,
dataloader,
optimizer,
device,
stage: int = 1,
lambda_flow: float = 1.0,
lambda_smooth: float = 0.05,
lambda_style: float = 0.1,
lambda_content: float = 0.1,
lambda_mood: float = 0.1,
p_uncond: float = 0.1,
grad_clip: float = 1.0,
ema_model=None,
ema_decay: float = 0.9999,
):
model.train()
total_loss = 0.0
total_flow = 0.0
total_smooth = 0.0
num_batches = 0
for z0, text_emb, style_label, content_label, mood_label in dataloader:
z0 = z0.to(device)
text_emb = text_emb.to(device)
style_label = style_label.to(device)
content_label = content_label.to(device)
mood_label = mood_label.to(device)
B = z0.shape[0]
mask_uncond = torch.rand(B, device=device) < p_uncond
text_emb[mask_uncond] = 0.0
z1 = torch.randn_like(z0)
t = sample_timesteps(B, device)
z_t, v_target = rectified_flow_step(z0, z1, t)
v_pred, asdl = model(z_t, t, text_emb, return_asdl=True)
loss_flow = F.mse_loss(v_pred, v_target)
loss = lambda_flow * loss_flow
loss_smooth = spectral_smoothness_loss(v_pred, z_t)
loss = loss + lambda_smooth * loss_smooth
if stage >= 1 and asdl is not None:
if lambda_style > 0:
s_logits = asdl['style_logits']
loss_style = F.cross_entropy(s_logits, style_label)
loss = loss + lambda_style * loss_style
if stage >= 2 and lambda_content > 0:
c_logits = asdl['content_logits']
c_logits_avg = c_logits.mean(dim=1)
loss_content = F.cross_entropy(c_logits_avg, content_label)
loss = loss + lambda_content * loss_content
if stage >= 4 and lambda_mood > 0:
m_logits = asdl['mood_logits']
loss_mood = F.cross_entropy(m_logits, mood_label)
loss = loss + lambda_mood * loss_mood
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
if ema_model is not None:
with torch.no_grad():
for p_ema, p in zip(ema_model.parameters(), model.parameters()):
p_ema.data.mul_(ema_decay).add_(p.data, alpha=1 - ema_decay)
total_loss += loss.item()
total_flow += loss_flow.item()
total_smooth += loss_smooth.item()
num_batches += 1
return {
'loss': total_loss / max(num_batches, 1),
'flow': total_flow / max(num_batches, 1),
'smooth': total_smooth / max(num_batches, 1),
}
def build_optimizer(model_or_params, lr=2e-4, weight_decay=0.01):
params = model_or_params.parameters() if hasattr(model_or_params, 'parameters') else model_or_params
return torch.optim.AdamW(params, lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
def apply_curriculum_freeze(model, stage: int):
for p in model.parameters():
p.requires_grad = False
def unfreeze(m):
for p in m.parameters():
p.requires_grad = True
unfreeze(model.patch_embed)
unfreeze(model.t_embed)
unfreeze(model.cond_proj)
unfreeze(model.cond_transform)
unfreeze(model.blocks)
unfreeze(model.adalns)
unfreeze(model.skip_connect)
unfreeze(model.final_proj)
if stage == 1:
unfreeze(model.style_head)
elif stage == 2:
unfreeze(model.content_head)
elif stage == 3:
unfreeze(model.concept_head)
elif stage == 4:
unfreeze(model.mood_head)
unfreeze(model.comp_head)
elif stage >= 5:
for p in model.parameters():
p.requires_grad = True
frozen = sum(1 for p in model.parameters() if not p.requires_grad)
trainable = sum(1 for p in model.parameters() if p.requires_grad)
print(f"[Curriculum] Stage {stage}: frozen {frozen} params, trainable {trainable} params")
def run_training(
num_epochs_per_stage=5,
batch_size=4,
lr=2e-4,
device='cuda' if torch.cuda.is_available() else 'cpu',
save_dir='./checkpoints',
embed_dim=256,
num_layers=16,
latent_h=32,
latent_w=32,
):
os.makedirs(save_dir, exist_ok=True)
print(f"Device: {device}")
model = ArtiGen(
latent_ch=4, latent_h=latent_h, latent_w=latent_w,
embed_dim=embed_dim, num_layers=num_layers,
d_state=16, expand=2, text_dim=768,
style_classes=128, content_objects=1024, mood_classes=64,
).to(device)
total = sum(p.numel() for p in model.parameters()) / 1e6
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Model total params: {total:.2f}M, trainable: {trainable:.2f}M")
ema_model = ArtiGen(
latent_ch=4, latent_h=latent_h, latent_w=latent_w,
embed_dim=embed_dim, num_layers=num_layers,
d_state=16, expand=2, text_dim=768,
style_classes=128, content_objects=1024, mood_classes=64,
).to(device)
ema_model.load_state_dict(model.state_dict())
ema_model.requires_grad_(False)
ema_model.eval()
dataset = DummyLatentDataset(
num_samples=2048, latent_h=latent_h, latent_w=latent_w,
num_style_classes=128, num_content_classes=1024, num_mood_classes=64,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
for stage in range(1, 6):
print(f"\n{'='*40}\n STAGE {stage}\n{'='*40}")
apply_curriculum_freeze(model, stage)
optimizer = build_optimizer(model, lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs_per_stage, eta_min=lr * 0.1)
for epoch in range(1, num_epochs_per_stage + 1):
metrics = train_one_epoch(model, dataloader, optimizer, device, stage=stage, ema_model=ema_model)
scheduler.step()
print(f" Stage {stage} Epoch {epoch}/{num_epochs_per_stage} | loss={metrics['loss']:.4f} flow={metrics['flow']:.4f} smooth={metrics['smooth']:.4f}")
ckpt_path = os.path.join(save_dir, f"artigen_stage{stage}.pt")
torch.save({'stage': stage, 'model': model.state_dict(), 'ema': ema_model.state_dict(), 'optimizer': optimizer.state_dict()}, ckpt_path)
print(f" Saved checkpoint to {ckpt_path}")
print("\nTraining complete!")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--bs', type=int, default=4)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--dim', type=int, default=256)
parser.add_argument('--layers', type=int, default=16)
parser.add_argument('--latent_h', type=int, default=32)
parser.add_argument('--latent_w', type=int, default=32)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--save_dir', type=str, default='./checkpoints')
args = parser.parse_args()
run_training(
num_epochs_per_stage=args.epochs,
batch_size=args.bs,
lr=args.lr,
device=args.device,
save_dir=args.save_dir,
embed_dim=args.dim,
num_layers=args.layers,
latent_h=args.latent_h,
latent_w=args.latent_w,
)
|