Egnalkram's picture
Upload folder using huggingface_hub
4689c2b verified
import os
from typing import Literal, Optional
import torch
import torch.nn as nn
from mmgp import offload
from shared.utils import files_locator as fl
from ..autoencoder.vae import VAE, get_my_vae
from ..bigvgan import BigVGAN
from ..bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
from ...model.utils.distributions import DiagonalGaussianDistribution
_BIGVGAN_V2_FOLDER = "bigvgan_v2_44khz_128band_512x"
def _resolve_bigvgan_v2_files():
weights_path = fl.locate_file(
os.path.join(_BIGVGAN_V2_FOLDER, "bigvgan_generator.pt"), error_if_none=False
)
config_path = fl.locate_file(
os.path.join(_BIGVGAN_V2_FOLDER, "config.json"), error_if_none=False
)
if weights_path is None or config_path is None:
raise FileNotFoundError(
f"Missing BigVGANv2 files in '{_BIGVGAN_V2_FOLDER}'. "
"Expected 'config.json' and 'bigvgan_generator.pt'."
)
return weights_path, config_path
def _preprocess_bigvgan_v2_state_dict(state_dict, quantization_map=None, tied_weights_map=None):
if isinstance(state_dict, dict) and isinstance(state_dict.get("generator"), dict):
state_dict = state_dict["generator"]
return state_dict, quantization_map, tied_weights_map
class AutoEncoderModule(nn.Module):
def __init__(self,
*,
vae_ckpt_path,
vocoder_ckpt_path: Optional[str] = None,
mode: Literal['16k', '44k'],
need_vae_encoder: bool = True):
super().__init__()
self.vae: VAE = get_my_vae(mode).eval()
vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
self.vae.load_state_dict(vae_state_dict)
self.vae.remove_weight_norm()
if mode == '16k':
assert vocoder_ckpt_path is not None
self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
elif mode == '44k':
vocoder_ckpt_path, vocoder_config_path = _resolve_bigvgan_v2_files()
self.vocoder = offload.fast_load_transformers_model(
vocoder_ckpt_path,
modelClass=BigVGANv2,
forcedConfigPath=vocoder_config_path,
preprocess_sd=_preprocess_bigvgan_v2_state_dict,
configKwargs={"use_cuda_kernel": False},
writable_tensors=False,
default_dtype=torch.float32,
)
self.vocoder.remove_weight_norm()
self.vocoder.eval()
else:
raise ValueError(f'Unknown mode: {mode}')
for param in self.parameters():
param.requires_grad = False
if not need_vae_encoder:
del self.vae.encoder
@torch.inference_mode()
def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
return self.vae.encode(x)
@torch.inference_mode()
def decode(self, z: torch.Tensor) -> torch.Tensor:
return self.vae.decode(z)
@torch.inference_mode()
def vocode(self, spec: torch.Tensor) -> torch.Tensor:
return self.vocoder(spec)