| """ |
| Minimal CPU training demo to verify ArtiGen trains end-to-end. |
| Run: python demo_train_cpu.py |
| """ |
| import sys, os |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| import torch |
| from model import ArtiGen |
| from train import train_one_epoch, build_optimizer, apply_curriculum_freeze |
| from train import DummyLatentDataset |
| from torch.utils.data import DataLoader |
|
|
| def demo(): |
| device = 'cpu' |
| model = ArtiGen( |
| embed_dim=64, num_layers=4, |
| latent_h=8, latent_w=8, |
| style_classes=8, content_objects=8, mood_classes=4, |
| ).to(device) |
|
|
| ema = ArtiGen( |
| embed_dim=64, num_layers=4, |
| latent_h=8, latent_w=8, |
| style_classes=8, content_objects=8, mood_classes=4, |
| ).to(device) |
| ema.load_state_dict(model.state_dict()) |
| ema.requires_grad_(False) |
| ema.eval() |
|
|
| ds = DummyLatentDataset(num_samples=64, latent_h=8, latent_w=8, |
| num_style_classes=8, num_content_classes=8, num_mood_classes=4) |
| dl = DataLoader(ds, batch_size=2, shuffle=True) |
|
|
| for stage in range(1, 3): |
| apply_curriculum_freeze(model, stage) |
| opt = build_optimizer(model, lr=1e-3) |
| print(f"\n=== Stage {stage} ===") |
| for epoch in range(1, 3): |
| m = train_one_epoch(model, dl, opt, device, stage=stage, ema_model=ema, ema_decay=0.995) |
| print(f" Epoch {epoch} | loss={m['loss']:.4f} flow={m['flow']:.4f} smooth={m['smooth']:.4f}") |
| print("\nDemo training complete — ArtiGen works!") |
|
|
| if __name__ == '__main__': |
| demo() |
|
|