""" Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py """ from __future__ import annotations import argparse from pathlib import Path import torch import torch.nn.functional as F import os os.environ["CUDA_VISIBLE_DEVICES"] = "3" from train import MusicLLM # 直接导入 MusicLLM 类 from utils import parse_yaml from piano_transcription.utils import write_midi def sample(args): # Arguments config_path = args.config ckpt_path = args.ckpt_path # Configs configs = parse_yaml(config_path) num_samples = 5 # Number of samples to draw max_new_codes = 1000 # Number of codes generated in each sample temperature = 1.0 # 控制随机性 top_k = 200 # 保留 top-k 个概率最高的 tokens device = "cuda" sr = configs["sample_rate"] # 加载 MusicLLM 模型 model = MusicLLM(configs) checkpoint = torch.load(ckpt_path, map_location=device) # 提取 state_dict 并加载 llm 参数 state_dict = checkpoint["state_dict"] llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")} model.llm.load_state_dict(llm_state_dict) model.to(device) model.eval() # 初始音频 codes B = 1 start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device) # Sample for n in range(num_samples): with torch.no_grad(): current_codes = start_codes.clone() for _ in range(max_new_codes): logits = model(current_codes) logits = logits[:, -1, :] logits = logits / temperature if top_k is not None: v, idx = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < v[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_code = torch.multinomial(probs, num_samples=1) current_codes = torch.cat([current_codes, next_code], dim=1) audio_codes = current_codes.unsqueeze(dim=2) events = model.codec.decode(audio_codes) results_dir = Path("results", Path(config_path).stem) results_dir.mkdir(parents=True, exist_ok=True) midi_path = Path(results_dir, f"sample_{n}.midi") Path(midi_path).parent.mkdir(parents=True, exist_ok=True) write_midi(events, midi_path) print(f"Generated sample {n} saved to {midi_path}") if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, required=True, help="Path of config yaml.") parser.add_argument('--ckpt_path', type=str, required=True) args = parser.parse_args() sample(args)