| |
| |
|
|
| import os |
| import sys |
|
|
| |
| parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) |
| sys.path.append(parent_dir) |
|
|
| import torch |
| import json |
| from env import AttrDict |
| from bigvgan import BigVGAN |
| from time import time |
| from tqdm import tqdm |
| from meldataset import mel_spectrogram, MAX_WAV_VALUE |
| from scipy.io.wavfile import write |
| import numpy as np |
|
|
| import argparse |
|
|
| torch.backends.cudnn.benchmark = True |
|
|
| |
| torch.set_printoptions(linewidth=200, threshold=10_000) |
|
|
|
|
| def generate_soundwave(duration=5.0, sr=24000): |
| t = np.linspace(0, duration, int(sr * duration), False, dtype=np.float32) |
|
|
| modulation = np.sin(2 * np.pi * t / duration) |
|
|
| min_freq = 220 |
| max_freq = 1760 |
| frequencies = min_freq + (max_freq - min_freq) * (modulation + 1) / 2 |
| soundwave = np.sin(2 * np.pi * frequencies * t) |
|
|
| soundwave = soundwave / np.max(np.abs(soundwave)) * 0.95 |
|
|
| return soundwave, sr |
|
|
|
|
| def get_mel(x, h): |
| return mel_spectrogram( |
| x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax |
| ) |
|
|
|
|
| def load_checkpoint(filepath, device): |
| assert os.path.isfile(filepath) |
| print(f"Loading '{filepath}'") |
| checkpoint_dict = torch.load(filepath, map_location=device) |
| print("Complete.") |
| return checkpoint_dict |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser( |
| description="Test script to check CUDA kernel correctness." |
| ) |
| parser.add_argument( |
| "--checkpoint_file", |
| type=str, |
| required=True, |
| help="Path to the checkpoint file. Assumes config.json exists in the directory.", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| config_file = os.path.join(os.path.split(args.checkpoint_file)[0], "config.json") |
| with open(config_file) as f: |
| config = f.read() |
| json_config = json.loads(config) |
| h = AttrDict({**json_config}) |
|
|
| print("loading plain Pytorch BigVGAN") |
| generator_original = BigVGAN(h).to("cuda") |
| print("loading CUDA kernel BigVGAN with auto-build") |
| generator_cuda_kernel = BigVGAN(h, use_cuda_kernel=True).to("cuda") |
|
|
| state_dict_g = load_checkpoint(args.checkpoint_file, "cuda") |
| generator_original.load_state_dict(state_dict_g["generator"]) |
| generator_cuda_kernel.load_state_dict(state_dict_g["generator"]) |
|
|
| generator_original.remove_weight_norm() |
| generator_original.eval() |
| generator_cuda_kernel.remove_weight_norm() |
| generator_cuda_kernel.eval() |
|
|
| |
| num_sample = 10 |
| num_mel_frame = 16384 |
| |
| |
| diff = 0.0 |
| for i in tqdm(range(num_sample)): |
| |
| data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") |
| |
| with torch.inference_mode(): |
| audio_original = generator_original(data) |
| |
| with torch.inference_mode(): |
| audio_cuda_kernel = generator_cuda_kernel(data) |
|
|
| |
| test_result = (audio_original - audio_cuda_kernel).abs() |
| diff += test_result.mean(dim=-1).item() |
| |
| diff /= num_sample |
| if ( |
| diff <= 2e-3 |
| ): |
| print( |
| f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference" |
| f"\n > mean_difference={diff}" |
| f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}" |
| f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" |
| ) |
| else: |
| print( |
| f"\n[Fail] test CUDA fused vs. plain torch BigVGAN inference" |
| f"\n > mean_difference={diff}" |
| f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, " |
| f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}" |
| ) |
| |
| del data, audio_original, audio_cuda_kernel |
| |
| |
| toc_total_original = 0 |
| toc_total_cuda_kernel = 0 |
| vram_used_original_total = 0 |
| vram_used_cuda_kernel_total = 0 |
| audio_length_total = 0 |
|
|
| |
| for i in tqdm(range(num_sample)): |
| torch.cuda.reset_peak_memory_stats(device="cuda") |
| data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") |
| torch.cuda.synchronize() |
| tic = time() |
| with torch.inference_mode(): |
| audio_original = generator_original(data) |
| torch.cuda.synchronize() |
| toc = time() - tic |
| toc_total_original += toc |
|
|
| vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda") |
| |
| del data, audio_original |
| torch.cuda.empty_cache() |
|
|
| |
| for i in tqdm(range(num_sample)): |
| torch.cuda.reset_peak_memory_stats(device="cuda") |
| data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda") |
| torch.cuda.synchronize() |
| tic = time() |
| with torch.inference_mode(): |
| audio_cuda_kernel = generator_cuda_kernel(data) |
| torch.cuda.synchronize() |
| toc = time() - tic |
| toc_total_cuda_kernel += toc |
| |
| audio_length_total += audio_cuda_kernel.shape[-1] |
| |
| vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda") |
| |
| del data, audio_cuda_kernel |
| torch.cuda.empty_cache() |
|
|
| |
| audio_second = audio_length_total / h.sampling_rate |
| khz_original = audio_length_total / toc_total_original / 1000 |
| khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000 |
| vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3) |
| vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3) |
|
|
| |
| print( |
| f"Original BigVGAN: took {toc_total_original:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_original:.1f}kHz, {audio_second / toc_total_original:.1f} faster than realtime, VRAM used {vram_used_original_gb:.1f} GB" |
| ) |
| print( |
| f"CUDA kernel BigVGAN: took {toc_total_cuda_kernel:.2f} seconds to generate {audio_second:.2f} seconds of audio, {khz_cuda_kernel:.1f}kHz, {audio_second / toc_total_cuda_kernel:.1f} faster than realtime, VRAM used {vram_used_cuda_kernel_gb:.1f} GB" |
| ) |
| print(f"speedup of CUDA kernel: {khz_cuda_kernel / khz_original}") |
| print(f"VRAM saving of CUDA kernel: {vram_used_original_gb / vram_used_cuda_kernel_gb}") |
|
|
| |
| audio_real, sr = generate_soundwave(duration=5.0, sr=h.sampling_rate) |
| audio_real = torch.tensor(audio_real).to("cuda") |
| |
| x = get_mel(audio_real.unsqueeze(0), h) |
|
|
| with torch.inference_mode(): |
| y_g_hat_original = generator_original(x) |
| y_g_hat_cuda_kernel = generator_cuda_kernel(x) |
|
|
| audio_real = audio_real.squeeze() |
| audio_real = audio_real * MAX_WAV_VALUE |
| audio_real = audio_real.cpu().numpy().astype("int16") |
|
|
| audio_original = y_g_hat_original.squeeze() |
| audio_original = audio_original * MAX_WAV_VALUE |
| audio_original = audio_original.cpu().numpy().astype("int16") |
|
|
| audio_cuda_kernel = y_g_hat_cuda_kernel.squeeze() |
| audio_cuda_kernel = audio_cuda_kernel * MAX_WAV_VALUE |
| audio_cuda_kernel = audio_cuda_kernel.cpu().numpy().astype("int16") |
|
|
| os.makedirs("tmp", exist_ok=True) |
| output_file_real = os.path.join("tmp", "audio_real.wav") |
| output_file_original = os.path.join("tmp", "audio_generated_original.wav") |
| output_file_cuda_kernel = os.path.join("tmp", "audio_generated_cuda_kernel.wav") |
| write(output_file_real, h.sampling_rate, audio_real) |
| write(output_file_original, h.sampling_rate, audio_original) |
| write(output_file_cuda_kernel, h.sampling_rate, audio_cuda_kernel) |
| print("Example generated audios of original vs. fused CUDA kernel written to tmp!") |
| print("Done") |
|
|