Spaces:
Sleeping
Sleeping
| """Phase 4f demo artifact builder: 8 prompts x 2 seeds + retrieval report.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import torch | |
| import torch.nn.functional as F | |
| from generate_conditional import ( | |
| EOS, | |
| autoregressive_decode, | |
| encode_text_prompt, | |
| make_initial_context, | |
| project_prefix, | |
| save_and_verify_midi, | |
| truncate_to_last_boundary, | |
| ) | |
| from inference_pipeline import ( | |
| _pick_device, | |
| load_clap, | |
| load_midi_gpt, | |
| load_prefix_projector, | |
| ) | |
| DEMO_PROMPTS: List[str] = [ | |
| ( | |
| "a slow, melancholic piano piece in a minor key " | |
| "with sparse, flowing notes" | |
| ), | |
| ( | |
| "an upbeat rock band with electric guitar, bass, and drums " | |
| "in a major key" | |
| ), | |
| ( | |
| "a fast jazz trio with saxophone, piano, and bass - " | |
| "syncopated and energetic" | |
| ), | |
| ( | |
| "a soft, ambient electronic piece with synthesizer pads, " | |
| "slow and atmospheric" | |
| ), | |
| ( | |
| "a loud, driving hard rock piece with heavy guitar " | |
| "and a dense, busy texture" | |
| ), | |
| "a gentle classical piece for piano and strings, moderate tempo, expressive", | |
| ( | |
| "a funky horn-driven arrangement with brass, guitar, and bass " | |
| "with strong rhythm" | |
| ), | |
| ( | |
| "a sparse solo acoustic guitar piece, fingerpicked, " | |
| "quiet and introspective" | |
| ), | |
| ] | |
| def _fixed_window( | |
| ids: List[int], max_seq_len: int | |
| ) -> tuple[torch.Tensor, torch.Tensor]: | |
| window = ids[:max_seq_len] | |
| valid_len = len(window) | |
| if valid_len < max_seq_len: | |
| window = window + [0] * (max_seq_len - valid_len) | |
| attention_mask = [1] * valid_len + [0] * (max_seq_len - valid_len) | |
| input_ids = torch.tensor([window], dtype=torch.long) | |
| mask = torch.tensor([attention_mask], dtype=torch.long) | |
| return input_ids, mask | |
| def _write_markdown(path: Path, rows: List[Dict[str, Any]]) -> None: | |
| by_prompt: Dict[int, List[Dict[str, Any]]] = {} | |
| for row in rows: | |
| prompt_idx = int(row["prompt_idx"]) | |
| by_prompt.setdefault(prompt_idx, []).append(row) | |
| lines: List[str] = ["# Phase 4 Demo Artifact", ""] | |
| for prompt_idx in sorted(by_prompt): | |
| entries = by_prompt[prompt_idx] | |
| entries = sorted(entries, key=lambda x: int(x["take"])) | |
| best = max( | |
| entries, | |
| key=lambda x: (int(x["is_top1"]), float(x["correct_similarity"])), | |
| ) | |
| lines.append(f"## Prompt {prompt_idx + 1}") | |
| lines.append(f"- Text: {entries[0]['prompt']}") | |
| for row in entries: | |
| lines.append( | |
| "- Take {take} (seed={seed}): file=`{midi_path}` " | |
| "rank={correct_rank}/8 sim={correct_similarity:.4f} top1={is_top1}".format( | |
| take=row["take"], | |
| seed=row["seed"], | |
| midi_path=row["midi_path"], | |
| correct_rank=row["correct_rank"], | |
| correct_similarity=float(row["correct_similarity"]), | |
| is_top1=bool(row["is_top1"]), | |
| ) | |
| ) | |
| lines.append( | |
| "- Selected for demo: take {take} (`{midi_path}`)".format( | |
| take=best["take"], midi_path=best["midi_path"] | |
| ) | |
| ) | |
| lines.append("") | |
| path.write_text("\n".join(lines)) | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="Build Phase 4 batch demo artifact." | |
| ) | |
| p.add_argument("--results-dir", type=str, default="results") | |
| p.add_argument("--midi-checkpoint", type=str, default="") | |
| p.add_argument("--clap-checkpoint", type=str, default="") | |
| p.add_argument("--prefix-checkpoint", type=str, default="") | |
| p.add_argument("--out-dir", type=str, default="") | |
| p.add_argument("--max-seq-len", type=int, default=512) | |
| p.add_argument("--max-new-tokens", type=int, default=512) | |
| p.add_argument("--temperature", type=float, default=0.9) | |
| p.add_argument("--top-k", type=int, default=50) | |
| p.add_argument("--top-p", type=float, default=0.92) | |
| p.add_argument("--repetition-penalty", type=float, default=1.0) | |
| p.add_argument("--repetition-window", type=int, default=64) | |
| p.add_argument("--n-prefix-tokens", type=int, default=0) | |
| p.add_argument("--seed-a", type=int, default=17) | |
| p.add_argument("--seed-b", type=int, default=29) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| results_dir = Path(args.results_dir) | |
| midi_ckpt = Path(args.midi_checkpoint) if args.midi_checkpoint else ( | |
| results_dir / "checkpoints" / "best_model.pt" | |
| ) | |
| clap_ckpt = Path(args.clap_checkpoint) if args.clap_checkpoint else ( | |
| results_dir / "checkpoints_contrastive" / "clap_best.pt" | |
| ) | |
| prefix_ckpt = Path(args.prefix_checkpoint) if args.prefix_checkpoint else ( | |
| results_dir / "checkpoints_prefix" / "prefix_projector_best.pt" | |
| ) | |
| out_dir = ( | |
| Path(args.out_dir) if args.out_dir else (results_dir / "demo_phase4") | |
| ) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| midi_out_dir = out_dir / "midi" | |
| midi_out_dir.mkdir(parents=True, exist_ok=True) | |
| device = _pick_device() | |
| print(f"[demo4f] device={device}") | |
| midi_gpt, _ = load_midi_gpt(midi_ckpt, device=device) | |
| clap, _ = load_clap(clap_ckpt, midi_gpt=midi_gpt, device=device) | |
| override = None if args.n_prefix_tokens <= 0 else args.n_prefix_tokens | |
| projector, _ = load_prefix_projector( | |
| prefix_ckpt, | |
| gpt_d_model=midi_gpt.config.d_model, | |
| device=device, | |
| n_prefix_tokens_override=override, | |
| ) | |
| midi_gpt.eval() | |
| clap.eval() | |
| clap.text_encoder.eval() | |
| projector.eval() | |
| with torch.no_grad(): | |
| prompt_text = clap.encode_text(DEMO_PROMPTS, device=device) | |
| prompt_embs = F.normalize(clap.text_projection(prompt_text), p=2, dim=-1) | |
| seeds = [args.seed_a, args.seed_b] | |
| rows: List[Dict[str, Any]] = [] | |
| for prompt_idx, prompt in enumerate(DEMO_PROMPTS): | |
| for take_idx, seed in enumerate(seeds, start=1): | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| with torch.no_grad(): | |
| text_emb = encode_text_prompt(clap, prompt, device=device) | |
| prefix_embeds = project_prefix(projector, text_emb) | |
| inputs_embeds = make_initial_context(midi_gpt, prefix_embeds) | |
| max_required = inputs_embeds.size(1) + args.max_new_tokens | |
| if max_required > midi_gpt.config.block_size: | |
| raise ValueError( | |
| "Requested generation exceeds GPT block size: " | |
| f"{max_required} > {midi_gpt.config.block_size}" | |
| ) | |
| generated_ids = autoregressive_decode( | |
| midi_gpt=midi_gpt, | |
| inputs_embeds=inputs_embeds, | |
| max_new_tokens=args.max_new_tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| top_p=args.top_p, | |
| repetition_penalty=args.repetition_penalty, | |
| repetition_window=args.repetition_window, | |
| eos_token_id=EOS, | |
| ) | |
| generated_ids = truncate_to_last_boundary(generated_ids) | |
| midi_path = ( | |
| midi_out_dir | |
| / f"prompt_{prompt_idx + 1:02d}_take_{take_idx}.mid" | |
| ) | |
| n_notes, duration = save_and_verify_midi(generated_ids, midi_path) | |
| input_ids, attn_mask = _fixed_window(generated_ids, args.max_seq_len) | |
| input_ids = input_ids.to(device) | |
| attn_mask = attn_mask.to(device) | |
| with torch.no_grad(): | |
| midi_feat = clap.encode_midi(input_ids, attn_mask) | |
| midi_emb = F.normalize(clap.midi_projection(midi_feat), p=2, dim=-1) | |
| sims = (midi_emb @ prompt_embs.t()).squeeze(0) | |
| sorted_idx = torch.argsort(sims, descending=True) | |
| rank = ( | |
| int( | |
| (sorted_idx == prompt_idx) | |
| .nonzero(as_tuple=False)[0] | |
| .item() | |
| ) | |
| + 1 | |
| ) | |
| best_idx = int(sorted_idx[0].item()) | |
| row: Dict[str, Any] = { | |
| "prompt_idx": prompt_idx, | |
| "prompt": prompt, | |
| "take": take_idx, | |
| "seed": seed, | |
| "midi_path": str(midi_path), | |
| "n_notes": n_notes, | |
| "duration_sec": float(duration), | |
| "correct_rank": rank, | |
| "is_top1": rank == 1, | |
| "correct_similarity": float(sims[prompt_idx].item()), | |
| "top1_prompt_idx": best_idx, | |
| "top1_similarity": float(sims[best_idx].item()), | |
| } | |
| rows.append(row) | |
| print( | |
| f"[demo4f] prompt={prompt_idx + 1} take={take_idx} seed={seed} " | |
| f"rank={rank}/8 file={midi_path.name}" | |
| ) | |
| json_path = out_dir / "demo_phase4_results.json" | |
| md_path = out_dir / "demo_phase4_report.md" | |
| json_path.write_text( | |
| json.dumps({"prompts": DEMO_PROMPTS, "rows": rows}, indent=2) | |
| ) | |
| _write_markdown(md_path, rows) | |
| print(f"[demo4f] wrote {json_path}") | |
| print(f"[demo4f] wrote {md_path}") | |
| if __name__ == "__main__": | |
| main() | |