Spaces:
Sleeping
Sleeping
| """Phase 3 prefix projector training (compound path).""" | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.optim import AdamW | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from caption_dataloader import build_compound_caption_dataloaders # noqa: E402 | |
| from prefix_projector import ( # noqa: E402 | |
| load_phase3_compound_components, | |
| phase3_compound_prefix_lm_loss, | |
| ) | |
| from compound import SENTINELS, STEP_PAD # noqa: E402 | |
| from compound_model import compound_loss # noqa: E402 | |
| from tokenizer import PHRASE_START # noqa: E402 | |
| def _pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| mps = getattr(torch.backends, "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser(description="Phase 3 prefix projector (compound path)") | |
| p.add_argument("--compound-checkpoint", type=str, required=True) | |
| p.add_argument("--clap-checkpoint", type=str, required=True) | |
| p.add_argument("--captions-jsonl", type=str, required=True) | |
| p.add_argument("--n-prefix-tokens", type=int, default=8) | |
| p.add_argument("--results-dir", type=str, default=str(_ROOT / "results")) | |
| p.add_argument("--epochs", type=int, default=15) | |
| p.add_argument("--batch-size", type=int, default=32) | |
| p.add_argument("--max-seq-len", type=int, default=504) | |
| p.add_argument("--split-ratio", type=float, default=0.95) | |
| p.add_argument("--num-workers", type=int, default=4) | |
| p.add_argument("--seed", type=int, default=17) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--weight-decay", type=float, default=0.01) | |
| p.add_argument("--grad-clip-norm", type=float, default=1.0) | |
| p.add_argument("--warmup-steps", type=int, default=100) | |
| p.add_argument("--min-lr-scale", type=float, default=0.01) | |
| return p.parse_args() | |
| def _set_lr(optimizer, step, total_steps, warmup_steps, base_lr, min_lr_scale): | |
| if step < warmup_steps: | |
| mult = float(step + 1) / float(max(1, warmup_steps)) | |
| else: | |
| progress = (step - warmup_steps) / float(max(1, total_steps - warmup_steps)) | |
| progress = min(1.0, max(0.0, progress)) | |
| mult = min_lr_scale + (1.0 - min_lr_scale) * 0.5 * ( | |
| 1.0 + math.cos(math.pi * progress) | |
| ) | |
| optimizer.param_groups[0]["lr"] = base_lr * mult | |
| def _eval_loss(clap_model, compound_gpt, projector, loader, device): | |
| projector.eval() | |
| total = 0.0 | |
| n = 0 | |
| for batch in loader: | |
| loss, _ = phase3_compound_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_compound_gpt=compound_gpt, | |
| prefix_projector=projector, | |
| compound_input=batch["compound_input"].to(device), | |
| captions=batch["captions"], | |
| ) | |
| total += float(loss.item()) | |
| n += 1 | |
| projector.train() | |
| return total / max(1, n) | |
| def _baseline_loss(compound_gpt, loader, device): | |
| """Loss without prefix — for comparison.""" | |
| compound_gpt.eval() | |
| total = 0.0 | |
| n = 0 | |
| for batch in loader: | |
| x = batch["compound_input"].to(device) | |
| logits = compound_gpt(idx=x) | |
| # shift: predict next step | |
| loss = compound_loss( | |
| logits_per_axis=[l[:, :-1, :] for l in logits], | |
| targets=x[:, 1:, :], | |
| pad_step_value=STEP_PAD, | |
| ignore_pad_steps=True, | |
| ) | |
| total += float(loss.item()) | |
| n += 1 | |
| compound_gpt.eval() | |
| return total / max(1, n) | |
| def main() -> None: | |
| args = parse_args() | |
| device = _pick_device() | |
| print(f"[phase3-compound] device={device}") | |
| clap_model, compound_gpt, projector, counts = load_phase3_compound_components( | |
| compound_midi_checkpoint=args.compound_checkpoint, | |
| compound_clap_checkpoint=args.clap_checkpoint, | |
| n_prefix_tokens=args.n_prefix_tokens, | |
| device=device, | |
| ) | |
| train_loader, val_loader, stats = build_compound_caption_dataloaders( | |
| jsonl_path=args.captions_jsonl, | |
| max_seq_len=args.max_seq_len, | |
| batch_size=args.batch_size, | |
| split_ratio=args.split_ratio, | |
| seed=args.seed, | |
| num_workers=args.num_workers, | |
| ) | |
| print(f"[phase3-compound] dataset total/train/val=" | |
| f"{stats.n_total_records}/{stats.n_train_records}/{stats.n_val_records}") | |
| print(f"[phase3-compound] CLAP params (frozen): {counts.n_clap_params:,}") | |
| print(f"[phase3-compound] GPT params (frozen): {counts.n_gpt_params:,}") | |
| print(f"[phase3-compound] projector params: {counts.n_projector_params:,}") | |
| optimizer = AdamW( | |
| projector.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| steps_per_epoch = len(train_loader) | |
| total_steps = max(1, args.epochs * steps_per_epoch) | |
| print(f"[phase3-compound] epochs={args.epochs} " | |
| f"steps_per_epoch={steps_per_epoch} total_steps={total_steps}") | |
| ckpt_dir = Path(args.results_dir) / "checkpoints_prefix" | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| global_step = 0 | |
| best_val = float("inf") | |
| t0 = time.perf_counter() | |
| for epoch in range(1, args.epochs + 1): | |
| projector.train() | |
| train_loss_sum = 0.0 | |
| n_train = 0 | |
| for batch in train_loader: | |
| _set_lr(optimizer, global_step, total_steps, | |
| args.warmup_steps, args.lr, args.min_lr_scale) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss, _ = phase3_compound_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_compound_gpt=compound_gpt, | |
| prefix_projector=projector, | |
| compound_input=batch["compound_input"].to(device), | |
| captions=batch["captions"], | |
| ) | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(projector.parameters(), args.grad_clip_norm) | |
| optimizer.step() | |
| train_loss_sum += float(loss.item()) | |
| n_train += 1 | |
| global_step += 1 | |
| train_loss = train_loss_sum / max(1, n_train) | |
| val_loss = _eval_loss(clap_model, compound_gpt, projector, val_loader, device) | |
| baseline = _baseline_loss(compound_gpt, val_loader, device) | |
| lr = optimizer.param_groups[0]["lr"] | |
| print( | |
| f"[phase3-compound] epoch={epoch}/{args.epochs} " | |
| f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} " | |
| f"baseline={baseline:.4f} gap={val_loss - baseline:+.4f} " | |
| f"lr={lr:.2e}" | |
| ) | |
| ckpt = { | |
| "projector_state_dict": projector.state_dict(), | |
| "epoch": epoch, | |
| "val_loss": val_loss, | |
| "args": vars(args), | |
| } | |
| torch.save(ckpt, ckpt_dir / "prefix_projector_latest.pt") | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| torch.save(ckpt, ckpt_dir / "prefix_projector_best.pt") | |
| elapsed = time.perf_counter() - t0 | |
| print(f"[phase3-compound] done in {elapsed/60:.1f} min, best_val={best_val:.4f}") | |
| if __name__ == "__main__": | |
| main() | |