File size: 7,290 Bytes
7bef20f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Mapping, Text, Tuple
from einops import rearrange
from torch.cuda.amp import autocast
class SoftVectorQuantizer(torch.nn.Module):
def __init__(self,
codebook_size: int = 1024,
token_size: int = 256,
commitment_cost: float = 0.25,
use_l2_norm: bool = False,
clustering_vq: bool = False,
entropy_loss_ratio: float = 0.01,
tau: float = 0.07,
num_codebooks: int = 1,
show_usage: bool = False
):
super().__init__()
# Map new parameter names to internal names for compatibility
self.codebook_size = codebook_size
self.token_size = token_size
self.commitment_cost = commitment_cost
self.use_l2_norm = use_l2_norm
self.clustering_vq = clustering_vq
# Keep soft quantization specific parameters
self.num_codebooks = num_codebooks
self.n_e = codebook_size
self.e_dim = token_size
self.entropy_loss_ratio = entropy_loss_ratio
self.l2_norm = use_l2_norm
self.show_usage = show_usage
self.tau = tau
# Single embedding layer for all codebooks
self.embedding = nn.Parameter(torch.randn(num_codebooks, codebook_size, token_size))
self.embedding.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
if self.l2_norm:
self.embedding.data = F.normalize(self.embedding.data, p=2, dim=-1)
if self.show_usage:
self.register_buffer("codebook_used", torch.zeros(num_codebooks, 65536))
# Ensure quantization is performed using f32
@autocast(enabled=False)
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]:
z = z.float()
original_shape = z.shape
# Handle input reshaping to match VectorQuantizer format
z = rearrange(z, 'b c h w -> b h w c').contiguous()
z = z.view(z.size(0), -1, z.size(-1))
batch_size, seq_length, _ = z.shape
# Ensure sequence length is divisible by number of codebooks
assert seq_length % self.num_codebooks == 0, \
f"Sequence length ({seq_length}) must be divisible by number of codebooks ({self.num_codebooks})"
segment_length = seq_length // self.num_codebooks
z_segments = z.view(batch_size, self.num_codebooks, segment_length, self.e_dim)
# Apply L2 norm if needed
embedding = F.normalize(self.embedding, p=2, dim=-1) if self.l2_norm else self.embedding
if self.l2_norm:
z_segments = F.normalize(z_segments, p=2, dim=-1)
z_flat = z_segments.permute(1, 0, 2, 3).contiguous().view(self.num_codebooks, -1, self.e_dim)
logits = torch.einsum('nbe, nke -> nbk', z_flat, embedding.detach())
# Calculate probabilities (soft quantization)
probs = F.softmax(logits / self.tau, dim=-1)
# Soft quantize
z_q = torch.einsum('nbk, nke -> nbe', probs, embedding)
# Reshape back
z_q = z_q.view(self.num_codebooks, batch_size, segment_length, self.e_dim).permute(1, 0, 2, 3).contiguous()
# Calculate cosine similarity
with torch.no_grad():
zq_z_cos = F.cosine_similarity(
z_segments.view(-1, self.e_dim),
z_q.view(-1, self.e_dim),
dim=-1
).mean()
# Get indices for usage tracking
indices = torch.argmax(probs, dim=-1) # (num_codebooks, batch_size * segment_length)
indices = indices.transpose(0, 1).contiguous() # (batch_size * segment_length, num_codebooks)
# Track codebook usage
if self.show_usage and self.training:
for k in range(self.num_codebooks):
cur_len = indices.size(0)
self.codebook_used[k, :-cur_len].copy_(self.codebook_used[k, cur_len:].clone())
self.codebook_used[k, -cur_len:].copy_(indices[:, k])
# Calculate losses if training
if self.training:
# Soft quantization doesn't have traditional commitment/codebook loss
# Map entropy loss to quantizer_loss for compatibility
entropy_loss = self.entropy_loss_ratio * compute_entropy_loss(logits.view(-1, self.n_e))
quantizer_loss = entropy_loss
commitment_loss = torch.tensor(0.0, device=z.device)
codebook_loss = torch.tensor(0.0, device=z.device)
else:
quantizer_loss = torch.tensor(0.0, device=z.device)
commitment_loss = torch.tensor(0.0, device=z.device)
codebook_loss = torch.tensor(0.0, device=z.device)
# Calculate codebook usage
codebook_usage = torch.tensor([
len(torch.unique(self.codebook_used[k])) / self.n_e
for k in range(self.num_codebooks)
]).mean() if self.show_usage else 0
z_q = z_q.view(batch_size, -1, self.e_dim)
# Reshape back to original input shape to match VectorQuantizer
z_q = z_q.view(batch_size, original_shape[2], original_shape[3], original_shape[1])
z_quantized = rearrange(z_q, 'b h w c -> b c h w').contiguous()
# Calculate average probabilities
avg_probs = torch.mean(torch.mean(probs, dim=-1))
max_probs = torch.mean(torch.max(probs, dim=-1)[0])
# Return format matching VectorQuantizer
result_dict = dict(
quantizer_loss=quantizer_loss,
commitment_loss=commitment_loss,
codebook_loss=codebook_loss,
min_encoding_indices=indices.view(batch_size, self.num_codebooks, segment_length).view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3])
)
return z_quantized, result_dict
def get_codebook_entry(self, indices):
"""Added for compatibility with VectorQuantizer API"""
if len(indices.shape) == 1:
# For single codebook case
z_quantized = self.embedding[0][indices]
elif len(indices.shape) == 2:
z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding[0])
else:
raise NotImplementedError
if self.use_l2_norm:
z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)
return z_quantized
def compute_entropy_loss(affinity, loss_type="softmax", temperature=0.01):
flat_affinity = affinity.reshape(-1, affinity.shape[-1])
flat_affinity /= temperature
probs = F.softmax(flat_affinity, dim=-1)
log_probs = F.log_softmax(flat_affinity + 1e-5, dim=-1)
if loss_type == "softmax":
target_probs = probs
else:
raise ValueError("Entropy loss {} not supported".format(loss_type))
avg_probs = torch.mean(target_probs, dim=0)
avg_entropy = - torch.sum(avg_probs * torch.log(avg_probs + 1e-6))
sample_entropy = - torch.mean(torch.sum(target_probs * log_probs, dim=-1))
loss = sample_entropy - avg_entropy
return loss |