Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n | |
| import math | |
| from abc import ABCMeta, abstractmethod | |
| from typing import Union | |
| import dacvae | |
| import torch | |
| from sam_audio.model.config import DACVAEConfig | |
| class Encoder(torch.nn.Module, metaclass=ABCMeta): | |
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: ... | |
| class Codec(Encoder): | |
| def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ... | |
| def wav_idx_to_feature_idx( | |
| self, wav_idx: Union[torch.Tensor, int], sample_rate=None | |
| ) -> Union[torch.Tensor, int]: ... | |
| def feature_idx_to_wav_idx( | |
| self, feature_idx: Union[torch.Tensor, int], sample_rate=None | |
| ) -> Union[torch.Tensor, int]: ... | |
| def cast_to_int( | |
| x: Union[int, torch.Tensor], | |
| ) -> Union[int, torch.Tensor]: | |
| if isinstance(x, torch.Tensor): | |
| return x.int() | |
| else: | |
| return int(x) | |
| class DACVAEEncoder(Encoder): | |
| def __init__(self, config: DACVAEConfig) -> None: | |
| super().__init__() | |
| model = dacvae.DACVAE( | |
| encoder_dim=config.encoder_dim, | |
| encoder_rates=config.encoder_rates, | |
| latent_dim=config.latent_dim, | |
| decoder_dim=config.decoder_dim, | |
| decoder_rates=config.decoder_rates, | |
| n_codebooks=config.n_codebooks, | |
| codebook_size=config.codebook_size, | |
| codebook_dim=config.codebook_dim, | |
| quantizer_dropout=config.quantizer_dropout, | |
| sample_rate=config.sample_rate, | |
| ).eval() | |
| self._setup_model(model) | |
| self.hop_length = config.hop_length | |
| self.sample_rate = config.sample_rate | |
| def _setup_model(self, model): | |
| self.encoder = model.encoder | |
| self.quantizer = model.quantizer | |
| def forward(self, waveform: torch.Tensor) -> torch.Tensor: | |
| with torch.no_grad(): | |
| z = self.encoder(self._pad(waveform)) | |
| mean, scale = self.quantizer.in_proj(z).chunk(2, dim=1) | |
| encoded_frames = mean | |
| return encoded_frames | |
| def _pad(self, wavs): | |
| length = wavs.size(-1) | |
| if length % self.hop_length: | |
| p1d = (0, self.hop_length - (length % self.hop_length)) | |
| return torch.nn.functional.pad(wavs, p1d, "reflect") | |
| else: | |
| return wavs | |
| class DACVAE(DACVAEEncoder, Codec): | |
| def _setup_model(self, model): | |
| super()._setup_model(model) | |
| self.decoder = model.decoder | |
| def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: | |
| emb = self.quantizer.out_proj(encoded_frames) | |
| return self.decoder(emb) | |
| def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None): | |
| if sample_rate is None: | |
| sample_rate = self.sample_rate | |
| orig_freq = sample_rate | |
| new_freq = self.sample_rate | |
| wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq) | |
| return self.cast_to_int(wav_chunklen) | |
| def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None): | |
| ceil = math.ceil | |
| if torch.is_tensor(wav_idx): | |
| ceil = torch.ceil | |
| if sample_rate is None: | |
| sample_rate = self.sample_rate | |
| orig_freq = sample_rate | |
| new_freq = self.sample_rate | |
| target_length = ceil(new_freq * wav_idx / orig_freq) | |
| res = ceil(target_length / self.hop_length) | |
| return self.cast_to_int(res) | |