Coda / src /generate_conditional_compound.py
Prajanya Gupta
initial deploy
6b7b403
"""Text-conditioned compound generation (Phase 4, compound path)."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import List, Sequence, Tuple
import pretty_midi
import torch
import torch.nn.functional as F
from compound import (
AXIS_NAMES,
SENTINELS,
STEP_BAR_END,
STEP_BOS,
STEP_CHORD_END,
STEP_EOS,
STEP_PB,
decode_compound,
)
from inference_pipeline import _pick_device
from prefix_projector import (
PrefixProjector,
clap_text_for_prefix_projector,
load_phase3_compound_components,
)
def _compound_step_embeds(midi_compound_gpt, step_ids: torch.Tensor) -> torch.Tensor:
"""step_ids: (B, T, N_AXES) -> embeds: (B, T, d_model)."""
out = midi_compound_gpt.input_embeds[0](step_ids[..., 0])
for a in range(1, midi_compound_gpt.n_axes):
out = out + midi_compound_gpt.input_embeds[a](step_ids[..., a])
return out
def _sample_axis(
logits: torch.Tensor,
temperature: float,
top_k: int,
top_p: float,
) -> int:
if temperature <= 0:
raise ValueError("temperature must be > 0")
if not 0.0 < top_p <= 1.0:
raise ValueError("top_p must be in (0, 1].")
l = logits.clone() / temperature
if top_k > 0 and top_k < l.numel():
values, _ = torch.topk(l, top_k)
cutoff = values[-1]
l = torch.where(l < cutoff, torch.tensor(float("-inf"), device=l.device), l)
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(l, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(sorted_probs, dim=-1)
remove = cumprobs > top_p
remove[1:] = remove[:-1].clone()
remove[0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
l_filtered = torch.full_like(l, float("-inf"))
l_filtered.scatter_(0, sorted_idx, sorted_logits)
l = l_filtered
probs = F.softmax(l, dim=-1)
return int(torch.multinomial(probs, num_samples=1).item())
@torch.no_grad()
def autoregressive_decode_compound(
clap_model,
midi_compound_gpt,
prefix_projector: PrefixProjector,
prompt: str,
max_new_steps: int,
temperature: float,
top_k: int,
top_p: float,
) -> List[List[int]]:
device = next(prefix_projector.parameters()).device
text_emb = clap_text_for_prefix_projector(clap_model, [prompt], device=device)
prefix_embeds = prefix_projector(text_emb) # (1, K, d_model)
generated_steps: List[List[int]] = []
bos = list(SENTINELS)
bos[0] = STEP_BOS
generated_steps.append(bos)
max_steps = (
midi_compound_gpt.config.block_size - prefix_embeds.size(1) - 1
) # -1 for BOS
for _ in range(min(max_new_steps, max_steps)):
step_ids = torch.tensor([generated_steps], dtype=torch.long, device=device)
token_embeds = _compound_step_embeds(midi_compound_gpt, step_ids)
inputs_embeds = torch.cat([prefix_embeds, token_embeds], dim=1)
seq_len = inputs_embeds.size(1)
if seq_len > midi_compound_gpt.config.block_size:
raise ValueError(
"Requested generation exceeds CompoundGPT block size: "
f"{seq_len} > {midi_compound_gpt.config.block_size}"
)
position_ids = torch.arange(seq_len, device=device, dtype=torch.long).unsqueeze(
0
)
logits_per_axis = midi_compound_gpt(
inputs_embeds=inputs_embeds,
position_ids=position_ids,
)
next_step: List[int] = []
for axis_idx, axis_logits in enumerate(logits_per_axis):
axis_next = _sample_axis(
logits=axis_logits[0, -1, :],
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
next_step.append(axis_next)
# Keep EOS structurally valid: only step-axis carries EOS tag.
if next_step[0] == STEP_EOS:
next_step = [STEP_EOS] + SENTINELS[1:]
generated_steps.append(next_step)
break
generated_steps.append(next_step)
return generated_steps
def _truncate_to_last_boundary(steps: Sequence[Sequence[int]]) -> List[List[int]]:
boundaries = {STEP_EOS, STEP_BAR_END, STEP_CHORD_END}
last = -1
for i, s in enumerate(steps):
if int(s[0]) in boundaries:
last = i
if last == -1:
return [list(s) for s in steps]
return [list(s) for s in steps[: last + 1]]
def save_and_verify_compound_midi(
steps: Sequence[Sequence[int]], out_path: Path
) -> Tuple[int, float]:
out_path.parent.mkdir(parents=True, exist_ok=True)
safe_steps = [list(s) for s in steps if int(s[0]) != STEP_PB]
decode_compound(safe_steps).write(str(out_path))
pm = pretty_midi.PrettyMIDI(str(out_path))
n_notes = sum(len(inst.notes) for inst in pm.instruments)
duration = pm.get_end_time()
if n_notes == 0 or duration < 1.0:
raise RuntimeError(
"Generated compound MIDI is empty or too short; "
"check sampling/decoding behavior."
)
return n_notes, duration
def _step_preview(steps: Sequence[Sequence[int]], max_len: int = 20) -> str:
rows = []
for s in steps[:max_len]:
rows.append(
", ".join(f"{name}={int(v)}" for name, v in zip(AXIS_NAMES, s))
)
suffix = " ..." if len(steps) > max_len else ""
return " | ".join(rows) + suffix
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Generate compound MIDI from text prompt."
)
p.add_argument(
"--compound-midi-checkpoint",
type=str,
default="results/test_compound/checkpoints_compound/compound_best.pt",
)
p.add_argument(
"--compound-clap-checkpoint",
type=str,
default="results/test_compound/checkpoints_contrastive_compound/clap_compound_best.pt",
)
p.add_argument(
"--prefix-checkpoint",
type=str,
default="results/test_compound/checkpoints_prefix/prefix_projector_best.pt",
)
p.add_argument("--prompt", type=str, required=True)
p.add_argument(
"--out",
type=str,
default="results/test_compound/generated_compound_conditional.mid",
)
p.add_argument("--max-new-steps", type=int, default=256)
p.add_argument("--temperature", type=float, default=0.9)
p.add_argument("--top-k", type=int, default=30)
p.add_argument("--top-p", type=float, default=0.95)
p.add_argument("--n-prefix-tokens", type=int, default=8)
return p.parse_args()
def main() -> None:
args = parse_args()
device = _pick_device()
print(f"[gen_cond_compound] device={device}")
clap_model, midi_compound_gpt, projector, _ = load_phase3_compound_components(
compound_midi_checkpoint=args.compound_midi_checkpoint,
compound_clap_checkpoint=args.compound_clap_checkpoint,
n_prefix_tokens=args.n_prefix_tokens,
device=device,
)
prefix_ckpt = torch.load(
Path(args.prefix_checkpoint), map_location=device, weights_only=True
)
proj_state = None
if isinstance(prefix_ckpt, dict):
proj_state = (
prefix_ckpt.get("projector_state_dict")
or prefix_ckpt.get("model_state_dict")
or prefix_ckpt.get("model")
)
if proj_state is None:
proj_state = prefix_ckpt
projector.load_state_dict(proj_state, strict=False)
projector.eval()
midi_compound_gpt.eval()
clap_model.eval()
steps = autoregressive_decode_compound(
clap_model=clap_model,
midi_compound_gpt=midi_compound_gpt,
prefix_projector=projector,
prompt=args.prompt,
max_new_steps=args.max_new_steps,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
steps = _truncate_to_last_boundary(steps)
out_path = Path(args.out)
n_notes, duration = save_and_verify_compound_midi(steps, out_path)
print(f"[gen_cond_compound] output -> {out_path}")
print(f"[gen_cond_compound] notes={n_notes} duration={duration:.2f}s")
print(f"[gen_cond_compound] steps={len(steps)}")
print(f"[gen_cond_compound] step preview: {_step_preview(steps)}")
if __name__ == "__main__":
main()