| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from contextlib import nullcontext |
| | from typing import Optional, Tuple, Literal, Callable, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution |
| | from einops import rearrange |
| |
|
| | from common.distributed.advanced import get_sequence_parallel_world_size |
| | from common.logger import get_logger |
| | from models.video_vae_v3.modules.causal_inflation_lib import ( |
| | InflatedCausalConv3d, |
| | causal_norm_wrapper, |
| | init_causal_conv3d, |
| | remove_head, |
| | ) |
| | from models.video_vae_v3.modules.context_parallel_lib import ( |
| | causal_conv_gather_outputs, |
| | causal_conv_slice_inputs, |
| | ) |
| | from models.video_vae_v3.modules.global_config import set_norm_limit |
| | from models.video_vae_v3.modules.types import ( |
| | CausalAutoencoderOutput, |
| | CausalDecoderOutput, |
| | CausalEncoderOutput, |
| | MemoryState, |
| | _inflation_mode_t, |
| | _memory_device_t, |
| | _receptive_field_t, |
| | _selective_checkpointing_t, |
| | ) |
| |
|
| | logger = get_logger(__name__) |
| |
|
| | |
| | def gradient_checkpointing(module: Union[Callable, nn.Module], *args, enabled: bool, **kwargs): |
| | return module(*args, **kwargs) |
| |
|
| | class ResnetBlock2D(nn.Module): |
| | r""" |
| | A Resnet block. |
| | |
| | Parameters: |
| | in_channels (`int`): The number of channels in the input. |
| | out_channels (`int`, *optional*, default to be `None`): |
| | The number of output channels for the first conv2d layer. |
| | If None, same as `in_channels`. |
| | dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. |
| | """ |
| |
|
| | def __init__( |
| | self, *, in_channels: int, out_channels: Optional[int] = None, dropout: float = 0.0 |
| | ): |
| | super().__init__() |
| | self.in_channels = in_channels |
| | out_channels = in_channels if out_channels is None else out_channels |
| | self.out_channels = out_channels |
| |
|
| | self.nonlinearity = nn.SiLU() |
| |
|
| | self.norm1 = torch.nn.GroupNorm( |
| | num_groups=32, num_channels=in_channels, eps=1e-6, affine=True |
| | ) |
| |
|
| | self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| |
|
| | self.norm2 = torch.nn.GroupNorm( |
| | num_groups=32, num_channels=out_channels, eps=1e-6, affine=True |
| | ) |
| |
|
| | self.dropout = torch.nn.Dropout(dropout) |
| | self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) |
| |
|
| | self.use_in_shortcut = self.in_channels != out_channels |
| |
|
| | self.conv_shortcut = None |
| | if self.use_in_shortcut: |
| | self.conv_shortcut = nn.Conv2d( |
| | in_channels, out_channels, kernel_size=1, stride=1, padding=0 |
| | ) |
| |
|
| | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: |
| | hidden = input_tensor |
| |
|
| | hidden = self.norm1(hidden) |
| | hidden = self.nonlinearity(hidden) |
| | hidden = self.conv1(hidden) |
| |
|
| | hidden = self.norm2(hidden) |
| | hidden = self.nonlinearity(hidden) |
| | hidden = self.dropout(hidden) |
| | hidden = self.conv2(hidden) |
| |
|
| | if self.conv_shortcut is not None: |
| | input_tensor = self.conv_shortcut(input_tensor) |
| |
|
| | output_tensor = input_tensor + hidden |
| |
|
| | return output_tensor |
| |
|
| | class Upsample3D(nn.Module): |
| | """A 3D upsampling layer.""" |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | temporal_up: bool = False, |
| | spatial_up: bool = True, |
| | slicing: bool = False, |
| | ): |
| | super().__init__() |
| | self.channels = channels |
| | self.conv = init_causal_conv3d( |
| | self.channels, self.channels, kernel_size=3, padding=1, inflation_mode=inflation_mode |
| | ) |
| |
|
| | self.temporal_up = temporal_up |
| | self.spatial_up = spatial_up |
| | self.temporal_ratio = 2 if temporal_up else 1 |
| | self.spatial_ratio = 2 if spatial_up else 1 |
| | self.slicing = slicing |
| |
|
| | upscale_ratio = (self.spatial_ratio**2) * self.temporal_ratio |
| | self.upscale_conv = nn.Conv3d( |
| | self.channels, self.channels * upscale_ratio, kernel_size=1, padding=0 |
| | ) |
| | identity = ( |
| | torch.eye(self.channels).repeat(upscale_ratio, 1).reshape_as(self.upscale_conv.weight) |
| | ) |
| |
|
| | self.upscale_conv.weight.data.copy_(identity) |
| | nn.init.zeros_(self.upscale_conv.bias) |
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | memory_state: MemoryState, |
| | ) -> torch.FloatTensor: |
| | return gradient_checkpointing( |
| | self.custom_forward, |
| | hidden_states, |
| | memory_state, |
| | enabled=self.training and self.gradient_checkpointing, |
| | ) |
| |
|
| | def custom_forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | memory_state: MemoryState, |
| | ) -> torch.FloatTensor: |
| | assert hidden_states.shape[1] == self.channels |
| |
|
| | if self.slicing: |
| | split_size = hidden_states.size(2) // 2 |
| | hidden_states = list( |
| | hidden_states.split([split_size, hidden_states.size(2) - split_size], dim=2) |
| | ) |
| | else: |
| | hidden_states = [hidden_states] |
| |
|
| | for i in range(len(hidden_states)): |
| | hidden_states[i] = self.upscale_conv(hidden_states[i]) |
| | hidden_states[i] = rearrange( |
| | hidden_states[i], |
| | "b (x y z c) f h w -> b c (f z) (h x) (w y)", |
| | x=self.spatial_ratio, |
| | y=self.spatial_ratio, |
| | z=self.temporal_ratio, |
| | ) |
| |
|
| | |
| | if self.temporal_up and memory_state != MemoryState.ACTIVE: |
| | hidden_states[0] = remove_head(hidden_states[0]) |
| |
|
| | if self.slicing: |
| | hidden_states = self.conv(hidden_states, memory_state=memory_state) |
| | return torch.cat(hidden_states, dim=2) |
| | else: |
| | return self.conv(hidden_states[0], memory_state=memory_state) |
| |
|
| |
|
| | class Downsample3D(nn.Module): |
| | """A 3D downsampling layer.""" |
| |
|
| | def __init__( |
| | self, |
| | channels: int, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | temporal_down: bool = False, |
| | spatial_down: bool = True, |
| | ): |
| | super().__init__() |
| | self.channels = channels |
| | self.temporal_down = temporal_down |
| | self.spatial_down = spatial_down |
| |
|
| | self.temporal_ratio = 2 if temporal_down else 1 |
| | self.spatial_ratio = 2 if spatial_down else 1 |
| |
|
| | self.temporal_kernel = 3 if temporal_down else 1 |
| | self.spatial_kernel = 3 if spatial_down else 1 |
| |
|
| | self.conv = init_causal_conv3d( |
| | self.channels, |
| | self.channels, |
| | kernel_size=(self.temporal_kernel, self.spatial_kernel, self.spatial_kernel), |
| | stride=(self.temporal_ratio, self.spatial_ratio, self.spatial_ratio), |
| | padding=((1 if self.temporal_down else 0), 0, 0), |
| | inflation_mode=inflation_mode, |
| | ) |
| | self.gradient_checkpointing = False |
| |
|
| | def forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | memory_state: MemoryState, |
| | ) -> torch.FloatTensor: |
| | return gradient_checkpointing( |
| | self.custom_forward, |
| | hidden_states, |
| | memory_state, |
| | enabled=self.training and self.gradient_checkpointing, |
| | ) |
| |
|
| | def custom_forward( |
| | self, |
| | hidden_states: torch.FloatTensor, |
| | memory_state: MemoryState, |
| | ) -> torch.FloatTensor: |
| |
|
| | assert hidden_states.shape[1] == self.channels |
| |
|
| | if self.spatial_down: |
| | hidden_states = F.pad(hidden_states, (0, 1, 0, 1), mode="constant", value=0) |
| |
|
| | hidden_states = self.conv(hidden_states, memory_state=memory_state) |
| | return hidden_states |
| |
|
| |
|
| | class ResnetBlock3D(ResnetBlock2D): |
| | def __init__( |
| | self, |
| | *args, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | **kwargs, |
| | ): |
| | super().__init__(*args, **kwargs) |
| | self.conv1 = init_causal_conv3d( |
| | self.in_channels, |
| | self.out_channels, |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | inflation_mode=inflation_mode, |
| | ) |
| |
|
| | self.conv2 = init_causal_conv3d( |
| | self.out_channels, |
| | self.out_channels, |
| | kernel_size=(1, 3, 3) if time_receptive_field == "half" else (3, 3, 3), |
| | stride=1, |
| | padding=(0, 1, 1) if time_receptive_field == "half" else (1, 1, 1), |
| | inflation_mode=inflation_mode, |
| | ) |
| |
|
| | if self.use_in_shortcut: |
| | self.conv_shortcut = init_causal_conv3d( |
| | self.in_channels, |
| | self.out_channels, |
| | kernel_size=1, |
| | stride=1, |
| | padding=0, |
| | bias=(self.conv_shortcut.bias is not None), |
| | inflation_mode=inflation_mode, |
| | ) |
| | self.gradient_checkpointing = False |
| |
|
| | def forward(self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET): |
| | return gradient_checkpointing( |
| | self.custom_forward, |
| | input_tensor, |
| | memory_state, |
| | enabled=self.training and self.gradient_checkpointing, |
| | ) |
| |
|
| | def custom_forward( |
| | self, input_tensor: torch.Tensor, memory_state: MemoryState = MemoryState.UNSET |
| | ): |
| | assert memory_state != MemoryState.UNSET |
| | hidden_states = input_tensor |
| |
|
| | hidden_states = causal_norm_wrapper(self.norm1, hidden_states) |
| | hidden_states = self.nonlinearity(hidden_states) |
| | hidden_states = self.conv1(hidden_states, memory_state=memory_state) |
| |
|
| | hidden_states = causal_norm_wrapper(self.norm2, hidden_states) |
| | hidden_states = self.nonlinearity(hidden_states) |
| | hidden_states = self.dropout(hidden_states) |
| | hidden_states = self.conv2(hidden_states, memory_state=memory_state) |
| |
|
| | if self.conv_shortcut is not None: |
| | input_tensor = self.conv_shortcut(input_tensor, memory_state=memory_state) |
| |
|
| | output_tensor = input_tensor + hidden_states |
| |
|
| | return output_tensor |
| |
|
| |
|
| | class DownEncoderBlock3D(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | dropout: float = 0.0, |
| | num_layers: int = 1, |
| | add_downsample: bool = True, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | temporal_down: bool = True, |
| | spatial_down: bool = True, |
| | ): |
| | super().__init__() |
| | resnets = [] |
| |
|
| | for i in range(num_layers): |
| | in_channels = in_channels if i == 0 else out_channels |
| | resnets.append( |
| | ResnetBlock3D( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | dropout=dropout, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| | ) |
| |
|
| | self.resnets = nn.ModuleList(resnets) |
| |
|
| | self.downsamplers = None |
| | if add_downsample: |
| | |
| | self.downsamplers = nn.ModuleList( |
| | [ |
| | Downsample3D( |
| | channels=out_channels, |
| | inflation_mode=inflation_mode, |
| | temporal_down=temporal_down, |
| | spatial_down=spatial_down, |
| | ) |
| | ] |
| | ) |
| |
|
| | def forward( |
| | self, hidden_states: torch.FloatTensor, memory_state: MemoryState |
| | ) -> torch.FloatTensor: |
| | for resnet in self.resnets: |
| | hidden_states = resnet(hidden_states, memory_state=memory_state) |
| |
|
| | if self.downsamplers is not None: |
| | for downsampler in self.downsamplers: |
| | hidden_states = downsampler(hidden_states, memory_state=memory_state) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class UpDecoderBlock3D(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int, |
| | out_channels: int, |
| | dropout: float = 0.0, |
| | num_layers: int = 1, |
| | add_upsample: bool = True, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | temporal_up: bool = True, |
| | spatial_up: bool = True, |
| | slicing: bool = False, |
| | ): |
| | super().__init__() |
| | resnets = [] |
| |
|
| | for i in range(num_layers): |
| | input_channels = in_channels if i == 0 else out_channels |
| |
|
| | resnets.append( |
| | ResnetBlock3D( |
| | in_channels=input_channels, |
| | out_channels=out_channels, |
| | dropout=dropout, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| | ) |
| |
|
| | self.resnets = nn.ModuleList(resnets) |
| |
|
| | self.upsamplers = None |
| | |
| | if add_upsample: |
| | self.upsamplers = nn.ModuleList( |
| | [ |
| | Upsample3D( |
| | channels=out_channels, |
| | inflation_mode=inflation_mode, |
| | temporal_up=temporal_up, |
| | spatial_up=spatial_up, |
| | slicing=slicing, |
| | ) |
| | ] |
| | ) |
| |
|
| | def forward( |
| | self, hidden_states: torch.FloatTensor, memory_state: MemoryState |
| | ) -> torch.FloatTensor: |
| | for resnet in self.resnets: |
| | hidden_states = resnet(hidden_states, memory_state=memory_state) |
| |
|
| | if self.upsamplers is not None: |
| | for upsampler in self.upsamplers: |
| | hidden_states = upsampler(hidden_states, memory_state=memory_state) |
| |
|
| | return hidden_states |
| |
|
| |
|
| | class UNetMidBlock3D(nn.Module): |
| | def __init__( |
| | self, |
| | channels: int, |
| | dropout: float = 0.0, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | ): |
| | super().__init__() |
| | self.resnets = nn.ModuleList( |
| | [ |
| | ResnetBlock3D( |
| | in_channels=channels, |
| | out_channels=channels, |
| | dropout=dropout, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ), |
| | ResnetBlock3D( |
| | in_channels=channels, |
| | out_channels=channels, |
| | dropout=dropout, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ), |
| | ] |
| | ) |
| |
|
| | def forward(self, hidden_states: torch.Tensor, memory_state: MemoryState): |
| | for resnet in self.resnets: |
| | hidden_states = resnet(hidden_states, memory_state) |
| | return hidden_states |
| |
|
| |
|
| | class Encoder3D(nn.Module): |
| | r""" |
| | The `Encoder` layer of a variational autoencoder that encodes |
| | its input into a latent representation. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | block_out_channels: Tuple[int, ...] = (64,), |
| | layers_per_block: int = 2, |
| | double_z: bool = True, |
| | temporal_down_num: int = 2, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), |
| | ): |
| | super().__init__() |
| | self.layers_per_block = layers_per_block |
| |
|
| | self.temporal_down_num = temporal_down_num |
| |
|
| | self.conv_in = init_causal_conv3d( |
| | in_channels, |
| | block_out_channels[0], |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | inflation_mode=inflation_mode, |
| | ) |
| |
|
| | self.down_blocks = nn.ModuleList([]) |
| |
|
| | |
| | output_channel = block_out_channels[0] |
| | for i in range(len(block_out_channels)): |
| | input_channel = output_channel |
| | output_channel = block_out_channels[i] |
| | is_final_block = i == len(block_out_channels) - 1 |
| | is_temporal_down_block = i >= len(block_out_channels) - self.temporal_down_num - 1 |
| | |
| |
|
| | down_block = DownEncoderBlock3D( |
| | num_layers=self.layers_per_block, |
| | in_channels=input_channel, |
| | out_channels=output_channel, |
| | add_downsample=not is_final_block, |
| | temporal_down=is_temporal_down_block, |
| | spatial_down=True, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| | self.down_blocks.append(down_block) |
| |
|
| | |
| | self.mid_block = UNetMidBlock3D( |
| | channels=block_out_channels[-1], |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| |
|
| | |
| | self.conv_norm_out = nn.GroupNorm( |
| | num_channels=block_out_channels[-1], num_groups=32, eps=1e-6 |
| | ) |
| | self.conv_act = nn.SiLU() |
| |
|
| | conv_out_channels = 2 * out_channels if double_z else out_channels |
| | self.conv_out = init_causal_conv3d( |
| | block_out_channels[-1], conv_out_channels, 3, padding=1, inflation_mode=inflation_mode |
| | ) |
| |
|
| | assert len(selective_checkpointing) == len(self.down_blocks) |
| | self.set_gradient_checkpointing(selective_checkpointing) |
| |
|
| | def set_gradient_checkpointing(self, checkpointing_types): |
| | gradient_checkpointing = [] |
| | for down_block, sac_type in zip(self.down_blocks, checkpointing_types): |
| | if sac_type == "coarse": |
| | gradient_checkpointing.append(True) |
| | elif sac_type == "fine": |
| | for n, m in down_block.named_modules(): |
| | if hasattr(m, "gradient_checkpointing"): |
| | m.gradient_checkpointing = True |
| | logger.debug(f"set gradient_checkpointing: {n}") |
| | gradient_checkpointing.append(False) |
| | else: |
| | gradient_checkpointing.append(False) |
| | self.gradient_checkpointing = gradient_checkpointing |
| | logger.info(f"[Encoder3D] gradient_checkpointing: {checkpointing_types}") |
| |
|
| | def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: |
| | r"""The forward method of the `Encoder` class.""" |
| | sample = self.conv_in(sample, memory_state=memory_state) |
| | |
| | for down_block, sac in zip(self.down_blocks, self.gradient_checkpointing): |
| | sample = gradient_checkpointing( |
| | down_block, |
| | sample, |
| | memory_state=memory_state, |
| | enabled=self.training and sac, |
| | ) |
| |
|
| | |
| | sample = self.mid_block(sample, memory_state=memory_state) |
| |
|
| | |
| | sample = causal_norm_wrapper(self.conv_norm_out, sample) |
| | sample = self.conv_act(sample) |
| | sample = self.conv_out(sample, memory_state=memory_state) |
| |
|
| | return sample |
| |
|
| |
|
| | class Decoder3D(nn.Module): |
| | r""" |
| | The `Decoder` layer of a variational autoencoder that |
| | decodes its latent representation into an output sample. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | block_out_channels: Tuple[int, ...] = (64,), |
| | layers_per_block: int = 2, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | temporal_up_num: int = 2, |
| | slicing_up_num: int = 0, |
| | selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), |
| | ): |
| | super().__init__() |
| | self.layers_per_block = layers_per_block |
| | self.temporal_up_num = temporal_up_num |
| |
|
| | self.conv_in = init_causal_conv3d( |
| | in_channels, |
| | block_out_channels[-1], |
| | kernel_size=3, |
| | stride=1, |
| | padding=1, |
| | inflation_mode=inflation_mode, |
| | ) |
| |
|
| | self.up_blocks = nn.ModuleList([]) |
| |
|
| | |
| | self.mid_block = UNetMidBlock3D( |
| | channels=block_out_channels[-1], |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| |
|
| | |
| | reversed_block_out_channels = list(reversed(block_out_channels)) |
| | output_channel = reversed_block_out_channels[0] |
| | for i in range(len(reversed_block_out_channels)): |
| | prev_output_channel = output_channel |
| | output_channel = reversed_block_out_channels[i] |
| |
|
| | is_final_block = i == len(block_out_channels) - 1 |
| | is_temporal_up_block = i < self.temporal_up_num |
| | is_slicing_up_block = i >= len(block_out_channels) - slicing_up_num |
| | |
| |
|
| | up_block = UpDecoderBlock3D( |
| | num_layers=self.layers_per_block + 1, |
| | in_channels=prev_output_channel, |
| | out_channels=output_channel, |
| | add_upsample=not is_final_block, |
| | temporal_up=is_temporal_up_block, |
| | slicing=is_slicing_up_block, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| | self.up_blocks.append(up_block) |
| |
|
| | |
| | self.conv_norm_out = nn.GroupNorm( |
| | num_channels=block_out_channels[0], num_groups=32, eps=1e-6 |
| | ) |
| | self.conv_act = nn.SiLU() |
| | self.conv_out = init_causal_conv3d( |
| | block_out_channels[0], out_channels, 3, padding=1, inflation_mode=inflation_mode |
| | ) |
| |
|
| | assert len(selective_checkpointing) == len(self.up_blocks) |
| | self.set_gradient_checkpointing(selective_checkpointing) |
| |
|
| | def set_gradient_checkpointing(self, checkpointing_types): |
| | gradient_checkpointing = [] |
| | for up_block, sac_type in zip(self.up_blocks, checkpointing_types): |
| | if sac_type == "coarse": |
| | gradient_checkpointing.append(True) |
| | elif sac_type == "fine": |
| | for n, m in up_block.named_modules(): |
| | if hasattr(m, "gradient_checkpointing"): |
| | m.gradient_checkpointing = True |
| | logger.debug(f"set gradient_checkpointing: {n}") |
| | gradient_checkpointing.append(False) |
| | else: |
| | gradient_checkpointing.append(False) |
| | self.gradient_checkpointing = gradient_checkpointing |
| | logger.info(f"[Decoder3D] gradient_checkpointing: {checkpointing_types}") |
| |
|
| | def forward(self, sample: torch.FloatTensor, memory_state: MemoryState) -> torch.FloatTensor: |
| | r"""The forward method of the `Decoder` class.""" |
| |
|
| | sample = self.conv_in(sample, memory_state=memory_state) |
| |
|
| | |
| | sample = self.mid_block(sample, memory_state=memory_state) |
| |
|
| | |
| | for up_block, sac in zip(self.up_blocks, self.gradient_checkpointing): |
| | sample = gradient_checkpointing( |
| | up_block, |
| | sample, |
| | memory_state=memory_state, |
| | enabled=self.training and sac, |
| | ) |
| |
|
| | |
| | sample = causal_norm_wrapper(self.conv_norm_out, sample) |
| | sample = self.conv_act(sample) |
| | sample = self.conv_out(sample, memory_state=memory_state) |
| |
|
| | return sample |
| |
|
| |
|
| | class VideoAutoencoderKL(nn.Module): |
| | def __init__( |
| | self, |
| | in_channels: int = 3, |
| | out_channels: int = 3, |
| | block_out_channels: Tuple[int] = (64,), |
| | layers_per_block: int = 1, |
| | latent_channels: int = 4, |
| | use_quant_conv: bool = True, |
| | use_post_quant_conv: bool = True, |
| | enc_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), |
| | dec_selective_checkpointing: Tuple[_selective_checkpointing_t] = ("none",), |
| | temporal_scale_num: int = 3, |
| | slicing_up_num: int = 0, |
| | inflation_mode: _inflation_mode_t = "tail", |
| | time_receptive_field: _receptive_field_t = "half", |
| | slicing_sample_min_size: int = None, |
| | spatial_downsample_factor: int = 16, |
| | temporal_downsample_factor: int = 8, |
| | freeze_encoder: bool = False, |
| | ): |
| | super().__init__() |
| | self.spatial_downsample_factor = spatial_downsample_factor |
| | self.temporal_downsample_factor = temporal_downsample_factor |
| | self.freeze_encoder = freeze_encoder |
| | if slicing_sample_min_size is None: |
| | slicing_sample_min_size = temporal_downsample_factor |
| | self.slicing_sample_min_size = slicing_sample_min_size |
| | self.slicing_latent_min_size = slicing_sample_min_size // (2**temporal_scale_num) |
| |
|
| | |
| | self.encoder = Encoder3D( |
| | in_channels=in_channels, |
| | out_channels=latent_channels, |
| | block_out_channels=block_out_channels, |
| | layers_per_block=layers_per_block, |
| | double_z=True, |
| | temporal_down_num=temporal_scale_num, |
| | selective_checkpointing=enc_selective_checkpointing, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| |
|
| | |
| | self.decoder = Decoder3D( |
| | in_channels=latent_channels, |
| | out_channels=out_channels, |
| | block_out_channels=block_out_channels, |
| | layers_per_block=layers_per_block, |
| | |
| | temporal_up_num=temporal_scale_num, |
| | slicing_up_num=slicing_up_num, |
| | selective_checkpointing=dec_selective_checkpointing, |
| | inflation_mode=inflation_mode, |
| | time_receptive_field=time_receptive_field, |
| | ) |
| |
|
| | self.quant_conv = ( |
| | init_causal_conv3d( |
| | in_channels=2 * latent_channels, |
| | out_channels=2 * latent_channels, |
| | kernel_size=1, |
| | inflation_mode=inflation_mode, |
| | ) |
| | if use_quant_conv |
| | else None |
| | ) |
| | self.post_quant_conv = ( |
| | init_causal_conv3d( |
| | in_channels=latent_channels, |
| | out_channels=latent_channels, |
| | kernel_size=1, |
| | inflation_mode=inflation_mode, |
| | ) |
| | if use_post_quant_conv |
| | else None |
| | ) |
| |
|
| | self.use_slicing = False |
| |
|
| | def enable_slicing(self): |
| | self.use_slicing = True |
| |
|
| | def disable_slicing(self): |
| | self.use_slicing = False |
| |
|
| | def encode(self, x: torch.FloatTensor) -> CausalEncoderOutput: |
| | if x.ndim == 4: |
| | x = x.unsqueeze(2) |
| | h = self.slicing_encode(x) |
| | p = DiagonalGaussianDistribution(h) |
| | z = p.sample() |
| | return CausalEncoderOutput(z, p) |
| |
|
| | def decode(self, z: torch.FloatTensor) -> CausalDecoderOutput: |
| | if z.ndim == 4: |
| | z = z.unsqueeze(2) |
| | x = self.slicing_decode(z) |
| | return CausalDecoderOutput(x) |
| |
|
| | def _encode(self, x: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: |
| | x = causal_conv_slice_inputs(x, self.slicing_sample_min_size, memory_state=memory_state) |
| | h = self.encoder(x, memory_state=memory_state) |
| | h = self.quant_conv(h, memory_state=memory_state) if self.quant_conv is not None else h |
| | h = causal_conv_gather_outputs(h) |
| | return h |
| |
|
| | def _decode(self, z: torch.Tensor, memory_state: MemoryState) -> torch.Tensor: |
| | z = causal_conv_slice_inputs(z, self.slicing_latent_min_size, memory_state=memory_state) |
| | z = ( |
| | self.post_quant_conv(z, memory_state=memory_state) |
| | if self.post_quant_conv is not None |
| | else z |
| | ) |
| | x = self.decoder(z, memory_state=memory_state) |
| | x = causal_conv_gather_outputs(x) |
| | return x |
| |
|
| | def slicing_encode(self, x: torch.Tensor) -> torch.Tensor: |
| | sp_size = get_sequence_parallel_world_size() |
| | if self.use_slicing and (x.shape[2] - 1) > self.slicing_sample_min_size * sp_size: |
| | x_slices = x[:, :, 1:].split(split_size=self.slicing_sample_min_size * sp_size, dim=2) |
| | encoded_slices = [ |
| | self._encode( |
| | torch.cat((x[:, :, :1], x_slices[0]), dim=2), |
| | memory_state=MemoryState.INITIALIZING, |
| | ) |
| | ] |
| | for x_idx in range(1, len(x_slices)): |
| | encoded_slices.append( |
| | self._encode(x_slices[x_idx], memory_state=MemoryState.ACTIVE) |
| | ) |
| | return torch.cat(encoded_slices, dim=2) |
| | else: |
| | return self._encode(x, memory_state=MemoryState.DISABLED) |
| |
|
| | def slicing_decode(self, z: torch.Tensor) -> torch.Tensor: |
| | sp_size = get_sequence_parallel_world_size() |
| | if self.use_slicing and (z.shape[2] - 1) > self.slicing_latent_min_size * sp_size: |
| | z_slices = z[:, :, 1:].split(split_size=self.slicing_latent_min_size * sp_size, dim=2) |
| | decoded_slices = [ |
| | self._decode( |
| | torch.cat((z[:, :, :1], z_slices[0]), dim=2), |
| | memory_state=MemoryState.INITIALIZING, |
| | ) |
| | ] |
| | for z_idx in range(1, len(z_slices)): |
| | decoded_slices.append( |
| | self._decode(z_slices[z_idx], memory_state=MemoryState.ACTIVE) |
| | ) |
| | return torch.cat(decoded_slices, dim=2) |
| | else: |
| | return self._decode(z, memory_state=MemoryState.DISABLED) |
| |
|
| | def forward(self, x: torch.FloatTensor) -> CausalAutoencoderOutput: |
| | with torch.no_grad() if self.freeze_encoder else nullcontext(): |
| | z, p = self.encode(x) |
| | x = self.decode(z).sample |
| | return CausalAutoencoderOutput(x, z, p) |
| |
|
| | def preprocess(self, x: torch.Tensor): |
| | |
| | assert x.ndim == 4 or x.size(2) % self.temporal_downsample_factor == 1 |
| | return x |
| |
|
| | def postprocess(self, x: torch.Tensor): |
| | |
| | return x |
| |
|
| | def set_causal_slicing( |
| | self, |
| | *, |
| | split_size: Optional[int], |
| | memory_device: _memory_device_t, |
| | ): |
| | assert ( |
| | split_size is None or memory_device is not None |
| | ), "if split_size is set, memory_device must not be None." |
| | if split_size is not None: |
| | self.enable_slicing() |
| | self.slicing_sample_min_size = split_size |
| | self.slicing_latent_min_size = split_size // self.temporal_downsample_factor |
| | else: |
| | self.disable_slicing() |
| | for module in self.modules(): |
| | if isinstance(module, InflatedCausalConv3d): |
| | module.set_memory_device(memory_device) |
| |
|
| | def set_memory_limit(self, conv_max_mem: Optional[float], norm_max_mem: Optional[float]): |
| | set_norm_limit(norm_max_mem) |
| | for m in self.modules(): |
| | if isinstance(m, InflatedCausalConv3d): |
| | m.set_memory_limit(conv_max_mem if conv_max_mem is not None else float("inf")) |
| |
|
| |
|
| | class VideoAutoencoderKLWrapper(VideoAutoencoderKL): |
| | def __init__( |
| | self, *args, spatial_downsample_factor: int, temporal_downsample_factor: int, **kwargs |
| | ): |
| | self.spatial_downsample_factor = spatial_downsample_factor |
| | self.temporal_downsample_factor = temporal_downsample_factor |
| | super().__init__(*args, **kwargs) |
| |
|
| | def forward(self, x) -> CausalAutoencoderOutput: |
| | z, _, p = self.encode(x) |
| | x, _ = self.decode(z) |
| | return CausalAutoencoderOutput(x, z, None, p) |
| |
|
| | def encode(self, x) -> CausalEncoderOutput: |
| | if x.ndim == 4: |
| | x = x.unsqueeze(2) |
| | p = super().encode(x).latent_dist |
| | z = p.sample().squeeze(2) |
| | return CausalEncoderOutput(z, None, p) |
| |
|
| | def decode(self, z) -> CausalDecoderOutput: |
| | if z.ndim == 4: |
| | z = z.unsqueeze(2) |
| | x = super().decode(z).sample.squeeze(2) |
| | return CausalDecoderOutput(x, None) |
| |
|
| | def preprocess(self, x): |
| | |
| | assert x.ndim == 4 or x.size(2) % 4 == 1 |
| | return x |
| |
|
| | def postprocess(self, x): |
| | |
| | return x |
| |
|
| | def set_causal_slicing( |
| | self, |
| | *, |
| | split_size: Optional[int], |
| | memory_device: Optional[Literal["cpu", "same"]], |
| | ): |
| | assert ( |
| | split_size is None or memory_device is not None |
| | ), "if split_size is set, memory_device must not be None." |
| | if split_size is not None: |
| | self.enable_slicing() |
| | else: |
| | self.disable_slicing() |
| | self.slicing_sample_min_size = split_size |
| | if split_size is not None: |
| | self.slicing_latent_min_size = split_size // self.temporal_downsample_factor |
| | for module in self.modules(): |
| | if isinstance(module, InflatedCausalConv3d): |
| | module.set_memory_device(memory_device) |