| """ |
| Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import os |
| os.environ["CUDA_VISIBLE_DEVICES"] = "3" |
|
|
| from train import MusicLLM |
| from utils import parse_yaml |
| from piano_transcription.utils import write_midi |
|
|
| def sample(args): |
| |
| config_path = args.config |
| ckpt_path = args.ckpt_path |
|
|
| |
| configs = parse_yaml(config_path) |
|
|
| num_samples = 5 |
| max_new_codes = 1000 |
| temperature = 1.0 |
| top_k = 200 |
| |
| device = "cuda" |
| sr = configs["sample_rate"] |
|
|
| |
| model = MusicLLM(configs) |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| |
| |
| state_dict = checkpoint["state_dict"] |
| llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")} |
| model.llm.load_state_dict(llm_state_dict) |
| |
| model.to(device) |
| model.eval() |
|
|
| |
| B = 1 |
| start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device) |
|
|
| |
| for n in range(num_samples): |
| with torch.no_grad(): |
| current_codes = start_codes.clone() |
| for _ in range(max_new_codes): |
| logits = model(current_codes) |
| logits = logits[:, -1, :] |
| logits = logits / temperature |
| if top_k is not None: |
| v, idx = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
| probs = F.softmax(logits, dim=-1) |
| next_code = torch.multinomial(probs, num_samples=1) |
| current_codes = torch.cat([current_codes, next_code], dim=1) |
|
|
| audio_codes = current_codes.unsqueeze(dim=2) |
| events = model.codec.decode(audio_codes) |
|
|
| results_dir = Path("results", Path(config_path).stem) |
| results_dir.mkdir(parents=True, exist_ok=True) |
| midi_path = Path(results_dir, f"sample_{n}.midi") |
| Path(midi_path).parent.mkdir(parents=True, exist_ok=True) |
| write_midi(events, midi_path) |
| print(f"Generated sample {n} saved to {midi_path}") |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") |
| parser.add_argument('--ckpt_path', type=str, required=True) |
| args = parser.parse_args() |
|
|
| sample(args) |