File size: 5,677 Bytes
fc4c601
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
"""
Modified from https://github.com/karpathy/nanoGPT/blob/master/sample.py
"""
from __future__ import annotations

import argparse
from pathlib import Path

import torch
import torch.nn.functional as F
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

from train import MusicLLM, get_dataset
from utils import parse_yaml
from piano_transcription.utils import write_midi
import random
import numpy as np

def sample_unconditional(model, config_path, max_new_codes=1000, num_samples=5, temperature=1.0, top_k=200, device="cuda"):
    """无条件生成,从零开始生成音频 codes"""
    B = 1
    start_codes = torch.zeros(size=(B, 1), dtype=torch.long, device=device)

    for n in range(num_samples):
        with torch.no_grad():
            current_codes = start_codes.clone()
            for _ in range(max_new_codes):
                logits = model(current_codes)
                logits = logits[:, -1, :]
                logits = logits / temperature
                if top_k is not None:
                    v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                probs = F.softmax(logits, dim=-1)
                next_code = torch.multinomial(probs, num_samples=1)
                current_codes = torch.cat([current_codes, next_code], dim=1)

            audio_codes = current_codes.unsqueeze(dim=2)
            events = model.codec.decode(audio_codes)

            results_dir = Path("results", Path(config_path).stem, "unconditional")
            results_dir.mkdir(parents=True, exist_ok=True)
            midi_path = Path(results_dir, f"sample_{n}.midi")
            write_midi(events, midi_path)
            print(f"Unconditional sample {n} saved to {midi_path}")

def sample_continuation(model, config_path, dataset, max_new_codes=1000, num_samples=5, prompt_duration=10, temperature=1.0, top_k=200, device="cuda", sr=16000):
    """条件续写:从数据集随机加载音频,用前 prompt_duration 秒作为 prompt,续写后保存原音频和续写结果"""
    indices = random.sample(range(len(dataset)), num_samples)
    
    for n, idx in enumerate(indices):
        with torch.no_grad():
            data = dataset[idx]
            # 将 NumPy 数组转换为 PyTorch 张量
            audio = torch.from_numpy(data["audio"]) if isinstance(data["audio"], np.ndarray) else data["audio"]
            if audio.dim() == 1:
                audio = audio.unsqueeze(0)  # 添加通道维度
            elif audio.shape[0] > 1:
                audio = audio[0:1]  # 多声道取第一个

            audio = audio.unsqueeze(1)

            prompt_samples = int(prompt_duration * sr)
            if audio.shape[-1] < prompt_samples:
                print(f"Audio {idx} too short ({audio.shape[-1]/sr:.2f}s), skipping")
                continue

            prompt_audio = audio[:,:, :prompt_samples]
            prompt_codes = model.codec.encode(prompt_audio.to(device)).squeeze(dim=2)

            current_codes = prompt_codes.clone()
            for _ in range(max_new_codes):
                logits = model(current_codes)
                logits = logits[:, -1, :]
                logits = logits / temperature
                if top_k is not None:
                    v, idx = torch.topk(logits, min(top_k, logits.size(-1)))
                    logits[logits < v[:, [-1]]] = -float('Inf')
                probs = F.softmax(logits, dim=-1)
                next_code = torch.multinomial(probs, num_samples=1)
                current_codes = torch.cat([current_codes, next_code], dim=1)

            full_codes = current_codes.unsqueeze(dim=2)
            full_events = model.codec.decode(full_codes)

            orig_audio = audio[:,:, :int(40 * sr)]
            orig_codes = model.codec.encode(orig_audio.to(device))
            orig_events = model.codec.decode(orig_codes)

            results_dir = Path("results", Path(config_path).stem, "continuation")
            results_dir.mkdir(parents=True, exist_ok=True)
            
            full_midi_path = Path(results_dir, f"continuation_{n}_full.midi")
            write_midi(full_events, full_midi_path)
            print(f"Continuation sample {n} (full) saved to {full_midi_path}")

            orig_midi_path = Path(results_dir, f"continuation_{n}_original.midi")
            write_midi(orig_events, orig_midi_path)
            print(f"Original sample {n} saved to {orig_midi_path}")

def sample(args):
    # Arguments
    config_path = args.config
    ckpt_path = args.ckpt_path

    # Configs
    configs = parse_yaml(config_path)
    device = "cuda"
    sr = configs["sample_rate"]

    # 加载 MusicLLM 模型
    model = MusicLLM(configs)
    checkpoint = torch.load(ckpt_path, map_location=device)
    state_dict = checkpoint["state_dict"]
    llm_state_dict = {k.replace("llm.", ""): v for k, v in state_dict.items() if k.startswith("llm.")}
    model.llm.load_state_dict(llm_state_dict)
    model.to(device)
    model.eval()

    # 无条件生成
    print("Starting unconditional sampling...")
    sample_unconditional(model, config_path, max_new_codes=1000, num_samples=1, device=device)

    # 条件续写
    print("Starting continuation sampling...")
    dataset = get_dataset(configs, "test")
    sample_continuation(model, config_path, dataset, max_new_codes=600, num_samples=1, prompt_duration=15, device=device, sr=sr)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", type=str, required=True, help="Path of config yaml.")
    parser.add_argument('--ckpt_path', type=str, required=True)
    args = parser.parse_args()

    sample(args)