| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from .layers import kmeans, sinkhorn_algorithm |
| |
|
| |
|
| | class VectorQuantizer(nn.Module): |
| |
|
| | def __init__(self, n_e, e_dim, |
| | beta = 0.25, kmeans_init = False, kmeans_iters = 10, |
| | sk_epsilon=0.01, sk_iters=100): |
| | super().__init__() |
| | self.n_e = n_e |
| | self.e_dim = e_dim |
| | self.beta = beta |
| | self.kmeans_init = kmeans_init |
| | self.kmeans_iters = kmeans_iters |
| | self.sk_epsilon = sk_epsilon |
| | self.sk_iters = sk_iters |
| |
|
| | self.embedding = nn.Embedding(self.n_e, self.e_dim) |
| | if not kmeans_init: |
| | self.initted = True |
| | self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
| | else: |
| | self.initted = False |
| | self.embedding.weight.data.zero_() |
| |
|
| | def get_codebook(self): |
| | return self.embedding.weight |
| |
|
| | def get_codebook_entry(self, indices, shape=None): |
| | |
| | z_q = self.embedding(indices) |
| | if shape is not None: |
| | z_q = z_q.view(shape) |
| |
|
| | return z_q |
| |
|
| | def init_emb(self, data): |
| |
|
| | centers = kmeans( |
| | data, |
| | self.n_e, |
| | self.kmeans_iters, |
| | ) |
| |
|
| | self.embedding.weight.data.copy_(centers) |
| | self.initted = True |
| |
|
| | @staticmethod |
| | def center_distance_for_constraint(distances): |
| | |
| | max_distance = distances.max() |
| | min_distance = distances.min() |
| |
|
| | middle = (max_distance + min_distance) / 2 |
| | amplitude = max_distance - middle + 1e-5 |
| | assert amplitude > 0 |
| | centered_distances = (distances - middle) / amplitude |
| | return centered_distances |
| |
|
| | def forward(self, x, use_sk=True): |
| | |
| | latent = x.view(-1, self.e_dim) |
| |
|
| | if not self.initted and self.training: |
| | self.init_emb(latent) |
| |
|
| | |
| | d = torch.sum(latent**2, dim=1, keepdim=True) + \ |
| | torch.sum(self.embedding.weight**2, dim=1, keepdim=True).t()- \ |
| | 2 * torch.matmul(latent, self.embedding.weight.t()) |
| | if not use_sk or self.sk_epsilon <= 0: |
| | indices = torch.argmin(d, dim=-1) |
| | |
| | else: |
| | |
| | d = self.center_distance_for_constraint(d) |
| | d = d.double() |
| | Q = sinkhorn_algorithm(d,self.sk_epsilon,self.sk_iters) |
| | |
| | Q = torch.nan_to_num(Q, Q[torch.isfinite(Q)].min().item()) |
| | if torch.isnan(Q).any() or torch.isinf(Q).any(): |
| | print(f"Sinkhorn Algorithm returns nan/inf values.") |
| | indices = torch.argmax(Q, dim=-1) |
| |
|
| | |
| |
|
| | x_q = self.embedding(indices).view(x.shape) |
| |
|
| | |
| | commitment_loss = F.mse_loss(x_q.detach(), x) |
| | codebook_loss = F.mse_loss(x_q, x.detach()) |
| | loss = codebook_loss + self.beta * commitment_loss |
| |
|
| | |
| | x_q = x + (x_q - x).detach() |
| |
|
| | indices = indices.view(x.shape[:-1]) |
| |
|
| | return x_q, loss, indices |
| |
|
| |
|
| |
|