File size: 5,677 Bytes
fc4c601 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from train import MusicLLM, get_dataset
from utils import parse_yaml
from piano_transcription.utils import write_midi
import random
import numpy as np
def sample_unconditional(model, config_path, max_new_codes=1000, num_samples=5, temperature=1.0, top_k=200, device="cuda"):
"""无条件生成,从零开始生成音频 codes"""
B = 1
start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device)
for n in range(num_samples):
with torch.no_grad():
current_codes = start_codes.clone()
for _ in range(max_new_codes):
logits = model(current_codes)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k is not None:
v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_code = torch.multinomial(probs, num_samples=1)
current_codes = torch.cat([current_codes, next_code], dim=1)
audio_codes = current_codes.unsqueeze(dim=2)
events = model.codec.decode(audio_codes)
results_dir = Path("results", Path(config_path).stem, "unconditional")
results_dir.mkdir(parents=True, exist_ok=True)
midi_path = Path(results_dir, f"sample_{n}.midi")
write_midi(events, midi_path)
print(f"Unconditional sample {n} saved to {midi_path}")
def sample_continuation(model, config_path, dataset, max_new_codes=1000, num_samples=5, prompt_duration=10, temperature=1.0, top_k=200, device="cuda", sr=16000):
"""条件续写:从数据集随机加载音频,用前 prompt_duration 秒作为 prompt,续写后保存原音频和续写结果"""
indices = random.sample(range(len(dataset)), num_samples)
for n, idx in enumerate(indices):
with torch.no_grad():
data = dataset[idx]
# 将 NumPy 数组转换为 PyTorch 张量
audio = torch.from_numpy(data["audio"]) if isinstance(data["audio"], np.ndarray) else data["audio"]
if audio.dim() == 1:
audio = audio.unsqueeze(0) # 添加通道维度
elif audio.shape[0] > 1:
audio = audio[0:1] # 多声道取第一个
audio = audio.unsqueeze(1)
prompt_samples = int(prompt_duration * sr)
if audio.shape[-1] < prompt_samples:
print(f"Audio {idx} too short ({audio.shape[-1]/sr:.2f}s), skipping")
continue
prompt_audio = audio[:,:, :prompt_samples]
prompt_codes = model.codec.encode(prompt_audio.to(device)).squeeze(dim=2)
current_codes = prompt_codes.clone()
for _ in range(max_new_codes):
logits = model(current_codes)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k is not None:
v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_code = torch.multinomial(probs, num_samples=1)
current_codes = torch.cat([current_codes, next_code], dim=1)
full_codes = current_codes.unsqueeze(dim=2)
full_events = model.codec.decode(full_codes)
orig_audio = audio[:,:, :int(40 * sr)]
orig_codes = model.codec.encode(orig_audio.to(device))
orig_events = model.codec.decode(orig_codes)
results_dir = Path("results", Path(config_path).stem, "continuation")
results_dir.mkdir(parents=True, exist_ok=True)
full_midi_path = Path(results_dir, f"continuation_{n}_full.midi")
write_midi(full_events, full_midi_path)
print(f"Continuation sample {n} (full) saved to {full_midi_path}")
orig_midi_path = Path(results_dir, f"continuation_{n}_original.midi")
write_midi(orig_events, orig_midi_path)
print(f"Original sample {n} saved to {orig_midi_path}")
def sample(args):
# Arguments
config_path = args.config
ckpt_path = args.ckpt_path
# Configs
configs = parse_yaml(config_path)
device = "cuda"
sr = configs["sample_rate"]
# 加载 MusicLLM 模型
model = MusicLLM(configs)
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint["state_dict"]
llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")}
model.llm.load_state_dict(llm_state_dict)
model.to(device)
model.eval()
# 无条件生成
print("Starting unconditional sampling...")
sample_unconditional(model, config_path, max_new_codes=1000, num_samples=1, device=device)
# 条件续写
print("Starting continuation sampling...")
dataset = get_dataset(configs, "test")
sample_continuation(model, config_path, dataset, max_new_codes=600, num_samples=1, prompt_duration=15, device=device, sr=sr)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path of config yaml.")
parser.add_argument('--ckpt_path', type=str, required=True)
args = parser.parse_args()
sample(args) |