| | import torch |
| | from typing import List, Tuple |
| | from torch.nn import functional as F |
| | from torch import distributed as tdist, nn as nn |
| |
|
| | from .quantizer import VectorQuantizer |
| |
|
| | def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau): |
| | E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False) |
| | E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) |
| | logits = -E_dist.float().mul_(inv_entropy_tau) |
| | |
| | prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) |
| | per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1)) |
| | |
| | avg_prob = prob.mean(dim=0) |
| | log_avg_prob = torch.log(avg_prob + 1e-7) |
| | codebook_entropy = (-avg_prob * log_avg_prob).sum() |
| | |
| | entropy_loss = per_sample_entropy - codebook_entropy |
| | return entropy_loss |
| |
|
| |
|
| | class NormalizedEmbedding(nn.Embedding): |
| | def __init__(self, num_embeddings: int, embedding_dim: int): |
| | super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim) |
| | |
| |
|
| | def forward(self, idx): |
| | return F.embedding( |
| | idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm, |
| | self.norm_type, self.scale_grad_by_freq, self.sparse |
| | ) |
| |
|
| | def get_norm_weight(self): |
| | return F.normalize(self.weight, dim=1) |
| |
|
| |
|
| | class ResConv(nn.Conv2d): |
| | def __init__(self, embed_dim, quant_resi): |
| | ks = 3 if quant_resi < 0 else 1 |
| | super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2) |
| | self.resi_ratio = abs(quant_resi) |
| |
|
| | def forward(self, h_BChw): |
| | return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio) |
| |
|
| | class VectorQuantizerMVQ(nn.Module): |
| | def __init__( |
| | self, |
| | codebook_size, |
| | token_size, |
| | commitment_cost=0.25, |
| | use_l2_norm=False, |
| | |
| | clustering_vq=False, |
| | num_codebooks=16 |
| | ): |
| | super().__init__() |
| | self.num_codebooks = num_codebooks |
| | self.codebooks = nn.ModuleList() |
| | for _ in range(num_codebooks): |
| | codebook = VectorQuantizer( |
| | codebook_size=codebook_size // num_codebooks, |
| | token_size=token_size // num_codebooks, |
| | commitment_cost=commitment_cost, |
| | use_l2_norm=use_l2_norm, |
| | clustering_vq=clustering_vq, |
| | ) |
| | self.codebooks.append(codebook) |
| |
|
| | def init_vocab(self, eini: float): |
| | for codebook in self.codebooks: |
| | codebook.init_vocab(eini) |
| |
|
| | def f_to_idx(self, features): |
| | indices = [] |
| | chunk_size = features.shape[-1] // self.num_codebooks |
| | splited_features = features.split(chunk_size, dim=-1) |
| | for i, codebook in enumerate(self.codebooks): |
| | indices.append(codebook.f_to_idx(splited_features[i])) |
| | indices = torch.stack(indices, dim=1) |
| | return indices |
| |
|
| | def idx_to_f(self, indices): |
| | assert indices.shape[1] == self.num_codebooks |
| | latent_features = [] |
| | for i, codebook in enumerate(self.codebooks): |
| | sub_indices = indices[:, i].flatten(start_dim=1) |
| | latent_feature = codebook.codebook(sub_indices) |
| | latent_features.append(latent_feature) |
| | latent_features = torch.cat(latent_features, dim=-1) |
| | return latent_features |
| |
|
| | def get_codebook_entry(self, indices): |
| | """Get codebook entries for multi-codebook indices. |
| | |
| | Args: |
| | indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W) |
| | |
| | Returns: |
| | z_quantized: Quantized features |
| | """ |
| | if len(indices.shape) == 2: |
| | |
| | latent_features = [] |
| | for i, codebook in enumerate(self.codebooks): |
| | sub_indices = indices[:, i] |
| | latent_feature = codebook.get_codebook_entry(sub_indices) |
| | latent_features.append(latent_feature) |
| | return torch.cat(latent_features, dim=-1) |
| | elif len(indices.shape) == 4: |
| | |
| | batch_size, _, height, width = indices.shape |
| | latent_features = [] |
| | for i, codebook in enumerate(self.codebooks): |
| | sub_indices = indices[:, i] |
| | latent_feature = codebook.get_codebook_entry(sub_indices.flatten()) |
| | |
| | latent_feature = latent_feature.view(batch_size, height, width, -1) |
| | latent_features.append(latent_feature) |
| | |
| | latent_features = torch.cat(latent_features, dim=-1) |
| | return latent_features.permute(0, 3, 1, 2).contiguous() |
| | else: |
| | raise NotImplementedError(f"Unsupported indices shape: {indices.shape}") |
| |
|
| | def forward(self, features): |
| | latent_features = [] |
| | all_result_dicts = [] |
| | chunk_size = features.shape[1] // self.num_codebooks |
| | splited_features = features.split(chunk_size, dim=1) |
| |
|
| | for i, codebook in enumerate(self.codebooks): |
| | latent_feature, result_dict = codebook(splited_features[i].float()) |
| | latent_features.append(latent_feature.to(features.dtype)) |
| | all_result_dicts.append(result_dict) |
| | |
| | |
| | z_quantized = torch.cat(latent_features, dim=1) |
| | |
| | |
| | global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks |
| | global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks |
| | global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks |
| | |
| | |
| | |
| | |
| | all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1) |
| | |
| | result_dict = dict( |
| | quantizer_loss=global_quantizer_loss, |
| | commitment_loss=global_commitment_loss, |
| | codebook_loss=global_codebook_loss, |
| | min_encoding_indices=all_indices |
| | ) |
| | |
| | return z_quantized, result_dict |