File size: 4,861 Bytes
161aead |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from typing import Dict, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import AutoencoderKL
from diffusers.configuration_utils import register_to_config
from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
from diffusers.models.modeling_outputs import AutoencoderKLOutput
from diffusers.utils.accelerate_utils import apply_forward_hook
class AutoencoderKLNextStep(AutoencoderKL):
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
block_out_channels: Tuple[int] = (64,),
layers_per_block: int = 1,
act_fn: str = "silu",
latent_channels: int = 4,
norm_num_groups: int = 32,
sample_size: int = 32,
scaling_factor: float = 0.18215,
shift_factor: Optional[float] = None,
latents_mean: Optional[Tuple[float]] = None,
latents_std: Optional[Tuple[float]] = None,
force_upcast: bool = True,
use_quant_conv: bool = True,
use_post_quant_conv: bool = True,
mid_block_add_attention: bool = True,
deterministic: bool = False,
normalize_latents: bool = False,
patch_size: Optional[int] = None,
):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
act_fn=act_fn,
latent_channels=latent_channels,
norm_num_groups=norm_num_groups,
sample_size=sample_size,
scaling_factor=scaling_factor,
shift_factor=shift_factor,
latents_mean=latents_mean,
latents_std=latents_std,
force_upcast=force_upcast,
use_quant_conv=use_quant_conv,
use_post_quant_conv=use_post_quant_conv,
mid_block_add_attention=mid_block_add_attention,
)
self.deterministic = deterministic
self.normalize_latents = normalize_latents
self.patch_size = patch_size
def patchify(self, x: torch.Tensor) -> torch.Tensor:
b, c, h, w = x.shape
p = self.patch_size
h_, w_ = h // p, w // p
x = x.reshape(b, c, h_, p, w_, p)
x = torch.einsum("bchpwq->bcpqhw", x)
x = x.reshape(b, c * p ** 2, h_, w_)
return x
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
b, _, h_, w_ = x.shape
p = self.patch_size
c = x.shape[1] // (p ** 2)
x = x.reshape(b, c, p, p, h_, w_)
x = torch.einsum("bcpqhw->bchpwq", x)
x = x.reshape(b, c, h_ * p, w_ * p)
return x
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
mean, logvar = torch.chunk(h, 2, dim=1)
if self.patch_size is not None:
mean = self.patchify(mean)
if self.normalize_latents:
mean = mean.permute(0, 2, 3, 1)
mean = F.layer_norm(mean, mean.shape[-1:], eps=1e-6)
mean = mean.permute(0, 3, 1, 2)
if self.patch_size is not None:
mean = self.unpatchify(mean)
h = torch.cat([mean, logvar], dim=1).contiguous()
posterior = DiagonalGaussianDistribution(h, deterministic=self.deterministic)
if not return_dict:
return (posterior,)
return AutoencoderKLOutput(latent_dist=posterior)
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
noise_strength: float = 0.0,
) -> Union[DecoderOutput, torch.Tensor]:
x = sample
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
if noise_strength > 0.0:
p = torch.distributions.Uniform(0, noise_strength)
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
z.shape, device=z.device, dtype=z.dtype
)
dec = self.decode(z).sample
if not return_dict:
return (dec,)
return DecoderOutput(sample=dec)
|