| | |
| | |
| | |
| | |
| |
|
| | from typing import Union |
| |
|
| | import numpy as np |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| | from torch.nn.utils import weight_norm |
| |
|
| |
|
| | class FactorizedVectorQuantize(nn.Module): |
| | def __init__(self, dim, codebook_size, codebook_dim, commitment, **kwargs): |
| | super().__init__() |
| | self.codebook_size = codebook_size |
| | self.codebook_dim = codebook_dim |
| | self.commitment = commitment |
| |
|
| | if dim != self.codebook_dim: |
| | self.in_proj = weight_norm(nn.Linear(dim, self.codebook_dim)) |
| | self.out_proj = weight_norm(nn.Linear(self.codebook_dim, dim)) |
| | else: |
| | self.in_proj = nn.Identity() |
| | self.out_proj = nn.Identity() |
| | self._codebook = nn.Embedding(codebook_size, self.codebook_dim) |
| |
|
| | @property |
| | def codebook(self): |
| | return self._codebook |
| |
|
| | def forward(self, z): |
| | """Quantized the input tensor using a fixed codebook and returns |
| | the corresponding codebook vectors |
| | |
| | Parameters |
| | ---------- |
| | z : Tensor[B x D x T] |
| | |
| | Returns |
| | ------- |
| | Tensor[B x D x T] |
| | Quantized continuous representation of input |
| | Tensor[1] |
| | Commitment loss to train encoder to predict vectors closer to codebook |
| | entries |
| | Tensor[1] |
| | Codebook loss to update the codebook |
| | Tensor[B x T] |
| | Codebook indices (quantized discrete representation of input) |
| | Tensor[B x D x T] |
| | Projected latents (continuous representation of input before quantization) |
| | """ |
| | |
| |
|
| | z = rearrange(z, "b d t -> b t d") |
| |
|
| | |
| | z_e = self.in_proj(z) |
| | z_e = rearrange(z_e, "b t d -> b d t") |
| | z_q, indices = self.decode_latents(z_e) |
| |
|
| | if self.training: |
| | commitment_loss = ( |
| | F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) |
| | * self.commitment |
| | ) |
| | codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) |
| | commit_loss = commitment_loss + codebook_loss |
| | else: |
| | commit_loss = torch.zeros(z.shape[0], device=z.device) |
| |
|
| | z_q = ( |
| | z_e + (z_q - z_e).detach() |
| | ) |
| |
|
| | z_q = rearrange(z_q, "b d t -> b t d") |
| | z_q = self.out_proj(z_q) |
| | z_q = rearrange(z_q, "b t d -> b d t") |
| |
|
| | return z_q, indices, commit_loss |
| |
|
| | def vq2emb(self, vq, proj=True): |
| | emb = self.embed_code(vq) |
| | if proj: |
| | emb = self.out_proj(emb) |
| | return emb.transpose(1, 2) |
| |
|
| | def get_emb(self): |
| | return self.codebook.weight |
| |
|
| | def embed_code(self, embed_id): |
| | return F.embedding(embed_id, self.codebook.weight) |
| |
|
| | def decode_code(self, embed_id): |
| | return self.embed_code(embed_id).transpose(1, 2) |
| |
|
| | def decode_latents(self, latents): |
| | encodings = rearrange(latents, "b d t -> (b t) d") |
| | codebook = self.codebook.weight |
| | |
| | encodings = F.normalize(encodings) |
| | codebook = F.normalize(codebook) |
| |
|
| | |
| | dist = ( |
| | encodings.pow(2).sum(1, keepdim=True) |
| | - 2 * encodings @ codebook.t() |
| | + codebook.pow(2).sum(1, keepdim=True).t() |
| | ) |
| | indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) |
| | z_q = self.decode_code(indices) |
| | return z_q, indices |
| |
|