| | """ |
| | Lookup Free Quantization |
| | Proposed in https://arxiv.org/abs/2310.05737 |
| | |
| | In the simplest setup, each dimension is quantized into {-1, 1}. |
| | An entropy penalty is used to encourage utilization. |
| | |
| | Refer to |
| | https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/lookup_free_quantization.py |
| | https://github.com/theAdamColton/ijepa-enhanced/blob/7edef5f7288ae8f537f0db8a10044a2a487f70c9/ijepa_enhanced/lfq.py |
| | """ |
| |
|
| | from math import log2, ceil |
| | from collections import namedtuple |
| |
|
| | import torch |
| | from torch import nn, einsum |
| | import torch.nn.functional as F |
| | from torch.nn import Module |
| |
|
| | from einops import rearrange, reduce, pack, unpack |
| |
|
| | |
| | LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs']) |
| |
|
| | |
| | def exists(v): |
| | return v is not None |
| |
|
| | def default(*args): |
| | for arg in args: |
| | if exists(arg): |
| | return arg() if callable(arg) else arg |
| | return None |
| |
|
| | def pack_one(t, pattern): |
| | return pack([t], pattern) |
| |
|
| | def unpack_one(t, ps, pattern): |
| | return unpack(t, ps, pattern)[0] |
| |
|
| | |
| | def entropy(prob): |
| | return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) |
| |
|
| | |
| | def mult_along_first_dims(x, y): |
| | """ |
| | returns x * y elementwise along the leading dimensions of y |
| | """ |
| | ndim_to_expand = x.ndim - y.ndim |
| | for _ in range(ndim_to_expand): |
| | y = y.unsqueeze(-1) |
| | return x * y |
| |
|
| | def masked_mean(x, m): |
| | """ |
| | takes the mean of the elements of x that are not masked |
| | the mean is taken along the shared leading dims of m |
| | equivalent to: x[m].mean(tuple(range(m.ndim))) |
| | |
| | The benefit of using masked_mean rather than using |
| | tensor indexing is that masked_mean is much faster |
| | for torch-compile on batches. |
| | |
| | The drawback is larger floating point errors |
| | """ |
| | x = mult_along_first_dims(x, m) |
| | x = x / m.sum() |
| | return x.sum(tuple(range(m.ndim))) |
| |
|
| |
|
| | def entropy_loss( |
| | logits, |
| | mask=None, |
| | temperature=0.01, |
| | sample_minimization_weight=1.0, |
| | batch_maximization_weight=1.0, |
| | eps=1e-5, |
| | ): |
| | """ |
| | Entropy loss of unnormalized logits |
| | |
| | logits: Affinities are over the last dimension |
| | |
| | https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 |
| | LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) |
| | """ |
| | probs = F.softmax(logits / temperature, -1) |
| | log_probs = F.log_softmax(logits / temperature + eps, -1) |
| |
|
| | if mask is not None: |
| | |
| | |
| |
|
| | avg_probs = masked_mean(probs, mask) |
| | |
| | else: |
| | avg_probs = reduce(probs, "... D -> D", "mean") |
| |
|
| | avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) |
| |
|
| | sample_entropy = -torch.sum(probs * log_probs, -1) |
| | if mask is not None: |
| | |
| | sample_entropy = masked_mean(sample_entropy, mask).mean() |
| | else: |
| | sample_entropy = torch.mean(sample_entropy) |
| |
|
| | loss = (sample_minimization_weight * sample_entropy) - ( |
| | batch_maximization_weight * avg_entropy |
| | ) |
| |
|
| | return sample_entropy, avg_entropy, loss |
| |
|
| |
|
| | class GFQ(Module): |
| | def __init__( |
| | self, |
| | *, |
| | dim, |
| | num_codebooks = 1, |
| | sample_minimization_weight=1.0, |
| | batch_maximization_weight=1.0, |
| | ): |
| | super().__init__() |
| | self.token_factorization = num_codebooks > 1 |
| | self.codebook_dim = dim // num_codebooks |
| | self.codebook_size = 2 ** self.codebook_dim |
| | self.dim = dim |
| | self.num_codebooks = num_codebooks |
| | self.vocab_size = num_codebooks * self.codebook_size |
| | |
| | |
| | self.sample_minimization_weight = sample_minimization_weight |
| | self.batch_maximization_weight = batch_maximization_weight |
| | self.factorized_bits = [self.codebook_dim] * num_codebooks |
| | for i, factorized_bit in enumerate(self.factorized_bits): |
| | self.register_buffer(f"mask_{i}", 2 ** torch.arange(factorized_bit), persistent=False) |
| | |
| | |
| | all_codes = torch.arange(self.codebook_size) |
| | bits = self.indices_to_bits(all_codes) |
| | codebook = bits * 2.0 - 1.0 |
| | self.register_buffer('codebook', codebook, persistent = False) |
| | self.register_buffer('zero', torch.tensor(0.), persistent = False) |
| |
|
| | @property |
| | def dtype(self): |
| | return self.codebook.dtype |
| | |
| | def indices_to_bits(self, x): |
| | """ |
| | x: long tensor of indices |
| | |
| | returns big endian bits |
| | """ |
| | mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) |
| | x = (x.unsqueeze(-1) & mask) != 0 |
| | return x |
| |
|
| | def get_codebook_entry(self, x, bhwc, index_order): |
| | mask = getattr(self, f"mask_{index_order}") if self.token_factorization else self.mask |
| | mask = mask.to(device=x.device, dtype=torch.long) |
| | |
| | x = (x.unsqueeze(-1) & mask) != 0 |
| | x = x * 2.0 - 1.0 |
| | b, h, w, c = bhwc |
| | x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) |
| | x = rearrange(x, "b h w c -> b c h w") |
| | return x |
| |
|
| | def bits_to_indices(self, bits): |
| | """ |
| | bits: bool tensor of big endian bits, where the last dimension is the bit dimension |
| | |
| | returns indices, which are long integers from 0 to self.codebook_size |
| | """ |
| | assert bits.shape[-1] == self.codebook_dim |
| | indices = 2 ** torch.arange( |
| | 0, |
| | self.codebook_dim, |
| | 1, |
| | dtype=torch.long, |
| | device=bits.device, |
| | ) |
| | return (bits * indices).sum(-1) |
| | |
| | def decode(self, x): |
| | """ |
| | x: ... NH |
| | where NH is number of codebook heads |
| | A longtensor of codebook indices, containing values from |
| | 0 to self.codebook_size |
| | """ |
| | x = self.indices_to_bits(x) |
| | x = x.to(self.dtype) |
| | x = x * 2 - 1 |
| | x = rearrange(x, "... NC Z-> ... (NC Z)") |
| | return x |
| |
|
| | def forward( |
| | self, |
| | x, |
| | inv_temperature = 100., |
| | return_loss_breakdown = False, |
| | mask = None, |
| | return_loss = True, |
| | ): |
| | """ |
| | einstein notation |
| | b - batch |
| | n - sequence (or flattened spatial dimensions) |
| | d - feature dimension, which is also log2(codebook size) |
| | c - number of codebook dim |
| | """ |
| | x = rearrange(x, 'b d ... -> b ... d') |
| | x, ps = pack_one(x, 'b * d') |
| | x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) |
| |
|
| | codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) |
| | quantized = torch.where(x > 0, codebook_value, -codebook_value) |
| |
|
| | |
| | if self.token_factorization: |
| | quantized = rearrange(quantized, 'b n c d -> b n 1 (c d)') |
| | indices_list = [] |
| | begin = 0 |
| | end = 0 |
| | for i, factorized_bit in enumerate(self.factorized_bits): |
| | end += factorized_bit |
| | mask_name = f"mask_{i}" |
| | mask = getattr(self, mask_name) |
| | indices = reduce((quantized[..., begin:end] > 0).int() * mask.int(), "b n c d -> b n c", "sum") |
| | indices_list.append(indices) |
| | begin += factorized_bit |
| | quantized = rearrange(quantized, 'b n 1 (c d) -> b n c d', c = self.num_codebooks) |
| | else: |
| | indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') |
| |
|
| | |
| | if self.training and return_loss: |
| | logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) |
| | |
| | per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( |
| | logits = logits, |
| | sample_minimization_weight = self.sample_minimization_weight, |
| | batch_maximization_weight = self.batch_maximization_weight |
| | ) |
| |
|
| | avg_probs = self.zero |
| | else: |
| | per_sample_entropy = codebook_entropy = self.zero |
| | entropy_aux_loss = self.zero |
| | avg_probs = self.zero |
| |
|
| | |
| | if self.training: |
| | commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none') |
| |
|
| | if exists(mask): |
| | commit_loss = commit_loss[mask] |
| |
|
| | commit_loss = commit_loss.mean() |
| | else: |
| | commit_loss = self.zero |
| |
|
| |
|
| | |
| | if self.training: |
| | quantized = x + (quantized - x).detach() |
| |
|
| | |
| | quantized = rearrange(quantized, 'b n c d -> b n (c d)') |
| |
|
| | |
| | quantized = unpack_one(quantized, ps, 'b * d') |
| | quantized = rearrange(quantized, 'b ... d -> b d ...') |
| | |
| | if self.token_factorization: |
| | indices_ = [] |
| | for i, indices in enumerate(indices_list): |
| | indices = unpack_one(indices, ps, "b * c") |
| | indices = indices.flatten() |
| | indices_.append(indices) |
| | indices = indices_ |
| | else: |
| | indices = unpack_one(indices, ps, 'b * c') |
| | indices = indices.flatten() |
| |
|
| | ret = (quantized, entropy_aux_loss, indices) |
| |
|
| | if not return_loss_breakdown: |
| | return ret |
| |
|
| | return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | quantizer = GFQ( |
| | codebook_size = 2**18, |
| | dim = 18, |
| | sample_minimization_weight = 1.0, |
| | batch_maximization_weight = 1.0 |
| | ) |
| |
|
| | image_feats = torch.randn(2, 18, 16, 16) |
| |
|
| | quantized, indices, entropy_aux_loss = quantizer(image_feats, inv_temperature=100.) |
| |
|
| | assert image_feats.shape == quantized.shape |
| | assert (quantized == quantizer.indices_to_codes(indices)).all() |