soundstream-impl / model.py
timofeiiz's picture
Upload folder using huggingface_hub
ad0bcd5 verified
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PreTrainedModel
from .config_model import SoundStreamConfig
class CausalConv1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
):
super().__init__()
self.left_padding = dilation * (kernel_size - 1)
self.conv = nn.Conv1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
dilation=dilation,
)
def forward(self, x):
x = F.pad(x, (self.left_padding, 0))
return self.conv(x)
class CausalConvTranspose1d(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int,
):
super().__init__()
self.stride = stride
self.conv_transpose = nn.ConvTranspose1d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
)
def forward(self, x):
target_length = x.shape[-1] * self.stride
x = self.conv_transpose(x)
return x[..., :target_length]
class ResidualUnit(nn.Module):
def __init__(self, channels: int, dilation: int):
super().__init__()
self.block = nn.Sequential(
nn.ELU(),
CausalConv1d(kernel_size=7, in_channels=channels, out_channels=channels, dilation=dilation),
nn.ELU(),
nn.Conv1d(kernel_size=1, in_channels=channels, out_channels=channels)
)
def forward(self, x):
return x + self.block(x)
class EncoderBlock(nn.Module):
def __init__(self, channels: int, s: int):
super().__init__()
self.block = nn.Sequential(
ResidualUnit(channels=channels // 2, dilation=1),
ResidualUnit(channels=channels // 2, dilation=3),
ResidualUnit(channels=channels // 2, dilation=9),
CausalConv1d(kernel_size=2 * s, in_channels=channels // 2, out_channels=channels, stride=s)
)
def forward(self, x):
return self.block(x)
class DecoderBlock(nn.Module):
def __init__(self, channels: int, s: int):
super().__init__()
self.block = nn.Sequential(
CausalConvTranspose1d(kernel_size=2 * s, in_channels=channels, out_channels=channels // 2, stride=s),
ResidualUnit(channels=channels // 2, dilation=1),
ResidualUnit(channels=channels // 2, dilation=3),
ResidualUnit(channels=channels // 2, dilation=9)
)
def forward(self, x):
return self.block(x)
class Encoder(nn.Module):
def __init__(self, channels: int = 16, dim: int = 512):
super().__init__()
self.encoder = nn.Sequential(
CausalConv1d(kernel_size=7, in_channels=1, out_channels=channels),
EncoderBlock(channels=2 * channels, s=2),
EncoderBlock(channels=4 * channels, s=4),
EncoderBlock(channels=8 * channels, s=5),
EncoderBlock(channels=16 * channels, s=5),
CausalConv1d(kernel_size=3, in_channels=16 * channels, out_channels=dim)
)
def forward(self, audio):
return self.encoder(audio)
class Decoder(nn.Module):
def __init__(self, channels: int = 16, dim: int = 512):
super().__init__()
self.decoder = nn.Sequential(
CausalConv1d(kernel_size=7, in_channels=dim, out_channels=16 * channels),
DecoderBlock(channels=16 * channels, s=5),
DecoderBlock(channels=8 * channels, s=5),
DecoderBlock(channels=4 * channels, s=4),
DecoderBlock(channels=2 * channels, s=2),
CausalConv1d(kernel_size=7, in_channels=channels, out_channels=1)
)
def forward(self, quantized):
return self.decoder(quantized)
@torch.no_grad()
def _k_means(vectors, num_clusters, num_iters):
n = vectors.size(0)
device = vectors.device
if n >= num_clusters:
init_indices = torch.randperm(n, device=device)[:num_clusters]
else:
init_indices = torch.randint(0, n, (num_clusters,), device=device)
centroids = vectors[init_indices].clone()
for _ in range(num_iters):
dists = (
vectors.pow(2).sum(1, keepdim=True)
- 2 * vectors @ centroids.t()
+ centroids.pow(2).sum(1)
)
assignments = dists.argmin(1)
counts = torch.bincount(assignments, minlength=num_clusters).to(vectors.dtype)
sums = torch.zeros_like(centroids)
sums.index_add_(0, assignments, vectors)
non_empty = counts > 0
if non_empty.any():
centroids[non_empty] = sums[non_empty] / counts[non_empty].unsqueeze(1)
empty = ~non_empty
if empty.any():
centroids[empty] = vectors[torch.randint(0, n, (int(empty.sum()),), device=device)]
dists = (
vectors.pow(2).sum(1, keepdim=True)
- 2 * vectors @ centroids.t()
+ centroids.pow(2).sum(1)
)
counts = torch.bincount(dists.argmin(1), minlength=num_clusters).to(vectors.dtype)
return centroids, counts
class VectorQuantizer(nn.Module):
def __init__(
self,
codebook_size: int,
latent_dim: int,
decay: float = 0.99,
dead_code_threshold: float = 2.0,
kmeans_iters: int = 50,
):
super().__init__()
self.codebook_size = codebook_size
self.latent_dim = latent_dim
self.decay = decay
self.dead_code_threshold = dead_code_threshold
self.kmeans_iters = kmeans_iters
self.eps = 1e-8
self.register_buffer("initialized", torch.tensor(False, dtype=torch.bool))
self.register_buffer("embedding", torch.randn(codebook_size, latent_dim))
self.register_buffer("ema_n", torch.zeros(codebook_size))
self.register_buffer("ema_s", torch.zeros(codebook_size, latent_dim))
def forward(self, latent):
B, D, T = latent.shape
flat = latent.transpose(1, 2).reshape(-1, D)
if self.training and not self.initialized:
self._init_codebook(flat)
self.initialized.fill_(True)
idx, quantized = self._nearest(flat)
if self.training:
self._update_ema(flat, idx)
self._replace_dead_codes(flat)
commit_loss = F.mse_loss(flat, quantized.detach())
quantized_ste = flat + (quantized - flat).detach()
return {
"quantized": quantized_ste.reshape(B, T, D).transpose(1, 2).contiguous(),
"indices": idx.reshape(B, T),
"commitment_loss": commit_loss,
}
def _nearest(self, flat):
dists = (
flat.pow(2).sum(1, keepdim=True)
- 2 * flat @ self.embedding.t()
+ self.embedding.pow(2).sum(1)
)
idx = dists.argmin(1)
return idx, self.embedding[idx]
@torch.no_grad()
def _init_codebook(self, flat):
centroids, counts = _k_means(flat, self.codebook_size, self.kmeans_iters)
counts = counts.clamp_min(1.0)
w = counts / counts.mean()
self.embedding.copy_(centroids)
self.ema_n.copy_(w)
self.ema_s.copy_(centroids * w.unsqueeze(1))
@torch.no_grad()
def _update_ema(self, flat, indices):
bins = torch.bincount(indices, minlength=self.codebook_size).to(flat.dtype)
sums = torch.zeros_like(self.ema_s)
sums.index_add_(0, indices, flat)
self.ema_n.mul_(self.decay).add_(bins, alpha=1 - self.decay)
self.ema_s.mul_(self.decay).add_(sums, alpha=1 - self.decay)
self.embedding.copy_(self.ema_s / (self.ema_n.unsqueeze(1) + self.eps))
@torch.no_grad()
def _replace_dead_codes(self, flat):
dead = self.ema_n < self.dead_code_threshold
if not dead.any():
return
n = int(dead.sum())
picks = flat[torch.randint(0, flat.size(0), (n,), device=flat.device)]
self.embedding[dead] = picks
self.ema_s[dead] = picks
self.ema_n[dead] = self.dead_code_threshold
@torch.no_grad()
def quantize(self, latent):
B, D, T = latent.shape
flat = latent.transpose(1, 2).reshape(-1, D)
idx, quantized = self._nearest(flat)
return {
"quantized": quantized.reshape(B, T, D).transpose(1, 2).contiguous(),
"indices": idx.reshape(B, T),
}
def decode_indices(self, indices):
return self.embedding[indices].transpose(1, 2).contiguous()
class ResidualVectorQuantizer(nn.Module):
def __init__(
self,
latent_dim: int,
num_quantizers: int = 8,
codebook_size: int = 1024,
):
super().__init__()
self.num_quantizers = num_quantizers
self.quantizers = nn.ModuleList(
VectorQuantizer(codebook_size, latent_dim) for _ in range(num_quantizers)
)
def forward(self, latent):
residual = latent
quantized = torch.zeros_like(latent)
total_commit = latent.new_zeros(())
all_indices = []
for vq in self.quantizers:
out = vq(residual)
quantized = quantized + out["quantized"]
residual = residual - out["quantized"].detach()
total_commit = total_commit + out["commitment_loss"]
all_indices.append(out["indices"])
return {
"quantized": quantized,
"indices": torch.stack(all_indices, dim=1),
"commitment_loss": total_commit,
}
@torch.no_grad()
def encode(self, latent):
residual = latent
all_indices = []
for vq in self.quantizers:
out = vq.quantize(residual)
all_indices.append(out["indices"])
residual = residual - out["quantized"]
return torch.stack(all_indices, dim=1)
@torch.no_grad()
def decode(self, indices):
quantized = None
for i, vq in enumerate(self.quantizers):
stage = vq.decode_indices(indices[:, i])
quantized = stage if quantized is None else quantized + stage
return quantized
class SoundStreamCodec(PreTrainedModel):
config_class = SoundStreamConfig
def __init__(self, config):
super().__init__(config)
self.strides = (2, 4, 5, 5)
self.downsampling_factor = 1
for s in self.strides:
self.downsampling_factor *= s
self.encoder = Encoder(channels=config.channels, dim=config.latent_dim)
self.quantizer = ResidualVectorQuantizer(
latent_dim=config.latent_dim,
codebook_size=config.codebook_size,
num_quantizers=config.num_quantizers,
)
self.decoder = Decoder(channels=config.channels, dim=config.latent_dim)
self.post_init()
def forward(self, audio, **kwargs):
original_length = audio.size(-1)
audio = self._pad_to_stride(audio)
latent = self.encoder(audio)
q_out = self.quantizer(latent)
reconstructed = self.decoder(q_out["quantized"])
reconstructed = reconstructed[..., :original_length]
return {
"reconstructed_audio": reconstructed,
"latent": latent,
**q_out,
}
@torch.no_grad()
def encode(self, audio):
audio = self._pad_to_stride(audio)
return self.quantizer.encode(self.encoder(audio))
@torch.no_grad()
def decode(self, indices, original_length=None):
out = self.decoder(self.quantizer.decode(indices))
if original_length is not None:
out = out[..., :original_length]
return out
def _pad_to_stride(self, audio):
remainder = audio.size(-1) % self.downsampling_factor
if remainder == 0:
return audio
return F.pad(audio, (0, self.downsampling_factor - remainder), mode="replicate")