| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from typing import Mapping, Text, Tuple |
| | from einops import rearrange |
| | from torch.cuda.amp import autocast |
| |
|
| |
|
| | class SoftVectorQuantizer(torch.nn.Module): |
| | def __init__(self, |
| | codebook_size: int = 1024, |
| | token_size: int = 256, |
| | commitment_cost: float = 0.25, |
| | use_l2_norm: bool = False, |
| | clustering_vq: bool = False, |
| | entropy_loss_ratio: float = 0.01, |
| | tau: float = 0.07, |
| | num_codebooks: int = 1, |
| | show_usage: bool = False |
| | ): |
| | super().__init__() |
| | |
| | self.codebook_size = codebook_size |
| | self.token_size = token_size |
| | self.commitment_cost = commitment_cost |
| | self.use_l2_norm = use_l2_norm |
| | self.clustering_vq = clustering_vq |
| | |
| | |
| | self.num_codebooks = num_codebooks |
| | self.n_e = codebook_size |
| | self.e_dim = token_size |
| | self.entropy_loss_ratio = entropy_loss_ratio |
| | self.l2_norm = use_l2_norm |
| | self.show_usage = show_usage |
| | self.tau = tau |
| | |
| | |
| | self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size)) |
| | self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e) |
| | |
| | if self.l2_norm: |
| | self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1) |
| | |
| | if self.show_usage: |
| | self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536)) |
| |
|
| | |
| | @autocast(enabled=False) |
| | def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: |
| | z = z.float() |
| | original_shape = z.shape |
| | |
| | |
| | z = rearrange(z, 'b c h w -> b h w c').contiguous() |
| | z = z.view(z.size(0), -1, z.size(-1)) |
| | |
| | batch_size, seq_length, _ = z.shape |
| | |
| | |
| | assert seq_length % self.num_codebooks == 0, \ |
| | f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})" |
| | |
| | segment_length = seq_length // self.num_codebooks |
| | z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim) |
| | |
| | |
| | embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding |
| | if self.l2_norm: |
| | z_segments = F.normalize(z_segments, p=2, dim=-1) |
| | |
| | z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim) |
| | |
| | logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach()) |
| | |
| | |
| | probs = F.softmax(logits / self.tau, dim=-1) |
| | |
| | |
| | z_q = torch.einsum('nbk, nke -> nbe', probs, embedding) |
| | |
| | |
| | z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous() |
| | |
| | |
| | with torch.no_grad(): |
| | zq_z_cos = F.cosine_similarity( |
| | z_segments.view(-1, self.e_dim), |
| | z_q.view(-1, self.e_dim), |
| | dim=-1 |
| | ).mean() |
| | |
| | |
| | indices = torch.argmax(probs, dim=-1) |
| | indices = indices.transpose(0, 1).contiguous() |
| | |
| | |
| | if self.show_usage and self.training: |
| | for k in range(self.num_codebooks): |
| | cur_len = indices.size(0) |
| | self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone()) |
| | self.codebook_used[k, -cur_len:].copy_(indices[:, k]) |
| | |
| | |
| | if self.training: |
| | |
| | |
| | entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e)) |
| | quantizer_loss = entropy_loss |
| | commitment_loss = torch.tensor(0.0, device=z.device) |
| | codebook_loss = torch.tensor(0.0, device=z.device) |
| | else: |
| | quantizer_loss = torch.tensor(0.0, device=z.device) |
| | commitment_loss = torch.tensor(0.0, device=z.device) |
| | codebook_loss = torch.tensor(0.0, device=z.device) |
| | |
| | |
| | codebook_usage = torch.tensor([ |
| | len(torch.unique(self.codebook_used[k])) / self.n_e |
| | for k in range(self.num_codebooks) |
| | ]).mean() if self.show_usage else 0 |
| |
|
| | z_q = z_q.view(batch_size, -1, self.e_dim) |
| | |
| | |
| | z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1]) |
| | z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous() |
| | |
| | |
| | avg_probs = torch.mean(torch.mean(probs, dim=-1)) |
| | max_probs = torch.mean(torch.max(probs, dim=-1)[0]) |
| | |
| | |
| | result_dict = dict( |
| | quantizer_loss=quantizer_loss, |
| | commitment_loss=commitment_loss, |
| | codebook_loss=codebook_loss, |
| | min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) |
| | ) |
| | |
| | return z_quantized, result_dict |
| |
|
| | def get_codebook_entry(self, indices): |
| | """Added for compatibility with VectorQuantizer API""" |
| | if len(indices.shape) == 1: |
| | |
| | z_quantized = self.embedding[0][indices] |
| | elif len(indices.shape) == 2: |
| | z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0]) |
| | else: |
| | raise NotImplementedError |
| | if self.use_l2_norm: |
| | z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) |
| | return z_quantized |
| |
|
| |
|
| | def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01): |
| | flat_affinity = affinity.reshape(-1, affinity.shape[-1]) |
| | flat_affinity /= temperature |
| | probs = F.softmax(flat_affinity, dim=-1) |
| | log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1) |
| | if loss_type == "softmax": |
| | target_probs = probs |
| | else: |
| | raise ValueError("Entropy loss {} not supported".format(loss_type)) |
| | avg_probs = torch.mean(target_probs, dim=0) |
| | avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6)) |
| | sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1)) |
| | loss = sample_entropy - avg_entropy |
| | return loss |