Instructions to use timofeiiz/soundstream-impl with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use timofeiiz/soundstream-impl with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("timofeiiz/soundstream-impl", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import torch | |
| import torch.nn.functional as F | |
| from torch import nn | |
| from transformers import PreTrainedModel | |
| from .config_model import SoundStreamConfig | |
| class CausalConv1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int = 1, | |
| dilation: int = 1, | |
| ): | |
| super().__init__() | |
| self.left_padding = dilation * (kernel_size - 1) | |
| self.conv = nn.Conv1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| dilation=dilation, | |
| ) | |
| def forward(self, x): | |
| x = F.pad(x, (self.left_padding, 0)) | |
| return self.conv(x) | |
| class CausalConvTranspose1d(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| kernel_size: int, | |
| stride: int, | |
| ): | |
| super().__init__() | |
| self.stride = stride | |
| self.conv_transpose = nn.ConvTranspose1d( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| kernel_size=kernel_size, | |
| stride=stride, | |
| ) | |
| def forward(self, x): | |
| target_length = x.shape[-1] * self.stride | |
| x = self.conv_transpose(x) | |
| return x[..., :target_length] | |
| class ResidualUnit(nn.Module): | |
| def __init__(self, channels: int, dilation: int): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.ELU(), | |
| CausalConv1d(kernel_size=7, in_channels=channels, out_channels=channels, dilation=dilation), | |
| nn.ELU(), | |
| nn.Conv1d(kernel_size=1, in_channels=channels, out_channels=channels) | |
| ) | |
| def forward(self, x): | |
| return x + self.block(x) | |
| class EncoderBlock(nn.Module): | |
| def __init__(self, channels: int, s: int): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| ResidualUnit(channels=channels // 2, dilation=1), | |
| ResidualUnit(channels=channels // 2, dilation=3), | |
| ResidualUnit(channels=channels // 2, dilation=9), | |
| CausalConv1d(kernel_size=2 * s, in_channels=channels // 2, out_channels=channels, stride=s) | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class DecoderBlock(nn.Module): | |
| def __init__(self, channels: int, s: int): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| CausalConvTranspose1d(kernel_size=2 * s, in_channels=channels, out_channels=channels // 2, stride=s), | |
| ResidualUnit(channels=channels // 2, dilation=1), | |
| ResidualUnit(channels=channels // 2, dilation=3), | |
| ResidualUnit(channels=channels // 2, dilation=9) | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class Encoder(nn.Module): | |
| def __init__(self, channels: int = 16, dim: int = 512): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| CausalConv1d(kernel_size=7, in_channels=1, out_channels=channels), | |
| EncoderBlock(channels=2 * channels, s=2), | |
| EncoderBlock(channels=4 * channels, s=4), | |
| EncoderBlock(channels=8 * channels, s=5), | |
| EncoderBlock(channels=16 * channels, s=5), | |
| CausalConv1d(kernel_size=3, in_channels=16 * channels, out_channels=dim) | |
| ) | |
| def forward(self, audio): | |
| return self.encoder(audio) | |
| class Decoder(nn.Module): | |
| def __init__(self, channels: int = 16, dim: int = 512): | |
| super().__init__() | |
| self.decoder = nn.Sequential( | |
| CausalConv1d(kernel_size=7, in_channels=dim, out_channels=16 * channels), | |
| DecoderBlock(channels=16 * channels, s=5), | |
| DecoderBlock(channels=8 * channels, s=5), | |
| DecoderBlock(channels=4 * channels, s=4), | |
| DecoderBlock(channels=2 * channels, s=2), | |
| CausalConv1d(kernel_size=7, in_channels=channels, out_channels=1) | |
| ) | |
| def forward(self, quantized): | |
| return self.decoder(quantized) | |
| def _k_means(vectors, num_clusters, num_iters): | |
| n = vectors.size(0) | |
| device = vectors.device | |
| if n >= num_clusters: | |
| init_indices = torch.randperm(n, device=device)[:num_clusters] | |
| else: | |
| init_indices = torch.randint(0, n, (num_clusters,), device=device) | |
| centroids = vectors[init_indices].clone() | |
| for _ in range(num_iters): | |
| dists = ( | |
| vectors.pow(2).sum(1, keepdim=True) | |
| - 2 * vectors @ centroids.t() | |
| + centroids.pow(2).sum(1) | |
| ) | |
| assignments = dists.argmin(1) | |
| counts = torch.bincount(assignments, minlength=num_clusters).to(vectors.dtype) | |
| sums = torch.zeros_like(centroids) | |
| sums.index_add_(0, assignments, vectors) | |
| non_empty = counts > 0 | |
| if non_empty.any(): | |
| centroids[non_empty] = sums[non_empty] / counts[non_empty].unsqueeze(1) | |
| empty = ~non_empty | |
| if empty.any(): | |
| centroids[empty] = vectors[torch.randint(0, n, (int(empty.sum()),), device=device)] | |
| dists = ( | |
| vectors.pow(2).sum(1, keepdim=True) | |
| - 2 * vectors @ centroids.t() | |
| + centroids.pow(2).sum(1) | |
| ) | |
| counts = torch.bincount(dists.argmin(1), minlength=num_clusters).to(vectors.dtype) | |
| return centroids, counts | |
| class VectorQuantizer(nn.Module): | |
| def __init__( | |
| self, | |
| codebook_size: int, | |
| latent_dim: int, | |
| decay: float = 0.99, | |
| dead_code_threshold: float = 2.0, | |
| kmeans_iters: int = 50, | |
| ): | |
| super().__init__() | |
| self.codebook_size = codebook_size | |
| self.latent_dim = latent_dim | |
| self.decay = decay | |
| self.dead_code_threshold = dead_code_threshold | |
| self.kmeans_iters = kmeans_iters | |
| self.eps = 1e-8 | |
| self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool)) | |
| self.register_buffer("embedding", torch.randn(codebook_size, latent_dim)) | |
| self.register_buffer("ema_n", torch.zeros(codebook_size)) | |
| self.register_buffer("ema_s", torch.zeros(codebook_size, latent_dim)) | |
| def forward(self, latent): | |
| B, D, T = latent.shape | |
| flat = latent.transpose(1, 2).reshape(-1, D) | |
| if self.training and not self.initialized: | |
| self._init_codebook(flat) | |
| self.initialized.fill_(True) | |
| idx, quantized = self._nearest(flat) | |
| if self.training: | |
| self._update_ema(flat, idx) | |
| self._replace_dead_codes(flat) | |
| commit_loss = F.mse_loss(flat, quantized.detach()) | |
| quantized_ste = flat + (quantized - flat).detach() | |
| return { | |
| "quantized": quantized_ste.reshape(B, T, D).transpose(1, 2).contiguous(), | |
| "indices": idx.reshape(B, T), | |
| "commitment_loss": commit_loss, | |
| } | |
| def _nearest(self, flat): | |
| dists = ( | |
| flat.pow(2).sum(1, keepdim=True) | |
| - 2 * flat @ self.embedding.t() | |
| + self.embedding.pow(2).sum(1) | |
| ) | |
| idx = dists.argmin(1) | |
| return idx, self.embedding[idx] | |
| def _init_codebook(self, flat): | |
| centroids, counts = _k_means(flat, self.codebook_size, self.kmeans_iters) | |
| counts = counts.clamp_min(1.0) | |
| w = counts / counts.mean() | |
| self.embedding.copy_(centroids) | |
| self.ema_n.copy_(w) | |
| self.ema_s.copy_(centroids * w.unsqueeze(1)) | |
| def _update_ema(self, flat, indices): | |
| bins = torch.bincount(indices, minlength=self.codebook_size).to(flat.dtype) | |
| sums = torch.zeros_like(self.ema_s) | |
| sums.index_add_(0, indices, flat) | |
| self.ema_n.mul_(self.decay).add_(bins, alpha=1 - self.decay) | |
| self.ema_s.mul_(self.decay).add_(sums, alpha=1 - self.decay) | |
| self.embedding.copy_(self.ema_s / (self.ema_n.unsqueeze(1) + self.eps)) | |
| def _replace_dead_codes(self, flat): | |
| dead = self.ema_n < self.dead_code_threshold | |
| if not dead.any(): | |
| return | |
| n = int(dead.sum()) | |
| picks = flat[torch.randint(0, flat.size(0), (n,), device=flat.device)] | |
| self.embedding[dead] = picks | |
| self.ema_s[dead] = picks | |
| self.ema_n[dead] = self.dead_code_threshold | |
| def quantize(self, latent): | |
| B, D, T = latent.shape | |
| flat = latent.transpose(1, 2).reshape(-1, D) | |
| idx, quantized = self._nearest(flat) | |
| return { | |
| "quantized": quantized.reshape(B, T, D).transpose(1, 2).contiguous(), | |
| "indices": idx.reshape(B, T), | |
| } | |
| def decode_indices(self, indices): | |
| return self.embedding[indices].transpose(1, 2).contiguous() | |
| class ResidualVectorQuantizer(nn.Module): | |
| def __init__( | |
| self, | |
| latent_dim: int, | |
| num_quantizers: int = 8, | |
| codebook_size: int = 1024, | |
| ): | |
| super().__init__() | |
| self.num_quantizers = num_quantizers | |
| self.quantizers = nn.ModuleList( | |
| VectorQuantizer(codebook_size, latent_dim) for _ in range(num_quantizers) | |
| ) | |
| def forward(self, latent): | |
| residual = latent | |
| quantized = torch.zeros_like(latent) | |
| total_commit = latent.new_zeros(()) | |
| all_indices = [] | |
| for vq in self.quantizers: | |
| out = vq(residual) | |
| quantized = quantized + out["quantized"] | |
| residual = residual - out["quantized"].detach() | |
| total_commit = total_commit + out["commitment_loss"] | |
| all_indices.append(out["indices"]) | |
| return { | |
| "quantized": quantized, | |
| "indices": torch.stack(all_indices, dim=1), | |
| "commitment_loss": total_commit, | |
| } | |
| def encode(self, latent): | |
| residual = latent | |
| all_indices = [] | |
| for vq in self.quantizers: | |
| out = vq.quantize(residual) | |
| all_indices.append(out["indices"]) | |
| residual = residual - out["quantized"] | |
| return torch.stack(all_indices, dim=1) | |
| def decode(self, indices): | |
| quantized = None | |
| for i, vq in enumerate(self.quantizers): | |
| stage = vq.decode_indices(indices[:, i]) | |
| quantized = stage if quantized is None else quantized + stage | |
| return quantized | |
| class SoundStreamCodec(PreTrainedModel): | |
| config_class = SoundStreamConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.strides = (2, 4, 5, 5) | |
| self.downsampling_factor = 1 | |
| for s in self.strides: | |
| self.downsampling_factor *= s | |
| self.encoder = Encoder(channels=config.channels, dim=config.latent_dim) | |
| self.quantizer = ResidualVectorQuantizer( | |
| latent_dim=config.latent_dim, | |
| codebook_size=config.codebook_size, | |
| num_quantizers=config.num_quantizers, | |
| ) | |
| self.decoder = Decoder(channels=config.channels, dim=config.latent_dim) | |
| self.post_init() | |
| def forward(self, audio, **kwargs): | |
| original_length = audio.size(-1) | |
| audio = self._pad_to_stride(audio) | |
| latent = self.encoder(audio) | |
| q_out = self.quantizer(latent) | |
| reconstructed = self.decoder(q_out["quantized"]) | |
| reconstructed = reconstructed[..., :original_length] | |
| return { | |
| "reconstructed_audio": reconstructed, | |
| "latent": latent, | |
| **q_out, | |
| } | |
| def encode(self, audio): | |
| audio = self._pad_to_stride(audio) | |
| return self.quantizer.encode(self.encoder(audio)) | |
| def decode(self, indices, original_length=None): | |
| out = self.decoder(self.quantizer.decode(indices)) | |
| if original_length is not None: | |
| out = out[..., :original_length] | |
| return out | |
| def _pad_to_stride(self, audio): | |
| remainder = audio.size(-1) % self.downsampling_factor | |
| if remainder == 0: | |
| return audio | |
| return F.pad(audio, (0, self.downsampling_factor - remainder), mode="replicate") | |