Spaces:
Running
on
Zero
Running
on
Zero
| """Vector quantizer. | |
| Copyright (2024) Bytedance Ltd. and/or its affiliates | |
| Licensed under the Apache License, Version 2.0 (the "License"); | |
| you may not use this file except in compliance with the License. | |
| You may obtain a copy of the License at | |
| http://www.apache.org/licenses/LICENSE-2.0 | |
| Unless required by applicable law or agreed to in writing, software | |
| distributed under the License is distributed on an "AS IS" BASIS, | |
| WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| See the License for the specific language governing permissions and | |
| limitations under the License. | |
| Reference: | |
| https://github.com/CompVis/taming-transformers/blob/master/taming/modules/vqvae/quantize.py | |
| https://github.com/google-research/magvit/blob/main/videogvt/models/vqvae.py | |
| https://github.com/CompVis/latent-diffusion/blob/main/ldm/modules/distributions/distributions.py | |
| """ | |
| from typing import Mapping, Text, Tuple | |
| import torch | |
| from einops import rearrange | |
| from torch.cuda.amp import autocast | |
| class VectorQuantizer(torch.nn.Module): | |
| def __init__(self, | |
| codebook_size: int = 1024, | |
| token_size: int = 256, | |
| commitment_cost: float = 0.25, | |
| use_l2_norm: bool = False, | |
| ): | |
| super().__init__() | |
| self.commitment_cost = commitment_cost | |
| self.embedding = torch.nn.Embedding(codebook_size, token_size) | |
| self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size) | |
| self.use_l2_norm = use_l2_norm | |
| # Ensure quantization is performed using f32 | |
| def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, Mapping[Text, torch.Tensor]]: | |
| z = z.float() | |
| z = rearrange(z, 'b c h w -> b h w c').contiguous() | |
| z_flattened = rearrange(z, 'b h w c -> (b h w) c') | |
| if self.use_l2_norm: | |
| z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1) | |
| embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1) | |
| else: | |
| embedding = self.embedding.weight | |
| d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \ | |
| torch.sum(embedding**2, dim=1) - 2 * \ | |
| torch.einsum('bd,dn->bn', z_flattened, embedding.T) | |
| min_encoding_indices = torch.argmin(d, dim=1) # num_ele | |
| z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape) | |
| if self.use_l2_norm: | |
| z = torch.nn.functional.normalize(z, dim=-1) | |
| # compute loss for embedding | |
| commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) **2) | |
| codebook_loss = torch.mean((z_quantized - z.detach()) **2) | |
| loss = commitment_loss + codebook_loss | |
| # preserve gradients | |
| z_quantized = z + (z_quantized - z).detach() | |
| # reshape back to match original input shape | |
| z_quantized = rearrange(z_quantized, 'b h w c -> b c h w').contiguous() | |
| result_dict = dict( | |
| quantizer_loss=loss, | |
| commitment_loss=commitment_loss, | |
| codebook_loss=codebook_loss, | |
| min_encoding_indices=min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]) | |
| ) | |
| return z_quantized, result_dict | |
| def get_codebook_entry(self, indices): | |
| if len(indices.shape) == 1: | |
| z_quantized = self.embedding(indices) | |
| elif len(indices.shape) == 2: | |
| z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight) | |
| else: | |
| raise NotImplementedError | |
| if self.use_l2_norm: | |
| z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1) | |
| return z_quantized | |
| class DiagonalGaussianDistribution(object): | |
| def __init__(self, parameters, deterministic=False): | |
| """Initializes a Gaussian distribution instance given the parameters. | |
| Args: | |
| parameters (torch.Tensor): The parameters for the Gaussian distribution. It is expected | |
| to be in shape [B, 2 * C, *], where B is batch size, and C is the embedding dimension. | |
| First C channels are used for mean and last C are used for logvar in the Gaussian distribution. | |
| deterministic (bool): Whether to use deterministic sampling. When it is true, the sampling results | |
| is purely based on mean (i.e., std = 0). | |
| """ | |
| self.parameters = parameters | |
| self.mean, self.logvar = torch.chunk(parameters.float(), 2, dim=1) | |
| self.logvar = torch.clamp(self.logvar, -30.0, 20.0) | |
| self.deterministic = deterministic | |
| self.std = torch.exp(0.5 * self.logvar) | |
| self.var = torch.exp(self.logvar) | |
| if self.deterministic: | |
| self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) | |
| def sample(self): | |
| x = self.mean.float() + self.std.float() * torch.randn(self.mean.shape).to(device=self.parameters.device) | |
| return x | |
| def mode(self): | |
| return self.mean | |
| def kl(self): | |
| if self.deterministic: | |
| return torch.Tensor([0.]) | |
| else: | |
| return 0.5 * torch.sum(torch.pow(self.mean.float(), 2) | |
| + self.var.float() - 1.0 - self.logvar.float(), | |
| dim=[1, 2]) | |