dpss-exp3-TTS / VoxCPM /src /voxcpm /modules /layers /scalar_quantization_layer.py
lglg666's picture
Upload folder using huggingface_hub
6766eda verified
import torch
import torch.nn as nn
class ScalarQuantizationLayer(nn.Module):
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.latent_dim = latent_dim
self.scale = scale
self.in_proj = nn.Linear(in_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, out_dim)
def forward(self, hidden):
hidden = self.in_proj(hidden)
hidden = torch.tanh(hidden)
if self.training:
quantized = torch.round(hidden * self.scale) / self.scale
hidden = hidden + (quantized - hidden).detach()
else:
hidden = torch.round(hidden * self.scale) / self.scale
return self.out_proj(hidden)