from typing import Dict, Optional, Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from diffusers import AutoencoderKL from diffusers.configuration_utils import register_to_config from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.modeling_outputs import AutoencoderKLOutput from diffusers.utils.accelerate_utils import apply_forward_hook class AutoencoderKLNextStep(AutoencoderKL): @register_to_config def __init__( self, in_channels: int = 3, out_channels: int = 3, down_block_types: Tuple[str] = ("DownEncoderBlock2D",), up_block_types: Tuple[str] = ("UpDecoderBlock2D",), block_out_channels: Tuple[int] = (64,), layers_per_block: int = 1, act_fn: str = "silu", latent_channels: int = 4, norm_num_groups: int = 32, sample_size: int = 32, scaling_factor: float = 0.18215, shift_factor: Optional[float] = None, latents_mean: Optional[Tuple[float]] = None, latents_std: Optional[Tuple[float]] = None, force_upcast: bool = True, use_quant_conv: bool = True, use_post_quant_conv: bool = True, mid_block_add_attention: bool = True, deterministic: bool = False, normalize_latents: bool = False, patch_size: Optional[int] = None, ): super().__init__( in_channels=in_channels, out_channels=out_channels, down_block_types=down_block_types, up_block_types=up_block_types, block_out_channels=block_out_channels, layers_per_block=layers_per_block, act_fn=act_fn, latent_channels=latent_channels, norm_num_groups=norm_num_groups, sample_size=sample_size, scaling_factor=scaling_factor, shift_factor=shift_factor, latents_mean=latents_mean, latents_std=latents_std, force_upcast=force_upcast, use_quant_conv=use_quant_conv, use_post_quant_conv=use_post_quant_conv, mid_block_add_attention=mid_block_add_attention, ) self.deterministic = deterministic self.normalize_latents = normalize_latents self.patch_size = patch_size def patchify(self, x: torch.Tensor) -> torch.Tensor: b, c, h, w = x.shape p = self.patch_size h_, w_ = h // p, w // p x = x.reshape(b, c, h_, p, w_, p) x = torch.einsum("bchpwq->bcpqhw", x) x = x.reshape(b, c * p ** 2, h_, w_) return x def unpatchify(self, x: torch.Tensor) -> torch.Tensor: b, _, h_, w_ = x.shape p = self.patch_size c = x.shape[1] // (p ** 2) x = x.reshape(b, c, p, p, h_, w_) x = torch.einsum("bcpqhw->bchpwq", x) x = x.reshape(b, c, h_ * p, w_ * p) return x @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: if self.use_slicing and x.shape[0] > 1: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self._encode(x) mean, logvar = torch.chunk(h, 2, dim=1) if self.patch_size is not None: mean = self.patchify(mean) if self.normalize_latents: mean = mean.permute(0, 2, 3, 1) mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6) mean = mean.permute(0, 3, 1, 2) if self.patch_size is not None: mean = self.unpatchify(mean) h = torch.cat([mean, logvar], dim=1).contiguous() posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def forward( self, sample: torch.Tensor, sample_posterior: bool = False, return_dict: bool = True, generator: Optional[torch.Generator] = None, noise_strength: float = 0.0, ) -> Union[DecoderOutput, torch.Tensor]: x = sample posterior = self.encode(x).latent_dist if sample_posterior: z = posterior.sample(generator=generator) else: z = posterior.mode() if noise_strength > 0.0: p = torch.distributions.Uniform(0, noise_strength) z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor( z.shape, device=z.device, dtype=z.dtype ) dec = self.decode(z).sample if not return_dict: return (dec,) return DecoderOutput(sample=dec)