| import torch | |
| import torch.nn as nn | |
| class ScalarQuantizationLayer(nn.Module): | |
| def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9): | |
| super().__init__() | |
| self.in_dim = in_dim | |
| self.out_dim = out_dim | |
| self.latent_dim = latent_dim | |
| self.scale = scale | |
| self.in_proj = nn.Linear(in_dim, latent_dim) | |
| self.out_proj = nn.Linear(latent_dim, out_dim) | |
| def forward(self, hidden): | |
| hidden = self.in_proj(hidden) | |
| hidden = torch.tanh(hidden) | |
| if self.training: | |
| quantized = torch.round(hidden * self.scale) / self.scale | |
| hidden = hidden + (quantized - hidden).detach() | |
| else: | |
| hidden = torch.round(hidden * self.scale) / self.scale | |
| return self.out_proj(hidden) |