| """ |
| Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| import os |
| os.environ["CUDA_VISIBLE_DEVICES"] = "3" |
|
|
| from train import MusicLLM, get_dataset |
| from utils import parse_yaml |
| from piano_transcription.utils import write_midi |
| import random |
| import numpy as np |
|
|
| def sample_unconditional(model, config_path, max_new_codes=1000, num_samples=5, temperature=1.0, top_k=200, device="cuda"): |
| """无条件生成,从零开始生成音频 codes""" |
| B = 1 |
| start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device) |
|
|
| for n in range(num_samples): |
| with torch.no_grad(): |
| current_codes = start_codes.clone() |
| for _ in range(max_new_codes): |
| logits = model(current_codes) |
| logits = logits[:, -1, :] |
| logits = logits / temperature |
| if top_k is not None: |
| v, idx = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
| probs = F.softmax(logits, dim=-1) |
| next_code = torch.multinomial(probs, num_samples=1) |
| current_codes = torch.cat([current_codes, next_code], dim=1) |
|
|
| audio_codes = current_codes.unsqueeze(dim=2) |
| events = model.codec.decode(audio_codes) |
|
|
| results_dir = Path("results", Path(config_path).stem, "unconditional") |
| results_dir.mkdir(parents=True, exist_ok=True) |
| midi_path = Path(results_dir, f"sample_{n}.midi") |
| write_midi(events, midi_path) |
| print(f"Unconditional sample {n} saved to {midi_path}") |
|
|
| def sample_continuation(model, config_path, dataset, max_new_codes=1000, num_samples=5, prompt_duration=10, temperature=1.0, top_k=200, device="cuda", sr=16000): |
| """条件续写:从数据集随机加载音频,用前 prompt_duration 秒作为 prompt,续写后保存原音频和续写结果""" |
| indices = random.sample(range(len(dataset)), num_samples) |
| |
| for n, idx in enumerate(indices): |
| with torch.no_grad(): |
| data = dataset[idx] |
| |
| audio = torch.from_numpy(data["audio"]) if isinstance(data["audio"], np.ndarray) else data["audio"] |
| if audio.dim() == 1: |
| audio = audio.unsqueeze(0) |
| elif audio.shape[0] > 1: |
| audio = audio[0:1] |
|
|
| audio = audio.unsqueeze(1) |
|
|
| prompt_samples = int(prompt_duration * sr) |
| if audio.shape[-1] < prompt_samples: |
| print(f"Audio {idx} too short ({audio.shape[-1]/sr:.2f}s), skipping") |
| continue |
|
|
| prompt_audio = audio[:,:, :prompt_samples] |
| prompt_codes = model.codec.encode(prompt_audio.to(device)).squeeze(dim=2) |
|
|
| current_codes = prompt_codes.clone() |
| for _ in range(max_new_codes): |
| logits = model(current_codes) |
| logits = logits[:, -1, :] |
| logits = logits / temperature |
| if top_k is not None: |
| v, idx = torch.topk(logits, min(top_k, logits.size(-1))) |
| logits[logits < v[:, [-1]]] = -float('Inf') |
| probs = F.softmax(logits, dim=-1) |
| next_code = torch.multinomial(probs, num_samples=1) |
| current_codes = torch.cat([current_codes, next_code], dim=1) |
|
|
| full_codes = current_codes.unsqueeze(dim=2) |
| full_events = model.codec.decode(full_codes) |
|
|
| orig_audio = audio[:,:, :int(40 * sr)] |
| orig_codes = model.codec.encode(orig_audio.to(device)) |
| orig_events = model.codec.decode(orig_codes) |
|
|
| results_dir = Path("results", Path(config_path).stem, "continuation") |
| results_dir.mkdir(parents=True, exist_ok=True) |
| |
| full_midi_path = Path(results_dir, f"continuation_{n}_full.midi") |
| write_midi(full_events, full_midi_path) |
| print(f"Continuation sample {n} (full) saved to {full_midi_path}") |
|
|
| orig_midi_path = Path(results_dir, f"continuation_{n}_original.midi") |
| write_midi(orig_events, orig_midi_path) |
| print(f"Original sample {n} saved to {orig_midi_path}") |
|
|
| def sample(args): |
| |
| config_path = args.config |
| ckpt_path = args.ckpt_path |
|
|
| |
| configs = parse_yaml(config_path) |
| device = "cuda" |
| sr = configs["sample_rate"] |
|
|
| |
| model = MusicLLM(configs) |
| checkpoint = torch.load(ckpt_path, map_location=device) |
| state_dict = checkpoint["state_dict"] |
| llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")} |
| model.llm.load_state_dict(llm_state_dict) |
| model.to(device) |
| model.eval() |
|
|
| |
| print("Starting unconditional sampling...") |
| sample_unconditional(model, config_path, max_new_codes=1000, num_samples=1, device=device) |
|
|
| |
| print("Starting continuation sampling...") |
| dataset = get_dataset(configs, "test") |
| sample_continuation(model, config_path, dataset, max_new_codes=600, num_samples=1, prompt_duration=15, device=device, sr=sr) |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") |
| parser.add_argument('--ckpt_path', type=str, required=True) |
| args = parser.parse_args() |
|
|
| sample(args) |