| import itertools |
| import math |
| import einops |
| from dataclasses import replace, dataclass |
| from typing import Any, Callable, Iterator, List, NamedTuple, Tuple, Union, Optional |
| import torch |
| from einops import rearrange |
| from torch import nn |
| from torch.nn import functional as F |
| from enum import Enum |
| from .ltx2_common import PixelNorm, SpatioTemporalScaleFactors, VideoLatentShape, Patchifier, AudioLatentShape |
| from .ltx2_dit import PixArtAlphaCombinedTimestepSizeEmbeddings |
|
|
| VAE_SPATIAL_FACTOR = 32 |
| VAE_TEMPORAL_FACTOR = 8 |
|
|
|
|
| class VideoLatentPatchifier(Patchifier): |
| def __init__(self, patch_size: int): |
| |
| self._patch_size = ( |
| 1, |
| patch_size, |
| patch_size, |
| ) |
|
|
| @property |
| def patch_size(self) -> Tuple[int, int, int]: |
| return self._patch_size |
|
|
| def get_token_count(self, tgt_shape: VideoLatentShape) -> int: |
| return math.prod(tgt_shape.to_torch_shape()[2:]) // math.prod(self._patch_size) |
|
|
| def patchify( |
| self, |
| latents: torch.Tensor, |
| ) -> torch.Tensor: |
| latents = einops.rearrange( |
| latents, |
| "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", |
| p1=self._patch_size[0], |
| p2=self._patch_size[1], |
| p3=self._patch_size[2], |
| ) |
|
|
| return latents |
|
|
| def unpatchify( |
| self, |
| latents: torch.Tensor, |
| output_shape: VideoLatentShape, |
| ) -> torch.Tensor: |
| assert self._patch_size[0] == 1, "Temporal patch size must be 1 for symmetric patchifier" |
|
|
| patch_grid_frames = output_shape.frames // self._patch_size[0] |
| patch_grid_height = output_shape.height // self._patch_size[1] |
| patch_grid_width = output_shape.width // self._patch_size[2] |
|
|
| latents = einops.rearrange( |
| latents, |
| "b (f h w) (c p q) -> b c f (h p) (w q)", |
| f=patch_grid_frames, |
| h=patch_grid_height, |
| w=patch_grid_width, |
| p=self._patch_size[1], |
| q=self._patch_size[2], |
| ) |
|
|
| return latents |
|
|
| def unpatchify_video( |
| self, |
| latents: torch.Tensor, |
| frames: int, |
| height: int, |
| width: int, |
| ) -> torch.Tensor: |
| latents = einops.rearrange( |
| latents, |
| "b (f h w) (c p q) -> b c f (h p) (w q)", |
| f=frames, |
| h=height // self._patch_size[1], |
| w=width // self._patch_size[2], |
| p=self._patch_size[1], |
| q=self._patch_size[2], |
| ) |
| return latents |
|
|
| def get_patch_grid_bounds( |
| self, |
| output_shape: AudioLatentShape | VideoLatentShape, |
| device: Optional[torch.device] = None, |
| ) -> torch.Tensor: |
| """ |
| Return the per-dimension bounds [inclusive start, exclusive end) for every |
| patch produced by `patchify`. The bounds are expressed in the original |
| video grid coordinates: frame/time, height, and width. |
| The resulting tensor is shaped `[batch_size, 3, num_patches, 2]`, where: |
| - axis 1 (size 3) enumerates (frame/time, height, width) dimensions |
| - axis 3 (size 2) stores `[start, end)` indices within each dimension |
| Args: |
| output_shape: Video grid description containing frames, height, and width. |
| device: Device of the latent tensor. |
| """ |
| if not isinstance(output_shape, VideoLatentShape): |
| raise ValueError("VideoLatentPatchifier expects VideoLatentShape when computing coordinates") |
|
|
| frames = output_shape.frames |
| height = output_shape.height |
| width = output_shape.width |
| batch_size = output_shape.batch |
|
|
| |
| assert frames > 0, f"frames must be positive, got {frames}" |
| assert height > 0, f"height must be positive, got {height}" |
| assert width > 0, f"width must be positive, got {width}" |
| assert batch_size > 0, f"batch_size must be positive, got {batch_size}" |
|
|
| |
| |
| |
| grid_coords = torch.meshgrid( |
| torch.arange(start=0, end=frames, step=self._patch_size[0], device=device), |
| torch.arange(start=0, end=height, step=self._patch_size[1], device=device), |
| torch.arange(start=0, end=width, step=self._patch_size[2], device=device), |
| indexing="ij", |
| ) |
|
|
| |
| |
| patch_starts = torch.stack(grid_coords, dim=0) |
|
|
| |
| |
| |
| patch_size_delta = torch.tensor( |
| self._patch_size, |
| device=patch_starts.device, |
| dtype=patch_starts.dtype, |
| ).view(3, 1, 1, 1) |
|
|
| |
| |
| patch_ends = patch_starts + patch_size_delta |
|
|
| |
| |
| latent_coords = torch.stack((patch_starts, patch_ends), dim=-1) |
|
|
| |
| |
| latent_coords = einops.repeat( |
| latent_coords, |
| "c f h w bounds -> b c (f h w) bounds", |
| b=batch_size, |
| bounds=2, |
| ) |
|
|
| return latent_coords |
|
|
|
|
| class NormLayerType(Enum): |
| GROUP_NORM = "group_norm" |
| PIXEL_NORM = "pixel_norm" |
|
|
|
|
| class LogVarianceType(Enum): |
| PER_CHANNEL = "per_channel" |
| UNIFORM = "uniform" |
| CONSTANT = "constant" |
| NONE = "none" |
|
|
|
|
| class PaddingModeType(Enum): |
| ZEROS = "zeros" |
| REFLECT = "reflect" |
| REPLICATE = "replicate" |
| CIRCULAR = "circular" |
|
|
|
|
| class DualConv3d(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: Union[int, Tuple[int, int, int]] = 1, |
| padding: Union[int, Tuple[int, int, int]] = 0, |
| dilation: Union[int, Tuple[int, int, int]] = 1, |
| groups: int = 1, |
| bias: bool = True, |
| padding_mode: str = "zeros", |
| ) -> None: |
| super(DualConv3d, self).__init__() |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
| self.padding_mode = padding_mode |
| |
| if isinstance(kernel_size, int): |
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| if kernel_size == (1, 1, 1): |
| raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") |
| if isinstance(stride, int): |
| stride = (stride, stride, stride) |
| if isinstance(padding, int): |
| padding = (padding, padding, padding) |
| if isinstance(dilation, int): |
| dilation = (dilation, dilation, dilation) |
|
|
| |
| self.groups = groups |
| self.bias = bias |
|
|
| |
| intermediate_channels = out_channels if in_channels < out_channels else in_channels |
|
|
| |
| self.weight1 = nn.Parameter( |
| torch.Tensor( |
| intermediate_channels, |
| in_channels // groups, |
| 1, |
| kernel_size[1], |
| kernel_size[2], |
| )) |
| self.stride1 = (1, stride[1], stride[2]) |
| self.padding1 = (0, padding[1], padding[2]) |
| self.dilation1 = (1, dilation[1], dilation[2]) |
| if bias: |
| self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) |
| else: |
| self.register_parameter("bias1", None) |
|
|
| |
| self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) |
| self.stride2 = (stride[0], 1, 1) |
| self.padding2 = (padding[0], 0, 0) |
| self.dilation2 = (dilation[0], 1, 1) |
| if bias: |
| self.bias2 = nn.Parameter(torch.Tensor(out_channels)) |
| else: |
| self.register_parameter("bias2", None) |
|
|
| |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.kaiming_uniform_(self.weight1, a=torch.sqrt(5)) |
| nn.init.kaiming_uniform_(self.weight2, a=torch.sqrt(5)) |
| if self.bias: |
| fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) |
| bound1 = 1 / torch.sqrt(fan_in1) |
| nn.init.uniform_(self.bias1, -bound1, bound1) |
| fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) |
| bound2 = 1 / torch.sqrt(fan_in2) |
| nn.init.uniform_(self.bias2, -bound2, bound2) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| use_conv3d: bool = False, |
| skip_time_conv: bool = False, |
| ) -> torch.Tensor: |
| if use_conv3d: |
| return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) |
| else: |
| return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) |
|
|
| def forward_with_3d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: |
| |
| x = F.conv3d( |
| x, |
| self.weight1, |
| self.bias1, |
| self.stride1, |
| self.padding1, |
| self.dilation1, |
| self.groups, |
| padding_mode=self.padding_mode, |
| ) |
|
|
| if skip_time_conv: |
| return x |
|
|
| |
| x = F.conv3d( |
| x, |
| self.weight2, |
| self.bias2, |
| self.stride2, |
| self.padding2, |
| self.dilation2, |
| self.groups, |
| padding_mode=self.padding_mode, |
| ) |
|
|
| return x |
|
|
| def forward_with_2d(self, x: torch.Tensor, skip_time_conv: bool = False) -> torch.Tensor: |
| b, _, _, h, w = x.shape |
|
|
| |
| x = rearrange(x, "b c d h w -> (b d) c h w") |
| |
| weight1 = self.weight1.squeeze(2) |
| |
| stride1 = (self.stride1[1], self.stride1[2]) |
| padding1 = (self.padding1[1], self.padding1[2]) |
| dilation1 = (self.dilation1[1], self.dilation1[2]) |
| x = F.conv2d( |
| x, |
| weight1, |
| self.bias1, |
| stride1, |
| padding1, |
| dilation1, |
| self.groups, |
| padding_mode=self.padding_mode, |
| ) |
|
|
| _, _, h, w = x.shape |
|
|
| if skip_time_conv: |
| x = rearrange(x, "(b d) c h w -> b c d h w", b=b) |
| return x |
|
|
| |
| x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) |
|
|
| |
| weight2 = self.weight2.squeeze(-1).squeeze(-1) |
| |
| stride2 = self.stride2[0] |
| padding2 = self.padding2[0] |
| dilation2 = self.dilation2[0] |
| x = F.conv1d( |
| x, |
| weight2, |
| self.bias2, |
| stride2, |
| padding2, |
| dilation2, |
| self.groups, |
| padding_mode=self.padding_mode, |
| ) |
| x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) |
|
|
| return x |
|
|
| @property |
| def weight(self) -> torch.Tensor: |
| return self.weight2 |
|
|
|
|
| class CausalConv3d(nn.Module): |
|
|
| def __init__( |
| self, |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int = 3, |
| stride: Union[int, Tuple[int]] = 1, |
| dilation: int = 1, |
| groups: int = 1, |
| bias: bool = True, |
| spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| ) -> None: |
| super().__init__() |
|
|
| self.in_channels = in_channels |
| self.out_channels = out_channels |
|
|
| kernel_size = (kernel_size, kernel_size, kernel_size) |
| self.time_kernel_size = kernel_size[0] |
|
|
| dilation = (dilation, 1, 1) |
|
|
| height_pad = kernel_size[1] // 2 |
| width_pad = kernel_size[2] // 2 |
| padding = (0, height_pad, width_pad) |
|
|
| self.conv = nn.Conv3d( |
| in_channels, |
| out_channels, |
| kernel_size, |
| stride=stride, |
| dilation=dilation, |
| padding=padding, |
| padding_mode=spatial_padding_mode.value, |
| groups=groups, |
| bias=bias, |
| ) |
|
|
| def forward(self, x: torch.Tensor, causal: bool = True) -> torch.Tensor: |
| if causal: |
| first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) |
| x = torch.concatenate((first_frame_pad, x), dim=2) |
| else: |
| first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) |
| last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)) |
| x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) |
| x = self.conv(x) |
| return x |
|
|
| @property |
| def weight(self) -> torch.Tensor: |
| return self.conv.weight |
|
|
|
|
| def make_conv_nd( |
| dims: Union[int, Tuple[int, int]], |
| in_channels: int, |
| out_channels: int, |
| kernel_size: int, |
| stride: int = 1, |
| padding: int = 0, |
| dilation: int = 1, |
| groups: int = 1, |
| bias: bool = True, |
| causal: bool = False, |
| spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| temporal_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| ) -> nn.Module: |
| if not (spatial_padding_mode == temporal_padding_mode or causal): |
| raise NotImplementedError("spatial and temporal padding modes must be equal") |
| if dims == 2: |
| return nn.Conv2d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| padding_mode=spatial_padding_mode.value, |
| ) |
| elif dims == 3: |
| if causal: |
| return CausalConv3d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| return nn.Conv3d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups, |
| bias=bias, |
| padding_mode=spatial_padding_mode.value, |
| ) |
| elif dims == (2, 1): |
| return DualConv3d( |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| bias=bias, |
| padding_mode=spatial_padding_mode.value, |
| ) |
| else: |
| raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
| def make_linear_nd( |
| dims: int, |
| in_channels: int, |
| out_channels: int, |
| bias: bool = True, |
| ) -> nn.Module: |
| if dims == 2: |
| return nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) |
| elif dims in (3, (2, 1)): |
| return nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias) |
| else: |
| raise ValueError(f"unsupported dimensions: {dims}") |
|
|
|
|
| def patchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: |
| """ |
| Rearrange spatial dimensions into channels. Divides image into patch_size x patch_size blocks |
| and moves pixels from each block into separate channels (space-to-depth). |
| Args: |
| x: Input tensor (4D or 5D) |
| patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, divides HxW into 4x4 blocks. |
| patch_size_t: Temporal patch size for frames. Default=1 (no temporal patching). |
| For 5D: (B, C, F, H, W) -> (B, Cx(patch_size_hw^2)x(patch_size_t), F/patch_size_t, H/patch_size_hw, W/patch_size_hw) |
| Example: (B, 3, 33, 512, 512) with patch_size_hw=4, patch_size_t=1 -> (B, 48, 33, 128, 128) |
| """ |
| if patch_size_hw == 1 and patch_size_t == 1: |
| return x |
| if x.dim() == 4: |
| x = rearrange(x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw) |
| elif x.dim() == 5: |
| x = rearrange( |
| x, |
| "b c (f p) (h q) (w r) -> b (c p r q) f h w", |
| p=patch_size_t, |
| q=patch_size_hw, |
| r=patch_size_hw, |
| ) |
| else: |
| raise ValueError(f"Invalid input shape: {x.shape}") |
|
|
| return x |
|
|
|
|
| def unpatchify(x: torch.Tensor, patch_size_hw: int, patch_size_t: int = 1) -> torch.Tensor: |
| """ |
| Rearrange channels back into spatial dimensions. Inverse of patchify - moves pixels from |
| channels back into patch_size x patch_size blocks (depth-to-space). |
| Args: |
| x: Input tensor (4D or 5D) |
| patch_size_hw: Spatial patch size for height and width. With patch_size_hw=4, expands HxW by 4x. |
| patch_size_t: Temporal patch size for frames. Default=1 (no temporal expansion). |
| For 5D: (B, Cx(patch_size_hw^2)x(patch_size_t), F, H, W) -> (B, C, Fxpatch_size_t, Hxpatch_size_hw, Wxpatch_size_hw) |
| Example: (B, 48, 33, 128, 128) with patch_size_hw=4, patch_size_t=1 -> (B, 3, 33, 512, 512) |
| """ |
| if patch_size_hw == 1 and patch_size_t == 1: |
| return x |
|
|
| if x.dim() == 4: |
| x = rearrange(x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw) |
| elif x.dim() == 5: |
| x = rearrange( |
| x, |
| "b (c p r q) f h w -> b c (f p) (h q) (w r)", |
| p=patch_size_t, |
| q=patch_size_hw, |
| r=patch_size_hw, |
| ) |
|
|
| return x |
|
|
|
|
| class PerChannelStatistics(nn.Module): |
| """ |
| Per-channel statistics for normalizing and denormalizing the latent representation. |
| This statics is computed over the entire dataset and stored in model's checkpoint under VAE state_dict. |
| """ |
|
|
| def __init__(self, latent_channels: int = 128): |
| super().__init__() |
| self.register_buffer("std-of-means", torch.empty(latent_channels)) |
| self.register_buffer("mean-of-means", torch.empty(latent_channels)) |
|
|
| def un_normalize(self, x: torch.Tensor) -> torch.Tensor: |
| return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view( |
| 1, -1, 1, 1, 1).to(x) |
|
|
| def normalize(self, x: torch.Tensor) -> torch.Tensor: |
| return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view( |
| 1, -1, 1, 1, 1).to(x) |
|
|
|
|
| class ResnetBlock3D(nn.Module): |
| r""" |
| A Resnet block. |
| Parameters: |
| in_channels (`int`): The number of channels in the input. |
| out_channels (`int`, *optional*, default to be `None`): |
| The number of output channels for the first conv layer. If None, same as `in_channels`. |
| dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. |
| groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. |
| eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. |
| """ |
|
|
| def __init__( |
| self, |
| dims: Union[int, Tuple[int, int]], |
| in_channels: int, |
| out_channels: Optional[int] = None, |
| dropout: float = 0.0, |
| groups: int = 32, |
| eps: float = 1e-6, |
| norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, |
| inject_noise: bool = False, |
| timestep_conditioning: bool = False, |
| spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| ): |
| super().__init__() |
| self.in_channels = in_channels |
| out_channels = in_channels if out_channels is None else out_channels |
| self.out_channels = out_channels |
| self.inject_noise = inject_noise |
|
|
| if norm_layer == NormLayerType.GROUP_NORM: |
| self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) |
| elif norm_layer == NormLayerType.PIXEL_NORM: |
| self.norm1 = PixelNorm() |
|
|
| self.non_linearity = nn.SiLU() |
|
|
| self.conv1 = make_conv_nd( |
| dims, |
| in_channels, |
| out_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
|
|
| if inject_noise: |
| self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) |
|
|
| if norm_layer == NormLayerType.GROUP_NORM: |
| self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) |
| elif norm_layer == NormLayerType.PIXEL_NORM: |
| self.norm2 = PixelNorm() |
|
|
| self.dropout = torch.nn.Dropout(dropout) |
|
|
| self.conv2 = make_conv_nd( |
| dims, |
| out_channels, |
| out_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
|
|
| if inject_noise: |
| self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) |
|
|
| self.conv_shortcut = (make_linear_nd(dims=dims, in_channels=in_channels, out_channels=out_channels) |
| if in_channels != out_channels else nn.Identity()) |
|
|
| |
| |
| self.norm3 = (nn.GroupNorm(num_groups=1, num_channels=in_channels, eps=eps, affine=True) |
| if in_channels != out_channels else nn.Identity()) |
|
|
| self.timestep_conditioning = timestep_conditioning |
|
|
| if timestep_conditioning: |
| self.scale_shift_table = nn.Parameter(torch.zeros(4, in_channels)) |
|
|
| def _feed_spatial_noise( |
| self, |
| hidden_states: torch.Tensor, |
| per_channel_scale: torch.Tensor, |
| generator: Optional[torch.Generator] = None, |
| ) -> torch.Tensor: |
| spatial_shape = hidden_states.shape[-2:] |
| device = hidden_states.device |
| dtype = hidden_states.dtype |
|
|
| |
| spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype, generator=generator)[None] |
| scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] |
| hidden_states = hidden_states + scaled_noise |
|
|
| return hidden_states |
|
|
| def forward( |
| self, |
| input_tensor: torch.Tensor, |
| causal: bool = True, |
| timestep: Optional[torch.Tensor] = None, |
| generator: Optional[torch.Generator] = None, |
| ) -> torch.Tensor: |
| hidden_states = input_tensor |
| batch_size = hidden_states.shape[0] |
|
|
| hidden_states = self.norm1(hidden_states) |
| if self.timestep_conditioning: |
| if timestep is None: |
| raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") |
| ada_values = self.scale_shift_table[None, ..., None, None, None].to( |
| device=hidden_states.device, dtype=hidden_states.dtype) + timestep.reshape( |
| batch_size, |
| 4, |
| -1, |
| timestep.shape[-3], |
| timestep.shape[-2], |
| timestep.shape[-1], |
| ) |
| shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) |
|
|
| hidden_states = hidden_states * (1 + scale1) + shift1 |
|
|
| hidden_states = self.non_linearity(hidden_states) |
|
|
| hidden_states = self.conv1(hidden_states, causal=causal) |
|
|
| if self.inject_noise: |
| hidden_states = self._feed_spatial_noise( |
| hidden_states, |
| self.per_channel_scale1.to(device=hidden_states.device, dtype=hidden_states.dtype), |
| generator=generator, |
| ) |
|
|
| hidden_states = self.norm2(hidden_states) |
|
|
| if self.timestep_conditioning: |
| hidden_states = hidden_states * (1 + scale2) + shift2 |
|
|
| hidden_states = self.non_linearity(hidden_states) |
|
|
| hidden_states = self.dropout(hidden_states) |
|
|
| hidden_states = self.conv2(hidden_states, causal=causal) |
|
|
| if self.inject_noise: |
| hidden_states = self._feed_spatial_noise( |
| hidden_states, |
| self.per_channel_scale2.to(device=hidden_states.device, dtype=hidden_states.dtype), |
| generator=generator, |
| ) |
|
|
| input_tensor = self.norm3(input_tensor) |
|
|
| batch_size = input_tensor.shape[0] |
|
|
| input_tensor = self.conv_shortcut(input_tensor) |
|
|
| output_tensor = input_tensor + hidden_states |
|
|
| return output_tensor |
|
|
|
|
| class UNetMidBlock3D(nn.Module): |
| """ |
| A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. |
| Args: |
| in_channels (`int`): The number of input channels. |
| dropout (`float`, *optional*, defaults to 0.0): The dropout rate. |
| num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. |
| resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. |
| resnet_groups (`int`, *optional*, defaults to 32): |
| The number of groups to use in the group normalization layers of the resnet blocks. |
| norm_layer (`str`, *optional*, defaults to `group_norm`): |
| The normalization layer to use. Can be either `group_norm` or `pixel_norm`. |
| inject_noise (`bool`, *optional*, defaults to `False`): |
| Whether to inject noise into the hidden states. |
| timestep_conditioning (`bool`, *optional*, defaults to `False`): |
| Whether to condition the hidden states on the timestep. |
| Returns: |
| `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, |
| in_channels, height, width)`. |
| """ |
|
|
| def __init__( |
| self, |
| dims: Union[int, Tuple[int, int]], |
| in_channels: int, |
| dropout: float = 0.0, |
| num_layers: int = 1, |
| resnet_eps: float = 1e-6, |
| resnet_groups: int = 32, |
| norm_layer: NormLayerType = NormLayerType.GROUP_NORM, |
| inject_noise: bool = False, |
| timestep_conditioning: bool = False, |
| spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| ): |
| super().__init__() |
| resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) |
|
|
| self.timestep_conditioning = timestep_conditioning |
|
|
| if timestep_conditioning: |
| self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=in_channels * 4, |
| size_emb_dim=0) |
|
|
| self.res_blocks = nn.ModuleList([ |
| ResnetBlock3D( |
| dims=dims, |
| in_channels=in_channels, |
| out_channels=in_channels, |
| eps=resnet_eps, |
| groups=resnet_groups, |
| dropout=dropout, |
| norm_layer=norm_layer, |
| inject_noise=inject_noise, |
| timestep_conditioning=timestep_conditioning, |
| spatial_padding_mode=spatial_padding_mode, |
| ) for _ in range(num_layers) |
| ]) |
|
|
| def forward( |
| self, |
| hidden_states: torch.Tensor, |
| causal: bool = True, |
| timestep: Optional[torch.Tensor] = None, |
| generator: Optional[torch.Generator] = None, |
| ) -> torch.Tensor: |
| timestep_embed = None |
| if self.timestep_conditioning: |
| if timestep is None: |
| raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") |
| batch_size = hidden_states.shape[0] |
| timestep_embed = self.time_embedder( |
| timestep=timestep.flatten(), |
| hidden_dtype=hidden_states.dtype, |
| ) |
| timestep_embed = timestep_embed.view(batch_size, timestep_embed.shape[-1], 1, 1, 1) |
|
|
| for resnet in self.res_blocks: |
| hidden_states = resnet( |
| hidden_states, |
| causal=causal, |
| timestep=timestep_embed, |
| generator=generator, |
| ) |
|
|
| return hidden_states |
|
|
|
|
| class SpaceToDepthDownsample(nn.Module): |
|
|
| def __init__( |
| self, |
| dims: Union[int, Tuple[int, int]], |
| in_channels: int, |
| out_channels: int, |
| stride: Tuple[int, int, int], |
| spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| ): |
| super().__init__() |
| self.stride = stride |
| self.group_size = in_channels * math.prod(stride) // out_channels |
| self.conv = make_conv_nd( |
| dims=dims, |
| in_channels=in_channels, |
| out_channels=out_channels // math.prod(stride), |
| kernel_size=3, |
| stride=1, |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| causal: bool = True, |
| ) -> torch.Tensor: |
| if self.stride[0] == 2: |
| x = torch.cat([x[:, :, :1, :, :], x], dim=2) |
|
|
| |
| x_in = rearrange( |
| x, |
| "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", |
| p1=self.stride[0], |
| p2=self.stride[1], |
| p3=self.stride[2], |
| ) |
| x_in = rearrange(x_in, "b (c g) d h w -> b c g d h w", g=self.group_size) |
| x_in = x_in.mean(dim=2) |
|
|
| |
| x = self.conv(x, causal=causal) |
| x = rearrange( |
| x, |
| "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", |
| p1=self.stride[0], |
| p2=self.stride[1], |
| p3=self.stride[2], |
| ) |
|
|
| x = x + x_in |
|
|
| return x |
|
|
|
|
| class DepthToSpaceUpsample(nn.Module): |
|
|
| def __init__( |
| self, |
| dims: int | Tuple[int, int], |
| in_channels: int, |
| stride: Tuple[int, int, int], |
| residual: bool = False, |
| out_channels_reduction_factor: int = 1, |
| spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| ): |
| super().__init__() |
| self.stride = stride |
| self.out_channels = math.prod(stride) * in_channels // out_channels_reduction_factor |
| self.conv = make_conv_nd( |
| dims=dims, |
| in_channels=in_channels, |
| out_channels=self.out_channels, |
| kernel_size=3, |
| stride=1, |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| self.residual = residual |
| self.out_channels_reduction_factor = out_channels_reduction_factor |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| causal: bool = True, |
| ) -> torch.Tensor: |
| if self.residual: |
| |
| x_in = rearrange( |
| x, |
| "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", |
| p1=self.stride[0], |
| p2=self.stride[1], |
| p3=self.stride[2], |
| ) |
| num_repeat = math.prod(self.stride) // self.out_channels_reduction_factor |
| x_in = x_in.repeat(1, num_repeat, 1, 1, 1) |
| if self.stride[0] == 2: |
| x_in = x_in[:, :, 1:, :, :] |
| x = self.conv(x, causal=causal) |
| x = rearrange( |
| x, |
| "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", |
| p1=self.stride[0], |
| p2=self.stride[1], |
| p3=self.stride[2], |
| ) |
| if self.stride[0] == 2: |
| x = x[:, :, 1:, :, :] |
| if self.residual: |
| x = x + x_in |
| return x |
|
|
|
|
| def compute_trapezoidal_mask_1d( |
| length: int, |
| ramp_left: int, |
| ramp_right: int, |
| left_starts_from_0: bool = False, |
| ) -> torch.Tensor: |
| """ |
| Generate a 1D trapezoidal blending mask with linear ramps. |
| Args: |
| length: Output length of the mask. |
| ramp_left: Fade-in length on the left. |
| ramp_right: Fade-out length on the right. |
| left_starts_from_0: Whether the ramp starts from 0 or first non-zero value. |
| Useful for temporal tiles where the first tile is causal. |
| Returns: |
| A 1D tensor of shape `(length,)` with values in [0, 1]. |
| """ |
| if length <= 0: |
| raise ValueError("Mask length must be positive.") |
|
|
| ramp_left = max(0, min(ramp_left, length)) |
| ramp_right = max(0, min(ramp_right, length)) |
|
|
| mask = torch.ones(length) |
|
|
| if ramp_left > 0: |
| interval_length = ramp_left + 1 if left_starts_from_0 else ramp_left + 2 |
| fade_in = torch.linspace(0.0, 1.0, interval_length)[:-1] |
| if not left_starts_from_0: |
| fade_in = fade_in[1:] |
| mask[:ramp_left] *= fade_in |
|
|
| if ramp_right > 0: |
| fade_out = torch.linspace(1.0, 0.0, steps=ramp_right + 2)[1:-1] |
| mask[-ramp_right:] *= fade_out |
|
|
| return mask.clamp_(0, 1) |
|
|
|
|
| @dataclass(frozen=True) |
| class SpatialTilingConfig: |
| """Configuration for dividing each frame into spatial tiles with optional overlap. |
| Args: |
| tile_size_in_pixels (int): Size of each tile in pixels. Must be at least 64 and divisible by 32. |
| tile_overlap_in_pixels (int, optional): Overlap between tiles in pixels. Must be divisible by 32. Defaults to 0. |
| """ |
|
|
| tile_size_in_pixels: int |
| tile_overlap_in_pixels: int = 0 |
|
|
| def __post_init__(self) -> None: |
| if self.tile_size_in_pixels < 64: |
| raise ValueError(f"tile_size_in_pixels must be at least 64, got {self.tile_size_in_pixels}") |
| if self.tile_size_in_pixels % 32 != 0: |
| raise ValueError(f"tile_size_in_pixels must be divisible by 32, got {self.tile_size_in_pixels}") |
| if self.tile_overlap_in_pixels % 32 != 0: |
| raise ValueError(f"tile_overlap_in_pixels must be divisible by 32, got {self.tile_overlap_in_pixels}") |
| if self.tile_overlap_in_pixels >= self.tile_size_in_pixels: |
| raise ValueError( |
| f"Overlap must be less than tile size, got {self.tile_overlap_in_pixels} and {self.tile_size_in_pixels}" |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class TemporalTilingConfig: |
| """Configuration for dividing a video into temporal tiles (chunks of frames) with optional overlap. |
| Args: |
| tile_size_in_frames (int): Number of frames in each tile. Must be at least 16 and divisible by 8. |
| tile_overlap_in_frames (int, optional): Number of overlapping frames between consecutive tiles. |
| Must be divisible by 8. Defaults to 0. |
| """ |
|
|
| tile_size_in_frames: int |
| tile_overlap_in_frames: int = 0 |
|
|
| def __post_init__(self) -> None: |
| if self.tile_size_in_frames < 16: |
| raise ValueError(f"tile_size_in_frames must be at least 16, got {self.tile_size_in_frames}") |
| if self.tile_size_in_frames % 8 != 0: |
| raise ValueError(f"tile_size_in_frames must be divisible by 8, got {self.tile_size_in_frames}") |
| if self.tile_overlap_in_frames % 8 != 0: |
| raise ValueError(f"tile_overlap_in_frames must be divisible by 8, got {self.tile_overlap_in_frames}") |
| if self.tile_overlap_in_frames >= self.tile_size_in_frames: |
| raise ValueError( |
| f"Overlap must be less than tile size, got {self.tile_overlap_in_frames} and {self.tile_size_in_frames}" |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class TilingConfig: |
| """Configuration for splitting video into tiles with optional overlap. |
| Attributes: |
| spatial_config: Configuration for splitting spatial dimensions into tiles. |
| temporal_config: Configuration for splitting temporal dimension into tiles. |
| """ |
|
|
| spatial_config: SpatialTilingConfig | None = None |
| temporal_config: TemporalTilingConfig | None = None |
|
|
| @classmethod |
| def default(cls) -> "TilingConfig": |
| return cls( |
| spatial_config=SpatialTilingConfig(tile_size_in_pixels=512, tile_overlap_in_pixels=64), |
| temporal_config=TemporalTilingConfig(tile_size_in_frames=64, tile_overlap_in_frames=24), |
| ) |
|
|
|
|
| @dataclass(frozen=True) |
| class DimensionIntervals: |
| """Intervals which a single dimension of the latent space is split into. |
| Each interval is defined by its start, end, left ramp, and right ramp. |
| The start and end are the indices of the first and last element (exclusive) in the interval. |
| Ramps are regions of the interval where the value of the mask tensor is |
| interpolated between 0 and 1 for blending with neighboring intervals. |
| The left ramp and right ramp values are the lengths of the left and right ramps. |
| """ |
|
|
| starts: List[int] |
| ends: List[int] |
| left_ramps: List[int] |
| right_ramps: List[int] |
|
|
|
|
| @dataclass(frozen=True) |
| class LatentIntervals: |
| """Intervals which the latent tensor of given shape is split into. |
| Each dimension of the latent space is split into intervals based on the length along said dimension. |
| """ |
|
|
| original_shape: torch.Size |
| dimension_intervals: Tuple[DimensionIntervals, ...] |
|
|
|
|
| |
| SplitOperation = Callable[[int], DimensionIntervals] |
| |
| MappingOperation = Callable[[DimensionIntervals], tuple[list[slice], list[torch.Tensor | None]]] |
|
|
|
|
| def default_split_operation(length: int) -> DimensionIntervals: |
| return DimensionIntervals(starts=[0], ends=[length], left_ramps=[0], right_ramps=[0]) |
|
|
|
|
| DEFAULT_SPLIT_OPERATION: SplitOperation = default_split_operation |
|
|
|
|
| def default_mapping_operation(_intervals: DimensionIntervals,) -> tuple[list[slice], list[torch.Tensor | None]]: |
| return [slice(0, None)], [None] |
|
|
|
|
| DEFAULT_MAPPING_OPERATION: MappingOperation = default_mapping_operation |
|
|
|
|
| class Tile(NamedTuple): |
| """ |
| Represents a single tile. |
| Attributes: |
| in_coords: |
| Tuple of slices specifying where to cut the tile from the INPUT tensor. |
| out_coords: |
| Tuple of slices specifying where this tile's OUTPUT should be placed in the reconstructed OUTPUT tensor. |
| masks_1d: |
| Per-dimension masks in OUTPUT units. |
| These are used to create all-dimensional blending mask. |
| Methods: |
| blend_mask: |
| Create a single N-D mask from the per-dimension masks. |
| """ |
|
|
| in_coords: Tuple[slice, ...] |
| out_coords: Tuple[slice, ...] |
| masks_1d: Tuple[Tuple[torch.Tensor, ...]] |
|
|
| @property |
| def blend_mask(self) -> torch.Tensor: |
| num_dims = len(self.out_coords) |
| per_dimension_masks: List[torch.Tensor] = [] |
|
|
| for dim_idx in range(num_dims): |
| mask_1d = self.masks_1d[dim_idx] |
| view_shape = [1] * num_dims |
| if mask_1d is None: |
| |
| one = torch.ones(1) |
|
|
| view_shape[dim_idx] = 1 |
| per_dimension_masks.append(one.view(*view_shape)) |
| continue |
|
|
| |
| view_shape[dim_idx] = mask_1d.shape[0] |
| per_dimension_masks.append(mask_1d.view(*view_shape)) |
|
|
| |
| combined_mask = per_dimension_masks[0] |
| for mask in per_dimension_masks[1:]: |
| combined_mask = combined_mask * mask |
|
|
| return combined_mask |
|
|
|
|
| def create_tiles_from_intervals_and_mappers( |
| intervals: LatentIntervals, |
| mappers: List[MappingOperation], |
| ) -> List[Tile]: |
| full_dim_input_slices = [] |
| full_dim_output_slices = [] |
| full_dim_masks_1d = [] |
| for axis_index in range(len(intervals.original_shape)): |
| dimension_intervals = intervals.dimension_intervals[axis_index] |
| starts = dimension_intervals.starts |
| ends = dimension_intervals.ends |
| input_slices = [slice(s, e) for s, e in zip(starts, ends, strict=True)] |
| output_slices, masks_1d = mappers[axis_index](dimension_intervals) |
| full_dim_input_slices.append(input_slices) |
| full_dim_output_slices.append(output_slices) |
| full_dim_masks_1d.append(masks_1d) |
|
|
| tiles = [] |
| tile_in_coords = list(itertools.product(*full_dim_input_slices)) |
| tile_out_coords = list(itertools.product(*full_dim_output_slices)) |
| tile_mask_1ds = list(itertools.product(*full_dim_masks_1d)) |
| for in_coord, out_coord, mask_1d in zip(tile_in_coords, tile_out_coords, tile_mask_1ds, strict=True): |
| tiles.append(Tile( |
| in_coords=in_coord, |
| out_coords=out_coord, |
| masks_1d=mask_1d, |
| )) |
| return tiles |
|
|
|
|
| def create_tiles( |
| latent_shape: torch.Size, |
| splitters: List[SplitOperation], |
| mappers: List[MappingOperation], |
| ) -> List[Tile]: |
| if len(splitters) != len(latent_shape): |
| raise ValueError(f"Number of splitters must be equal to number of dimensions in latent shape, " |
| f"got {len(splitters)} and {len(latent_shape)}") |
| if len(mappers) != len(latent_shape): |
| raise ValueError(f"Number of mappers must be equal to number of dimensions in latent shape, " |
| f"got {len(mappers)} and {len(latent_shape)}") |
| intervals = [splitter(length) for splitter, length in zip(splitters, latent_shape, strict=True)] |
| latent_intervals = LatentIntervals(original_shape=latent_shape, dimension_intervals=tuple(intervals)) |
| return create_tiles_from_intervals_and_mappers(latent_intervals, mappers) |
|
|
|
|
| def _make_encoder_block( |
| block_name: str, |
| block_config: dict[str, Any], |
| in_channels: int, |
| convolution_dimensions: int, |
| norm_layer: NormLayerType, |
| norm_num_groups: int, |
| spatial_padding_mode: PaddingModeType, |
| ) -> Tuple[nn.Module, int]: |
| out_channels = in_channels |
|
|
| if block_name == "res_x": |
| block = UNetMidBlock3D( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| num_layers=block_config["num_layers"], |
| resnet_eps=1e-6, |
| resnet_groups=norm_num_groups, |
| norm_layer=norm_layer, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "res_x_y": |
| out_channels = in_channels * block_config.get("multiplier", 2) |
| block = ResnetBlock3D( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| eps=1e-6, |
| groups=norm_num_groups, |
| norm_layer=norm_layer, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_time": |
| block = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| stride=(2, 1, 1), |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_space": |
| block = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| stride=(1, 2, 2), |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_all": |
| block = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| stride=(2, 2, 2), |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_all_x_y": |
| out_channels = in_channels * block_config.get("multiplier", 2) |
| block = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| stride=(2, 2, 2), |
| causal=True, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_all_res": |
| out_channels = in_channels * block_config.get("multiplier", 2) |
| block = SpaceToDepthDownsample( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| stride=(2, 2, 2), |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_space_res": |
| out_channels = in_channels * block_config.get("multiplier", 2) |
| block = SpaceToDepthDownsample( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| stride=(1, 2, 2), |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_time_res": |
| out_channels = in_channels * block_config.get("multiplier", 2) |
| block = SpaceToDepthDownsample( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| stride=(2, 1, 1), |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| else: |
| raise ValueError(f"unknown block: {block_name}") |
|
|
| return block, out_channels |
|
|
|
|
| class LTX2VideoEncoder(nn.Module): |
| _DEFAULT_NORM_NUM_GROUPS = 32 |
| """ |
| Variational Autoencoder Encoder. Encodes video frames into a latent representation. |
| The encoder compresses the input video through a series of downsampling operations controlled by |
| patch_size and encoder_blocks. The output is a normalized latent tensor with shape (B, 128, F', H', W'). |
| Compression Behavior: |
| The total compression is determined by: |
| 1. Initial spatial compression via patchify: H -> H/4, W -> W/4 (patch_size=4) |
| 2. Sequential compression through encoder_blocks based on their stride patterns |
| Compression blocks apply 2x compression in specified dimensions: |
| - "compress_time" / "compress_time_res": temporal only |
| - "compress_space" / "compress_space_res": spatial only (H and W) |
| - "compress_all" / "compress_all_res": all dimensions (F, H, W) |
| - "res_x" / "res_x_y": no compression |
| Standard LTX Video configuration: |
| - patch_size=4 |
| - encoder_blocks: 1x compress_space_res, 1x compress_time_res, 2x compress_all_res |
| - Final dimensions: F' = 1 + (F-1)/8, H' = H/32, W' = W/32 |
| - Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16) |
| - Note: Input must have 1 + 8*k frames (e.g., 1, 9, 17, 25, 33...) |
| Args: |
| convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). |
| in_channels: The number of input channels. For RGB images, this is 3. |
| out_channels: The number of output channels (latent channels). For latent channels, this is 128. |
| encoder_blocks: The list of blocks to construct the encoder. Each block is a tuple of (block_name, params) |
| where params is either an int (num_layers) or a dict with configuration. |
| patch_size: The patch size for initial spatial compression. Should be a power of 2. |
| norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. |
| latent_log_var: The log variance mode. Can be either `per_channel`, `uniform`, `constant` or `none`. |
| """ |
|
|
| def __init__( |
| self, |
| convolution_dimensions: int = 3, |
| in_channels: int = 3, |
| out_channels: int = 128, |
| patch_size: int = 4, |
| norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, |
| latent_log_var: LogVarianceType = LogVarianceType.UNIFORM, |
| encoder_spatial_padding_mode: PaddingModeType = PaddingModeType.ZEROS, |
| encoder_version: str = "ltx-2", |
| ): |
| super().__init__() |
| if encoder_version == "ltx-2": |
| encoder_blocks = [ |
| ['res_x', {'num_layers': 4}], |
| ['compress_space_res', {'multiplier': 2}], |
| ['res_x', {'num_layers': 6}], |
| ['compress_time_res', {'multiplier': 2}], |
| ['res_x', {'num_layers': 6}], |
| ['compress_all_res', {'multiplier': 2}], |
| ['res_x', {'num_layers': 2}], |
| ['compress_all_res', {'multiplier': 2}], |
| ['res_x', {'num_layers': 2}] |
| ] |
| else: |
| |
| encoder_blocks = [ |
| ["res_x", {"num_layers": 4}], |
| ["compress_space_res", {"multiplier": 2}], |
| ["res_x", {"num_layers": 6}], |
| ["compress_time_res", {"multiplier": 2}], |
| ["res_x", {"num_layers": 4}], |
| ["compress_all_res", {"multiplier": 2}], |
| ["res_x", {"num_layers": 2}], |
| ["compress_all_res", {"multiplier": 1}], |
| ["res_x", {"num_layers": 2}] |
| ] |
| self.patch_size = patch_size |
| self.norm_layer = norm_layer |
| self.latent_channels = out_channels |
| self.latent_log_var = latent_log_var |
| self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS |
|
|
| |
| self.per_channel_statistics = PerChannelStatistics(latent_channels=out_channels) |
|
|
| in_channels = in_channels * patch_size**2 |
| feature_channels = out_channels |
|
|
| self.conv_in = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=feature_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| causal=True, |
| spatial_padding_mode=encoder_spatial_padding_mode, |
| ) |
|
|
| self.down_blocks = nn.ModuleList([]) |
|
|
| for block_name, block_params in encoder_blocks: |
| |
| block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params |
|
|
| block, feature_channels = _make_encoder_block( |
| block_name=block_name, |
| block_config=block_config, |
| in_channels=feature_channels, |
| convolution_dimensions=convolution_dimensions, |
| norm_layer=norm_layer, |
| norm_num_groups=self._norm_num_groups, |
| spatial_padding_mode=encoder_spatial_padding_mode, |
| ) |
|
|
| self.down_blocks.append(block) |
|
|
| |
| if norm_layer == NormLayerType.GROUP_NORM: |
| self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) |
| elif norm_layer == NormLayerType.PIXEL_NORM: |
| self.conv_norm_out = PixelNorm() |
|
|
| self.conv_act = nn.SiLU() |
|
|
| conv_out_channels = out_channels |
| if latent_log_var == LogVarianceType.PER_CHANNEL: |
| conv_out_channels *= 2 |
| elif latent_log_var in {LogVarianceType.UNIFORM, LogVarianceType.CONSTANT}: |
| conv_out_channels += 1 |
| elif latent_log_var != LogVarianceType.NONE: |
| raise ValueError(f"Invalid latent_log_var: {latent_log_var}") |
|
|
| self.conv_out = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=feature_channels, |
| out_channels=conv_out_channels, |
| kernel_size=3, |
| padding=1, |
| causal=True, |
| spatial_padding_mode=encoder_spatial_padding_mode, |
| ) |
|
|
| def forward(self, sample: torch.Tensor) -> torch.Tensor: |
| r""" |
| Encode video frames into normalized latent representation. |
| Args: |
| sample: Input video (B, C, F, H, W). F must be 1 + 8*k (e.g., 1, 9, 17, 25, 33...). |
| Returns: |
| Normalized latent means (B, 128, F', H', W') where F' = 1+(F-1)/8, H' = H/32, W' = W/32. |
| Example: (B, 3, 33, 512, 512) -> (B, 128, 5, 16, 16). |
| """ |
| |
| frames_count = sample.shape[2] |
| if ((frames_count - 1) % 8) != 0: |
| frames_to_crop = (frames_count - 1) % 8 |
| sample = sample[:, :, :-frames_to_crop, ...] |
|
|
| |
| |
| |
| sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) |
| sample = self.conv_in(sample) |
|
|
| for down_block in self.down_blocks: |
| sample = down_block(sample) |
|
|
| sample = self.conv_norm_out(sample) |
| sample = self.conv_act(sample) |
| sample = self.conv_out(sample) |
|
|
| if self.latent_log_var == LogVarianceType.UNIFORM: |
| |
| |
| |
| |
| |
|
|
| if sample.shape[1] < 2: |
| raise ValueError(f"Invalid channel count for UNIFORM mode: expected at least 2 channels " |
| f"(N means + 1 logvar), got {sample.shape[1]}") |
|
|
| |
| means = sample[:, :-1, ...] |
| logvar = sample[:, -1:, ...] |
|
|
| |
| |
| num_channels = means.shape[1] |
| repeat_shape = [1, num_channels] + [1] * (sample.ndim - 2) |
| repeated_logvar = logvar.repeat(*repeat_shape) |
|
|
| |
| sample = torch.cat([means, repeated_logvar], dim=1) |
| elif self.latent_log_var == LogVarianceType.CONSTANT: |
| sample = sample[:, :-1, ...] |
| approx_ln_0 = -30 |
| sample = torch.cat( |
| [sample, torch.ones_like(sample, device=sample.device) * approx_ln_0], |
| dim=1, |
| ) |
|
|
| |
| means, _ = torch.chunk(sample, 2, dim=1) |
| return self.per_channel_statistics.normalize(means) |
|
|
|
|
| def tiled_encode_video( |
| self, |
| video: torch.Tensor, |
| tile_size: int = 512, |
| tile_overlap: int = 128, |
| ) -> torch.Tensor: |
| """Encode video using spatial tiling for memory efficiency. |
| Splits the video into overlapping spatial tiles, encodes each tile separately, |
| and blends the results using linear feathering in the overlap regions. |
| Args: |
| video: Input tensor of shape [B, C, F, H, W] |
| tile_size: Tile size in pixels (must be divisible by 32) |
| tile_overlap: Overlap between tiles in pixels (must be divisible by 32) |
| Returns: |
| Encoded latent tensor [B, C_latent, F_latent, H_latent, W_latent] |
| """ |
| batch, _channels, frames, height, width = video.shape |
| device = video.device |
| dtype = video.dtype |
|
|
| |
| if tile_size % VAE_SPATIAL_FACTOR != 0: |
| raise ValueError(f"tile_size must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_size}") |
| if tile_overlap % VAE_SPATIAL_FACTOR != 0: |
| raise ValueError(f"tile_overlap must be divisible by {VAE_SPATIAL_FACTOR}, got {tile_overlap}") |
| if tile_overlap >= tile_size: |
| raise ValueError(f"tile_overlap ({tile_overlap}) must be less than tile_size ({tile_size})") |
|
|
| |
| if height <= tile_size and width <= tile_size: |
| return self.forward(video) |
|
|
| |
| |
| output_height = height // VAE_SPATIAL_FACTOR |
| output_width = width // VAE_SPATIAL_FACTOR |
| output_frames = 1 + (frames - 1) // VAE_TEMPORAL_FACTOR |
|
|
| |
| |
| latent_channels = 128 |
|
|
| |
| output = torch.zeros( |
| (batch, latent_channels, output_frames, output_height, output_width), |
| device=device, |
| dtype=dtype, |
| ) |
| weights = torch.zeros( |
| (batch, 1, output_frames, output_height, output_width), |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| |
| |
| step_h = tile_size - tile_overlap |
| step_w = tile_size - tile_overlap |
|
|
| h_positions = list(range(0, max(1, height - tile_overlap), step_h)) |
| w_positions = list(range(0, max(1, width - tile_overlap), step_w)) |
|
|
| |
| if h_positions[-1] + tile_size < height: |
| h_positions.append(height - tile_size) |
| if w_positions[-1] + tile_size < width: |
| w_positions.append(width - tile_size) |
|
|
| |
| h_positions = sorted(set(h_positions)) |
| w_positions = sorted(set(w_positions)) |
|
|
| |
| overlap_out_h = tile_overlap // VAE_SPATIAL_FACTOR |
| overlap_out_w = tile_overlap // VAE_SPATIAL_FACTOR |
|
|
| |
| for h_pos in h_positions: |
| for w_pos in w_positions: |
| |
| h_start = max(0, h_pos) |
| w_start = max(0, w_pos) |
| h_end = min(h_start + tile_size, height) |
| w_end = min(w_start + tile_size, width) |
|
|
| |
| tile_h = ((h_end - h_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR |
| tile_w = ((w_end - w_start) // VAE_SPATIAL_FACTOR) * VAE_SPATIAL_FACTOR |
|
|
| if tile_h < VAE_SPATIAL_FACTOR or tile_w < VAE_SPATIAL_FACTOR: |
| continue |
|
|
| |
| h_end = h_start + tile_h |
| w_end = w_start + tile_w |
|
|
| |
| tile = video[:, :, :, h_start:h_end, w_start:w_end] |
|
|
| |
| encoded_tile = self.forward(tile) |
|
|
| |
| _, _, tile_out_frames, tile_out_height, tile_out_width = encoded_tile.shape |
|
|
| |
| out_h_start = h_start // VAE_SPATIAL_FACTOR |
| out_w_start = w_start // VAE_SPATIAL_FACTOR |
| out_h_end = min(out_h_start + tile_out_height, output_height) |
| out_w_end = min(out_w_start + tile_out_width, output_width) |
|
|
| |
| actual_tile_h = out_h_end - out_h_start |
| actual_tile_w = out_w_end - out_w_start |
| encoded_tile = encoded_tile[:, :, :, :actual_tile_h, :actual_tile_w] |
|
|
| |
| mask = torch.ones( |
| (1, 1, tile_out_frames, actual_tile_h, actual_tile_w), |
| device=device, |
| dtype=dtype, |
| ) |
|
|
| |
| |
| if h_pos > 0 and overlap_out_h > 0 and overlap_out_h < actual_tile_h: |
| fade_in = torch.linspace(0.0, 1.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] |
| mask[:, :, :, :overlap_out_h, :] *= fade_in.view(1, 1, 1, -1, 1) |
|
|
| |
| if h_end < height and overlap_out_h > 0 and overlap_out_h < actual_tile_h: |
| fade_out = torch.linspace(1.0, 0.0, overlap_out_h + 2, device=device, dtype=dtype)[1:-1] |
| mask[:, :, :, -overlap_out_h:, :] *= fade_out.view(1, 1, 1, -1, 1) |
|
|
| |
| if w_pos > 0 and overlap_out_w > 0 and overlap_out_w < actual_tile_w: |
| fade_in = torch.linspace(0.0, 1.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] |
| mask[:, :, :, :, :overlap_out_w] *= fade_in.view(1, 1, 1, 1, -1) |
|
|
| |
| if w_end < width and overlap_out_w > 0 and overlap_out_w < actual_tile_w: |
| fade_out = torch.linspace(1.0, 0.0, overlap_out_w + 2, device=device, dtype=dtype)[1:-1] |
| mask[:, :, :, :, -overlap_out_w:] *= fade_out.view(1, 1, 1, 1, -1) |
|
|
| |
| output[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += encoded_tile * mask |
| weights[:, :, :, out_h_start:out_h_end, out_w_start:out_w_end] += mask |
|
|
| |
| output = output / (weights + 1e-8) |
|
|
| return output |
|
|
| def encode( |
| self, |
| video: torch.Tensor, |
| tiled=False, |
| tile_size_in_pixels: Optional[int] = 512, |
| tile_overlap_in_pixels: Optional[int] = 128, |
| **kwargs, |
| ) -> torch.Tensor: |
| if video.ndim == 4: |
| video = video.unsqueeze(0) |
| |
| if tiled: |
| latents = self.tiled_encode_video( |
| video=video, |
| tile_size=tile_size_in_pixels, |
| tile_overlap=tile_overlap_in_pixels, |
| ) |
| else: |
| |
| latents = self.forward(video) |
| return latents |
|
|
|
|
| def _make_decoder_block( |
| block_name: str, |
| block_config: dict[str, Any], |
| in_channels: int, |
| convolution_dimensions: int, |
| norm_layer: NormLayerType, |
| timestep_conditioning: bool, |
| norm_num_groups: int, |
| spatial_padding_mode: PaddingModeType, |
| ) -> Tuple[nn.Module, int]: |
| out_channels = in_channels |
| if block_name == "res_x": |
| block = UNetMidBlock3D( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| num_layers=block_config["num_layers"], |
| resnet_eps=1e-6, |
| resnet_groups=norm_num_groups, |
| norm_layer=norm_layer, |
| inject_noise=block_config.get("inject_noise", False), |
| timestep_conditioning=timestep_conditioning, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "attn_res_x": |
| block = UNetMidBlock3D( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| num_layers=block_config["num_layers"], |
| resnet_groups=norm_num_groups, |
| norm_layer=norm_layer, |
| inject_noise=block_config.get("inject_noise", False), |
| timestep_conditioning=timestep_conditioning, |
| attention_head_dim=block_config["attention_head_dim"], |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "res_x_y": |
| out_channels = in_channels // block_config.get("multiplier", 2) |
| block = ResnetBlock3D( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=out_channels, |
| eps=1e-6, |
| groups=norm_num_groups, |
| norm_layer=norm_layer, |
| inject_noise=block_config.get("inject_noise", False), |
| timestep_conditioning=False, |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_time": |
| out_channels = in_channels // block_config.get("multiplier", 1) |
| block = DepthToSpaceUpsample( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| stride=(2, 1, 1), |
| out_channels_reduction_factor=block_config.get("multiplier", 1), |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_space": |
| out_channels = in_channels // block_config.get("multiplier", 1) |
| block = DepthToSpaceUpsample( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| stride=(1, 2, 2), |
| out_channels_reduction_factor=block_config.get("multiplier", 1), |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| elif block_name == "compress_all": |
| out_channels = in_channels // block_config.get("multiplier", 1) |
| block = DepthToSpaceUpsample( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| stride=(2, 2, 2), |
| residual=block_config.get("residual", False), |
| out_channels_reduction_factor=block_config.get("multiplier", 1), |
| spatial_padding_mode=spatial_padding_mode, |
| ) |
| else: |
| raise ValueError(f"unknown layer: {block_name}") |
|
|
| return block, out_channels |
|
|
|
|
| class LTX2VideoDecoder(nn.Module): |
| _DEFAULT_NORM_NUM_GROUPS = 32 |
| """ |
| Variational Autoencoder Decoder. Decodes latent representation into video frames. |
| The decoder upsamples latents through a series of upsampling operations (inverse of encoder). |
| Output dimensions: F = 8x(F'-1) + 1, H = 32xH', W = 32xW' for standard LTX Video configuration. |
| Upsampling blocks expand dimensions by 2x in specified dimensions: |
| - "compress_time": temporal only |
| - "compress_space": spatial only (H and W) |
| - "compress_all": all dimensions (F, H, W) |
| - "res_x" / "res_x_y" / "attn_res_x": no upsampling |
| Causal Mode: |
| causal=False (standard): Symmetric padding, allows future frame dependencies. |
| causal=True: Causal padding, each frame depends only on past/current frames. |
| First frame removed after temporal upsampling in both modes. Output shape unchanged. |
| Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512) for both modes. |
| Args: |
| convolution_dimensions: The number of dimensions to use in convolutions (2D or 3D). |
| in_channels: The number of input channels (latent channels). Default is 128. |
| out_channels: The number of output channels. For RGB images, this is 3. |
| decoder_blocks: The list of blocks to construct the decoder. Each block is a tuple of (block_name, params) |
| where params is either an int (num_layers) or a dict with configuration. |
| patch_size: Final spatial expansion factor. For standard LTX Video, use 4 for 4x spatial expansion: |
| H -> Hx4, W -> Wx4. Should be a power of 2. |
| norm_layer: The normalization layer to use. Can be either `group_norm` or `pixel_norm`. |
| causal: Whether to use causal convolutions. For standard LTX Video, use False for symmetric padding. |
| When True, uses causal padding (past/current frames only). |
| timestep_conditioning: Whether to condition the decoder on timestep for denoising. |
| """ |
|
|
| def __init__( |
| self, |
| convolution_dimensions: int = 3, |
| in_channels: int = 128, |
| out_channels: int = 3, |
| decoder_blocks: List[Tuple[str, int | dict]] = [], |
| patch_size: int = 4, |
| norm_layer: NormLayerType = NormLayerType.PIXEL_NORM, |
| causal: bool = False, |
| timestep_conditioning: bool = False, |
| decoder_spatial_padding_mode: PaddingModeType = PaddingModeType.REFLECT, |
| decoder_version: str = "ltx-2", |
| base_channels: int = 128, |
| ): |
| super().__init__() |
|
|
| |
| |
| |
| |
| |
| if decoder_version == "ltx-2": |
| decoder_blocks = [ |
| ['res_x', {'num_layers': 5, 'inject_noise': False}], |
| ['compress_all', {'residual': True, 'multiplier': 2}], |
| ['res_x', {'num_layers': 5, 'inject_noise': False}], |
| ['compress_all', {'residual': True, 'multiplier': 2}], |
| ['res_x', {'num_layers': 5, 'inject_noise': False}], |
| ['compress_all', {'residual': True, 'multiplier': 2}], |
| ['res_x', {'num_layers': 5, 'inject_noise': False}] |
| ] |
| else: |
| |
| decoder_blocks = [ |
| ["res_x", {"num_layers": 4}], |
| ["compress_space", {"multiplier": 2}], |
| ["res_x", {"num_layers": 6}], |
| ["compress_time", {"multiplier": 2}], |
| ["res_x", {"num_layers": 4}], |
| ["compress_all", {"multiplier": 1}], |
| ["res_x", {"num_layers": 2}], |
| ["compress_all", {"multiplier": 2}], |
| ["res_x", {"num_layers": 2}] |
| ] |
| self.video_downscale_factors = SpatioTemporalScaleFactors( |
| time=8, |
| width=32, |
| height=32, |
| ) |
|
|
| self.patch_size = patch_size |
| out_channels = out_channels * patch_size**2 |
| self.causal = causal |
| self.timestep_conditioning = timestep_conditioning |
| self._norm_num_groups = self._DEFAULT_NORM_NUM_GROUPS |
|
|
| |
| self.per_channel_statistics = PerChannelStatistics(latent_channels=in_channels) |
|
|
| |
| self.decode_noise_scale = 0.025 |
| self.decode_timestep = 0.05 |
|
|
| |
| |
| feature_channels = base_channels * 8 |
|
|
| self.conv_in = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=in_channels, |
| out_channels=feature_channels, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| causal=True, |
| spatial_padding_mode=decoder_spatial_padding_mode, |
| ) |
|
|
| self.up_blocks = nn.ModuleList([]) |
|
|
| for block_name, block_params in list(reversed(decoder_blocks)): |
| |
| block_config = {"num_layers": block_params} if isinstance(block_params, int) else block_params |
|
|
| block, feature_channels = _make_decoder_block( |
| block_name=block_name, |
| block_config=block_config, |
| in_channels=feature_channels, |
| convolution_dimensions=convolution_dimensions, |
| norm_layer=norm_layer, |
| timestep_conditioning=timestep_conditioning, |
| norm_num_groups=self._norm_num_groups, |
| spatial_padding_mode=decoder_spatial_padding_mode, |
| ) |
|
|
| self.up_blocks.append(block) |
|
|
| if norm_layer == NormLayerType.GROUP_NORM: |
| self.conv_norm_out = nn.GroupNorm(num_channels=feature_channels, num_groups=self._norm_num_groups, eps=1e-6) |
| elif norm_layer == NormLayerType.PIXEL_NORM: |
| self.conv_norm_out = PixelNorm() |
|
|
| self.conv_act = nn.SiLU() |
| self.conv_out = make_conv_nd( |
| dims=convolution_dimensions, |
| in_channels=feature_channels, |
| out_channels=out_channels, |
| kernel_size=3, |
| padding=1, |
| causal=True, |
| spatial_padding_mode=decoder_spatial_padding_mode, |
| ) |
|
|
| if timestep_conditioning: |
| self.timestep_scale_multiplier = nn.Parameter(torch.tensor(1000.0)) |
| self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings(embedding_dim=feature_channels * 2, |
| size_emb_dim=0) |
| self.last_scale_shift_table = nn.Parameter(torch.empty(2, feature_channels)) |
|
|
| def forward( |
| self, |
| sample: torch.Tensor, |
| timestep: torch.Tensor | None = None, |
| generator: torch.Generator | None = None, |
| ) -> torch.Tensor: |
| r""" |
| Decode latent representation into video frames. |
| Args: |
| sample: Latent tensor (B, 128, F', H', W'). |
| timestep: Timestep for conditioning (if timestep_conditioning=True). Uses default 0.05 if None. |
| generator: Random generator for deterministic noise injection (if inject_noise=True in blocks). |
| Returns: |
| Decoded video (B, 3, F, H, W) where F = 8x(F'-1) + 1, H = 32xH', W = 32xW'. |
| Example: (B, 128, 5, 16, 16) -> (B, 3, 33, 512, 512). |
| Note: First frame is removed after temporal upsampling regardless of causal mode. |
| When causal=False, allows future frame dependencies in convolutions but maintains same output shape. |
| """ |
| batch_size = sample.shape[0] |
|
|
| |
| if self.timestep_conditioning: |
| noise = (torch.randn( |
| sample.size(), |
| generator=generator, |
| dtype=sample.dtype, |
| device=sample.device, |
| ) * self.decode_noise_scale) |
|
|
| sample = noise + (1.0 - self.decode_noise_scale) * sample |
|
|
| |
| sample = self.per_channel_statistics.un_normalize(sample) |
|
|
| |
| if timestep is None and self.timestep_conditioning: |
| timestep = torch.full((batch_size,), self.decode_timestep, device=sample.device, dtype=sample.dtype) |
|
|
| sample = self.conv_in(sample, causal=self.causal) |
|
|
| scaled_timestep = None |
| if self.timestep_conditioning: |
| if timestep is None: |
| raise ValueError("'timestep' parameter must be provided when 'timestep_conditioning' is True") |
| scaled_timestep = timestep * self.timestep_scale_multiplier.to(sample) |
|
|
| for up_block in self.up_blocks: |
| if isinstance(up_block, UNetMidBlock3D): |
| block_kwargs = { |
| "causal": self.causal, |
| "timestep": scaled_timestep if self.timestep_conditioning else None, |
| "generator": generator, |
| } |
| sample = up_block(sample, **block_kwargs) |
| elif isinstance(up_block, ResnetBlock3D): |
| sample = up_block(sample, causal=self.causal, generator=generator) |
| else: |
| sample = up_block(sample, causal=self.causal) |
|
|
| sample = self.conv_norm_out(sample) |
|
|
| if self.timestep_conditioning: |
| embedded_timestep = self.last_time_embedder( |
| timestep=scaled_timestep.flatten(), |
| hidden_dtype=sample.dtype, |
| ) |
| embedded_timestep = embedded_timestep.view(batch_size, embedded_timestep.shape[-1], 1, 1, 1) |
| ada_values = self.last_scale_shift_table[None, ..., None, None, None].to( |
| device=sample.device, dtype=sample.dtype) + embedded_timestep.reshape( |
| batch_size, |
| 2, |
| -1, |
| embedded_timestep.shape[-3], |
| embedded_timestep.shape[-2], |
| embedded_timestep.shape[-1], |
| ) |
| shift, scale = ada_values.unbind(dim=1) |
| sample = sample * (1 + scale) + shift |
|
|
| sample = self.conv_act(sample) |
| sample = self.conv_out(sample, causal=self.causal) |
|
|
| |
| |
| |
| sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) |
|
|
| return sample |
|
|
| def _prepare_tiles( |
| self, |
| latent: torch.Tensor, |
| tiling_config: TilingConfig | None = None, |
| ) -> List[Tile]: |
| splitters = [DEFAULT_SPLIT_OPERATION] * len(latent.shape) |
| mappers = [DEFAULT_MAPPING_OPERATION] * len(latent.shape) |
| if tiling_config is not None and tiling_config.spatial_config is not None: |
| cfg = tiling_config.spatial_config |
| long_side = max(latent.shape[3], latent.shape[4]) |
|
|
| def enable_on_axis(axis_idx: int, factor: int) -> None: |
| size = cfg.tile_size_in_pixels // factor |
| overlap = cfg.tile_overlap_in_pixels // factor |
| axis_length = latent.shape[axis_idx] |
| lower_threshold = max(2, overlap + 1) |
| tile_size = max(lower_threshold, round(size * axis_length / long_side)) |
| splitters[axis_idx] = split_in_spatial(tile_size, overlap) |
| mappers[axis_idx] = to_mapping_operation(map_spatial_slice, factor) |
|
|
| enable_on_axis(3, self.video_downscale_factors.height) |
| enable_on_axis(4, self.video_downscale_factors.width) |
|
|
| if tiling_config is not None and tiling_config.temporal_config is not None: |
| cfg = tiling_config.temporal_config |
| tile_size = cfg.tile_size_in_frames // self.video_downscale_factors.time |
| overlap = cfg.tile_overlap_in_frames // self.video_downscale_factors.time |
| splitters[2] = split_in_temporal(tile_size, overlap) |
| mappers[2] = to_mapping_operation(map_temporal_slice, self.video_downscale_factors.time) |
|
|
| return create_tiles(latent.shape, splitters, mappers) |
|
|
| def tiled_decode( |
| self, |
| latent: torch.Tensor, |
| tiling_config: TilingConfig | None = None, |
| timestep: torch.Tensor | None = None, |
| generator: torch.Generator | None = None, |
| ) -> Iterator[torch.Tensor]: |
| """ |
| Decode a latent tensor into video frames using tiled processing. |
| Splits the latent tensor into tiles, decodes each tile individually, |
| and yields video chunks as they become available. |
| Args: |
| latent: Input latent tensor (B, C, F', H', W'). |
| tiling_config: Tiling configuration for the latent tensor. |
| timestep: Optional timestep for decoder conditioning. |
| generator: Optional random generator for deterministic decoding. |
| Yields: |
| Video chunks (B, C, T, H, W) by temporal slices; |
| """ |
|
|
| |
| full_video_shape = VideoLatentShape.from_torch_shape(latent.shape).upscale(self.video_downscale_factors) |
| tiles = self._prepare_tiles(latent, tiling_config) |
|
|
| temporal_groups = self._group_tiles_by_temporal_slice(tiles) |
|
|
| |
| previous_chunk = None |
| previous_weights = None |
| previous_temporal_slice = None |
|
|
| for temporal_group_tiles in temporal_groups: |
| curr_temporal_slice = temporal_group_tiles[0].out_coords[2] |
|
|
| |
| |
| |
| |
| |
| temporal_tile_buffer_shape = full_video_shape._replace(frames=curr_temporal_slice.stop - |
| curr_temporal_slice.start,) |
|
|
| buffer = torch.zeros( |
| temporal_tile_buffer_shape.to_torch_shape(), |
| device=latent.device, |
| dtype=latent.dtype, |
| ) |
|
|
| curr_weights = self._accumulate_temporal_group_into_buffer( |
| group_tiles=temporal_group_tiles, |
| buffer=buffer, |
| latent=latent, |
| timestep=timestep, |
| generator=generator, |
| ) |
|
|
| |
| if previous_chunk is not None: |
| |
| if previous_temporal_slice.stop > curr_temporal_slice.start: |
| overlap_len = previous_temporal_slice.stop - curr_temporal_slice.start |
| temporal_overlap_slice = slice(curr_temporal_slice.start - previous_temporal_slice.start, None) |
|
|
| |
| |
| |
| |
| previous_chunk[:, :, temporal_overlap_slice, :, :] += buffer[:, :, slice(0, overlap_len), :, :] |
| previous_weights[:, :, temporal_overlap_slice, :, :] += curr_weights[:, :, |
| slice(0, overlap_len), :, :] |
|
|
| buffer[:, :, slice(0, overlap_len), :, :] = previous_chunk[:, :, temporal_overlap_slice, :, :] |
| curr_weights[:, :, slice(0, overlap_len), :, :] = previous_weights[:, :, |
| temporal_overlap_slice, :, :] |
|
|
| |
| previous_weights = previous_weights.clamp(min=1e-8) |
| yield_len = curr_temporal_slice.start - previous_temporal_slice.start |
| yield (previous_chunk / previous_weights)[:, :, :yield_len, :, :] |
|
|
| |
| previous_chunk = buffer |
| previous_weights = curr_weights |
| previous_temporal_slice = curr_temporal_slice |
|
|
| |
| if previous_chunk is not None: |
| previous_weights = previous_weights.clamp(min=1e-8) |
| yield previous_chunk / previous_weights |
|
|
| def _group_tiles_by_temporal_slice(self, tiles: List[Tile]) -> List[List[Tile]]: |
| """Group tiles by their temporal output slice.""" |
| if not tiles: |
| return [] |
|
|
| groups = [] |
| current_slice = tiles[0].out_coords[2] |
| current_group = [] |
|
|
| for tile in tiles: |
| tile_slice = tile.out_coords[2] |
| if tile_slice == current_slice: |
| current_group.append(tile) |
| else: |
| groups.append(current_group) |
| current_slice = tile_slice |
| current_group = [tile] |
|
|
| |
| if current_group: |
| groups.append(current_group) |
|
|
| return groups |
|
|
| def _accumulate_temporal_group_into_buffer( |
| self, |
| group_tiles: List[Tile], |
| buffer: torch.Tensor, |
| latent: torch.Tensor, |
| timestep: torch.Tensor | None, |
| generator: torch.Generator | None, |
| ) -> torch.Tensor: |
| """ |
| Decode and accumulate all tiles of a temporal group into a local buffer. |
| The buffer is local to the group and always starts at time 0; temporal coordinates |
| are rebased by subtracting temporal_slice.start. |
| """ |
| temporal_slice = group_tiles[0].out_coords[2] |
|
|
| weights = torch.zeros_like(buffer) |
|
|
| for tile in group_tiles: |
| decoded_tile = self.forward(latent[tile.in_coords], timestep, generator) |
| mask = tile.blend_mask.to(device=buffer.device, dtype=buffer.dtype) |
| temporal_offset = tile.out_coords[2].start - temporal_slice.start |
| |
| |
| expected_temporal_len = tile.out_coords[2].stop - tile.out_coords[2].start |
| decoded_temporal_len = decoded_tile.shape[2] |
|
|
| |
| actual_temporal_len = min(expected_temporal_len, decoded_temporal_len, buffer.shape[2] - temporal_offset) |
|
|
| chunk_coords = ( |
| slice(None), |
| slice(None), |
| slice(temporal_offset, temporal_offset + actual_temporal_len), |
| tile.out_coords[3], |
| tile.out_coords[4], |
| ) |
|
|
| |
| decoded_slice = decoded_tile[:, :, :actual_temporal_len, :, :] |
| mask_slice = mask[:, :, :actual_temporal_len, :, :] if mask.shape[2] > 1 else mask |
|
|
| buffer[chunk_coords] += decoded_slice * mask_slice |
| weights[chunk_coords] += mask_slice |
|
|
| return weights |
|
|
| def decode( |
| self, |
| latent: torch.Tensor, |
| tiled=False, |
| tile_size_in_pixels: Optional[int] = 512, |
| tile_overlap_in_pixels: Optional[int] = 128, |
| tile_size_in_frames: Optional[int] = 128, |
| tile_overlap_in_frames: Optional[int] = 24, |
| ) -> torch.Tensor: |
| if tiled: |
| tiling_config = TilingConfig( |
| spatial_config=SpatialTilingConfig( |
| tile_size_in_pixels=tile_size_in_pixels, |
| tile_overlap_in_pixels=tile_overlap_in_pixels, |
| ), |
| temporal_config=TemporalTilingConfig( |
| tile_size_in_frames=tile_size_in_frames, |
| tile_overlap_in_frames=tile_overlap_in_frames, |
| ), |
| ) |
| tiles = self.tiled_decode(latent, tiling_config) |
| return torch.cat(list(tiles), dim=2) |
| else: |
| return self.forward(latent) |
|
|
| def decode_video( |
| latent: torch.Tensor, |
| video_decoder: LTX2VideoDecoder, |
| tiling_config: TilingConfig | None = None, |
| generator: torch.Generator | None = None, |
| ) -> Iterator[torch.Tensor]: |
| """ |
| Decode a video latent tensor with the given decoder. |
| Args: |
| latent: Tensor [c, f, h, w] |
| video_decoder: Decoder module. |
| tiling_config: Optional tiling settings. |
| generator: Optional random generator for deterministic decoding. |
| Yields: |
| Decoded chunk [f, h, w, c], uint8 in [0, 255]. |
| """ |
|
|
| def convert_to_uint8(frames: torch.Tensor) -> torch.Tensor: |
| frames = (((frames + 1.0) / 2.0).clamp(0.0, 1.0) * 255.0).to(torch.uint8) |
| frames = rearrange(frames[0], "c f h w -> f h w c") |
| return frames |
|
|
| if tiling_config is not None: |
| for frames in video_decoder.tiled_decode(latent, tiling_config, generator=generator): |
| return convert_to_uint8(frames) |
| else: |
| decoded_video = video_decoder(latent, generator=generator) |
| return convert_to_uint8(decoded_video) |
|
|
|
|
| def get_video_chunks_number(num_frames: int, tiling_config: TilingConfig | None = None) -> int: |
| """ |
| Get the number of video chunks for a given number of frames and tiling configuration. |
| Args: |
| num_frames: Number of frames in the video. |
| tiling_config: Tiling configuration. |
| Returns: |
| Number of video chunks. |
| """ |
| if not tiling_config or not tiling_config.temporal_config: |
| return 1 |
| cfg = tiling_config.temporal_config |
| frame_stride = cfg.tile_size_in_frames - cfg.tile_overlap_in_frames |
| return (num_frames - 1 + frame_stride - 1) // frame_stride |
|
|
|
|
| def split_in_spatial(size: int, overlap: int) -> SplitOperation: |
|
|
| def split(dimension_size: int) -> DimensionIntervals: |
| if dimension_size <= size: |
| return DEFAULT_SPLIT_OPERATION(dimension_size) |
| amount = (dimension_size + size - 2 * overlap - 1) // (size - overlap) |
| starts = [i * (size - overlap) for i in range(amount)] |
| ends = [start + size for start in starts] |
| ends[-1] = dimension_size |
| left_ramps = [0] + [overlap] * (amount - 1) |
| right_ramps = [overlap] * (amount - 1) + [0] |
| return DimensionIntervals(starts=starts, ends=ends, left_ramps=left_ramps, right_ramps=right_ramps) |
|
|
| return split |
|
|
|
|
| def split_in_temporal(size: int, overlap: int) -> SplitOperation: |
| non_causal_split = split_in_spatial(size, overlap) |
|
|
| def split(dimension_size: int) -> DimensionIntervals: |
| if dimension_size <= size: |
| return DEFAULT_SPLIT_OPERATION(dimension_size) |
| intervals = non_causal_split(dimension_size) |
| starts = intervals.starts |
| starts[1:] = [s - 1 for s in starts[1:]] |
| left_ramps = intervals.left_ramps |
| left_ramps[1:] = [r + 1 for r in left_ramps[1:]] |
| return replace(intervals, starts=starts, left_ramps=left_ramps) |
|
|
| return split |
|
|
|
|
| def to_mapping_operation( |
| map_func: Callable[[int, int, int, int, int], Tuple[slice, torch.Tensor]], |
| scale: int, |
| ) -> MappingOperation: |
|
|
| def map_op(intervals: DimensionIntervals) -> tuple[list[slice], list[torch.Tensor | None]]: |
| output_slices: list[slice] = [] |
| masks_1d: list[torch.Tensor | None] = [] |
| number_of_slices = len(intervals.starts) |
| for i in range(number_of_slices): |
| start = intervals.starts[i] |
| end = intervals.ends[i] |
| left_ramp = intervals.left_ramps[i] |
| right_ramp = intervals.right_ramps[i] |
| output_slice, mask_1d = map_func(start, end, left_ramp, right_ramp, scale) |
| output_slices.append(output_slice) |
| masks_1d.append(mask_1d) |
| return output_slices, masks_1d |
|
|
| return map_op |
|
|
|
|
| def map_temporal_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: |
| start = begin * scale |
| stop = 1 + (end - 1) * scale |
| left_ramp = 1 + (left_ramp - 1) * scale |
| right_ramp = right_ramp * scale |
|
|
| return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, True) |
|
|
|
|
| def map_spatial_slice(begin: int, end: int, left_ramp: int, right_ramp: int, scale: int) -> Tuple[slice, torch.Tensor]: |
| start = begin * scale |
| stop = end * scale |
| left_ramp = left_ramp * scale |
| right_ramp = right_ramp * scale |
|
|
| return slice(start, stop), compute_trapezoidal_mask_1d(stop - start, left_ramp, right_ramp, False) |
|
|