| import os
|
| from typing import Literal, Optional
|
|
|
| import torch
|
| import torch.nn as nn
|
| from mmgp import offload
|
| from shared.utils import files_locator as fl
|
|
|
| from ..autoencoder.vae import VAE, get_my_vae
|
| from ..bigvgan import BigVGAN
|
| from ..bigvgan_v2.bigvgan import BigVGAN as BigVGANv2
|
| from ...model.utils.distributions import DiagonalGaussianDistribution
|
|
|
| _BIGVGAN_V2_FOLDER = "bigvgan_v2_44khz_128band_512x"
|
|
|
|
|
| def _resolve_bigvgan_v2_files():
|
| weights_path = fl.locate_file(
|
| os.path.join(_BIGVGAN_V2_FOLDER, "bigvgan_generator.pt"), error_if_none=False
|
| )
|
| config_path = fl.locate_file(
|
| os.path.join(_BIGVGAN_V2_FOLDER, "config.json"), error_if_none=False
|
| )
|
| if weights_path is None or config_path is None:
|
| raise FileNotFoundError(
|
| f"Missing BigVGANv2 files in '{_BIGVGAN_V2_FOLDER}'. "
|
| "Expected 'config.json' and 'bigvgan_generator.pt'."
|
| )
|
| return weights_path, config_path
|
|
|
|
|
| def _preprocess_bigvgan_v2_state_dict(state_dict, quantization_map=None, tied_weights_map=None):
|
| if isinstance(state_dict, dict) and isinstance(state_dict.get("generator"), dict):
|
| state_dict = state_dict["generator"]
|
| return state_dict, quantization_map, tied_weights_map
|
|
|
|
|
| class AutoEncoderModule(nn.Module):
|
|
|
| def __init__(self,
|
| *,
|
| vae_ckpt_path,
|
| vocoder_ckpt_path: Optional[str] = None,
|
| mode: Literal['16k', '44k'],
|
| need_vae_encoder: bool = True):
|
| super().__init__()
|
| self.vae: VAE = get_my_vae(mode).eval()
|
| vae_state_dict = torch.load(vae_ckpt_path, weights_only=True, map_location='cpu')
|
| self.vae.load_state_dict(vae_state_dict)
|
| self.vae.remove_weight_norm()
|
|
|
| if mode == '16k':
|
| assert vocoder_ckpt_path is not None
|
| self.vocoder = BigVGAN(vocoder_ckpt_path).eval()
|
| elif mode == '44k':
|
| vocoder_ckpt_path, vocoder_config_path = _resolve_bigvgan_v2_files()
|
| self.vocoder = offload.fast_load_transformers_model(
|
| vocoder_ckpt_path,
|
| modelClass=BigVGANv2,
|
| forcedConfigPath=vocoder_config_path,
|
| preprocess_sd=_preprocess_bigvgan_v2_state_dict,
|
| configKwargs={"use_cuda_kernel": False},
|
| writable_tensors=False,
|
| default_dtype=torch.float32,
|
| )
|
| self.vocoder.remove_weight_norm()
|
| self.vocoder.eval()
|
| else:
|
| raise ValueError(f'Unknown mode: {mode}')
|
|
|
| for param in self.parameters():
|
| param.requires_grad = False
|
|
|
| if not need_vae_encoder:
|
| del self.vae.encoder
|
|
|
| @torch.inference_mode()
|
| def encode(self, x: torch.Tensor) -> DiagonalGaussianDistribution:
|
| return self.vae.encode(x)
|
|
|
| @torch.inference_mode()
|
| def decode(self, z: torch.Tensor) -> torch.Tensor:
|
| return self.vae.decode(z)
|
|
|
| @torch.inference_mode()
|
| def vocode(self, spec: torch.Tensor) -> torch.Tensor:
|
| return self.vocoder(spec)
|
|
|