""" YingMusicSinger 批量推理脚本 支持多卡多进程、进度条显示 输入支持 JSONL 文件 或 LyricEditBench 数据集 用法: # JSONL 输入,4卡 python batch_infer.py \ --input_type jsonl \ --input_path /path/to/input.jsonl \ --output_dir /path/to/output \ --ckpt_path /path/to/ckpts \ --num_gpus 4 # LyricEditBench 输入 python batch_infer.py \ --input_type lyric_edit_bench \ --output_dir /path/to/output \ --ckpt_path /path/to/ckpts \ --num_gpus 4 """ import argparse import json import os import sys import traceback from pathlib import Path import torch import torch.multiprocessing as mp import torchaudio from datasets import Audio, Dataset from huggingface_hub import hf_hub_download from tqdm import tqdm def load_jsonl(path: str) -> list[dict]: items = [] with open(path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if line: items.append(json.loads(line)) return items def build_dataset_from_local(gtsinger_root: str): """ Build LyricEditBench dataset using your local GTSinger directory. Args: gtsinger_root: Root directory of your local GTSinger dataset. """ # Download the inherited metadata from HuggingFace json_path = hf_hub_download( repo_id="ASLP-lab/LyricEditBench", filename="GTSinger_Inherited.json", repo_type="dataset", ) with open(json_path, "r") as f: data = json.load(f) gtsinger_root = str(Path(gtsinger_root).resolve()) # Prepend local root to relative paths for item in data: item["melody_ref_path"] = os.path.join(gtsinger_root, item["melody_ref_path"]) item["timbre_ref_path"] = os.path.join(gtsinger_root, item["timbre_ref_path"]) # Set audio fields to the resolved file paths item["melody_ref_audio"] = item["melody_ref_path"] item["timbre_ref_audio"] = item["timbre_ref_path"] # Build HuggingFace Dataset with Audio features ds = Dataset.from_list(data) ds = ds.cast_column("melody_ref_audio", Audio()) ds = ds.cast_column("timbre_ref_audio", Audio()) return ds def load_subset(data: list, subset_id: str) -> list: """Filter dataset by a subset ID list.""" subset_path = hf_hub_download( repo_id="ASLP-lab/LyricEditBench", filename=f"id_lists/{subset_id}.txt", repo_type="dataset", ) with open(subset_path, "r") as f: id_set = set(line.strip() for line in f if line.strip()) return [item for item in data if item["id"] in id_set] def load_lyric_edit_bench(input_type) -> list[dict]: # If you have GTsinger downloaded, use this: ds_full = build_dataset_from_local( "/user-fs/chenzihao/zhengjunjie/datas/Music/openvocaldata/GTSinger" ) # else, you kan use this: # from datasets import load_dataset # ds_full = load_dataset("ASLP-lab/LyricEditBench", split="test") # ds_full loaded subset_1k = load_subset(ds_full, "1K") print(f"Loaded {len(subset_1k)} items") items = [] for row in subset_1k: if input_type == "lyric_edit_bench_melody_control": items.append( { "id": row.get("id", ""), "melody_ref_path": row.get("melody_ref_path", ""), "gen_text": row.get("gen_text", ""), "timbre_ref_path": row.get("timbre_ref_path", ""), "timbre_ref_text": row.get("timbre_ref_text", ""), } ) elif input_type == "lyric_edit_bench_sing_edit": items.append( { "id": row.get("id", ""), "melody_ref_path": row.get("melody_ref_path", ""), "gen_text": row.get("gen_text", ""), "timbre_ref_path": row.get("melody_ref_path", ""), "timbre_ref_text": row.get("melody_ref_text", ""), } ) else: assert 0 return items def worker( rank: int, world_size: int, items: list[dict], output_dir: str, ckpt_path: str, args: argparse.Namespace, ): """每个 GPU 上运行的 worker 进程""" device = f"cuda:{rank}" torch.cuda.set_device(rank) # ---- 加载模型 ---- from src.YingMusicSinger.infer.YingMusicSinger import YingMusicSinger model = YingMusicSinger.from_pretrained(ckpt_path) model.to(device) model.eval() # ---- 分片: 每个 worker 处理自己那份 ---- shard = items[rank::world_size] # ---- 只在 rank 0 显示进度条 ---- pbar = tqdm( shard, desc=f"[GPU {rank}]", position=rank, leave=True, disable=(rank != 0 and not args.show_all_progress), ) success, fail = 0, 0 for item in pbar: item_id = item.get("id", f"unknown_{success + fail}") out_path = os.path.join(output_dir, f"{item_id}.wav") # 跳过已存在的文件 if os.path.exists(out_path) and not args.overwrite: success += 1 pbar.set_postfix(ok=success, err=fail) continue try: with torch.no_grad(): audio, sr = model( ref_audio_path=item["timbre_ref_path"], melody_audio_path=item["melody_ref_path"], ref_text=item.get("timbre_ref_text", ""), target_text=item.get("gen_text", ""), lrc_align_mode=args.lrc_align_mode, sil_len_to_end=args.sil_len_to_end, t_shift=args.t_shift, nfe_step=args.nfe_step, cfg_strength=args.cfg_strength, seed=args.seed if args.seed != -1 else torch.randint(0, 2**32, (1,)).item(), ) torchaudio.save(out_path, audio, sample_rate=sr) success += 1 except Exception as e: fail += 1 print(f"\n[GPU {rank}] ERROR on {item_id}: {e}", file=sys.stderr) if args.verbose: traceback.print_exc() pbar.set_postfix(ok=success, err=fail) pbar.close() print(f"[GPU {rank}] Done. success={success}, fail={fail}") def main(): parser = argparse.ArgumentParser(description="YingMusicSinger 批量推理") # ---- 输入 ---- parser.add_argument( "--input_type", type=str, required=True, choices=[ "jsonl", "lyric_edit_bench_melody_control", "lyric_edit_bench_sing_edit", ], help="输入类型: jsonl / lyric_edit_bench_melody_control 或 lyric_edit_bench_sing_edit", ) parser.add_argument( "--input_path", type=str, default=None, help="JSONL 文件路径 (input_type=jsonl 时必填)", ) # ---- 输出 ---- parser.add_argument( "--output_dir", type=str, required=True, help="输出目录", ) # ---- 模型 ---- parser.add_argument( "--ckpt_path", type=str, required=False, help="模型 checkpoint 路径 (save_pretrained 保存的目录)", default=None, ) # ---- 推理参数 ---- parser.add_argument( "--num_gpus", type=int, default=None, help="使用 GPU 数量,默认全部" ) parser.add_argument( "--lrc_align_mode", type=str, default="sentence_level", choices=["sentence_level"], ) parser.add_argument("--sil_len_to_end", type=float, default=0.5) parser.add_argument("--t_shift", type=float, default=0.5) parser.add_argument("--nfe_step", type=int, default=32) parser.add_argument("--cfg_strength", type=float, default=3.0) parser.add_argument("--seed", type=int, default=-1) # ---- 其它 ---- parser.add_argument("--overwrite", action="store_true", help="覆盖已有输出文件") parser.add_argument( "--show_all_progress", action="store_true", help="所有 GPU 都显示进度条" ) parser.add_argument("--verbose", action="store_true", help="打印详细错误信息") args = parser.parse_args() # ---- 校验 ---- if args.input_type == "jsonl": assert args.input_path is not None, "--input_path 是 jsonl 模式下必填的" assert os.path.isfile(args.input_path), f"文件不存在: {args.input_path}" # ---- 加载数据 ---- print("加载数据...") if args.input_type == "jsonl": items = load_jsonl(args.input_path) else: items = load_lyric_edit_bench(args.input_type) print(f"共 {len(items)} 条数据") # ---- 确定 GPU 数量 ---- available_gpus = torch.cuda.device_count() num_gpus = args.num_gpus or available_gpus num_gpus = min(num_gpus, available_gpus, len(items)) assert num_gpus > 0, "没有可用的 GPU" print(f"使用 {num_gpus} 张 GPU") # ---- 创建输出目录 ---- os.makedirs(args.output_dir, exist_ok=True) # ---- 启动多进程 ---- if num_gpus == 1: # 单卡直接跑,不需要 spawn worker(0, 1, items, args.output_dir, args.ckpt_path, args) else: mp.set_start_method("spawn", force=True) processes = [] for rank in range(num_gpus): p = mp.Process( target=worker, args=(rank, num_gpus, items, args.output_dir, args.ckpt_path, args), ) p.start() processes.append(p) for p in processes: p.join() print(f"\n推理完成! 输出目录: {args.output_dir}") if __name__ == "__main__": main()