File size: 9,847 Bytes
ffbb4ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
"""
YingMusicSinger 批量推理脚本
支持多卡多进程、进度条显示
输入支持 JSONL 文件 或 LyricEditBench 数据集

用法:
    # JSONL 输入,4卡
    python batch_infer.py \
        --input_type jsonl \
        --input_path /path/to/input.jsonl \
        --output_dir /path/to/output \
        --ckpt_path /path/to/ckpts \
        --num_gpus 4

    # LyricEditBench 输入
    python batch_infer.py \
        --input_type lyric_edit_bench \
        --output_dir /path/to/output \
        --ckpt_path /path/to/ckpts \
        --num_gpus 4
"""

import argparse
import json
import os
import sys
import traceback
from pathlib import Path

import torch
import torch.multiprocessing as mp
import torchaudio
from datasets import Audio, Dataset
from huggingface_hub import hf_hub_download
from tqdm import tqdm


def load_jsonl(path: str) -> list[dict]:
    items = []
    with open(path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                items.append(json.loads(line))
    return items


def build_dataset_from_local(gtsinger_root: str):
    """
    Build LyricEditBench dataset using your local GTSinger directory.

    Args:
        gtsinger_root: Root directory of your local GTSinger dataset.
    """
    # Download the inherited metadata from HuggingFace
    json_path = hf_hub_download(
        repo_id="ASLP-lab/LyricEditBench",
        filename="GTSinger_Inherited.json",
        repo_type="dataset",
    )

    with open(json_path, "r") as f:
        data = json.load(f)

    gtsinger_root = str(Path(gtsinger_root).resolve())

    # Prepend local root to relative paths
    for item in data:
        item["melody_ref_path"] = os.path.join(gtsinger_root, item["melody_ref_path"])
        item["timbre_ref_path"] = os.path.join(gtsinger_root, item["timbre_ref_path"])
        # Set audio fields to the resolved file paths
        item["melody_ref_audio"] = item["melody_ref_path"]
        item["timbre_ref_audio"] = item["timbre_ref_path"]

    # Build HuggingFace Dataset with Audio features
    ds = Dataset.from_list(data)
    ds = ds.cast_column("melody_ref_audio", Audio())
    ds = ds.cast_column("timbre_ref_audio", Audio())

    return ds


def load_subset(data: list, subset_id: str) -> list:
    """Filter dataset by a subset ID list."""
    subset_path = hf_hub_download(
        repo_id="ASLP-lab/LyricEditBench",
        filename=f"id_lists/{subset_id}.txt",
        repo_type="dataset",
    )

    with open(subset_path, "r") as f:
        id_set = set(line.strip() for line in f if line.strip())

    return [item for item in data if item["id"] in id_set]


def load_lyric_edit_bench(input_type) -> list[dict]:
    # If you have GTsinger downloaded, use this:

    ds_full = build_dataset_from_local(
        "/user-fs/chenzihao/zhengjunjie/datas/Music/openvocaldata/GTSinger"
    )

    # else, you kan use this:
    # from datasets import load_dataset

    # ds_full = load_dataset("ASLP-lab/LyricEditBench", split="test")

    # ds_full loaded

    subset_1k = load_subset(ds_full, "1K")
    print(f"Loaded {len(subset_1k)} items")

    items = []
    for row in subset_1k:
        if input_type == "lyric_edit_bench_melody_control":
            items.append(
                {
                    "id": row.get("id", ""),
                    "melody_ref_path": row.get("melody_ref_path", ""),
                    "gen_text": row.get("gen_text", ""),
                    "timbre_ref_path": row.get("timbre_ref_path", ""),
                    "timbre_ref_text": row.get("timbre_ref_text", ""),
                }
            )
        elif input_type == "lyric_edit_bench_sing_edit":
            items.append(
                {
                    "id": row.get("id", ""),
                    "melody_ref_path": row.get("melody_ref_path", ""),
                    "gen_text": row.get("gen_text", ""),
                    "timbre_ref_path": row.get("melody_ref_path", ""),
                    "timbre_ref_text": row.get("melody_ref_text", ""),
                }
            )
        else:
            assert 0
    return items


def worker(
    rank: int,
    world_size: int,
    items: list[dict],
    output_dir: str,
    ckpt_path: str,
    args: argparse.Namespace,
):
    """每个 GPU 上运行的 worker 进程"""
    device = f"cuda:{rank}"
    torch.cuda.set_device(rank)

    # ---- 加载模型 ----
    from src.YingMusicSinger.infer.YingMusicSinger import YingMusicSinger

    model = YingMusicSinger.from_pretrained(ckpt_path)
    model.to(device)
    model.eval()

    # ---- 分片: 每个 worker 处理自己那份 ----
    shard = items[rank::world_size]

    # ---- 只在 rank 0 显示进度条 ----
    pbar = tqdm(
        shard,
        desc=f"[GPU {rank}]",
        position=rank,
        leave=True,
        disable=(rank != 0 and not args.show_all_progress),
    )

    success, fail = 0, 0
    for item in pbar:
        item_id = item.get("id", f"unknown_{success + fail}")
        out_path = os.path.join(output_dir, f"{item_id}.wav")

        # 跳过已存在的文件
        if os.path.exists(out_path) and not args.overwrite:
            success += 1
            pbar.set_postfix(ok=success, err=fail)
            continue

        try:
            with torch.no_grad():
                audio, sr = model(
                    ref_audio_path=item["timbre_ref_path"],
                    melody_audio_path=item["melody_ref_path"],
                    ref_text=item.get("timbre_ref_text", ""),
                    target_text=item.get("gen_text", ""),
                    lrc_align_mode=args.lrc_align_mode,
                    sil_len_to_end=args.sil_len_to_end,
                    t_shift=args.t_shift,
                    nfe_step=args.nfe_step,
                    cfg_strength=args.cfg_strength,
                    seed=args.seed
                    if args.seed != -1
                    else torch.randint(0, 2**32, (1,)).item(),
                )

            torchaudio.save(out_path, audio, sample_rate=sr)
            success += 1

        except Exception as e:
            fail += 1
            print(f"\n[GPU {rank}] ERROR on {item_id}: {e}", file=sys.stderr)
            if args.verbose:
                traceback.print_exc()

        pbar.set_postfix(ok=success, err=fail)

    pbar.close()
    print(f"[GPU {rank}] Done. success={success}, fail={fail}")


def main():
    parser = argparse.ArgumentParser(description="YingMusicSinger 批量推理")

    # ---- 输入 ----
    parser.add_argument(
        "--input_type",
        type=str,
        required=True,
        choices=[
            "jsonl",
            "lyric_edit_bench_melody_control",
            "lyric_edit_bench_sing_edit",
        ],
        help="输入类型: jsonl / lyric_edit_bench_melody_control 或 lyric_edit_bench_sing_edit",
    )
    parser.add_argument(
        "--input_path",
        type=str,
        default=None,
        help="JSONL 文件路径 (input_type=jsonl 时必填)",
    )

    # ---- 输出 ----
    parser.add_argument(
        "--output_dir",
        type=str,
        required=True,
        help="输出目录",
    )

    # ---- 模型 ----
    parser.add_argument(
        "--ckpt_path",
        type=str,
        required=False,
        help="模型 checkpoint 路径 (save_pretrained 保存的目录)",
        default=None,
    )

    # ---- 推理参数 ----
    parser.add_argument(
        "--num_gpus", type=int, default=None, help="使用 GPU 数量,默认全部"
    )
    parser.add_argument(
        "--lrc_align_mode",
        type=str,
        default="sentence_level",
        choices=["sentence_level"],
    )
    parser.add_argument("--sil_len_to_end", type=float, default=0.5)
    parser.add_argument("--t_shift", type=float, default=0.5)
    parser.add_argument("--nfe_step", type=int, default=32)
    parser.add_argument("--cfg_strength", type=float, default=3.0)
    parser.add_argument("--seed", type=int, default=-1)

    # ---- 其它 ----
    parser.add_argument("--overwrite", action="store_true", help="覆盖已有输出文件")
    parser.add_argument(
        "--show_all_progress", action="store_true", help="所有 GPU 都显示进度条"
    )
    parser.add_argument("--verbose", action="store_true", help="打印详细错误信息")

    args = parser.parse_args()

    # ---- 校验 ----
    if args.input_type == "jsonl":
        assert args.input_path is not None, "--input_path 是 jsonl 模式下必填的"
        assert os.path.isfile(args.input_path), f"文件不存在: {args.input_path}"

    # ---- 加载数据 ----
    print("加载数据...")
    if args.input_type == "jsonl":
        items = load_jsonl(args.input_path)
    else:
        items = load_lyric_edit_bench(args.input_type)
    print(f"共 {len(items)} 条数据")

    # ---- 确定 GPU 数量 ----
    available_gpus = torch.cuda.device_count()
    num_gpus = args.num_gpus or available_gpus
    num_gpus = min(num_gpus, available_gpus, len(items))
    assert num_gpus > 0, "没有可用的 GPU"
    print(f"使用 {num_gpus} 张 GPU")

    # ---- 创建输出目录 ----
    os.makedirs(args.output_dir, exist_ok=True)

    # ---- 启动多进程 ----
    if num_gpus == 1:
        # 单卡直接跑,不需要 spawn
        worker(0, 1, items, args.output_dir, args.ckpt_path, args)
    else:
        mp.set_start_method("spawn", force=True)
        processes = []
        for rank in range(num_gpus):
            p = mp.Process(
                target=worker,
                args=(rank, num_gpus, items, args.output_dir, args.ckpt_path, args),
            )
            p.start()
            processes.append(p)

        for p in processes:
            p.join()

    print(f"\n推理完成! 输出目录: {args.output_dir}")


if __name__ == "__main__":
    main()