"""Phase 4b text-conditioned MIDI generation with soft prefixes.""" from __future__ import annotations import argparse from pathlib import Path from typing import List, Tuple import pretty_midi import torch import torch.nn.functional as F from inference_pipeline import ( # noqa: E402 _pick_device, load_clap, load_midi_gpt, load_prefix_projector, ) from prefix_projector import clap_text_for_prefix_projector # noqa: E402 from tokenizer import BAR_END, EOS, ID2TOKEN, PHRASE_END, PHRASE_START, decode def encode_text_prompt( clap, prompt: str, device: torch.device ) -> torch.Tensor: """Step 1: CLAP 256-d projected text emb (Phase 3 / contrastive space).""" return clap_text_for_prefix_projector(clap, [prompt], device) def project_prefix(projector, text_emb: torch.Tensor) -> torch.Tensor: """Step 2: deterministic prefix projection to (1, K, d_model).""" return projector(text_emb) def make_initial_context( midi_gpt, prefix_embeds: torch.Tensor ) -> torch.Tensor: """Step 3: BOS token embedding + prefix concat -> (1, K+1, d_model).""" bos = torch.tensor( [[PHRASE_START]], dtype=torch.long, device=prefix_embeds.device ) token_embeds = midi_gpt.wte(bos) return torch.cat([prefix_embeds, token_embeds], dim=1) def _sample_token( logits: torch.Tensor, temperature: float, top_k: int, top_p: float, repetition_penalty: float, recent_tokens: List[int], ) -> torch.Tensor: if temperature <= 0: raise ValueError("temperature must be > 0") if not 0.0 < top_p <= 1.0: raise ValueError("top_p must be in (0, 1].") logits = logits.clone() / temperature if repetition_penalty > 1.0 and recent_tokens: unique_recent = set(recent_tokens) idx = torch.tensor( list(unique_recent), dtype=torch.long, device=logits.device ) logits[:, idx] = logits[:, idx] / repetition_penalty if top_k > 0 and top_k < logits.size(-1): values, _ = torch.topk(logits, top_k) cutoff = values[:, -1].unsqueeze(-1) logits = logits.masked_fill(logits < cutoff, float("-inf")) if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(logits, descending=True, dim=-1) sorted_probs = F.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(sorted_probs, dim=-1) # Keep the smallest prefix with cumulative prob >= top_p. remove = cumprobs > top_p remove[..., 1:] = remove[..., :-1].clone() remove[..., 0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) logits_filtered = torch.full_like(logits, float("-inf")) logits_filtered.scatter_(1, sorted_idx, sorted_logits) logits = logits_filtered probs = F.softmax(logits, dim=-1) return torch.multinomial(probs, num_samples=1) @torch.no_grad() def autoregressive_decode( midi_gpt, inputs_embeds: torch.Tensor, max_new_tokens: int, temperature: float, top_k: int, top_p: float, repetition_penalty: float, repetition_window: int, eos_token_id: int, ) -> List[int]: """Step 4: cached autoregressive decoding.""" generated = [PHRASE_START] seq_len = inputs_embeds.size(1) position_ids = torch.arange( seq_len, device=inputs_embeds.device, dtype=torch.long ).unsqueeze(0) out = midi_gpt( inputs_embeds=inputs_embeds, position_ids=position_ids, use_cache=True, ) logits, past_key_values = out next_token = _sample_token( logits=logits[:, -1, :], temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, recent_tokens=generated[-repetition_window:], ) token_id = int(next_token.item()) generated.append(token_id) if token_id == eos_token_id: return generated cur_pos = seq_len for _ in range(max_new_tokens - 1): step_pos = torch.tensor( [[cur_pos]], device=inputs_embeds.device, dtype=torch.long ) out = midi_gpt( idx=next_token, position_ids=step_pos, use_cache=True, past_key_values=past_key_values, ) logits, past_key_values = out next_token = _sample_token( logits=logits[:, -1, :], temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, recent_tokens=generated[-repetition_window:], ) token_id = int(next_token.item()) generated.append(token_id) cur_pos += 1 if token_id == eos_token_id: break return generated def truncate_to_last_boundary(ids: List[int]) -> List[int]: """Step 5 helper: truncate malformed tails at safe structural boundary.""" boundaries = {EOS, PHRASE_END, BAR_END} last = -1 for i, tid in enumerate(ids): if tid in boundaries: last = i if last == -1: return ids return ids[: last + 1] def save_and_verify_midi(ids: List[int], out_path: Path) -> Tuple[int, float]: """Step 6: write MIDI and verify it's readable/non-empty.""" out_path.parent.mkdir(parents=True, exist_ok=True) decode(ids).write(str(out_path)) pm = pretty_midi.PrettyMIDI(str(out_path)) n_notes = sum(len(inst.notes) for inst in pm.instruments) duration = pm.get_end_time() if n_notes == 0 or duration < 1.0: raise RuntimeError( "Generated MIDI is empty or too short; " "check EOS behavior / decode logic." ) return n_notes, duration def _token_preview(ids: List[int], max_len: int = 60) -> str: toks = [ID2TOKEN.get(i, f"UNK({i})") for i in ids[:max_len]] suffix = " ..." if len(ids) > max_len else "" return " ".join(toks) + suffix def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Generate MIDI from text prompt.") p.add_argument( "--midi-checkpoint", type=str, default="results/checkpoints/best_model.pt", ) p.add_argument( "--clap-checkpoint", type=str, default="results/checkpoints_contrastive/clap_best.pt", ) p.add_argument( "--prefix-checkpoint", type=str, default="results/checkpoints_prefix/prefix_projector_best.pt", ) p.add_argument("--prompt", type=str, required=True) p.add_argument( "--out", type=str, default="results/conditional_generated.mid" ) p.add_argument("--max-new-tokens", type=int, default=512) p.add_argument("--temperature", type=float, default=0.9) p.add_argument("--top-k", type=int, default=50) p.add_argument("--top-p", type=float, default=0.92) p.add_argument("--repetition-penalty", type=float, default=1.0) p.add_argument("--repetition-window", type=int, default=64) p.add_argument("--n-prefix-tokens", type=int, default=0) return p.parse_args() def main() -> None: args = parse_args() device = _pick_device() print(f"[gen_cond] device={device}") midi_gpt, _ = load_midi_gpt(Path(args.midi_checkpoint), device=device) clap, _ = load_clap( Path(args.clap_checkpoint), midi_gpt=midi_gpt, device=device ) override = None if args.n_prefix_tokens <= 0 else args.n_prefix_tokens projector, _ = load_prefix_projector( Path(args.prefix_checkpoint), gpt_d_model=midi_gpt.config.d_model, device=device, n_prefix_tokens_override=override, ) with torch.no_grad(): # Step 1 text_emb = encode_text_prompt(clap, args.prompt, device=device) # Step 2 prefix_embeds = project_prefix(projector, text_emb) # Step 3 inputs_embeds = make_initial_context(midi_gpt, prefix_embeds) max_required = inputs_embeds.size(1) + args.max_new_tokens if max_required > midi_gpt.config.block_size: raise ValueError( "Requested generation exceeds GPT block size: " f"{max_required} > {midi_gpt.config.block_size}" ) # Step 4 generated_ids = autoregressive_decode( midi_gpt=midi_gpt, inputs_embeds=inputs_embeds, max_new_tokens=args.max_new_tokens, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, repetition_penalty=args.repetition_penalty, repetition_window=args.repetition_window, eos_token_id=EOS, ) # Step 5 generated_ids = truncate_to_last_boundary(generated_ids) # Step 6 out_path = Path(args.out) n_notes, duration = save_and_verify_midi(generated_ids, out_path) print(f"[gen_cond] output -> {out_path}") print(f"[gen_cond] notes={n_notes} duration={duration:.2f}s") print(f"[gen_cond] token preview: {_token_preview(generated_ids)}") if __name__ == "__main__": main()