Audio-to-Audio
English
cn
CleanMel / model /vocos /pretrained.py
SaoYear's picture
Upload folder using huggingface_hub
cf82a4e verified
from __future__ import annotations
from typing import Any, Dict, Tuple, Union, Optional
import torch
import yaml
from huggingface_hub import hf_hub_download
from torch import nn
from model.vocos.feature_extractors import FeatureExtractor, EncodecFeatures
from model.vocos.heads import FourierHead
from model.vocos.models import Backbone
def instantiate_class(args: Union[Any, Tuple[Any, ...]], init: Dict[str, Any]) -> Any:
"""Instantiates a class with the given args and init.
Args:
args: Positional arguments required for instantiation.
init: Dict of the form {"class_path":...,"init_args":...}.
Returns:
The instantiated class object.
"""
kwargs = init.get("init_args", {})
if not isinstance(args, tuple):
args = (args,)
class_module, class_name = init["class_path"].rsplit(".", 1)
module = __import__(class_module, fromlist=[class_name])
args_class = getattr(module, class_name)
return args_class(*args, **kwargs)
class Vocos(nn.Module):
"""
The Vocos class represents a Fourier-based neural vocoder for audio synthesis.
This class is primarily designed for inference, with support for loading from pretrained
model checkpoints. It consists of three main components: a feature extractor,
a backbone, and a head.
"""
def __init__(
self, feature_extractor: nn.Module, backbone: Backbone, head: FourierHead,
):
super().__init__()
self.feature_extractor = feature_extractor
self.backbone = backbone
self.head = head
@classmethod
def from_hparams(cls, config_path: str) -> "Vocos":
"""
Class method to create a new Vocos model instance from hyperparameters stored in a yaml configuration file.
"""
with open(config_path, "r") as f:
config = yaml.safe_load(f)
feature_extractor = instantiate_class(args=(), init=config["feature_extractor"])
backbone = instantiate_class(args=(), init=config["backbone"])
head = instantiate_class(args=(), init=config["head"])
model = cls(feature_extractor=feature_extractor, backbone=backbone, head=head)
return model
@classmethod
def from_pretrained(self, config_path: str, model_path: str, model: nn.Module=None) -> "Vocos":
"""
Class method to create a new Vocos model instance from a pre-trained model stored in the Hugging Face model hub.
"""
if model is None:
model = self.from_hparams(config_path)
state_dict = torch.load(model_path, map_location="cpu")
prefixes = ("backbone", "feature_extractor", "head")
state_dict = {
key: value
for key, value in state_dict.items()
if any(key.startswith(prefix) for prefix in prefixes)
}
if isinstance(model.feature_extractor, EncodecFeatures):
encodec_parameters = {
"feature_extractor.encodec." + key: value
for key, value in model.feature_extractor.encodec.state_dict().items()
}
state_dict.update(encodec_parameters)
model.load_state_dict(state_dict)
model.eval()
return model
@torch.inference_mode()
def forward(self, features_input: torch.Tensor, X_norm, **kwargs: Any) -> torch.Tensor:
"""
Method to run a copy-synthesis from audio waveform. The feature extractor first processes the audio input,
which is then passed through the backbone and the head to reconstruct the audio output.
Args:
audio_input (Tensor): The input tensor representing the audio waveform of shape (B, T),
where B is the batch size and L is the waveform length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
audio_output = self.decode(features_input, **kwargs)
return audio_output / X_norm
@torch.inference_mode()
def decode(self, features_input: torch.Tensor, **kwargs: Any) -> torch.Tensor:
"""
Method to decode audio waveform from already calculated features. The features input is passed through
the backbone and the head to reconstruct the audio output.
Args:
features_input (Tensor): The input tensor of features of shape (B, C, L), where B is the batch size,
C denotes the feature dimension, and L is the sequence length.
Returns:
Tensor: The output tensor representing the reconstructed audio waveform of shape (B, T).
"""
x = self.backbone(features_input, **kwargs)
audio_output = self.head(x)
return audio_output
@torch.inference_mode()
def codes_to_features(self, codes: torch.Tensor) -> torch.Tensor:
"""
Transforms an input sequence of discrete tokens (codes) into feature embeddings using the feature extractor's
codebook weights.
Args:
codes (Tensor): The input tensor. Expected shape is (K, L) or (K, B, L),
where K is the number of codebooks, B is the batch size and L is the sequence length.
Returns:
Tensor: Features of shape (B, C, L), where B is the batch size, C denotes the feature dimension,
and L is the sequence length.
"""
assert isinstance(
self.feature_extractor, EncodecFeatures
), "Feature extractor should be an instance of EncodecFeatures"
if codes.dim() == 2:
codes = codes.unsqueeze(1)
n_bins = self.feature_extractor.encodec.quantizer.bins
offsets = torch.arange(0, n_bins * len(codes), n_bins, device=codes.device)
embeddings_idxs = codes + offsets.view(-1, 1, 1)
features = torch.nn.functional.embedding(embeddings_idxs, self.feature_extractor.codebook_weights).sum(dim=0)
features = features.transpose(1, 2)
return features
if __name__ == "__main__":
model = Vocos.from_pretrained(
"/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/pretrained_rec_normed.yaml",
"/nvmework3/shaonian/MelSpatialNet/MelSpatialNet/models/vocos/pretrained/vocos_hop128_clip1e-5_rts.ckpt").to("meta")
x = torch.randn(1, 80, 501)
x = x.to('meta')
from torch.utils.flop_counter import FlopCounterMode # requires torch>=2.1.0
with FlopCounterMode(model, display=False) as fcm:
y = model.decode(x)
flops_forward_eval = fcm.get_total_flops()
params_eval = sum(param.numel() for param in model.parameters())
print(f"flops_forward={flops_forward_eval/4e9:.2f}G, params={params_eval/1e6:.2f} M")