gooorillax's picture
first push of codes and models for g2p, t2u, tokenizer and detokenizer
cd8454d
# Copyright (C) 2025. Huawei Technologies Co., Ltd. All Rights Reserved. (authors: Dehua Tao)
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
import torch.nn as nn
import logging
Codeword = torch.FloatTensor
Indices = torch.FloatTensor
def round_ste(z):
"""Round with straight through gradients."""
zhat = torch.round(z)
return z + (zhat - z).detach()
class FSQ(nn.Module):
"""Quantizer."""
def __init__(self, levels: list, eps: float = 1e-3, l2_norm: bool = False, batch_norm: bool = False):
super().__init__()
self._levels = levels
self._eps = eps
self.l2_norm = l2_norm
self.batch_norm = batch_norm
self._levels_np = torch.Tensor(levels)
self._basis = torch.cat((torch.Tensor([1]), torch.cumprod(self._levels_np[:-1], dim=0)))
self._implicit_codebook = self.indexes_to_codes(torch.arange(self.codebook_size))
logging.info(f'levels: {levels}')
if self.batch_norm:
self.bn = nn.BatchNorm1d(self.num_dimensions, momentum=0.01, eps=1e-3)
@property
def num_dimensions(self) -> int:
"""Number of dimensions expected from inputs."""
return len(self._levels)
@property
def codebook_size(self) -> int:
"""Size of the codebook."""
return np.prod(self._levels)
@property
def codebook(self):
"""Returns the implicit codebook. Shape (prod(levels), num_dimensions)."""
return self._implicit_codebook
def bound(self, z: torch.FloatTensor) -> torch.FloatTensor:
"""Bound `z`, an array of shape (..., d)."""
half_l = (self._levels_np - 1) * (1 - self._eps) / 2
offset = torch.where(self._levels_np % 2 == 1, 0.0, 0.5)
shift = torch.tan(offset / half_l)
return torch.tanh(z + shift) * half_l - offset
def quantize(self, z: torch.FloatTensor) -> Codeword:
"""Quanitzes z, returns quantized zhat, same shape as z."""
quantized = round_ste(self.bound(z))
# Renormalize to [-1, 1].
half_width = torch.div(self._levels_np, 2, rounding_mode='floor')
return quantized / half_width
def _scale_and_shift(self, zhat_normalized):
# Scale and shift to range [0, ..., L-1]
half_width = torch.div(self._levels_np, 2, rounding_mode='floor')
return (zhat_normalized * half_width) + half_width
def _scale_and_shift_inverse(self, zhat):
# Note that array(x) // 2 != tensor(x) // 2 when x is negative
half_width = torch.div(self._levels_np, 2, rounding_mode='floor')
return (zhat - half_width) / half_width
def codes_to_indexes(self, zhat: Codeword) -> Indices:
"""Converts a `code` to an index in the codebook."""
zhat = self._scale_and_shift(zhat)
return torch.sum(zhat * self._basis, axis=-1)
def indexes_to_codes(self, indices: Indices) -> Codeword:
"""Inverse of `indexes_to_codes`."""
indices = indices.unsqueeze(-1)
codes_non_centered = torch.remainder(
torch.div(indices, self._basis, rounding_mode='floor'), self._levels_np
)
return self._scale_and_shift_inverse(codes_non_centered)
def forward(self, z: torch.FloatTensor) -> Codeword:
# Specify cuda index
cuda_index = z.get_device()
self._levels_np = self._levels_np.to(f'cuda:{cuda_index}')
self._basis = self._basis.to(f'cuda:{cuda_index}')
self._implicit_codebook = self._implicit_codebook.to(f'cuda:{cuda_index}')
if self.l2_norm:
z = nn.functional.normalize(z, p=2, dim=-1)
if self.batch_norm:
self.bn = self.bn.to(f'cuda:{cuda_index}')
z = z.permute(0, 2, 1)
z = self.bn(z)
z = z.permute(0, 2, 1)
zhat = self.quantize(z)
return zhat