File size: 6,810 Bytes
7bef20f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 | import torch
from typing import List, Tuple
from torch.nn import functional as F
from torch import distributed as tdist, nn as nn
from .quantizer import VectorQuantizer
def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
logits = -E_dist.float().mul_(inv_entropy_tau)
# calc per_sample_entropy
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
# calc codebook_entropy
avg_prob = prob.mean(dim=0) # (vocab_size,)
log_avg_prob = torch.log(avg_prob + 1e-7)
codebook_entropy = (-avg_prob * log_avg_prob).sum()
# calc entropy_loss
entropy_loss = per_sample_entropy - codebook_entropy
return entropy_loss
class NormalizedEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
# self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
def forward(self, idx):
return F.embedding(
idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
def get_norm_weight(self):
return F.normalize(self.weight, dim=1)
class ResConv(nn.Conv2d):
def __init__(self, embed_dim, quant_resi):
ks = 3 if quant_resi < 0 else 1
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw):
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
class VectorQuantizerMVQ(nn.Module):
def __init__(
self,
codebook_size,
token_size,
commitment_cost=0.25,
use_l2_norm=False,
# entropy_temp=0.01, # we do not use this
clustering_vq=False,
num_codebooks=16
):
super().__init__()
self.num_codebooks = num_codebooks
self.codebooks = nn.ModuleList()
for _ in range(num_codebooks):
codebook = VectorQuantizer(
codebook_size=codebook_size // num_codebooks,
token_size=token_size // num_codebooks,
commitment_cost=commitment_cost,
use_l2_norm=use_l2_norm,
clustering_vq=clustering_vq,
)
self.codebooks.append(codebook)
def init_vocab(self, eini: float):
for codebook in self.codebooks:
codebook.init_vocab(eini)
def f_to_idx(self, features):
indices = []
chunk_size = features.shape[-1] // self.num_codebooks
splited_features = features.split(chunk_size, dim=-1)
for i, codebook in enumerate(self.codebooks):
indices.append(codebook.f_to_idx(splited_features[i]))
indices = torch.stack(indices, dim=1)
return indices
def idx_to_f(self, indices):
assert indices.shape[1] == self.num_codebooks
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i].flatten(start_dim=1)
latent_feature = codebook.codebook(sub_indices)
latent_features.append(latent_feature)
latent_features = torch.cat(latent_features, dim=-1)
return latent_features
def get_codebook_entry(self, indices):
"""Get codebook entries for multi-codebook indices.
Args:
indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W)
Returns:
z_quantized: Quantized features
"""
if len(indices.shape) == 2:
# indices shape: (N, num_codebooks)
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i]
latent_feature = codebook.get_codebook_entry(sub_indices)
latent_features.append(latent_feature)
return torch.cat(latent_features, dim=-1)
elif len(indices.shape) == 4:
# indices shape: (B, num_codebooks, H, W)
batch_size, _, height, width = indices.shape
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i] # (B, H, W)
latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
# Reshape to (B, H, W, token_size // num_codebooks)
latent_feature = latent_feature.view(batch_size, height, width, -1)
latent_features.append(latent_feature)
# Concatenate along the last dimension and rearrange to (B, C, H, W)
latent_features = torch.cat(latent_features, dim=-1) # (B, H, W, C)
return latent_features.permute(0, 3, 1, 2).contiguous() # (B, C, H, W)
else:
raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
def forward(self, features):
latent_features = []
all_result_dicts = []
chunk_size = features.shape[1] // self.num_codebooks
splited_features = features.split(chunk_size, dim=1)
for i, codebook in enumerate(self.codebooks):
latent_feature, result_dict = codebook(splited_features[i].float())
latent_features.append(latent_feature.to(features.dtype))
all_result_dicts.append(result_dict)
# Concatenate latent features
z_quantized = torch.cat(latent_features, dim=1) # Concatenate along channel dimension
# Calculate global losses
global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks
global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks
global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks
# Collect all min_encoding_indices
# Each codebook returns indices of shape (B, H, W)
# Stack them to get shape (B, num_codebooks, H, W)
all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1)
result_dict = dict(
quantizer_loss=global_quantizer_loss,
commitment_loss=global_commitment_loss,
codebook_loss=global_codebook_loss,
min_encoding_indices=all_indices
)
return z_quantized, result_dict |