ideogram4 / diffusers_src /src /diffusers /models /autoencoders /autoencoder_vidtok.py
multimodalart's picture
multimodalart HF Staff
Embed diffusers PR source; install locally
b8c861f verified
raw
history blame
64.5 kB
# Copyright 2025 The VidTok team, MSRA & Shanghai Jiao Tong University and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import logging
from ...utils.accelerate_utils import apply_forward_hook
from ..modeling_outputs import AutoencoderKLOutput
from ..modeling_utils import ModelMixin
from .vae import DecoderOutput, DiagonalGaussianDistribution
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class FSQRegularizer(nn.Module):
r"""
Finite Scalar Quantization: VQ-VAE Made Simple - https://arxiv.org/abs/2309.15505 Code adapted from
https://github.com/lucidrains/vector-quantize-pytorch/blob/master/vector_quantize_pytorch/finite_scalar_quantization.py
Args:
levels (`List[int]`):
A list of quantization levels.
dim (`int`, *optional*, defaults to `None`):
The dimension of latent codes.
num_codebooks (`int`, defaults to 1):
The number of codebooks.
keep_num_codebooks_dim (`bool`, *optional*, defaults to `None`):
Whether to keep the number of codebook dim.
"""
def __init__(
self,
levels: List[int],
dim: Optional[int] = None,
num_codebooks: int = 1,
keep_num_codebooks_dim: Optional[bool] = None,
):
super().__init__()
_levels = torch.tensor(levels, dtype=torch.int32)
self.register_buffer("_levels", _levels, persistent=False)
_basis = torch.cumprod(torch.tensor([1] + levels[:-1]), dim=0, dtype=torch.int32)
self.register_buffer("_basis", _basis, persistent=False)
codebook_dim = len(levels)
self.codebook_dim = codebook_dim
effective_codebook_dim = codebook_dim * num_codebooks
self.num_codebooks = num_codebooks
self.effective_codebook_dim = effective_codebook_dim
if keep_num_codebooks_dim is None:
keep_num_codebooks_dim = num_codebooks > 1
self.keep_num_codebooks_dim = keep_num_codebooks_dim
self.dim = len(_levels) * num_codebooks if dim is None else dim
has_projections = self.dim != effective_codebook_dim
self.project_in = nn.Linear(self.dim, effective_codebook_dim) if has_projections else nn.Identity()
self.project_out = nn.Linear(effective_codebook_dim, self.dim) if has_projections else nn.Identity()
self.has_projections = has_projections
self.codebook_size = self._levels.prod().item()
implicit_codebook = self.indices_to_codes(torch.arange(self.codebook_size), project_out=False)
self.register_buffer("implicit_codebook", implicit_codebook, persistent=False)
self.register_buffer("zero", torch.tensor(0.0), persistent=False)
self.global_codebook_usage = torch.zeros([2**self.codebook_dim, self.num_codebooks], dtype=torch.long)
def quantize(self, z: torch.Tensor, eps: float = 1e-3) -> torch.Tensor:
r"""Quantizes z, returns quantized zhat, same shape as z."""
half_l = (self._levels - 1) * (1 + eps) / 2
offset = torch.where(self._levels % 2 == 0, 0.5, 0.0)
shift = (offset / half_l).atanh()
z = (z + shift).tanh() * half_l - offset
zhat = z.round()
quantized = z + (zhat - z).detach()
half_width = self._levels // 2
return quantized / half_width
def codes_to_indices(self, zhat: torch.Tensor) -> torch.Tensor:
r"""Converts a `code` to an index in the codebook."""
half_width = self._levels // 2
zhat = (zhat * half_width) + half_width
return (zhat * self._basis).sum(dim=-1).to(torch.int32)
def indices_to_codes(self, indices: torch.Tensor, project_out: bool = True) -> torch.Tensor:
r"""Inverse of `codes_to_indices`."""
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
indices = indices.unsqueeze(-1)
codes_non_centered = (indices // self._basis) % self._levels
half_width = self._levels // 2
codes = (codes_non_centered - half_width) / half_width
if self.keep_num_codebooks_dim:
codes = codes.reshape(*codes.shape[:-2], -1)
if project_out:
codes = self.project_out(codes)
if is_img_or_video:
codes = codes.permute(0, -1, *range(1, codes.dim() - 1))
return codes
def forward(self, z: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
einstein notation b - batch n - sequence (or flattened spatial dimensions) d - feature dimension c - number of
codebook dim
"""
is_img_or_video = z.ndim >= 4
if is_img_or_video:
if z.ndim == 5:
b, d, t, h, w = z.shape
is_video = True
else:
b, d, h, w = z.shape
is_video = False
z = z.reshape(b, d, -1).permute(0, 2, 1)
z = self.project_in(z)
b, n, _ = z.shape
z = z.reshape(b, n, self.num_codebooks, -1)
orig_dtype = z.dtype
z = z.float()
codes = self.quantize(z)
indices = self.codes_to_indices(codes)
codes = codes.type(orig_dtype)
codes = codes.reshape(b, n, -1)
out = self.project_out(codes)
# reconstitute image or video dimensions
if is_img_or_video:
if is_video:
out = out.reshape(b, t, h, w, d).permute(0, 4, 1, 2, 3)
indices = indices.reshape(b, t, h, w, 1)
else:
out = out.reshape(b, h, w, d).permute(0, 3, 1, 2)
indices = indices.reshape(b, h, w, 1)
if not self.keep_num_codebooks_dim:
indices = indices.squeeze(-1)
return out, indices
class VidTokDownsample2D(nn.Module):
r"""A 2D downsampling layer used in VidTok Model."""
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
pad = (0, 1, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
x = self.conv(x)
return x
class VidTokUpsample2D(nn.Module):
r"""A 2D upsampling layer used in VidTok Model."""
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.interpolate(x.to(torch.float32), scale_factor=2.0, mode="nearest").to(x.dtype)
x = self.conv(x)
return x
class VidTokLayerNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
if x.dim() == 5:
x = x.permute(0, 2, 3, 4, 1)
x = self.norm(x)
x = x.permute(0, 4, 1, 2, 3)
elif x.dim() == 4:
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = x.permute(0, 3, 1, 2)
else:
x = x.permute(0, 2, 1)
x = self.norm(x)
x = x.permute(0, 2, 1)
return x
class VidTokCausalConv1d(nn.Module):
r"""A 1D causal convolution layer that pads the input tensor to ensure causality in VidTok Model."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
padding: int = 0,
):
super().__init__()
self.time_pad = dilation * (kernel_size - 1) + (1 - stride)
self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation)
self.is_first_chunk = True
self.causal_cache = None
self.cache_offset = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_first_chunk:
first_frame_pad = x[:, :, :1].repeat((1, 1, self.time_pad))
else:
first_frame_pad = self.causal_cache
if self.time_pad != 0:
first_frame_pad = first_frame_pad[:, :, -self.time_pad :]
else:
first_frame_pad = first_frame_pad[:, :, 0:0]
x = torch.concatenate((first_frame_pad, x), dim=2)
if self.cache_offset == 0:
self.causal_cache = x.clone()
else:
self.causal_cache = x[:, :, : -self.cache_offset].clone()
return self.conv(x)
class VidTokCausalConv3d(nn.Module):
r"""A 3D causal convolution layer that pads the input tensor to ensure causality in VidTok Model."""
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
dilation: Union[int, Tuple[int, int, int]] = 1,
padding: Union[int, Tuple[int, int, int]] = 0,
pad_mode: str = "constant",
):
super().__init__()
self.pad_mode = pad_mode
if isinstance(kernel_size, int):
kernel_size = (kernel_size,) * 3
if isinstance(dilation, int):
dilation = (dilation,) * 3
if isinstance(stride, int):
stride = (stride,) * 3
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
time_pad = dilation[0] * (time_kernel_size - 1) + (1 - stride[0])
height_pad = dilation[1] * (height_kernel_size - 1) + (1 - stride[1])
width_pad = dilation[2] * (width_kernel_size - 1) + (1 - stride[2])
self.time_pad = time_pad
self.spatial_padding = (
width_pad // 2,
width_pad - width_pad // 2,
height_pad // 2,
height_pad - height_pad // 2,
0,
0,
)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, dilation=dilation)
self.is_first_chunk = True
self.causal_cache = None
self.cache_offset = 0
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.is_first_chunk:
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_pad, 1, 1))
else:
first_frame_pad = self.causal_cache
if self.time_pad != 0:
first_frame_pad = first_frame_pad[:, :, -self.time_pad :]
else:
first_frame_pad = first_frame_pad[:, :, 0:0]
x = torch.concatenate((first_frame_pad, x), dim=2)
if self.cache_offset == 0:
self.causal_cache = x.clone()
else:
self.causal_cache = x[:, :, : -self.cache_offset].clone()
x = F.pad(x, self.spatial_padding, mode=self.pad_mode)
return self.conv(x)
class VidTokDownsample3D(nn.Module):
r"""A 3D downsampling layer used in VidTok Model."""
def __init__(self, in_channels: int, out_channels: int, mix_factor: float = 2.0, is_causal: bool = True):
super().__init__()
self.is_causal = is_causal
self.kernel_size = (3, 3, 3)
self.avg_pool = nn.AvgPool3d((3, 1, 1), stride=(2, 1, 1))
make_conv_cls = VidTokCausalConv3d if self.is_causal else nn.Conv3d
self.conv = make_conv_cls(in_channels, out_channels, 3, stride=(2, 1, 1), padding=(0, 1, 1))
self.mix_factor = nn.Parameter(torch.Tensor([mix_factor]))
if self.is_causal:
self.is_first_chunk = True
self.causal_cache = None
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha = torch.sigmoid(self.mix_factor)
if self.is_causal:
pad = (0, 0, 0, 0, 1, 0)
if self.is_first_chunk:
x_pad = torch.nn.functional.pad(x, pad, mode="replicate")
else:
x_pad = torch.concatenate((self.causal_cache, x), dim=2)
self.causal_cache = x_pad[:, :, -1:].clone()
if x_pad.device.type == "cpu" and x_pad.dtype == torch.bfloat16:
# PyTorch's avg_pool3d lacks CPU support for BFloat16.
# To avoid errors, we cast to float32, perform the pooling,
# and then cast back to BFloat16 to maintain the expected dtype.
x1 = self.avg_pool(x_pad.float()).to(torch.bfloat16)
else:
x1 = self.avg_pool(x_pad)
else:
pad = (0, 0, 0, 0, 0, 1)
x = F.pad(x, pad, mode="constant", value=0)
if x.device.type == "cpu" and x.dtype == torch.bfloat16:
# PyTorch's avg_pool3d lacks CPU support for BFloat16.
# To avoid errors, we cast to float32, perform the pooling,
# and then cast back to BFloat16 to maintain the expected dtype.
x1 = self.avg_pool(x.float()).to(torch.bfloat16)
else:
x1 = self.avg_pool(x)
x2 = self.conv(x)
return alpha * x1 + (1 - alpha) * x2
class VidTokUpsample3D(nn.Module):
r"""A 3D upsampling layer used in VidTok Model."""
def __init__(
self,
in_channels: int,
out_channels: int,
mix_factor: float = 2.0,
num_temp_upsample: int = 1,
is_causal: bool = True,
):
super().__init__()
make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d
self.conv = make_conv_cls(in_channels, out_channels, 3, padding=1)
self.mix_factor = nn.Parameter(torch.Tensor([mix_factor]))
self.is_causal = is_causal
if self.is_causal:
self.enable_cached = True
self.interpolation_mode = "trilinear"
self.is_first_chunk = True
self.causal_cache = None
self.num_temp_upsample = num_temp_upsample
else:
self.enable_cached = False
self.interpolation_mode = "nearest"
def forward(self, x: torch.Tensor) -> torch.Tensor:
alpha = torch.sigmoid(self.mix_factor)
if not self.is_causal:
xlst = [
F.interpolate(
sx.unsqueeze(0).to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode
).to(x.dtype)
for sx in x
]
x = torch.cat(xlst, dim=0)
else:
if not self.enable_cached:
x = F.interpolate(x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode).to(
x.dtype
)
elif not self.is_first_chunk:
x = torch.cat([self.causal_cache, x], dim=2)
self.causal_cache = x[:, :, -2 * self.num_temp_upsample : -self.num_temp_upsample].clone()
x = F.interpolate(x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode).to(
x.dtype
)
x = x[:, :, 2 * self.num_temp_upsample :]
else:
self.causal_cache = x[:, :, -self.num_temp_upsample :].clone()
x, _x = x[:, :, : self.num_temp_upsample], x[:, :, self.num_temp_upsample :]
x = F.interpolate(x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode).to(
x.dtype
)
if _x.shape[-3] > 0:
_x = F.interpolate(
_x.to(torch.float32), scale_factor=[2.0, 1.0, 1.0], mode=self.interpolation_mode
).to(_x.dtype)
x = torch.concat([x, _x], dim=2)
x_ = self.conv(x)
return alpha * x + (1 - alpha) * x_
class VidTokAttnBlock(nn.Module):
r"""A 3D self-attention block used in VidTok Model."""
def __init__(self, in_channels: int, is_causal: bool = True):
super().__init__()
make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d
self.norm = VidTokLayerNorm(dim=in_channels, eps=1e-6)
self.q = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = make_conv_cls(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
def attention(self, hidden_states: torch.Tensor) -> torch.Tensor:
r"""Implement self-attention."""
hidden_states = self.norm(hidden_states)
q = self.q(hidden_states)
k = self.k(hidden_states)
v = self.v(hidden_states)
b, c, t, h, w = q.shape
q, k, v = [x.permute(0, 2, 3, 4, 1).reshape(b, t, -1, c).contiguous() for x in [q, k, v]]
hidden_states = F.scaled_dot_product_attention(q, k, v) # scale is dim ** -0.5 per default
return hidden_states.reshape(b, t, h, w, c).permute(0, 4, 1, 2, 3)
def forward(self, x: torch.Tensor) -> torch.Tensor:
hidden_states = x
hidden_states = self.attention(hidden_states)
hidden_states = self.proj_out(hidden_states)
return x + hidden_states
class VidTokResnetBlock(nn.Module):
r"""A versatile ResNet block used in VidTok Model."""
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
conv_shortcut: bool = False,
dropout: float = 0.0,
temb_channels: int = 512,
btype: str = "3d",
is_causal: bool = True,
):
super().__init__()
assert btype in ["1d", "2d", "3d"], f"Invalid btype: {btype}"
if btype == "2d":
make_conv_cls = nn.Conv2d
elif btype == "1d":
make_conv_cls = VidTokCausalConv1d if is_causal else nn.Conv1d
else:
make_conv_cls = VidTokCausalConv3d if is_causal else nn.Conv3d
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.nonlinearity = nn.SiLU()
self.norm1 = VidTokLayerNorm(dim=in_channels, eps=1e-6)
self.conv1 = make_conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels > 0:
self.temb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = VidTokLayerNorm(dim=out_channels, eps=1e-6)
self.dropout = nn.Dropout(dropout)
self.conv2 = make_conv_cls(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = make_conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
else:
self.nin_shortcut = make_conv_cls(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x: torch.Tensor, temb: Optional[torch.Tensor]) -> torch.Tensor:
hidden_states = x
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
hidden_states = hidden_states + self.temb_proj(self.nonlinearity(temb))[:, :, None, None]
hidden_states = self.norm2(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)
return x + hidden_states
class VidTokEncoder3D(nn.Module):
r"""
The `VidTokEncoder3D` layer of a variational autoencoder that encodes its input into a latent representation.
Args:
in_channels (`int`):
The number of input channels.
ch (`int`):
The number of the basic channel.
ch_mult (`List[int]`, defaults to `[1, 2, 4, 8]`):
The multiple of the basic channel for each block.
num_res_blocks (`int`, defaults to 2):
The number of resblocks.
dropout (`float`, defaults to 0.0):
Dropout rate.
z_channels (`int`, defaults to 4):
The number of latent channels.
double_z (`bool`, defaults to `True`):
Whether or not to double the z_channels.
spatial_ds (`List`, *optional*, defaults to `None`):
Spatial downsample layers.
tempo_ds (`List`, *optional*, defaults to `None`):
Temporal downsample layers.
is_causal (`bool`, defaults to `True`):
Whether it is a causal module.
"""
def __init__(
self,
in_channels: int,
ch: int,
ch_mult: List[int] = [1, 2, 4, 8],
num_res_blocks: int = 2,
dropout: float = 0.0,
z_channels: int = 4,
double_z: bool = True,
spatial_ds: Optional[List] = None,
tempo_ds: Optional[List] = None,
is_causal: bool = True,
):
super().__init__()
self.is_causal = is_causal
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.in_channels = in_channels
self.nonlinearity = nn.SiLU()
make_conv_cls = VidTokCausalConv3d if self.is_causal else nn.Conv3d
self.conv_in = make_conv_cls(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.spatial_ds = list(range(0, self.num_resolutions - 1)) if spatial_ds is None else spatial_ds
self.tempo_ds = [self.num_resolutions - 2, self.num_resolutions - 3] if tempo_ds is None else tempo_ds
self.down = nn.ModuleList()
self.down_temporal = nn.ModuleList()
for i_level in range(self.num_resolutions):
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
block = nn.ModuleList()
attn = nn.ModuleList()
block_temporal = nn.ModuleList()
attn_temporal = nn.ModuleList()
for i_block in range(self.num_res_blocks):
block.append(
VidTokResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
btype="2d",
)
)
block_temporal.append(
VidTokResnetBlock(
in_channels=block_out,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
btype="1d",
is_causal=self.is_causal,
)
)
block_in = block_out
down = nn.Module()
down.block = block
down.attn = attn
down_temporal = nn.Module()
down_temporal.block = block_temporal
down_temporal.attn = attn_temporal
if i_level in self.spatial_ds:
down.downsample = VidTokDownsample2D(block_in)
if i_level in self.tempo_ds:
down_temporal.downsample = VidTokDownsample3D(block_in, block_in, is_causal=self.is_causal)
self.down.append(down)
self.down_temporal.append(down_temporal)
# middle
self.mid = nn.Module()
self.mid.block_1 = VidTokResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
btype="3d",
is_causal=self.is_causal,
)
self.mid.attn_1 = VidTokAttnBlock(block_in, is_causal=self.is_causal)
self.mid.block_2 = VidTokResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
btype="3d",
is_causal=self.is_causal,
)
# end
self.norm_out = VidTokLayerNorm(dim=block_in, eps=1e-6)
self.conv_out = make_conv_cls(
block_in,
2 * z_channels if double_z else z_channels,
kernel_size=3,
stride=1,
padding=1,
)
self.gradient_checkpointing = False
def forward(self, x: torch.Tensor) -> torch.Tensor:
temb = None
B, _, T, H, W = x.shape
hs = [self.conv_in(x)]
if torch.is_grad_enabled() and self.gradient_checkpointing:
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self._gradient_checkpointing_func(
self.down[i_level].block[i_block], hidden_states, temb
)
hidden_states = (
hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T)
)
hidden_states = self._gradient_checkpointing_func(
self.down_temporal[i_level].block[i_block], hidden_states, temb
)
hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2)
hs.append(hidden_states)
if i_level in self.spatial_ds:
# spatial downsample
hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self._gradient_checkpointing_func(self.down[i_level].downsample, hidden_states)
hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4)
if i_level in self.tempo_ds:
# temporal downsample
hidden_states = self._gradient_checkpointing_func(
self.down_temporal[i_level].downsample, hidden_states
)
hs.append(hidden_states)
B, _, T, H, W = hidden_states.shape
# middle
hidden_states = hs[-1]
hidden_states = self._gradient_checkpointing_func(self.mid.block_1, hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(self.mid.attn_1, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid.block_2, hidden_states, temb)
else:
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self.down[i_level].block[i_block](hidden_states, temb)
hidden_states = (
hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T)
)
hidden_states = self.down_temporal[i_level].block[i_block](hidden_states, temb)
hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2)
hs.append(hidden_states)
if i_level in self.spatial_ds:
# spatial downsample
hidden_states = hs[-1].permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self.down[i_level].downsample(hidden_states)
hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4)
if i_level in self.tempo_ds:
# temporal downsample
hidden_states = self.down_temporal[i_level].downsample(hidden_states)
hs.append(hidden_states)
B, _, T, H, W = hidden_states.shape
# middle
hidden_states = hs[-1]
hidden_states = self.mid.block_1(hidden_states, temb)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb)
# end
hidden_states = self.norm_out(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv_out(hidden_states)
return hidden_states
class VidTokDecoder3D(nn.Module):
r"""
The `VidTokDecoder3D` layer of a variational autoencoder that decodes its latent representation into an output
video.
Args:
ch (`int`):
The number of the basic channel.
ch_mult (`List[int]`, defaults to `[1, 2, 4, 8]`):
The multiple of the basic channel for each block.
num_res_blocks (`int`, defaults to 2):
The number of resblocks.
dropout (`float`, defaults to 0.0):
Dropout rate.
z_channels (`int`, defaults to 4):
The number of latent channels.
out_channels (`int`, defaults to 3):
The number of output channels.
spatial_us (`List`, *optional*, defaults to `None`):
Spatial upsample layers.
tempo_us (`List`, *optional*, defaults to `None`):
Temporal upsample layers.
is_causal (`bool`, defaults to `True`):
Whether it is a causal module.
"""
def __init__(
self,
ch: int,
ch_mult: List[int] = [1, 2, 4, 8],
num_res_blocks: int = 2,
dropout: float = 0.0,
z_channels: int = 4,
out_channels: int = 3,
spatial_us: Optional[List] = None,
tempo_us: Optional[List] = None,
is_causal: bool = True,
):
super().__init__()
self.is_causal = is_causal
self.ch = ch
self.temb_ch = 0
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.nonlinearity = nn.SiLU()
block_in = ch * ch_mult[self.num_resolutions - 1]
make_conv_cls = VidTokCausalConv3d if self.is_causal else nn.Conv3d
self.conv_in = make_conv_cls(z_channels, block_in, kernel_size=3, stride=1, padding=1)
# middle
self.mid = nn.Module()
self.mid.block_1 = VidTokResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
btype="3d",
is_causal=self.is_causal,
)
self.mid.attn_1 = VidTokAttnBlock(block_in, is_causal=self.is_causal)
self.mid.block_2 = VidTokResnetBlock(
in_channels=block_in,
out_channels=block_in,
temb_channels=self.temb_ch,
dropout=dropout,
btype="3d",
is_causal=self.is_causal,
)
# upsampling
self.spatial_us = list(range(1, self.num_resolutions)) if spatial_us is None else spatial_us
self.tempo_us = [1, 2] if tempo_us is None else tempo_us
self.up = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
VidTokResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
btype="2d",
)
)
block_in = block_out
up = nn.Module()
up.block = block
up.attn = attn
if i_level in self.spatial_us:
up.upsample = VidTokUpsample2D(block_in)
self.up.insert(0, up)
num_temp_upsample = 1
self.up_temporal = nn.ModuleList()
for i_level in reversed(range(self.num_resolutions)):
block = nn.ModuleList()
attn = nn.ModuleList()
block_in = ch * ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for i_block in range(self.num_res_blocks + 1):
block.append(
VidTokResnetBlock(
in_channels=block_in,
out_channels=block_out,
temb_channels=self.temb_ch,
dropout=dropout,
btype="1d",
is_causal=self.is_causal,
)
)
block_in = block_out
up_temporal = nn.Module()
up_temporal.block = block
up_temporal.attn = attn
if i_level in self.tempo_us:
up_temporal.upsample = VidTokUpsample3D(
block_in, block_in, num_temp_upsample=num_temp_upsample, is_causal=self.is_causal
)
num_temp_upsample *= 2
self.up_temporal.insert(0, up_temporal)
# end
self.norm_out = VidTokLayerNorm(dim=block_in, eps=1e-6)
self.conv_out = make_conv_cls(block_in, out_channels, kernel_size=3, stride=1, padding=1)
self.gradient_checkpointing = False
def forward(self, z: torch.Tensor) -> torch.Tensor:
temb = None
B, _, T, H, W = z.shape
hidden_states = self.conv_in(z)
if torch.is_grad_enabled() and self.gradient_checkpointing:
# middle
hidden_states = self._gradient_checkpointing_func(self.mid.block_1, hidden_states, temb)
hidden_states = self._gradient_checkpointing_func(self.mid.attn_1, hidden_states)
hidden_states = self._gradient_checkpointing_func(self.mid.block_2, hidden_states, temb)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self._gradient_checkpointing_func(
self.up[i_level].block[i_block], hidden_states, temb
)
hidden_states = (
hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T)
)
hidden_states = self._gradient_checkpointing_func(
self.up_temporal[i_level].block[i_block], hidden_states, temb
)
hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2)
if i_level in self.spatial_us:
# spatial upsample
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self._gradient_checkpointing_func(self.up[i_level].upsample, hidden_states)
hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4)
if i_level in self.tempo_us:
# temporal upsample
hidden_states = self._gradient_checkpointing_func(
self.up_temporal[i_level].upsample, hidden_states
)
B, _, T, H, W = hidden_states.shape
else:
# middle
hidden_states = self.mid.block_1(hidden_states, temb)
hidden_states = self.mid.attn_1(hidden_states)
hidden_states = self.mid.block_2(hidden_states, temb)
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self.up[i_level].block[i_block](hidden_states, temb)
hidden_states = (
hidden_states.reshape(B, T, -1, H, W).permute(0, 3, 4, 2, 1).reshape(B * H * W, -1, T)
)
hidden_states = self.up_temporal[i_level].block[i_block](hidden_states, temb)
hidden_states = hidden_states.reshape(B, H, W, -1, T).permute(0, 3, 4, 1, 2)
if i_level in self.spatial_us:
# spatial upsample
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(B * T, -1, H, W)
hidden_states = self.up[i_level].upsample(hidden_states)
hidden_states = hidden_states.reshape(B, T, -1, *hidden_states.shape[-2:]).permute(0, 2, 1, 3, 4)
if i_level in self.tempo_us:
# temporal upsample
hidden_states = self.up_temporal[i_level].upsample(hidden_states)
B, _, T, H, W = hidden_states.shape
# end
hidden_states = self.norm_out(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
out = self.conv_out(hidden_states)
return out
class AutoencoderVidTok(ModelMixin, ConfigMixin):
r"""
A VAE model for encoding videos into latents and decoding latent representations into videos, supporting both
continuous and discrete latent representations. Used in [VidTok](https://github.com/microsoft/VidTok).
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
for all models (such as downloading or saving).
Args:
in_channels (`int`, defaults to 3):
The number of input channels.
out_channels (`int`, defaults to 3):
The number of output channels.
ch (`int`, defaults to 128):
The number of the basic channel.
ch_mult (`List[int]`, defaults to `[1, 2, 4, 4]`):
The multiple of the basic channel for each block.
z_channels (`int`, defaults to 4):
The number of latent channels.
double_z (`bool`, defaults to `True`):
Whether or not to double the z_channels.
num_res_blocks (`int`, defaults to 2):
The number of resblocks.
spatial_ds (`List`, *optional*, defaults to `None`):
Spatial downsample layers.
spatial_us (`List`, *optional*, defaults to `None`):
Spatial upsample layers.
tempo_ds (`List`, *optional*, defaults to `None`):
Temporal downsample layers.
tempo_us (`List`, *optional*, defaults to `None`):
Temporal upsample layers.
dropout (`float`, defaults to 0.0):
Dropout rate.
regularizer (`str`, defaults to `"kl"`):
The regularizer type - "kl" for continuous cases and "fsq" for discrete cases.
codebook_size (`int`, defaults to 262144):
The codebook size used only in discrete cases.
is_causal (`bool`, defaults to `True`):
Whether it is a causal module.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
in_channels: int = 3,
out_channels: int = 3,
ch: int = 128,
ch_mult: List[int] = [1, 2, 4, 4],
z_channels: int = 4,
double_z: bool = True,
num_res_blocks: int = 2,
spatial_ds: Optional[List] = None,
spatial_us: Optional[List] = None,
tempo_ds: Optional[List] = None,
tempo_us: Optional[List] = None,
dropout: float = 0.0,
regularizer: str = "kl",
codebook_size: int = 262144,
is_causal: bool = True,
):
super().__init__()
self.is_causal = is_causal
self.encoder = VidTokEncoder3D(
in_channels=in_channels,
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
dropout=dropout,
z_channels=z_channels,
double_z=double_z,
spatial_ds=spatial_ds,
tempo_ds=tempo_ds,
is_causal=self.is_causal,
)
self.decoder = VidTokDecoder3D(
ch=ch,
ch_mult=ch_mult,
num_res_blocks=num_res_blocks,
dropout=dropout,
z_channels=z_channels,
out_channels=out_channels,
spatial_us=spatial_us,
tempo_us=tempo_us,
is_causal=self.is_causal,
)
self.temporal_compression_ratio = 2 ** len(self.encoder.tempo_ds)
self.regularizer = regularizer
if self.regularizer not in ["kl", "fsq"]:
raise ValueError(f"Invalid regularizer: {self.regularizer}. Only `kl` and `fsq` are supported.")
if self.regularizer == "fsq":
if z_channels != int(math.log(codebook_size, 8)):
raise ValueError(
f"When using the `fsq` regularizer, `z_channels` must be {int(math.log(codebook_size, 8))}, the"
f" log base 8 of the `codebook_size` {codebook_size}, but got {z_channels}."
)
if double_z:
raise ValueError("When using the `fsq` regularizer, `double_z` must be `False`.")
self.regularization = FSQRegularizer(levels=[8] * z_channels)
self.use_slicing = False
self.use_tiling = False
# Decode more latent frames at once
self.num_sample_frames_batch_size = 16
self.num_latent_frames_batch_size = self.num_sample_frames_batch_size // self.temporal_compression_ratio
# We make the minimum height and width of sample for tiling half that of the generally supported
self.tile_sample_min_height = 256
self.tile_sample_min_width = 256
self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** len(self.encoder.spatial_ds)))
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** len(self.encoder.spatial_ds)))
self.tile_overlap_factor_height = 0.0 # 1 / 8
self.tile_overlap_factor_width = 0.0 # 1 / 8
@staticmethod
def _pad_at_dim(
t: torch.Tensor, pad: Tuple[int], dim: int = -1, pad_mode: str = "constant", value: float = 0.0
) -> torch.Tensor:
r"""Pad function. Supported pad_mode: `constant`, `replicate`, `reflect`."""
dims_from_right = (-dim - 1) if dim < 0 else (t.ndim - dim - 1)
zeros = (0, 0) * dims_from_right
if pad_mode == "constant":
return F.pad(t, (*zeros, *pad), value=value)
return F.pad(t, (*zeros, *pad), mode=pad_mode)
def enable_tiling(
self,
tile_sample_min_height: Optional[int] = None,
tile_sample_min_width: Optional[int] = None,
tile_overlap_factor_height: Optional[float] = None,
tile_overlap_factor_width: Optional[float] = None,
) -> None:
r"""
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
processing larger images.
Args:
tile_sample_min_height (`int`, *optional*, defaults to `None`):
The minimum height required for a sample to be separated into tiles across the height dimension.
tile_sample_min_width (`int`, *optional*, defaults to `None`):
The minimum width required for a sample to be separated into tiles across the width dimension.
tile_overlap_factor_height (`float`, *optional*, defaults to `None`):
The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
tile_overlap_factor_width (`float`, *optional*, defaults to `None`):
The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
value might cause more tiles to be processed leading to slow down of the decoding process.
"""
self.use_tiling = True
self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
self.tile_latent_min_height = int(self.tile_sample_min_height / (2 ** len(self.encoder.spatial_ds)))
self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** len(self.encoder.spatial_ds)))
self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
def disable_tiling(self) -> None:
r"""
Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_tiling = False
def enable_slicing(self) -> None:
r"""
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
"""
self.use_slicing = True
def disable_slicing(self) -> None:
r"""
Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
decoding in one step.
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
self._empty_causal_cached(self.encoder)
self._set_first_chunk(True)
if self.use_tiling:
return self.tiled_encode(x)
return self.encoder(x)
@apply_forward_hook
def encode(self, x: torch.Tensor) -> Union[AutoencoderKLOutput, Tuple[torch.Tensor, torch.Tensor]]:
r"""
Encode a batch of images into latents.
Args:
x (`torch.Tensor`): Input batch of images.
Returns:
`AutoencoderKLOutput` or `Tuple[torch.Tensor]`:
The latent representations of the encoded videos. If the regularizer is `kl`, an `AutoencoderKLOutput`
is returned, otherwise a tuple of `torch.Tensor` is returned.
"""
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
z = torch.cat(encoded_slices)
else:
z = self._encode(x)
if self.regularizer == "kl":
posterior = DiagonalGaussianDistribution(z)
return AutoencoderKLOutput(latent_dist=posterior)
else:
quant_z, indices = self.regularization(z)
return quant_z, indices
def _decode(self, z: torch.Tensor, decode_from_indices: bool = False) -> torch.Tensor:
self._empty_causal_cached(self.decoder)
self._set_first_chunk(True)
if not self.is_causal and z.shape[-3] % self.num_latent_frames_batch_size != 0:
assert z.shape[-3] >= self.num_latent_frames_batch_size, (
f"Too short latent frames. At least {self.num_latent_frames_batch_size} frames."
)
z = z[..., : (z.shape[-3] // self.num_latent_frames_batch_size * self.num_latent_frames_batch_size), :, :]
if decode_from_indices:
z = self.tile_indices_to_latent(z) if self.use_tiling else self.indices_to_latent(z)
dec = self.tiled_decode(z) if self.use_tiling else self.decoder(z)
return dec
@apply_forward_hook
def decode(self, z: torch.Tensor, decode_from_indices: bool = False) -> torch.Tensor:
r"""
Decode a batch of images from latents.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
decode_from_indices (`bool`): If decode from indices or decode from latent code.
Returns:
`torch.Tensor`: The decoded images.
"""
if self.use_slicing and z.shape[0] > 1:
decoded_slices = [self._decode(z_slice, decode_from_indices=decode_from_indices) for z_slice in z.split(1)]
decoded = torch.cat(decoded_slices)
else:
decoded = self._decode(z, decode_from_indices=decode_from_indices)
if self.is_causal:
decoded = decoded[:, :, self.temporal_compression_ratio - 1 :, :, :]
return decoded
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
for y in range(blend_extent):
b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
y / blend_extent
)
return b
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
blend_extent = min(a.shape[4], b.shape[4], blend_extent)
for x in range(blend_extent):
b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
x / blend_extent
)
return b
def build_chunk_start_end(self, t, decoder_mode=False):
if self.is_causal:
start_end = [[0, self.temporal_compression_ratio]] if not decoder_mode else [[0, 1]]
start = start_end[0][-1]
else:
start_end, start = [], 0
end = start
while True:
if start >= t:
break
end = min(
t, end + (self.num_latent_frames_batch_size if decoder_mode else self.num_sample_frames_batch_size)
)
start_end.append([start, end])
start = end
if len(start_end) > (2 if self.is_causal else 1):
if start_end[-1][1] - start_end[-1][0] < (
self.num_latent_frames_batch_size if decoder_mode else self.num_sample_frames_batch_size
):
start_end[-2] = [start_end[-2][0], start_end[-1][1]]
start_end = start_end[:-1]
return start_end
def _set_first_chunk(self, is_first_chunk=True):
for module in self.modules():
if hasattr(module, "is_first_chunk"):
module.is_first_chunk = is_first_chunk
def _empty_causal_cached(self, parent):
for name, module in parent.named_modules():
if hasattr(module, "causal_cache"):
module.causal_cache = None
def _set_cache_offset(self, modules, cache_offset=0):
for module in modules:
for submodule in module.modules():
if hasattr(submodule, "cache_offset"):
submodule.cache_offset = cache_offset
def tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""
Encode a batch of images using a tiled encoder.
When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is
different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
output, but they should be much less noticeable.
Args:
x (`torch.Tensor`): Input batch of videos.
Returns:
`torch.Tensor`: The latent representation of the encoded videos.
"""
num_frames, height, width = x.shape[-3:]
overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_latent_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_latent_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_latent_min_height - blend_extent_height
row_limit_width = self.tile_latent_min_width - blend_extent_width
# Split x into overlapping tiles and encode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
start_end = self.build_chunk_start_end(num_frames)
time = []
for idx, (start_frame, end_frame) in enumerate(start_end):
self._set_first_chunk(idx == 0)
tile = x[
:,
:,
start_frame:end_frame,
i : i + self.tile_sample_min_height,
j : j + self.tile_sample_min_width,
]
tile = self.encoder(tile)
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
enc = torch.cat(result_rows, dim=3)
return enc
def indices_to_latent(self, token_indices: torch.Tensor) -> torch.Tensor:
r"""
Transform indices to latent code.
Args:
token_indices (`torch.Tensor`): Token indices.
Returns:
`torch.Tensor`: Latent code corresponding to the input token indices.
"""
b, t, h, w = token_indices.shape
token_indices = token_indices.unsqueeze(-1).reshape(b, -1, 1)
codes = self.regularization.indices_to_codes(token_indices)
codes = codes.permute(0, 2, 3, 1).reshape(b, codes.shape[2], -1)
z = self.regularization.project_out(codes)
return z.reshape(b, t, h, w, -1).permute(0, 4, 1, 2, 3)
def tile_indices_to_latent(self, token_indices: torch.Tensor) -> torch.Tensor:
r"""
Transform indices to latent code with tiling inference.
Args:
token_indices (`torch.Tensor`): Token indices.
Returns:
`torch.Tensor`: Latent code corresponding to the input token indices.
"""
num_frames = token_indices.shape[1]
start_end = self.build_chunk_start_end(num_frames, decoder_mode=True)
result_z = []
for start, end in start_end:
chunk_z = self.indices_to_latent(token_indices[:, start:end, :, :])
result_z.append(chunk_z.clone())
return torch.cat(result_z, dim=2)
def tiled_decode(self, z: torch.Tensor) -> torch.Tensor:
r"""
Decode a batch of images using a tiled decoder.
Args:
z (`torch.Tensor`): Input batch of latent vectors.
Returns:
`torch.Tensor`: Reconstructed batch of videos.
"""
num_frames, height, width = z.shape[-3:]
overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
row_limit_height = self.tile_sample_min_height - blend_extent_height
row_limit_width = self.tile_sample_min_width - blend_extent_width
# Split z into overlapping tiles and decode them separately.
# The tiles have an overlap to avoid seams between tiles.
rows = []
for i in range(0, height, overlap_height):
row = []
for j in range(0, width, overlap_width):
if self.is_causal:
assert self.temporal_compression_ratio in [
2,
4,
8,
], "Only support 2x, 4x or 8x temporal downsampling now."
if self.temporal_compression_ratio == 4:
self._set_cache_offset([self.decoder], 1)
self._set_cache_offset([self.decoder.up_temporal[2].upsample, self.decoder.up_temporal[1]], 2)
self._set_cache_offset(
[self.decoder.up_temporal[1].upsample, self.decoder.up_temporal[0], self.decoder.conv_out],
4,
)
elif self.temporal_compression_ratio == 2:
self._set_cache_offset([self.decoder], 1)
self._set_cache_offset(
[
self.decoder.up_temporal[2].upsample,
self.decoder.up_temporal[1],
self.decoder.up_temporal[0],
self.decoder.conv_out,
],
2,
)
else:
self._set_cache_offset([self.decoder], 1)
self._set_cache_offset([self.decoder.up_temporal[3].upsample, self.decoder.up_temporal[2]], 2)
self._set_cache_offset([self.decoder.up_temporal[2].upsample, self.decoder.up_temporal[1]], 4)
self._set_cache_offset(
[self.decoder.up_temporal[1].upsample, self.decoder.up_temporal[0], self.decoder.conv_out],
8,
)
start_end = self.build_chunk_start_end(num_frames, decoder_mode=True)
time = []
for idx, (start_frame, end_frame) in enumerate(start_end):
self._set_first_chunk(idx == 0)
tile = z[
:,
:,
start_frame : (end_frame + 1 if self.is_causal and end_frame + 1 <= num_frames else end_frame),
i : i + self.tile_latent_min_height,
j : j + self.tile_latent_min_width,
]
tile = self.decoder(tile)
if self.is_causal and end_frame + 1 <= num_frames:
tile = tile[:, :, : -self.temporal_compression_ratio]
time.append(tile)
row.append(torch.cat(time, dim=2))
rows.append(row)
result_rows = []
for i, row in enumerate(rows):
result_row = []
for j, tile in enumerate(row):
# blend the above tile and the left tile
# to the current tile and add the current tile to the result row
if i > 0:
tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
if j > 0:
tile = self.blend_h(row[j - 1], tile, blend_extent_width)
result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
result_rows.append(torch.cat(result_row, dim=4))
dec = torch.cat(result_rows, dim=3)
return dec
def forward(
self,
sample: torch.Tensor,
sample_posterior: bool = True,
encoder_mode: bool = False,
return_dict: bool = True,
generator: Optional[torch.Generator] = None,
) -> Union[torch.Tensor, DecoderOutput]:
r"""
Args:
sample (`torch.Tensor`): Input sample.
sample_posterior (`bool`, *optional*, defaults to `True`):
Whether to sample from the posterior.
encoder_mode (`bool`, *optional*, defaults to `False`):
If `True`, only run the encoder and return the encoded latent without decoding.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
generator (`torch.Generator`, *optional*):
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make sampling
deterministic.
Returns:
[`~models.vae.DecoderOutput`] or `torch.Tensor`:
If `return_dict` is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `torch.Tensor`
is returned.
"""
x = sample
res = 1 if self.is_causal else 0
if self.is_causal:
if x.shape[2] % self.temporal_compression_ratio != res:
time_padding = self.temporal_compression_ratio - x.shape[2] % self.temporal_compression_ratio + res
x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate")
else:
time_padding = 0
else:
if x.shape[2] % self.num_sample_frames_batch_size != res:
if not encoder_mode:
time_padding = (
self.num_sample_frames_batch_size - x.shape[2] % self.num_sample_frames_batch_size + res
)
x = self._pad_at_dim(x, (0, time_padding), dim=2, pad_mode="replicate")
else:
assert x.shape[2] >= self.num_sample_frames_batch_size, (
f"Too short video. At least {self.num_sample_frames_batch_size} frames."
)
x = x[:, :, : x.shape[2] // self.num_sample_frames_batch_size * self.num_sample_frames_batch_size]
else:
time_padding = 0
if self.is_causal:
x = self._pad_at_dim(x, (self.temporal_compression_ratio - 1, 0), dim=2, pad_mode="replicate")
if self.regularizer == "kl":
posterior = self.encode(x).latent_dist
if sample_posterior:
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
if encoder_mode:
return z
else:
z, indices = self.encode(x)
if encoder_mode:
return z, indices
dec = self.decode(z)
if time_padding != 0:
dec = dec[:, :, :-time_padding, :, :]
if not return_dict:
return dec
return DecoderOutput(sample=dec)