artigen / demo_train_cpu.py
krystv's picture
Upload demo_train_cpu.py
41ef0df verified
"""
Minimal CPU training demo to verify ArtiGen trains end-to-end.
Run: python demo_train_cpu.py
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from model import ArtiGen
from train import train_one_epoch, build_optimizer, apply_curriculum_freeze
from train import DummyLatentDataset
from torch.utils.data import DataLoader
def demo():
device = 'cpu'
model = ArtiGen(
embed_dim=64, num_layers=4,
latent_h=8, latent_w=8,
style_classes=8, content_objects=8, mood_classes=4,
).to(device)
ema = ArtiGen(
embed_dim=64, num_layers=4,
latent_h=8, latent_w=8,
style_classes=8, content_objects=8, mood_classes=4,
).to(device)
ema.load_state_dict(model.state_dict())
ema.requires_grad_(False)
ema.eval()
ds = DummyLatentDataset(num_samples=64, latent_h=8, latent_w=8,
num_style_classes=8, num_content_classes=8, num_mood_classes=4)
dl = DataLoader(ds, batch_size=2, shuffle=True)
for stage in range(1, 3):
apply_curriculum_freeze(model, stage)
opt = build_optimizer(model, lr=1e-3)
print(f"\n=== Stage {stage} ===")
for epoch in range(1, 3):
m = train_one_epoch(model, dl, opt, device, stage=stage, ema_model=ema, ema_decay=0.995)
print(f" Epoch {epoch} | loss={m['loss']:.4f} flow={m['flow']:.4f} smooth={m['smooth']:.4f}")
print("\nDemo training complete — ArtiGen works!")
if __name__ == '__main__':
demo()