""" ArtiGen Training Script — Flow Matching + Modular Curriculum + Spectral Smoothness. Optimized for Colab Free Tier / small GPU. """ import os import math import random from pathlib import Path from typing import Optional import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset from torchvision import transforms from PIL import Image import numpy as np try: from artigen.model import ArtiGen except ImportError: from model import ArtiGen def sample_timesteps(batch_size, device, min_t=0.0, max_t=1.0): return torch.rand(batch_size, device=device) * (max_t - min_t) + min_t def rectified_flow_step(z0, z1, t): B = z0.shape[0] t_broadcast = t.view(B, 1, 1, 1) z_t = (1.0 - t_broadcast) * z0 + t_broadcast * z1 v_target = z1 - z0 return z_t, v_target def spectral_smoothness_loss(v_pred, z_t): laplacian_h = v_pred[:, :, 2:, :] - 2 * v_pred[:, :, 1:-1, :] + v_pred[:, :, :-2, :] laplacian_w = v_pred[:, :, :, 2:] - 2 * v_pred[:, :, :, 1:-1] + v_pred[:, :, :, :-2] lap_h = F.pad(laplacian_h, (0, 0, 1, 1), mode='reflect') lap_w = F.pad(laplacian_w, (1, 1, 0, 0), mode='reflect') smooth = (lap_h.abs().mean() + lap_w.abs().mean()) * 0.01 return smooth class DummyLatentDataset(Dataset): def __init__( self, num_samples=1024, latent_ch=4, latent_h=32, latent_w=32, text_dim=768, num_style_classes=128, num_content_classes=512, num_mood_classes=64, ): self.num_samples = num_samples self.latent_shape = (latent_ch, latent_h, latent_w) self.text_dim = text_dim self.num_style_classes = num_style_classes self.num_content_classes = num_content_classes self.num_mood_classes = num_mood_classes def __len__(self): return self.num_samples def __getitem__(self, idx): z0 = torch.randn(self.latent_shape) text_emb = torch.randn(self.text_dim) style_label = torch.tensor(random.randint(0, self.num_style_classes - 1), dtype=torch.long) content_label = torch.tensor(random.randint(0, self.num_content_classes - 1), dtype=torch.long) mood_label = torch.tensor(random.randint(0, self.num_mood_classes - 1), dtype=torch.long) return z0, text_emb, style_label, content_label, mood_label def train_one_epoch( model, dataloader, optimizer, device, stage: int = 1, lambda_flow: float = 1.0, lambda_smooth: float = 0.05, lambda_style: float = 0.1, lambda_content: float = 0.1, lambda_mood: float = 0.1, p_uncond: float = 0.1, grad_clip: float = 1.0, ema_model=None, ema_decay: float = 0.9999, ): model.train() total_loss = 0.0 total_flow = 0.0 total_smooth = 0.0 num_batches = 0 for z0, text_emb, style_label, content_label, mood_label in dataloader: z0 = z0.to(device) text_emb = text_emb.to(device) style_label = style_label.to(device) content_label = content_label.to(device) mood_label = mood_label.to(device) B = z0.shape[0] mask_uncond = torch.rand(B, device=device) < p_uncond text_emb[mask_uncond] = 0.0 z1 = torch.randn_like(z0) t = sample_timesteps(B, device) z_t, v_target = rectified_flow_step(z0, z1, t) v_pred, asdl = model(z_t, t, text_emb, return_asdl=True) loss_flow = F.mse_loss(v_pred, v_target) loss = lambda_flow * loss_flow loss_smooth = spectral_smoothness_loss(v_pred, z_t) loss = loss + lambda_smooth * loss_smooth if stage >= 1 and asdl is not None: if lambda_style > 0: s_logits = asdl['style_logits'] loss_style = F.cross_entropy(s_logits, style_label) loss = loss + lambda_style * loss_style if stage >= 2 and lambda_content > 0: c_logits = asdl['content_logits'] c_logits_avg = c_logits.mean(dim=1) loss_content = F.cross_entropy(c_logits_avg, content_label) loss = loss + lambda_content * loss_content if stage >= 4 and lambda_mood > 0: m_logits = asdl['mood_logits'] loss_mood = F.cross_entropy(m_logits, mood_label) loss = loss + lambda_mood * loss_mood optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) optimizer.step() if ema_model is not None: with torch.no_grad(): for p_ema, p in zip(ema_model.parameters(), model.parameters()): p_ema.data.mul_(ema_decay).add_(p.data, alpha=1 - ema_decay) total_loss += loss.item() total_flow += loss_flow.item() total_smooth += loss_smooth.item() num_batches += 1 return { 'loss': total_loss / max(num_batches, 1), 'flow': total_flow / max(num_batches, 1), 'smooth': total_smooth / max(num_batches, 1), } def build_optimizer(model_or_params, lr=2e-4, weight_decay=0.01): params = model_or_params.parameters() if hasattr(model_or_params, 'parameters') else model_or_params return torch.optim.AdamW(params, lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) def apply_curriculum_freeze(model, stage: int): for p in model.parameters(): p.requires_grad = False def unfreeze(m): for p in m.parameters(): p.requires_grad = True unfreeze(model.patch_embed) unfreeze(model.t_embed) unfreeze(model.cond_proj) unfreeze(model.cond_transform) unfreeze(model.blocks) unfreeze(model.adalns) unfreeze(model.skip_connect) unfreeze(model.final_proj) if stage == 1: unfreeze(model.style_head) elif stage == 2: unfreeze(model.content_head) elif stage == 3: unfreeze(model.concept_head) elif stage == 4: unfreeze(model.mood_head) unfreeze(model.comp_head) elif stage >= 5: for p in model.parameters(): p.requires_grad = True frozen = sum(1 for p in model.parameters() if not p.requires_grad) trainable = sum(1 for p in model.parameters() if p.requires_grad) print(f"[Curriculum] Stage {stage}: frozen {frozen} params, trainable {trainable} params") def run_training( num_epochs_per_stage=5, batch_size=4, lr=2e-4, device='cuda' if torch.cuda.is_available() else 'cpu', save_dir='./checkpoints', embed_dim=256, num_layers=16, latent_h=32, latent_w=32, ): os.makedirs(save_dir, exist_ok=True) print(f"Device: {device}") model = ArtiGen( latent_ch=4, latent_h=latent_h, latent_w=latent_w, embed_dim=embed_dim, num_layers=num_layers, d_state=16, expand=2, text_dim=768, style_classes=128, content_objects=1024, mood_classes=64, ).to(device) total = sum(p.numel() for p in model.parameters()) / 1e6 trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 print(f"Model total params: {total:.2f}M, trainable: {trainable:.2f}M") ema_model = ArtiGen( latent_ch=4, latent_h=latent_h, latent_w=latent_w, embed_dim=embed_dim, num_layers=num_layers, d_state=16, expand=2, text_dim=768, style_classes=128, content_objects=1024, mood_classes=64, ).to(device) ema_model.load_state_dict(model.state_dict()) ema_model.requires_grad_(False) ema_model.eval() dataset = DummyLatentDataset( num_samples=2048, latent_h=latent_h, latent_w=latent_w, num_style_classes=128, num_content_classes=1024, num_mood_classes=64, ) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) for stage in range(1, 6): print(f"\n{'='*40}\n STAGE {stage}\n{'='*40}") apply_curriculum_freeze(model, stage) optimizer = build_optimizer(model, lr=lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs_per_stage, eta_min=lr * 0.1) for epoch in range(1, num_epochs_per_stage + 1): metrics = train_one_epoch(model, dataloader, optimizer, device, stage=stage, ema_model=ema_model) scheduler.step() print(f" Stage {stage} Epoch {epoch}/{num_epochs_per_stage} | loss={metrics['loss']:.4f} flow={metrics['flow']:.4f} smooth={metrics['smooth']:.4f}") ckpt_path = os.path.join(save_dir, f"artigen_stage{stage}.pt") torch.save({'stage': stage, 'model': model.state_dict(), 'ema': ema_model.state_dict(), 'optimizer': optimizer.state_dict()}, ckpt_path) print(f" Saved checkpoint to {ckpt_path}") print("\nTraining complete!") if __name__ == '__main__': import argparse parser = argparse.ArgumentParser() parser.add_argument('--epochs', type=int, default=3) parser.add_argument('--bs', type=int, default=4) parser.add_argument('--lr', type=float, default=2e-4) parser.add_argument('--dim', type=int, default=256) parser.add_argument('--layers', type=int, default=16) parser.add_argument('--latent_h', type=int, default=32) parser.add_argument('--latent_w', type=int, default=32) parser.add_argument('--device', type=str, default='cpu') parser.add_argument('--save_dir', type=str, default='./checkpoints') args = parser.parse_args() run_training( num_epochs_per_stage=args.epochs, batch_size=args.bs, lr=args.lr, device=args.device, save_dir=args.save_dir, embed_dim=args.dim, num_layers=args.layers, latent_h=args.latent_h, latent_w=args.latent_w, )