Coda / src /generate.py
Prajanya Gupta
initial deploy
6b7b403
"""Autoregressive MIDI token generation from a trained checkpoint."""
from __future__ import annotations
import argparse
import json
import random
import sys
import tempfile
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
import pretty_midi
import torch
import torch.nn.functional as F
_SCRIPT_DIR = Path(__file__).resolve().parent
_ROOT = _SCRIPT_DIR.parent
if str(_SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(_SCRIPT_DIR))
from bpe import ( # noqa: E402
Merge,
apply_bpe,
load as load_bpe_merges,
unapply_bpe,
)
from model import GPT, GPTConfig, default_gpt_config # noqa: E402
from tokenizer import ID2TOKEN, decode, encode # noqa: E402
DEFAULT_BPE_MERGES_PATH = _ROOT / "data" / "bpe" / "merges.json"
def _pick_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda")
mps = getattr(torch.backends, "mps", None)
if mps is not None and mps.is_available():
return torch.device("mps")
return torch.device("cpu")
def top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor:
"""Keep only top-k logits per row and mask the rest."""
if k <= 0 or k >= logits.size(-1):
return logits
values, _ = torch.topk(logits, k)
threshold = values[:, -1].unsqueeze(-1)
return logits.masked_fill(logits < threshold, float("-inf"))
def _extract_gpt_config_dict(raw: Dict[str, Any]) -> Dict[str, Any]:
keys = set(GPTConfig.__dataclass_fields__.keys())
return {k: raw[k] for k in keys if k in raw}
def _load_config_from_sources(
checkpoint_data: Dict[str, Any], config_path: str
) -> GPTConfig:
cfg = default_gpt_config()
ckpt_cfg = checkpoint_data.get("config")
if isinstance(ckpt_cfg, dict):
for k, v in _extract_gpt_config_dict(ckpt_cfg).items():
setattr(cfg, k, v)
if config_path:
loaded = json.loads(Path(config_path).read_text())
if not isinstance(loaded, dict):
raise ValueError("--config must point to a JSON object.")
for k, v in _extract_gpt_config_dict(loaded).items():
setattr(cfg, k, v)
return cfg
def _load_jsb_prompt(seed: int) -> Tuple[List[int], str]:
try:
from music21 import corpus
except Exception as e:
raise RuntimeError(
"JSB prompt mode requires music21 to be installed."
) from e
rng = random.Random(seed)
all_scores = list(
corpus.chorales.Iterator(
numberingSystem="bwv",
returnType="stream",
)
)
if not all_scores:
raise RuntimeError("No JSB chorales found via music21 corpus.")
idx = rng.randrange(len(all_scores))
score = all_scores[idx]
with tempfile.NamedTemporaryFile(suffix=".mid", delete=True) as tmp:
score.write("midi", fp=tmp.name)
pm = pretty_midi.PrettyMIDI(tmp.name)
return encode(pm), f"jsb chorale #{idx}"
def _load_prompt_tokens(
prompt: str,
prompt_tokens: int,
seed: int,
merges: Sequence[Merge],
vocab_size: int,
) -> Tuple[List[int], str]:
if prompt == "random":
rng = random.Random(seed)
ids = [rng.randrange(vocab_size) for _ in range(prompt_tokens)]
return ids, "random"
if prompt == "jsb":
ids, label = _load_jsb_prompt(seed=seed)
if merges:
ids = apply_bpe(ids, merges)
return ids[:prompt_tokens], label
midi_path = Path(prompt)
if not midi_path.exists():
raise FileNotFoundError(f"Prompt MIDI not found: {midi_path}")
pm = pretty_midi.PrettyMIDI(str(midi_path))
ids = encode(pm)
if merges:
ids = apply_bpe(ids, merges)
return ids[:prompt_tokens], str(midi_path)
@torch.no_grad()
def generate_tokens(
model: GPT,
prompt_ids: torch.Tensor,
gen_tokens: int,
temperature: float,
top_k: int,
) -> torch.Tensor:
if temperature <= 0.0:
raise ValueError("temperature must be > 0")
generated = prompt_ids
for _ in range(gen_tokens):
context = generated[:, -model.config.block_size:]
logits = model(context)
logits = logits[:, -1, :] / temperature
if top_k > 0:
logits = top_k_filter(logits, top_k)
probs = F.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated = torch.cat([generated, next_token], dim=1)
return generated
def _token_text(ids: List[int], max_len: int = 120) -> str:
toks = [ID2TOKEN.get(i, f"UNK({i})") for i in ids[:max_len]]
suffix = " ..." if len(ids) > max_len else ""
return " ".join(toks) + suffix
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(
description="Generate MIDI from a trained GPT checkpoint"
)
p.add_argument(
"--checkpoint",
type=str,
default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"),
)
p.add_argument(
"--config",
type=str,
default="",
help="Optional JSON config path; overrides checkpoint config.",
)
p.add_argument(
"--prompt",
type=str,
default="random",
help='Prompt source: "jsb", "random", or path to .mid/.midi file.',
)
p.add_argument("--prompt-tokens", type=int, default=64)
p.add_argument("--gen-tokens", type=int, default=128)
p.add_argument("--temperature", type=float, default=0.8)
p.add_argument("--top-k", type=int, default=40)
p.add_argument("--seed", type=int, default=42)
p.add_argument(
"--out",
type=str,
default=str(_ROOT / "results" / "generated.mid"),
)
p.add_argument(
"--bpe-merges",
type=str,
default=str(DEFAULT_BPE_MERGES_PATH),
help="BPE merges JSON path. Skipped silently if file missing.",
)
return p.parse_args()
def main() -> None:
args = parse_args()
torch.manual_seed(args.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(args.seed)
device = _pick_device()
print(f"[generate] device={device}")
ckpt = torch.load(
args.checkpoint,
map_location=device,
weights_only=True,
)
cfg = _load_config_from_sources(ckpt, args.config)
model = GPT(cfg).to(device)
state = (
ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt
)
model.load_state_dict(state)
model.eval()
merges_path = Path(args.bpe_merges)
merges: List[Merge] = (
load_bpe_merges(merges_path) if merges_path.exists() else []
)
if merges:
print(f"[generate] BPE merges loaded: {len(merges)} from {merges_path}")
prompt_ids_list, prompt_label = _load_prompt_tokens(
prompt=args.prompt,
prompt_tokens=args.prompt_tokens,
seed=args.seed,
merges=merges,
vocab_size=cfg.vocab_size,
)
if not prompt_ids_list:
raise ValueError("Prompt produced zero tokens.")
x = torch.tensor([prompt_ids_list], dtype=torch.long, device=device)
out_ids = generate_tokens(
model=model,
prompt_ids=x,
gen_tokens=args.gen_tokens,
temperature=args.temperature,
top_k=args.top_k,
)[0].tolist()
prompt_len = len(prompt_ids_list)
cont_ids = out_ids[prompt_len:]
base_out_ids = unapply_bpe(out_ids, merges) if merges else out_ids
base_cont_ids = unapply_bpe(cont_ids, merges) if merges else cont_ids
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
decode(base_out_ids).write(str(out_path))
cont_path = out_path.with_name(
f"{out_path.stem}_continuation{out_path.suffix}"
)
if base_cont_ids:
decode(base_cont_ids).write(str(cont_path))
print(f"[generate] prompt: {prompt_len} tokens ({prompt_label})")
print(f"[generate] generated: {len(cont_ids)} tokens")
print(f"[generate] temperature={args.temperature}, top_k={args.top_k}")
print(f"[generate] output -> {out_path}")
if cont_ids:
print(f"[generate] continuation_only -> {cont_path}")
print("[generate] token preview:")
print(_token_text(out_ids))
if __name__ == "__main__":
main()