prithivMLmods's picture
update [.sam_audio]
70af406 verified
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n
import math
from abc import ABCMeta, abstractmethod
from typing import Union
import dacvae
import torch
from sam_audio.model.config import DACVAEConfig
class Encoder(torch.nn.Module, metaclass=ABCMeta):
@abstractmethod
def forward(self, waveform: torch.Tensor) -> torch.Tensor: ...
class Codec(Encoder):
@abstractmethod
def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor: ...
@abstractmethod
def wav_idx_to_feature_idx(
self, wav_idx: Union[torch.Tensor, int], sample_rate=None
) -> Union[torch.Tensor, int]: ...
@abstractmethod
def feature_idx_to_wav_idx(
self, feature_idx: Union[torch.Tensor, int], sample_rate=None
) -> Union[torch.Tensor, int]: ...
@staticmethod
def cast_to_int(
x: Union[int, torch.Tensor],
) -> Union[int, torch.Tensor]:
if isinstance(x, torch.Tensor):
return x.int()
else:
return int(x)
class DACVAEEncoder(Encoder):
def __init__(self, config: DACVAEConfig) -> None:
super().__init__()
model = dacvae.DACVAE(
encoder_dim=config.encoder_dim,
encoder_rates=config.encoder_rates,
latent_dim=config.latent_dim,
decoder_dim=config.decoder_dim,
decoder_rates=config.decoder_rates,
n_codebooks=config.n_codebooks,
codebook_size=config.codebook_size,
codebook_dim=config.codebook_dim,
quantizer_dropout=config.quantizer_dropout,
sample_rate=config.sample_rate,
).eval()
self._setup_model(model)
self.hop_length = config.hop_length
self.sample_rate = config.sample_rate
def _setup_model(self, model):
self.encoder = model.encoder
self.quantizer = model.quantizer
def forward(self, waveform: torch.Tensor) -> torch.Tensor:
with torch.no_grad():
z = self.encoder(self._pad(waveform))
mean, scale = self.quantizer.in_proj(z).chunk(2, dim=1)
encoded_frames = mean
return encoded_frames
def _pad(self, wavs):
length = wavs.size(-1)
if length % self.hop_length:
p1d = (0, self.hop_length - (length % self.hop_length))
return torch.nn.functional.pad(wavs, p1d, "reflect")
else:
return wavs
class DACVAE(DACVAEEncoder, Codec):
def _setup_model(self, model):
super()._setup_model(model)
self.decoder = model.decoder
def decode(self, encoded_frames: torch.Tensor) -> torch.Tensor:
emb = self.quantizer.out_proj(encoded_frames)
return self.decoder(emb)
def feature_idx_to_wav_idx(self, feature_idx, sample_rate=None):
if sample_rate is None:
sample_rate = self.sample_rate
orig_freq = sample_rate
new_freq = self.sample_rate
wav_chunklen = feature_idx * self.hop_length * (orig_freq / new_freq)
return self.cast_to_int(wav_chunklen)
def wav_idx_to_feature_idx(self, wav_idx, sample_rate=None):
ceil = math.ceil
if torch.is_tensor(wav_idx):
ceil = torch.ceil
if sample_rate is None:
sample_rate = self.sample_rate
orig_freq = sample_rate
new_freq = self.sample_rate
target_length = ceil(new_freq * wav_idx / orig_freq)
res = ceil(target_length / self.hop_length)
return self.cast_to_int(res)