File size: 803 Bytes
6766eda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch
import torch.nn as nn
class ScalarQuantizationLayer(nn.Module):
def __init__(self, in_dim, out_dim, latent_dim: int = 64, scale: int = 9):
super().__init__()
self.in_dim = in_dim
self.out_dim = out_dim
self.latent_dim = latent_dim
self.scale = scale
self.in_proj = nn.Linear(in_dim, latent_dim)
self.out_proj = nn.Linear(latent_dim, out_dim)
def forward(self, hidden):
hidden = self.in_proj(hidden)
hidden = torch.tanh(hidden)
if self.training:
quantized = torch.round(hidden * self.scale) / self.scale
hidden = hidden + (quantized - hidden).detach()
else:
hidden = torch.round(hidden * self.scale) / self.scale
return self.out_proj(hidden) |