Spaces:
Sleeping
Sleeping
| """Autoregressive MIDI token generation from a trained checkpoint.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| import sys | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Sequence, Tuple | |
| import pretty_midi | |
| import torch | |
| import torch.nn.functional as F | |
| _SCRIPT_DIR = Path(__file__).resolve().parent | |
| _ROOT = _SCRIPT_DIR.parent | |
| if str(_SCRIPT_DIR) not in sys.path: | |
| sys.path.insert(0, str(_SCRIPT_DIR)) | |
| from bpe import ( # noqa: E402 | |
| Merge, | |
| apply_bpe, | |
| load as load_bpe_merges, | |
| unapply_bpe, | |
| ) | |
| from model import GPT, GPTConfig, default_gpt_config # noqa: E402 | |
| from tokenizer import ID2TOKEN, decode, encode # noqa: E402 | |
| DEFAULT_BPE_MERGES_PATH = _ROOT / "data" / "bpe" / "merges.json" | |
| def _pick_device() -> torch.device: | |
| if torch.cuda.is_available(): | |
| return torch.device("cuda") | |
| mps = getattr(torch.backends, "mps", None) | |
| if mps is not None and mps.is_available(): | |
| return torch.device("mps") | |
| return torch.device("cpu") | |
| def top_k_filter(logits: torch.Tensor, k: int) -> torch.Tensor: | |
| """Keep only top-k logits per row and mask the rest.""" | |
| if k <= 0 or k >= logits.size(-1): | |
| return logits | |
| values, _ = torch.topk(logits, k) | |
| threshold = values[:, -1].unsqueeze(-1) | |
| return logits.masked_fill(logits < threshold, float("-inf")) | |
| def _extract_gpt_config_dict(raw: Dict[str, Any]) -> Dict[str, Any]: | |
| keys = set(GPTConfig.__dataclass_fields__.keys()) | |
| return {k: raw[k] for k in keys if k in raw} | |
| def _load_config_from_sources( | |
| checkpoint_data: Dict[str, Any], config_path: str | |
| ) -> GPTConfig: | |
| cfg = default_gpt_config() | |
| ckpt_cfg = checkpoint_data.get("config") | |
| if isinstance(ckpt_cfg, dict): | |
| for k, v in _extract_gpt_config_dict(ckpt_cfg).items(): | |
| setattr(cfg, k, v) | |
| if config_path: | |
| loaded = json.loads(Path(config_path).read_text()) | |
| if not isinstance(loaded, dict): | |
| raise ValueError("--config must point to a JSON object.") | |
| for k, v in _extract_gpt_config_dict(loaded).items(): | |
| setattr(cfg, k, v) | |
| return cfg | |
| def _load_jsb_prompt(seed: int) -> Tuple[List[int], str]: | |
| try: | |
| from music21 import corpus | |
| except Exception as e: | |
| raise RuntimeError( | |
| "JSB prompt mode requires music21 to be installed." | |
| ) from e | |
| rng = random.Random(seed) | |
| all_scores = list( | |
| corpus.chorales.Iterator( | |
| numberingSystem="bwv", | |
| returnType="stream", | |
| ) | |
| ) | |
| if not all_scores: | |
| raise RuntimeError("No JSB chorales found via music21 corpus.") | |
| idx = rng.randrange(len(all_scores)) | |
| score = all_scores[idx] | |
| with tempfile.NamedTemporaryFile(suffix=".mid", delete=True) as tmp: | |
| score.write("midi", fp=tmp.name) | |
| pm = pretty_midi.PrettyMIDI(tmp.name) | |
| return encode(pm), f"jsb chorale #{idx}" | |
| def _load_prompt_tokens( | |
| prompt: str, | |
| prompt_tokens: int, | |
| seed: int, | |
| merges: Sequence[Merge], | |
| vocab_size: int, | |
| ) -> Tuple[List[int], str]: | |
| if prompt == "random": | |
| rng = random.Random(seed) | |
| ids = [rng.randrange(vocab_size) for _ in range(prompt_tokens)] | |
| return ids, "random" | |
| if prompt == "jsb": | |
| ids, label = _load_jsb_prompt(seed=seed) | |
| if merges: | |
| ids = apply_bpe(ids, merges) | |
| return ids[:prompt_tokens], label | |
| midi_path = Path(prompt) | |
| if not midi_path.exists(): | |
| raise FileNotFoundError(f"Prompt MIDI not found: {midi_path}") | |
| pm = pretty_midi.PrettyMIDI(str(midi_path)) | |
| ids = encode(pm) | |
| if merges: | |
| ids = apply_bpe(ids, merges) | |
| return ids[:prompt_tokens], str(midi_path) | |
| def generate_tokens( | |
| model: GPT, | |
| prompt_ids: torch.Tensor, | |
| gen_tokens: int, | |
| temperature: float, | |
| top_k: int, | |
| ) -> torch.Tensor: | |
| if temperature <= 0.0: | |
| raise ValueError("temperature must be > 0") | |
| generated = prompt_ids | |
| for _ in range(gen_tokens): | |
| context = generated[:, -model.config.block_size:] | |
| logits = model(context) | |
| logits = logits[:, -1, :] / temperature | |
| if top_k > 0: | |
| logits = top_k_filter(logits, top_k) | |
| probs = F.softmax(logits, dim=-1) | |
| next_token = torch.multinomial(probs, num_samples=1) | |
| generated = torch.cat([generated, next_token], dim=1) | |
| return generated | |
| def _token_text(ids: List[int], max_len: int = 120) -> str: | |
| toks = [ID2TOKEN.get(i, f"UNK({i})") for i in ids[:max_len]] | |
| suffix = " ..." if len(ids) > max_len else "" | |
| return " ".join(toks) + suffix | |
| def parse_args() -> argparse.Namespace: | |
| p = argparse.ArgumentParser( | |
| description="Generate MIDI from a trained GPT checkpoint" | |
| ) | |
| p.add_argument( | |
| "--checkpoint", | |
| type=str, | |
| default=str(_ROOT / "results" / "checkpoints" / "best_model.pt"), | |
| ) | |
| p.add_argument( | |
| "--config", | |
| type=str, | |
| default="", | |
| help="Optional JSON config path; overrides checkpoint config.", | |
| ) | |
| p.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="random", | |
| help='Prompt source: "jsb", "random", or path to .mid/.midi file.', | |
| ) | |
| p.add_argument("--prompt-tokens", type=int, default=64) | |
| p.add_argument("--gen-tokens", type=int, default=128) | |
| p.add_argument("--temperature", type=float, default=0.8) | |
| p.add_argument("--top-k", type=int, default=40) | |
| p.add_argument("--seed", type=int, default=42) | |
| p.add_argument( | |
| "--out", | |
| type=str, | |
| default=str(_ROOT / "results" / "generated.mid"), | |
| ) | |
| p.add_argument( | |
| "--bpe-merges", | |
| type=str, | |
| default=str(DEFAULT_BPE_MERGES_PATH), | |
| help="BPE merges JSON path. Skipped silently if file missing.", | |
| ) | |
| return p.parse_args() | |
| def main() -> None: | |
| args = parse_args() | |
| torch.manual_seed(args.seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(args.seed) | |
| device = _pick_device() | |
| print(f"[generate] device={device}") | |
| ckpt = torch.load( | |
| args.checkpoint, | |
| map_location=device, | |
| weights_only=True, | |
| ) | |
| cfg = _load_config_from_sources(ckpt, args.config) | |
| model = GPT(cfg).to(device) | |
| state = ( | |
| ckpt["model"] if isinstance(ckpt, dict) and "model" in ckpt else ckpt | |
| ) | |
| model.load_state_dict(state) | |
| model.eval() | |
| merges_path = Path(args.bpe_merges) | |
| merges: List[Merge] = ( | |
| load_bpe_merges(merges_path) if merges_path.exists() else [] | |
| ) | |
| if merges: | |
| print(f"[generate] BPE merges loaded: {len(merges)} from {merges_path}") | |
| prompt_ids_list, prompt_label = _load_prompt_tokens( | |
| prompt=args.prompt, | |
| prompt_tokens=args.prompt_tokens, | |
| seed=args.seed, | |
| merges=merges, | |
| vocab_size=cfg.vocab_size, | |
| ) | |
| if not prompt_ids_list: | |
| raise ValueError("Prompt produced zero tokens.") | |
| x = torch.tensor([prompt_ids_list], dtype=torch.long, device=device) | |
| out_ids = generate_tokens( | |
| model=model, | |
| prompt_ids=x, | |
| gen_tokens=args.gen_tokens, | |
| temperature=args.temperature, | |
| top_k=args.top_k, | |
| )[0].tolist() | |
| prompt_len = len(prompt_ids_list) | |
| cont_ids = out_ids[prompt_len:] | |
| base_out_ids = unapply_bpe(out_ids, merges) if merges else out_ids | |
| base_cont_ids = unapply_bpe(cont_ids, merges) if merges else cont_ids | |
| out_path = Path(args.out) | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| decode(base_out_ids).write(str(out_path)) | |
| cont_path = out_path.with_name( | |
| f"{out_path.stem}_continuation{out_path.suffix}" | |
| ) | |
| if base_cont_ids: | |
| decode(base_cont_ids).write(str(cont_path)) | |
| print(f"[generate] prompt: {prompt_len} tokens ({prompt_label})") | |
| print(f"[generate] generated: {len(cont_ids)} tokens") | |
| print(f"[generate] temperature={args.temperature}, top_k={args.top_k}") | |
| print(f"[generate] output -> {out_path}") | |
| if cont_ids: | |
| print(f"[generate] continuation_only -> {cont_path}") | |
| print("[generate] token preview:") | |
| print(_token_text(out_ids)) | |
| if __name__ == "__main__": | |
| main() | |