YingMusic-Singer-Plus / inference_mp.py
xjsc0's picture
11
ffbb4ab
"""
YingMusicSinger 批量推理脚本
支持多卡多进程、进度条显示
输入支持 JSONL 文件 或 LyricEditBench 数据集
用法:
# JSONL 输入,4卡
python batch_infer.py \
--input_type jsonl \
--input_path /path/to/input.jsonl \
--output_dir /path/to/output \
--ckpt_path /path/to/ckpts \
--num_gpus 4
# LyricEditBench 输入
python batch_infer.py \
--input_type lyric_edit_bench \
--output_dir /path/to/output \
--ckpt_path /path/to/ckpts \
--num_gpus 4
"""
import argparse
import json
import os
import sys
import traceback
from pathlib import Path
import torch
import torch.multiprocessing as mp
import torchaudio
from datasets import Audio, Dataset
from huggingface_hub import hf_hub_download
from tqdm import tqdm
def load_jsonl(path: str) -> list[dict]:
items = []
with open(path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
items.append(json.loads(line))
return items
def build_dataset_from_local(gtsinger_root: str):
"""
Build LyricEditBench dataset using your local GTSinger directory.
Args:
gtsinger_root: Root directory of your local GTSinger dataset.
"""
# Download the inherited metadata from HuggingFace
json_path = hf_hub_download(
repo_id="ASLP-lab/LyricEditBench",
filename="GTSinger_Inherited.json",
repo_type="dataset",
)
with open(json_path, "r") as f:
data = json.load(f)
gtsinger_root = str(Path(gtsinger_root).resolve())
# Prepend local root to relative paths
for item in data:
item["melody_ref_path"] = os.path.join(gtsinger_root, item["melody_ref_path"])
item["timbre_ref_path"] = os.path.join(gtsinger_root, item["timbre_ref_path"])
# Set audio fields to the resolved file paths
item["melody_ref_audio"] = item["melody_ref_path"]
item["timbre_ref_audio"] = item["timbre_ref_path"]
# Build HuggingFace Dataset with Audio features
ds = Dataset.from_list(data)
ds = ds.cast_column("melody_ref_audio", Audio())
ds = ds.cast_column("timbre_ref_audio", Audio())
return ds
def load_subset(data: list, subset_id: str) -> list:
"""Filter dataset by a subset ID list."""
subset_path = hf_hub_download(
repo_id="ASLP-lab/LyricEditBench",
filename=f"id_lists/{subset_id}.txt",
repo_type="dataset",
)
with open(subset_path, "r") as f:
id_set = set(line.strip() for line in f if line.strip())
return [item for item in data if item["id"] in id_set]
def load_lyric_edit_bench(input_type) -> list[dict]:
# If you have GTsinger downloaded, use this:
ds_full = build_dataset_from_local(
"/user-fs/chenzihao/zhengjunjie/datas/Music/openvocaldata/GTSinger"
)
# else, you kan use this:
# from datasets import load_dataset
# ds_full = load_dataset("ASLP-lab/LyricEditBench", split="test")
# ds_full loaded
subset_1k = load_subset(ds_full, "1K")
print(f"Loaded {len(subset_1k)} items")
items = []
for row in subset_1k:
if input_type == "lyric_edit_bench_melody_control":
items.append(
{
"id": row.get("id", ""),
"melody_ref_path": row.get("melody_ref_path", ""),
"gen_text": row.get("gen_text", ""),
"timbre_ref_path": row.get("timbre_ref_path", ""),
"timbre_ref_text": row.get("timbre_ref_text", ""),
}
)
elif input_type == "lyric_edit_bench_sing_edit":
items.append(
{
"id": row.get("id", ""),
"melody_ref_path": row.get("melody_ref_path", ""),
"gen_text": row.get("gen_text", ""),
"timbre_ref_path": row.get("melody_ref_path", ""),
"timbre_ref_text": row.get("melody_ref_text", ""),
}
)
else:
assert 0
return items
def worker(
rank: int,
world_size: int,
items: list[dict],
output_dir: str,
ckpt_path: str,
args: argparse.Namespace,
):
"""每个 GPU 上运行的 worker 进程"""
device = f"cuda:{rank}"
torch.cuda.set_device(rank)
# ---- 加载模型 ----
from src.YingMusicSinger.infer.YingMusicSinger import YingMusicSinger
model = YingMusicSinger.from_pretrained(ckpt_path)
model.to(device)
model.eval()
# ---- 分片: 每个 worker 处理自己那份 ----
shard = items[rank::world_size]
# ---- 只在 rank 0 显示进度条 ----
pbar = tqdm(
shard,
desc=f"[GPU {rank}]",
position=rank,
leave=True,
disable=(rank != 0 and not args.show_all_progress),
)
success, fail = 0, 0
for item in pbar:
item_id = item.get("id", f"unknown_{success + fail}")
out_path = os.path.join(output_dir, f"{item_id}.wav")
# 跳过已存在的文件
if os.path.exists(out_path) and not args.overwrite:
success += 1
pbar.set_postfix(ok=success, err=fail)
continue
try:
with torch.no_grad():
audio, sr = model(
ref_audio_path=item["timbre_ref_path"],
melody_audio_path=item["melody_ref_path"],
ref_text=item.get("timbre_ref_text", ""),
target_text=item.get("gen_text", ""),
lrc_align_mode=args.lrc_align_mode,
sil_len_to_end=args.sil_len_to_end,
t_shift=args.t_shift,
nfe_step=args.nfe_step,
cfg_strength=args.cfg_strength,
seed=args.seed
if args.seed != -1
else torch.randint(0, 2**32, (1,)).item(),
)
torchaudio.save(out_path, audio, sample_rate=sr)
success += 1
except Exception as e:
fail += 1
print(f"\n[GPU {rank}] ERROR on {item_id}: {e}", file=sys.stderr)
if args.verbose:
traceback.print_exc()
pbar.set_postfix(ok=success, err=fail)
pbar.close()
print(f"[GPU {rank}] Done. success={success}, fail={fail}")
def main():
parser = argparse.ArgumentParser(description="YingMusicSinger 批量推理")
# ---- 输入 ----
parser.add_argument(
"--input_type",
type=str,
required=True,
choices=[
"jsonl",
"lyric_edit_bench_melody_control",
"lyric_edit_bench_sing_edit",
],
help="输入类型: jsonl / lyric_edit_bench_melody_control 或 lyric_edit_bench_sing_edit",
)
parser.add_argument(
"--input_path",
type=str,
default=None,
help="JSONL 文件路径 (input_type=jsonl 时必填)",
)
# ---- 输出 ----
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="输出目录",
)
# ---- 模型 ----
parser.add_argument(
"--ckpt_path",
type=str,
required=False,
help="模型 checkpoint 路径 (save_pretrained 保存的目录)",
default=None,
)
# ---- 推理参数 ----
parser.add_argument(
"--num_gpus", type=int, default=None, help="使用 GPU 数量,默认全部"
)
parser.add_argument(
"--lrc_align_mode",
type=str,
default="sentence_level",
choices=["sentence_level"],
)
parser.add_argument("--sil_len_to_end", type=float, default=0.5)
parser.add_argument("--t_shift", type=float, default=0.5)
parser.add_argument("--nfe_step", type=int, default=32)
parser.add_argument("--cfg_strength", type=float, default=3.0)
parser.add_argument("--seed", type=int, default=-1)
# ---- 其它 ----
parser.add_argument("--overwrite", action="store_true", help="覆盖已有输出文件")
parser.add_argument(
"--show_all_progress", action="store_true", help="所有 GPU 都显示进度条"
)
parser.add_argument("--verbose", action="store_true", help="打印详细错误信息")
args = parser.parse_args()
# ---- 校验 ----
if args.input_type == "jsonl":
assert args.input_path is not None, "--input_path 是 jsonl 模式下必填的"
assert os.path.isfile(args.input_path), f"文件不存在: {args.input_path}"
# ---- 加载数据 ----
print("加载数据...")
if args.input_type == "jsonl":
items = load_jsonl(args.input_path)
else:
items = load_lyric_edit_bench(args.input_type)
print(f"共 {len(items)} 条数据")
# ---- 确定 GPU 数量 ----
available_gpus = torch.cuda.device_count()
num_gpus = args.num_gpus or available_gpus
num_gpus = min(num_gpus, available_gpus, len(items))
assert num_gpus > 0, "没有可用的 GPU"
print(f"使用 {num_gpus} 张 GPU")
# ---- 创建输出目录 ----
os.makedirs(args.output_dir, exist_ok=True)
# ---- 启动多进程 ----
if num_gpus == 1:
# 单卡直接跑,不需要 spawn
worker(0, 1, items, args.output_dir, args.ckpt_path, args)
else:
mp.set_start_method("spawn", force=True)
processes = []
for rank in range(num_gpus):
p = mp.Process(
target=worker,
args=(rank, num_gpus, items, args.output_dir, args.ckpt_path, args),
)
p.start()
processes.append(p)
for p in processes:
p.join()
print(f"\n推理完成! 输出目录: {args.output_dir}")
if __name__ == "__main__":
main()