| |
| """ |
| LRF Extended Training — more epochs on CPU with cached latents. |
| Uses the same proven architecture from v3, just trains longer. |
| Pushes results to HF Hub. |
| """ |
| import math, os, sys, time, json |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, TensorDataset |
| from einops import rearrange |
| import numpy as np |
|
|
| DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| |
| sys.path.insert(0, '/app') |
| from lrf_v3 import LRF, FlowScheduler, get_taesd, get_cifar, precompute, save_grid, gen |
|
|
| def main(): |
| OUT = '/app/lrf_extended' |
| REPO = 'krystv/LatentRecurrentFlow' |
| os.makedirs(OUT, exist_ok=True) |
|
|
| EPOCHS = 100 |
| BS = 128 |
| LR = 5e-4 |
|
|
| print("=" * 60, flush=True) |
| print(f"LRF Extended Training — {EPOCHS} epochs, bs={BS}", flush=True) |
| print(f"Device: {DEVICE}", flush=True) |
| print("=" * 60, flush=True) |
|
|
| |
| print("\n[1] Loading TAESD + CIFAR-10...", flush=True) |
| vae = get_taesd(DEVICE) |
| tr, te = get_cifar() |
|
|
| |
| cache_dir = '/app/lrf_out' |
| if os.path.exists(f'{cache_dir}/cache_train.pt'): |
| print(" Using cached latents from v3 run!", flush=True) |
| tr_lat, tr_lab = precompute(vae, tr, 256, DEVICE, f'{cache_dir}/cache_train.pt') |
| else: |
| tr_lat, tr_lab = precompute(vae, tr, 256, DEVICE, f'{OUT}/cache_train.pt') |
|
|
| |
| print("\n[2] Creating model...", flush=True) |
| cfg = LRF.default() |
| model = LRF(cfg).to(DEVICE) |
| print(f" {model.count():,} params", flush=True) |
|
|
| |
| v3_ckpt = '/app/lrf_out/model.pt' |
| if os.path.exists(v3_ckpt): |
| print(f" Warm-starting from {v3_ckpt}", flush=True) |
| ckpt = torch.load(v3_ckpt, map_location=DEVICE, weights_only=False) |
| model.load_state_dict(ckpt['state']) |
| prev_losses = ckpt.get('losses', []) |
| print(f" Previous best loss: {min(prev_losses):.4f}", flush=True) |
| else: |
| prev_losses = [] |
|
|
| |
| print(f"\n[3] Training {EPOCHS} epochs...", flush=True) |
| sched = FlowScheduler() |
| opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=0.01, betas=(0.9, 0.95)) |
| total_steps = EPOCHS * (len(tr_lat) // BS) |
| lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(opt, total_steps, LR * 0.01) |
| ema = {n: p.clone().detach() for n, p in model.named_parameters()} |
| losses = list(prev_losses) |
|
|
| dl = DataLoader(TensorDataset(tr_lat, tr_lab), BS, shuffle=True, drop_last=True, num_workers=0) |
| best_loss = min(losses) if losses else 999 |
| t0 = time.time() |
|
|
| for ep in range(EPOCHS): |
| model.train() |
| el, nb = 0, 0 |
| for lat, lab in dl: |
| lat, lab = lat.to(DEVICE), lab.to(DEVICE) |
| B = lat.shape[0] |
| t = sched.sample_t(B, DEVICE) |
| eps = torch.randn_like(lat) |
| zt = sched.add_noise(lat, eps, t) |
| vp = model.predict_v(zt, t, lab, cfg_drop=0.1) |
| vt = sched.velocity(lat, eps) |
| lps = (vp - vt).pow(2).mean([1,2,3]) |
| w = 1.0 / (t * (1-t) + 0.01); w = w / w.mean() |
| loss = (lps * w).mean() |
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step(); lr_sched.step() |
| with torch.no_grad(): |
| for n, p in model.named_parameters(): |
| ema[n].mul_(0.9995).add_(p, alpha=0.0005) |
| el += loss.item(); nb += 1 |
|
|
| al = el / nb |
| losses.append(al) |
| if al < best_loss: best_loss = al |
| elapsed = time.time() - t0 |
|
|
| if (ep+1) % 10 == 0 or ep == 0 or ep == EPOCHS-1: |
| print(f" Ep {ep+1:3d}/{EPOCHS}: loss={al:.4f} best={best_loss:.4f} " |
| f"lr={opt.param_groups[0]['lr']:.1e} {elapsed:.0f}s", flush=True) |
|
|
| |
| if (ep+1) % 25 == 0 or ep == EPOCHS-1: |
| bak = {n: p.clone() for n, p in model.named_parameters()} |
| with torch.no_grad(): |
| for n, p in model.named_parameters(): p.copy_(ema[n]) |
| model.eval() |
| samps = gen(model, vae, sched, DEVICE, 16, 20, 2.5) |
| save_grid(samps, f'{OUT}/ep{ep+1:03d}.png', 4) |
| with torch.no_grad(): |
| for n, p in model.named_parameters(): p.copy_(bak[n]) |
|
|
| |
| with torch.no_grad(): |
| for n, p in model.named_parameters(): p.copy_(ema[n]) |
| model.eval() |
|
|
| |
| print(f"\n[4] Final generation...", flush=True) |
| classes = ['airplane','auto','bird','cat','deer','dog','frog','horse','ship','truck'] |
| all_s = [] |
| for ci in range(10): |
| s = gen(model, vae, sched, DEVICE, 8, 50, 3.0, ci) |
| all_s.append(s) |
| print(f" {classes[ci]:10s}: std={s.std():.3f}", flush=True) |
| save_grid(torch.cat(all_s), f'{OUT}/final.png', 8) |
|
|
| |
| torch.save({'state': model.state_dict(), 'cfg': cfg, 'losses': losses}, f'{OUT}/model.pt') |
|
|
| |
| try: |
| import matplotlib; matplotlib.use('Agg'); import matplotlib.pyplot as plt |
| plt.figure(figsize=(10,4)) |
| plt.plot(losses, 'b-', alpha=0.7) |
| if prev_losses: |
| plt.axvline(x=len(prev_losses), color='r', linestyle='--', alpha=0.5, label='Extended training start') |
| plt.legend() |
| plt.xlabel('Epoch'); plt.ylabel('Loss') |
| plt.title(f'LRF Training (best={best_loss:.4f})') |
| plt.grid(True, alpha=0.3) |
| plt.savefig(f'{OUT}/loss.png', dpi=150, bbox_inches='tight'); plt.close() |
| except: pass |
|
|
| |
| print(f"\n[5] Pushing to Hub...", flush=True) |
| from huggingface_hub import HfApi |
| api = HfApi() |
| for f in sorted(os.listdir(OUT)): |
| fp = os.path.join(OUT, f) |
| if f.endswith(('.pt', '.png')) and os.path.getsize(fp) < 100_000_000 and 'cache' not in f: |
| api.upload_file(path_or_fileobj=fp, path_in_repo=f'gpu_trained/{f}', |
| repo_id=REPO, repo_type='model') |
| print(f" Uploaded gpu_trained/{f}", flush=True) |
|
|
| |
| api.upload_file(path_or_fileobj='/app/train_extended.py', path_in_repo='train_gpu.py', |
| repo_id=REPO, repo_type='model') |
| print(f" Uploaded train_gpu.py", flush=True) |
|
|
| print(f"\n{'='*60}", flush=True) |
| print(f"DONE! Best loss: {best_loss:.4f}", flush=True) |
| print(f"See: https://huggingface.co/{REPO}", flush=True) |
| print(f"{'='*60}", flush=True) |
|
|
| if __name__ == '__main__': |
| main() |
|
|