Coda / src /generate_compound_unconditional.py
Prajanya Gupta
initial deploy
6b7b403
"""Sample MIDI/WAV from a CompoundGPT checkpoint only (no CLAP / prefix weights)."""
from __future__ import annotations
import argparse
from pathlib import Path
from typing import List, Sequence
import numpy as np
import pretty_midi
import scipy.io.wavfile
import torch
import torch.nn.functional as F
from compound import (
SENTINELS,
STEP_BAR_END,
STEP_BOS,
STEP_CHORD_END,
STEP_EOS,
STEP_PB,
decode_compound,
)
from compound_model import CompoundGPT, CompoundGPTConfig, default_compound_config
from inference_pipeline import _pick_device
def _load_compound_gpt(ckpt_path: Path, device: torch.device) -> CompoundGPT:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
cfg = default_compound_config()
raw_cfg = ckpt.get("config") if isinstance(ckpt, dict) else None
if isinstance(raw_cfg, dict):
for k in CompoundGPTConfig.__dataclass_fields__.keys():
if k in raw_cfg:
setattr(cfg, k, raw_cfg[k])
model = CompoundGPT(cfg).to(device)
state = ckpt.get("model_state_dict", ckpt)
model.load_state_dict(state, strict=False)
model.eval()
return model
def _sample_axis(
logits: torch.Tensor,
temperature: float,
top_k: int,
top_p: float,
) -> int:
if temperature <= 0:
raise ValueError("temperature must be > 0")
if not 0.0 < top_p <= 1.0:
raise ValueError("top_p must be in (0, 1].")
l = logits.clone() / temperature
if top_k > 0 and top_k < l.numel():
values, _ = torch.topk(l, top_k)
cutoff = values[-1]
l = torch.where(l < cutoff, torch.tensor(float("-inf"), device=l.device), l)
if top_p < 1.0:
sorted_logits, sorted_idx = torch.sort(l, descending=True)
sorted_probs = F.softmax(sorted_logits, dim=-1)
cumprobs = torch.cumsum(sorted_probs, dim=-1)
remove = cumprobs > top_p
remove[1:] = remove[:-1].clone()
remove[0] = False
sorted_logits = sorted_logits.masked_fill(remove, float("-inf"))
l_filtered = torch.full_like(l, float("-inf"))
l_filtered.scatter_(0, sorted_idx, sorted_logits)
l = l_filtered
probs = F.softmax(l, dim=-1)
return int(torch.multinomial(probs, num_samples=1).item())
def _truncate_to_last_boundary(steps: Sequence[Sequence[int]]) -> List[List[int]]:
boundaries = {STEP_EOS, STEP_BAR_END, STEP_CHORD_END}
last = -1
for i, s in enumerate(steps):
if int(s[0]) in boundaries:
last = i
if last == -1:
return [list(s) for s in steps]
return [list(s) for s in steps[: last + 1]]
@torch.no_grad()
def _generate_one_sequence(
model: CompoundGPT,
device: torch.device,
max_new_steps: int,
temperature: float,
top_k: int,
top_p: float,
) -> List[List[int]]:
generated_steps: List[List[int]] = []
bos = list(SENTINELS)
bos[0] = STEP_BOS
generated_steps.append(bos)
for _ in range(max_new_steps):
step_ids = torch.tensor([generated_steps], dtype=torch.long, device=device)
if step_ids.size(1) > model.config.block_size:
raise ValueError(
f"sequence length {step_ids.size(1)} > block_size {model.config.block_size}"
)
position_ids = torch.arange(
step_ids.size(1), device=device, dtype=torch.long
).unsqueeze(0)
logits_per_axis = model(idx=step_ids, position_ids=position_ids)
next_step: List[int] = []
for axis_logits in logits_per_axis:
axis_next = _sample_axis(
logits=axis_logits[0, -1, :],
temperature=temperature,
top_k=top_k,
top_p=top_p,
)
next_step.append(axis_next)
if next_step[0] == STEP_EOS:
next_step = [STEP_EOS] + SENTINELS[1:]
generated_steps.append(next_step)
break
generated_steps.append(next_step)
return _truncate_to_last_boundary(generated_steps)
def _steps_to_safe_midi_steps(steps: Sequence[Sequence[int]]) -> List[List[int]]:
return [list(s) for s in steps if int(s[0]) != STEP_PB]
def _append_pm(dst: pretty_midi.PrettyMIDI, src: pretty_midi.PrettyMIDI, t0: float) -> None:
for inst in src.instruments:
new_inst = pretty_midi.Instrument(
program=inst.program, is_drum=inst.is_drum, name=inst.name
)
for n in inst.notes:
new_inst.notes.append(
pretty_midi.Note(
velocity=n.velocity,
pitch=n.pitch,
start=n.start + t0,
end=n.end + t0,
)
)
for cc in inst.control_changes:
new_inst.control_changes.append(
pretty_midi.ControlChange(
number=cc.number,
value=cc.value,
time=float(cc.time) + t0,
)
)
for pb in inst.pitch_bends:
new_inst.pitch_bends.append(
pretty_midi.PitchBend(pitch=pb.pitch, time=pb.time + t0)
)
if (
new_inst.notes
or new_inst.control_changes
or new_inst.pitch_bends
):
dst.instruments.append(new_inst)
def _synthesize_wav_numpy(pm: pretty_midi.PrettyMIDI, sample_rate: int) -> np.ndarray:
"""Fallback PCM when FluidSynth/pyfluidsynth is unavailable (simple additive tones)."""
duration = float(pm.get_end_time())
n_samples = int(np.ceil(duration * sample_rate)) + 1
y = np.zeros(n_samples, dtype=np.float64)
twopi = 2.0 * np.pi
for inst in pm.instruments:
for note in inst.notes:
f = 440.0 * (2.0 ** ((float(note.pitch) - 69.0) / 12.0))
i0 = max(0, int(note.start * sample_rate))
i1 = min(n_samples, int(np.ceil(note.end * sample_rate)))
if i1 <= i0:
continue
seg_len = i1 - i0
t = (np.arange(seg_len, dtype=np.float64) + i0) / sample_rate
ph = twopi * f * t
vel = float(note.velocity) / 127.0
sig = vel * (
0.55 * np.sin(ph)
+ 0.28 * np.sin(2.0 * ph)
+ 0.12 * np.sin(3.0 * ph)
+ 0.05 * np.sin(4.0 * ph)
)
atk = max(1, int(0.008 * sample_rate))
rel = max(1, int(0.04 * sample_rate))
env = np.ones(seg_len, dtype=np.float64)
env[:atk] *= np.linspace(0.0, 1.0, atk, endpoint=False)
tail = min(rel, seg_len)
env[-tail:] *= np.linspace(1.0, 0.0, tail, endpoint=False)
y[i0:i1] += sig * env
peak = float(np.max(np.abs(y))) if y.size else 0.0
if peak > 1e-8:
y = y / peak * 0.85
return y.astype(np.float32)
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Unconditional compound GPT sampling (checkpoint weights only)."
)
p.add_argument(
"--checkpoint",
type=str,
default="compound_best.pt",
help="CompoundGPT checkpoint (e.g. compound_best.pt).",
)
p.add_argument("--out-midi", type=str, default="results/compound_unconditional.mid")
p.add_argument("--out-wav", type=str, default="results/compound_unconditional.wav")
p.add_argument(
"--target-seconds",
type=float,
default=60.0,
help="Accumulate decoded MIDI segments until at least this duration.",
)
p.add_argument(
"--max-segments",
type=int,
default=64,
help="Safety cap on number of BOS..EOS sequences to stitch.",
)
p.add_argument("--temperature", type=float, default=0.9)
p.add_argument("--top-k", type=int, default=30)
p.add_argument("--top-p", type=float, default=0.95)
p.add_argument("--seed", type=int, default=0)
p.add_argument("--sample-rate", type=int, default=44100)
return p.parse_args()
def main() -> None:
args = parse_args()
device = _pick_device()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
ckpt_path = Path(args.checkpoint)
print(f"[gen_compound_uncond] device={device} ckpt={ckpt_path}")
model = _load_compound_gpt(ckpt_path, device=device)
bs = model.config.block_size
max_new = bs - 1
pm_out = pretty_midi.PrettyMIDI(initial_tempo=120.0)
t_off = 0.0
n_segments = 0
while t_off < args.target_seconds and n_segments < args.max_segments:
torch.manual_seed(args.seed + n_segments)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed + n_segments)
steps = _generate_one_sequence(
model=model,
device=device,
max_new_steps=max_new,
temperature=args.temperature,
top_k=args.top_k,
top_p=args.top_p,
)
safe = _steps_to_safe_midi_steps(steps)
if len(safe) < 2:
print("[gen_compound_uncond] warning: empty segment, retrying sweep")
break
seg = decode_compound(safe)
dur = seg.get_end_time()
if dur < 0.05:
print("[gen_compound_uncond] warning: near-empty decode, stopping")
break
_append_pm(pm_out, seg, t_off)
t_off = pm_out.get_end_time()
n_segments += 1
print(
f"[gen_compound_uncond] segment={n_segments} steps={len(steps)} "
f"seg_dur={dur:.2f}s total={t_off:.2f}s"
)
if t_off < 1.0:
raise RuntimeError(
"Generated MIDI is too short; try different --seed or sampling params."
)
midi_path = Path(args.out_midi)
wav_path = Path(args.out_wav)
midi_path.parent.mkdir(parents=True, exist_ok=True)
wav_path.parent.mkdir(parents=True, exist_ok=True)
pm_out.write(str(midi_path))
print(f"[gen_compound_uncond] midi -> {midi_path} duration={t_off:.2f}s")
try:
audio = pm_out.fluidsynth(fs=args.sample_rate)
audio = np.asarray(audio, dtype=np.float32).reshape(-1)
except (ImportError, OSError, ValueError) as e:
print(
f"[gen_compound_uncond] fluidsynth unavailable ({e!s}); "
"using numpy additive synthesizer fallback."
)
audio = _synthesize_wav_numpy(pm_out, args.sample_rate)
max_samples = int(args.target_seconds * args.sample_rate)
if audio.size > max_samples:
audio = audio[:max_samples]
audio = np.clip(audio, -1.0, 1.0)
scipy.io.wavfile.write(
str(wav_path), args.sample_rate, (audio * 32767.0).astype(np.int16)
)
print(f"[gen_compound_uncond] wav -> {wav_path} samples={audio.size}")
if __name__ == "__main__":
main()