| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.utils.checkpoint |
|
|
| from ...configuration_utils import ConfigMixin, register_to_config |
| from ...loaders import FromOriginalModelMixin |
| from ...utils import logging |
| from ...utils.accelerate_utils import apply_forward_hook |
| from ..activations import get_activation |
| from ..modeling_outputs import AutoencoderKLOutput |
| from ..modeling_utils import ModelMixin |
| from .vae import AutoencoderMixin, DecoderOutput, DiagonalGaussianDistribution |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class HunyuanImageResnetBlock(nn.Module): |
| r""" |
| Residual block with two convolutions and optional channel change. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| non_linearity (str, optional): Type of non-linearity to use. Default is "silu". |
| """ |
|
|
| def __init__(self, in_channels: int, out_channels: int, non_linearity: str = "silu") -> None: |
| super().__init__() |
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.nonlinearity = get_activation(non_linearity) |
|
|
| |
| self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
| self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| self.norm2 = nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True) |
| self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| if in_channels != out_channels: |
| self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) |
| else: |
| self.conv_shortcut = None |
|
|
| def forward(self, x): |
| |
| residual = x |
|
|
| |
| x = self.norm1(x) |
| x = self.nonlinearity(x) |
|
|
| x = self.conv1(x) |
| x = self.norm2(x) |
| x = self.nonlinearity(x) |
| x = self.conv2(x) |
|
|
| if self.conv_shortcut is not None: |
| x = self.conv_shortcut(x) |
| |
| return x + residual |
|
|
|
|
| class HunyuanImageAttentionBlock(nn.Module): |
| r""" |
| Self-attention with a single head. |
| |
| Args: |
| in_channels (int): The number of channels in the input tensor. |
| """ |
|
|
| def __init__(self, in_channels: int): |
| super().__init__() |
|
|
| |
| self.norm = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) |
| self.to_q = nn.Conv2d(in_channels, in_channels, 1) |
| self.to_k = nn.Conv2d(in_channels, in_channels, 1) |
| self.to_v = nn.Conv2d(in_channels, in_channels, 1) |
| self.proj = nn.Conv2d(in_channels, in_channels, 1) |
|
|
| def forward(self, x): |
| identity = x |
| x = self.norm(x) |
|
|
| |
| query = self.to_q(x) |
| key = self.to_k(x) |
| value = self.to_v(x) |
|
|
| batch_size, channels, height, width = query.shape |
| query = query.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous() |
| key = key.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous() |
| value = value.permute(0, 2, 3, 1).reshape(batch_size, height * width, channels).contiguous() |
|
|
| |
| x = F.scaled_dot_product_attention(query, key, value) |
|
|
| x = x.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2) |
| |
| x = self.proj(x) |
|
|
| return x + identity |
|
|
|
|
| class HunyuanImageDownsample(nn.Module): |
| """ |
| Downsampling block for spatial reduction. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| """ |
|
|
| def __init__(self, in_channels: int, out_channels: int): |
| super().__init__() |
| factor = 4 |
| if out_channels % factor != 0: |
| raise ValueError(f"out_channels % factor != 0: {out_channels % factor}") |
|
|
| self.conv = nn.Conv2d(in_channels, out_channels // factor, kernel_size=3, stride=1, padding=1) |
| self.group_size = factor * in_channels // out_channels |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.conv(x) |
|
|
| B, C, H, W = h.shape |
| h = h.reshape(B, C, H // 2, 2, W // 2, 2) |
| h = h.permute(0, 3, 5, 1, 2, 4) |
| h = h.reshape(B, 4 * C, H // 2, W // 2) |
|
|
| B, C, H, W = x.shape |
| shortcut = x.reshape(B, C, H // 2, 2, W // 2, 2) |
| shortcut = shortcut.permute(0, 3, 5, 1, 2, 4) |
| shortcut = shortcut.reshape(B, 4 * C, H // 2, W // 2) |
|
|
| B, C, H, W = shortcut.shape |
| shortcut = shortcut.view(B, h.shape[1], self.group_size, H, W).mean(dim=2) |
| return h + shortcut |
|
|
|
|
| class HunyuanImageUpsample(nn.Module): |
| """ |
| Upsampling block for spatial expansion. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| out_channels (int): Number of output channels. |
| """ |
|
|
| def __init__(self, in_channels: int, out_channels: int): |
| super().__init__() |
| factor = 4 |
| self.conv = nn.Conv2d(in_channels, out_channels * factor, kernel_size=3, stride=1, padding=1) |
| self.repeats = factor * out_channels // in_channels |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.conv(x) |
|
|
| B, C, H, W = h.shape |
| h = h.reshape(B, 2, 2, C // 4, H, W) |
| h = h.permute(0, 3, 4, 1, 5, 2) |
| h = h.reshape(B, C // 4, H * 2, W * 2) |
|
|
| shortcut = x.repeat_interleave(repeats=self.repeats, dim=1) |
|
|
| B, C, H, W = shortcut.shape |
| shortcut = shortcut.reshape(B, 2, 2, C // 4, H, W) |
| shortcut = shortcut.permute(0, 3, 4, 1, 5, 2) |
| shortcut = shortcut.reshape(B, C // 4, H * 2, W * 2) |
| return h + shortcut |
|
|
|
|
| class HunyuanImageMidBlock(nn.Module): |
| """ |
| Middle block for HunyuanImageVAE encoder and decoder. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| num_layers (int): Number of layers. |
| """ |
|
|
| def __init__(self, in_channels: int, num_layers: int = 1): |
| super().__init__() |
|
|
| resnets = [HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)] |
|
|
| attentions = [] |
| for _ in range(num_layers): |
| attentions.append(HunyuanImageAttentionBlock(in_channels)) |
| resnets.append(HunyuanImageResnetBlock(in_channels=in_channels, out_channels=in_channels)) |
|
|
| self.resnets = nn.ModuleList(resnets) |
| self.attentions = nn.ModuleList(attentions) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.resnets[0](x) |
|
|
| for attn, resnet in zip(self.attentions, self.resnets[1:]): |
| x = attn(x) |
| x = resnet(x) |
|
|
| return x |
|
|
|
|
| class HunyuanImageEncoder2D(nn.Module): |
| r""" |
| Encoder network that compresses input to latent representation. |
| |
| Args: |
| in_channels (int): Number of input channels. |
| z_channels (int): Number of latent channels. |
| block_out_channels (list of int): Output channels for each block. |
| num_res_blocks (int): Number of residual blocks per block. |
| spatial_compression_ratio (int): Spatial downsampling factor. |
| non_linearity (str): Type of non-linearity to use. Default is "silu". |
| downsample_match_channel (bool): Whether to match channels during downsampling. |
| """ |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| z_channels: int, |
| block_out_channels: tuple[int, ...], |
| num_res_blocks: int, |
| spatial_compression_ratio: int, |
| non_linearity: str = "silu", |
| downsample_match_channel: bool = True, |
| ): |
| super().__init__() |
| if block_out_channels[-1] % (2 * z_channels) != 0: |
| raise ValueError( |
| f"block_out_channels[-1 has to be divisible by 2 * out_channels, you have block_out_channels = {block_out_channels[-1]} and out_channels = {z_channels}" |
| ) |
|
|
| self.in_channels = in_channels |
| self.z_channels = z_channels |
| self.block_out_channels = block_out_channels |
| self.num_res_blocks = num_res_blocks |
| self.spatial_compression_ratio = spatial_compression_ratio |
|
|
| self.group_size = block_out_channels[-1] // (2 * z_channels) |
| self.nonlinearity = get_activation(non_linearity) |
|
|
| |
| self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) |
|
|
| |
| self.down_blocks = nn.ModuleList([]) |
|
|
| block_in_channel = block_out_channels[0] |
| for i in range(len(block_out_channels)): |
| block_out_channel = block_out_channels[i] |
| |
| for _ in range(num_res_blocks): |
| self.down_blocks.append( |
| HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel) |
| ) |
| block_in_channel = block_out_channel |
|
|
| |
| if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1: |
| if downsample_match_channel: |
| block_out_channel = block_out_channels[i + 1] |
| self.down_blocks.append( |
| HunyuanImageDownsample(in_channels=block_in_channel, out_channels=block_out_channel) |
| ) |
| block_in_channel = block_out_channel |
|
|
| |
| self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[-1], num_layers=1) |
|
|
| |
| |
| self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True) |
| self.conv_out = nn.Conv2d(block_out_channels[-1], 2 * z_channels, kernel_size=3, stride=1, padding=1) |
|
|
| self.gradient_checkpointing = False |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.conv_in(x) |
|
|
| |
| for down_block in self.down_blocks: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| x = self._gradient_checkpointing_func(down_block, x) |
| else: |
| x = down_block(x) |
|
|
| |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| x = self._gradient_checkpointing_func(self.mid_block, x) |
| else: |
| x = self.mid_block(x) |
|
|
| |
| B, C, H, W = x.shape |
| residual = x.view(B, C // self.group_size, self.group_size, H, W).mean(dim=2) |
|
|
| x = self.norm_out(x) |
| x = self.nonlinearity(x) |
| x = self.conv_out(x) |
| return x + residual |
|
|
|
|
| class HunyuanImageDecoder2D(nn.Module): |
| r""" |
| Decoder network that reconstructs output from latent representation. |
| |
| Args: |
| z_channels : int |
| Number of latent channels. |
| out_channels : int |
| Number of output channels. |
| block_out_channels : tuple[int, ...] |
| Output channels for each block. |
| num_res_blocks : int |
| Number of residual blocks per block. |
| spatial_compression_ratio : int |
| Spatial upsampling factor. |
| upsample_match_channel : bool |
| Whether to match channels during upsampling. |
| non_linearity (str): Type of non-linearity to use. Default is "silu". |
| """ |
|
|
| def __init__( |
| self, |
| z_channels: int, |
| out_channels: int, |
| block_out_channels: tuple[int, ...], |
| num_res_blocks: int, |
| spatial_compression_ratio: int, |
| upsample_match_channel: bool = True, |
| non_linearity: str = "silu", |
| ): |
| super().__init__() |
| if block_out_channels[0] % z_channels != 0: |
| raise ValueError( |
| f"block_out_channels[0] should be divisible by z_channels but has block_out_channels[0] = {block_out_channels[0]} and z_channels = {z_channels}" |
| ) |
|
|
| self.z_channels = z_channels |
| self.block_out_channels = block_out_channels |
| self.num_res_blocks = num_res_blocks |
| self.repeat = block_out_channels[0] // z_channels |
| self.spatial_compression_ratio = spatial_compression_ratio |
| self.nonlinearity = get_activation(non_linearity) |
|
|
| self.conv_in = nn.Conv2d(z_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1) |
|
|
| |
| self.mid_block = HunyuanImageMidBlock(in_channels=block_out_channels[0], num_layers=1) |
|
|
| |
| block_in_channel = block_out_channels[0] |
| self.up_blocks = nn.ModuleList() |
| for i in range(len(block_out_channels)): |
| block_out_channel = block_out_channels[i] |
| for _ in range(self.num_res_blocks + 1): |
| self.up_blocks.append( |
| HunyuanImageResnetBlock(in_channels=block_in_channel, out_channels=block_out_channel) |
| ) |
| block_in_channel = block_out_channel |
|
|
| if i < np.log2(spatial_compression_ratio) and i != len(block_out_channels) - 1: |
| if upsample_match_channel: |
| block_out_channel = block_out_channels[i + 1] |
| self.up_blocks.append(HunyuanImageUpsample(block_in_channel, block_out_channel)) |
| block_in_channel = block_out_channel |
|
|
| |
| self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out_channels[-1], eps=1e-6, affine=True) |
| self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) |
|
|
| self.gradient_checkpointing = False |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| h = self.conv_in(x) + x.repeat_interleave(repeats=self.repeat, dim=1) |
|
|
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| h = self._gradient_checkpointing_func(self.mid_block, h) |
| else: |
| h = self.mid_block(h) |
|
|
| for up_block in self.up_blocks: |
| if torch.is_grad_enabled() and self.gradient_checkpointing: |
| h = self._gradient_checkpointing_func(up_block, h) |
| else: |
| h = up_block(h) |
| h = self.norm_out(h) |
| h = self.nonlinearity(h) |
| h = self.conv_out(h) |
| return h |
|
|
|
|
| class AutoencoderKLHunyuanImage(ModelMixin, AutoencoderMixin, ConfigMixin, FromOriginalModelMixin): |
| r""" |
| A VAE model for 2D images with spatial tiling support. |
| |
| This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
| for all models (such as downloading or saving). |
| """ |
|
|
| _supports_gradient_checkpointing = False |
|
|
| |
| @register_to_config |
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| latent_channels: int, |
| block_out_channels: tuple[int, ...], |
| layers_per_block: int, |
| spatial_compression_ratio: int, |
| sample_size: int, |
| scaling_factor: float = None, |
| downsample_match_channel: bool = True, |
| upsample_match_channel: bool = True, |
| ) -> None: |
| |
| super().__init__() |
|
|
| self.encoder = HunyuanImageEncoder2D( |
| in_channels=in_channels, |
| z_channels=latent_channels, |
| block_out_channels=block_out_channels, |
| num_res_blocks=layers_per_block, |
| spatial_compression_ratio=spatial_compression_ratio, |
| downsample_match_channel=downsample_match_channel, |
| ) |
|
|
| self.decoder = HunyuanImageDecoder2D( |
| z_channels=latent_channels, |
| out_channels=out_channels, |
| block_out_channels=list(reversed(block_out_channels)), |
| num_res_blocks=layers_per_block, |
| spatial_compression_ratio=spatial_compression_ratio, |
| upsample_match_channel=upsample_match_channel, |
| ) |
|
|
| |
| self.use_slicing = False |
| self.use_tiling = False |
|
|
| |
| self.tile_sample_min_size = sample_size |
| self.tile_latent_min_size = sample_size // spatial_compression_ratio |
| self.tile_overlap_factor = 0.25 |
|
|
| def enable_tiling( |
| self, |
| tile_sample_min_size: int | None = None, |
| tile_overlap_factor: float | None = None, |
| ) -> None: |
| r""" |
| Enable spatial tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles |
| to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to |
| allow processing larger images. |
| |
| Args: |
| tile_sample_min_size (`int`, *optional*): |
| The minimum size required for a sample to be separated into tiles across the spatial dimension. |
| tile_overlap_factor (`float`, *optional*): |
| The overlap factor required for a latent to be separated into tiles across the spatial dimension. |
| """ |
| self.use_tiling = True |
| self.tile_sample_min_size = tile_sample_min_size or self.tile_sample_min_size |
| self.tile_overlap_factor = tile_overlap_factor or self.tile_overlap_factor |
| self.tile_latent_min_size = self.tile_sample_min_size // self.config.spatial_compression_ratio |
|
|
| def _encode(self, x: torch.Tensor): |
|
|
| batch_size, num_channels, height, width = x.shape |
|
|
| if self.use_tiling and (width > self.tile_sample_min_size or height > self.tile_sample_min_size): |
| return self.tiled_encode(x) |
|
|
| enc = self.encoder(x) |
|
|
| return enc |
|
|
| @apply_forward_hook |
| def encode( |
| self, x: torch.Tensor, return_dict: bool = True |
| ) -> AutoencoderKLOutput | tuple[DiagonalGaussianDistribution]: |
| r""" |
| Encode a batch of images into latents. |
| |
| Args: |
| x (`torch.Tensor`): Input batch of images. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. |
| |
| Returns: |
| The latent representations of the encoded videos. If `return_dict` is True, a |
| [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. |
| """ |
| if self.use_slicing and x.shape[0] > 1: |
| encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] |
| h = torch.cat(encoded_slices) |
| else: |
| h = self._encode(x) |
| posterior = DiagonalGaussianDistribution(h) |
|
|
| if not return_dict: |
| return (posterior,) |
| return AutoencoderKLOutput(latent_dist=posterior) |
|
|
| def _decode(self, z: torch.Tensor, return_dict: bool = True): |
|
|
| batch_size, num_channels, height, width = z.shape |
|
|
| if self.use_tiling and (width > self.tile_latent_min_size or height > self.tile_latent_min_size): |
| return self.tiled_decode(z, return_dict=return_dict) |
|
|
| dec = self.decoder(z) |
|
|
| if not return_dict: |
| return (dec,) |
|
|
| return DecoderOutput(sample=dec) |
|
|
| @apply_forward_hook |
| def decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: |
| r""" |
| Decode a batch of images. |
| |
| Args: |
| z (`torch.Tensor`): Input batch of latent vectors. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
| |
| Returns: |
| [`~models.vae.DecoderOutput`] or `tuple`: |
| If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
| returned. |
| """ |
| if self.use_slicing and z.shape[0] > 1: |
| decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] |
| decoded = torch.cat(decoded_slices) |
| else: |
| decoded = self._decode(z).sample |
|
|
| if not return_dict: |
| return (decoded,) |
| return DecoderOutput(sample=decoded) |
|
|
|
|
| def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
| blend_extent = min(a.shape[-2], b.shape[-2], blend_extent) |
| for y in range(blend_extent): |
| b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * ( |
| y / blend_extent |
| ) |
| return b |
|
|
| def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
| blend_extent = min(a.shape[-1], b.shape[-1], blend_extent) |
| for x in range(blend_extent): |
| b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * ( |
| x / blend_extent |
| ) |
| return b |
|
|
| def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Encode input using spatial tiling strategy. |
| |
| Args: |
| x (`torch.Tensor`): Input tensor of shape (B, C, T, H, W). |
| |
| Returns: |
| `torch.Tensor`: |
| The latent representation of the encoded images. |
| """ |
| _, _, _, height, width = x.shape |
| overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) |
| blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) |
| row_limit = self.tile_latent_min_size - blend_extent |
|
|
| rows = [] |
| for i in range(0, height, overlap_size): |
| row = [] |
| for j in range(0, width, overlap_size): |
| tile = x[:, :, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] |
| tile = self.encoder(tile) |
| row.append(tile) |
| rows.append(row) |
|
|
| result_rows = [] |
| for i, row in enumerate(rows): |
| result_row = [] |
| for j, tile in enumerate(row): |
| if i > 0: |
| tile = self.blend_v(rows[i - 1][j], tile, blend_extent) |
| if j > 0: |
| tile = self.blend_h(row[j - 1], tile, blend_extent) |
| result_row.append(tile[:, :, :, :row_limit, :row_limit]) |
| result_rows.append(torch.cat(result_row, dim=-1)) |
|
|
| moments = torch.cat(result_rows, dim=-2) |
|
|
| return moments |
|
|
| def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> DecoderOutput | torch.Tensor: |
| """ |
| Decode latent using spatial tiling strategy. |
| |
| Args: |
| z (`torch.Tensor`): Latent tensor of shape (B, C, H, W). |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
| |
| Returns: |
| [`~models.vae.DecoderOutput`] or `tuple`: |
| If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
| returned. |
| """ |
| _, _, height, width = z.shape |
| overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) |
| blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) |
| row_limit = self.tile_sample_min_size - blend_extent |
|
|
| rows = [] |
| for i in range(0, height, overlap_size): |
| row = [] |
| for j in range(0, width, overlap_size): |
| tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] |
| decoded = self.decoder(tile) |
| row.append(decoded) |
| rows.append(row) |
|
|
| result_rows = [] |
| for i, row in enumerate(rows): |
| result_row = [] |
| for j, tile in enumerate(row): |
| if i > 0: |
| tile = self.blend_v(rows[i - 1][j], tile, blend_extent) |
| if j > 0: |
| tile = self.blend_h(row[j - 1], tile, blend_extent) |
| result_row.append(tile[:, :, :row_limit, :row_limit]) |
| result_rows.append(torch.cat(result_row, dim=-1)) |
|
|
| dec = torch.cat(result_rows, dim=-2) |
| if not return_dict: |
| return (dec,) |
| return DecoderOutput(sample=dec) |
|
|
|
|
| def forward( |
| self, |
| sample: torch.Tensor, |
| sample_posterior: bool = False, |
| return_dict: bool = True, |
| generator: torch.Generator | None = None, |
| ) -> DecoderOutput | torch.Tensor: |
| """ |
| Args: |
| sample (`torch.Tensor`): Input sample. |
| return_dict (`bool`, *optional*, defaults to `True`): |
| Whether or not to return a [`DecoderOutput`] instead of a plain tuple. |
| """ |
| posterior = self.encode(sample).latent_dist |
| if sample_posterior: |
| z = posterior.sample(generator=generator) |
| else: |
| z = posterior.mode() |
| dec = self.decode(z, return_dict=return_dict) |
|
|
| return dec |
|
|