| | |
| | |
| | |
| | |
| |
|
| | import math |
| | import torch |
| | from torch import nn |
| | from .fvq import FactorizedVectorQuantize |
| |
|
| |
|
| | class ResidualVQ(nn.Module): |
| | """Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf""" |
| |
|
| | def __init__(self, *, num_quantizers, codebook_size, **kwargs): |
| | super().__init__() |
| | VQ = FactorizedVectorQuantize |
| | if type(codebook_size) == int: |
| | codebook_size = [codebook_size] * num_quantizers |
| | self.layers = nn.ModuleList( |
| | [VQ(codebook_size=2**size, **kwargs) for size in codebook_size] |
| | ) |
| | self.num_quantizers = num_quantizers |
| | self.quantizer_dropout = kwargs.get("quantizer_dropout", 0.0) |
| | self.dropout_type = kwargs.get("dropout_type", None) |
| |
|
| | def forward(self, x, n_quantizers=None): |
| | quantized_out = 0.0 |
| | residual = x |
| |
|
| | all_losses = [] |
| | all_indices = [] |
| | all_quantized = [] |
| |
|
| | if n_quantizers is None: |
| | n_quantizers = self.num_quantizers |
| | if self.training: |
| | n_quantizers = torch.ones((x.shape[0],)) * self.num_quantizers + 1 |
| | if self.dropout_type == "linear": |
| | dropout = torch.randint(1, self.num_quantizers + 1, (x.shape[0],)) |
| | elif self.dropout_type == "exp": |
| | dropout = torch.randint( |
| | 1, int(math.log2(self.num_quantizers)), (x.shape[0],) |
| | ) |
| | dropout = torch.pow(2, dropout) |
| | n_dropout = int(x.shape[0] * self.quantizer_dropout) |
| | n_quantizers[:n_dropout] = dropout[:n_dropout] |
| | n_quantizers = n_quantizers.to(x.device) |
| |
|
| | for idx, layer in enumerate(self.layers): |
| | if not self.training and idx >= n_quantizers: |
| | break |
| | quantized, indices, loss = layer(residual) |
| |
|
| | mask = ( |
| | torch.full((x.shape[0],), fill_value=idx, device=x.device) |
| | < n_quantizers |
| | ) |
| |
|
| | residual = residual - quantized |
| |
|
| | quantized_out = quantized_out + quantized * mask[:, None, None] |
| |
|
| | |
| | loss = (loss * mask).mean() |
| |
|
| | all_indices.append(indices) |
| | all_losses.append(loss) |
| | all_quantized.append(quantized) |
| | all_losses, all_indices, all_quantized = map( |
| | torch.stack, (all_losses, all_indices, all_quantized) |
| | ) |
| | return quantized_out, all_indices, all_losses, all_quantized |
| |
|
| | def vq2emb(self, vq): |
| | |
| | quantized_out = 0.0 |
| | for idx, layer in enumerate(self.layers): |
| | quantized = layer.vq2emb(vq[idx]) |
| | quantized_out += quantized |
| | return quantized_out |
| |
|
| | def get_emb(self): |
| | embs = [] |
| | for idx, layer in enumerate(self.layers): |
| | embs.append(layer.get_emb()) |
| | return embs |
| |
|