| from __future__ import annotations |
|
|
| from typing import Any, Dict, Tuple, Union, Optional |
|
|
| import torch |
| import yaml |
| from huggingface_hub import hf_hub_download |
| from torch import nn |
| from model.vocos.feature_extractors import FeatureExtractor, EncodecFeatures |
| from model.vocos.heads import FourierHead |
| from model.vocos.models import Backbone |
|
|
|
|
| def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any: |
| """Instantiates a class with the given args and init. |
| |
| Args: |
| args: Positional arguments required for instantiation. |
| init: Dict of the form {"class_path":...,"init_args":...}. |
| |
| Returns: |
| The instantiated class object. |
| """ |
| kwargs = init.get("init_args", {}) |
| if not isinstance(args, tuple): |
| args = (args,) |
| class_module, class_name = init["class_path"].rsplit(".", 1) |
| module = __import__(class_module, fromlist=[class_name]) |
| args_class = getattr(module, class_name) |
| return args_class(*args, **kwargs) |
|
|
|
|
| class Vocos(nn.Module): |
| """ |
| The Vocos class represents a Fourier-based neural vocoder for audio synthesis. |
| This class is primarily designed for inference, with support for loading from pretrained |
| model checkpoints. It consists of three main components: a feature extractor, |
| a backbone, and a head. |
| """ |
|
|
| def __init__( |
| self, feature_extractor: nn.Module, backbone: Backbone, head: FourierHead, |
| ): |
| super().__init__() |
| self.feature_extractor = feature_extractor |
| self.backbone = backbone |
| self.head = head |
|
|
| @classmethod |
| def from_hparams(cls, config_path: str) -> "Vocos": |
| """ |
| Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file. |
| """ |
| with open(config_path, "r") as f: |
| config = yaml.safe_load(f) |
| feature_extractor = instantiate_class(args=(), init=config["feature_extractor"]) |
| backbone = instantiate_class(args=(), init=config["backbone"]) |
| head = instantiate_class(args=(), init=config["head"]) |
| model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head) |
| return model |
|
|
| @classmethod |
| def from_pretrained(self, config_path: str, model_path: str, model: nn.Module=None) -> "Vocos": |
| """ |
| Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub. |
| """ |
| if model is None: |
| model = self.from_hparams(config_path) |
| state_dict = torch.load(model_path, map_location="cpu") |
| prefixes = ("backbone", "feature_extractor", "head") |
| state_dict = { |
| key: value |
| for key, value in state_dict.items() |
| if any(key.startswith(prefix) for prefix in prefixes) |
| } |
| if isinstance(model.feature_extractor, EncodecFeatures): |
| encodec_parameters = { |
| "feature_extractor.encodec." + key: value |
| for key, value in model.feature_extractor.encodec.state_dict().items() |
| } |
| state_dict.update(encodec_parameters) |
| model.load_state_dict(state_dict) |
| model.eval() |
| return model |
|
|
| @torch.inference_mode() |
| def forward(self, features_input: torch.Tensor, X_norm, **kwargs: Any) -> torch.Tensor: |
| """ |
| Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input, |
| which is then passed through the backbone and the head to reconstruct the audio output. |
| |
| Args: |
| audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T), |
| where B is the batch size and L is the waveform length. |
| |
| |
| Returns: |
| Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). |
| """ |
| audio_output = self.decode(features_input, **kwargs) |
| return audio_output / X_norm |
|
|
| @torch.inference_mode() |
| def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor: |
| """ |
| Method to decode audio waveform from already calculated features. The features input is passed through |
| the backbone and the head to reconstruct the audio output. |
| |
| Args: |
| features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size, |
| C denotes the feature dimension, and L is the sequence length. |
| |
| Returns: |
| Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T). |
| """ |
| x = self.backbone(features_input, **kwargs) |
| audio_output = self.head(x) |
| return audio_output |
|
|
| @torch.inference_mode() |
| def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor: |
| """ |
| Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's |
| codebook weights. |
| |
| Args: |
| codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L), |
| where K is the number of codebooks, B is the batch size and L is the sequence length. |
| |
| Returns: |
| Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension, |
| and L is the sequence length. |
| """ |
| assert isinstance( |
| self.feature_extractor, EncodecFeatures |
| ), "Feature extractor should be an instance of EncodecFeatures" |
|
|
| if codes.dim() == 2: |
| codes = codes.unsqueeze(1) |
|
|
| n_bins = self.feature_extractor.encodec.quantizer.bins |
| offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device) |
| embeddings_idxs = codes + offsets.view(-1, 1, 1) |
| features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0) |
| features = features.transpose(1, 2) |
|
|
| return features |
|
|
|
|
| if __name__ == "__main__": |
| model = Vocos.from_pretrained( |
| "/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/pretrained_rec_normed.yaml", |
| "/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/vocos_hop128_clip1e-5_rts.ckpt").to("meta") |
| x = torch.randn(1, 80, 501) |
| x = x.to('meta') |
| from torch.utils.flop_counter import FlopCounterMode |
| with FlopCounterMode(model, display=False) as fcm: |
| y = model.decode(x) |
| flops_forward_eval = fcm.get_total_flops() |
|
|
| params_eval = sum(param.numel() for param in model.parameters()) |
| print(f"flops_forward={flops_forward_eval/4e9:.2f}G, params={params_eval/1e6:.2f} M") |