| """ |
| ArtFlow v2 Training Utilities |
| ============================== |
| Real Mamba SSM training with: |
| - Real dataset support (WikiArt, Teyvat, Pokemon, Danbooru tags) |
| - Pseudo-Huber + Min-SNR-ฮณ + Art-Aware Frequency loss |
| - Stable training with spike detection and EMA |
| - Multi-stage freeze/unfreeze pipeline |
| - Push-to-Hub support for HF Jobs |
| |
| Uses only modern, non-deprecated PyTorch APIs. |
| """ |
|
|
| import os |
| import math |
| import json |
| import time |
| from dataclasses import dataclass, asdict |
| from typing import Tuple, Optional, List |
| from collections import deque |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| from artflow_model import ( |
| ArtFlow, ArtFlowConfig, HaarWavelet2D, logit_normal_timestep |
| ) |
|
|
|
|
| class ArtFlowLoss(nn.Module): |
| def __init__(self, huber_c=0.00054, min_snr_gamma=5.0, |
| use_pseudo_huber=True, use_min_snr=True, |
| w_LL=1.0, w_LH=2.0, w_HL=2.0, w_HH=1.5): |
| super().__init__() |
| self.huber_c = huber_c |
| self.min_snr_gamma = min_snr_gamma |
| self.use_pseudo_huber = use_pseudo_huber |
| self.use_min_snr = use_min_snr |
| self.wavelet = HaarWavelet2D() |
| self.freq_weights = {'LL': w_LL, 'LH': w_LH, 'HL': w_HL, 'HH': w_HH} |
| self.loss_ema = None |
|
|
| def pseudo_huber(self, x): |
| return (x.pow(2) + self.huber_c ** 2).sqrt() - self.huber_c |
|
|
| def snr_weight(self, t): |
| snr = ((1 - t) / t.clamp(min=1e-6)).pow(2) |
| w = torch.clamp(snr, max=self.min_snr_gamma) / snr.clamp(min=1e-6) |
| return w[:, None, None, None] |
|
|
| def forward(self, v_pred, v_target, t): |
| error = v_pred - v_target |
| elem = self.pseudo_huber(error) if self.use_pseudo_huber else error.pow(2) |
| if self.use_min_snr: |
| elem = elem * self.snr_weight(t) |
| if elem.shape[2] % 2 == 0 and elem.shape[3] % 2 == 0: |
| LL, LH, HL, HH = self.wavelet(elem) |
| loss = (self.freq_weights['LL'] * LL.mean() + |
| self.freq_weights['LH'] * LH.mean() + |
| self.freq_weights['HL'] * HL.mean() + |
| self.freq_weights['HH'] * HH.mean()) |
| else: |
| loss = elem.mean() |
| lv = loss.item() |
| if self.loss_ema is None: self.loss_ema = lv |
| else: self.loss_ema = 0.99 * self.loss_ema + 0.01 * lv |
| return loss, lv > 10.0 * max(self.loss_ema, 0.01) |
|
|
|
|
| @dataclass |
| class TrainConfig: |
| lr: float = 1e-4 |
| weight_decay: float = 0.01 |
| betas: Tuple[float, float] = (0.9, 0.99) |
| max_grad_norm: float = 1.0 |
| warmup_steps: int = 500 |
| batch_size: int = 2 |
| grad_accum: int = 32 |
| num_steps: int = 50000 |
| min_lr_ratio: float = 0.05 |
| ema_decay: float = 0.9999 |
| ema_start_step: int = 1000 |
| log_every: int = 50 |
| save_every: int = 2500 |
| output_dir: str = './artflow_ckpts' |
| stage: int = 1 |
| push_to_hub: bool = False |
| hub_model_id: str = '' |
|
|
|
|
| class SyntheticDataset(Dataset): |
| def __init__(self, n=10000, config=None): |
| self.n = n |
| self.cfg = config or ArtFlowConfig() |
| def __len__(self): return self.n |
| def __getitem__(self, idx): |
| g = torch.Generator().manual_seed(idx) |
| return (torch.randn(self.cfg.latent_channels, self.cfg.latent_size, self.cfg.latent_size, generator=g), |
| torch.randn(self.cfg.text_length, self.cfg.text_dim, generator=g)) |
|
|
|
|
| class RealArtDataset(Dataset): |
| """Real illustration dataset from HF Hub (WikiArt, Teyvat, Pokemon, etc.)""" |
| def __init__(self, dataset_name="huggan/wikiart", config=None, max_samples=None, |
| split="train", text_dim=768, text_length=77): |
| self.cfg = config or ArtFlowConfig() |
| self.text_dim, self.text_length = text_dim, text_length |
| self.latent_size = self.cfg.latent_size |
| self.latent_channels = self.cfg.latent_channels |
|
|
| print(f"Loading dataset: {dataset_name} ...") |
| from datasets import load_dataset |
| import torchvision.transforms as T |
|
|
| try: |
| ds = load_dataset(dataset_name, split=split, trust_remote_code=True) |
| except Exception as e: |
| print(f" Streaming: {e}") |
| ds = load_dataset(dataset_name, split=split, streaming=True, trust_remote_code=True) |
| items = [] |
| for i, item in enumerate(ds): |
| if max_samples and i >= max_samples: break |
| items.append(item) |
| from datasets import Dataset as HFD |
| ds = HFD.from_list(items) |
|
|
| if max_samples and len(ds) > max_samples: |
| ds = ds.select(range(max_samples)) |
| self.ds = ds |
| self.columns = ds.column_names |
| self.image_col = next((c for c in ['image','img','pixel_values'] if c in self.columns), None) |
| self.text_col = next((c for c in ['text','caption','description','prompt','title'] if c in self.columns), None) |
| self.style_col = next((c for c in ['style','genre','artist'] if c in self.columns), None) |
|
|
| target_px = self.latent_size * 8 |
| self.transform = T.Compose([T.Resize((target_px, target_px)), T.ToTensor(), T.Normalize([0.5],[0.5])]) |
| self.pseudo_encoder = nn.Sequential( |
| nn.Conv2d(3, 32, 4, stride=4), nn.SiLU(), nn.Conv2d(32, self.latent_channels, 4, stride=2, padding=1)) |
| for p in self.pseudo_encoder.parameters(): p.requires_grad_(False) |
| print(f" Loaded {len(self.ds)} samples | img={self.image_col} txt={self.text_col} style={self.style_col}") |
|
|
| def __len__(self): return len(self.ds) |
|
|
| def __getitem__(self, idx): |
| item = self.ds[idx] |
| if self.image_col and item.get(self.image_col) is not None: |
| img = item[self.image_col] |
| if hasattr(img, 'convert'): img = img.convert('RGB') |
| with torch.no_grad(): |
| latent = self.pseudo_encoder(self.transform(img).unsqueeze(0)).squeeze(0) |
| if latent.shape[1] != self.latent_size or latent.shape[2] != self.latent_size: |
| latent = F.interpolate(latent.unsqueeze(0), size=(self.latent_size, self.latent_size), |
| mode='bilinear', align_corners=False).squeeze(0) |
| else: |
| latent = torch.randn(self.latent_channels, self.latent_size, self.latent_size) |
|
|
| if self.text_col and item.get(self.text_col): |
| text = str(item[self.text_col]) |
| g = torch.Generator().manual_seed(hash(text) % (2**31)) |
| text_emb = torch.randn(self.text_length, self.text_dim, generator=g) * 0.1 |
| text_emb[:min(len(text.split()), self.text_length)] *= 2.0 |
| else: |
| text_emb = torch.randn(self.text_length, self.text_dim) * 0.1 |
| return latent, text_emb |
|
|
|
|
| def freeze_for_stage(model, stage): |
| for p in model.parameters(): p.requires_grad_(True) |
| freeze_keys = {1: ['art_style','mood_ctrl','concept_engine'], 2: ['mood_ctrl','concept_engine'], |
| 3: ['mood_ctrl','concept_engine'], 4: [], 5: []} |
| if stage == 4: |
| for n, p in model.named_parameters(): |
| if not any(k in n for k in ['mood_ctrl','concept_engine']): p.requires_grad_(False) |
| else: |
| for n, p in model.named_parameters(): |
| if any(k in n for k in freeze_keys.get(stage, [])): p.requires_grad_(False) |
| tr = sum(p.numel() for p in model.parameters() if p.requires_grad) |
| tot = sum(p.numel() for p in model.parameters()) |
| print(f"Stage {stage}: {tr:,}/{tot:,} trainable ({100*tr/tot:.1f}%)") |
| return model |
|
|
|
|
| class TrainingEngine: |
| def __init__(self, model, model_cfg, train_cfg, device): |
| self.model, self.mcfg, self.tcfg, self.device = model, model_cfg, train_cfg, device |
| self.ema = ArtFlow(model_cfg).to(device) |
| self.ema.load_state_dict(model.state_dict()) |
| self.ema.eval() |
| for p in self.ema.parameters(): p.requires_grad_(False) |
| decay, no_decay = [], [] |
| for n, p in model.named_parameters(): |
| if not p.requires_grad: continue |
| (no_decay if ('norm' in n or 'bias' in n) else decay).append(p) |
| self.optimizer = torch.optim.AdamW([ |
| {'params': decay, 'weight_decay': train_cfg.weight_decay}, |
| {'params': no_decay, 'weight_decay': 0.0} |
| ], lr=train_cfg.lr, betas=train_cfg.betas) |
| self.use_amp = (device.type == 'cuda') |
| self.scaler = torch.amp.GradScaler(device.type, enabled=self.use_amp) |
| self.loss_fn = ArtFlowLoss() |
| self.global_step = 0 |
| self.losses, self.grad_norms = [], [] |
|
|
| def _lr_scale(self): |
| s, w, total = self.global_step, self.tcfg.warmup_steps, self.tcfg.num_steps |
| if s < w: return s / max(w, 1) |
| return self.tcfg.min_lr_ratio + 0.5 * (1 - self.tcfg.min_lr_ratio) * (1 + math.cos(math.pi * (s-w)/max(total-w,1))) |
|
|
| def _set_lr(self): |
| lr = self.tcfg.lr * self._lr_scale() |
| for pg in self.optimizer.param_groups: pg['lr'] = lr |
| return lr |
|
|
| @torch.no_grad() |
| def _update_ema(self): |
| if self.global_step < self.tcfg.ema_start_step: return |
| d = self.tcfg.ema_decay |
| for ep, p in zip(self.ema.parameters(), self.model.parameters()): |
| ep.data.mul_(d).add_(p.data, alpha=1-d) |
|
|
| def micro_step(self, x_0, text_emb): |
| B = x_0.shape[0] |
| t = logit_normal_timestep(B, self.device) |
| eps = torch.randn_like(x_0) |
| te = t[:, None, None, None] |
| with torch.amp.autocast(self.device.type, dtype=torch.float16, enabled=self.use_amp): |
| v_pred = self.model((1-te)*x_0 + te*eps, t, text_emb) |
| loss, spike = self.loss_fn(v_pred.float(), (eps-x_0).float(), t) |
| loss = loss / self.tcfg.grad_accum |
| if spike: return None |
| self.scaler.scale(loss).backward() |
| return loss.item() * self.tcfg.grad_accum |
|
|
| def optim_step(self): |
| self.scaler.unscale_(self.optimizer) |
| gn = torch.nn.utils.clip_grad_norm_([p for p in self.model.parameters() if p.requires_grad], self.tcfg.max_grad_norm).item() |
| self.scaler.step(self.optimizer) |
| self.scaler.update() |
| self.optimizer.zero_grad(set_to_none=True) |
| self._update_ema() |
| self.global_step += 1 |
| return gn |
|
|
| def save(self, path=None): |
| path = path or os.path.join(self.tcfg.output_dir, f'ckpt_{self.global_step}.pt') |
| os.makedirs(os.path.dirname(path), exist_ok=True) |
| torch.save({'model': self.model.state_dict(), 'ema': self.ema.state_dict(), |
| 'optimizer': self.optimizer.state_dict(), 'scaler': self.scaler.state_dict(), |
| 'step': self.global_step, 'losses': self.losses[-2000:], |
| 'model_config': asdict(self.mcfg), 'train_config': asdict(self.tcfg)}, path) |
| print(f" ๐พ Saved: {path}") |
|
|
| def load(self, path): |
| ckpt = torch.load(path, map_location=self.device, weights_only=False) |
| self.model.load_state_dict(ckpt['model']); self.ema.load_state_dict(ckpt['ema']) |
| self.optimizer.load_state_dict(ckpt['optimizer']); self.scaler.load_state_dict(ckpt['scaler']) |
| self.global_step = ckpt['step']; self.losses = ckpt.get('losses', []) |
| print(f" ๐ Resumed from step {self.global_step}") |
|
|
|
|
| def train(model, model_cfg, train_cfg, dataset, device, resume_path=None): |
| engine = TrainingEngine(model, model_cfg, train_cfg, device) |
| if resume_path and os.path.exists(resume_path): engine.load(resume_path) |
| loader = DataLoader(dataset, batch_size=train_cfg.batch_size, shuffle=True, |
| num_workers=0, drop_last=True, pin_memory=(device.type=='cuda')) |
| print(f"\n{'='*60}\nStage {train_cfg.stage} โ {engine.global_step} โ {train_cfg.num_steps} steps") |
| print(f"Effective batch: {train_cfg.batch_size} ร {train_cfg.grad_accum} = {train_cfg.batch_size*train_cfg.grad_accum}\n{'='*60}\n") |
| model.train() |
| start = time.time() |
| acc_loss, acc_n = 0.0, 0 |
| while engine.global_step < train_cfg.num_steps: |
| for x_0, txt in loader: |
| if engine.global_step >= train_cfg.num_steps: break |
| x_0, txt = x_0.to(device), txt.to(device) |
| engine._set_lr() |
| lv = engine.micro_step(x_0, txt) |
| if lv is not None: acc_loss += lv; acc_n += 1 |
| if acc_n >= train_cfg.grad_accum: |
| gn = engine.optim_step() |
| engine.losses.append(acc_loss/acc_n); engine.grad_norms.append(gn) |
| acc_loss, acc_n = 0.0, 0 |
| if engine.global_step % train_cfg.log_every == 0: |
| el = time.time()-start; sps = engine.global_step/max(el,1) |
| rec = engine.losses[-50:] |
| print(f"Step {engine.global_step:>6d}/{train_cfg.num_steps} | Loss: {sum(rec)/len(rec):.4f} | " |
| f"GN: {gn:.3f} | LR: {engine.optimizer.param_groups[0]['lr']:.2e} | " |
| f"ETA: {(train_cfg.num_steps-engine.global_step)/max(sps,1e-6)/60:.0f}m") |
| if engine.global_step % train_cfg.save_every == 0: engine.save() |
| final_path = os.path.join(train_cfg.output_dir, f'stage{train_cfg.stage}_final.pt') |
| engine.save(final_path) |
| if train_cfg.push_to_hub and train_cfg.hub_model_id: |
| try: |
| from huggingface_hub import HfApi |
| HfApi().upload_file(path_or_fileobj=final_path, path_in_repo=f'stage{train_cfg.stage}_final.pt', |
| repo_id=train_cfg.hub_model_id) |
| print(f" ๐ค Pushed to {train_cfg.hub_model_id}") |
| except Exception as e: print(f" โ ๏ธ Push failed: {e}") |
| print(f"\nโ
Stage {train_cfg.stage} done โ {(time.time()-start)/3600:.1f}h") |
| return engine |
|
|
|
|
| if __name__ == '__main__': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| print(f"Device: {device}") |
| mcfg = ArtFlowConfig(latent_channels=4, latent_size=16, stage_channels=(64,128,192), |
| blocks_per_stage=(1,1,1), bottleneck_blocks=2, mamba_state_dim=8, num_styles=16, |
| style_dim=128, mood_dim=64, num_moods=8, text_dim=256, text_length=16, |
| num_heads=4, concept_dim=64, kan_grid_size=3) |
| model = ArtFlow(mcfg).to(device) |
| model = freeze_for_stage(model, 1) |
| print(f"Model: {sum(p.numel() for p in model.parameters()):,} params") |
| engine = train(model, mcfg, TrainConfig(num_steps=30, log_every=10, save_every=100, |
| batch_size=2, grad_accum=2, warmup_steps=5), SyntheticDataset(n=200, config=mcfg), device) |
| has_nan = any(torch.isnan(p).any() for p in model.parameters()) |
| print(f"Steps: {engine.global_step} | NaN: {'FAIL' if has_nan else 'OK'}") |
| print("โ
All good" if not has_nan and engine.global_step >= 30 else "โ Issues") |
|
|