Spaces:
Sleeping
Sleeping
| """Phase 3 training loop (prefix projector only).""" | |
| from __future__ import annotations | |
| import argparse | |
| import math | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Dict, List | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.optim import AdamW | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from caption_dataloader import build_caption_dataloaders # noqa: E402 | |
| from prefix_projector import ( # noqa: E402 | |
| clap_text_for_prefix_projector, | |
| load_phase3_components, | |
| phase3_prefix_lm_loss, | |
| ) | |
| from tokenizer import ID2TOKEN, PHRASE_START # noqa: E402 | |
| def _pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| mps = getattr(torch.backends, "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="Phase 3 training: prefix projector only" | |
| ) | |
| p.add_argument( | |
| "--midi-checkpoint", | |
| type=str, | |
| default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"), | |
| ) | |
| p.add_argument( | |
| "--clap-checkpoint", | |
| type=str, | |
| default=str( | |
| _ROOT | |
| / "results" | |
| / "checkpoints_contrastive" | |
| / "clap_best.pt" | |
| ), | |
| ) | |
| p.add_argument("--n-prefix-tokens", type=int, default=8) | |
| p.add_argument( | |
| "--captions-jsonl", | |
| type=str, | |
| default=str(_ROOT / "data" / "captions_llm.jsonl"), | |
| ) | |
| p.add_argument("--batch-size", type=int, default=64) | |
| p.add_argument("--max-seq-len", type=int, default=512) | |
| p.add_argument("--split-ratio", type=float, default=0.95) | |
| p.add_argument("--num-workers", type=int, default=4) | |
| p.add_argument("--seed", type=int, default=17) | |
| p.add_argument( | |
| "--results-dir", | |
| type=str, | |
| default=str(_ROOT / "results"), | |
| ) | |
| p.add_argument("--epochs", type=int, default=20) | |
| p.add_argument("--lr", type=float, default=1e-4) | |
| p.add_argument("--weight-decay", type=float, default=0.01) | |
| p.add_argument("--grad-clip-norm", type=float, default=1.0) | |
| p.add_argument("--warmup-steps", type=int, default=100) | |
| p.add_argument("--min-lr-scale", type=float, default=0.01) | |
| p.add_argument("--prefix-attn-reg-weight", type=float, default=0.0) | |
| p.add_argument("--prefix-attn-min-mean", type=float, default=0.05) | |
| p.add_argument("--qualitative-every", type=int, default=5) | |
| p.add_argument("--qual-gen-tokens", type=int, default=40) | |
| p.add_argument( | |
| "--qual-prompts", | |
| nargs="+", | |
| default=[ | |
| "A fast bright piano étude with rising melodic contour.", | |
| "A syncopated jazz combo with saxophone and walking bass.", | |
| "An ambient electronic piece with sustained synth pads.", | |
| ], | |
| ) | |
| return p.parse_args() | |
| def _set_warmup_cosine_lr( | |
| optimizer: AdamW, | |
| step: int, | |
| total_steps: int, | |
| warmup_steps: int, | |
| base_lr: float, | |
| min_lr_scale: float, | |
| ) -> None: | |
| if step < warmup_steps: | |
| mult = float(step + 1) / float(max(1, warmup_steps)) | |
| else: | |
| if total_steps <= warmup_steps: | |
| mult = 1.0 | |
| else: | |
| progress = (step - warmup_steps) / float(total_steps - warmup_steps) | |
| progress = min(1.0, max(0.0, progress)) | |
| mult = min_lr_scale + (1.0 - min_lr_scale) * 0.5 * ( | |
| 1.0 + math.cos(math.pi * progress) | |
| ) | |
| optimizer.param_groups[0]["lr"] = base_lr * mult | |
| def _eval_loss( | |
| clap_model, | |
| midi_gpt, | |
| projector, | |
| loader, | |
| device: torch.device, | |
| prefix_attn_reg_weight: float, | |
| prefix_attn_min_mean: float, | |
| ) -> float: | |
| projector.eval() | |
| total = 0.0 | |
| n = 0 | |
| for batch in loader: | |
| loss, _ = phase3_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| prefix_projector=projector, | |
| input_ids=batch["input_ids"].to(device), | |
| captions=batch["captions"], | |
| prefix_attn_reg_weight=prefix_attn_reg_weight, | |
| prefix_attn_min_mean=prefix_attn_min_mean, | |
| ) | |
| total += float(loss.item()) | |
| n += 1 | |
| projector.train() | |
| return total / max(1, n) | |
| def _lm_loss_without_prefix(midi_gpt, input_ids: torch.Tensor) -> torch.Tensor: | |
| logits = midi_gpt(input_ids) | |
| return F.cross_entropy( | |
| logits[:, :-1, :].reshape(-1, logits.size(-1)), | |
| input_ids[:, 1:].reshape(-1), | |
| ) | |
| def _infer_genre_label(caption: str) -> str: | |
| text = caption.lower() | |
| if "jazz" in text or "swing" in text or "bebop" in text: | |
| return "jazz" | |
| if "electronic" in text or "synth" in text or "edm" in text: | |
| return "electronic" | |
| if "classical" in text or "orchestral" in text or "baroque" in text: | |
| return "classical" | |
| if "rock" in text or "guitar" in text or "band" in text: | |
| return "rock" | |
| return "other" | |
| def _conditional_perplexity_gap_by_genre( | |
| clap_model, | |
| midi_gpt, | |
| projector, | |
| loader, | |
| device: torch.device, | |
| max_examples: int = 200, | |
| ) -> Dict[str, float]: | |
| projector.eval() | |
| sums_with: Dict[str, float] = {} | |
| sums_without: Dict[str, float] = {} | |
| counts: Dict[str, int] = {} | |
| seen = 0 | |
| for batch in loader: | |
| if seen >= max_examples: | |
| break | |
| input_ids = batch["input_ids"].to(device) | |
| captions = batch["captions"] | |
| for i in range(input_ids.size(0)): | |
| if seen >= max_examples: | |
| break | |
| x = input_ids[i : i + 1] | |
| cap = [captions[i]] | |
| loss_with, _ = phase3_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| prefix_projector=projector, | |
| input_ids=x, | |
| captions=cap, | |
| ) | |
| loss_without = _lm_loss_without_prefix(midi_gpt=midi_gpt, input_ids=x) | |
| genre = _infer_genre_label(cap[0]) | |
| sums_with[genre] = sums_with.get(genre, 0.0) + float(loss_with.item()) | |
| sums_without[genre] = sums_without.get(genre, 0.0) + float( | |
| loss_without.item() | |
| ) | |
| counts[genre] = counts.get(genre, 0) + 1 | |
| seen += 1 | |
| gaps: Dict[str, float] = {} | |
| for genre, n in counts.items(): | |
| mean_with = sums_with[genre] / n | |
| mean_without = sums_without[genre] / n | |
| gaps[genre] = math.exp(mean_with) - math.exp(mean_without) | |
| projector.train() | |
| return gaps | |
| def _generate_unconditional(midi_gpt, gen_tokens: int, device: torch.device) -> List[int]: | |
| seq = torch.tensor([[PHRASE_START]], dtype=torch.long, device=device) | |
| for _ in range(gen_tokens): | |
| logits = midi_gpt(seq)[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| nxt = torch.multinomial(probs, num_samples=1) | |
| seq = torch.cat([seq, nxt], dim=1) | |
| return seq[0].tolist() | |
| def _generate_with_text_prefix( | |
| clap_model, | |
| midi_gpt, | |
| projector, | |
| text_prompt: str, | |
| gen_tokens: int, | |
| device: torch.device, | |
| ) -> List[int]: | |
| # Diagnostic-only helper for qualitative checks during training. | |
| # This re-runs full prefix+GPT forward each token (O(n^2)); production | |
| # inference should use cached decoding in generate_conditional.py. | |
| ids: List[int] = [PHRASE_START] | |
| for _ in range(gen_tokens): | |
| x = torch.tensor([ids], dtype=torch.long, device=device) | |
| _loss, logits_full = phase3_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| prefix_projector=projector, | |
| input_ids=x, | |
| captions=[text_prompt], | |
| ) | |
| logits = logits_full[:, -1, :] | |
| probs = F.softmax(logits, dim=-1) | |
| nxt = torch.multinomial(probs, num_samples=1) | |
| ids.append(int(nxt.item())) | |
| return ids | |
| def _token_preview(ids: List[int], max_len: int = 40) -> str: | |
| toks = [ID2TOKEN.get(i, f"UNK({i})") for i in ids[:max_len]] | |
| suffix = " ..." if len(ids) > max_len else "" | |
| return " ".join(toks) + suffix | |
| def _prefix_token_scale_diagnostics( | |
| clap_model, | |
| midi_gpt, | |
| projector, | |
| batch, | |
| device: torch.device, | |
| ) -> None: | |
| x = batch["input_ids"].to(device) | |
| caps = batch["captions"] | |
| text_emb = clap_text_for_prefix_projector(clap_model, caps, device) | |
| prefix = projector(text_emb) | |
| token = midi_gpt.wte(x) | |
| pnorm = float(prefix.norm(dim=-1).mean().item()) | |
| tnorm = float(token.norm(dim=-1).mean().item()) | |
| ratio = pnorm / max(1e-8, tnorm) | |
| print( | |
| "[phase3][scale] prefix_norm=" | |
| f"{pnorm:.4f} token_norm={tnorm:.4f} ratio={ratio:.3f}" | |
| ) | |
| if ratio > 10.0 or ratio < 0.1: | |
| print( | |
| "[phase3][scale][warn] prefix/token norm mismatch is large." | |
| ) | |
| def _verify_prefix_usage( | |
| clap_model, | |
| midi_gpt, | |
| projector, | |
| batch, | |
| device: torch.device, | |
| ) -> None: | |
| """Check loss is lower with correct caption prefix than random wrong one.""" | |
| input_ids = batch["input_ids"].to(device) | |
| captions = batch["captions"] | |
| if input_ids.size(0) < 2: | |
| print("[phase3][verify1] skipped: need batch size >= 2.") | |
| return | |
| x = input_ids[0:1] | |
| correct_caption = [captions[0]] | |
| wrong_caption = [captions[1]] | |
| loss_correct, _ = phase3_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| prefix_projector=projector, | |
| input_ids=x, | |
| captions=correct_caption, | |
| ) | |
| loss_wrong, _ = phase3_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| prefix_projector=projector, | |
| input_ids=x, | |
| captions=wrong_caption, | |
| ) | |
| delta = float(loss_wrong.item() - loss_correct.item()) | |
| print( | |
| "[phase3][verify1] loss(correct_prefix)=" | |
| f"{loss_correct.item():.4f} loss(wrong_prefix)={loss_wrong.item():.4f} " | |
| f"delta(wrong-correct)={delta:+.4f}" | |
| ) | |
| if abs(delta) < 1e-4: | |
| print( | |
| "[phase3][verify1][warn] losses are almost identical; " | |
| "prefix may be ignored." | |
| ) | |
| def main() -> None: | |
| args = parse_args() | |
| device = _pick_device() | |
| print(f"[phase3] device={device}") | |
| clap_model, midi_gpt, projector, counts = load_phase3_components( | |
| midi_checkpoint=args.midi_checkpoint, | |
| clap_checkpoint=args.clap_checkpoint, | |
| n_prefix_tokens=args.n_prefix_tokens, | |
| device=device, | |
| ) | |
| # Phase 3 uses the exact same dataset setup as Phase 2. | |
| train_loader, val_loader, stats = build_caption_dataloaders( | |
| jsonl_path=args.captions_jsonl, | |
| max_seq_len=args.max_seq_len, | |
| batch_size=args.batch_size, | |
| split_ratio=args.split_ratio, | |
| seed=args.seed, | |
| num_workers=args.num_workers, | |
| ) | |
| batch = next(iter(train_loader)) | |
| if "input_ids" not in batch: | |
| raise RuntimeError( | |
| "Phase 3 requires input_ids from the dataloader to build LM labels." | |
| ) | |
| print("[phase3] freeze policy check passed.") | |
| print( | |
| "[phase3] dataset total/train/val=" | |
| f"{stats.n_total_records}/{stats.n_train_records}/{stats.n_val_records}" | |
| ) | |
| print( | |
| "[phase3] dataloader check passed: input_ids shape=" | |
| f"{tuple(batch['input_ids'].shape)}" | |
| ) | |
| print(f"[phase3] CLAP params (frozen): {counts.n_clap_params:,}") | |
| print(f"[phase3] GPT params (frozen): {counts.n_gpt_params:,}") | |
| print(f"[phase3] projector params: {counts.n_projector_params:,}") | |
| print(f"[phase3] total trainable: {counts.n_total_trainable:,}") | |
| optimizer = AdamW( | |
| projector.parameters(), | |
| lr=args.lr, | |
| weight_decay=args.weight_decay, | |
| ) | |
| steps_per_epoch = len(train_loader) | |
| total_steps = max(1, args.epochs * steps_per_epoch) | |
| print( | |
| f"[phase3] epochs={args.epochs} steps_per_epoch={steps_per_epoch} " | |
| f"total_steps={total_steps} warmup_steps={args.warmup_steps}" | |
| ) | |
| global_step = 0 | |
| best_val = float("inf") | |
| t0 = time.perf_counter() | |
| for epoch in range(1, args.epochs + 1): | |
| train_loss_sum = 0.0 | |
| n_train = 0 | |
| verify1_done = False | |
| verify2_done = False | |
| for batch in train_loader: | |
| if not verify1_done: | |
| _verify_prefix_usage( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| projector=projector, | |
| batch=batch, | |
| device=device, | |
| ) | |
| verify1_done = True | |
| _set_warmup_cosine_lr( | |
| optimizer=optimizer, | |
| step=global_step, | |
| total_steps=total_steps, | |
| warmup_steps=args.warmup_steps, | |
| base_lr=args.lr, | |
| min_lr_scale=args.min_lr_scale, | |
| ) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss, _ = phase3_prefix_lm_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| prefix_projector=projector, | |
| input_ids=batch["input_ids"].to(device), | |
| captions=batch["captions"], | |
| prefix_attn_reg_weight=args.prefix_attn_reg_weight, | |
| prefix_attn_min_mean=args.prefix_attn_min_mean, | |
| ) | |
| loss.backward() | |
| if not verify2_done: | |
| grad = projector.fc2.weight.grad | |
| if grad is None: | |
| print( | |
| "[phase3][verify2][warn] projector.fc2.weight.grad is None." | |
| ) | |
| else: | |
| grad_norm = float(grad.norm().item()) | |
| grad_abs = float(grad.abs().sum().item()) | |
| print( | |
| "[phase3][verify2] projector.fc2.weight grad_norm=" | |
| f"{grad_norm:.6f} grad_abs_sum={grad_abs:.6f}" | |
| ) | |
| if grad_abs == 0.0: | |
| print( | |
| "[phase3][verify2][warn] projector gradient is all zeros." | |
| ) | |
| verify2_done = True | |
| torch.nn.utils.clip_grad_norm_( | |
| projector.parameters(), args.grad_clip_norm | |
| ) | |
| optimizer.step() | |
| train_loss_sum += float(loss.item()) | |
| n_train += 1 | |
| global_step += 1 | |
| train_loss = train_loss_sum / max(1, n_train) | |
| val_loss = _eval_loss( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| projector=projector, | |
| loader=val_loader, | |
| device=device, | |
| prefix_attn_reg_weight=args.prefix_attn_reg_weight, | |
| prefix_attn_min_mean=args.prefix_attn_min_mean, | |
| ) | |
| baseline_sum = 0.0 | |
| baseline_n = 0 | |
| for vbatch in val_loader: | |
| baseline_sum += float( | |
| _lm_loss_without_prefix( | |
| midi_gpt=midi_gpt, | |
| input_ids=vbatch["input_ids"].to(device), | |
| ).item() | |
| ) | |
| baseline_n += 1 | |
| baseline_val = baseline_sum / max(1, baseline_n) | |
| ppl_gap = math.exp(val_loss) - math.exp(baseline_val) | |
| genre_gap = _conditional_perplexity_gap_by_genre( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| projector=projector, | |
| loader=val_loader, | |
| device=device, | |
| max_examples=200, | |
| ) | |
| current_lr = optimizer.param_groups[0]["lr"] | |
| print( | |
| f"[phase3] epoch={epoch}/{args.epochs} " | |
| f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} " | |
| f"baseline_val={baseline_val:.4f} ppl_gap={ppl_gap:+.3f} " | |
| f"lr={current_lr:.2e}" | |
| ) | |
| if genre_gap: | |
| parts = " ".join(f"{k}:{v:+.3f}" for k, v in sorted(genre_gap.items())) | |
| print(f"[phase3] genre ppl_gap(with-without): {parts}") | |
| if epoch >= 10 and val_loss >= baseline_val: | |
| print("[phase3][warn] prefix loss not below no-prefix baseline.") | |
| ckpt_dir = Path( | |
| args.results_dir | |
| if hasattr(args, "results_dir") | |
| else _ROOT / "results" | |
| ) / "checkpoints_prefix" | |
| ckpt_dir.mkdir(parents=True, exist_ok=True) | |
| ckpt = { | |
| "projector_state_dict": projector.state_dict(), | |
| "epoch": epoch, | |
| "val_loss": val_loss, | |
| "args": vars(args), | |
| } | |
| torch.save(ckpt, ckpt_dir / "prefix_projector_latest.pt") | |
| if val_loss < best_val: | |
| best_val = val_loss | |
| torch.save(ckpt, ckpt_dir / "prefix_projector_best.pt") | |
| _prefix_token_scale_diagnostics( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| projector=projector, | |
| batch=batch, | |
| device=device, | |
| ) | |
| if epoch % args.qualitative_every == 0: | |
| print("\n[phase3] qualitative generation check") | |
| uncond = _generate_unconditional( | |
| midi_gpt=midi_gpt, | |
| gen_tokens=args.qual_gen_tokens, | |
| device=device, | |
| ) | |
| print(f" [unconditional] {_token_preview(uncond)}") | |
| for prompt in args.qual_prompts: | |
| cond = _generate_with_text_prefix( | |
| clap_model=clap_model, | |
| midi_gpt=midi_gpt, | |
| projector=projector, | |
| text_prompt=prompt, | |
| gen_tokens=args.qual_gen_tokens, | |
| device=device, | |
| ) | |
| print(f" [prompt] {prompt}") | |
| print(f" {_token_preview(cond)}") | |
| elapsed = time.perf_counter() - t0 | |
| print( | |
| f"[phase3] finished in {elapsed/60:.1f} min, best_val={best_val:.4f}" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |