| """ |
| ArtiGen Training Script — Flow Matching + Modular Curriculum + Spectral Smoothness. |
| Optimized for Colab Free Tier / small GPU. |
| """ |
| import os |
| import math |
| import random |
| from pathlib import Path |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
| from PIL import Image |
| import numpy as np |
|
|
| try: |
| from artigen.model import ArtiGen |
| except ImportError: |
| from model import ArtiGen |
|
|
|
|
| def sample_timesteps(batch_size, device, min_t=0.0, max_t=1.0): |
| return torch.rand(batch_size, device=device) * (max_t - min_t) + min_t |
|
|
|
|
| def rectified_flow_step(z0, z1, t): |
| B = z0.shape[0] |
| t_broadcast = t.view(B, 1, 1, 1) |
| z_t = (1.0 - t_broadcast) * z0 + t_broadcast * z1 |
| v_target = z1 - z0 |
| return z_t, v_target |
|
|
|
|
| def spectral_smoothness_loss(v_pred, z_t): |
| laplacian_h = v_pred[:, :, 2:, :] - 2 * v_pred[:, :, 1:-1, :] + v_pred[:, :, :-2, :] |
| laplacian_w = v_pred[:, :, :, 2:] - 2 * v_pred[:, :, :, 1:-1] + v_pred[:, :, :, :-2] |
| lap_h = F.pad(laplacian_h, (0, 0, 1, 1), mode='reflect') |
| lap_w = F.pad(laplacian_w, (1, 1, 0, 0), mode='reflect') |
| smooth = (lap_h.abs().mean() + lap_w.abs().mean()) * 0.01 |
| return smooth |
|
|
|
|
| class DummyLatentDataset(Dataset): |
| def __init__( |
| self, |
| num_samples=1024, |
| latent_ch=4, |
| latent_h=32, |
| latent_w=32, |
| text_dim=768, |
| num_style_classes=128, |
| num_content_classes=512, |
| num_mood_classes=64, |
| ): |
| self.num_samples = num_samples |
| self.latent_shape = (latent_ch, latent_h, latent_w) |
| self.text_dim = text_dim |
| self.num_style_classes = num_style_classes |
| self.num_content_classes = num_content_classes |
| self.num_mood_classes = num_mood_classes |
|
|
| def __len__(self): |
| return self.num_samples |
|
|
| def __getitem__(self, idx): |
| z0 = torch.randn(self.latent_shape) |
| text_emb = torch.randn(self.text_dim) |
| style_label = torch.tensor(random.randint(0, self.num_style_classes - 1), dtype=torch.long) |
| content_label = torch.tensor(random.randint(0, self.num_content_classes - 1), dtype=torch.long) |
| mood_label = torch.tensor(random.randint(0, self.num_mood_classes - 1), dtype=torch.long) |
| return z0, text_emb, style_label, content_label, mood_label |
|
|
|
|
| def train_one_epoch( |
| model, |
| dataloader, |
| optimizer, |
| device, |
| stage: int = 1, |
| lambda_flow: float = 1.0, |
| lambda_smooth: float = 0.05, |
| lambda_style: float = 0.1, |
| lambda_content: float = 0.1, |
| lambda_mood: float = 0.1, |
| p_uncond: float = 0.1, |
| grad_clip: float = 1.0, |
| ema_model=None, |
| ema_decay: float = 0.9999, |
| ): |
| model.train() |
| total_loss = 0.0 |
| total_flow = 0.0 |
| total_smooth = 0.0 |
| num_batches = 0 |
|
|
| for z0, text_emb, style_label, content_label, mood_label in dataloader: |
| z0 = z0.to(device) |
| text_emb = text_emb.to(device) |
| style_label = style_label.to(device) |
| content_label = content_label.to(device) |
| mood_label = mood_label.to(device) |
| B = z0.shape[0] |
|
|
| mask_uncond = torch.rand(B, device=device) < p_uncond |
| text_emb[mask_uncond] = 0.0 |
|
|
| z1 = torch.randn_like(z0) |
| t = sample_timesteps(B, device) |
| z_t, v_target = rectified_flow_step(z0, z1, t) |
|
|
| v_pred, asdl = model(z_t, t, text_emb, return_asdl=True) |
|
|
| loss_flow = F.mse_loss(v_pred, v_target) |
| loss = lambda_flow * loss_flow |
|
|
| loss_smooth = spectral_smoothness_loss(v_pred, z_t) |
| loss = loss + lambda_smooth * loss_smooth |
|
|
| if stage >= 1 and asdl is not None: |
| if lambda_style > 0: |
| s_logits = asdl['style_logits'] |
| loss_style = F.cross_entropy(s_logits, style_label) |
| loss = loss + lambda_style * loss_style |
| if stage >= 2 and lambda_content > 0: |
| c_logits = asdl['content_logits'] |
| c_logits_avg = c_logits.mean(dim=1) |
| loss_content = F.cross_entropy(c_logits_avg, content_label) |
| loss = loss + lambda_content * loss_content |
| if stage >= 4 and lambda_mood > 0: |
| m_logits = asdl['mood_logits'] |
| loss_mood = F.cross_entropy(m_logits, mood_label) |
| loss = loss + lambda_mood * loss_mood |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip) |
| optimizer.step() |
|
|
| if ema_model is not None: |
| with torch.no_grad(): |
| for p_ema, p in zip(ema_model.parameters(), model.parameters()): |
| p_ema.data.mul_(ema_decay).add_(p.data, alpha=1 - ema_decay) |
|
|
| total_loss += loss.item() |
| total_flow += loss_flow.item() |
| total_smooth += loss_smooth.item() |
| num_batches += 1 |
|
|
| return { |
| 'loss': total_loss / max(num_batches, 1), |
| 'flow': total_flow / max(num_batches, 1), |
| 'smooth': total_smooth / max(num_batches, 1), |
| } |
|
|
|
|
| def build_optimizer(model_or_params, lr=2e-4, weight_decay=0.01): |
| params = model_or_params.parameters() if hasattr(model_or_params, 'parameters') else model_or_params |
| return torch.optim.AdamW(params, lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay) |
|
|
|
|
| def apply_curriculum_freeze(model, stage: int): |
| for p in model.parameters(): |
| p.requires_grad = False |
|
|
| def unfreeze(m): |
| for p in m.parameters(): |
| p.requires_grad = True |
|
|
| unfreeze(model.patch_embed) |
| unfreeze(model.t_embed) |
| unfreeze(model.cond_proj) |
| unfreeze(model.cond_transform) |
| unfreeze(model.blocks) |
| unfreeze(model.adalns) |
| unfreeze(model.skip_connect) |
| unfreeze(model.final_proj) |
|
|
| if stage == 1: |
| unfreeze(model.style_head) |
| elif stage == 2: |
| unfreeze(model.content_head) |
| elif stage == 3: |
| unfreeze(model.concept_head) |
| elif stage == 4: |
| unfreeze(model.mood_head) |
| unfreeze(model.comp_head) |
| elif stage >= 5: |
| for p in model.parameters(): |
| p.requires_grad = True |
|
|
| frozen = sum(1 for p in model.parameters() if not p.requires_grad) |
| trainable = sum(1 for p in model.parameters() if p.requires_grad) |
| print(f"[Curriculum] Stage {stage}: frozen {frozen} params, trainable {trainable} params") |
|
|
|
|
| def run_training( |
| num_epochs_per_stage=5, |
| batch_size=4, |
| lr=2e-4, |
| device='cuda' if torch.cuda.is_available() else 'cpu', |
| save_dir='./checkpoints', |
| embed_dim=256, |
| num_layers=16, |
| latent_h=32, |
| latent_w=32, |
| ): |
| os.makedirs(save_dir, exist_ok=True) |
| print(f"Device: {device}") |
|
|
| model = ArtiGen( |
| latent_ch=4, latent_h=latent_h, latent_w=latent_w, |
| embed_dim=embed_dim, num_layers=num_layers, |
| d_state=16, expand=2, text_dim=768, |
| style_classes=128, content_objects=1024, mood_classes=64, |
| ).to(device) |
|
|
| total = sum(p.numel() for p in model.parameters()) / 1e6 |
| trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6 |
| print(f"Model total params: {total:.2f}M, trainable: {trainable:.2f}M") |
|
|
| ema_model = ArtiGen( |
| latent_ch=4, latent_h=latent_h, latent_w=latent_w, |
| embed_dim=embed_dim, num_layers=num_layers, |
| d_state=16, expand=2, text_dim=768, |
| style_classes=128, content_objects=1024, mood_classes=64, |
| ).to(device) |
| ema_model.load_state_dict(model.state_dict()) |
| ema_model.requires_grad_(False) |
| ema_model.eval() |
|
|
| dataset = DummyLatentDataset( |
| num_samples=2048, latent_h=latent_h, latent_w=latent_w, |
| num_style_classes=128, num_content_classes=1024, num_mood_classes=64, |
| ) |
| dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True) |
|
|
| for stage in range(1, 6): |
| print(f"\n{'='*40}\n STAGE {stage}\n{'='*40}") |
| apply_curriculum_freeze(model, stage) |
| optimizer = build_optimizer(model, lr=lr) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs_per_stage, eta_min=lr * 0.1) |
|
|
| for epoch in range(1, num_epochs_per_stage + 1): |
| metrics = train_one_epoch(model, dataloader, optimizer, device, stage=stage, ema_model=ema_model) |
| scheduler.step() |
| print(f" Stage {stage} Epoch {epoch}/{num_epochs_per_stage} | loss={metrics['loss']:.4f} flow={metrics['flow']:.4f} smooth={metrics['smooth']:.4f}") |
|
|
| ckpt_path = os.path.join(save_dir, f"artigen_stage{stage}.pt") |
| torch.save({'stage': stage, 'model': model.state_dict(), 'ema': ema_model.state_dict(), 'optimizer': optimizer.state_dict()}, ckpt_path) |
| print(f" Saved checkpoint to {ckpt_path}") |
|
|
| print("\nTraining complete!") |
|
|
|
|
| if __name__ == '__main__': |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--epochs', type=int, default=3) |
| parser.add_argument('--bs', type=int, default=4) |
| parser.add_argument('--lr', type=float, default=2e-4) |
| parser.add_argument('--dim', type=int, default=256) |
| parser.add_argument('--layers', type=int, default=16) |
| parser.add_argument('--latent_h', type=int, default=32) |
| parser.add_argument('--latent_w', type=int, default=32) |
| parser.add_argument('--device', type=str, default='cpu') |
| parser.add_argument('--save_dir', type=str, default='./checkpoints') |
| args = parser.parse_args() |
| run_training( |
| num_epochs_per_stage=args.epochs, |
| batch_size=args.bs, |
| lr=args.lr, |
| device=args.device, |
| save_dir=args.save_dir, |
| embed_dim=args.dim, |
| num_layers=args.layers, |
| latent_h=args.latent_h, |
| latent_w=args.latent_w, |
| ) |
|
|