Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import List, Tuple, Dict, Optional, Union, Any | |
| from contextlib import nullcontext | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from einops import rearrange, repeat | |
| from diffusers.schedulers.scheduling_utils import SchedulerMixin | |
| from diffusers import StableDiffusionPipeline | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from fourm.vq.quantizers import VectorQuantizerLucid, Memcodes | |
| import fourm.vq.models.vit_models as vit_models | |
| import fourm.vq.models.unet.unet as unet | |
| import fourm.vq.models.uvit as uvit | |
| import fourm.vq.models.controlnet as controlnet | |
| from fourm.vq.models.mlp_models import build_mlp | |
| from fourm.vq.scheduling import DDPMScheduler, DDIMScheduler, PNDMScheduler, PipelineCond | |
| from fourm.utils import denormalize | |
| # If freeze_enc is True, the following modules will be frozen | |
| FREEZE_MODULES = ['encoder', 'quant_proj', 'quantize', 'cls_emb'] | |
| class VQ(nn.Module, PyTorchModelHubMixin): | |
| """Base class for VQVAE and DiVAE models. Implements the encoder and quantizer, and can be used as such without a decoder | |
| after training. | |
| Args: | |
| image_size: Input and target image size. | |
| image_size_enc: Input image size for the encoder. Defaults to image_size. Change this when loading weights | |
| from a tokenizer trained on a different image size. | |
| n_channels: Number of input channels. | |
| n_labels: Number of classes for semantic segmentation. | |
| enc_type: String identifier specifying the encoder architecture. See vq/vit_models.py and vq/mlp_models.py | |
| for available architectures. | |
| patch_proj: Whether or not to use a ViT-style patch-wise linear projection in the encoder. | |
| post_mlp: Whether or not to add a small point-wise MLP before the quantizer. | |
| patch_size: Patch size for the encoder. | |
| quant_type: String identifier specifying the quantizer implementation. Can be 'lucid', or 'memcodes'. | |
| codebook_size: Number of codebook entries. | |
| num_codebooks: Number of "parallel" codebooks to use. Only relevant for 'lucid' and 'memcodes' quantizers. | |
| When using this, the tokens will be of shape B N_C H_Q W_Q, where N_C is the number of codebooks. | |
| latent_dim: Dimensionality of the latent code. Can be small when using norm_codes=True, | |
| see ViT-VQGAN (https://arxiv.org/abs/2110.04627) paper for details. | |
| norm_codes: Whether or not to normalize the codebook entries to the unit sphere. | |
| See ViT-VQGAN (https://arxiv.org/abs/2110.04627) paper for details. | |
| norm_latents: Whether or not to normalize the latent codes to the unit sphere for computing commitment loss. | |
| sync_codebook: Enable this when training on multiple GPUs, and disable for single GPUs, e.g. at inference. | |
| ema_decay: Decay rate for the exponential moving average of the codebook entries. | |
| threshold_ema_dead_code: Threshold for replacing stale codes that are used less than the | |
| indicated exponential moving average of the codebook entries. | |
| code_replacement_policy: Policy for replacing stale codes. Can be 'batch_random' or 'linde_buzo_gray'. | |
| commitment_weight: Weight for the quantizer commitment loss. | |
| kmeans_init: Whether or not to initialize the codebook entries with k-means clustering. | |
| ckpt_path: Path to a checkpoint to load the model weights from. | |
| ignore_keys: List of keys to ignore when loading the state_dict from the above checkpoint. | |
| freeze_enc: Whether or not to freeze the encoder weights. See FREEZE_MODULES for the list of modules. | |
| undo_std: Whether or not to undo any ImageNet standardization and transform the images to [-1,1] | |
| before feeding the input to the encoder. | |
| config: Dictionary containing the model configuration. Only used when loading | |
| from Huggingface Hub. Ignore otherwise. | |
| """ | |
| def __init__(self, | |
| image_size: int = 224, | |
| image_size_enc: Optional[int] = None, | |
| n_channels: str = 3, | |
| n_labels: Optional[int] = None, | |
| enc_type: str = 'vit_b_enc', | |
| patch_proj: bool = True, | |
| post_mlp: bool = False, | |
| patch_size: int = 16, | |
| quant_type: str = 'lucid', | |
| codebook_size: Union[int, str] = 16384, | |
| num_codebooks: int = 1, | |
| latent_dim: int = 32, | |
| norm_codes: bool = True, | |
| norm_latents: bool = False, | |
| sync_codebook: bool = True, | |
| ema_decay: float = 0.99, | |
| threshold_ema_dead_code: float = 0.25, | |
| code_replacement_policy: str = 'batch_random', | |
| commitment_weight: float = 1.0, | |
| kmeans_init: bool = False, | |
| ckpt_path: Optional[str] = None, | |
| ignore_keys: List[str] = [ | |
| 'decoder', 'loss', | |
| 'post_quant_conv', 'post_quant_proj', | |
| 'encoder.pos_emb', | |
| ], | |
| freeze_enc: bool = False, | |
| undo_std: bool = False, | |
| config: Optional[Dict[str, Any]] = None, | |
| **kwargs): | |
| if config is not None: | |
| config = copy.deepcopy(config) | |
| self.__init__(**config) | |
| return | |
| super().__init__() | |
| self.image_size = image_size | |
| self.n_channels = n_channels | |
| self.n_labels = n_labels | |
| self.enc_type = enc_type | |
| self.patch_proj = patch_proj | |
| self.post_mlp = post_mlp | |
| self.patch_size = patch_size | |
| self.quant_type = quant_type | |
| self.codebook_size = codebook_size | |
| self.num_codebooks = num_codebooks | |
| self.latent_dim = latent_dim | |
| self.norm_codes = norm_codes | |
| self.norm_latents = norm_latents | |
| self.sync_codebook = sync_codebook | |
| self.ema_decay = ema_decay | |
| self.threshold_ema_dead_code = threshold_ema_dead_code | |
| self.code_replacement_policy = code_replacement_policy | |
| self.commitment_weight = commitment_weight | |
| self.kmeans_init = kmeans_init | |
| self.ckpt_path = ckpt_path | |
| self.ignore_keys = ignore_keys | |
| self.freeze_enc = freeze_enc | |
| self.undo_std = undo_std | |
| # For semantic segmentation | |
| if n_labels is not None: | |
| self.cls_emb = nn.Embedding(num_embeddings=n_labels, embedding_dim=n_channels) | |
| self.colorize = torch.randn(3, n_labels, 1, 1) | |
| else: | |
| self.cls_emb = None | |
| # Init encoder | |
| image_size_enc = image_size_enc or image_size | |
| if 'vit' in enc_type: | |
| self.encoder = getattr(vit_models, enc_type)( | |
| in_channels=n_channels, patch_size=patch_size, | |
| resolution=image_size_enc, patch_proj=patch_proj, post_mlp=post_mlp | |
| ) | |
| self.enc_dim = self.encoder.dim_tokens | |
| elif 'MLP' in enc_type: | |
| self.encoder = build_mlp(model_id=enc_type, dim_in=n_channels, dim_out=None) | |
| self.enc_dim = self.encoder.dim_out | |
| else: | |
| raise NotImplementedError(f'{enc_type} not implemented.') | |
| # Encoder -> quantizer projection | |
| self.quant_proj = torch.nn.Conv2d(self.enc_dim, self.latent_dim, 1) | |
| # Init quantizer | |
| if quant_type == 'lucid': | |
| self.quantize = VectorQuantizerLucid( | |
| dim=latent_dim, | |
| codebook_size=codebook_size, | |
| codebook_dim=latent_dim, | |
| heads=num_codebooks, | |
| use_cosine_sim = norm_codes, | |
| threshold_ema_dead_code = threshold_ema_dead_code, | |
| code_replacement_policy=code_replacement_policy, | |
| sync_codebook = sync_codebook, | |
| decay = ema_decay, | |
| commitment_weight=self.commitment_weight, | |
| norm_latents = norm_latents, | |
| kmeans_init=kmeans_init, | |
| ) | |
| elif quant_type == 'memcodes': | |
| self.quantize = Memcodes( | |
| dim=latent_dim, codebook_size=codebook_size, | |
| heads=num_codebooks, temperature=1., | |
| ) | |
| else: | |
| raise ValueError(f'{quant_type} not a valid quant_type.') | |
| # Load checkpoint | |
| if ckpt_path is not None: | |
| self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) | |
| # Freeze encoder | |
| if freeze_enc: | |
| for module_name, module in self.named_children(): | |
| if module_name not in FREEZE_MODULES: | |
| continue | |
| for param in module.parameters(): | |
| param.requires_grad = False | |
| module.eval() | |
| def train(self, mode: bool = True) -> 'VQ': | |
| """Override the default train() to set the training mode to all modules | |
| except the encoder if freeze_enc is True. | |
| Args: | |
| mode: Whether to set the model to training mode (True) or evaluation mode (False). | |
| """ | |
| if not isinstance(mode, bool): | |
| raise ValueError("training mode is expected to be boolean") | |
| self.training = mode | |
| for module_name, module in self.named_children(): | |
| if self.freeze_enc and module_name in FREEZE_MODULES: | |
| continue | |
| module.train(mode) | |
| return self | |
| def init_from_ckpt(self, path: str, ignore_keys: List[str] = list()) -> 'VQ': | |
| """Loads the state_dict from a checkpoint file and initializes the model with it. | |
| Renames the keys in the state_dict if necessary (e.g. when loading VQ-GAN weights). | |
| Args: | |
| path: Path to the checkpoint file. | |
| ignore_keys: List of keys to ignore when loading the state_dict. | |
| Returns: | |
| self | |
| """ | |
| ckpt = torch.load(path, map_location="cpu") | |
| sd = ckpt['model'] if 'model' in ckpt else ckpt['state_dict'] | |
| # Compatibility with ViT-VQGAN weights | |
| if 'quant_conv.0.weight' in sd and 'quant_conv.0.bias' in sd: | |
| print("Renaming quant_conv.0 to quant_proj") | |
| sd['quant_proj.weight'] = sd['quant_conv.0.weight'] | |
| sd['quant_proj.bias'] = sd['quant_conv.0.bias'] | |
| del sd['quant_conv.0.weight'] | |
| del sd['quant_conv.0.bias'] | |
| elif 'quant_conv.weight' in sd and 'quant_conv.bias' in sd: | |
| print("Renaming quant_conv to quant_proj") | |
| sd['quant_proj.weight'] = sd['quant_conv.weight'] | |
| sd['quant_proj.bias'] = sd['quant_conv.bias'] | |
| del sd['quant_conv.weight'] | |
| del sd['quant_conv.bias'] | |
| if 'post_quant_conv.0.weight' in sd and 'post_quant_conv.0.bias' in sd: | |
| print("Renaming post_quant_conv.0 to post_quant_proj") | |
| sd['post_quant_proj.weight'] = sd['post_quant_conv.0.weight'] | |
| sd['post_quant_proj.bias'] = sd['post_quant_conv.0.bias'] | |
| del sd['post_quant_conv.0.weight'] | |
| del sd['post_quant_conv.0.bias'] | |
| elif 'post_quant_conv.weight' in sd and 'post_quant_conv.bias' in sd: | |
| print("Renaming post_quant_conv to post_quant_proj") | |
| sd['post_quant_proj.weight'] = sd['post_quant_conv.weight'] | |
| sd['post_quant_proj.bias'] = sd['post_quant_conv.bias'] | |
| del sd['post_quant_conv.weight'] | |
| del sd['post_quant_conv.bias'] | |
| keys = list(sd.keys()) | |
| for k in keys: | |
| for ik in ignore_keys: | |
| if k.startswith(ik): | |
| print("Deleting key {} from state_dict.".format(k)) | |
| del sd[k] | |
| msg = self.load_state_dict(sd, strict=False) | |
| print(msg) | |
| print(f"Restored from {path}") | |
| return self | |
| def prepare_input(self, x: torch.Tensor) -> torch.Tensor: | |
| """Preprocesses the input image tensor before feeding it to the encoder. | |
| If self.undo_std, the input is first denormalized from the ImageNet | |
| standardization to [-1, 1]. If semantic segmentation is performed, the | |
| class indices are embedded. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| Preprocessed input tensor of shape B C H W | |
| """ | |
| if self.undo_std: | |
| x = 2.0 * denormalize(x) - 1.0 | |
| if self.cls_emb is not None: | |
| x = rearrange(self.cls_emb(x), 'b h w c -> b c h w') | |
| return x | |
| def to_rgb(self, x: torch.Tensor) -> torch.Tensor: | |
| """When semantic segmentation is performed, this function converts the | |
| class embeddings to RGB. | |
| Args: | |
| x: Input tensor of shape B C H W | |
| Returns: | |
| RGB tensor of shape B C H W | |
| """ | |
| x = F.conv2d(x, weight=self.colorize) | |
| x = (x-x.min())/(x.max()-x.min()) | |
| return x | |
| def encode(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: | |
| """Encodes an input image tensor and quantizes the latent code. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| quant: Quantized latent code of shape B D_Q H_Q W_Q | |
| code_loss: Codebook loss | |
| tokens: Quantized indices of shape B H_Q W_Q | |
| """ | |
| x = self.prepare_input(x) | |
| h = self.encoder(x) | |
| h = self.quant_proj(h) | |
| quant, code_loss, tokens = self.quantize(h) | |
| return quant, code_loss, tokens | |
| def tokenize(self, x: torch.Tensor) -> torch.LongTensor: | |
| """Tokenizes an input image tensor. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| Quantized indices of shape B H_Q W_Q | |
| """ | |
| _, _, tokens = self.encode(x) | |
| return tokens | |
| def autoencode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """Autoencodes an input image tensor by encoding it, quantizing the latent code, | |
| and decoding it back to an image. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| Reconstructed image tensor of shape B C H W | |
| """ | |
| pass | |
| def decode_quant(self, quant: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """Decodes quantized latent codes back to an image. | |
| Args: | |
| quant: Quantized latent code of shape B D_Q H_Q W_Q | |
| Returns: | |
| Decoded image tensor of shape B C H W | |
| """ | |
| pass | |
| def tokens_to_embedding(self, tokens: torch.LongTensor) -> torch.Tensor: | |
| """Look up the codebook entries corresponding the discrete tokens. | |
| Args: | |
| tokens: Quantized indices of shape B H_Q W_Q | |
| Returns: | |
| Quantized latent code of shape B D_Q H_Q W_Q | |
| """ | |
| return self.quantize.indices_to_embedding(tokens) | |
| def decode_tokens(self, tokens: torch.LongTensor, **kwargs) -> torch.Tensor: | |
| """Decodes discrete tokens back to an image. | |
| Args: | |
| tokens: Quantized indices of shape B H_Q W_Q | |
| Returns: | |
| Decoded image tensor of shape B C H W | |
| """ | |
| quant = self.tokens_to_embedding(tokens) | |
| dec = self.decode_quant(quant, **kwargs) | |
| return dec | |
| def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass of the encoder and quantizer. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| quant: Quantized latent code of shape B D_Q H_Q W_Q | |
| code_loss: Codebook loss | |
| """ | |
| quant, code_loss, _ = self.encode(x) | |
| return quant, code_loss | |
| class VQVAE(VQ): | |
| """VQ-VAE model = simple encoder + decoder with a discrete bottleneck and | |
| basic reconstruction loss (optionall with perceptual loss), i.e. no diffusion, | |
| nor GAN discriminator. | |
| Args: | |
| dec_type: String identifier specifying the decoder architecture. | |
| See vq/vit_models.py and vq/mlp_models.py for available architectures. | |
| out_conv: Whether or not to add final conv layers to the ViT decoder. | |
| image_size_dec: Image size for the decoder. Defaults to self.image_size. | |
| Change this when loading weights from a tokenizer decoder trained on a | |
| different image size. | |
| patch_size_dec: Patch size for the decoder. Defaults to self.patch_size. | |
| config: Dictionary containing the model configuration. Only used when loading | |
| from Huggingface Hub. Ignore otherwise. | |
| """ | |
| def __init__(self, | |
| dec_type: str = 'vit_b_dec', | |
| out_conv: bool = False, | |
| image_size_dec: int = None, | |
| patch_size_dec: int = None, | |
| config: Optional[Dict[str, Any]] = None, | |
| *args, | |
| **kwargs): | |
| if config is not None: | |
| config = copy.deepcopy(config) | |
| self.__init__(**config) | |
| return | |
| # Don't want to load the weights just yet | |
| self.original_ckpt_path = kwargs.get('ckpt_path', None) | |
| kwargs['ckpt_path'] = None | |
| super().__init__(*args, **kwargs) | |
| self.ckpt_path = self.original_ckpt_path | |
| # Init decoder | |
| out_channels = self.n_channels if self.n_labels is None else self.n_labels | |
| image_size_dec = image_size_dec or self.image_size | |
| patch_size = patch_size_dec or self.patch_size | |
| if 'vit' in dec_type: | |
| self.decoder = getattr(vit_models, dec_type)( | |
| out_channels=out_channels, patch_size=patch_size, | |
| resolution=image_size_dec, out_conv=out_conv, post_mlp=self.post_mlp, | |
| patch_proj=self.patch_proj | |
| ) | |
| self.dec_dim = self.decoder.dim_tokens | |
| elif 'MLP' in dec_type: | |
| self.decoder = build_mlp(model_id=dec_type, dim_in=None, dim_out=out_channels) | |
| self.dec_dim = self.decoder.dim_in | |
| else: | |
| raise NotImplementedError(f'{dec_type} not implemented.') | |
| # Quantizer -> decoder projection | |
| self.post_quant_proj = torch.nn.Conv2d(self.latent_dim, self.dec_dim, 1) | |
| # Load checkpoint | |
| if self.ckpt_path is not None: | |
| self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys) | |
| def decode_quant(self, quant: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """Decodes quantized latent codes back to an image. | |
| Args: | |
| quant: Quantized latent code of shape B D_Q H_Q W_Q | |
| Returns: | |
| Decoded image tensor of shape B C H W | |
| """ | |
| quant = self.post_quant_proj(quant) | |
| dec = self.decoder(quant) | |
| return dec | |
| def forward(self, x: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass of the encoder, quantizer, and decoder. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| dec: Decoded image tensor of shape B C H W | |
| code_loss: Codebook loss | |
| """ | |
| with torch.no_grad() if self.freeze_enc else nullcontext(): | |
| quant, code_loss, _ = self.encode(x) | |
| dec = self.decode_quant(quant) | |
| return dec, code_loss | |
| def autoencode(self, x: torch.Tensor, **kwargs) -> torch.Tensor: | |
| """Autoencodes an input image tensor by encoding it, quantizing the | |
| latent code, and decoding it back to an image. | |
| Args: | |
| x: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| Returns: | |
| Reconstructed image tensor of shape B C H W | |
| """ | |
| dec, _ = self.forward(x) | |
| return dec | |
| class DiVAE(VQ): | |
| """DiVAE ("Diffusion VQ-VAE") model = simple encoder + diffusion decoder with | |
| a discrete bottleneck, inspired by https://arxiv.org/abs/2206.00386. | |
| Args: | |
| dec_type: String identifier specifying the decoder architecture. | |
| See vq/models/unet/unet.py and vq/models/uvit.py for available architectures. | |
| num_train_timesteps: Number of diffusion timesteps to use for training. | |
| cls_free_guidance_dropout: Dropout probability for classifier-free guidance. | |
| masked_cfg: Whether or not to randomly mask out conditioning tokens. | |
| cls_free_guidance_dropout must be > 0.0 for this to have any effect, and | |
| decides how often masking is performed. E.g. with 0.5, half of the time | |
| the conditioning tokens will be randomly masked, and half the time they | |
| will be kept as is. | |
| masked_cfg_low: Lower bound of number of tokens to mask out. | |
| masked_cfg_high: Upper bound of number of tokens to mask out (inclusive). | |
| Defaults to total number of tokens (H_Q * W_Q) if it is set to None. | |
| scheduler: String identifier specifying the diffusion scheduler to use. | |
| Can be 'ddpm' or 'ddim'. | |
| beta_schedule: String identifier specifying the beta schedule to use for | |
| the diffusion process. Can be 'linear', 'squaredcos_cap_v2' (cosine), | |
| 'shifted_cosine:{shift_amount}'; see vq/scheduling for details. | |
| prediction_type: String identifier specifying the type of prediction to use. | |
| Can be 'sample', 'epsilon', or 'v_prediction'; see vq/scheduling for details. | |
| clip_sample: Whether or not to clip the samples to [-1, 1], at inference only. | |
| thresholding: Whether or not to use dynamic thresholding (introduced by Imagen, | |
| https://arxiv.org/abs/2205.11487) for the diffusion process, at inference only. | |
| conditioning: String identifier specifying the way to condition the diffusion | |
| decoder. Can be 'concat' or 'xattn'. See models for details (only relevant to UViT). | |
| dec_transformer_dropout: Dropout rate for the transformer layers in the | |
| diffusion decoder (only relevant to UViT models). | |
| zero_terminal_snr: Whether or not to enforce zero terminal SNR, i.e. the SNR | |
| at the last timestep is set to zero. This is useful for preventing the model | |
| from "cheating" by using information in the last timestep to reconstruct the image. | |
| See https://arxiv.org/abs/2305.08891. | |
| image_size_dec: Image size for the decoder. Defaults to image_size. | |
| Change this when loading weights from a tokenizer decoder trained on a | |
| different image size. | |
| config: Dictionary containing the model configuration. Only used when loading | |
| from Huggingface Hub. Ignore otherwise. | |
| """ | |
| def __init__(self, | |
| dec_type: str = 'unet_patched', | |
| num_train_timesteps: int = 1000, | |
| cls_free_guidance_dropout: float = 0.0, | |
| masked_cfg: bool = False, | |
| masked_cfg_low: int = 0, | |
| masked_cfg_high: Optional[int] = None, | |
| scheduler: str = 'ddpm', | |
| beta_schedule: str = 'squaredcos_cap_v2', | |
| prediction_type: str = 'v_prediction', | |
| clip_sample: bool = False, | |
| thresholding: bool = True, | |
| conditioning: str = 'concat', | |
| dec_transformer_dropout: float = 0.2, | |
| zero_terminal_snr: bool = True, | |
| image_size_dec: Optional[int] = None, | |
| config: Optional[Dict[str, Any]] = None, | |
| *args, **kwargs): | |
| if config is not None: | |
| config = copy.deepcopy(config) | |
| self.__init__(**config) | |
| return | |
| # Don't want to load the weights just yet | |
| self.original_ckpt_path = kwargs.get('ckpt_path', None) | |
| kwargs['ckpt_path'] = None | |
| super().__init__(*args, **kwargs) | |
| self.ckpt_path = self.original_ckpt_path | |
| self.num_train_timesteps = num_train_timesteps | |
| self.beta_schedule = beta_schedule | |
| self.prediction_type = prediction_type | |
| self.clip_sample = clip_sample | |
| self.thresholding = thresholding | |
| self.zero_terminal_snr = zero_terminal_snr | |
| if cls_free_guidance_dropout > 0.0: | |
| self.cfg_dist = torch.distributions.Bernoulli(probs=cls_free_guidance_dropout) | |
| else: | |
| self.cfg_dist = None | |
| self.masked_cfg = masked_cfg | |
| self.masked_cfg_low = masked_cfg_low | |
| self.masked_cfg_high = masked_cfg_high | |
| # Init diffusion decoder | |
| image_size_dec = image_size_dec or self.image_size | |
| if 'unet_' in dec_type: | |
| self.decoder = getattr(unet, dec_type)( | |
| in_channels=self.n_channels, | |
| out_channels=self.n_channels, | |
| cond_channels=self.latent_dim, | |
| image_size=image_size_dec, | |
| ) | |
| elif 'uvit_' in dec_type: | |
| self.decoder = getattr(uvit, dec_type)( | |
| sample_size=image_size_dec, | |
| in_channels=self.n_channels, | |
| out_channels=self.n_channels, | |
| cond_dim=self.latent_dim, | |
| cond_type=conditioning, | |
| mid_drop_rate=dec_transformer_dropout, | |
| ) | |
| else: | |
| raise NotImplementedError(f'dec_type {dec_type} not implemented.') | |
| # Init training diffusion scheduler / default pipeline for generation | |
| scheduler_cls = DDPMScheduler if scheduler == 'ddpm' else DDIMScheduler | |
| self.noise_scheduler = scheduler_cls( | |
| num_train_timesteps=num_train_timesteps, | |
| thresholding=thresholding, | |
| clip_sample=clip_sample, | |
| beta_schedule=beta_schedule, | |
| prediction_type=prediction_type, | |
| zero_terminal_snr=zero_terminal_snr, | |
| ) | |
| self.pipeline = PipelineCond(model=self.decoder, scheduler=self.noise_scheduler) | |
| # Load checkpoint | |
| if self.ckpt_path is not None: | |
| self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys) | |
| def sample_mask(self, quant: torch.Tensor, low: int = 0, high: Optional[int] = None) -> torch.BoolTensor: | |
| """Returns a mask of shape B H_Q W_Q, where True = masked-out, False = keep. | |
| Args: | |
| quant: Dequantized latent tensor of shape B D_Q H_Q W_Q | |
| low: Lower bound of number of tokens to mask out | |
| high: Upper bound of number of tokens to mask out (inclusive). | |
| Defaults to total number of tokens (H_Q * W_Q) if it is set to None. | |
| Returns: | |
| Boolean mask of shape B H_Q W_Q | |
| """ | |
| B, _, H_Q, W_Q = quant.shape | |
| num_tokens = H_Q * W_Q | |
| high = high if high is not None else num_tokens | |
| zero_idxs = torch.randint(low=low, high=high+1, size=(B,), device=quant.device) | |
| noise = torch.rand(B, num_tokens, device=quant.device) | |
| ids_arange_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
| mask = torch.where(ids_arange_shuffle < zero_idxs.unsqueeze(1), 0, 1) | |
| mask = rearrange(mask, 'b (h w) -> b h w', h=H_Q, w=W_Q).bool() | |
| return mask | |
| def _get_pipeline(self, scheduler: Optional[SchedulerMixin] = None) -> PipelineCond: | |
| """Creates a conditional diffusion pipeline with the given scheduler. | |
| Args: | |
| scheduler: Scheduler to use for the diffusion pipeline. | |
| If None, the default scheduler will be used. | |
| Returns: | |
| Conditional diffusion pipeline. | |
| """ | |
| return PipelineCond(model=self.decoder, scheduler=scheduler) if scheduler is not None else self.pipeline | |
| def decode_quant(self, | |
| quant: torch.Tensor, | |
| timesteps: Optional[int] = None, | |
| scheduler: Optional[SchedulerMixin] = None, | |
| generator: Optional[torch.Generator] = None, | |
| image_size: Optional[Union[Tuple[int, int], int]] = None, | |
| verbose: bool = False, | |
| scheduler_timesteps_mode: str = 'trailing', | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None) -> torch.Tensor: | |
| """Decodes quantized latent codes back to an image. | |
| Args: | |
| quant: Quantized latent code of shape B D_Q H_Q W_Q | |
| timesteps: Number of diffusion timesteps to use. Defaults to self.num_train_timesteps. | |
| scheduler: Scheduler to use for the diffusion pipeline. Defaults to the training scheduler. | |
| generator: Random number generator to use for sampling. By default generations are stochastic. | |
| image_size: Image size to use for the diffusion pipeline. Defaults to decoder image size. | |
| verbose: Whether or not to print progress bar. | |
| scheduler_timesteps_mode: The mode to use for DDIMScheduler. One of `trailing`, `linspace`, | |
| `leading`. See https://arxiv.org/abs/2305.08891 for more details. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| Returns: | |
| Decoded image tensor of shape B C H W | |
| """ | |
| pipeline = self._get_pipeline(scheduler) | |
| dec = pipeline( | |
| quant, timesteps=timesteps, generator=generator, image_size=image_size, | |
| verbose=verbose, scheduler_timesteps_mode=scheduler_timesteps_mode, orig_res=orig_res | |
| ) | |
| return dec | |
| def decode_tokens(self, tokens: torch.LongTensor, **kwargs) -> torch.Tensor: | |
| """See `decode_quant` for details on the optional args.""" | |
| return super().decode_tokens(tokens, **kwargs) | |
| def autoencode(self, | |
| input_clean: torch.Tensor, | |
| timesteps: Optional[int] = None, | |
| scheduler: Optional[SchedulerMixin] = None, | |
| generator: Optional[torch.Generator] = None, | |
| verbose: bool = True, | |
| scheduler_timesteps_mode: str = 'trailing', | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None, | |
| **kwargs) -> torch.Tensor: | |
| """Autoencodes an input image tensor by encoding it, quantizing the latent code, | |
| and decoding it back to an image. | |
| Args: | |
| input_clean: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| timesteps: Number of diffusion timesteps to use. Defaults to self.num_train_timesteps. | |
| scheduler: Scheduler to use for the diffusion pipeline. Defaults to the training scheduler. | |
| generator: Random number generator to use for sampling. By default generations are stochastic. | |
| verbose: Whether or not to print progress bar. | |
| scheduler_timesteps_mode: The mode to use for DDIMScheduler. One of `trailing`, `linspace`, | |
| `leading`. See https://arxiv.org/abs/2305.08891 for more details. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| Returns: | |
| Reconstructed image tensor of shape B C H W | |
| """ | |
| pipeline = self._get_pipeline(scheduler) | |
| quant, _, _ = self.encode(input_clean) | |
| image_size = input_clean.shape[-1] | |
| dec = pipeline( | |
| quant, timesteps=timesteps, generator=generator, image_size=image_size, | |
| verbose=verbose, scheduler_timesteps_mode=scheduler_timesteps_mode, orig_res=orig_res | |
| ) | |
| return dec | |
| def forward(self, | |
| input_clean: torch.Tensor, | |
| input_noised: torch.Tensor, | |
| timesteps: Union[torch.Tensor, float, int], | |
| cond_mask: Optional[torch.Tensor] = None, | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass of the encoder, quantizer, and decoder. | |
| Args: | |
| input_clean: Clean input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation. Used for encoding. | |
| input_noised: Noised input image tensor of shape B C H W. Used as | |
| input to the diffusion decoder. | |
| timesteps: Timesteps for conditioning the diffusion decoder on. | |
| cond_mask: Optional mask for the diffusion conditioning. | |
| True = masked-out, False = keep. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| Returns: | |
| dec: Decoded image tensor of shape B C H W | |
| code_loss: Codebook loss | |
| """ | |
| with torch.no_grad() if self.freeze_enc else nullcontext(): | |
| quant, code_loss, _ = self.encode(input_clean) | |
| if cond_mask is None and self.cfg_dist is not None and self.training: | |
| # Create a random mask for each batch element. True = masked-out, False = keep | |
| B, _, H_Q, W_Q = quant.shape | |
| cond_mask = self.cfg_dist.sample((B,)).to(quant.device, dtype=torch.bool) | |
| cond_mask = repeat(cond_mask, 'b -> b h w', h=H_Q, w=W_Q) | |
| if self.masked_cfg: | |
| mask = self.sample_mask(quant, low=self.masked_cfg_low, high=self.masked_cfg_high) | |
| cond_mask = (mask * cond_mask) | |
| dec = self.decoder(input_noised, timesteps, quant, cond_mask=cond_mask, orig_res=orig_res) | |
| return dec, code_loss | |
| class VQControlNet(VQ): | |
| """VQControlNet model = simple pertrained encoder + a ControlNet decoder conditioned on tokens. | |
| Args: | |
| sd_path: Path to the Stable Diffusion weights for training the ControlNet. | |
| image_size_sd: Stable diffusion input image size. Defaults to image_size. | |
| Change this to the image size that Stable Diffusion is trained on. | |
| pretrained_cn: Whether to use pretrained Stable Diffusion weights for the control model. | |
| cls_free_guidance_dropout: Dropout probability for classifier-free guidance. | |
| masked_cfg: Whether or not to randomly mask out conditioning tokens. | |
| cls_free_guidance_dropout must be > 0.0 for this to have any effect, and | |
| decides how often masking is performed. E.g. with 0.5, half of the time | |
| the conditioning tokens will be randomly masked, and half the time they | |
| will be kept as is. | |
| masked_cfg_low: Lower bound of number of tokens to mask out. | |
| masked_cfg_high: Upper bound of number of tokens to mask out (inclusive). | |
| Defaults to total number of tokens (H_Q * W_Q) if it is set to None. | |
| enable_xformer: Enables xFormers. | |
| adapter: Path to the adapter model weights. The adapter model is initialy trained to map | |
| the tokens to a VAE latent-like representation. Then the output of the adapter model | |
| is passed as the condition to train the ControlNet. By default there is no adapter usage. | |
| config: Dictionary containing the model configuration. Only used when loading | |
| from Huggingface Hub. Ignore otherwise. | |
| """ | |
| def __init__(self, | |
| sd_path: str = "runwayml/stable-diffusion-v1-5", | |
| image_size_sd: Optional[int] = None, | |
| pretrained_cn: bool = False, | |
| cls_free_guidance_dropout: float = 0.0, | |
| masked_cfg: bool = False, | |
| masked_cfg_low: int = 0, | |
| masked_cfg_high: Optional[int] = None, | |
| enable_xformer: bool = False, | |
| adapter: Optional[str] = None, | |
| config: Optional[Dict[str, Any]] = None, | |
| *args, **kwargs): | |
| if config is not None: | |
| config = copy.deepcopy(config) | |
| self.__init__(**config) | |
| return | |
| # Don't want to load the weights just yet | |
| self.original_ckpt_path = kwargs.get('ckpt_path', None) | |
| kwargs['ckpt_path'] = None | |
| super().__init__(*args, **kwargs) | |
| self.ckpt_path = self.original_ckpt_path | |
| if cls_free_guidance_dropout > 0.0: | |
| self.cfg_dist = torch.distributions.Bernoulli(probs=cls_free_guidance_dropout) | |
| else: | |
| self.cfg_dist = None | |
| self.masked_cfg = masked_cfg | |
| self.masked_cfg_low = masked_cfg_low | |
| self.masked_cfg_high = masked_cfg_high | |
| self.image_size_sd = self.image_size if image_size_sd is None else image_size_sd | |
| sd_pipeline = StableDiffusionPipeline.from_pretrained(sd_path) | |
| try: | |
| import xformers | |
| XFORMERS_AVAILABLE = True | |
| except ImportError: | |
| print("xFormers not available") | |
| XFORMERS_AVAILABLE = False | |
| enable_xformer = enable_xformer and XFORMERS_AVAILABLE | |
| if enable_xformer: | |
| print('Enabling xFormer for Stable Diffusion') | |
| sd_pipeline.enable_xformers_memory_efficient_attention() | |
| self.decoder = getattr(controlnet, 'controlnet')( | |
| in_channels=4, | |
| cond_channels=self.latent_dim, | |
| sd_pipeline=sd_pipeline, | |
| image_size=self.image_size_sd, | |
| pretrained_cn=pretrained_cn, | |
| enable_xformer=enable_xformer, | |
| adapter=adapter, | |
| ) | |
| # Use the defualt controlnet pipeline both for training and generation | |
| self.noise_scheduler = PNDMScheduler(**sd_pipeline.scheduler.config) | |
| self.vae = sd_pipeline.vae | |
| self._freeze_vae() | |
| self.pipeline = PipelineCond(model=self.decoder, scheduler=self.noise_scheduler) | |
| # Load checkpoint | |
| if self.ckpt_path is not None: | |
| self.init_from_ckpt(self.ckpt_path, ignore_keys=self.ignore_keys) | |
| def sample_mask(self, quant: torch.Tensor, low: int = 0, high: Optional[int] = None) -> torch.BoolTensor: | |
| """Returns a mask of shape B H_Q W_Q, where True = masked-out, False = keep. | |
| Args: | |
| quant: Dequantized latent tensor of shape B D_Q H_Q W_Q | |
| low: Lower bound of number of tokens to mask out | |
| high: Upper bound of number of tokens to mask out (inclusive). | |
| Defaults to total number of tokens (H_Q * W_Q) if it is set to None. | |
| Returns: | |
| Boolean mask of shape B H_Q W_Q | |
| """ | |
| B, _, H_Q, W_Q = quant.shape | |
| num_tokens = H_Q * W_Q | |
| high = high if high is not None else num_tokens | |
| zero_idxs = torch.randint(low=low, high=high+1, size=(B,), device=quant.device) | |
| noise = torch.rand(B, num_tokens, device=quant.device) | |
| ids_arange_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove | |
| mask = torch.where(ids_arange_shuffle < zero_idxs.unsqueeze(1), 0, 1) | |
| mask = rearrange(mask, 'b (h w) -> b h w', h=H_Q, w=W_Q).bool() | |
| return mask | |
| def decode_quant(self, | |
| quant: torch.Tensor, | |
| timesteps: Optional[int] = None, | |
| generator: Optional[torch.Generator] = None, | |
| image_size: Optional[Union[Tuple[int, int], int]] = None, | |
| verbose: bool = False, | |
| vae_decode: bool = False, | |
| scheduler_timesteps_mode: str = 'leading', | |
| prompt: Optional[Union[List[str], str]]= None, | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None, | |
| guidance_scale: int = 0.0, | |
| cond_scale: int = 1.0) -> torch.Tensor: | |
| """Decodes quantized latent codes back to an image. | |
| Args: | |
| quant: Quantized latent code of shape B D_Q H_Q W_Q | |
| timesteps: Number of diffusion timesteps to use. Defaults to self.num_train_timesteps. | |
| generator: Random number generator to use for sampling. By default generations are stochastic. | |
| image_size: Image size to use for the diffusion pipeline. Defaults to decoder image size. | |
| verbose: Whether or not to print progress bar. | |
| vae_decode: If set to True decodes the latent output of stable diffusion | |
| scheduler_timesteps_mode: The mode to use for DDIMScheduler. One of `trailing`, `linspace`, | |
| `leading`. See https://arxiv.org/abs/2305.08891 for more details. | |
| prompt: the input prompts for controlnet. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| guidance_scale: Classifier free guidance scale. | |
| cond_scale: Scale that is multiplied by the output of control model before being added | |
| to stable diffusion layers in controlnet. | |
| Returns: | |
| Decoded tensor of shape B C H W | |
| """ | |
| dec = self.pipeline( | |
| quant, timesteps=timesteps, generator=generator, image_size=image_size, | |
| verbose=verbose, scheduler_timesteps_mode=scheduler_timesteps_mode, prompt=prompt, | |
| guidance_scale=guidance_scale, cond_scale=cond_scale, | |
| ) | |
| if vae_decode: | |
| return self.vae_decode(dec) | |
| return dec | |
| def decode_tokens(self, tokens: torch.LongTensor, **kwargs) -> torch.Tensor: | |
| """See `decode_quant` for details on the optional args.""" | |
| return super().decode_tokens(tokens, **kwargs) | |
| def vae_encode(self, x: torch.Tensor): | |
| """Encodes the input image into vae latent representaiton. | |
| Args: | |
| x: Input images | |
| Returns: | |
| Encoded latent tensor of shape B C H W | |
| """ | |
| z = self.vae.encode(x).latent_dist.sample() | |
| z = z * self.vae.config.scaling_factor | |
| return z | |
| def vae_decode(self, x: torch.Tensor, clip: bool = True) -> torch.Tensor: | |
| """Decodes the vae latent representation into vae latent representaiton. | |
| Args: | |
| x: VAE latent representation | |
| clip: If set True clips the decoded image between -1 and 1. | |
| Returns: | |
| Decoded image of shape B C H W | |
| """ | |
| x = self.vae.decode(x / self.vae.config.scaling_factor).sample | |
| if clip: | |
| x = torch.clip(x, min=-1, max=1) | |
| return x | |
| def autoencode(self, | |
| input_clean: torch.Tensor, | |
| timesteps: Optional[int] = None, | |
| generator: Optional[torch.Generator] = None, | |
| image_size: Optional[Union[Tuple[int, int], int]] = None, | |
| verbose: bool = False, | |
| vae_decode: bool = False, | |
| scheduler_timesteps_mode: str = 'leading', | |
| prompt: Optional[Union[List[str], str]]= None, | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None, | |
| guidance_scale: int = 0.0, | |
| cond_scale: int = 1.0) -> torch.Tensor: | |
| """Autoencodes an input image tensor by encoding it, quantizing the latent code, | |
| and decoding it back to an image. | |
| Args: | |
| input_clean: Input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation | |
| timesteps: Number of diffusion timesteps to use. Defaults to self.num_train_timesteps. | |
| scheduler: Scheduler to use for the diffusion pipeline. Defaults to the training scheduler. | |
| generator: Random number generator to use for sampling. By default generations are stochastic. | |
| image_size: Image size to use for the diffusion pipeline. Defaults to decoder image size. | |
| verbose: Whether or not to print progress bar. | |
| vae_decode: If set to True, decodes the latent output of stable diffusion | |
| scheduler_timesteps_mode: The mode to use for DDIMScheduler. One of `trailing`, `linspace`, | |
| `leading`. See https://arxiv.org/abs/2305.08891 for more details. | |
| prompt: the input prompts for controlnet. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| guidance_scale: Classifier free guidance scale. | |
| cond_scale: Scale that is multiplied by the output of control model before being added | |
| to stable diffusion layers in controlnet. | |
| Returns: | |
| Reconstructed tensor of shape B C H W | |
| """ | |
| quant, _, _ = self.encode(input_clean) | |
| dec = self.pipeline( | |
| quant, timesteps=timesteps, generator=generator, | |
| verbose=verbose, scheduler_timesteps_mode=scheduler_timesteps_mode, prompt=prompt, | |
| guidance_scale=guidance_scale, cond_scale=cond_scale, | |
| ) | |
| if vae_decode: | |
| return self.vae_decode(dec) | |
| return dec | |
| def forward(self, | |
| input_clean: torch.Tensor, | |
| input_noised: torch.Tensor, | |
| timesteps: Union[torch.Tensor, float, int], | |
| cond_mask: Optional[torch.Tensor] = None, | |
| prompt: Optional[Union[List[str], str]] = None, | |
| orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
| """Forward pass of the encoder, quantizer, and decoder. | |
| Args: | |
| input_clean: Clean input image tensor of shape B C H W | |
| or B H W in case of semantic segmentation. Used for encoding. | |
| input_noised: Noised input image tensor of shape B C H W. Used as | |
| input to the diffusion decoder. | |
| timesteps: Timesteps for conditioning the diffusion decoder on. | |
| cond_mask: Optional mask for the diffusion conditioning. | |
| True = masked-out, False = keep. | |
| prompt: ControlNet input prompt. Defaults to an empty string. | |
| orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
| See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
| Returns: | |
| dec: Decoded image tensor of shape B C H W | |
| code_loss: Codebook loss | |
| """ | |
| with torch.no_grad() if self.freeze_enc else nullcontext(): | |
| quant, code_loss, _ = self.encode(input_clean) | |
| if cond_mask is None and self.cfg_dist is not None and self.training: | |
| # Create a random mask for each batch element. True = masked-out, False = keep | |
| B, _, H_Q, W_Q = quant.shape | |
| cond_mask = self.cfg_dist.sample((B,)).to(quant.device, dtype=torch.bool) | |
| cond_mask = repeat(cond_mask, 'b -> b h w', h=H_Q, w=W_Q) | |
| if self.masked_cfg: | |
| mask = self.sample_mask(quant, low=self.masked_cfg_low, high=self.masked_cfg_high) | |
| cond_mask = (mask * cond_mask) | |
| dec = self.decoder(input_noised, timesteps, quant, cond_mask=cond_mask, orig_res=orig_res, prompt=prompt) | |
| return dec, code_loss | |
| def _freeze_vae(self): | |
| """Freezes VAE""" | |
| for param in self.vae.parameters(): | |
| param.requires_grad = False |