|
|
from typing import Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from diffusers import AutoencoderKL |
|
|
from diffusers.configuration_utils import register_to_config |
|
|
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution |
|
|
from diffusers.models.modeling_outputs import AutoencoderKLOutput |
|
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
|
|
|
|
|
|
|
|
class AutoencoderKLNextStep(AutoencoderKL): |
|
|
@register_to_config |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int = 3, |
|
|
out_channels: int = 3, |
|
|
down_block_types: Tuple[str] = ("DownEncoderBlock2D",), |
|
|
up_block_types: Tuple[str] = ("UpDecoderBlock2D",), |
|
|
block_out_channels: Tuple[int] = (64,), |
|
|
layers_per_block: int = 1, |
|
|
act_fn: str = "silu", |
|
|
latent_channels: int = 4, |
|
|
norm_num_groups: int = 32, |
|
|
sample_size: int = 32, |
|
|
scaling_factor: float = 0.18215, |
|
|
shift_factor: Optional[float] = None, |
|
|
latents_mean: Optional[Tuple[float]] = None, |
|
|
latents_std: Optional[Tuple[float]] = None, |
|
|
force_upcast: bool = True, |
|
|
use_quant_conv: bool = True, |
|
|
use_post_quant_conv: bool = True, |
|
|
mid_block_add_attention: bool = True, |
|
|
deterministic: bool = False, |
|
|
normalize_latents: bool = False, |
|
|
patch_size: Optional[int] = None, |
|
|
): |
|
|
super().__init__( |
|
|
in_channels=in_channels, |
|
|
out_channels=out_channels, |
|
|
down_block_types=down_block_types, |
|
|
up_block_types=up_block_types, |
|
|
block_out_channels=block_out_channels, |
|
|
layers_per_block=layers_per_block, |
|
|
act_fn=act_fn, |
|
|
latent_channels=latent_channels, |
|
|
norm_num_groups=norm_num_groups, |
|
|
sample_size=sample_size, |
|
|
scaling_factor=scaling_factor, |
|
|
shift_factor=shift_factor, |
|
|
latents_mean=latents_mean, |
|
|
latents_std=latents_std, |
|
|
force_upcast=force_upcast, |
|
|
use_quant_conv=use_quant_conv, |
|
|
use_post_quant_conv=use_post_quant_conv, |
|
|
mid_block_add_attention=mid_block_add_attention, |
|
|
) |
|
|
self.deterministic = deterministic |
|
|
self.normalize_latents = normalize_latents |
|
|
self.patch_size = patch_size |
|
|
|
|
|
def patchify(self, x: torch.Tensor) -> torch.Tensor: |
|
|
b, c, h, w = x.shape |
|
|
p = self.patch_size |
|
|
h_, w_ = h // p, w // p |
|
|
|
|
|
x = x.reshape(b, c, h_, p, w_, p) |
|
|
x = torch.einsum("bchpwq->bcpqhw", x) |
|
|
x = x.reshape(b, c * p ** 2, h_, w_) |
|
|
return x |
|
|
|
|
|
def unpatchify(self, x: torch.Tensor) -> torch.Tensor: |
|
|
b, _, h_, w_ = x.shape |
|
|
p = self.patch_size |
|
|
c = x.shape[1] // (p ** 2) |
|
|
|
|
|
x = x.reshape(b, c, p, p, h_, w_) |
|
|
x = torch.einsum("bcpqhw->bchpwq", x) |
|
|
x = x.reshape(b, c, h_ * p, w_ * p) |
|
|
return x |
|
|
|
|
|
@apply_forward_hook |
|
|
def encode( |
|
|
self, x: torch.Tensor, return_dict: bool = True |
|
|
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]: |
|
|
if self.use_slicing and x.shape[0] > 1: |
|
|
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
|
|
h = torch.cat(encoded_slices) |
|
|
else: |
|
|
h = self._encode(x) |
|
|
|
|
|
mean, logvar = torch.chunk(h, 2, dim=1) |
|
|
if self.patch_size is not None: |
|
|
mean = self.patchify(mean) |
|
|
if self.normalize_latents: |
|
|
mean = mean.permute(0, 2, 3, 1) |
|
|
mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6) |
|
|
mean = mean.permute(0, 3, 1, 2) |
|
|
if self.patch_size is not None: |
|
|
mean = self.unpatchify(mean) |
|
|
h = torch.cat([mean, logvar], dim=1).contiguous() |
|
|
posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic) |
|
|
|
|
|
if not return_dict: |
|
|
return (posterior,) |
|
|
|
|
|
return AutoencoderKLOutput(latent_dist=posterior) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
sample: torch.Tensor, |
|
|
sample_posterior: bool = False, |
|
|
return_dict: bool = True, |
|
|
generator: Optional[torch.Generator] = None, |
|
|
noise_strength: float = 0.0, |
|
|
) -> Union[DecoderOutput, torch.Tensor]: |
|
|
x = sample |
|
|
posterior = self.encode(x).latent_dist |
|
|
if sample_posterior: |
|
|
z = posterior.sample(generator=generator) |
|
|
else: |
|
|
z = posterior.mode() |
|
|
if noise_strength > 0.0: |
|
|
p = torch.distributions.Uniform(0, noise_strength) |
|
|
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor( |
|
|
z.shape, device=z.device, dtype=z.dtype |
|
|
) |
|
|
dec = self.decode(z).sample |
|
|
|
|
|
if not return_dict: |
|
|
return (dec,) |
|
|
|
|
|
return DecoderOutput(sample=dec) |
|
|
|