File size: 6,818 Bytes
2c078fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#!/usr/bin/env python3
"""
LRF Extended Training — more epochs on CPU with cached latents.
Uses the same proven architecture from v3, just trains longer.
Pushes results to HF Hub.
"""
import math, os, sys, time, json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from einops import rearrange
import numpy as np

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# Reuse the exact architecture from lrf_v3.py
sys.path.insert(0, '/app')
from lrf_v3 import LRF, FlowScheduler, get_taesd, get_cifar, precompute, save_grid, gen

def main():
    OUT = '/app/lrf_extended'
    REPO = 'krystv/LatentRecurrentFlow'
    os.makedirs(OUT, exist_ok=True)

    EPOCHS = 100  # 3x more than v3
    BS = 128      # Bigger batch
    LR = 5e-4     # Slightly higher LR for faster convergence

    print("=" * 60, flush=True)
    print(f"LRF Extended Training — {EPOCHS} epochs, bs={BS}", flush=True)
    print(f"Device: {DEVICE}", flush=True)
    print("=" * 60, flush=True)

    # VAE + Data (use cached latents from v3 if available)
    print("\n[1] Loading TAESD + CIFAR-10...", flush=True)
    vae = get_taesd(DEVICE)
    tr, te = get_cifar()

    # Check for cached latents from previous run
    cache_dir = '/app/lrf_out'
    if os.path.exists(f'{cache_dir}/cache_train.pt'):
        print("  Using cached latents from v3 run!", flush=True)
        tr_lat, tr_lab = precompute(vae, tr, 256, DEVICE, f'{cache_dir}/cache_train.pt')
    else:
        tr_lat, tr_lab = precompute(vae, tr, 256, DEVICE, f'{OUT}/cache_train.pt')

    # Model — use the proven fast config
    print("\n[2] Creating model...", flush=True)
    cfg = LRF.default()
    model = LRF(cfg).to(DEVICE)
    print(f"  {model.count():,} params", flush=True)

    # Try to warm-start from v3 checkpoint
    v3_ckpt = '/app/lrf_out/model.pt'
    if os.path.exists(v3_ckpt):
        print(f"  Warm-starting from {v3_ckpt}", flush=True)
        ckpt = torch.load(v3_ckpt, map_location=DEVICE, weights_only=False)
        model.load_state_dict(ckpt['state'])
        prev_losses = ckpt.get('losses', [])
        print(f"  Previous best loss: {min(prev_losses):.4f}", flush=True)
    else:
        prev_losses = []

    # Train
    print(f"\n[3] Training {EPOCHS} epochs...", flush=True)
    sched = FlowScheduler()
    opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01, betas=(0.9, 0.95))
    total_steps = EPOCHS * (len(tr_lat) // BS)
    lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, total_steps, LR * 0.01)
    ema = {n: p.clone().detach() for n, p in model.named_parameters()}
    losses = list(prev_losses)  # Continue loss history

    dl = DataLoader(TensorDataset(tr_lat, tr_lab), BS, shuffle=True, drop_last=True, num_workers=0)
    best_loss = min(losses) if losses else 999
    t0 = time.time()

    for ep in range(EPOCHS):
        model.train()
        el, nb = 0, 0
        for lat, lab in dl:
            lat, lab = lat.to(DEVICE), lab.to(DEVICE)
            B = lat.shape[0]
            t = sched.sample_t(B, DEVICE)
            eps = torch.randn_like(lat)
            zt = sched.add_noise(lat, eps, t)
            vp = model.predict_v(zt, t, lab, cfg_drop=0.1)
            vt = sched.velocity(lat, eps)
            lps = (vp - vt).pow(2).mean([1,2,3])
            w = 1.0 / (t * (1-t) + 0.01); w = w / w.mean()
            loss = (lps * w).mean()
            opt.zero_grad(); loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            opt.step(); lr_sched.step()
            with torch.no_grad():
                for n, p in model.named_parameters():
                    ema[n].mul_(0.9995).add_(p, alpha=0.0005)
            el += loss.item(); nb += 1

        al = el / nb
        losses.append(al)
        if al < best_loss: best_loss = al
        elapsed = time.time() - t0

        if (ep+1) % 10 == 0 or ep == 0 or ep == EPOCHS-1:
            print(f"  Ep {ep+1:3d}/{EPOCHS}: loss={al:.4f} best={best_loss:.4f} "
                  f"lr={opt.param_groups[0]['lr']:.1e} {elapsed:.0f}s", flush=True)

        # Sample every 25 epochs
        if (ep+1) % 25 == 0 or ep == EPOCHS-1:
            bak = {n: p.clone() for n, p in model.named_parameters()}
            with torch.no_grad():
                for n, p in model.named_parameters(): p.copy_(ema[n])
            model.eval()
            samps = gen(model, vae, sched, DEVICE, 16, 20, 2.5)
            save_grid(samps, f'{OUT}/ep{ep+1:03d}.png', 4)
            with torch.no_grad():
                for n, p in model.named_parameters(): p.copy_(bak[n])

    # Final EMA
    with torch.no_grad():
        for n, p in model.named_parameters(): p.copy_(ema[n])
    model.eval()

    # Final generation
    print(f"\n[4] Final generation...", flush=True)
    classes = ['airplane','auto','bird','cat','deer','dog','frog','horse','ship','truck']
    all_s = []
    for ci in range(10):
        s = gen(model, vae, sched, DEVICE, 8, 50, 3.0, ci)
        all_s.append(s)
        print(f"  {classes[ci]:10s}: std={s.std():.3f}", flush=True)
    save_grid(torch.cat(all_s), f'{OUT}/final.png', 8)

    # Save
    torch.save({'state': model.state_dict(), 'cfg': cfg, 'losses': losses}, f'{OUT}/model.pt')

    # Loss plot
    try:
        import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt
        plt.figure(figsize=(10,4))
        plt.plot(losses, 'b-', alpha=0.7)
        if prev_losses:
            plt.axvline(x=len(prev_losses), color='r', linestyle='--', alpha=0.5, label='Extended training start')
            plt.legend()
        plt.xlabel('Epoch'); plt.ylabel('Loss')
        plt.title(f'LRF Training (best={best_loss:.4f})')
        plt.grid(True, alpha=0.3)
        plt.savefig(f'{OUT}/loss.png', dpi=150, bbox_inches='tight'); plt.close()
    except: pass

    # Push to Hub
    print(f"\n[5] Pushing to Hub...", flush=True)
    from huggingface_hub import HfApi
    api = HfApi()
    for f in sorted(os.listdir(OUT)):
        fp = os.path.join(OUT, f)
        if f.endswith(('.pt', '.png')) and os.path.getsize(fp) < 100_000_000 and 'cache' not in f:
            api.upload_file(path_or_fileobj=fp, path_in_repo=f'gpu_trained/{f}',
                            repo_id=REPO, repo_type='model')
            print(f"    Uploaded gpu_trained/{f}", flush=True)

    # Upload train script
    api.upload_file(path_or_fileobj='/app/train_extended.py', path_in_repo='train_gpu.py',
                    repo_id=REPO, repo_type='model')
    print(f"    Uploaded train_gpu.py", flush=True)

    print(f"\n{'='*60}", flush=True)
    print(f"DONE! Best loss: {best_loss:.4f}", flush=True)
    print(f"See: https://huggingface.co/{REPO}", flush=True)
    print(f"{'='*60}", flush=True)

if __name__ == '__main__':
    main()