MusicTokenizer / unconditional_sample.py
ZheqiDAI's picture
Initial commit with cleaned files
fc4c601
Raw
History Blame Contribute Delete
2.75 kB
"""
Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py
"""
from __future__ import annotations
import argparse
from pathlib import Path
import torch
import torch.nn.functional as F
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
from train import MusicLLM # 直接导入 MusicLLM 类
from utils import parse_yaml
from piano_transcription.utils import write_midi
def sample(args):
# Arguments
config_path = args.config
ckpt_path = args.ckpt_path
# Configs
configs = parse_yaml(config_path)
num_samples = 5 # Number of samples to draw
max_new_codes = 1000 # Number of codes generated in each sample
temperature = 1.0 # 控制随机性
top_k = 200 # 保留 top-k 个概率最高的 tokens
device = "cuda"
sr = configs["sample_rate"]
# 加载 MusicLLM 模型
model = MusicLLM(configs)
checkpoint = torch.load(ckpt_path, map_location=device)
# 提取 state_dict 并加载 llm 参数
state_dict = checkpoint["state_dict"]
llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")}
model.llm.load_state_dict(llm_state_dict)
model.to(device)
model.eval()
# 初始音频 codes
B = 1
start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device)
# Sample
for n in range(num_samples):
with torch.no_grad():
current_codes = start_codes.clone()
for _ in range(max_new_codes):
logits = model(current_codes)
logits = logits[:, -1, :]
logits = logits / temperature
if top_k is not None:
v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
logits[logits < v[:, [-1]]] = -float('Inf')
probs = F.softmax(logits, dim=-1)
next_code = torch.multinomial(probs, num_samples=1)
current_codes = torch.cat([current_codes, next_code], dim=1)
audio_codes = current_codes.unsqueeze(dim=2)
events = model.codec.decode(audio_codes)
results_dir = Path("results", Path(config_path).stem)
results_dir.mkdir(parents=True, exist_ok=True)
midi_path = Path(results_dir, f"sample_{n}.midi")
Path(midi_path).parent.mkdir(parents=True, exist_ok=True)
write_midi(events, midi_path)
print(f"Generated sample {n} saved to {midi_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True, help="Path of config yaml.")
parser.add_argument('--ckpt_path', type=str, required=True)
args = parser.parse_args()
sample(args)