Coda / src /demo_phase4_batch.py
Prajanya Gupta
initial deploy
6b7b403
"""Phase 4f demo artifact builder: 8 prompts x 2 seeds + retrieval report."""
from __future__ import annotations
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List
import torch
import torch.nn.functional as F
from generate_conditional import (
EOS,
autoregressive_decode,
encode_text_prompt,
make_initial_context,
project_prefix,
save_and_verify_midi,
truncate_to_last_boundary,
)
from inference_pipeline import (
_pick_device,
load_clap,
load_midi_gpt,
load_prefix_projector,
)
DEMO_PROMPTS: List[str] = [
(
"a slow, melancholic piano piece in a minor key "
"with sparse, flowing notes"
),
(
"an upbeat rock band with electric guitar, bass, and drums "
"in a major key"
),
(
"a fast jazz trio with saxophone, piano, and bass - "
"syncopated and energetic"
),
(
"a soft, ambient electronic piece with synthesizer pads, "
"slow and atmospheric"
),
(
"a loud, driving hard rock piece with heavy guitar "
"and a dense, busy texture"
),
"a gentle classical piece for piano and strings, moderate tempo, expressive",
(
"a funky horn-driven arrangement with brass, guitar, and bass "
"with strong rhythm"
),
(
"a sparse solo acoustic guitar piece, fingerpicked, "
"quiet and introspective"
),
]
def _fixed_window(
ids: List[int], max_seq_len: int
) -> tuple[torch.Tensor, torch.Tensor]:
window = ids[:max_seq_len]
valid_len = len(window)
if valid_len < max_seq_len:
window = window + [0] * (max_seq_len - valid_len)
attention_mask = [1] * valid_len + [0] * (max_seq_len - valid_len)
input_ids = torch.tensor([window], dtype=torch.long)
mask = torch.tensor([attention_mask], dtype=torch.long)
return input_ids, mask
def _write_markdown(path: Path, rows: List[Dict[str, Any]]) -> None:
by_prompt: Dict[int, List[Dict[str, Any]]] = {}
for row in rows:
prompt_idx = int(row["prompt_idx"])
by_prompt.setdefault(prompt_idx, []).append(row)
lines: List[str] = ["# Phase 4 Demo Artifact", ""]
for prompt_idx in sorted(by_prompt):
entries = by_prompt[prompt_idx]
entries = sorted(entries, key=lambda x: int(x["take"]))
best = max(
entries,
key=lambda x: (int(x["is_top1"]), float(x["correct_similarity"])),
)
lines.append(f"## Prompt {prompt_idx + 1}")
lines.append(f"- Text: {entries[0]['prompt']}")
for row in entries:
lines.append(
"- Take {take} (seed={seed}): file=`{midi_path}` "
"rank={correct_rank}/8 sim={correct_similarity:.4f} top1={is_top1}".format(
take=row["take"],
seed=row["seed"],
midi_path=row["midi_path"],
correct_rank=row["correct_rank"],
correct_similarity=float(row["correct_similarity"]),
is_top1=bool(row["is_top1"]),
)
)
lines.append(
"- Selected for demo: take {take} (`{midi_path}`)".format(
take=best["take"], midi_path=best["midi_path"]
)
)
lines.append("")
path.write_text("\n".join(lines))
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Build Phase 4 batch demo artifact."
)
p.add_argument("--results-dir", type=str, default="results")
p.add_argument("--midi-checkpoint", type=str, default="")
p.add_argument("--clap-checkpoint", type=str, default="")
p.add_argument("--prefix-checkpoint", type=str, default="")
p.add_argument("--out-dir", type=str, default="")
p.add_argument("--max-seq-len", type=int, default=512)
p.add_argument("--max-new-tokens", type=int, default=512)
p.add_argument("--temperature", type=float, default=0.9)
p.add_argument("--top-k", type=int, default=50)
p.add_argument("--top-p", type=float, default=0.92)
p.add_argument("--repetition-penalty", type=float, default=1.0)
p.add_argument("--repetition-window", type=int, default=64)
p.add_argument("--n-prefix-tokens", type=int, default=0)
p.add_argument("--seed-a", type=int, default=17)
p.add_argument("--seed-b", type=int, default=29)
return p.parse_args()
def main() -> None:
args = parse_args()
results_dir = Path(args.results_dir)
midi_ckpt = Path(args.midi_checkpoint) if args.midi_checkpoint else (
results_dir / "checkpoints" / "best_model.pt"
)
clap_ckpt = Path(args.clap_checkpoint) if args.clap_checkpoint else (
results_dir / "checkpoints_contrastive" / "clap_best.pt"
)
prefix_ckpt = Path(args.prefix_checkpoint) if args.prefix_checkpoint else (
results_dir / "checkpoints_prefix" / "prefix_projector_best.pt"
)
out_dir = (
Path(args.out_dir) if args.out_dir else (results_dir / "demo_phase4")
)
out_dir.mkdir(parents=True, exist_ok=True)
midi_out_dir = out_dir / "midi"
midi_out_dir.mkdir(parents=True, exist_ok=True)
device = _pick_device()
print(f"[demo4f] device={device}")
midi_gpt, _ = load_midi_gpt(midi_ckpt, device=device)
clap, _ = load_clap(clap_ckpt, midi_gpt=midi_gpt, device=device)
override = None if args.n_prefix_tokens <= 0 else args.n_prefix_tokens
projector, _ = load_prefix_projector(
prefix_ckpt,
gpt_d_model=midi_gpt.config.d_model,
device=device,
n_prefix_tokens_override=override,
)
midi_gpt.eval()
clap.eval()
clap.text_encoder.eval()
projector.eval()
with torch.no_grad():
prompt_text = clap.encode_text(DEMO_PROMPTS, device=device)
prompt_embs = F.normalize(clap.text_projection(prompt_text), p=2, dim=-1)
seeds = [args.seed_a, args.seed_b]
rows: List[Dict[str, Any]] = []
for prompt_idx, prompt in enumerate(DEMO_PROMPTS):
for take_idx, seed in enumerate(seeds, start=1):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
with torch.no_grad():
text_emb = encode_text_prompt(clap, prompt, device=device)
prefix_embeds = project_prefix(projector, text_emb)
inputs_embeds = make_initial_context(midi_gpt, prefix_embeds)
max_required = inputs_embeds.size(1) + args.max_new_tokens
if max_required > midi_gpt.config.block_size:
raise ValueError(
"Requested generation exceeds GPT block size: "
f"{max_required} > {midi_gpt.config.block_size}"
)
generated_ids = autoregressive_decode(
midi_gpt=midi_gpt,
inputs_embeds=inputs_embeds,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
repetition_penalty=args.repetition_penalty,
repetition_window=args.repetition_window,
eos_token_id=EOS,
)
generated_ids = truncate_to_last_boundary(generated_ids)
midi_path = (
midi_out_dir
/ f"prompt_{prompt_idx + 1:02d}_take_{take_idx}.mid"
)
n_notes, duration = save_and_verify_midi(generated_ids, midi_path)
input_ids, attn_mask = _fixed_window(generated_ids, args.max_seq_len)
input_ids = input_ids.to(device)
attn_mask = attn_mask.to(device)
with torch.no_grad():
midi_feat = clap.encode_midi(input_ids, attn_mask)
midi_emb = F.normalize(clap.midi_projection(midi_feat), p=2, dim=-1)
sims = (midi_emb @ prompt_embs.t()).squeeze(0)
sorted_idx = torch.argsort(sims, descending=True)
rank = (
int(
(sorted_idx == prompt_idx)
.nonzero(as_tuple=False)[0]
.item()
)
+ 1
)
best_idx = int(sorted_idx[0].item())
row: Dict[str, Any] = {
"prompt_idx": prompt_idx,
"prompt": prompt,
"take": take_idx,
"seed": seed,
"midi_path": str(midi_path),
"n_notes": n_notes,
"duration_sec": float(duration),
"correct_rank": rank,
"is_top1": rank == 1,
"correct_similarity": float(sims[prompt_idx].item()),
"top1_prompt_idx": best_idx,
"top1_similarity": float(sims[best_idx].item()),
}
rows.append(row)
print(
f"[demo4f] prompt={prompt_idx + 1} take={take_idx} seed={seed} "
f"rank={rank}/8 file={midi_path.name}"
)
json_path = out_dir / "demo_phase4_results.json"
md_path = out_dir / "demo_phase4_report.md"
json_path.write_text(
json.dumps({"prompts": DEMO_PROMPTS, "rows": rows}, indent=2)
)
_write_markdown(md_path, rows)
print(f"[demo4f] wrote {json_path}")
print(f"[demo4f] wrote {md_path}")
if __name__ == "__main__":
main()