| import torch |
| import torch.nn as nn |
|
|
|
|
| class ScalarQuantizationLayer(nn.Module): |
| def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9): |
| super().__init__() |
| self.in_dim = in_dim |
| self.out_dim = out_dim |
| self.latent_dim = latent_dim |
| self.scale = scale |
|
|
| self.in_proj = nn.Linear(in_dim, latent_dim) |
| self.out_proj = nn.Linear(latent_dim, out_dim) |
| |
| def forward(self, hidden): |
| hidden = self.in_proj(hidden) |
| hidden = torch.tanh(hidden) |
|
|
| if self.training: |
| quantized = torch.round(hidden * self.scale) / self.scale |
| hidden = hidden + (quantized - hidden).detach() |
| else: |
| hidden = torch.round(hidden * self.scale) / self.scale |
|
|
| return self.out_proj(hidden) |