Spaces:
Paused
Paused
| from typing import Optional, Tuple, Union | |
| from diffusers import AutoencoderTiny | |
| from diffusers.models.autoencoders.vae import ( | |
| EncoderTiny, | |
| get_activation, | |
| AutoencoderTinyBlock, | |
| DecoderOutput | |
| ) | |
| from diffusers.utils.accelerate_utils import apply_forward_hook | |
| from diffusers.configuration_utils import register_to_config | |
| import torch | |
| import torch.nn as nn | |
| class DecoderTinyWithPooledExits(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| num_blocks: Tuple[int, ...], | |
| block_out_channels: Tuple[int, ...], | |
| upsampling_scaling_factor: int, | |
| act_fn: str, | |
| upsample_fn: str, | |
| ): | |
| super().__init__() | |
| layers = [] | |
| self.ordered_layers = [] | |
| l = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, padding=1) | |
| self.ordered_layers.append(l) | |
| layers.append(l) | |
| l = get_activation(act_fn) | |
| self.ordered_layers.append(l) | |
| layers.append(l) | |
| pooled_exits = [] | |
| for i, num_block in enumerate(num_blocks): | |
| is_final_block = i == (len(num_blocks) - 1) | |
| num_channels = block_out_channels[i] | |
| for _ in range(num_block): | |
| l = AutoencoderTinyBlock(num_channels, num_channels, act_fn) | |
| layers.append(l) | |
| self.ordered_layers.append(l) | |
| if not is_final_block: | |
| l = nn.Upsample( | |
| scale_factor=upsampling_scaling_factor, mode=upsample_fn | |
| ) | |
| layers.append(l) | |
| self.ordered_layers.append(l) | |
| conv_out_channel = num_channels if not is_final_block else out_channels | |
| l = nn.Conv2d( | |
| num_channels, | |
| conv_out_channel, | |
| kernel_size=3, | |
| padding=1, | |
| bias=is_final_block, | |
| ) | |
| layers.append(l) | |
| self.ordered_layers.append(l) | |
| if not is_final_block: | |
| p = nn.Conv2d( | |
| conv_out_channel, | |
| out_channels=3, | |
| kernel_size=3, | |
| padding=1, | |
| bias=True, | |
| ) | |
| p._is_pooled_exit = True | |
| pooled_exits.append(p) | |
| self.ordered_layers.append(p) | |
| self.layers = nn.ModuleList(layers) | |
| self.pooled_exits = nn.ModuleList(pooled_exits) | |
| self.gradient_checkpointing = False | |
| def forward(self, x: torch.Tensor, pooled_outputs=False) -> torch.Tensor: | |
| r"""The forward method of the `DecoderTiny` class.""" | |
| # Clamp. | |
| x = torch.tanh(x / 3) * 3 | |
| pooled_output_list = [] | |
| for layer in self.ordered_layers: | |
| # see if is pooled exit | |
| try: | |
| if hasattr(layer, '_is_pooled_exit') and layer._is_pooled_exit: | |
| if pooled_outputs: | |
| pooled_output = layer(x) | |
| pooled_output_list.append(pooled_output) | |
| else: | |
| if torch.is_grad_enabled() and self.gradient_checkpointing: | |
| x = self._gradient_checkpointing_func(layer, x) | |
| else: | |
| x = layer(x) | |
| except RuntimeError as e: | |
| raise e | |
| # scale image from [0, 1] to [-1, 1] to match diffusers convention | |
| x = x.mul(2).sub(1) | |
| if pooled_outputs: | |
| return x, pooled_output_list | |
| return x | |
| class AutoencoderTinyWithPooledExits(AutoencoderTiny): | |
| def __init__( | |
| self, | |
| in_channels: int = 3, | |
| out_channels: int = 3, | |
| encoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), | |
| decoder_block_out_channels: Tuple[int, ...] = (64, 64, 64, 64), | |
| act_fn: str = "relu", | |
| upsample_fn: str = "nearest", | |
| latent_channels: int = 4, | |
| upsampling_scaling_factor: int = 2, | |
| num_encoder_blocks: Tuple[int, ...] = (1, 3, 3, 3), | |
| num_decoder_blocks: Tuple[int, ...] = (3, 3, 3, 1), | |
| latent_magnitude: int = 3, | |
| latent_shift: float = 0.5, | |
| force_upcast: bool = False, | |
| scaling_factor: float = 1.0, | |
| shift_factor: float = 0.0, | |
| ): | |
| super(AutoencoderTiny, self).__init__() | |
| if len(encoder_block_out_channels) != len(num_encoder_blocks): | |
| raise ValueError( | |
| "`encoder_block_out_channels` should have the same length as `num_encoder_blocks`." | |
| ) | |
| if len(decoder_block_out_channels) != len(num_decoder_blocks): | |
| raise ValueError( | |
| "`decoder_block_out_channels` should have the same length as `num_decoder_blocks`." | |
| ) | |
| self.encoder = EncoderTiny( | |
| in_channels=in_channels, | |
| out_channels=latent_channels, | |
| num_blocks=num_encoder_blocks, | |
| block_out_channels=encoder_block_out_channels, | |
| act_fn=act_fn, | |
| ) | |
| self.decoder = DecoderTinyWithPooledExits( | |
| in_channels=latent_channels, | |
| out_channels=out_channels, | |
| num_blocks=num_decoder_blocks, | |
| block_out_channels=decoder_block_out_channels, | |
| upsampling_scaling_factor=upsampling_scaling_factor, | |
| act_fn=act_fn, | |
| upsample_fn=upsample_fn, | |
| ) | |
| self.latent_magnitude = latent_magnitude | |
| self.latent_shift = latent_shift | |
| self.scaling_factor = scaling_factor | |
| self.use_slicing = False | |
| self.use_tiling = False | |
| # only relevant if vae tiling is enabled | |
| self.spatial_scale_factor = 2**out_channels | |
| self.tile_overlap_factor = 0.125 | |
| self.tile_sample_min_size = 512 | |
| self.tile_latent_min_size = ( | |
| self.tile_sample_min_size // self.spatial_scale_factor | |
| ) | |
| self.register_to_config(block_out_channels=decoder_block_out_channels) | |
| self.register_to_config(force_upcast=False) | |
| def decode_with_pooled_exits( | |
| self, x: torch.Tensor, generator: Optional[torch.Generator] = None, return_dict: bool = False | |
| ) -> Union[DecoderOutput, Tuple[torch.Tensor]]: | |
| output, pooled_outputs = self.decoder(x, pooled_outputs=True) | |
| if not return_dict: | |
| return (output, pooled_outputs) | |
| return DecoderOutput(sample=output) | |