ChristophSchuhmann's picture
Add model code, inference script, and examples
dfd1909 verified
from typing import Union, Optional
from numpy import ndarray
from torch import Tensor
import os
import torch
import torch.nn as nn
from TorchJaekwon.Util.UtilData import UtilData
from TorchJaekwon.Util.UtilAudioMelSpec import UtilAudioMelSpec
from FlashSR.AudioSR.autoencoder import AutoencoderKL
from TorchJaekwon.Util.UtilTorch import UtilTorch
class VAEWrapper:
def __init__(self,
autoencoder_ckpt_path:str,
sr:int = '48000',
frame_sec:float = 5.12,
device:torch.device = torch.device('cpu'),
scale_factor_z:float = 0.3342
) -> None:
vocoder_config_dir:str = f'{os.path.dirname(os.path.abspath(__file__))}/AudioSR/args'
self.sr:int = sr
self.frame_sec:float = frame_sec
self.scale_factor_z:float = scale_factor_z
self.device:torch.device = device
autoencoder:nn.Module = AutoencoderKL(**UtilData.yaml_load(f'{vocoder_config_dir}/model_argument.yaml'))
autoencoder_ckpt = torch.load(autoencoder_ckpt_path, map_location='cpu')
autoencoder.load_state_dict(autoencoder_ckpt)
autoencoder = autoencoder.to(device)
self.autoencoder = UtilTorch.freeze_param(autoencoder)
self.mel_config:dict = UtilData.yaml_load(f'{vocoder_config_dir}/mel_argument.yaml')
self.util_mel_spec = UtilAudioMelSpec(**self.mel_config)
def to(self, device):
self.device = device
self.autoencoder = self.autoencoder.to(self.device)
@torch.no_grad()
def encode_to_z(self,
audio:Union[ndarray,Tensor], # [batch, time]
normalize:bool = True,
scale_dict:dict = None,
) -> dict: #[batch, 16, time / (hop * 8), mel_bin / 8] mel_bin: 256
assert len(audio.shape) == 2, f'audio shape must be [batch, time] but got {audio.shape}'
result_dict:dict = {'wav': audio}
if normalize:
audio, scale_dict = self.normalize_wav(audio, scale_dict=scale_dict)
result_dict['norm_wav'] = audio
result_dict.update(scale_dict)
mel_spec:Tensor = self.audio_to_mel(audio)
result_dict['mel_spec'] = mel_spec
encoder_posterior = self.autoencoder.encode(mel_spec)
z = encoder_posterior.sample() * self.scale_factor_z
result_dict['z'] = z
return result_dict
def normalize_wav(self, waveform:Union[Tensor], scale_dict:dict):
mean_scale_factor = torch.mean(waveform, dim=1, keepdim=True) if scale_dict is None else scale_dict['mean_scale_factor']
waveform = waveform - mean_scale_factor
var_scale_factor = torch.max(torch.abs(waveform), dim=1, keepdim=True)[0] if scale_dict is None else scale_dict['var_scale_factor']
waveform = waveform / (var_scale_factor + 1e-8)
return waveform * 0.5, {'mean_scale_factor':mean_scale_factor, 'var_scale_factor':var_scale_factor}
def denormalize_wav(self, waveform:Union[Tensor], scale_dict:dict):
waveform = waveform * 2.0
waveform = waveform * (scale_dict['var_scale_factor'] + 1e-8)
waveform = waveform + scale_dict['mean_scale_factor']
return waveform
def get_mel_spec(self, audio:Union[ndarray,Tensor]):
return self.util_mel_spec.get_hifigan_mel_spec(audio).to(self.device)
@torch.no_grad()
def audio_to_mel(self, audio):
mel_spec:Tensor = self.util_mel_spec.get_hifigan_mel_spec(audio).to(self.device)
if len(mel_spec.shape) == 3: #to make [batch, channel, freq, time]
mel_spec = mel_spec.unsqueeze(1)
return mel_spec.permute(0, 1, 3, 2)
def z_to_audio(self,z:Tensor, scale_dict:dict = None, with_no_grad:bool = True):
if with_no_grad:
with torch.no_grad():
mel_spec = self.z_to_mel(z)
audio = self.mel_to_audio(mel_spec, scale_dict)
return audio
else:
mel_spec = self.z_to_mel(z, with_no_grad=False)
audio = self.mel_to_audio(mel_spec, scale_dict, with_no_grad=False)
return audio
def z_to_mel(self,z:Tensor, with_no_grad:bool = True):
if with_no_grad:
with torch.no_grad():
z = (1.0 / self.scale_factor_z) * z
mel_spec = self.autoencoder.decode(z)
return mel_spec
else:
z = (1.0 / self.scale_factor_z) * z
mel_spec = self.autoencoder.decode(z)
return mel_spec
def mel_to_audio(self, mel_spec:Tensor, scale_dict:dict = None, with_no_grad:bool = True):
if with_no_grad:
with torch.no_grad():
mel_spec = mel_spec.permute(0, 1, 3, 2).squeeze(1)
audio = self.autoencoder.vocoder(mel_spec)
if scale_dict is not None: audio = self.denormalize_wav(audio)
return audio
else:
mel_spec = mel_spec.permute(0, 1, 3, 2).squeeze(1)
audio = self.autoencoder.vocoder(mel_spec)
if scale_dict is not None: audio = self.denormalize_wav(audio)
return audio