| """ |
| GPUๆน้ๆจ็่ๆฌ๏ผinference_gpu.py๏ผ |
| ๅบไบ inference.py๏ผๅขๅ ๏ผ |
| - DataLoader ๆน้ๆจ็๏ผๅคงๅน
ๆๅ GPU ๅฉ็จ็๏ผ |
| - ๅค GPU ๆฏๆ๏ผtorchrun / LOCAL_RANK ่ชๅจๅ็๏ผ |
| - ่พๅบ้ๅปบ้ณ้ข๏ผ.wav๏ผ๏ผไฟ็ๅๅงๅญ็ฎๅฝ็ปๆ |
| - ่ชๅจ่ฃๅช padding๏ผ่พๅบ้ฟๅบฆไธๅๅง้ณ้ขไธ่ด |
| |
| ๅ GPU ่ฟ่ก็คบไพ๏ผ |
| python inference_gpu.py \ |
| --input-dir test_audio/input_test \ |
| --output-dir test_audio/output_test \ |
| --ckpt ckpt/epoch=4-step=1400000.ckpt \ |
| --batch-size 8 --num-workers 4 |
| |
| ๅค GPU ่ฟ่ก็คบไพ๏ผ4 ๅก๏ผ๏ผ |
| torchrun --nproc_per_node=4 inference_gpu.py \ |
| --input-dir test_audio/input_test \ |
| --output-dir test_audio/output_test \ |
| --ckpt ckpt/epoch=4-step=1400000.ckpt \ |
| --batch-size 8 --num-workers 4 |
| """ |
|
|
| import os |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| import soundfile as sf |
| import torchaudio |
| from torchaudio.transforms import Resample |
| from glob import glob |
| from tqdm import tqdm |
| from os.path import join |
| from collections import OrderedDict |
| from argparse import ArgumentParser |
| from time import time |
|
|
| import torch.nn as nn |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import AutoFeatureExtractor, Wav2Vec2BertModel |
|
|
| from vq.codec_encoder import CodecEncoder |
| from vq.codec_decoder_vocos import CodecDecoderVocos |
| from vq.module import SemanticEncoder |
|
|
|
|
| |
| |
| |
|
|
| class AudioDataset(Dataset): |
| """ๅ ่ฝฝ้ณ้ขๅนถๆๅๆๅ w2v-bert feature๏ผ่ฟๅ (audio, feat, path, orig_len)ใ""" |
|
|
| HOP = 320 |
|
|
| def __init__(self, file_list, sampling_rate: int, feature_extractor_path: str): |
| self.file_list = file_list |
| self.sr = sampling_rate |
| self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path) |
|
|
| def __len__(self): |
| return len(self.file_list) |
|
|
| def __getitem__(self, idx): |
| path = self.file_list[idx] |
| audio, sr = torchaudio.load(path) |
|
|
| |
| if audio.shape[0] > 1: |
| audio = audio.mean(dim=0, keepdim=True) |
|
|
| |
| if sr != self.sr: |
| audio = Resample(sr, self.sr)(audio) |
|
|
| orig_len = audio.shape[1] |
|
|
| |
| pad_enc = (self.HOP - audio.shape[1] % self.HOP) % self.HOP |
| audio_padded = F.pad(audio, (0, pad_enc)) |
|
|
| |
| feat = self.feature_extractor( |
| F.pad(audio[0], (160, 160)), |
| sampling_rate=self.sr, |
| return_tensors="pt" |
| ).data['input_features'] |
|
|
| return audio_padded, feat, path, orig_len |
|
|
|
|
| def collate_fn(batch): |
| """ๆไธไธช batch ไธญ้ฟๅบฆไธไธ็ๆ ทๆฌ pad ๆ็ธๅ้ฟๅบฆใ""" |
| audios, feats, paths, orig_lens = zip(*batch) |
|
|
| |
| max_feat_len = max(f.shape[1] for f in feats) |
| max_audio_len = max_feat_len * AudioDataset.HOP |
|
|
| |
| padded_audios = [] |
| for a in audios: |
| diff = max_audio_len - a.shape[1] |
| padded_audios.append(F.pad(a, (0, diff)) if diff > 0 else a[:, :max_audio_len]) |
| padded_audios = torch.stack(padded_audios) |
|
|
| |
| padded_feats = [] |
| for f in feats: |
| diff = max_feat_len - f.shape[1] |
| padded_feats.append(F.pad(f, (0, 0, 0, diff)) if diff > 0 else f[:, :max_feat_len, :]) |
| padded_feats = torch.stack(padded_feats) |
|
|
| return padded_audios, padded_feats, paths, torch.tensor(orig_lens, dtype=torch.long) |
|
|
|
|
| |
| |
| |
|
|
| def load_models(ckpt_path: str, w2v_path: str, device: torch.device): |
| print(f"[rank {device}] Loading checkpoint: {ckpt_path}") |
| ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"] |
|
|
| fd_codec = OrderedDict() |
| fd_sem_enc = OrderedDict() |
| fd_gen = OrderedDict() |
| fd_fc_post = OrderedDict() |
| fd_fc_pri = OrderedDict() |
|
|
| prefix_map = { |
| "CodecEnc.": fd_codec, |
| "generator.": fd_gen, |
| "fc_post_a.": fd_fc_post, |
| "SemanticEncoder_module.": fd_sem_enc, |
| "fc_prior.": fd_fc_pri, |
| } |
| for key, val in ckpt.items(): |
| for prefix, target in prefix_map.items(): |
| if key.startswith(prefix): |
| target[key[len(prefix):]] = val |
| break |
|
|
| semantic_model = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True) |
| semantic_model.eval().to(device) |
|
|
| sem_enc = SemanticEncoder(1024, 1024, 1024) |
| sem_enc.load_state_dict(fd_sem_enc) |
| sem_enc.eval().to(device) |
|
|
| encoder = CodecEncoder() |
| encoder.load_state_dict(fd_codec) |
| encoder.eval().to(device) |
|
|
| decoder = CodecDecoderVocos() |
| decoder.load_state_dict(fd_gen) |
| decoder.eval().to(device) |
|
|
| fc_post_a = nn.Linear(1024, 1024) |
| fc_post_a.load_state_dict(fd_fc_post) |
| fc_post_a.eval().to(device) |
|
|
| fc_prior = nn.Linear(2048, 1024) |
| fc_prior.load_state_dict(fd_fc_pri) |
| fc_prior.eval().to(device) |
|
|
| return semantic_model, sem_enc, encoder, decoder, fc_post_a, fc_prior |
|
|
|
|
| |
| |
| |
|
|
| @torch.no_grad() |
| def infer_batch(wavs, feats, orig_lens, models, device, sr): |
| """ |
| wavs: (B, 1, T) float32๏ผๅทฒ pad |
| feats: (B, 1, Tf, 160) float32 |
| orig_lens:(B,) int๏ผๅๅง้ๆ ท็นๆฐ |
| ่ฟๅ: List[np.ndarray]๏ผๆฏไธชๅ
็ด ไธบ่ฃๅชๅ็้ๅปบๆณขๅฝข |
| """ |
| semantic_model, sem_enc, encoder, decoder, fc_post_a, fc_prior = models |
| |
|
|
| wavs = wavs.to(device) |
| feats = feats[:, 0, :, :].to(device) |
|
|
| |
| vq_emb = encoder(wavs) |
| vq_emb = vq_emb.transpose(1, 2) |
|
|
| |
| sem_out = semantic_model(feats) |
| sem_feat = sem_out.hidden_states[16] |
| sem_feat = sem_feat.transpose(1, 2) |
| sem_feat = sem_enc(sem_feat) |
|
|
| |
| vq_emb = torch.cat([sem_feat, vq_emb], dim=1) |
| vq_emb = fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| vq_post_emb = vq_emb |
|
|
| vq_post_emb = fc_post_a( |
| vq_post_emb.transpose(1, 2) |
| ).transpose(1, 2) |
|
|
| recon_batch = decoder( |
| vq_post_emb.transpose(1, 2), vq=False |
| )[0].squeeze(1) |
|
|
| |
| results = [] |
| for i, wav_len in enumerate(orig_lens.tolist()): |
| results.append(recon_batch[i, :wav_len].detach().cpu().numpy()) |
|
|
| return results |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = ArgumentParser() |
| parser.add_argument("--input-dir", type=str, |
| default="test_audio/input_test") |
| parser.add_argument("--output-dir", type=str, |
| default="test_audio/output_test") |
| parser.add_argument("--ckpt", type=str, |
| default="/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/log/shengpeng_debug/last.ckpt") |
| parser.add_argument("--w2v-path", type=str, |
| default="/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/ckpt/w2v-bert-2.0") |
| parser.add_argument("--batch-size", type=int, default=8) |
| parser.add_argument("--num-workers", type=int, default=4) |
| parser.add_argument("--num-gpus", type=int, default=None, |
| help="ๆๅจๆๅฎ GPU ๆปๆฐ๏ผ้ป่ฎค่ชๅจๆฃๆต๏ผๆ่ฏปๅ torchrun ็ฏๅขๅ้๏ผ") |
| args = parser.parse_args() |
|
|
| sr = 16000 |
|
|
| |
| local_rank = int(os.getenv("LOCAL_RANK", 0)) |
| world_size = int(os.getenv("WORLD_SIZE", 1)) |
| if args.num_gpus is not None: |
| world_size = args.num_gpus |
|
|
| if torch.cuda.is_available(): |
| device = torch.device(f"cuda:{local_rank}") |
| torch.cuda.set_device(device) |
| else: |
| device = torch.device("cpu") |
| print("[่ญฆๅ] ๆชๆฃๆตๅฐ CUDA๏ผไฝฟ็จ CPU ๆจ็๏ผ้ๅบฆ่พๆ
ขใ") |
|
|
| |
| all_paths = [] |
| for ext in ("wav", "flac", "mp3"): |
| all_paths += glob(join(args.input_dir, "**", f"*.{ext}"), recursive=True) |
| all_paths = sorted(set(all_paths)) |
|
|
| if world_size > 1: |
| |
| all_paths = np.array_split(all_paths, world_size)[local_rank].tolist() |
|
|
| print(f"[rank {local_rank}] {len(all_paths)} files to process on {device}") |
|
|
| if len(all_paths) == 0: |
| print(f"[rank {local_rank}] No files found, exiting.") |
| return |
|
|
| |
| dataset = AudioDataset(all_paths, sr, args.w2v_path) |
| dataloader = DataLoader( |
| dataset, |
| batch_size=args.batch_size, |
| shuffle=False, |
| num_workers=args.num_workers, |
| pin_memory=torch.cuda.is_available(), |
| collate_fn=collate_fn, |
| drop_last=False, |
| ) |
|
|
| |
| models = load_models(args.ckpt, args.w2v_path, device) |
|
|
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
| |
| st = time() |
| for wavs, feats, paths, orig_lens in tqdm(dataloader, |
| desc=f"rank {local_rank}", |
| dynamic_ncols=True): |
| recon_list = infer_batch(wavs, feats, orig_lens, models, device, sr) |
|
|
| for recon, src_path in zip(recon_list, paths): |
| |
| rel = os.path.relpath(src_path, args.input_dir) |
| |
| rel = os.path.splitext(rel)[0] + ".wav" |
| dst = join(args.output_dir, rel) |
| os.makedirs(os.path.dirname(dst), exist_ok=True) |
| sf.write(dst, recon, sr) |
|
|
| et = time() |
| print(f"[rank {local_rank}] Done. Total time: {(et - st) / 60:.2f} min") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|