""" GPU批量推理脚本(inference_gpu.py) 基于 inference.py,增加: - DataLoader 批量推理(大幅提升 GPU 利用率) - 多 GPU 支持(torchrun / LOCAL_RANK 自动分片) - 输出重建音频(.wav),保留原始子目录结构 - 自动裁剪 padding,输出长度与原始音频一致 单 GPU 运行示例: python inference_gpu.py \ --input-dir test_audio/input_test \ --output-dir test_audio/output_test \ --ckpt ckpt/epoch=4-step=1400000.ckpt \ --batch-size 8 --num-workers 4 多 GPU 运行示例(4 卡): torchrun --nproc_per_node=4 inference_gpu.py \ --input-dir test_audio/input_test \ --output-dir test_audio/output_test \ --ckpt ckpt/epoch=4-step=1400000.ckpt \ --batch-size 8 --num-workers 4 """ import os import torch import torch.nn.functional as F import numpy as np import soundfile as sf import torchaudio from torchaudio.transforms import Resample from glob import glob from tqdm import tqdm from os.path import join from collections import OrderedDict from argparse import ArgumentParser from time import time import torch.nn as nn from torch.utils.data import Dataset, DataLoader from transformers import AutoFeatureExtractor, Wav2Vec2BertModel from vq.codec_encoder import CodecEncoder from vq.codec_decoder_vocos import CodecDecoderVocos from vq.module import SemanticEncoder # ───────────────────────────────────────────── # Dataset # ───────────────────────────────────────────── class AudioDataset(Dataset): """加载音频并提前提取 w2v-bert feature,返回 (audio, feat, path, orig_len)。""" HOP = 320 # encoder 下采样步长 def __init__(self, file_list, sampling_rate: int, feature_extractor_path: str): self.file_list = file_list self.sr = sampling_rate self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path) def __len__(self): return len(self.file_list) def __getitem__(self, idx): path = self.file_list[idx] audio, sr = torchaudio.load(path) # 转单声道 if audio.shape[0] > 1: audio = audio.mean(dim=0, keepdim=True) # 重采样 if sr != self.sr: audio = Resample(sr, self.sr)(audio) # (1, T) orig_len = audio.shape[1] # 原始采样点数,用于裁剪 padding # 对 encoder 补齐到 HOP 整数倍 pad_enc = (self.HOP - audio.shape[1] % self.HOP) % self.HOP audio_padded = F.pad(audio, (0, pad_enc)) # (1, T') # feature extractor 需要在首尾各 pad 160 feat = self.feature_extractor( F.pad(audio[0], (160, 160)), sampling_rate=self.sr, return_tensors="pt" ).data['input_features'] # (1, T_feat, 160) return audio_padded, feat, path, orig_len def collate_fn(batch): """把一个 batch 中长度不一的样本 pad 成相同长度。""" audios, feats, paths, orig_lens = zip(*batch) # feat 长度对齐(encoder 输出帧数) max_feat_len = max(f.shape[1] for f in feats) max_audio_len = max_feat_len * AudioDataset.HOP # audio pad padded_audios = [] for a in audios: diff = max_audio_len - a.shape[1] padded_audios.append(F.pad(a, (0, diff)) if diff > 0 else a[:, :max_audio_len]) padded_audios = torch.stack(padded_audios) # (B, 1, T) # feat pad padded_feats = [] for f in feats: diff = max_feat_len - f.shape[1] padded_feats.append(F.pad(f, (0, 0, 0, diff)) if diff > 0 else f[:, :max_feat_len, :]) padded_feats = torch.stack(padded_feats) # (B, 1, T_feat, 160) return padded_audios, padded_feats, paths, torch.tensor(orig_lens, dtype=torch.long) # ───────────────────────────────────────────── # 模型加载工具 # ───────────────────────────────────────────── def load_models(ckpt_path: str, w2v_path: str, device: torch.device): print(f"[rank {device}] Loading checkpoint: {ckpt_path}") ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] fd_codec = OrderedDict() fd_sem_enc = OrderedDict() fd_gen = OrderedDict() fd_fc_post = OrderedDict() fd_fc_pri = OrderedDict() prefix_map = { "CodecEnc.": fd_codec, "generator.": fd_gen, "fc_post_a.": fd_fc_post, "SemanticEncoder_module.": fd_sem_enc, "fc_prior.": fd_fc_pri, } for key, val in ckpt.items(): for prefix, target in prefix_map.items(): if key.startswith(prefix): target[key[len(prefix):]] = val break semantic_model = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True) semantic_model.eval().to(device) sem_enc = SemanticEncoder(1024, 1024, 1024) sem_enc.load_state_dict(fd_sem_enc) sem_enc.eval().to(device) encoder = CodecEncoder() encoder.load_state_dict(fd_codec) encoder.eval().to(device) decoder = CodecDecoderVocos() decoder.load_state_dict(fd_gen) decoder.eval().to(device) fc_post_a = nn.Linear(1024, 1024) fc_post_a.load_state_dict(fd_fc_post) fc_post_a.eval().to(device) fc_prior = nn.Linear(2048, 1024) fc_prior.load_state_dict(fd_fc_pri) fc_prior.eval().to(device) return semantic_model, sem_enc, encoder, decoder, fc_post_a, fc_prior # ───────────────────────────────────────────── # 单 batch 推理 # ───────────────────────────────────────────── @torch.no_grad() def infer_batch(wavs, feats, orig_lens, models, device, sr): """ wavs: (B, 1, T) float32,已 pad feats: (B, 1, Tf, 160) float32 orig_lens:(B,) int,原始采样点数 返回: List[np.ndarray],每个元素为裁剪后的重建波形 """ semantic_model, sem_enc, encoder, decoder, fc_post_a, fc_prior = models # breakpoint() wavs = wavs.to(device) # (B, 1, T) feats = feats[:, 0, :, :].to(device) # (B, Tf, 160) # ① Codec encoder vq_emb = encoder(wavs) # (B, C, T//HOP) vq_emb = vq_emb.transpose(1, 2) # (B, T//HOP, C) # ② Semantic encoder sem_out = semantic_model(feats) sem_feat = sem_out.hidden_states[16] # (B, Tf, 1024) sem_feat = sem_feat.transpose(1, 2) # (B, 1024, Tf) sem_feat = sem_enc(sem_feat) # (B, 1024, Tf) # ③ 拼接 & fc_prior vq_emb = torch.cat([sem_feat, vq_emb], dim=1) # (B, 2048, Tf) vq_emb = fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) # (B, 2048, Tf) # # ④ VQ # _, vq_code, _ = decoder(vq_emb, vq=True) # vq_code: (B, 1, Tf) # # ⑤ 解码重建 # vq_post_emb = decoder.quantizer.get_output_from_indices( # vq_code.transpose(1, 2) # ) # (B, Tf, 1024) # vq_post_emb = vq_post_emb.transpose(1, 2) # (B, 1024, Tf) vq_post_emb = vq_emb vq_post_emb = fc_post_a( vq_post_emb.transpose(1, 2) ).transpose(1, 2) # (B, 1024, Tf) recon_batch = decoder( vq_post_emb.transpose(1, 2), vq=False )[0].squeeze(1) # (B, T) # ⑥ 裁剪到原始长度 results = [] for i, wav_len in enumerate(orig_lens.tolist()): results.append(recon_batch[i, :wav_len].detach().cpu().numpy()) return results # ───────────────────────────────────────────── # 主函数 # ───────────────────────────────────────────── def main(): parser = ArgumentParser() parser.add_argument("--input-dir", type=str, default="test_audio/input_test") parser.add_argument("--output-dir", type=str, default="test_audio/output_test") parser.add_argument("--ckpt", type=str, default="/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/log/shengpeng_debug/last.ckpt") parser.add_argument("--w2v-path", type=str, default="/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/ckpt/w2v-bert-2.0") parser.add_argument("--batch-size", type=int, default=8) parser.add_argument("--num-workers", type=int, default=4) parser.add_argument("--num-gpus", type=int, default=None, help="手动指定 GPU 总数;默认自动检测(或读取 torchrun 环境变量)") args = parser.parse_args() sr = 16000 # ── 多 GPU 分片(兼容 torchrun 和单机手动指定)──────────────── local_rank = int(os.getenv("LOCAL_RANK", 0)) world_size = int(os.getenv("WORLD_SIZE", 1)) if args.num_gpus is not None: world_size = args.num_gpus if torch.cuda.is_available(): device = torch.device(f"cuda:{local_rank}") torch.cuda.set_device(device) else: device = torch.device("cpu") print("[警告] 未检测到 CUDA,使用 CPU 推理,速度较慢。") # ── 收集音频文件 ───────────────────────────────────────────── all_paths = [] for ext in ("wav", "flac", "mp3"): all_paths += glob(join(args.input_dir, "**", f"*.{ext}"), recursive=True) all_paths = sorted(set(all_paths)) if world_size > 1: # 按 rank 均匀分片 all_paths = np.array_split(all_paths, world_size)[local_rank].tolist() print(f"[rank {local_rank}] {len(all_paths)} files to process on {device}") if len(all_paths) == 0: print(f"[rank {local_rank}] No files found, exiting.") return # ── 构建 DataLoader ────────────────────────────────────────── dataset = AudioDataset(all_paths, sr, args.w2v_path) dataloader = DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=torch.cuda.is_available(), collate_fn=collate_fn, drop_last=False, ) # ── 加载模型 ───────────────────────────────────────────────── models = load_models(args.ckpt, args.w2v_path, device) os.makedirs(args.output_dir, exist_ok=True) # ── 推理循环 ───────────────────────────────────────────────── st = time() for wavs, feats, paths, orig_lens in tqdm(dataloader, desc=f"rank {local_rank}", dynamic_ncols=True): recon_list = infer_batch(wavs, feats, orig_lens, models, device, sr) for recon, src_path in zip(recon_list, paths): # 保留相对于 input_dir 的子目录结构 rel = os.path.relpath(src_path, args.input_dir) # 统一输出为 .wav rel = os.path.splitext(rel)[0] + ".wav" dst = join(args.output_dir, rel) os.makedirs(os.path.dirname(dst), exist_ok=True) sf.write(dst, recon, sr) et = time() print(f"[rank {local_rank}] Done. Total time: {(et - st) / 60:.2f} min") if __name__ == "__main__": main()