MusicTokenizer / sample.py
ZheqiDAI's picture
Initial commit with cleaned files
fc4c601
Raw
History Blame Contribute Delete
5.68 kB
"""
Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from train import MusicLLM, get_dataset
from utils import parse_yaml
from piano_transcription.utils import write_midi
import random
import numpy as np
def sample_unconditional(model, config_path, max_new_codes=1000, num_samples=5, temperature=1.0, top_k=200, device="cuda"):
"""无条件生成,从零开始生成音频 codes"""
B = 1
start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device)
for n in range(num_samples):
with torch.no_grad():
current_codes = start_codes.clone()
for _ in range(max_new_codes):
logits = model(current_codes)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k is not None:
v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_code = torch.multinomial(probs, num_samples=1)
current_codes = torch.cat([current_codes, next_code], dim=1)
audio_codes = current_codes.unsqueeze(dim=2)
events = model.codec.decode(audio_codes)
results_dir = Path("results", Path(config_path).stem, "unconditional")
results_dir.mkdir(parents=True, exist_ok=True)
midi_path = Path(results_dir, f"sample_{n}.midi")
write_midi(events, midi_path)
print(f"Unconditional sample {n} saved to {midi_path}")
def sample_continuation(model, config_path, dataset, max_new_codes=1000, num_samples=5, prompt_duration=10, temperature=1.0, top_k=200, device="cuda", sr=16000):
"""条件续写:从数据集随机加载音频,用前 prompt_duration 秒作为 prompt,续写后保存原音频和续写结果"""
indices = random.sample(range(len(dataset)), num_samples)
for n, idx in enumerate(indices):
with torch.no_grad():
data = dataset[idx]
# 将 NumPy 数组转换为 PyTorch 张量
audio = torch.from_numpy(data["audio"]) if isinstance(data["audio"], np.ndarray) else data["audio"]
if audio.dim() == 1:
audio = audio.unsqueeze(0) # 添加通道维度
elif audio.shape[0] > 1:
audio = audio[0:1] # 多声道取第一个
audio = audio.unsqueeze(1)
prompt_samples = int(prompt_duration * sr)
if audio.shape[-1] < prompt_samples:
print(f"Audio {idx} too short ({audio.shape[-1]/sr:.2f}s), skipping")
continue
prompt_audio = audio[:,:, :prompt_samples]
prompt_codes = model.codec.encode(prompt_audio.to(device)).squeeze(dim=2)
current_codes = prompt_codes.clone()
for _ in range(max_new_codes):
logits = model(current_codes)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k is not None:
v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_code = torch.multinomial(probs, num_samples=1)
current_codes = torch.cat([current_codes, next_code], dim=1)
full_codes = current_codes.unsqueeze(dim=2)
full_events = model.codec.decode(full_codes)
orig_audio = audio[:,:, :int(40 * sr)]
orig_codes = model.codec.encode(orig_audio.to(device))
orig_events = model.codec.decode(orig_codes)
results_dir = Path("results", Path(config_path).stem, "continuation")
results_dir.mkdir(parents=True, exist_ok=True)
full_midi_path = Path(results_dir, f"continuation_{n}_full.midi")
write_midi(full_events, full_midi_path)
print(f"Continuation sample {n} (full) saved to {full_midi_path}")
orig_midi_path = Path(results_dir, f"continuation_{n}_original.midi")
write_midi(orig_events, orig_midi_path)
print(f"Original sample {n} saved to {orig_midi_path}")
def sample(args):
# Arguments
config_path = args.config
ckpt_path = args.ckpt_path
# Configs
configs = parse_yaml(config_path)
device = "cuda"
sr = configs["sample_rate"]
# 加载 MusicLLM 模型
model = MusicLLM(configs)
checkpoint = torch.load(ckpt_path, map_location=device)
state_dict = checkpoint["state_dict"]
llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")}
model.llm.load_state_dict(llm_state_dict)
model.to(device)
model.eval()
# 无条件生成
print("Starting unconditional sampling...")
sample_unconditional(model, config_path, max_new_codes=1000, num_samples=1, device=device)
# 条件续写
print("Starting continuation sampling...")
dataset = get_dataset(configs, "test")
sample_continuation(model, config_path, dataset, max_new_codes=600, num_samples=1, prompt_duration=15, device=device, sr=sr)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path of config yaml.")
parser.add_argument('--ckpt_path', type=str, required=True)
args = parser.parse_args()
sample(args)