artigen / train.py
krystv's picture
Upload train.py
eb8440b verified
"""
ArtiGen Training Script — Flow Matching + Modular Curriculum + Spectral Smoothness.
Optimized for Colab Free Tier / small GPU.
"""
import os
import math
import random
from pathlib import Path
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from PIL import Image
import numpy as np
try:
from artigen.model import ArtiGen
except ImportError:
from model import ArtiGen
def sample_timesteps(batch_size, device, min_t=0.0, max_t=1.0):
return torch.rand(batch_size, device=device) * (max_t - min_t) + min_t
def rectified_flow_step(z0, z1, t):
B = z0.shape[0]
t_broadcast = t.view(B, 1, 1, 1)
z_t = (1.0 - t_broadcast) * z0 + t_broadcast * z1
v_target = z1 - z0
return z_t, v_target
def spectral_smoothness_loss(v_pred, z_t):
laplacian_h = v_pred[:, :, 2:, :] - 2 * v_pred[:, :, 1:-1, :] + v_pred[:, :, :-2, :]
laplacian_w = v_pred[:, :, :, 2:] - 2 * v_pred[:, :, :, 1:-1] + v_pred[:, :, :, :-2]
lap_h = F.pad(laplacian_h, (0, 0, 1, 1), mode='reflect')
lap_w = F.pad(laplacian_w, (1, 1, 0, 0), mode='reflect')
smooth = (lap_h.abs().mean() + lap_w.abs().mean()) * 0.01
return smooth
class DummyLatentDataset(Dataset):
def __init__(
self,
num_samples=1024,
latent_ch=4,
latent_h=32,
latent_w=32,
text_dim=768,
num_style_classes=128,
num_content_classes=512,
num_mood_classes=64,
):
self.num_samples = num_samples
self.latent_shape = (latent_ch, latent_h, latent_w)
self.text_dim = text_dim
self.num_style_classes = num_style_classes
self.num_content_classes = num_content_classes
self.num_mood_classes = num_mood_classes
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
z0 = torch.randn(self.latent_shape)
text_emb = torch.randn(self.text_dim)
style_label = torch.tensor(random.randint(0, self.num_style_classes - 1), dtype=torch.long)
content_label = torch.tensor(random.randint(0, self.num_content_classes - 1), dtype=torch.long)
mood_label = torch.tensor(random.randint(0, self.num_mood_classes - 1), dtype=torch.long)
return z0, text_emb, style_label, content_label, mood_label
def train_one_epoch(
model,
dataloader,
optimizer,
device,
stage: int = 1,
lambda_flow: float = 1.0,
lambda_smooth: float = 0.05,
lambda_style: float = 0.1,
lambda_content: float = 0.1,
lambda_mood: float = 0.1,
p_uncond: float = 0.1,
grad_clip: float = 1.0,
ema_model=None,
ema_decay: float = 0.9999,
):
model.train()
total_loss = 0.0
total_flow = 0.0
total_smooth = 0.0
num_batches = 0
for z0, text_emb, style_label, content_label, mood_label in dataloader:
z0 = z0.to(device)
text_emb = text_emb.to(device)
style_label = style_label.to(device)
content_label = content_label.to(device)
mood_label = mood_label.to(device)
B = z0.shape[0]
mask_uncond = torch.rand(B, device=device) < p_uncond
text_emb[mask_uncond] = 0.0
z1 = torch.randn_like(z0)
t = sample_timesteps(B, device)
z_t, v_target = rectified_flow_step(z0, z1, t)
v_pred, asdl = model(z_t, t, text_emb, return_asdl=True)
loss_flow = F.mse_loss(v_pred, v_target)
loss = lambda_flow * loss_flow
loss_smooth = spectral_smoothness_loss(v_pred, z_t)
loss = loss + lambda_smooth * loss_smooth
if stage >= 1 and asdl is not None:
if lambda_style > 0:
s_logits = asdl['style_logits']
loss_style = F.cross_entropy(s_logits, style_label)
loss = loss + lambda_style * loss_style
if stage >= 2 and lambda_content > 0:
c_logits = asdl['content_logits']
c_logits_avg = c_logits.mean(dim=1)
loss_content = F.cross_entropy(c_logits_avg, content_label)
loss = loss + lambda_content * loss_content
if stage >= 4 and lambda_mood > 0:
m_logits = asdl['mood_logits']
loss_mood = F.cross_entropy(m_logits, mood_label)
loss = loss + lambda_mood * loss_mood
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
if ema_model is not None:
with torch.no_grad():
for p_ema, p in zip(ema_model.parameters(), model.parameters()):
p_ema.data.mul_(ema_decay).add_(p.data, alpha=1 - ema_decay)
total_loss += loss.item()
total_flow += loss_flow.item()
total_smooth += loss_smooth.item()
num_batches += 1
return {
'loss': total_loss / max(num_batches, 1),
'flow': total_flow / max(num_batches, 1),
'smooth': total_smooth / max(num_batches, 1),
}
def build_optimizer(model_or_params, lr=2e-4, weight_decay=0.01):
params = model_or_params.parameters() if hasattr(model_or_params, 'parameters') else model_or_params
return torch.optim.AdamW(params, lr=lr, betas=(0.9, 0.999), weight_decay=weight_decay)
def apply_curriculum_freeze(model, stage: int):
for p in model.parameters():
p.requires_grad = False
def unfreeze(m):
for p in m.parameters():
p.requires_grad = True
unfreeze(model.patch_embed)
unfreeze(model.t_embed)
unfreeze(model.cond_proj)
unfreeze(model.cond_transform)
unfreeze(model.blocks)
unfreeze(model.adalns)
unfreeze(model.skip_connect)
unfreeze(model.final_proj)
if stage == 1:
unfreeze(model.style_head)
elif stage == 2:
unfreeze(model.content_head)
elif stage == 3:
unfreeze(model.concept_head)
elif stage == 4:
unfreeze(model.mood_head)
unfreeze(model.comp_head)
elif stage >= 5:
for p in model.parameters():
p.requires_grad = True
frozen = sum(1 for p in model.parameters() if not p.requires_grad)
trainable = sum(1 for p in model.parameters() if p.requires_grad)
print(f"[Curriculum] Stage {stage}: frozen {frozen} params, trainable {trainable} params")
def run_training(
num_epochs_per_stage=5,
batch_size=4,
lr=2e-4,
device='cuda' if torch.cuda.is_available() else 'cpu',
save_dir='./checkpoints',
embed_dim=256,
num_layers=16,
latent_h=32,
latent_w=32,
):
os.makedirs(save_dir, exist_ok=True)
print(f"Device: {device}")
model = ArtiGen(
latent_ch=4, latent_h=latent_h, latent_w=latent_w,
embed_dim=embed_dim, num_layers=num_layers,
d_state=16, expand=2, text_dim=768,
style_classes=128, content_objects=1024, mood_classes=64,
).to(device)
total = sum(p.numel() for p in model.parameters()) / 1e6
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6
print(f"Model total params: {total:.2f}M, trainable: {trainable:.2f}M")
ema_model = ArtiGen(
latent_ch=4, latent_h=latent_h, latent_w=latent_w,
embed_dim=embed_dim, num_layers=num_layers,
d_state=16, expand=2, text_dim=768,
style_classes=128, content_objects=1024, mood_classes=64,
).to(device)
ema_model.load_state_dict(model.state_dict())
ema_model.requires_grad_(False)
ema_model.eval()
dataset = DummyLatentDataset(
num_samples=2048, latent_h=latent_h, latent_w=latent_w,
num_style_classes=128, num_content_classes=1024, num_mood_classes=64,
)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
for stage in range(1, 6):
print(f"\n{'='*40}\n STAGE {stage}\n{'='*40}")
apply_curriculum_freeze(model, stage)
optimizer = build_optimizer(model, lr=lr)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs_per_stage, eta_min=lr * 0.1)
for epoch in range(1, num_epochs_per_stage + 1):
metrics = train_one_epoch(model, dataloader, optimizer, device, stage=stage, ema_model=ema_model)
scheduler.step()
print(f" Stage {stage} Epoch {epoch}/{num_epochs_per_stage} | loss={metrics['loss']:.4f} flow={metrics['flow']:.4f} smooth={metrics['smooth']:.4f}")
ckpt_path = os.path.join(save_dir, f"artigen_stage{stage}.pt")
torch.save({'stage': stage, 'model': model.state_dict(), 'ema': ema_model.state_dict(), 'optimizer': optimizer.state_dict()}, ckpt_path)
print(f" Saved checkpoint to {ckpt_path}")
print("\nTraining complete!")
if __name__ == '__main__':
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--epochs', type=int, default=3)
parser.add_argument('--bs', type=int, default=4)
parser.add_argument('--lr', type=float, default=2e-4)
parser.add_argument('--dim', type=int, default=256)
parser.add_argument('--layers', type=int, default=16)
parser.add_argument('--latent_h', type=int, default=32)
parser.add_argument('--latent_w', type=int, default=32)
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--save_dir', type=str, default='./checkpoints')
args = parser.parse_args()
run_training(
num_epochs_per_stage=args.epochs,
batch_size=args.bs,
lr=args.lr,
device=args.device,
save_dir=args.save_dir,
embed_dim=args.dim,
num_layers=args.layers,
latent_h=args.latent_h,
latent_w=args.latent_w,
)