lavinal712's picture
init
161aead
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderKL
from diffusers.configuration_utils import register_to_config
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.utils.accelerate_utils import apply_forward_hook
class AutoencoderKLNextStep(AutoencoderKL):
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: bool = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
mid_block_add_attention: bool = True,
deterministic: bool = False,
normalize_latents: bool = False,
patch_size: Optional[int] = None,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
latent_channels=latent_channels,
norm_num_groups=norm_num_groups,
sample_size=sample_size,
scaling_factor=scaling_factor,
shift_factor=shift_factor,
latents_mean=latents_mean,
latents_std=latents_std,
force_upcast=force_upcast,
use_quant_conv=use_quant_conv,
use_post_quant_conv=use_post_quant_conv,
mid_block_add_attention=mid_block_add_attention,
)
self.deterministic = deterministic
self.normalize_latents = normalize_latents
self.patch_size = patch_size
def patchify(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
p = self.patch_size
h_, w_ = h // p, w // p
x = x.reshape(b, c, h_, p, w_, p)
x = torch.einsum("bchpwq->bcpqhw", x)
x = x.reshape(b, c * p ** 2, h_, w_)
return x
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
b, _, h_, w_ = x.shape
p = self.patch_size
c = x.shape[1] // (p ** 2)
x = x.reshape(b, c, p, p, h_, w_)
x = torch.einsum("bcpqhw->bchpwq", x)
x = x.reshape(b, c, h_ * p, w_ * p)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
mean, logvar = torch.chunk(h, 2, dim=1)
if self.patch_size is not None:
mean = self.patchify(mean)
if self.normalize_latents:
mean = mean.permute(0, 2, 3, 1)
mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6)
mean = mean.permute(0, 3, 1, 2)
if self.patch_size is not None:
mean = self.unpatchify(mean)
h = torch.cat([mean, logvar], dim=1).contiguous()
posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
noise_strength: float = 0.0,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
if noise_strength > 0.0:
p = torch.distributions.Uniform(0, noise_strength)
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
z.shape, device=z.device, dtype=z.dtype
)
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)