|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import logging |
|
|
|
|
|
Codeword = torch.FloatTensor |
|
|
Indices = torch.FloatTensor |
|
|
|
|
|
|
|
|
def round_ste(z): |
|
|
"""Round with straight through gradients.""" |
|
|
zhat = torch.round(z) |
|
|
return z + (zhat - z).detach() |
|
|
|
|
|
|
|
|
class FSQ(nn.Module): |
|
|
"""Quantizer.""" |
|
|
|
|
|
def __init__(self, levels: list, eps: float = 1e-3, l2_norm: bool = False, batch_norm: bool = False): |
|
|
super().__init__() |
|
|
|
|
|
self._levels = levels |
|
|
self._eps = eps |
|
|
self.l2_norm = l2_norm |
|
|
self.batch_norm = batch_norm |
|
|
self._levels_np = torch.Tensor(levels) |
|
|
self._basis = torch.cat((torch.Tensor([1]), torch.cumprod(self._levels_np[:-1], dim=0))) |
|
|
self._implicit_codebook = self.indexes_to_codes(torch.arange(self.codebook_size)) |
|
|
logging.info(f'levels: {levels}') |
|
|
|
|
|
if self.batch_norm: |
|
|
self.bn = nn.BatchNorm1d(self.num_dimensions, momentum=0.01, eps=1e-3) |
|
|
|
|
|
@property |
|
|
def num_dimensions(self) -> int: |
|
|
"""Number of dimensions expected from inputs.""" |
|
|
return len(self._levels) |
|
|
|
|
|
@property |
|
|
def codebook_size(self) -> int: |
|
|
"""Size of the codebook.""" |
|
|
return np.prod(self._levels) |
|
|
|
|
|
@property |
|
|
def codebook(self): |
|
|
"""Returns the implicit codebook. Shape (prod(levels), num_dimensions).""" |
|
|
return self._implicit_codebook |
|
|
|
|
|
def bound(self, z: torch.FloatTensor) -> torch.FloatTensor: |
|
|
"""Bound `z`, an array of shape (..., d).""" |
|
|
half_l = (self._levels_np - 1) * (1 - self._eps) / 2 |
|
|
offset = torch.where(self._levels_np % 2 == 1, 0.0, 0.5) |
|
|
shift = torch.tan(offset / half_l) |
|
|
return torch.tanh(z + shift) * half_l - offset |
|
|
|
|
|
def quantize(self, z: torch.FloatTensor) -> Codeword: |
|
|
"""Quanitzes z, returns quantized zhat, same shape as z.""" |
|
|
quantized = round_ste(self.bound(z)) |
|
|
|
|
|
|
|
|
half_width = torch.div(self._levels_np, 2, rounding_mode='floor') |
|
|
return quantized / half_width |
|
|
|
|
|
def _scale_and_shift(self, zhat_normalized): |
|
|
|
|
|
half_width = torch.div(self._levels_np, 2, rounding_mode='floor') |
|
|
return (zhat_normalized * half_width) + half_width |
|
|
|
|
|
def _scale_and_shift_inverse(self, zhat): |
|
|
|
|
|
half_width = torch.div(self._levels_np, 2, rounding_mode='floor') |
|
|
return (zhat - half_width) / half_width |
|
|
|
|
|
def codes_to_indexes(self, zhat: Codeword) -> Indices: |
|
|
"""Converts a `code` to an index in the codebook.""" |
|
|
zhat = self._scale_and_shift(zhat) |
|
|
return torch.sum(zhat * self._basis, axis=-1) |
|
|
|
|
|
def indexes_to_codes(self, indices: Indices) -> Codeword: |
|
|
"""Inverse of `indexes_to_codes`.""" |
|
|
indices = indices.unsqueeze(-1) |
|
|
codes_non_centered = torch.remainder( |
|
|
torch.div(indices, self._basis, rounding_mode='floor'), self._levels_np |
|
|
) |
|
|
return self._scale_and_shift_inverse(codes_non_centered) |
|
|
|
|
|
def forward(self, z: torch.FloatTensor) -> Codeword: |
|
|
|
|
|
cuda_index = z.get_device() |
|
|
self._levels_np = self._levels_np.to(f'cuda:{cuda_index}') |
|
|
self._basis = self._basis.to(f'cuda:{cuda_index}') |
|
|
self._implicit_codebook = self._implicit_codebook.to(f'cuda:{cuda_index}') |
|
|
|
|
|
if self.l2_norm: |
|
|
z = nn.functional.normalize(z, p=2, dim=-1) |
|
|
|
|
|
if self.batch_norm: |
|
|
self.bn = self.bn.to(f'cuda:{cuda_index}') |
|
|
|
|
|
z = z.permute(0, 2, 1) |
|
|
z = self.bn(z) |
|
|
z = z.permute(0, 2, 1) |
|
|
|
|
|
zhat = self.quantize(z) |
|
|
|
|
|
return zhat |
|
|
|
|
|
|
|
|
|