|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass |
|
|
from typing import Tuple, Union |
|
|
|
|
|
import torch |
|
|
|
|
|
from ..configuration_utils import ConfigMixin, register_to_config |
|
|
from ..utils import BaseOutput, apply_forward_hook |
|
|
from .modeling_utils import ModelMixin |
|
|
from .vae import DecoderOutput, DecoderTiny, EncoderTiny |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AutoencoderTinyOutput(BaseOutput): |
|
|
""" |
|
|
Output of AutoencoderTiny encoding method. |
|
|
|
|
|
Args: |
|
|
latents (`torch.Tensor`): Encoded outputs of the `Encoder`. |
|
|
|
|
|
""" |
|
|
|
|
|
latents: torch.Tensor |
|
|
|
|
|
|
|
|
class AutoencoderTiny(ModelMixin, ConfigMixin): |
|
|
r""" |
|
|
A tiny distilled VAE model for encoding images into latents and decoding latent representations into images. |
|
|
|
|
|
[`AutoencoderTiny`] is a wrapper around the original implementation of `TAESD`. |
|
|
|
|
|
This model inherits from [`ModelMixin`]. Check the superclass documentation for its generic methods implemented for |
|
|
all models (such as downloading or saving). |
|
|
|
|
|
Parameters: |
|
|
in_channels (`int`, *optional*, defaults to 3): Number of channels in the input image. |
|
|
out_channels (`int`, *optional*, defaults to 3): Number of channels in the output. |
|
|
encoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): |
|
|
Tuple of integers representing the number of output channels for each encoder block. The length of the |
|
|
tuple should be equal to the number of encoder blocks. |
|
|
decoder_block_out_channels (`Tuple[int]`, *optional*, defaults to `(64, 64, 64, 64)`): |
|
|
Tuple of integers representing the number of output channels for each decoder block. The length of the |
|
|
tuple should be equal to the number of decoder blocks. |
|
|
act_fn (`str`, *optional*, defaults to `"relu"`): |
|
|
Activation function to be used throughout the model. |
|
|
latent_channels (`int`, *optional*, defaults to 4): |
|
|
Number of channels in the latent representation. The latent space acts as a compressed representation of |
|
|
the input image. |
|
|
upsampling_scaling_factor (`int`, *optional*, defaults to 2): |
|
|
Scaling factor for upsampling in the decoder. It determines the size of the output image during the |
|
|
upsampling process. |
|
|
num_encoder_blocks (`Tuple[int]`, *optional*, defaults to `(1, 3, 3, 3)`): |
|
|
Tuple of integers representing the number of encoder blocks at each stage of the encoding process. The |
|
|
length of the tuple should be equal to the number of stages in the encoder. Each stage has a different |
|
|
number of encoder blocks. |
|
|
num_decoder_blocks (`Tuple[int]`, *optional*, defaults to `(3, 3, 3, 1)`): |
|
|
Tuple of integers representing the number of decoder blocks at each stage of the decoding process. The |
|
|
length of the tuple should be equal to the number of stages in the decoder. Each stage has a different |
|
|
number of decoder blocks. |
|
|
latent_magnitude (`float`, *optional*, defaults to 3.0): |
|
|
Magnitude of the latent representation. This parameter scales the latent representation values to control |
|
|
the extent of information preservation. |
|
|
latent_shift (float, *optional*, defaults to 0.5): |
|
|
Shift applied to the latent representation. This parameter controls the center of the latent space. |
|
|
scaling_factor (`float`, *optional*, defaults to 1.0): |
|
|
The component-wise standard deviation of the trained latent space computed using the first batch of the |
|
|
training set. This is used to scale the latent space to have unit variance when training the diffusion |
|
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the |
|
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 |
|
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image |
|
|
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. For this Autoencoder, |
|
|
however, no such scaling factor was used, hence the value of 1.0 as the default. |
|
|
force_upcast (`bool`, *optional*, default to `False`): |
|
|
If enabled it will force the VAE to run in float32 for high image resolution pipelines, such as SD-XL. VAE |
|
|
can be fine-tuned / trained to a lower range without losing too much precision, in which case |
|
|
`force_upcast` can be set to `False` (see this fp16-friendly |
|
|
[AutoEncoder](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). |
|
|
""" |
|
|
_supports_gradient_checkpointing = True |
|
|
|
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channels=3, |
|
|
out_channels=3, |
|
|
encoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), |
|
|
decoder_block_out_channels: Tuple[int] = (64, 64, 64, 64), |
|
|
act_fn: str = "relu", |
|
|
latent_channels: int = 4, |
|
|
upsampling_scaling_factor: int = 2, |
|
|
num_encoder_blocks: Tuple[int] = (1, 3, 3, 3), |
|
|
num_decoder_blocks: Tuple[int] = (3, 3, 3, 1), |
|
|
latent_magnitude: int = 3, |
|
|
latent_shift: float = 0.5, |
|
|
force_upcast: float = False, |
|
|
scaling_factor: float = 1.0, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
if len(encoder_block_out_channels) != len(num_encoder_blocks): |
|
|
raise ValueError("`encoder_block_out_channels` should have the same length as `num_encoder_blocks`.") |
|
|
if len(decoder_block_out_channels) != len(num_decoder_blocks): |
|
|
raise ValueError("`decoder_block_out_channels` should have the same length as `num_decoder_blocks`.") |
|
|
|
|
|
self.encoder = EncoderTiny( |
|
|
in_channels=in_channels, |
|
|
out_channels=latent_channels, |
|
|
num_blocks=num_encoder_blocks, |
|
|
block_out_channels=encoder_block_out_channels, |
|
|
act_fn=act_fn, |
|
|
) |
|
|
|
|
|
self.decoder = DecoderTiny( |
|
|
in_channels=latent_channels, |
|
|
out_channels=out_channels, |
|
|
num_blocks=num_decoder_blocks, |
|
|
block_out_channels=decoder_block_out_channels, |
|
|
upsampling_scaling_factor=upsampling_scaling_factor, |
|
|
act_fn=act_fn, |
|
|
) |
|
|
|
|
|
self.latent_magnitude = latent_magnitude |
|
|
self.latent_shift = latent_shift |
|
|
self.scaling_factor = scaling_factor |
|
|
|
|
|
def _set_gradient_checkpointing(self, module, value=False): |
|
|
if isinstance(module, (EncoderTiny, DecoderTiny)): |
|
|
module.gradient_checkpointing = value |
|
|
|
|
|
def scale_latents(self, x): |
|
|
"""raw latents -> [0, 1]""" |
|
|
return x.div(2 * self.latent_magnitude).add(self.latent_shift).clamp(0, 1) |
|
|
|
|
|
def unscale_latents(self, x): |
|
|
"""[0, 1] -> raw latents""" |
|
|
return x.sub(self.latent_shift).mul(2 * self.latent_magnitude) |
|
|
|
|
|
@apply_forward_hook |
|
|
def encode( |
|
|
self, x: torch.FloatTensor, return_dict: bool = True |
|
|
) -> Union[AutoencoderTinyOutput, Tuple[torch.FloatTensor]]: |
|
|
output = self.encoder(x) |
|
|
|
|
|
if not return_dict: |
|
|
return (output,) |
|
|
|
|
|
return AutoencoderTinyOutput(latents=output) |
|
|
|
|
|
@apply_forward_hook |
|
|
def decode(self, x: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: |
|
|
output = self.decoder(x) |
|
|
|
|
|
|
|
|
output = output.mul_(2).sub_(1) |
|
|
|
|
|
if not return_dict: |
|
|
return (output,) |
|
|
|
|
|
return DecoderOutput(sample=output) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
sample: torch.FloatTensor, |
|
|
return_dict: bool = True, |
|
|
) -> Union[DecoderOutput, Tuple[torch.FloatTensor]]: |
|
|
r""" |
|
|
Args: |
|
|
sample (`torch.FloatTensor`): Input sample. |
|
|
return_dict (`bool`, *optional*, defaults to `True`): |
|
|
Whether or not to return a [`DecoderOutput`] instead of a plain tuple. |
|
|
""" |
|
|
enc = self.encode(sample).latents |
|
|
scaled_enc = self.scale_latents(enc).mul_(255).round_().byte() |
|
|
unscaled_enc = self.unscale_latents(scaled_enc) |
|
|
dec = self.decode(unscaled_enc) |
|
|
|
|
|
if not return_dict: |
|
|
return (dec,) |
|
|
return DecoderOutput(sample=dec) |
|
|
|