APGASU's picture
scripts
7bef20f verified
import torch
from typing import List, Tuple
from torch.nn import functional as F
from torch import distributed as tdist, nn as nn
from .quantizer import VectorQuantizer
def get_entropy_loss(latent_embed, codebook_embed, inv_entropy_tau):
E_dist = latent_embed.square().sum(dim=1, keepdim=True) + codebook_embed.square().sum(dim=1, keepdim=False)
E_dist.addmm_(latent_embed, codebook_embed.T, alpha=-2, beta=1) # E_dist: (N, vocab_size)
logits = -E_dist.float().mul_(inv_entropy_tau)
# calc per_sample_entropy
prob, log_prob = logits.softmax(dim=-1), logits.log_softmax(dim=-1) # both are (N, vocab_size)
per_sample_entropy = torch.mean((-prob * log_prob).sum(dim=-1))
# calc codebook_entropy
avg_prob = prob.mean(dim=0) # (vocab_size,)
log_avg_prob = torch.log(avg_prob + 1e-7)
codebook_entropy = (-avg_prob * log_avg_prob).sum()
# calc entropy_loss
entropy_loss = per_sample_entropy - codebook_entropy
return entropy_loss
class NormalizedEmbedding(nn.Embedding):
def __init__(self, num_embeddings: int, embedding_dim: int):
super().__init__(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
# self.norm_scale = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))
def forward(self, idx):
return F.embedding(
idx, F.normalize(self.weight, dim=1), self.padding_idx, self.max_norm,
self.norm_type, self.scale_grad_by_freq, self.sparse
)
def get_norm_weight(self):
return F.normalize(self.weight, dim=1)
class ResConv(nn.Conv2d):
def __init__(self, embed_dim, quant_resi):
ks = 3 if quant_resi < 0 else 1
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
self.resi_ratio = abs(quant_resi)
def forward(self, h_BChw):
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
class VectorQuantizerMVQ(nn.Module):
def __init__(
self,
codebook_size,
token_size,
commitment_cost=0.25,
use_l2_norm=False,
# entropy_temp=0.01, # we do not use this
clustering_vq=False,
num_codebooks=16
):
super().__init__()
self.num_codebooks = num_codebooks
self.codebooks = nn.ModuleList()
for _ in range(num_codebooks):
codebook = VectorQuantizer(
codebook_size=codebook_size // num_codebooks,
token_size=token_size // num_codebooks,
commitment_cost=commitment_cost,
use_l2_norm=use_l2_norm,
clustering_vq=clustering_vq,
)
self.codebooks.append(codebook)
def init_vocab(self, eini: float):
for codebook in self.codebooks:
codebook.init_vocab(eini)
def f_to_idx(self, features):
indices = []
chunk_size = features.shape[-1] // self.num_codebooks
splited_features = features.split(chunk_size, dim=-1)
for i, codebook in enumerate(self.codebooks):
indices.append(codebook.f_to_idx(splited_features[i]))
indices = torch.stack(indices, dim=1)
return indices
def idx_to_f(self, indices):
assert indices.shape[1] == self.num_codebooks
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i].flatten(start_dim=1)
latent_feature = codebook.codebook(sub_indices)
latent_features.append(latent_feature)
latent_features = torch.cat(latent_features, dim=-1)
return latent_features
def get_codebook_entry(self, indices):
"""Get codebook entries for multi-codebook indices.
Args:
indices: Tensor of shape (N, num_codebooks) or (N, num_codebooks, H, W)
Returns:
z_quantized: Quantized features
"""
if len(indices.shape) == 2:
# indices shape: (N, num_codebooks)
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i]
latent_feature = codebook.get_codebook_entry(sub_indices)
latent_features.append(latent_feature)
return torch.cat(latent_features, dim=-1)
elif len(indices.shape) == 4:
# indices shape: (B, num_codebooks, H, W)
batch_size, _, height, width = indices.shape
latent_features = []
for i, codebook in enumerate(self.codebooks):
sub_indices = indices[:, i] # (B, H, W)
latent_feature = codebook.get_codebook_entry(sub_indices.flatten())
# Reshape to (B, H, W, token_size // num_codebooks)
latent_feature = latent_feature.view(batch_size, height, width, -1)
latent_features.append(latent_feature)
# Concatenate along the last dimension and rearrange to (B, C, H, W)
latent_features = torch.cat(latent_features, dim=-1) # (B, H, W, C)
return latent_features.permute(0, 3, 1, 2).contiguous() # (B, C, H, W)
else:
raise NotImplementedError(f"Unsupported indices shape: {indices.shape}")
def forward(self, features):
latent_features = []
all_result_dicts = []
chunk_size = features.shape[1] // self.num_codebooks
splited_features = features.split(chunk_size, dim=1)
for i, codebook in enumerate(self.codebooks):
latent_feature, result_dict = codebook(splited_features[i].float())
latent_features.append(latent_feature.to(features.dtype))
all_result_dicts.append(result_dict)
# Concatenate latent features
z_quantized = torch.cat(latent_features, dim=1) # Concatenate along channel dimension
# Calculate global losses
global_quantizer_loss = sum(rd['quantizer_loss'] for rd in all_result_dicts) / self.num_codebooks
global_commitment_loss = sum(rd['commitment_loss'] for rd in all_result_dicts) / self.num_codebooks
global_codebook_loss = sum(rd['codebook_loss'] for rd in all_result_dicts) / self.num_codebooks
# Collect all min_encoding_indices
# Each codebook returns indices of shape (B, H, W)
# Stack them to get shape (B, num_codebooks, H, W)
all_indices = torch.stack([rd['min_encoding_indices'] for rd in all_result_dicts], dim=1)
result_dict = dict(
quantizer_loss=global_quantizer_loss,
commitment_loss=global_commitment_loss,
codebook_loss=global_codebook_loss,
min_encoding_indices=all_indices
)
return z_quantized, result_dict