Coda / src /train_prefix.py
Prajanya Gupta
initial deploy
6b7b403
"""Phase 3 training loop (prefix projector only)."""
from __future__ import annotations
import argparse
import math
import sys
import time
from pathlib import Path
from typing import Dict, List
import torch
import torch.nn.functional as F
from torch.optim import AdamW
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT = _SCRIPT_DIR.parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from caption_dataloader import build_caption_dataloaders # noqa: E402
from prefix_projector import ( # noqa: E402
clap_text_for_prefix_projector,
load_phase3_components,
phase3_prefix_lm_loss,
)
from tokenizer import ID2TOKEN, PHRASE_START # noqa: E402
def _pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
mps = getattr(torch.backends, "mps", None)
if mps is not None and mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Phase 3 training: prefix projector only"
)
p.add_argument(
"--midi-checkpoint",
type=str,
default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"),
)
p.add_argument(
"--clap-checkpoint",
type=str,
default=str(
_ROOT
/ "results"
/ "checkpoints_contrastive"
/ "clap_best.pt"
),
)
p.add_argument("--n-prefix-tokens", type=int, default=8)
p.add_argument(
"--captions-jsonl",
type=str,
default=str(_ROOT / "data" / "captions_llm.jsonl"),
)
p.add_argument("--batch-size", type=int, default=64)
p.add_argument("--max-seq-len", type=int, default=512)
p.add_argument("--split-ratio", type=float, default=0.95)
p.add_argument("--num-workers", type=int, default=4)
p.add_argument("--seed", type=int, default=17)
p.add_argument(
"--results-dir",
type=str,
default=str(_ROOT / "results"),
)
p.add_argument("--epochs", type=int, default=20)
p.add_argument("--lr", type=float, default=1e-4)
p.add_argument("--weight-decay", type=float, default=0.01)
p.add_argument("--grad-clip-norm", type=float, default=1.0)
p.add_argument("--warmup-steps", type=int, default=100)
p.add_argument("--min-lr-scale", type=float, default=0.01)
p.add_argument("--prefix-attn-reg-weight", type=float, default=0.0)
p.add_argument("--prefix-attn-min-mean", type=float, default=0.05)
p.add_argument("--qualitative-every", type=int, default=5)
p.add_argument("--qual-gen-tokens", type=int, default=40)
p.add_argument(
"--qual-prompts",
nargs="+",
default=[
"A fast bright piano étude with rising melodic contour.",
"A syncopated jazz combo with saxophone and walking bass.",
"An ambient electronic piece with sustained synth pads.",
],
)
return p.parse_args()
def _set_warmup_cosine_lr(
optimizer: AdamW,
step: int,
total_steps: int,
warmup_steps: int,
base_lr: float,
min_lr_scale: float,
) -> None:
if step < warmup_steps:
mult = float(step + 1) / float(max(1, warmup_steps))
else:
if total_steps <= warmup_steps:
mult = 1.0
else:
progress = (step - warmup_steps) / float(total_steps - warmup_steps)
progress = min(1.0, max(0.0, progress))
mult = min_lr_scale + (1.0 - min_lr_scale) * 0.5 * (
1.0 + math.cos(math.pi * progress)
)
optimizer.param_groups[0]["lr"] = base_lr * mult
@torch.no_grad()
def _eval_loss(
clap_model,
midi_gpt,
projector,
loader,
device: torch.device,
prefix_attn_reg_weight: float,
prefix_attn_min_mean: float,
) -> float:
projector.eval()
total = 0.0
n = 0
for batch in loader:
loss, _ = phase3_prefix_lm_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
prefix_projector=projector,
input_ids=batch["input_ids"].to(device),
captions=batch["captions"],
prefix_attn_reg_weight=prefix_attn_reg_weight,
prefix_attn_min_mean=prefix_attn_min_mean,
)
total += float(loss.item())
n += 1
projector.train()
return total / max(1, n)
def _lm_loss_without_prefix(midi_gpt, input_ids: torch.Tensor) -> torch.Tensor:
logits = midi_gpt(input_ids)
return F.cross_entropy(
logits[:, :-1, :].reshape(-1, logits.size(-1)),
input_ids[:, 1:].reshape(-1),
)
def _infer_genre_label(caption: str) -> str:
text = caption.lower()
if "jazz" in text or "swing" in text or "bebop" in text:
return "jazz"
if "electronic" in text or "synth" in text or "edm" in text:
return "electronic"
if "classical" in text or "orchestral" in text or "baroque" in text:
return "classical"
if "rock" in text or "guitar" in text or "band" in text:
return "rock"
return "other"
@torch.no_grad()
def _conditional_perplexity_gap_by_genre(
clap_model,
midi_gpt,
projector,
loader,
device: torch.device,
max_examples: int = 200,
) -> Dict[str, float]:
projector.eval()
sums_with: Dict[str, float] = {}
sums_without: Dict[str, float] = {}
counts: Dict[str, int] = {}
seen = 0
for batch in loader:
if seen >= max_examples:
break
input_ids = batch["input_ids"].to(device)
captions = batch["captions"]
for i in range(input_ids.size(0)):
if seen >= max_examples:
break
x = input_ids[i : i + 1]
cap = [captions[i]]
loss_with, _ = phase3_prefix_lm_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
prefix_projector=projector,
input_ids=x,
captions=cap,
)
loss_without = _lm_loss_without_prefix(midi_gpt=midi_gpt, input_ids=x)
genre = _infer_genre_label(cap[0])
sums_with[genre] = sums_with.get(genre, 0.0) + float(loss_with.item())
sums_without[genre] = sums_without.get(genre, 0.0) + float(
loss_without.item()
)
counts[genre] = counts.get(genre, 0) + 1
seen += 1
gaps: Dict[str, float] = {}
for genre, n in counts.items():
mean_with = sums_with[genre] / n
mean_without = sums_without[genre] / n
gaps[genre] = math.exp(mean_with) - math.exp(mean_without)
projector.train()
return gaps
@torch.no_grad()
def _generate_unconditional(midi_gpt, gen_tokens: int, device: torch.device) -> List[int]:
seq = torch.tensor([[PHRASE_START]], dtype=torch.long, device=device)
for _ in range(gen_tokens):
logits = midi_gpt(seq)[:, -1, :]
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
seq = torch.cat([seq, nxt], dim=1)
return seq[0].tolist()
@torch.no_grad()
def _generate_with_text_prefix(
clap_model,
midi_gpt,
projector,
text_prompt: str,
gen_tokens: int,
device: torch.device,
) -> List[int]:
# Diagnostic-only helper for qualitative checks during training.
# This re-runs full prefix+GPT forward each token (O(n^2)); production
# inference should use cached decoding in generate_conditional.py.
ids: List[int] = [PHRASE_START]
for _ in range(gen_tokens):
x = torch.tensor([ids], dtype=torch.long, device=device)
_loss, logits_full = phase3_prefix_lm_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
prefix_projector=projector,
input_ids=x,
captions=[text_prompt],
)
logits = logits_full[:, -1, :]
probs = F.softmax(logits, dim=-1)
nxt = torch.multinomial(probs, num_samples=1)
ids.append(int(nxt.item()))
return ids
def _token_preview(ids: List[int], max_len: int = 40) -> str:
toks = [ID2TOKEN.get(i, f"UNK({i})") for i in ids[:max_len]]
suffix = " ..." if len(ids) > max_len else ""
return " ".join(toks) + suffix
@torch.no_grad()
def _prefix_token_scale_diagnostics(
clap_model,
midi_gpt,
projector,
batch,
device: torch.device,
) -> None:
x = batch["input_ids"].to(device)
caps = batch["captions"]
text_emb = clap_text_for_prefix_projector(clap_model, caps, device)
prefix = projector(text_emb)
token = midi_gpt.wte(x)
pnorm = float(prefix.norm(dim=-1).mean().item())
tnorm = float(token.norm(dim=-1).mean().item())
ratio = pnorm / max(1e-8, tnorm)
print(
"[phase3][scale] prefix_norm="
f"{pnorm:.4f} token_norm={tnorm:.4f} ratio={ratio:.3f}"
)
if ratio > 10.0 or ratio < 0.1:
print(
"[phase3][scale][warn] prefix/token norm mismatch is large."
)
@torch.no_grad()
def _verify_prefix_usage(
clap_model,
midi_gpt,
projector,
batch,
device: torch.device,
) -> None:
"""Check loss is lower with correct caption prefix than random wrong one."""
input_ids = batch["input_ids"].to(device)
captions = batch["captions"]
if input_ids.size(0) < 2:
print("[phase3][verify1] skipped: need batch size >= 2.")
return
x = input_ids[0:1]
correct_caption = [captions[0]]
wrong_caption = [captions[1]]
loss_correct, _ = phase3_prefix_lm_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
prefix_projector=projector,
input_ids=x,
captions=correct_caption,
)
loss_wrong, _ = phase3_prefix_lm_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
prefix_projector=projector,
input_ids=x,
captions=wrong_caption,
)
delta = float(loss_wrong.item() - loss_correct.item())
print(
"[phase3][verify1] loss(correct_prefix)="
f"{loss_correct.item():.4f} loss(wrong_prefix)={loss_wrong.item():.4f} "
f"delta(wrong-correct)={delta:+.4f}"
)
if abs(delta) < 1e-4:
print(
"[phase3][verify1][warn] losses are almost identical; "
"prefix may be ignored."
)
def main() -> None:
args = parse_args()
device = _pick_device()
print(f"[phase3] device={device}")
clap_model, midi_gpt, projector, counts = load_phase3_components(
midi_checkpoint=args.midi_checkpoint,
clap_checkpoint=args.clap_checkpoint,
n_prefix_tokens=args.n_prefix_tokens,
device=device,
)
# Phase 3 uses the exact same dataset setup as Phase 2.
train_loader, val_loader, stats = build_caption_dataloaders(
jsonl_path=args.captions_jsonl,
max_seq_len=args.max_seq_len,
batch_size=args.batch_size,
split_ratio=args.split_ratio,
seed=args.seed,
num_workers=args.num_workers,
)
batch = next(iter(train_loader))
if "input_ids" not in batch:
raise RuntimeError(
"Phase 3 requires input_ids from the dataloader to build LM labels."
)
print("[phase3] freeze policy check passed.")
print(
"[phase3] dataset total/train/val="
f"{stats.n_total_records}/{stats.n_train_records}/{stats.n_val_records}"
)
print(
"[phase3] dataloader check passed: input_ids shape="
f"{tuple(batch['input_ids'].shape)}"
)
print(f"[phase3] CLAP params (frozen): {counts.n_clap_params:,}")
print(f"[phase3] GPT params (frozen): {counts.n_gpt_params:,}")
print(f"[phase3] projector params: {counts.n_projector_params:,}")
print(f"[phase3] total trainable: {counts.n_total_trainable:,}")
optimizer = AdamW(
projector.parameters(),
lr=args.lr,
weight_decay=args.weight_decay,
)
steps_per_epoch = len(train_loader)
total_steps = max(1, args.epochs * steps_per_epoch)
print(
f"[phase3] epochs={args.epochs} steps_per_epoch={steps_per_epoch} "
f"total_steps={total_steps} warmup_steps={args.warmup_steps}"
)
global_step = 0
best_val = float("inf")
t0 = time.perf_counter()
for epoch in range(1, args.epochs + 1):
train_loss_sum = 0.0
n_train = 0
verify1_done = False
verify2_done = False
for batch in train_loader:
if not verify1_done:
_verify_prefix_usage(
clap_model=clap_model,
midi_gpt=midi_gpt,
projector=projector,
batch=batch,
device=device,
)
verify1_done = True
_set_warmup_cosine_lr(
optimizer=optimizer,
step=global_step,
total_steps=total_steps,
warmup_steps=args.warmup_steps,
base_lr=args.lr,
min_lr_scale=args.min_lr_scale,
)
optimizer.zero_grad(set_to_none=True)
loss, _ = phase3_prefix_lm_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
prefix_projector=projector,
input_ids=batch["input_ids"].to(device),
captions=batch["captions"],
prefix_attn_reg_weight=args.prefix_attn_reg_weight,
prefix_attn_min_mean=args.prefix_attn_min_mean,
)
loss.backward()
if not verify2_done:
grad = projector.fc2.weight.grad
if grad is None:
print(
"[phase3][verify2][warn] projector.fc2.weight.grad is None."
)
else:
grad_norm = float(grad.norm().item())
grad_abs = float(grad.abs().sum().item())
print(
"[phase3][verify2] projector.fc2.weight grad_norm="
f"{grad_norm:.6f} grad_abs_sum={grad_abs:.6f}"
)
if grad_abs == 0.0:
print(
"[phase3][verify2][warn] projector gradient is all zeros."
)
verify2_done = True
torch.nn.utils.clip_grad_norm_(
projector.parameters(), args.grad_clip_norm
)
optimizer.step()
train_loss_sum += float(loss.item())
n_train += 1
global_step += 1
train_loss = train_loss_sum / max(1, n_train)
val_loss = _eval_loss(
clap_model=clap_model,
midi_gpt=midi_gpt,
projector=projector,
loader=val_loader,
device=device,
prefix_attn_reg_weight=args.prefix_attn_reg_weight,
prefix_attn_min_mean=args.prefix_attn_min_mean,
)
baseline_sum = 0.0
baseline_n = 0
for vbatch in val_loader:
baseline_sum += float(
_lm_loss_without_prefix(
midi_gpt=midi_gpt,
input_ids=vbatch["input_ids"].to(device),
).item()
)
baseline_n += 1
baseline_val = baseline_sum / max(1, baseline_n)
ppl_gap = math.exp(val_loss) - math.exp(baseline_val)
genre_gap = _conditional_perplexity_gap_by_genre(
clap_model=clap_model,
midi_gpt=midi_gpt,
projector=projector,
loader=val_loader,
device=device,
max_examples=200,
)
current_lr = optimizer.param_groups[0]["lr"]
print(
f"[phase3] epoch={epoch}/{args.epochs} "
f"train_loss={train_loss:.4f} val_loss={val_loss:.4f} "
f"baseline_val={baseline_val:.4f} ppl_gap={ppl_gap:+.3f} "
f"lr={current_lr:.2e}"
)
if genre_gap:
parts = " ".join(f"{k}:{v:+.3f}" for k, v in sorted(genre_gap.items()))
print(f"[phase3] genre ppl_gap(with-without): {parts}")
if epoch >= 10 and val_loss >= baseline_val:
print("[phase3][warn] prefix loss not below no-prefix baseline.")
ckpt_dir = Path(
args.results_dir
if hasattr(args, "results_dir")
else _ROOT / "results"
) / "checkpoints_prefix"
ckpt_dir.mkdir(parents=True, exist_ok=True)
ckpt = {
"projector_state_dict": projector.state_dict(),
"epoch": epoch,
"val_loss": val_loss,
"args": vars(args),
}
torch.save(ckpt, ckpt_dir / "prefix_projector_latest.pt")
if val_loss < best_val:
best_val = val_loss
torch.save(ckpt, ckpt_dir / "prefix_projector_best.pt")
_prefix_token_scale_diagnostics(
clap_model=clap_model,
midi_gpt=midi_gpt,
projector=projector,
batch=batch,
device=device,
)
if epoch % args.qualitative_every == 0:
print("\n[phase3] qualitative generation check")
uncond = _generate_unconditional(
midi_gpt=midi_gpt,
gen_tokens=args.qual_gen_tokens,
device=device,
)
print(f" [unconditional] {_token_preview(uncond)}")
for prompt in args.qual_prompts:
cond = _generate_with_text_prefix(
clap_model=clap_model,
midi_gpt=midi_gpt,
projector=projector,
text_prompt=prompt,
gen_tokens=args.qual_gen_tokens,
device=device,
)
print(f" [prompt] {prompt}")
print(f" {_token_preview(cond)}")
elapsed = time.perf_counter() - t0
print(
f"[phase3] finished in {elapsed/60:.1f} min, best_val={best_val:.4f}"
)
if __name__ == "__main__":
main()