fff / X-Codec-2.0 /inference_gpu.py
novateur's picture
Add files using upload-large-folder tool
24b4b54 verified
"""
GPUๆ‰น้‡ๆŽจ็†่„šๆœฌ๏ผˆinference_gpu.py๏ผ‰
ๅŸบไบŽ inference.py๏ผŒๅขžๅŠ ๏ผš
- DataLoader ๆ‰น้‡ๆŽจ็†๏ผˆๅคงๅน…ๆๅ‡ GPU ๅˆฉ็”จ็އ๏ผ‰
- ๅคš GPU ๆ”ฏๆŒ๏ผˆtorchrun / LOCAL_RANK ่‡ชๅŠจๅˆ†็‰‡๏ผ‰
- ่พ“ๅ‡บ้‡ๅปบ้Ÿณ้ข‘๏ผˆ.wav๏ผ‰๏ผŒไฟ็•™ๅŽŸๅง‹ๅญ็›ฎๅฝ•็ป“ๆž„
- ่‡ชๅŠจ่ฃๅ‰ช padding๏ผŒ่พ“ๅ‡บ้•ฟๅบฆไธŽๅŽŸๅง‹้Ÿณ้ข‘ไธ€่‡ด
ๅ• GPU ่ฟ่กŒ็คบไพ‹๏ผš
python inference_gpu.py \
--input-dir test_audio/input_test \
--output-dir test_audio/output_test \
--ckpt ckpt/epoch=4-step=1400000.ckpt \
--batch-size 8 --num-workers 4
ๅคš GPU ่ฟ่กŒ็คบไพ‹๏ผˆ4 ๅก๏ผ‰๏ผš
torchrun --nproc_per_node=4 inference_gpu.py \
--input-dir test_audio/input_test \
--output-dir test_audio/output_test \
--ckpt ckpt/epoch=4-step=1400000.ckpt \
--batch-size 8 --num-workers 4
"""
import os
import torch
import torch.nn.functional as F
import numpy as np
import soundfile as sf
import torchaudio
from torchaudio.transforms import Resample
from glob import glob
from tqdm import tqdm
from os.path import join
from collections import OrderedDict
from argparse import ArgumentParser
from time import time
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoFeatureExtractor, Wav2Vec2BertModel
from vq.codec_encoder import CodecEncoder
from vq.codec_decoder_vocos import CodecDecoderVocos
from vq.module import SemanticEncoder
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# Dataset
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
class AudioDataset(Dataset):
"""ๅŠ ่ฝฝ้Ÿณ้ข‘ๅนถๆๅ‰ๆๅ– w2v-bert feature๏ผŒ่ฟ”ๅ›ž (audio, feat, path, orig_len)ใ€‚"""
HOP = 320 # encoder ไธ‹้‡‡ๆ ทๆญฅ้•ฟ
def __init__(self, file_list, sampling_rate: int, feature_extractor_path: str):
self.file_list = file_list
self.sr = sampling_rate
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
path = self.file_list[idx]
audio, sr = torchaudio.load(path)
# ่ฝฌๅ•ๅฃฐ้“
if audio.shape[0] > 1:
audio = audio.mean(dim=0, keepdim=True)
# ้‡้‡‡ๆ ท
if sr != self.sr:
audio = Resample(sr, self.sr)(audio) # (1, T)
orig_len = audio.shape[1] # ๅŽŸๅง‹้‡‡ๆ ท็‚นๆ•ฐ๏ผŒ็”จไบŽ่ฃๅ‰ช padding
# ๅฏน encoder ่กฅ้ฝๅˆฐ HOP ๆ•ดๆ•ฐๅ€
pad_enc = (self.HOP - audio.shape[1] % self.HOP) % self.HOP
audio_padded = F.pad(audio, (0, pad_enc)) # (1, T')
# feature extractor ้œ€่ฆๅœจ้ฆ–ๅฐพๅ„ pad 160
feat = self.feature_extractor(
F.pad(audio[0], (160, 160)),
sampling_rate=self.sr,
return_tensors="pt"
).data['input_features'] # (1, T_feat, 160)
return audio_padded, feat, path, orig_len
def collate_fn(batch):
"""ๆŠŠไธ€ไธช batch ไธญ้•ฟๅบฆไธไธ€็š„ๆ ทๆœฌ pad ๆˆ็›ธๅŒ้•ฟๅบฆใ€‚"""
audios, feats, paths, orig_lens = zip(*batch)
# feat ้•ฟๅบฆๅฏน้ฝ๏ผˆencoder ่พ“ๅ‡บๅธงๆ•ฐ๏ผ‰
max_feat_len = max(f.shape[1] for f in feats)
max_audio_len = max_feat_len * AudioDataset.HOP
# audio pad
padded_audios = []
for a in audios:
diff = max_audio_len - a.shape[1]
padded_audios.append(F.pad(a, (0, diff)) if diff > 0 else a[:, :max_audio_len])
padded_audios = torch.stack(padded_audios) # (B, 1, T)
# feat pad
padded_feats = []
for f in feats:
diff = max_feat_len - f.shape[1]
padded_feats.append(F.pad(f, (0, 0, 0, diff)) if diff > 0 else f[:, :max_feat_len, :])
padded_feats = torch.stack(padded_feats) # (B, 1, T_feat, 160)
return padded_audios, padded_feats, paths, torch.tensor(orig_lens, dtype=torch.long)
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๆจกๅž‹ๅŠ ่ฝฝๅทฅๅ…ท
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def load_models(ckpt_path: str, w2v_path: str, device: torch.device):
print(f"[rank {device}] Loading checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location="cpu")["state_dict"]
fd_codec = OrderedDict()
fd_sem_enc = OrderedDict()
fd_gen = OrderedDict()
fd_fc_post = OrderedDict()
fd_fc_pri = OrderedDict()
prefix_map = {
"CodecEnc.": fd_codec,
"generator.": fd_gen,
"fc_post_a.": fd_fc_post,
"SemanticEncoder_module.": fd_sem_enc,
"fc_prior.": fd_fc_pri,
}
for key, val in ckpt.items():
for prefix, target in prefix_map.items():
if key.startswith(prefix):
target[key[len(prefix):]] = val
break
semantic_model = Wav2Vec2BertModel.from_pretrained(w2v_path, output_hidden_states=True)
semantic_model.eval().to(device)
sem_enc = SemanticEncoder(1024, 1024, 1024)
sem_enc.load_state_dict(fd_sem_enc)
sem_enc.eval().to(device)
encoder = CodecEncoder()
encoder.load_state_dict(fd_codec)
encoder.eval().to(device)
decoder = CodecDecoderVocos()
decoder.load_state_dict(fd_gen)
decoder.eval().to(device)
fc_post_a = nn.Linear(1024, 1024)
fc_post_a.load_state_dict(fd_fc_post)
fc_post_a.eval().to(device)
fc_prior = nn.Linear(2048, 1024)
fc_prior.load_state_dict(fd_fc_pri)
fc_prior.eval().to(device)
return semantic_model, sem_enc, encoder, decoder, fc_post_a, fc_prior
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ๅ• batch ๆŽจ็†
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
@torch.no_grad()
def infer_batch(wavs, feats, orig_lens, models, device, sr):
"""
wavs: (B, 1, T) float32๏ผŒๅทฒ pad
feats: (B, 1, Tf, 160) float32
orig_lens:(B,) int๏ผŒๅŽŸๅง‹้‡‡ๆ ท็‚นๆ•ฐ
่ฟ”ๅ›ž: List[np.ndarray]๏ผŒๆฏไธชๅ…ƒ็ด ไธบ่ฃๅ‰ชๅŽ็š„้‡ๅปบๆณขๅฝข
"""
semantic_model, sem_enc, encoder, decoder, fc_post_a, fc_prior = models
# breakpoint()
wavs = wavs.to(device) # (B, 1, T)
feats = feats[:, 0, :, :].to(device) # (B, Tf, 160)
# โ‘  Codec encoder
vq_emb = encoder(wavs) # (B, C, T//HOP)
vq_emb = vq_emb.transpose(1, 2) # (B, T//HOP, C)
# โ‘ก Semantic encoder
sem_out = semantic_model(feats)
sem_feat = sem_out.hidden_states[16] # (B, Tf, 1024)
sem_feat = sem_feat.transpose(1, 2) # (B, 1024, Tf)
sem_feat = sem_enc(sem_feat) # (B, 1024, Tf)
# โ‘ข ๆ‹ผๆŽฅ & fc_prior
vq_emb = torch.cat([sem_feat, vq_emb], dim=1) # (B, 2048, Tf)
vq_emb = fc_prior(vq_emb.transpose(1, 2)).transpose(1, 2) # (B, 2048, Tf)
# # โ‘ฃ VQ
# _, vq_code, _ = decoder(vq_emb, vq=True) # vq_code: (B, 1, Tf)
# # โ‘ค ่งฃ็ ้‡ๅปบ
# vq_post_emb = decoder.quantizer.get_output_from_indices(
# vq_code.transpose(1, 2)
# ) # (B, Tf, 1024)
# vq_post_emb = vq_post_emb.transpose(1, 2) # (B, 1024, Tf)
vq_post_emb = vq_emb
vq_post_emb = fc_post_a(
vq_post_emb.transpose(1, 2)
).transpose(1, 2) # (B, 1024, Tf)
recon_batch = decoder(
vq_post_emb.transpose(1, 2), vq=False
)[0].squeeze(1) # (B, T)
# โ‘ฅ ่ฃๅ‰ชๅˆฐๅŽŸๅง‹้•ฟๅบฆ
results = []
for i, wav_len in enumerate(orig_lens.tolist()):
results.append(recon_batch[i, :wav_len].detach().cpu().numpy())
return results
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
# ไธปๅ‡ฝๆ•ฐ
# โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
def main():
parser = ArgumentParser()
parser.add_argument("--input-dir", type=str,
default="test_audio/input_test")
parser.add_argument("--output-dir", type=str,
default="test_audio/output_test")
parser.add_argument("--ckpt", type=str,
default="/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/log/shengpeng_debug/last.ckpt")
parser.add_argument("--w2v-path", type=str,
default="/apdcephfs/private_jishengpeng2/work/shengpeng/research/X-Codec-2.0/ckpt/w2v-bert-2.0")
parser.add_argument("--batch-size", type=int, default=8)
parser.add_argument("--num-workers", type=int, default=4)
parser.add_argument("--num-gpus", type=int, default=None,
help="ๆ‰‹ๅŠจๆŒ‡ๅฎš GPU ๆ€ปๆ•ฐ๏ผ›้ป˜่ฎค่‡ชๅŠจๆฃ€ๆต‹๏ผˆๆˆ–่ฏปๅ– torchrun ็Žฏๅขƒๅ˜้‡๏ผ‰")
args = parser.parse_args()
sr = 16000
# โ”€โ”€ ๅคš GPU ๅˆ†็‰‡๏ผˆๅ…ผๅฎน torchrun ๅ’Œๅ•ๆœบๆ‰‹ๅŠจๆŒ‡ๅฎš๏ผ‰โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
local_rank = int(os.getenv("LOCAL_RANK", 0))
world_size = int(os.getenv("WORLD_SIZE", 1))
if args.num_gpus is not None:
world_size = args.num_gpus
if torch.cuda.is_available():
device = torch.device(f"cuda:{local_rank}")
torch.cuda.set_device(device)
else:
device = torch.device("cpu")
print("[่ญฆๅ‘Š] ๆœชๆฃ€ๆต‹ๅˆฐ CUDA๏ผŒไฝฟ็”จ CPU ๆŽจ็†๏ผŒ้€Ÿๅบฆ่พƒๆ…ขใ€‚")
# โ”€โ”€ ๆ”ถ้›†้Ÿณ้ข‘ๆ–‡ไปถ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
all_paths = []
for ext in ("wav", "flac", "mp3"):
all_paths += glob(join(args.input_dir, "**", f"*.{ext}"), recursive=True)
all_paths = sorted(set(all_paths))
if world_size > 1:
# ๆŒ‰ rank ๅ‡ๅŒ€ๅˆ†็‰‡
all_paths = np.array_split(all_paths, world_size)[local_rank].tolist()
print(f"[rank {local_rank}] {len(all_paths)} files to process on {device}")
if len(all_paths) == 0:
print(f"[rank {local_rank}] No files found, exiting.")
return
# โ”€โ”€ ๆž„ๅปบ DataLoader โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
dataset = AudioDataset(all_paths, sr, args.w2v_path)
dataloader = DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
collate_fn=collate_fn,
drop_last=False,
)
# โ”€โ”€ ๅŠ ่ฝฝๆจกๅž‹ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
models = load_models(args.ckpt, args.w2v_path, device)
os.makedirs(args.output_dir, exist_ok=True)
# โ”€โ”€ ๆŽจ็†ๅพช็Žฏ โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
st = time()
for wavs, feats, paths, orig_lens in tqdm(dataloader,
desc=f"rank {local_rank}",
dynamic_ncols=True):
recon_list = infer_batch(wavs, feats, orig_lens, models, device, sr)
for recon, src_path in zip(recon_list, paths):
# ไฟ็•™็›ธๅฏนไบŽ input_dir ็š„ๅญ็›ฎๅฝ•็ป“ๆž„
rel = os.path.relpath(src_path, args.input_dir)
# ็ปŸไธ€่พ“ๅ‡บไธบ .wav
rel = os.path.splitext(rel)[0] + ".wav"
dst = join(args.output_dir, rel)
os.makedirs(os.path.dirname(dst), exist_ok=True)
sf.write(dst, recon, sr)
et = time()
print(f"[rank {local_rank}] Done. Total time: {(et - st) / 60:.2f} min")
if __name__ == "__main__":
main()