| | |
| | |
| |
|
| | from typing import List, Optional, Tuple, Union |
| | from functools import partial |
| | import math |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from einops import rearrange |
| |
|
| | from comfy.ldm.modules.attention import optimized_attention |
| |
|
| | import comfy.ops |
| | ops = comfy.ops.disable_weight_init |
| |
|
| | |
| | |
| |
|
| |
|
| | def cast_tuple(t, length=1): |
| | return t if isinstance(t, tuple) else ((t,) * length) |
| |
|
| |
|
| | class GroupNormSpatial(ops.GroupNorm): |
| | """ |
| | GroupNorm applied per-frame. |
| | """ |
| |
|
| | def forward(self, x: torch.Tensor, *, chunk_size: int = 8): |
| | B, C, T, H, W = x.shape |
| | x = rearrange(x, "B C T H W -> (B T) C H W") |
| | |
| | output = torch.empty_like(x) |
| | for b in range(0, B * T, chunk_size): |
| | output[b : b + chunk_size] = super().forward(x[b : b + chunk_size]) |
| | return rearrange(output, "(B T) C H W -> B C T H W", B=B, T=T) |
| |
|
| | class PConv3d(ops.Conv3d): |
| | def __init__( |
| | self, |
| | in_channels, |
| | out_channels, |
| | kernel_size: Union[int, Tuple[int, int, int]], |
| | stride: Union[int, Tuple[int, int, int]], |
| | causal: bool = True, |
| | context_parallel: bool = True, |
| | **kwargs, |
| | ): |
| | self.causal = causal |
| | self.context_parallel = context_parallel |
| | kernel_size = cast_tuple(kernel_size, 3) |
| | stride = cast_tuple(stride, 3) |
| | height_pad = (kernel_size[1] - 1) // 2 |
| | width_pad = (kernel_size[2] - 1) // 2 |
| |
|
| | super().__init__( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=kernel_size, |
| | stride=stride, |
| | dilation=(1, 1, 1), |
| | padding=(0, height_pad, width_pad), |
| | **kwargs, |
| | ) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | |
| | context_size = self.kernel_size[0] - 1 |
| | if self.causal: |
| | pad_front = context_size |
| | pad_back = 0 |
| | else: |
| | pad_front = context_size // 2 |
| | pad_back = context_size - pad_front |
| |
|
| | |
| | assert self.padding_mode == "replicate" |
| | mode = "constant" if self.padding_mode == "zeros" else self.padding_mode |
| | x = F.pad(x, (0, 0, 0, 0, pad_front, pad_back), mode=mode) |
| | return super().forward(x) |
| |
|
| |
|
| | class Conv1x1(ops.Linear): |
| | """*1x1 Conv implemented with a linear layer.""" |
| |
|
| | def __init__(self, in_features: int, out_features: int, *args, **kwargs): |
| | super().__init__(in_features, out_features, *args, **kwargs) |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """Forward pass. |
| | |
| | Args: |
| | x: Input tensor. Shape: [B, C, *] or [B, *, C]. |
| | |
| | Returns: |
| | x: Output tensor. Shape: [B, C', *] or [B, *, C']. |
| | """ |
| | x = x.movedim(1, -1) |
| | x = super().forward(x) |
| | x = x.movedim(-1, 1) |
| | return x |
| |
|
| |
|
| | class DepthToSpaceTime(nn.Module): |
| | def __init__( |
| | self, |
| | temporal_expansion: int, |
| | spatial_expansion: int, |
| | ): |
| | super().__init__() |
| | self.temporal_expansion = temporal_expansion |
| | self.spatial_expansion = spatial_expansion |
| |
|
| | |
| | def extra_repr(self): |
| | return f"texp={self.temporal_expansion}, sexp={self.spatial_expansion}" |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """Forward pass. |
| | |
| | Args: |
| | x: Input tensor. Shape: [B, C, T, H, W]. |
| | |
| | Returns: |
| | x: Rearranged tensor. Shape: [B, C/(st*s*s), T*st, H*s, W*s]. |
| | """ |
| | x = rearrange( |
| | x, |
| | "B (C st sh sw) T H W -> B C (T st) (H sh) (W sw)", |
| | st=self.temporal_expansion, |
| | sh=self.spatial_expansion, |
| | sw=self.spatial_expansion, |
| | ) |
| |
|
| | |
| | if self.temporal_expansion > 1: |
| | |
| | |
| | |
| | assert all(x.shape) |
| | x = x[:, :, self.temporal_expansion - 1 :] |
| | assert all(x.shape) |
| |
|
| | return x |
| |
|
| |
|
| | def norm_fn( |
| | in_channels: int, |
| | affine: bool = True, |
| | ): |
| | return GroupNormSpatial(affine=affine, num_groups=32, num_channels=in_channels) |
| |
|
| |
|
| | class ResBlock(nn.Module): |
| | """Residual block that preserves the spatial dimensions.""" |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | *, |
| | affine: bool = True, |
| | attn_block: Optional[nn.Module] = None, |
| | causal: bool = True, |
| | prune_bottleneck: bool = False, |
| | padding_mode: str, |
| | bias: bool = True, |
| | ): |
| | super().__init__() |
| | self.channels = channels |
| |
|
| | assert causal |
| | self.stack = nn.Sequential( |
| | norm_fn(channels, affine=affine), |
| | nn.SiLU(inplace=True), |
| | PConv3d( |
| | in_channels=channels, |
| | out_channels=channels // 2 if prune_bottleneck else channels, |
| | kernel_size=(3, 3, 3), |
| | stride=(1, 1, 1), |
| | padding_mode=padding_mode, |
| | bias=bias, |
| | causal=causal, |
| | ), |
| | norm_fn(channels, affine=affine), |
| | nn.SiLU(inplace=True), |
| | PConv3d( |
| | in_channels=channels // 2 if prune_bottleneck else channels, |
| | out_channels=channels, |
| | kernel_size=(3, 3, 3), |
| | stride=(1, 1, 1), |
| | padding_mode=padding_mode, |
| | bias=bias, |
| | causal=causal, |
| | ), |
| | ) |
| |
|
| | self.attn_block = attn_block if attn_block else nn.Identity() |
| |
|
| | def forward(self, x: torch.Tensor): |
| | """Forward pass. |
| | |
| | Args: |
| | x: Input tensor. Shape: [B, C, T, H, W]. |
| | """ |
| | residual = x |
| | x = self.stack(x) |
| | x = x + residual |
| | del residual |
| |
|
| | return self.attn_block(x) |
| |
|
| |
|
| | class Attention(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | head_dim: int = 32, |
| | qkv_bias: bool = False, |
| | out_bias: bool = True, |
| | qk_norm: bool = True, |
| | ) -> None: |
| | super().__init__() |
| | self.head_dim = head_dim |
| | self.num_heads = dim // head_dim |
| | self.qk_norm = qk_norm |
| |
|
| | self.qkv = nn.Linear(dim, 3 * dim, bias=qkv_bias) |
| | self.out = nn.Linear(dim, dim, bias=out_bias) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Compute temporal self-attention. |
| | |
| | Args: |
| | x: Input tensor. Shape: [B, C, T, H, W]. |
| | chunk_size: Chunk size for large tensors. |
| | |
| | Returns: |
| | x: Output tensor. Shape: [B, C, T, H, W]. |
| | """ |
| | B, _, T, H, W = x.shape |
| |
|
| | if T == 1: |
| | |
| | x = x.movedim(1, -1) |
| | qkv = self.qkv(x) |
| | _, _, x = qkv.chunk(3, dim=-1) |
| | x = self.out(x) |
| | return x.movedim(-1, 1) |
| |
|
| | |
| | x = rearrange(x, "B C t h w -> (B h w) t C") |
| | qkv = self.qkv(x) |
| |
|
| | |
| | |
| | q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, self.head_dim).transpose(1, 3).unbind(2) |
| |
|
| | if self.qk_norm: |
| | q = F.normalize(q, p=2, dim=-1) |
| | k = F.normalize(k, p=2, dim=-1) |
| |
|
| | x = optimized_attention(q, k, v, self.num_heads, skip_reshape=True) |
| |
|
| | assert x.size(0) == q.size(0) |
| |
|
| | x = self.out(x) |
| | x = rearrange(x, "(B h w) t C -> B C t h w", B=B, h=H, w=W) |
| | return x |
| |
|
| |
|
| | class AttentionBlock(nn.Module): |
| | def __init__( |
| | self, |
| | dim: int, |
| | **attn_kwargs, |
| | ) -> None: |
| | super().__init__() |
| | self.norm = norm_fn(dim) |
| | self.attn = Attention(dim, **attn_kwargs) |
| |
|
| | def forward(self, x: torch.Tensor) -> torch.Tensor: |
| | return x + self.attn(self.norm(x)) |
| |
|
| |
|
| | class CausalUpsampleBlock(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | num_res_blocks: int, |
| | *, |
| | temporal_expansion: int = 2, |
| | spatial_expansion: int = 2, |
| | **block_kwargs, |
| | ): |
| | super().__init__() |
| |
|
| | blocks = [] |
| | for _ in range(num_res_blocks): |
| | blocks.append(block_fn(in_channels, **block_kwargs)) |
| | self.blocks = nn.Sequential(*blocks) |
| |
|
| | self.temporal_expansion = temporal_expansion |
| | self.spatial_expansion = spatial_expansion |
| |
|
| | |
| | self.proj = Conv1x1( |
| | in_channels, |
| | out_channels * temporal_expansion * (spatial_expansion**2), |
| | ) |
| |
|
| | self.d2st = DepthToSpaceTime( |
| | temporal_expansion=temporal_expansion, spatial_expansion=spatial_expansion |
| | ) |
| |
|
| | def forward(self, x): |
| | x = self.blocks(x) |
| | x = self.proj(x) |
| | x = self.d2st(x) |
| | return x |
| |
|
| |
|
| | def block_fn(channels, *, affine: bool = True, has_attention: bool = False, **block_kwargs): |
| | attn_block = AttentionBlock(channels) if has_attention else None |
| | return ResBlock(channels, affine=affine, attn_block=attn_block, **block_kwargs) |
| |
|
| |
|
| | class DownsampleBlock(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | num_res_blocks, |
| | *, |
| | temporal_reduction=2, |
| | spatial_reduction=2, |
| | **block_kwargs, |
| | ): |
| | """ |
| | Downsample block for the VAE encoder. |
| | |
| | Args: |
| | in_channels: Number of input channels. |
| | out_channels: Number of output channels. |
| | num_res_blocks: Number of residual blocks. |
| | temporal_reduction: Temporal reduction factor. |
| | spatial_reduction: Spatial reduction factor. |
| | """ |
| | super().__init__() |
| | layers = [] |
| |
|
| | |
| | |
| | |
| | assert in_channels != out_channels |
| | layers.append( |
| | PConv3d( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | kernel_size=(temporal_reduction, spatial_reduction, spatial_reduction), |
| | stride=(temporal_reduction, spatial_reduction, spatial_reduction), |
| | |
| | padding_mode="replicate", |
| | bias=block_kwargs["bias"], |
| | ) |
| | ) |
| |
|
| | for _ in range(num_res_blocks): |
| | layers.append(block_fn(out_channels, **block_kwargs)) |
| |
|
| | self.layers = nn.Sequential(*layers) |
| |
|
| | def forward(self, x): |
| | return self.layers(x) |
| |
|
| |
|
| | def add_fourier_features(inputs: torch.Tensor, start=6, stop=8, step=1): |
| | num_freqs = (stop - start) // step |
| | assert inputs.ndim == 5 |
| | C = inputs.size(1) |
| |
|
| | |
| | freqs = torch.arange(start, stop, step, dtype=inputs.dtype, device=inputs.device) |
| | assert num_freqs == len(freqs) |
| | w = torch.pow(2.0, freqs) * (2 * torch.pi) |
| | C = inputs.shape[1] |
| | w = w.repeat(C)[None, :, None, None, None] |
| |
|
| | |
| | h = inputs.repeat_interleave(num_freqs, dim=1) |
| | |
| | h = w * h |
| |
|
| | return torch.cat( |
| | [ |
| | inputs, |
| | torch.sin(h), |
| | torch.cos(h), |
| | ], |
| | dim=1, |
| | ) |
| |
|
| |
|
| | class FourierFeatures(nn.Module): |
| | def __init__(self, start: int = 6, stop: int = 8, step: int = 1): |
| | super().__init__() |
| | self.start = start |
| | self.stop = stop |
| | self.step = step |
| |
|
| | def forward(self, inputs): |
| | """Add Fourier features to inputs. |
| | |
| | Args: |
| | inputs: Input tensor. Shape: [B, C, T, H, W] |
| | |
| | Returns: |
| | h: Output tensor. Shape: [B, (1 + 2 * num_freqs) * C, T, H, W] |
| | """ |
| | return add_fourier_features(inputs, self.start, self.stop, self.step) |
| |
|
| |
|
| | class Decoder(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | out_channels: int = 3, |
| | latent_dim: int, |
| | base_channels: int, |
| | channel_multipliers: List[int], |
| | num_res_blocks: List[int], |
| | temporal_expansions: Optional[List[int]] = None, |
| | spatial_expansions: Optional[List[int]] = None, |
| | has_attention: List[bool], |
| | output_norm: bool = True, |
| | nonlinearity: str = "silu", |
| | output_nonlinearity: str = "silu", |
| | causal: bool = True, |
| | **block_kwargs, |
| | ): |
| | super().__init__() |
| | self.input_channels = latent_dim |
| | self.base_channels = base_channels |
| | self.channel_multipliers = channel_multipliers |
| | self.num_res_blocks = num_res_blocks |
| | self.output_nonlinearity = output_nonlinearity |
| | assert nonlinearity == "silu" |
| | assert causal |
| |
|
| | ch = [mult * base_channels for mult in channel_multipliers] |
| | self.num_up_blocks = len(ch) - 1 |
| | assert len(num_res_blocks) == self.num_up_blocks + 2 |
| |
|
| | blocks = [] |
| |
|
| | first_block = [ |
| | ops.Conv3d(latent_dim, ch[-1], kernel_size=(1, 1, 1)) |
| | ] |
| | |
| | for _ in range(num_res_blocks[-1]): |
| | first_block.append( |
| | block_fn( |
| | ch[-1], |
| | has_attention=has_attention[-1], |
| | causal=causal, |
| | **block_kwargs, |
| | ) |
| | ) |
| | blocks.append(nn.Sequential(*first_block)) |
| |
|
| | assert len(temporal_expansions) == len(spatial_expansions) == self.num_up_blocks |
| | assert len(num_res_blocks) == len(has_attention) == self.num_up_blocks + 2 |
| |
|
| | upsample_block_fn = CausalUpsampleBlock |
| |
|
| | for i in range(self.num_up_blocks): |
| | block = upsample_block_fn( |
| | ch[-i - 1], |
| | ch[-i - 2], |
| | num_res_blocks=num_res_blocks[-i - 2], |
| | has_attention=has_attention[-i - 2], |
| | temporal_expansion=temporal_expansions[-i - 1], |
| | spatial_expansion=spatial_expansions[-i - 1], |
| | causal=causal, |
| | **block_kwargs, |
| | ) |
| | blocks.append(block) |
| |
|
| | assert not output_norm |
| |
|
| | |
| | last_block = [] |
| | for _ in range(num_res_blocks[0]): |
| | last_block.append( |
| | block_fn( |
| | ch[0], has_attention=has_attention[0], causal=causal, **block_kwargs |
| | ) |
| | ) |
| | blocks.append(nn.Sequential(*last_block)) |
| |
|
| | self.blocks = nn.ModuleList(blocks) |
| | self.output_proj = Conv1x1(ch[0], out_channels) |
| |
|
| | def forward(self, x): |
| | """Forward pass. |
| | |
| | Args: |
| | x: Latent tensor. Shape: [B, input_channels, t, h, w]. Scaled [-1, 1]. |
| | |
| | Returns: |
| | x: Reconstructed video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1]. |
| | T + 1 = (t - 1) * 4. |
| | H = h * 16, W = w * 16. |
| | """ |
| | for block in self.blocks: |
| | x = block(x) |
| |
|
| | if self.output_nonlinearity == "silu": |
| | x = F.silu(x, inplace=not self.training) |
| | else: |
| | assert ( |
| | not self.output_nonlinearity |
| | ) |
| |
|
| | return self.output_proj(x).contiguous() |
| |
|
| | class LatentDistribution: |
| | def __init__(self, mean: torch.Tensor, logvar: torch.Tensor): |
| | """Initialize latent distribution. |
| | |
| | Args: |
| | mean: Mean of the distribution. Shape: [B, C, T, H, W]. |
| | logvar: Logarithm of variance of the distribution. Shape: [B, C, T, H, W]. |
| | """ |
| | assert mean.shape == logvar.shape |
| | self.mean = mean |
| | self.logvar = logvar |
| |
|
| | def sample(self, temperature=1.0, generator: torch.Generator = None, noise=None): |
| | if temperature == 0.0: |
| | return self.mean |
| |
|
| | if noise is None: |
| | noise = torch.randn(self.mean.shape, device=self.mean.device, dtype=self.mean.dtype, generator=generator) |
| | else: |
| | assert noise.device == self.mean.device |
| | noise = noise.to(self.mean.dtype) |
| |
|
| | if temperature != 1.0: |
| | raise NotImplementedError(f"Temperature {temperature} is not supported.") |
| |
|
| | |
| | return noise * torch.exp(self.logvar * 0.5) + self.mean |
| |
|
| | def mode(self): |
| | return self.mean |
| |
|
| | class Encoder(nn.Module): |
| | def __init__( |
| | self, |
| | *, |
| | in_channels: int, |
| | base_channels: int, |
| | channel_multipliers: List[int], |
| | num_res_blocks: List[int], |
| | latent_dim: int, |
| | temporal_reductions: List[int], |
| | spatial_reductions: List[int], |
| | prune_bottlenecks: List[bool], |
| | has_attentions: List[bool], |
| | affine: bool = True, |
| | bias: bool = True, |
| | input_is_conv_1x1: bool = False, |
| | padding_mode: str, |
| | ): |
| | super().__init__() |
| | self.temporal_reductions = temporal_reductions |
| | self.spatial_reductions = spatial_reductions |
| | self.base_channels = base_channels |
| | self.channel_multipliers = channel_multipliers |
| | self.num_res_blocks = num_res_blocks |
| | self.latent_dim = latent_dim |
| |
|
| | self.fourier_features = FourierFeatures() |
| | ch = [mult * base_channels for mult in channel_multipliers] |
| | num_down_blocks = len(ch) - 1 |
| | assert len(num_res_blocks) == num_down_blocks + 2 |
| |
|
| | layers = ( |
| | [ops.Conv3d(in_channels, ch[0], kernel_size=(1, 1, 1), bias=True)] |
| | if not input_is_conv_1x1 |
| | else [Conv1x1(in_channels, ch[0])] |
| | ) |
| |
|
| | assert len(prune_bottlenecks) == num_down_blocks + 2 |
| | assert len(has_attentions) == num_down_blocks + 2 |
| | block = partial(block_fn, padding_mode=padding_mode, affine=affine, bias=bias) |
| |
|
| | for _ in range(num_res_blocks[0]): |
| | layers.append(block(ch[0], has_attention=has_attentions[0], prune_bottleneck=prune_bottlenecks[0])) |
| | prune_bottlenecks = prune_bottlenecks[1:] |
| | has_attentions = has_attentions[1:] |
| |
|
| | assert len(temporal_reductions) == len(spatial_reductions) == len(ch) - 1 |
| | for i in range(num_down_blocks): |
| | layer = DownsampleBlock( |
| | ch[i], |
| | ch[i + 1], |
| | num_res_blocks=num_res_blocks[i + 1], |
| | temporal_reduction=temporal_reductions[i], |
| | spatial_reduction=spatial_reductions[i], |
| | prune_bottleneck=prune_bottlenecks[i], |
| | has_attention=has_attentions[i], |
| | affine=affine, |
| | bias=bias, |
| | padding_mode=padding_mode, |
| | ) |
| |
|
| | layers.append(layer) |
| |
|
| | |
| | for _ in range(num_res_blocks[-1]): |
| | layers.append(block(ch[-1], has_attention=has_attentions[-1], prune_bottleneck=prune_bottlenecks[-1])) |
| |
|
| | self.layers = nn.Sequential(*layers) |
| |
|
| | |
| | self.output_norm = norm_fn(ch[-1]) |
| | self.output_proj = Conv1x1(ch[-1], 2 * latent_dim, bias=False) |
| |
|
| | @property |
| | def temporal_downsample(self): |
| | return math.prod(self.temporal_reductions) |
| |
|
| | @property |
| | def spatial_downsample(self): |
| | return math.prod(self.spatial_reductions) |
| |
|
| | def forward(self, x) -> LatentDistribution: |
| | """Forward pass. |
| | |
| | Args: |
| | x: Input video tensor. Shape: [B, C, T, H, W]. Scaled to [-1, 1] |
| | |
| | Returns: |
| | means: Latent tensor. Shape: [B, latent_dim, t, h, w]. Scaled [-1, 1]. |
| | h = H // 8, w = W // 8, t - 1 = (T - 1) // 6 |
| | logvar: Shape: [B, latent_dim, t, h, w]. |
| | """ |
| | assert x.ndim == 5, f"Expected 5D input, got {x.shape}" |
| | x = self.fourier_features(x) |
| |
|
| | x = self.layers(x) |
| |
|
| | x = self.output_norm(x) |
| | x = F.silu(x, inplace=True) |
| | x = self.output_proj(x) |
| |
|
| | means, logvar = torch.chunk(x, 2, dim=1) |
| |
|
| | assert means.ndim == 5 |
| | assert logvar.shape == means.shape |
| | assert means.size(1) == self.latent_dim |
| |
|
| | return LatentDistribution(means, logvar) |
| |
|
| |
|
| | class VideoVAE(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.encoder = Encoder( |
| | in_channels=15, |
| | base_channels=64, |
| | channel_multipliers=[1, 2, 4, 6], |
| | num_res_blocks=[3, 3, 4, 6, 3], |
| | latent_dim=12, |
| | temporal_reductions=[1, 2, 3], |
| | spatial_reductions=[2, 2, 2], |
| | prune_bottlenecks=[False, False, False, False, False], |
| | has_attentions=[False, True, True, True, True], |
| | affine=True, |
| | bias=True, |
| | input_is_conv_1x1=True, |
| | padding_mode="replicate" |
| | ) |
| | self.decoder = Decoder( |
| | out_channels=3, |
| | base_channels=128, |
| | channel_multipliers=[1, 2, 4, 6], |
| | temporal_expansions=[1, 2, 3], |
| | spatial_expansions=[2, 2, 2], |
| | num_res_blocks=[3, 3, 4, 6, 3], |
| | latent_dim=12, |
| | has_attention=[False, False, False, False, False], |
| | padding_mode="replicate", |
| | output_norm=False, |
| | nonlinearity="silu", |
| | output_nonlinearity="silu", |
| | causal=True, |
| | ) |
| |
|
| | def encode(self, x): |
| | return self.encoder(x).mode() |
| |
|
| | def decode(self, x): |
| | return self.decoder(x) |
| |
|