""" Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py """ from __future__ import annotations import argparse from pathlib import Path import torch import torch.nn.functional as F import os os.environ["CUDA_VISIBLE_DEVICES"] = "3" from train import MusicLLM, get_dataset from utils import parse_yaml from piano_transcription.utils import write_midi import random import numpy as np def sample_unconditional(model, config_path, max_new_codes=1000, num_samples=5, temperature=1.0, top_k=200, device="cuda"): """无条件生成,从零开始生成音频 codes""" B = 1 start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device) for n in range(num_samples): with torch.no_grad(): current_codes = start_codes.clone() for _ in range(max_new_codes): logits = model(current_codes) logits = logits[:, -1, :] logits = logits / temperature if top_k is not None: v, idx = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_code = torch.multinomial(probs, num_samples=1) current_codes = torch.cat([current_codes, next_code], dim=1) audio_codes = current_codes.unsqueeze(dim=2) events = model.codec.decode(audio_codes) results_dir = Path("results", Path(config_path).stem, "unconditional") results_dir.mkdir(parents=True, exist_ok=True) midi_path = Path(results_dir, f"sample_{n}.midi") write_midi(events, midi_path) print(f"Unconditional sample {n} saved to {midi_path}") def sample_continuation(model, config_path, dataset, max_new_codes=1000, num_samples=5, prompt_duration=10, temperature=1.0, top_k=200, device="cuda", sr=16000): """条件续写:从数据集随机加载音频,用前 prompt_duration 秒作为 prompt,续写后保存原音频和续写结果""" indices = random.sample(range(len(dataset)), num_samples) for n, idx in enumerate(indices): with torch.no_grad(): data = dataset[idx] # 将 NumPy 数组转换为 PyTorch 张量 audio = torch.from_numpy(data["audio"]) if isinstance(data["audio"], np.ndarray) else data["audio"] if audio.dim() == 1: audio = audio.unsqueeze(0) # 添加通道维度 elif audio.shape[0] > 1: audio = audio[0:1] # 多声道取第一个 audio = audio.unsqueeze(1) prompt_samples = int(prompt_duration * sr) if audio.shape[-1] < prompt_samples: print(f"Audio {idx} too short ({audio.shape[-1]/sr:.2f}s), skipping") continue prompt_audio = audio[:,:, :prompt_samples] prompt_codes = model.codec.encode(prompt_audio.to(device)).squeeze(dim=2) current_codes = prompt_codes.clone() for _ in range(max_new_codes): logits = model(current_codes) logits = logits[:, -1, :] logits = logits / temperature if top_k is not None: v, idx = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_code = torch.multinomial(probs, num_samples=1) current_codes = torch.cat([current_codes, next_code], dim=1) full_codes = current_codes.unsqueeze(dim=2) full_events = model.codec.decode(full_codes) orig_audio = audio[:,:, :int(40 * sr)] orig_codes = model.codec.encode(orig_audio.to(device)) orig_events = model.codec.decode(orig_codes) results_dir = Path("results", Path(config_path).stem, "continuation") results_dir.mkdir(parents=True, exist_ok=True) full_midi_path = Path(results_dir, f"continuation_{n}_full.midi") write_midi(full_events, full_midi_path) print(f"Continuation sample {n} (full) saved to {full_midi_path}") orig_midi_path = Path(results_dir, f"continuation_{n}_original.midi") write_midi(orig_events, orig_midi_path) print(f"Original sample {n} saved to {orig_midi_path}") def sample(args): # Arguments config_path = args.config ckpt_path = args.ckpt_path # Configs configs = parse_yaml(config_path) device = "cuda" sr = configs["sample_rate"] # 加载 MusicLLM 模型 model = MusicLLM(configs) checkpoint = torch.load(ckpt_path, map_location=device) state_dict = checkpoint["state_dict"] llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")} model.llm.load_state_dict(llm_state_dict) model.to(device) model.eval() # 无条件生成 print("Starting unconditional sampling...") sample_unconditional(model, config_path, max_new_codes=1000, num_samples=1, device=device) # 条件续写 print("Starting continuation sampling...") dataset = get_dataset(configs, "test") sample_continuation(model, config_path, dataset, max_new_codes=600, num_samples=1, prompt_duration=15, device=device, sr=sr) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") parser.add_argument('--ckpt_path', type=str, required=True) args = parser.parse_args() sample(args)