"""Sample MIDI/WAV from a CompoundGPT checkpoint only (no CLAP / prefix weights).""" from __future__ import annotations import argparse from pathlib import Path from typing import List, Sequence import numpy as np import pretty_midi import scipy.io.wavfile import torch import torch.nn.functional as F from compound import ( SENTINELS, STEP_BAR_END, STEP_BOS, STEP_CHORD_END, STEP_EOS, STEP_PB, decode_compound, ) from compound_model import CompoundGPT, CompoundGPTConfig, default_compound_config from inference_pipeline import _pick_device def _load_compound_gpt(ckpt_path: Path, device: torch.device) -> CompoundGPT: ckpt = torch.load(ckpt_path, map_location=device, weights_only=True) cfg = default_compound_config() raw_cfg = ckpt.get("config") if isinstance(ckpt, dict) else None if isinstance(raw_cfg, dict): for k in CompoundGPTConfig.__dataclass_fields__.keys(): if k in raw_cfg: setattr(cfg, k, raw_cfg[k]) model = CompoundGPT(cfg).to(device) state = ckpt.get("model_state_dict", ckpt) model.load_state_dict(state, strict=False) model.eval() return model def _sample_axis( logits: torch.Tensor, temperature: float, top_k: int, top_p: float, ) -> int: if temperature <= 0: raise ValueError("temperature must be > 0") if not 0.0 < top_p <= 1.0: raise ValueError("top_p must be in (0, 1].") l = logits.clone() / temperature if top_k > 0 and top_k < l.numel(): values, _ = torch.topk(l, top_k) cutoff = values[-1] l = torch.where(l < cutoff, torch.tensor(float("-inf"), device=l.device), l) if top_p < 1.0: sorted_logits, sorted_idx = torch.sort(l, descending=True) sorted_probs = F.softmax(sorted_logits, dim=-1) cumprobs = torch.cumsum(sorted_probs, dim=-1) remove = cumprobs > top_p remove[1:] = remove[:-1].clone() remove[0] = False sorted_logits = sorted_logits.masked_fill(remove, float("-inf")) l_filtered = torch.full_like(l, float("-inf")) l_filtered.scatter_(0, sorted_idx, sorted_logits) l = l_filtered probs = F.softmax(l, dim=-1) return int(torch.multinomial(probs, num_samples=1).item()) def _truncate_to_last_boundary(steps: Sequence[Sequence[int]]) -> List[List[int]]: boundaries = {STEP_EOS, STEP_BAR_END, STEP_CHORD_END} last = -1 for i, s in enumerate(steps): if int(s[0]) in boundaries: last = i if last == -1: return [list(s) for s in steps] return [list(s) for s in steps[: last + 1]] @torch.no_grad() def _generate_one_sequence( model: CompoundGPT, device: torch.device, max_new_steps: int, temperature: float, top_k: int, top_p: float, ) -> List[List[int]]: generated_steps: List[List[int]] = [] bos = list(SENTINELS) bos[0] = STEP_BOS generated_steps.append(bos) for _ in range(max_new_steps): step_ids = torch.tensor([generated_steps], dtype=torch.long, device=device) if step_ids.size(1) > model.config.block_size: raise ValueError( f"sequence length {step_ids.size(1)} > block_size {model.config.block_size}" ) position_ids = torch.arange( step_ids.size(1), device=device, dtype=torch.long ).unsqueeze(0) logits_per_axis = model(idx=step_ids, position_ids=position_ids) next_step: List[int] = [] for axis_logits in logits_per_axis: axis_next = _sample_axis( logits=axis_logits[0, -1, :], temperature=temperature, top_k=top_k, top_p=top_p, ) next_step.append(axis_next) if next_step[0] == STEP_EOS: next_step = [STEP_EOS] + SENTINELS[1:] generated_steps.append(next_step) break generated_steps.append(next_step) return _truncate_to_last_boundary(generated_steps) def _steps_to_safe_midi_steps(steps: Sequence[Sequence[int]]) -> List[List[int]]: return [list(s) for s in steps if int(s[0]) != STEP_PB] def _append_pm(dst: pretty_midi.PrettyMIDI, src: pretty_midi.PrettyMIDI, t0: float) -> None: for inst in src.instruments: new_inst = pretty_midi.Instrument( program=inst.program, is_drum=inst.is_drum, name=inst.name ) for n in inst.notes: new_inst.notes.append( pretty_midi.Note( velocity=n.velocity, pitch=n.pitch, start=n.start + t0, end=n.end + t0, ) ) for cc in inst.control_changes: new_inst.control_changes.append( pretty_midi.ControlChange( number=cc.number, value=cc.value, time=float(cc.time) + t0, ) ) for pb in inst.pitch_bends: new_inst.pitch_bends.append( pretty_midi.PitchBend(pitch=pb.pitch, time=pb.time + t0) ) if ( new_inst.notes or new_inst.control_changes or new_inst.pitch_bends ): dst.instruments.append(new_inst) def _synthesize_wav_numpy(pm: pretty_midi.PrettyMIDI, sample_rate: int) -> np.ndarray: """Fallback PCM when FluidSynth/pyfluidsynth is unavailable (simple additive tones).""" duration = float(pm.get_end_time()) n_samples = int(np.ceil(duration * sample_rate)) + 1 y = np.zeros(n_samples, dtype=np.float64) twopi = 2.0 * np.pi for inst in pm.instruments: for note in inst.notes: f = 440.0 * (2.0 ** ((float(note.pitch) - 69.0) / 12.0)) i0 = max(0, int(note.start * sample_rate)) i1 = min(n_samples, int(np.ceil(note.end * sample_rate))) if i1 <= i0: continue seg_len = i1 - i0 t = (np.arange(seg_len, dtype=np.float64) + i0) / sample_rate ph = twopi * f * t vel = float(note.velocity) / 127.0 sig = vel * ( 0.55 * np.sin(ph) + 0.28 * np.sin(2.0 * ph) + 0.12 * np.sin(3.0 * ph) + 0.05 * np.sin(4.0 * ph) ) atk = max(1, int(0.008 * sample_rate)) rel = max(1, int(0.04 * sample_rate)) env = np.ones(seg_len, dtype=np.float64) env[:atk] *= np.linspace(0.0, 1.0, atk, endpoint=False) tail = min(rel, seg_len) env[-tail:] *= np.linspace(1.0, 0.0, tail, endpoint=False) y[i0:i1] += sig * env peak = float(np.max(np.abs(y))) if y.size else 0.0 if peak > 1e-8: y = y / peak * 0.85 return y.astype(np.float32) def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser( description="Unconditional compound GPT sampling (checkpoint weights only)." ) p.add_argument( "--checkpoint", type=str, default="compound_best.pt", help="CompoundGPT checkpoint (e.g. compound_best.pt).", ) p.add_argument("--out-midi", type=str, default="results/compound_unconditional.mid") p.add_argument("--out-wav", type=str, default="results/compound_unconditional.wav") p.add_argument( "--target-seconds", type=float, default=60.0, help="Accumulate decoded MIDI segments until at least this duration.", ) p.add_argument( "--max-segments", type=int, default=64, help="Safety cap on number of BOS..EOS sequences to stitch.", ) p.add_argument("--temperature", type=float, default=0.9) p.add_argument("--top-k", type=int, default=30) p.add_argument("--top-p", type=float, default=0.95) p.add_argument("--seed", type=int, default=0) p.add_argument("--sample-rate", type=int, default=44100) return p.parse_args() def main() -> None: args = parse_args() device = _pick_device() torch.manual_seed(args.seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed) ckpt_path = Path(args.checkpoint) print(f"[gen_compound_uncond] device={device} ckpt={ckpt_path}") model = _load_compound_gpt(ckpt_path, device=device) bs = model.config.block_size max_new = bs - 1 pm_out = pretty_midi.PrettyMIDI(initial_tempo=120.0) t_off = 0.0 n_segments = 0 while t_off < args.target_seconds and n_segments < args.max_segments: torch.manual_seed(args.seed + n_segments) if torch.cuda.is_available(): torch.cuda.manual_seed_all(args.seed + n_segments) steps = _generate_one_sequence( model=model, device=device, max_new_steps=max_new, temperature=args.temperature, top_k=args.top_k, top_p=args.top_p, ) safe = _steps_to_safe_midi_steps(steps) if len(safe) < 2: print("[gen_compound_uncond] warning: empty segment, retrying sweep") break seg = decode_compound(safe) dur = seg.get_end_time() if dur < 0.05: print("[gen_compound_uncond] warning: near-empty decode, stopping") break _append_pm(pm_out, seg, t_off) t_off = pm_out.get_end_time() n_segments += 1 print( f"[gen_compound_uncond] segment={n_segments} steps={len(steps)} " f"seg_dur={dur:.2f}s total={t_off:.2f}s" ) if t_off < 1.0: raise RuntimeError( "Generated MIDI is too short; try different --seed or sampling params." ) midi_path = Path(args.out_midi) wav_path = Path(args.out_wav) midi_path.parent.mkdir(parents=True, exist_ok=True) wav_path.parent.mkdir(parents=True, exist_ok=True) pm_out.write(str(midi_path)) print(f"[gen_compound_uncond] midi -> {midi_path} duration={t_off:.2f}s") try: audio = pm_out.fluidsynth(fs=args.sample_rate) audio = np.asarray(audio, dtype=np.float32).reshape(-1) except (ImportError, OSError, ValueError) as e: print( f"[gen_compound_uncond] fluidsynth unavailable ({e!s}); " "using numpy additive synthesizer fallback." ) audio = _synthesize_wav_numpy(pm_out, args.sample_rate) max_samples = int(args.target_seconds * args.sample_rate) if audio.size > max_samples: audio = audio[:max_samples] audio = np.clip(audio, -1.0, 1.0) scipy.io.wavfile.write( str(wav_path), args.sample_rate, (audio * 32767.0).astype(np.int16) ) print(f"[gen_compound_uncond] wav -> {wav_path} samples={audio.size}") if __name__ == "__main__": main()