| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | from dataclasses import dataclass |
| | from typing import Optional, Tuple, Union |
| |
|
| | import torch |
| |
|
| | from ...configuration_utils import ConfigMixin, register_to_config |
| | from ...utils import BaseOutput |
| | from ...utils.accelerate_utils import apply_forward_hook |
| | from ..modeling_utils import ModelMixin |
| | from .vae import DecoderOutput, DecoderTiny, EncoderTiny |
| |
|
| |
|
| | @dataclass |
| | class AutoencoderTinyOutput(BaseOutput): |
| | """ |
| | Output of AutoencoderTiny encoding method. |
| | |
| | Args: |
| | latents (`torch.Tensor`): Encoded outputs of the `Encoder`. |
| | |
| | """ |
| |
|
| | latents: torch.Tensor |
| |
|
| |
|
| | class AutoencoderTiny(ModelMixin, ConfigMixin): |
| | r""" |
| | A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. |
| | |
| | [`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. |
| | |
| | This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for |
| | all models (such as downloading or saving). |
| | |
| | Parameters: |
| | in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. |
| | out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. |
| | encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): |
| | Tuple of integers representing the number of output channels for each encoder block. The length of the |
| | tuple should be equal to the number of encoder blocks. |
| | decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): |
| | Tuple of integers representing the number of output channels for each decoder block. The length of the |
| | tuple should be equal to the number of decoder blocks. |
| | act_fn (`str`, *optional*, defaults to `"relu"`): |
| | Activation function to be used throughout the model. |
| | latent_channels (`int`, *optional*, defaults to 4): |
| | Number of channels in the latent representation. The latent space acts as a compressed representation of |
| | the input image. |
| | upsampling_scaling_factor (`int`, *optional*, defaults to 2): |
| | Scaling factor for upsampling in the decoder. It determines the size of the output image during the |
| | upsampling process. |
| | num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): |
| | Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The |
| | length of the tuple should be equal to the number of stages in the encoder. Each stage has a different |
| | number of encoder blocks. |
| | num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): |
| | Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The |
| | length of the tuple should be equal to the number of stages in the decoder. Each stage has a different |
| | number of decoder blocks. |
| | latent_magnitude (`float`, *optional*, defaults to 3.0): |
| | Magnitude of the latent representation. This parameter scales the latent representation values to control |
| | the extent of information preservation. |
| | latent_shift (float, *optional*, defaults to 0.5): |
| | Shift applied to the latent representation. This parameter controls the center of the latent space. |
| | scaling_factor (`float`, *optional*, defaults to 1.0): |
| | The component-wise standard deviation of the trained latent space computed using the first batch of the |
| | training set. This is used to scale the latent space to have unit variance when training the diffusion |
| | model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the |
| | diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 |
| | / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image |
| | Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, |
| | however, no such scaling factor was used, hence the value of 1.0 as the default. |
| | force_upcast (`bool`, *optional*, default to `False`): |
| | If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE |
| | can be fine-tuned / trained to a lower range without losing too much precision, in which case |
| | `force_upcast` can be set to `False` (see this fp16-friendly |
| | [AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). |
| | """ |
| |
|
| | _supports_gradient_checkpointing = True |
| |
|
| | @register_to_config |
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), |
| | decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), |
| | act_fn: str = "relu", |
| | upsample_fn: str = "nearest", |
| | latent_channels: int = 4, |
| | upsampling_scaling_factor: int = 2, |
| | num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), |
| | num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), |
| | latent_magnitude: int = 3, |
| | latent_shift: float = 0.5, |
| | force_upcast: bool = False, |
| | scaling_factor: float = 1.0, |
| | shift_factor: float = 0.0, |
| | ): |
| | super().__init__() |
| |
|
| | if len(encoder_block_out_channels) != len(num_encoder_blocks): |
| | raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") |
| | if len(decoder_block_out_channels) != len(num_decoder_blocks): |
| | raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") |
| |
|
| | self.encoder = EncoderTiny( |
| | in_channels=in_channels, |
| | out_channels=latent_channels, |
| | num_blocks=num_encoder_blocks, |
| | block_out_channels=encoder_block_out_channels, |
| | act_fn=act_fn, |
| | ) |
| |
|
| | self.decoder = DecoderTiny( |
| | in_channels=latent_channels, |
| | out_channels=out_channels, |
| | num_blocks=num_decoder_blocks, |
| | block_out_channels=decoder_block_out_channels, |
| | upsampling_scaling_factor=upsampling_scaling_factor, |
| | act_fn=act_fn, |
| | upsample_fn=upsample_fn, |
| | ) |
| |
|
| | self.latent_magnitude = latent_magnitude |
| | self.latent_shift = latent_shift |
| | self.scaling_factor = scaling_factor |
| |
|
| | self.use_slicing = False |
| | self.use_tiling = False |
| |
|
| | |
| | self.spatial_scale_factor = 2**out_channels |
| | self.tile_overlap_factor = 0.125 |
| | self.tile_sample_min_size = 512 |
| | self.tile_latent_min_size = self.tile_sample_min_size // self.spatial_scale_factor |
| |
|
| | self.register_to_config(block_out_channels=decoder_block_out_channels) |
| | self.register_to_config(force_upcast=False) |
| |
|
| | def _set_gradient_checkpointing(self, module, value: bool = False) -> None: |
| | if isinstance(module, (EncoderTiny, DecoderTiny)): |
| | module.gradient_checkpointing = value |
| |
|
| | def scale_latents(self, x: torch.Tensor) -> torch.Tensor: |
| | """raw latents -> [0, 1]""" |
| | return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) |
| |
|
| | def unscale_latents(self, x: torch.Tensor) -> torch.Tensor: |
| | """[0, 1] -> raw latents""" |
| | return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) |
| |
|
| | def enable_slicing(self) -> None: |
| | r""" |
| | Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
| | compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
| | """ |
| | self.use_slicing = True |
| |
|
| | def disable_slicing(self) -> None: |
| | r""" |
| | Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing |
| | decoding in one step. |
| | """ |
| | self.use_slicing = False |
| |
|
| | def enable_tiling(self, use_tiling: bool = True) -> None: |
| | r""" |
| | Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
| | compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
| | processing larger images. |
| | """ |
| | self.use_tiling = use_tiling |
| |
|
| | def disable_tiling(self) -> None: |
| | r""" |
| | Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing |
| | decoding in one step. |
| | """ |
| | self.enable_tiling(False) |
| |
|
| | def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: |
| | r"""Encode a batch of images using a tiled encoder. |
| | |
| | When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several |
| | steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the |
| | tiles overlap and are blended together to form a smooth output. |
| | |
| | Args: |
| | x (`torch.Tensor`): Input batch of images. |
| | |
| | Returns: |
| | `torch.Tensor`: Encoded batch of images. |
| | """ |
| | |
| | sf = self.spatial_scale_factor |
| | tile_size = self.tile_sample_min_size |
| |
|
| | |
| | blend_size = int(tile_size * self.tile_overlap_factor) |
| | traverse_size = tile_size - blend_size |
| |
|
| | |
| | ti = range(0, x.shape[-2], traverse_size) |
| | tj = range(0, x.shape[-1], traverse_size) |
| |
|
| | |
| | blend_masks = torch.stack( |
| | torch.meshgrid([torch.arange(tile_size / sf) / (blend_size / sf - 1)] * 2, indexing="ij") |
| | ) |
| | blend_masks = blend_masks.clamp(0, 1).to(x.device) |
| |
|
| | |
| | out = torch.zeros(x.shape[0], 4, x.shape[-2] // sf, x.shape[-1] // sf, device=x.device) |
| | for i in ti: |
| | for j in tj: |
| | tile_in = x[..., i : i + tile_size, j : j + tile_size] |
| | |
| | tile_out = out[..., i // sf : (i + tile_size) // sf, j // sf : (j + tile_size) // sf] |
| | tile = self.encoder(tile_in) |
| | h, w = tile.shape[-2], tile.shape[-1] |
| | |
| | blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] |
| | blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] |
| | blend_mask = blend_mask_i * blend_mask_j |
| | tile, blend_mask = tile[..., :h, :w], blend_mask[..., :h, :w] |
| | tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) |
| | return out |
| |
|
| | def _tiled_decode(self, x: torch.Tensor) -> torch.Tensor: |
| | r"""Encode a batch of images using a tiled encoder. |
| | |
| | When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several |
| | steps. This is useful to keep memory use constant regardless of image size. To avoid tiling artifacts, the |
| | tiles overlap and are blended together to form a smooth output. |
| | |
| | Args: |
| | x (`torch.Tensor`): Input batch of images. |
| | |
| | Returns: |
| | `torch.Tensor`: Encoded batch of images. |
| | """ |
| | |
| | sf = self.spatial_scale_factor |
| | tile_size = self.tile_latent_min_size |
| |
|
| | |
| | blend_size = int(tile_size * self.tile_overlap_factor) |
| | traverse_size = tile_size - blend_size |
| |
|
| | |
| | ti = range(0, x.shape[-2], traverse_size) |
| | tj = range(0, x.shape[-1], traverse_size) |
| |
|
| | |
| | blend_masks = torch.stack( |
| | torch.meshgrid([torch.arange(tile_size * sf) / (blend_size * sf - 1)] * 2, indexing="ij") |
| | ) |
| | blend_masks = blend_masks.clamp(0, 1).to(x.device) |
| |
|
| | |
| | out = torch.zeros(x.shape[0], 3, x.shape[-2] * sf, x.shape[-1] * sf, device=x.device) |
| | for i in ti: |
| | for j in tj: |
| | tile_in = x[..., i : i + tile_size, j : j + tile_size] |
| | |
| | tile_out = out[..., i * sf : (i + tile_size) * sf, j * sf : (j + tile_size) * sf] |
| | tile = self.decoder(tile_in) |
| | h, w = tile.shape[-2], tile.shape[-1] |
| | |
| | blend_mask_i = torch.ones_like(blend_masks[0]) if i == 0 else blend_masks[0] |
| | blend_mask_j = torch.ones_like(blend_masks[1]) if j == 0 else blend_masks[1] |
| | blend_mask = (blend_mask_i * blend_mask_j)[..., :h, :w] |
| | tile_out.copy_(blend_mask * tile + (1 - blend_mask) * tile_out) |
| | return out |
| |
|
| | @apply_forward_hook |
| | def encode(self, x: torch.Tensor, return_dict: bool = True) -> Union[AutoencoderTinyOutput, Tuple[torch.Tensor]]: |
| | if self.use_slicing and x.shape[0] > 1: |
| | output = [ |
| | self._tiled_encode(x_slice) if self.use_tiling else self.encoder(x_slice) for x_slice in x.split(1) |
| | ] |
| | output = torch.cat(output) |
| | else: |
| | output = self._tiled_encode(x) if self.use_tiling else self.encoder(x) |
| |
|
| | if not return_dict: |
| | return (output,) |
| |
|
| | return AutoencoderTinyOutput(latents=output) |
| |
|
| | @apply_forward_hook |
| | def decode( |
| | self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = True |
| | ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: |
| | if self.use_slicing and x.shape[0] > 1: |
| | output = [self._tiled_decode(x_slice) if self.use_tiling else self.decoder(x) for x_slice in x.split(1)] |
| | output = torch.cat(output) |
| | else: |
| | output = self._tiled_decode(x) if self.use_tiling else self.decoder(x) |
| |
|
| | if not return_dict: |
| | return (output,) |
| |
|
| | return DecoderOutput(sample=output) |
| |
|
| | def forward( |
| | self, |
| | sample: torch.Tensor, |
| | return_dict: bool = True, |
| | ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: |
| | r""" |
| | Args: |
| | sample (`torch.Tensor`): Input sample. |
| | return_dict (`bool`, *optional*, defaults to `True`): |
| | Whether or not to return a [`DecoderOutput`] instead of a plain tuple. |
| | """ |
| | enc = self.encode(sample).latents |
| |
|
| | |
| | |
| | scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() |
| |
|
| | |
| | |
| | unscaled_enc = self.unscale_latents(scaled_enc / 255.0) |
| |
|
| | dec = self.decode(unscaled_enc) |
| |
|
| | if not return_dict: |
| | return (dec,) |
| | return DecoderOutput(sample=dec) |
| |
|