| """ |
| Hugging Face compatible implementation of Open-MAGVIt2 |
| Code reference: https://github.com/TencentARC/Open-MAGVIT2 |
| """ |
|
|
|
|
| from math import log2, ceil |
| from collections import namedtuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from einops import rearrange, reduce, pack, unpack |
| from torch import einsum |
| from torch.nn import Module |
| from transformers import PreTrainedModel |
|
|
| from .configuration_lfq_tokenizer import LFQTokenizerConfig |
|
|
|
|
| def swish(x): |
| |
| return x * torch.sigmoid(x) |
|
|
|
|
| class ResBlock(nn.Module): |
| def __init__(self, |
| in_filters, |
| out_filters, |
| use_conv_shortcut = False |
| ) -> None: |
| super().__init__() |
|
|
| self.in_filters = in_filters |
| self.out_filters = out_filters |
| self.use_conv_shortcut = use_conv_shortcut |
|
|
| self.norm1 = nn.GroupNorm(32, in_filters, eps=1e-6) |
| self.norm2 = nn.GroupNorm(32, out_filters, eps=1e-6) |
|
|
| self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) |
| self.conv2 = nn.Conv2d(out_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) |
|
|
| if in_filters != out_filters: |
| if self.use_conv_shortcut: |
| self.conv_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(3, 3), padding=1, bias=False) |
| else: |
| self.nin_shortcut = nn.Conv2d(in_filters, out_filters, kernel_size=(1, 1), padding=0, bias=False) |
| |
|
|
| def forward(self, x, **kwargs): |
| residual = x |
|
|
| x = self.norm1(x) |
| x = swish(x) |
| x = self.conv1(x) |
| x = self.norm2(x) |
| x = swish(x) |
| x = self.conv2(x) |
| if self.in_filters != self.out_filters: |
| if self.use_conv_shortcut: |
| residual = self.conv_shortcut(residual) |
| else: |
| residual = self.nin_shortcut(residual) |
|
|
| return x + residual |
|
|
| class Encoder(nn.Module): |
| def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)): |
| super().__init__() |
|
|
| self.in_channels = in_channels |
| self.z_channels = z_channels |
|
|
| self.num_res_blocks = num_res_blocks |
| self.num_blocks = len(ch_mult) |
| |
| self.conv_in = nn.Conv2d(in_channels, |
| ch, |
| kernel_size=(3, 3), |
| padding=1, |
| bias=False |
| ) |
|
|
| |
| self.down = nn.ModuleList() |
|
|
| in_ch_mult = (1,)+tuple(ch_mult) |
| for i_level in range(self.num_blocks): |
| block = nn.ModuleList() |
| block_in = ch*in_ch_mult[i_level] |
| block_out = ch*ch_mult[i_level] |
| for _ in range(self.num_res_blocks): |
| block.append(ResBlock(block_in, block_out)) |
| block_in = block_out |
| |
| down = nn.Module() |
| down.block = block |
| if i_level < self.num_blocks - 1: |
| down.downsample = nn.Conv2d(block_out, block_out, kernel_size=(3, 3), stride=(2, 2), padding=1) |
|
|
| self.down.append(down) |
| |
| |
| self.mid_block = nn.ModuleList() |
| for res_idx in range(self.num_res_blocks): |
| self.mid_block.append(ResBlock(block_in, block_in)) |
| |
| |
| self.norm_out = nn.GroupNorm(32, block_out, eps=1e-6) |
| self.conv_out = nn.Conv2d(block_out, z_channels, kernel_size=(1, 1)) |
| |
| def forward(self, x): |
|
|
| |
| x = self.conv_in(x) |
| for i_level in range(self.num_blocks): |
| for i_block in range(self.num_res_blocks): |
| x = self.down[i_level].block[i_block](x) |
| |
| if i_level < self.num_blocks - 1: |
| x = self.down[i_level].downsample(x) |
| |
| |
| for res in range(self.num_res_blocks): |
| x = self.mid_block[res](x) |
| |
|
|
| x = self.norm_out(x) |
| x = swish(x) |
| x = self.conv_out(x) |
|
|
| return x |
|
|
|
|
| class Decoder(nn.Module): |
| def __init__(self, *, ch, out_ch, in_channels, num_res_blocks, z_channels, ch_mult=(1, 2, 2, 4)) -> None: |
| super().__init__() |
|
|
| self.ch = ch |
| self.num_blocks = len(ch_mult) |
| self.num_res_blocks = num_res_blocks |
| self.in_channels = in_channels |
|
|
| block_in = ch*ch_mult[self.num_blocks-1] |
|
|
| self.conv_in = nn.Conv2d( |
| z_channels, block_in, kernel_size=(3, 3), padding=1, bias=True |
| ) |
|
|
| self.mid_block = nn.ModuleList() |
| for res_idx in range(self.num_res_blocks): |
| self.mid_block.append(ResBlock(block_in, block_in)) |
| |
| self.up = nn.ModuleList() |
|
|
| for i_level in reversed(range(self.num_blocks)): |
| block = nn.ModuleList() |
| block_out = ch*ch_mult[i_level] |
| for i_block in range(self.num_res_blocks): |
| block.append(ResBlock(block_in, block_out)) |
| block_in = block_out |
| |
| up = nn.Module() |
| up.block = block |
| if i_level > 0: |
| up.upsample = Upsampler(block_in) |
| self.up.insert(0, up) |
| |
| self.norm_out = nn.GroupNorm(32, block_in, eps=1e-6) |
|
|
| self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=(3, 3), padding=1) |
| |
| def forward(self, z): |
|
|
| z = self.conv_in(z) |
|
|
| |
| for res in range(self.num_res_blocks): |
| z = self.mid_block[res](z) |
| |
| |
| for i_level in reversed(range(self.num_blocks)): |
| for i_block in range(self.num_res_blocks): |
| z = self.up[i_level].block[i_block](z) |
| |
| if i_level > 0: |
| z = self.up[i_level].upsample(z) |
| |
| z = self.norm_out(z) |
| z = swish(z) |
| z = self.conv_out(z) |
|
|
| return z |
|
|
|
|
| def depth_to_space(x: torch.Tensor, block_size: int) -> torch.Tensor: |
| """ Depth-to-Space DCR mode (depth-column-row) core implementation. |
| |
| Args: |
| x (torch.Tensor): input tensor. The channels-first (*CHW) layout is supported. |
| block_size (int): block side size |
| """ |
| |
| if x.dim() < 3: |
| raise ValueError( |
| f"Expecting a channels-first (*CHW) tensor of at least 3 dimensions" |
| ) |
| c, h, w = x.shape[-3:] |
|
|
| s = block_size**2 |
| if c % s != 0: |
| raise ValueError( |
| f"Expecting a channels-first (*CHW) tensor with C divisible by {s}, but got C={c} channels" |
| ) |
|
|
| outer_dims = x.shape[:-3] |
|
|
| |
| x = x.view(-1, block_size, block_size, c // s, h, w) |
|
|
| |
| x = x.permute(0, 3, 4, 1, 5, 2) |
|
|
| |
| x = x.contiguous().view(*outer_dims, c // s, h * block_size, |
| w * block_size) |
|
|
| return x |
|
|
| class Upsampler(nn.Module): |
| def __init__( |
| self, |
| dim, |
| dim_out = None |
| ): |
| super().__init__() |
| dim_out = dim * 4 |
| self.conv1 = nn.Conv2d(dim, dim_out, (3, 3), padding=1) |
| self.depth2space = depth_to_space |
|
|
| def forward(self, x): |
| """ |
| input_image: [B C H W] |
| """ |
| out = self.conv1(x) |
| out = self.depth2space(out, block_size=2) |
| return out |
|
|
|
|
| class AdaptiveGroupNorm(nn.Module): |
| def __init__(self, z_channel, in_filters, num_groups=32, eps=1e-6): |
| super().__init__() |
| self.gn = nn.GroupNorm(num_groups=32, num_channels=in_filters, eps=eps, affine=False) |
| |
| self.gamma = nn.Linear(z_channel, in_filters) |
| self.beta = nn.Linear(z_channel, in_filters) |
| self.eps = eps |
| |
| def forward(self, x, quantizer): |
| B, C, _, _ = x.shape |
| |
| |
| scale = rearrange(quantizer, "b c h w -> b c (h w)") |
| scale = scale.var(dim=-1) + self.eps |
| scale = scale.sqrt() |
| scale = self.gamma(scale).view(B, C, 1, 1) |
|
|
| |
| bias = rearrange(quantizer, "b c h w -> b c (h w)") |
| bias = bias.mean(dim=-1) |
| bias = self.beta(bias).view(B, C, 1, 1) |
| |
| x = self.gn(x) |
| x = scale * x + bias |
|
|
| return x |
|
|
|
|
| |
|
|
| LossBreakdown = namedtuple('LossBreakdown', ['per_sample_entropy', 'codebook_entropy', 'commitment', 'avg_probs']) |
|
|
| |
|
|
| def exists(v): |
| return v is not None |
|
|
| def default(*args): |
| for arg in args: |
| if exists(arg): |
| return arg() if callable(arg) else arg |
| return None |
|
|
| def pack_one(t, pattern): |
| return pack([t], pattern) |
|
|
| def unpack_one(t, ps, pattern): |
| return unpack(t, ps, pattern)[0] |
|
|
| |
|
|
| def entropy(prob): |
| return (-prob * torch.log(prob + 1e-5)).sum(dim=-1) |
|
|
| |
|
|
| def mult_along_first_dims(x, y): |
| """ |
| returns x * y elementwise along the leading dimensions of y |
| """ |
| ndim_to_expand = x.ndim - y.ndim |
| for _ in range(ndim_to_expand): |
| y = y.unsqueeze(-1) |
| return x * y |
|
|
| def masked_mean(x, m): |
| """ |
| takes the mean of the elements of x that are not masked |
| the mean is taken along the shared leading dims of m |
| equivalent to: x[m].mean(tuple(range(m.ndim))) |
| |
| The benefit of using masked_mean rather than using |
| tensor indexing is that masked_mean is much faster |
| for torch-compile on batches. |
| |
| The drawback is larger floating point errors |
| """ |
| x = mult_along_first_dims(x, m) |
| x = x / m.sum() |
| return x.sum(tuple(range(m.ndim))) |
|
|
| def entropy_loss( |
| logits, |
| mask=None, |
| temperature=0.01, |
| sample_minimization_weight=1.0, |
| batch_maximization_weight=1.0, |
| eps=1e-5, |
| ): |
| """ |
| Entropy loss of unnormalized logits |
| |
| logits: Affinities are over the last dimension |
| |
| https://github.com/google-research/magvit/blob/05e8cfd6559c47955793d70602d62a2f9b0bdef5/videogvt/train_lib/losses.py#L279 |
| LANGUAGE MODEL BEATS DIFFUSION — TOKENIZER IS KEY TO VISUAL GENERATION (2024) |
| """ |
| probs = F.softmax(logits / temperature, -1) |
| log_probs = F.log_softmax(logits / temperature + eps, -1) |
|
|
| if mask is not None: |
| avg_probs = masked_mean(probs, mask) |
| else: |
| avg_probs = reduce(probs, "... D -> D", "mean") |
|
|
| avg_entropy = -torch.sum(avg_probs * torch.log(avg_probs + eps)) |
|
|
| sample_entropy = -torch.sum(probs * log_probs, -1) |
| if mask is not None: |
| sample_entropy = masked_mean(sample_entropy, mask).mean() |
| else: |
| sample_entropy = torch.mean(sample_entropy) |
|
|
| loss = (sample_minimization_weight * sample_entropy) - ( |
| batch_maximization_weight * avg_entropy |
| ) |
|
|
| return sample_entropy, avg_entropy, loss |
|
|
|
|
| class LFQ(Module): |
| def __init__( |
| self, |
| *, |
| dim = None, |
| codebook_size = None, |
| num_codebooks = 1, |
| sample_minimization_weight=1.0, |
| batch_maximization_weight=1.0, |
| token_factorization = False, |
| ): |
| super().__init__() |
|
|
| |
|
|
| assert exists(dim) or exists(codebook_size), 'either dim or codebook_size must be specified for LFQ' |
| assert not exists(codebook_size) or log2(codebook_size).is_integer(), f'your codebook size must be a power of 2 for lookup free quantization (suggested {2 ** ceil(log2(codebook_size))})' |
|
|
| self.codebook_size = default(codebook_size, lambda: 2 ** dim) |
| self.codebook_dim = int(log2(codebook_size)) |
|
|
| codebook_dims = self.codebook_dim * num_codebooks |
| dim = default(dim, codebook_dims) |
|
|
| has_projections = dim != codebook_dims |
| self.has_projections = has_projections |
|
|
| self.dim = dim |
| self.codebook_dim = self.codebook_dim |
| self.num_codebooks = num_codebooks |
| |
| |
| self.sample_minimization_weight = sample_minimization_weight |
| self.batch_maximization_weight = batch_maximization_weight |
|
|
| |
| self.token_factorization = token_factorization |
| if not self.token_factorization: |
| self.register_buffer('mask', 2 ** torch.arange(self.codebook_dim - 1, -1, -1), persistent=False) |
| else: |
| k = self.codebook_dim // 2 |
| self.register_buffer("mask", 2 ** torch.arange(k - 1, -1, -1), persistent=False) |
|
|
| self.register_buffer('zero', torch.tensor(0.), persistent = False) |
|
|
| |
| all_codes = torch.arange(codebook_size) |
| bits = self.indices_to_bits(all_codes) |
| codebook = bits * 2.0 - 1.0 |
|
|
| self.register_buffer('codebook', codebook, persistent = False) |
|
|
| @property |
| def dtype(self): |
| return self.codebook.dtype |
| |
| def indices_to_bits(self, x): |
| """ |
| x: long tensor of indices for constructing codebook, but actually not utilized in all the experiments. |
| |
| returns big endian bits |
| """ |
| mask = 2 ** torch.arange(self.codebook_dim, device=x.device, dtype=torch.long) |
| |
| x = (x.unsqueeze(-1) & mask) != 0 |
| return x |
|
|
| def get_codebook_entry(self, x, bhwc): |
| if self.token_factorization: |
| k = self.codebook_dim // 2 |
| mask = 2 ** torch.arange(k - 1, -1, -1, device=x.device, dtype=torch.long) |
| else: |
| mask = 2 ** torch.arange(self.codebook_dim-1, -1, -1, device=x.device, dtype=torch.long) |
| |
| x = (x.unsqueeze(-1) & mask) != 0 |
| x = x * 2.0 - 1.0 |
| |
| b, h, w, c = bhwc |
| x = rearrange(x, "b (h w) c -> b h w c", h=h, w=w, c=c) |
| x = rearrange(x, "b h w c -> b c h w") |
| return x |
|
|
| def bits_to_indices(self, bits): |
| """ |
| bits: bool tensor of big endian bits, where the last dimension is the bit dimension |
| |
| returns indices, which are long integers from 0 to self.codebook_size |
| """ |
| assert bits.shape[-1] == self.codebook_dim |
| indices = 2 ** torch.arange( |
| 0, |
| self.codebook_dim, |
| 1, |
| dtype=torch.long, |
| device=bits.device, |
| ) |
| return (bits * indices).sum(-1) |
| |
| def decode(self, x): |
| """ |
| x: ... NH |
| where NH is number of codebook heads |
| A longtensor of codebook indices, containing values from |
| 0 to self.codebook_size |
| """ |
| x = self.indices_to_bits(x) |
| |
| x = x.to(self.dtype) |
| |
| x = x * 2 - 1 |
| x = rearrange(x, "... NC Z-> ... (NC Z)") |
| return x |
|
|
| def forward( |
| self, |
| x, |
| return_loss_breakdown = False, |
| mask = None, |
| return_loss = True, |
| ): |
| """ |
| einstein notation |
| b - batch |
| n - sequence (or flattened spatial dimensions) |
| d - feature dimension, which is also log2(codebook size) |
| c - number of codebook dim |
| """ |
|
|
|
|
| x = rearrange(x, 'b d ... -> b ... d') |
| x, ps = pack_one(x, 'b * d') |
| |
|
|
| x = rearrange(x, 'b n (c d) -> b n c d', c = self.num_codebooks) |
|
|
|
|
| codebook_value = torch.Tensor([1.0]).to(device=x.device, dtype=x.dtype) |
| quantized = torch.where(x > 0, codebook_value, -codebook_value) |
|
|
| |
| if self.token_factorization: |
| k = self.codebook_dim // 2 |
| indices_pre = reduce((quantized[..., :k] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") |
| indices_post = reduce((quantized[..., k:] > 0).int() * self.mask.int(), "b n c d -> b n c", "sum") |
| |
| else: |
| indices = reduce((quantized > 0).int() * self.mask.int(), 'b n c d -> b n c', 'sum') |
|
|
| |
|
|
| if self.training and return_loss: |
| logits = 2 * einsum('... i d, j d -> ... i j', x, self.codebook) |
| |
| per_sample_entropy, codebook_entropy, entropy_aux_loss = entropy_loss( |
| logits = logits, |
| sample_minimization_weight = self.sample_minimization_weight, |
| batch_maximization_weight = self.batch_maximization_weight |
| ) |
|
|
| avg_probs = self.zero |
| else: |
| |
| |
| |
| |
| |
| |
| |
| |
| per_sample_entropy = codebook_entropy = self.zero |
| entropy_aux_loss = self.zero |
| avg_probs = self.zero |
|
|
| |
|
|
| if self.training: |
| commit_loss = F.mse_loss(x, quantized.detach(), reduction = 'none') |
|
|
| if exists(mask): |
| commit_loss = commit_loss[mask] |
|
|
| commit_loss = commit_loss.mean() |
| else: |
| commit_loss = self.zero |
|
|
|
|
| |
|
|
| quantized = x + (quantized - x).detach() |
|
|
| |
|
|
| quantized = rearrange(quantized, 'b n c d -> b n (c d)') |
|
|
| |
|
|
| quantized = unpack_one(quantized, ps, 'b * d') |
| quantized = rearrange(quantized, 'b ... d -> b d ...') |
|
|
| |
| if self.token_factorization: |
| indices_pre = unpack_one(indices_pre, ps, "b * c") |
| indices_post = unpack_one(indices_post, ps, "b * c") |
| indices_pre = indices_pre.flatten() |
| indices_post = indices_post.flatten() |
| indices = (indices_pre, indices_post) |
| else: |
| indices = unpack_one(indices, ps, 'b * c') |
| indices = indices.flatten() |
|
|
| ret = (quantized, entropy_aux_loss, indices) |
|
|
| if not return_loss_breakdown: |
| return ret |
|
|
| return ret, LossBreakdown(per_sample_entropy, codebook_entropy, commit_loss, avg_probs) |
|
|
|
|
| class LFQTokenizer(PreTrainedModel): |
| config_class = LFQTokenizerConfig |
|
|
| def __init__(self, config: LFQTokenizerConfig): |
| super().__init__(config) |
|
|
| self.encoder = Encoder(**config.encoder_decoder_config) |
| self.decoder = Decoder(**config.encoder_decoder_config) |
| self.quantize = LFQ(**config.quantizer_config) |
|
|
| def encode(self, x): |
| h = self.encoder(x) |
| (quant, emb_loss, info), loss_breakdown = self.quantize(h, return_loss_breakdown=True) |
| return quant, emb_loss, info, loss_breakdown |
|
|
| def decode(self, quant): |
| return self.decoder(quant) |
|
|
| def forward(self, input): |
| quant, diff, _, loss_breakdown = self.encode(input) |
| dec = self.decoder(quant) |
| return dec, diff, loss_breakdown |
|
|
| def tokenize(self, input): |
| _, _, tokens, _ = self.encode(input) |
| return tokens |
|
|
| def get_last_layer(self): |
| return self.decoder.conv_out.weight |
|
|
| def decode_tokens(self, tokens, shape: tuple): |
| if self.quantize.token_factorization: |
| tokens_pre, tokens_post = tokens[0], tokens[1] |
| quant_pre = self.quantize.get_codebook_entry(tokens_pre, shape) |
| quant_post = self.quantize.get_codebook_entry(tokens_post, shape) |
| quant = torch.concat([quant_pre, quant_post], dim=1) |
| return self.decode(quant) |
| else: |
| if tokens.ndim == 1: |
| batch_size = shape[0] |
| tokens = tokens.view(batch_size, -1) |
| quant = self.quantize.get_codebook_entry(tokens, shape) |
| return self.decode(quant) |
|
|