File size: 1,540 Bytes
41ef0df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
"""
Minimal CPU training demo to verify ArtiGen trains end-to-end.
Run: python demo_train_cpu.py
"""
import sys, os
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
import torch
from model import ArtiGen
from train import train_one_epoch, build_optimizer, apply_curriculum_freeze
from train import DummyLatentDataset
from torch.utils.data import DataLoader

def demo():
    device = 'cpu'
    model = ArtiGen(
        embed_dim=64, num_layers=4,
        latent_h=8, latent_w=8,
        style_classes=8, content_objects=8, mood_classes=4,
    ).to(device)

    ema = ArtiGen(
        embed_dim=64, num_layers=4,
        latent_h=8, latent_w=8,
        style_classes=8, content_objects=8, mood_classes=4,
    ).to(device)
    ema.load_state_dict(model.state_dict())
    ema.requires_grad_(False)
    ema.eval()

    ds = DummyLatentDataset(num_samples=64, latent_h=8, latent_w=8,
                             num_style_classes=8, num_content_classes=8, num_mood_classes=4)
    dl = DataLoader(ds, batch_size=2, shuffle=True)

    for stage in range(1, 3):
        apply_curriculum_freeze(model, stage)
        opt = build_optimizer(model, lr=1e-3)
        print(f"\n=== Stage {stage} ===")
        for epoch in range(1, 3):
            m = train_one_epoch(model, dl, opt, device, stage=stage, ema_model=ema, ema_decay=0.995)
            print(f"  Epoch {epoch} | loss={m['loss']:.4f} flow={m['flow']:.4f} smooth={m['smooth']:.4f}")
    print("\nDemo training complete — ArtiGen works!")

if __name__ == '__main__':
    demo()