import torch import torch.nn.functional as F from torch import nn from transformers import PreTrainedModel from .config_model import SoundStreamConfig class CausalConv1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int = 1, dilation: int = 1, ): super().__init__() self.left_padding = dilation * (kernel_size - 1) self.conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, ) def forward(self, x): x = F.pad(x, (self.left_padding, 0)) return self.conv(x) class CausalConvTranspose1d(nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: int, stride: int, ): super().__init__() self.stride = stride self.conv_transpose = nn.ConvTranspose1d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, ) def forward(self, x): target_length = x.shape[-1] * self.stride x = self.conv_transpose(x) return x[..., :target_length] class ResidualUnit(nn.Module): def __init__(self, channels: int, dilation: int): super().__init__() self.block = nn.Sequential( nn.ELU(), CausalConv1d(kernel_size=7, in_channels=channels, out_channels=channels, dilation=dilation), nn.ELU(), nn.Conv1d(kernel_size=1, in_channels=channels, out_channels=channels) ) def forward(self, x): return x + self.block(x) class EncoderBlock(nn.Module): def __init__(self, channels: int, s: int): super().__init__() self.block = nn.Sequential( ResidualUnit(channels=channels // 2, dilation=1), ResidualUnit(channels=channels // 2, dilation=3), ResidualUnit(channels=channels // 2, dilation=9), CausalConv1d(kernel_size=2 * s, in_channels=channels // 2, out_channels=channels, stride=s) ) def forward(self, x): return self.block(x) class DecoderBlock(nn.Module): def __init__(self, channels: int, s: int): super().__init__() self.block = nn.Sequential( CausalConvTranspose1d(kernel_size=2 * s, in_channels=channels, out_channels=channels // 2, stride=s), ResidualUnit(channels=channels // 2, dilation=1), ResidualUnit(channels=channels // 2, dilation=3), ResidualUnit(channels=channels // 2, dilation=9) ) def forward(self, x): return self.block(x) class Encoder(nn.Module): def __init__(self, channels: int = 16, dim: int = 512): super().__init__() self.encoder = nn.Sequential( CausalConv1d(kernel_size=7, in_channels=1, out_channels=channels), EncoderBlock(channels=2 * channels, s=2), EncoderBlock(channels=4 * channels, s=4), EncoderBlock(channels=8 * channels, s=5), EncoderBlock(channels=16 * channels, s=5), CausalConv1d(kernel_size=3, in_channels=16 * channels, out_channels=dim) ) def forward(self, audio): return self.encoder(audio) class Decoder(nn.Module): def __init__(self, channels: int = 16, dim: int = 512): super().__init__() self.decoder = nn.Sequential( CausalConv1d(kernel_size=7, in_channels=dim, out_channels=16 * channels), DecoderBlock(channels=16 * channels, s=5), DecoderBlock(channels=8 * channels, s=5), DecoderBlock(channels=4 * channels, s=4), DecoderBlock(channels=2 * channels, s=2), CausalConv1d(kernel_size=7, in_channels=channels, out_channels=1) ) def forward(self, quantized): return self.decoder(quantized) @torch.no_grad() def _k_means(vectors, num_clusters, num_iters): n = vectors.size(0) device = vectors.device if n >= num_clusters: init_indices = torch.randperm(n, device=device)[:num_clusters] else: init_indices = torch.randint(0, n, (num_clusters,), device=device) centroids = vectors[init_indices].clone() for _ in range(num_iters): dists = ( vectors.pow(2).sum(1, keepdim=True) - 2 * vectors @ centroids.t() + centroids.pow(2).sum(1) ) assignments = dists.argmin(1) counts = torch.bincount(assignments, minlength=num_clusters).to(vectors.dtype) sums = torch.zeros_like(centroids) sums.index_add_(0, assignments, vectors) non_empty = counts > 0 if non_empty.any(): centroids[non_empty] = sums[non_empty] / counts[non_empty].unsqueeze(1) empty = ~non_empty if empty.any(): centroids[empty] = vectors[torch.randint(0, n, (int(empty.sum()),), device=device)] dists = ( vectors.pow(2).sum(1, keepdim=True) - 2 * vectors @ centroids.t() + centroids.pow(2).sum(1) ) counts = torch.bincount(dists.argmin(1), minlength=num_clusters).to(vectors.dtype) return centroids, counts class VectorQuantizer(nn.Module): def __init__( self, codebook_size: int, latent_dim: int, decay: float = 0.99, dead_code_threshold: float = 2.0, kmeans_iters: int = 50, ): super().__init__() self.codebook_size = codebook_size self.latent_dim = latent_dim self.decay = decay self.dead_code_threshold = dead_code_threshold self.kmeans_iters = kmeans_iters self.eps = 1e-8 self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool)) self.register_buffer("embedding", torch.randn(codebook_size, latent_dim)) self.register_buffer("ema_n", torch.zeros(codebook_size)) self.register_buffer("ema_s", torch.zeros(codebook_size, latent_dim)) def forward(self, latent): B, D, T = latent.shape flat = latent.transpose(1, 2).reshape(-1, D) if self.training and not self.initialized: self._init_codebook(flat) self.initialized.fill_(True) idx, quantized = self._nearest(flat) if self.training: self._update_ema(flat, idx) self._replace_dead_codes(flat) commit_loss = F.mse_loss(flat, quantized.detach()) quantized_ste = flat + (quantized - flat).detach() return { "quantized": quantized_ste.reshape(B, T, D).transpose(1, 2).contiguous(), "indices": idx.reshape(B, T), "commitment_loss": commit_loss, } def _nearest(self, flat): dists = ( flat.pow(2).sum(1, keepdim=True) - 2 * flat @ self.embedding.t() + self.embedding.pow(2).sum(1) ) idx = dists.argmin(1) return idx, self.embedding[idx] @torch.no_grad() def _init_codebook(self, flat): centroids, counts = _k_means(flat, self.codebook_size, self.kmeans_iters) counts = counts.clamp_min(1.0) w = counts / counts.mean() self.embedding.copy_(centroids) self.ema_n.copy_(w) self.ema_s.copy_(centroids * w.unsqueeze(1)) @torch.no_grad() def _update_ema(self, flat, indices): bins = torch.bincount(indices, minlength=self.codebook_size).to(flat.dtype) sums = torch.zeros_like(self.ema_s) sums.index_add_(0, indices, flat) self.ema_n.mul_(self.decay).add_(bins, alpha=1 - self.decay) self.ema_s.mul_(self.decay).add_(sums, alpha=1 - self.decay) self.embedding.copy_(self.ema_s / (self.ema_n.unsqueeze(1) + self.eps)) @torch.no_grad() def _replace_dead_codes(self, flat): dead = self.ema_n < self.dead_code_threshold if not dead.any(): return n = int(dead.sum()) picks = flat[torch.randint(0, flat.size(0), (n,), device=flat.device)] self.embedding[dead] = picks self.ema_s[dead] = picks self.ema_n[dead] = self.dead_code_threshold @torch.no_grad() def quantize(self, latent): B, D, T = latent.shape flat = latent.transpose(1, 2).reshape(-1, D) idx, quantized = self._nearest(flat) return { "quantized": quantized.reshape(B, T, D).transpose(1, 2).contiguous(), "indices": idx.reshape(B, T), } def decode_indices(self, indices): return self.embedding[indices].transpose(1, 2).contiguous() class ResidualVectorQuantizer(nn.Module): def __init__( self, latent_dim: int, num_quantizers: int = 8, codebook_size: int = 1024, ): super().__init__() self.num_quantizers = num_quantizers self.quantizers = nn.ModuleList( VectorQuantizer(codebook_size, latent_dim) for _ in range(num_quantizers) ) def forward(self, latent): residual = latent quantized = torch.zeros_like(latent) total_commit = latent.new_zeros(()) all_indices = [] for vq in self.quantizers: out = vq(residual) quantized = quantized + out["quantized"] residual = residual - out["quantized"].detach() total_commit = total_commit + out["commitment_loss"] all_indices.append(out["indices"]) return { "quantized": quantized, "indices": torch.stack(all_indices, dim=1), "commitment_loss": total_commit, } @torch.no_grad() def encode(self, latent): residual = latent all_indices = [] for vq in self.quantizers: out = vq.quantize(residual) all_indices.append(out["indices"]) residual = residual - out["quantized"] return torch.stack(all_indices, dim=1) @torch.no_grad() def decode(self, indices): quantized = None for i, vq in enumerate(self.quantizers): stage = vq.decode_indices(indices[:, i]) quantized = stage if quantized is None else quantized + stage return quantized class SoundStreamCodec(PreTrainedModel): config_class = SoundStreamConfig def __init__(self, config): super().__init__(config) self.strides = (2, 4, 5, 5) self.downsampling_factor = 1 for s in self.strides: self.downsampling_factor *= s self.encoder = Encoder(channels=config.channels, dim=config.latent_dim) self.quantizer = ResidualVectorQuantizer( latent_dim=config.latent_dim, codebook_size=config.codebook_size, num_quantizers=config.num_quantizers, ) self.decoder = Decoder(channels=config.channels, dim=config.latent_dim) self.post_init() def forward(self, audio, **kwargs): original_length = audio.size(-1) audio = self._pad_to_stride(audio) latent = self.encoder(audio) q_out = self.quantizer(latent) reconstructed = self.decoder(q_out["quantized"]) reconstructed = reconstructed[..., :original_length] return { "reconstructed_audio": reconstructed, "latent": latent, **q_out, } @torch.no_grad() def encode(self, audio): audio = self._pad_to_stride(audio) return self.quantizer.encode(self.encoder(audio)) @torch.no_grad() def decode(self, indices, original_length=None): out = self.decoder(self.quantizer.decode(indices)) if original_length is not None: out = out[..., :original_length] return out def _pad_to_stride(self, audio): remainder = audio.size(-1) % self.downsampling_factor if remainder == 0: return audio return F.pad(audio, (0, self.downsampling_factor - remainder), mode="replicate")