Spaces:
Configuration error
Configuration error
| # pylint: disable=R0801 | |
| # src/models/unet_3d_blocks.py | |
| """ | |
| This module defines various 3D UNet blocks used in the video model. | |
| The blocks include: | |
| - UNetMidBlock3DCrossAttn: The middle block of the UNet with cross attention. | |
| - CrossAttnDownBlock3D: The downsampling block with cross attention. | |
| - DownBlock3D: The standard downsampling block without cross attention. | |
| - CrossAttnUpBlock3D: The upsampling block with cross attention. | |
| - UpBlock3D: The standard upsampling block without cross attention. | |
| These blocks are used to construct the 3D UNet architecture for video-related tasks. | |
| """ | |
| import torch | |
| from einops import rearrange | |
| from torch import nn | |
| from .motion_module import get_motion_module | |
| from .resnet import Downsample3D, ResnetBlock3D, Upsample3D | |
| from .transformer_3d import Transformer3DModel | |
| def get_down_block( | |
| down_block_type, | |
| num_layers, | |
| in_channels, | |
| out_channels, | |
| temb_channels, | |
| add_downsample, | |
| resnet_eps, | |
| resnet_act_fn, | |
| attn_num_head_channels, | |
| resnet_groups=None, | |
| cross_attention_dim=None, | |
| audio_attention_dim=None, | |
| downsample_padding=None, | |
| dual_cross_attention=False, | |
| use_linear_projection=False, | |
| only_cross_attention=False, | |
| upcast_attention=False, | |
| resnet_time_scale_shift="default", | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| use_inflated_groupnorm=None, | |
| use_motion_module=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| use_audio_module=None, | |
| depth=0, | |
| stack_enable_blocks_name=None, | |
| stack_enable_blocks_depth=None, | |
| ): | |
| """ | |
| Factory function to instantiate a down-block module for the 3D UNet architecture. | |
| Down blocks are used in the downsampling part of the U-Net to reduce the spatial dimensions | |
| of the feature maps while increasing the depth. This function can create blocks with or without | |
| cross attention based on the specified parameters. | |
| Parameters: | |
| - down_block_type (str): The type of down block to instantiate. | |
| - num_layers (int): The number of layers in the block. | |
| - in_channels (int): The number of input channels. | |
| - out_channels (int): The number of output channels. | |
| - temb_channels (int): The number of token embedding channels. | |
| - add_downsample (bool): Flag to add a downsampling layer. | |
| - resnet_eps (float): Epsilon for residual block stability. | |
| - resnet_act_fn (callable): Activation function for the residual block. | |
| - ... (remaining parameters): Additional parameters for configuring the block. | |
| Returns: | |
| - nn.Module: An instance of a down-sampling block module. | |
| """ | |
| down_block_type = ( | |
| down_block_type[7:] | |
| if down_block_type.startswith("UNetRes") | |
| else down_block_type | |
| ) | |
| if down_block_type == "DownBlock3D": | |
| return DownBlock3D( | |
| num_layers=num_layers, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| add_downsample=add_downsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| downsample_padding=downsample_padding, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| use_motion_module=use_motion_module, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if down_block_type == "CrossAttnDownBlock3D": | |
| if cross_attention_dim is None: | |
| raise ValueError( | |
| "cross_attention_dim must be specified for CrossAttnDownBlock3D" | |
| ) | |
| return CrossAttnDownBlock3D( | |
| num_layers=num_layers, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| add_downsample=add_downsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| downsample_padding=downsample_padding, | |
| cross_attention_dim=cross_attention_dim, | |
| audio_attention_dim=audio_attention_dim, | |
| attn_num_head_channels=attn_num_head_channels, | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
| unet_use_temporal_attention=unet_use_temporal_attention, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| use_motion_module=use_motion_module, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| use_audio_module=use_audio_module, | |
| depth=depth, | |
| stack_enable_blocks_name=stack_enable_blocks_name, | |
| stack_enable_blocks_depth=stack_enable_blocks_depth, | |
| ) | |
| raise ValueError(f"{down_block_type} does not exist.") | |
| def get_up_block( | |
| up_block_type, | |
| num_layers, | |
| in_channels, | |
| out_channels, | |
| prev_output_channel, | |
| temb_channels, | |
| add_upsample, | |
| resnet_eps, | |
| resnet_act_fn, | |
| attn_num_head_channels, | |
| resnet_groups=None, | |
| cross_attention_dim=None, | |
| audio_attention_dim=None, | |
| dual_cross_attention=False, | |
| use_linear_projection=False, | |
| only_cross_attention=False, | |
| upcast_attention=False, | |
| resnet_time_scale_shift="default", | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| use_inflated_groupnorm=None, | |
| use_motion_module=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| use_audio_module=None, | |
| depth=0, | |
| stack_enable_blocks_name=None, | |
| stack_enable_blocks_depth=None, | |
| ): | |
| """ | |
| Factory function to instantiate an up-block module for the 3D UNet architecture. | |
| Up blocks are used in the upsampling part of the U-Net to increase the spatial dimensions | |
| of the feature maps while decreasing the depth. This function can create blocks with or without | |
| cross attention based on the specified parameters. | |
| Parameters: | |
| - up_block_type (str): The type of up block to instantiate. | |
| - num_layers (int): The number of layers in the block. | |
| - in_channels (int): The number of input channels. | |
| - out_channels (int): The number of output channels. | |
| - prev_output_channel (int): The number of channels from the previous layer's output. | |
| - temb_channels (int): The number of token embedding channels. | |
| - add_upsample (bool): Flag to add an upsampling layer. | |
| - resnet_eps (float): Epsilon for residual block stability. | |
| - resnet_act_fn (callable): Activation function for the residual block. | |
| - ... (remaining parameters): Additional parameters for configuring the block. | |
| Returns: | |
| - nn.Module: An instance of an up-sampling block module. | |
| """ | |
| up_block_type = ( | |
| up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type | |
| ) | |
| if up_block_type == "UpBlock3D": | |
| return UpBlock3D( | |
| num_layers=num_layers, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| prev_output_channel=prev_output_channel, | |
| temb_channels=temb_channels, | |
| add_upsample=add_upsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| use_motion_module=use_motion_module, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if up_block_type == "CrossAttnUpBlock3D": | |
| if cross_attention_dim is None: | |
| raise ValueError( | |
| "cross_attention_dim must be specified for CrossAttnUpBlock3D" | |
| ) | |
| return CrossAttnUpBlock3D( | |
| num_layers=num_layers, | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| prev_output_channel=prev_output_channel, | |
| temb_channels=temb_channels, | |
| add_upsample=add_upsample, | |
| resnet_eps=resnet_eps, | |
| resnet_act_fn=resnet_act_fn, | |
| resnet_groups=resnet_groups, | |
| cross_attention_dim=cross_attention_dim, | |
| audio_attention_dim=audio_attention_dim, | |
| attn_num_head_channels=attn_num_head_channels, | |
| dual_cross_attention=dual_cross_attention, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| resnet_time_scale_shift=resnet_time_scale_shift, | |
| unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
| unet_use_temporal_attention=unet_use_temporal_attention, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| use_motion_module=use_motion_module, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| use_audio_module=use_audio_module, | |
| depth=depth, | |
| stack_enable_blocks_name=stack_enable_blocks_name, | |
| stack_enable_blocks_depth=stack_enable_blocks_depth, | |
| ) | |
| raise ValueError(f"{up_block_type} does not exist.") | |
| class UNetMidBlock3DCrossAttn(nn.Module): | |
| """ | |
| A 3D UNet middle block with cross attention mechanism. This block is part of the U-Net architecture | |
| and is used for feature extraction in the middle of the downsampling path. | |
| Parameters: | |
| - in_channels (int): Number of input channels. | |
| - temb_channels (int): Number of token embedding channels. | |
| - dropout (float): Dropout rate. | |
| - num_layers (int): Number of layers in the block. | |
| - resnet_eps (float): Epsilon for residual block. | |
| - resnet_time_scale_shift (str): Time scale shift for time embedding normalization. | |
| - resnet_act_fn (str): Activation function for the residual block. | |
| - resnet_groups (int): Number of groups for the convolutions in the residual block. | |
| - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
| - attn_num_head_channels (int): Number of attention heads. | |
| - cross_attention_dim (int): Dimensionality of the cross attention layers. | |
| - audio_attention_dim (int): Dimensionality of the audio attention layers. | |
| - dual_cross_attention (bool): Whether to use dual cross attention. | |
| - use_linear_projection (bool): Whether to use linear projection in attention. | |
| - upcast_attention (bool): Whether to upcast attention to the original input dimension. | |
| - unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net. | |
| - unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net. | |
| - use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
| - use_motion_module (bool): Whether to use motion module. | |
| - motion_module_type (str): Type of motion module. | |
| - motion_module_kwargs (dict): Keyword arguments for the motion module. | |
| - use_audio_module (bool): Whether to use audio module. | |
| - depth (int): Depth of the block in the network. | |
| - stack_enable_blocks_name (str): Name of the stack enable blocks. | |
| - stack_enable_blocks_depth (int): Depth of the stack enable blocks. | |
| Forward method: | |
| The forward method applies the residual blocks, cross attention, and optional motion and audio modules | |
| to the input hidden states. It returns the transformed hidden states. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| attn_num_head_channels=1, | |
| output_scale_factor=1.0, | |
| cross_attention_dim=1280, | |
| audio_attention_dim=1024, | |
| dual_cross_attention=False, | |
| use_linear_projection=False, | |
| upcast_attention=False, | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| use_inflated_groupnorm=None, | |
| use_motion_module=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| use_audio_module=None, | |
| depth=0, | |
| stack_enable_blocks_name=None, | |
| stack_enable_blocks_depth=None, | |
| ): | |
| super().__init__() | |
| self.has_cross_attention = True | |
| self.attn_num_head_channels = attn_num_head_channels | |
| resnet_groups = ( | |
| resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) | |
| ) | |
| # there is always at least one resnet | |
| resnets = [ | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| ) | |
| ] | |
| attentions = [] | |
| motion_modules = [] | |
| audio_modules = [] | |
| for _ in range(num_layers): | |
| if dual_cross_attention: | |
| raise NotImplementedError | |
| attentions.append( | |
| Transformer3DModel( | |
| attn_num_head_channels, | |
| in_channels // attn_num_head_channels, | |
| in_channels=in_channels, | |
| num_layers=1, | |
| cross_attention_dim=cross_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| upcast_attention=upcast_attention, | |
| unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
| unet_use_temporal_attention=unet_use_temporal_attention, | |
| ) | |
| ) | |
| audio_modules.append( | |
| Transformer3DModel( | |
| attn_num_head_channels, | |
| in_channels // attn_num_head_channels, | |
| in_channels=in_channels, | |
| num_layers=1, | |
| cross_attention_dim=audio_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| upcast_attention=upcast_attention, | |
| use_audio_module=use_audio_module, | |
| depth=depth, | |
| unet_block_name="mid", | |
| stack_enable_blocks_name=stack_enable_blocks_name, | |
| stack_enable_blocks_depth=stack_enable_blocks_depth, | |
| ) | |
| if use_audio_module | |
| else None | |
| ) | |
| motion_modules.append( | |
| get_motion_module( | |
| in_channels=in_channels, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if use_motion_module | |
| else None | |
| ) | |
| resnets.append( | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=in_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| ) | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.audio_modules = nn.ModuleList(audio_modules) | |
| self.motion_modules = nn.ModuleList(motion_modules) | |
| def forward( | |
| self, | |
| hidden_states, | |
| temb=None, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| full_mask=None, | |
| face_mask=None, | |
| lip_mask=None, | |
| audio_embedding=None, | |
| motion_scale=None, | |
| ): | |
| """ | |
| Forward pass for the UNetMidBlock3DCrossAttn class. | |
| Args: | |
| self (UNetMidBlock3DCrossAttn): An instance of the UNetMidBlock3DCrossAttn class. | |
| hidden_states (Tensor): The input hidden states tensor. | |
| temb (Tensor, optional): The input temporal embedding tensor. Defaults to None. | |
| encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. | |
| attention_mask (Tensor, optional): The attention mask tensor. Defaults to None. | |
| full_mask (Tensor, optional): The full mask tensor. Defaults to None. | |
| face_mask (Tensor, optional): The face mask tensor. Defaults to None. | |
| lip_mask (Tensor, optional): The lip mask tensor. Defaults to None. | |
| audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None. | |
| Returns: | |
| Tensor: The output tensor after passing through the UNetMidBlock3DCrossAttn layers. | |
| """ | |
| hidden_states = self.resnets[0](hidden_states, temb) | |
| for attn, resnet, audio_module, motion_module in zip( | |
| self.attentions, self.resnets[1:], self.audio_modules, self.motion_modules | |
| ): | |
| hidden_states, motion_frame = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| return_dict=False, | |
| ) # .sample | |
| if len(motion_frame[0]) > 0: | |
| # if motion_frame[0][0].numel() > 0: | |
| motion_frames = motion_frame[0][0] | |
| motion_frames = rearrange( | |
| motion_frames, | |
| "b f (d1 d2) c -> b c f d1 d2", | |
| d1=hidden_states.size(-1), | |
| ) | |
| else: | |
| motion_frames = torch.zeros( | |
| hidden_states.shape[0], | |
| hidden_states.shape[1], | |
| 4, | |
| hidden_states.shape[3], | |
| hidden_states.shape[4], | |
| ) | |
| n_motion_frames = motion_frames.size(2) | |
| if audio_module is not None: | |
| hidden_states = ( | |
| audio_module( | |
| hidden_states, | |
| encoder_hidden_states=audio_embedding, | |
| attention_mask=attention_mask, | |
| full_mask=full_mask, | |
| face_mask=face_mask, | |
| lip_mask=lip_mask, | |
| motion_scale=motion_scale, | |
| return_dict=False, | |
| ) | |
| )[0] # .sample | |
| if motion_module is not None: | |
| motion_frames = motion_frames.to( | |
| device=hidden_states.device, dtype=hidden_states.dtype | |
| ) | |
| _hidden_states = ( | |
| torch.cat([motion_frames, hidden_states], dim=2) | |
| if n_motion_frames > 0 | |
| else hidden_states | |
| ) | |
| hidden_states = motion_module( | |
| _hidden_states, encoder_hidden_states=encoder_hidden_states | |
| ) | |
| hidden_states = hidden_states[:, :, n_motion_frames:] | |
| hidden_states = resnet(hidden_states, temb) | |
| return hidden_states | |
| class CrossAttnDownBlock3D(nn.Module): | |
| """ | |
| A 3D downsampling block with cross attention for the U-Net architecture. | |
| Parameters: | |
| - (same as above, refer to the constructor for details) | |
| Forward method: | |
| The forward method downsamples the input hidden states using residual blocks and cross attention. | |
| It also applies optional motion and audio modules. The method supports gradient checkpointing | |
| to save memory during training. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| attn_num_head_channels=1, | |
| cross_attention_dim=1280, | |
| audio_attention_dim=1024, | |
| output_scale_factor=1.0, | |
| downsample_padding=1, | |
| add_downsample=True, | |
| dual_cross_attention=False, | |
| use_linear_projection=False, | |
| only_cross_attention=False, | |
| upcast_attention=False, | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| use_inflated_groupnorm=None, | |
| use_motion_module=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| use_audio_module=None, | |
| depth=0, | |
| stack_enable_blocks_name=None, | |
| stack_enable_blocks_depth=None, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| attentions = [] | |
| audio_modules = [] | |
| motion_modules = [] | |
| self.has_cross_attention = True | |
| self.attn_num_head_channels = attn_num_head_channels | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| ) | |
| ) | |
| if dual_cross_attention: | |
| raise NotImplementedError | |
| attentions.append( | |
| Transformer3DModel( | |
| attn_num_head_channels, | |
| out_channels // attn_num_head_channels, | |
| in_channels=out_channels, | |
| num_layers=1, | |
| cross_attention_dim=cross_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
| unet_use_temporal_attention=unet_use_temporal_attention, | |
| ) | |
| ) | |
| # TODO:检查维度 | |
| audio_modules.append( | |
| Transformer3DModel( | |
| attn_num_head_channels, | |
| in_channels // attn_num_head_channels, | |
| in_channels=out_channels, | |
| num_layers=1, | |
| cross_attention_dim=audio_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| use_audio_module=use_audio_module, | |
| depth=depth, | |
| unet_block_name="down", | |
| stack_enable_blocks_name=stack_enable_blocks_name, | |
| stack_enable_blocks_depth=stack_enable_blocks_depth, | |
| ) | |
| if use_audio_module | |
| else None | |
| ) | |
| motion_modules.append( | |
| get_motion_module( | |
| in_channels=out_channels, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if use_motion_module | |
| else None | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.audio_modules = nn.ModuleList(audio_modules) | |
| self.motion_modules = nn.ModuleList(motion_modules) | |
| if add_downsample: | |
| self.downsamplers = nn.ModuleList( | |
| [ | |
| Downsample3D( | |
| out_channels, | |
| use_conv=True, | |
| out_channels=out_channels, | |
| padding=downsample_padding, | |
| name="op", | |
| ) | |
| ] | |
| ) | |
| else: | |
| self.downsamplers = None | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states, | |
| temb=None, | |
| encoder_hidden_states=None, | |
| attention_mask=None, | |
| full_mask=None, | |
| face_mask=None, | |
| lip_mask=None, | |
| audio_embedding=None, | |
| motion_scale=None, | |
| ): | |
| """ | |
| Defines the forward pass for the CrossAttnDownBlock3D class. | |
| Parameters: | |
| - hidden_states : torch.Tensor | |
| The input tensor to the block. | |
| temb : torch.Tensor, optional | |
| The token embeddings from the previous block. | |
| encoder_hidden_states : torch.Tensor, optional | |
| The hidden states from the encoder. | |
| attention_mask : torch.Tensor, optional | |
| The attention mask for the cross-attention mechanism. | |
| full_mask : torch.Tensor, optional | |
| The full mask for the cross-attention mechanism. | |
| face_mask : torch.Tensor, optional | |
| The face mask for the cross-attention mechanism. | |
| lip_mask : torch.Tensor, optional | |
| The lip mask for the cross-attention mechanism. | |
| audio_embedding : torch.Tensor, optional | |
| The audio embedding for the cross-attention mechanism. | |
| Returns: | |
| -- torch.Tensor | |
| The output tensor from the block. | |
| """ | |
| output_states = () | |
| for _, (resnet, attn, audio_module, motion_module) in enumerate( | |
| zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules) | |
| ): | |
| # self.gradient_checkpointing = False | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module, return_dict=None): | |
| def custom_forward(*inputs): | |
| if return_dict is not None: | |
| return module(*inputs, return_dict=return_dict) | |
| return module(*inputs) | |
| return custom_forward | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(resnet), hidden_states, temb | |
| ) | |
| motion_frames = [] | |
| hidden_states, motion_frame = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(attn, return_dict=False), | |
| hidden_states, | |
| encoder_hidden_states, | |
| ) | |
| if len(motion_frame[0]) > 0: | |
| motion_frames = motion_frame[0][0] | |
| # motion_frames = torch.cat(motion_frames, dim=0) | |
| motion_frames = rearrange( | |
| motion_frames, | |
| "b f (d1 d2) c -> b c f d1 d2", | |
| d1=hidden_states.size(-1), | |
| ) | |
| else: | |
| motion_frames = torch.zeros( | |
| hidden_states.shape[0], | |
| hidden_states.shape[1], | |
| 4, | |
| hidden_states.shape[3], | |
| hidden_states.shape[4], | |
| ) | |
| n_motion_frames = motion_frames.size(2) | |
| if audio_module is not None: | |
| # audio_embedding = audio_embedding | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(audio_module, return_dict=False), | |
| hidden_states, | |
| audio_embedding, | |
| attention_mask, | |
| full_mask, | |
| face_mask, | |
| lip_mask, | |
| motion_scale, | |
| )[0] | |
| # add motion module | |
| if motion_module is not None: | |
| motion_frames = motion_frames.to( | |
| device=hidden_states.device, dtype=hidden_states.dtype | |
| ) | |
| _hidden_states = torch.cat( | |
| [motion_frames, hidden_states], dim=2 | |
| ) # if n_motion_frames > 0 else hidden_states | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(motion_module), | |
| _hidden_states, | |
| encoder_hidden_states, | |
| ) | |
| hidden_states = hidden_states[:, :, n_motion_frames:] | |
| else: | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| ).sample | |
| if audio_module is not None: | |
| hidden_states = audio_module( | |
| hidden_states, | |
| audio_embedding, | |
| attention_mask=attention_mask, | |
| full_mask=full_mask, | |
| face_mask=face_mask, | |
| lip_mask=lip_mask, | |
| return_dict=False, | |
| )[0] | |
| # add motion module | |
| if motion_module is not None: | |
| hidden_states = motion_module( | |
| hidden_states, encoder_hidden_states=encoder_hidden_states | |
| ) | |
| output_states += (hidden_states,) | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states) | |
| output_states += (hidden_states,) | |
| return hidden_states, output_states | |
| class DownBlock3D(nn.Module): | |
| """ | |
| A 3D downsampling block for the U-Net architecture. This block performs downsampling operations | |
| using residual blocks and an optional motion module. | |
| Parameters: | |
| - in_channels (int): Number of input channels. | |
| - out_channels (int): Number of output channels. | |
| - temb_channels (int): Number of token embedding channels. | |
| - dropout (float): Dropout rate for the block. | |
| - num_layers (int): Number of layers in the block. | |
| - resnet_eps (float): Epsilon for residual block stability. | |
| - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. | |
| - resnet_act_fn (str): Activation function used in the residual block. | |
| - resnet_groups (int): Number of groups for the convolutions in the residual block. | |
| - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
| - output_scale_factor (float): Scaling factor for the block's output. | |
| - add_downsample (bool): Whether to add a downsampling layer. | |
| - downsample_padding (int): Padding for the downsampling layer. | |
| - use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
| - use_motion_module (bool): Whether to include a motion module. | |
| - motion_module_type (str): Type of motion module to use. | |
| - motion_module_kwargs (dict): Keyword arguments for the motion module. | |
| Forward method: | |
| The forward method processes the input hidden states through the residual blocks and optional | |
| motion modules, followed by an optional downsampling step. It supports gradient checkpointing | |
| during training to reduce memory usage. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| output_scale_factor=1.0, | |
| add_downsample=True, | |
| downsample_padding=1, | |
| use_inflated_groupnorm=None, | |
| use_motion_module=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| motion_modules = [] | |
| # use_motion_module = False | |
| for i in range(num_layers): | |
| in_channels = in_channels if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock3D( | |
| in_channels=in_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| ) | |
| ) | |
| motion_modules.append( | |
| get_motion_module( | |
| in_channels=out_channels, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if use_motion_module | |
| else None | |
| ) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.motion_modules = nn.ModuleList(motion_modules) | |
| if add_downsample: | |
| self.downsamplers = nn.ModuleList( | |
| [ | |
| Downsample3D( | |
| out_channels, | |
| use_conv=True, | |
| out_channels=out_channels, | |
| padding=downsample_padding, | |
| name="op", | |
| ) | |
| ] | |
| ) | |
| else: | |
| self.downsamplers = None | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states, | |
| temb=None, | |
| encoder_hidden_states=None, | |
| ): | |
| """ | |
| forward method for the DownBlock3D class. | |
| Args: | |
| hidden_states (Tensor): The input tensor to the DownBlock3D layer. | |
| temb (Tensor, optional): The token embeddings, if using transformer. | |
| encoder_hidden_states (Tensor, optional): The hidden states from the encoder. | |
| Returns: | |
| Tensor: The output tensor after passing through the DownBlock3D layer. | |
| """ | |
| output_states = () | |
| for resnet, motion_module in zip(self.resnets, self.motion_modules): | |
| # print(f"DownBlock3D {self.gradient_checkpointing = }") | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(resnet), hidden_states, temb | |
| ) | |
| else: | |
| hidden_states = resnet(hidden_states, temb) | |
| # add motion module | |
| hidden_states = ( | |
| motion_module( | |
| hidden_states, encoder_hidden_states=encoder_hidden_states | |
| ) | |
| if motion_module is not None | |
| else hidden_states | |
| ) | |
| output_states += (hidden_states,) | |
| if self.downsamplers is not None: | |
| for downsampler in self.downsamplers: | |
| hidden_states = downsampler(hidden_states) | |
| output_states += (hidden_states,) | |
| return hidden_states, output_states | |
| class CrossAttnUpBlock3D(nn.Module): | |
| """ | |
| Standard 3D downsampling block for the U-Net architecture. This block performs downsampling | |
| operations in the U-Net using residual blocks and an optional motion module. | |
| Parameters: | |
| - in_channels (int): Number of input channels. | |
| - out_channels (int): Number of output channels. | |
| - temb_channels (int): Number of channels for the temporal embedding. | |
| - dropout (float): Dropout rate for the block. | |
| - num_layers (int): Number of layers in the block. | |
| - resnet_eps (float): Epsilon for residual block stability. | |
| - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. | |
| - resnet_act_fn (str): Activation function used in the residual block. | |
| - resnet_groups (int): Number of groups for the convolutions in the residual block. | |
| - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
| - output_scale_factor (float): Scaling factor for the block's output. | |
| - add_downsample (bool): Whether to add a downsampling layer. | |
| - downsample_padding (int): Padding for the downsampling layer. | |
| - use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
| - use_motion_module (bool): Whether to include a motion module. | |
| - motion_module_type (str): Type of motion module to use. | |
| - motion_module_kwargs (dict): Keyword arguments for the motion module. | |
| Forward method: | |
| The forward method processes the input hidden states through the residual blocks and optional | |
| motion modules, followed by an optional downsampling step. It supports gradient checkpointing | |
| during training to reduce memory usage. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| prev_output_channel: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| attn_num_head_channels=1, | |
| cross_attention_dim=1280, | |
| audio_attention_dim=1024, | |
| output_scale_factor=1.0, | |
| add_upsample=True, | |
| dual_cross_attention=False, | |
| use_linear_projection=False, | |
| only_cross_attention=False, | |
| upcast_attention=False, | |
| unet_use_cross_frame_attention=None, | |
| unet_use_temporal_attention=None, | |
| use_motion_module=None, | |
| use_inflated_groupnorm=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| use_audio_module=None, | |
| depth=0, | |
| stack_enable_blocks_name=None, | |
| stack_enable_blocks_depth=None, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| attentions = [] | |
| audio_modules = [] | |
| motion_modules = [] | |
| self.has_cross_attention = True | |
| self.attn_num_head_channels = attn_num_head_channels | |
| for i in range(num_layers): | |
| res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock3D( | |
| in_channels=resnet_in_channels + res_skip_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| ) | |
| ) | |
| if dual_cross_attention: | |
| raise NotImplementedError | |
| attentions.append( | |
| Transformer3DModel( | |
| attn_num_head_channels, | |
| out_channels // attn_num_head_channels, | |
| in_channels=out_channels, | |
| num_layers=1, | |
| cross_attention_dim=cross_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| unet_use_cross_frame_attention=unet_use_cross_frame_attention, | |
| unet_use_temporal_attention=unet_use_temporal_attention, | |
| ) | |
| ) | |
| audio_modules.append( | |
| Transformer3DModel( | |
| attn_num_head_channels, | |
| in_channels // attn_num_head_channels, | |
| in_channels=out_channels, | |
| num_layers=1, | |
| cross_attention_dim=audio_attention_dim, | |
| norm_num_groups=resnet_groups, | |
| use_linear_projection=use_linear_projection, | |
| only_cross_attention=only_cross_attention, | |
| upcast_attention=upcast_attention, | |
| use_audio_module=use_audio_module, | |
| depth=depth, | |
| unet_block_name="up", | |
| stack_enable_blocks_name=stack_enable_blocks_name, | |
| stack_enable_blocks_depth=stack_enable_blocks_depth, | |
| ) | |
| if use_audio_module | |
| else None | |
| ) | |
| motion_modules.append( | |
| get_motion_module( | |
| in_channels=out_channels, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if use_motion_module | |
| else None | |
| ) | |
| self.attentions = nn.ModuleList(attentions) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.audio_modules = nn.ModuleList(audio_modules) | |
| self.motion_modules = nn.ModuleList(motion_modules) | |
| if add_upsample: | |
| self.upsamplers = nn.ModuleList( | |
| [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] | |
| ) | |
| else: | |
| self.upsamplers = None | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states, | |
| res_hidden_states_tuple, | |
| temb=None, | |
| encoder_hidden_states=None, | |
| upsample_size=None, | |
| attention_mask=None, | |
| full_mask=None, | |
| face_mask=None, | |
| lip_mask=None, | |
| audio_embedding=None, | |
| motion_scale=None, | |
| ): | |
| """ | |
| Forward pass for the CrossAttnUpBlock3D class. | |
| Args: | |
| self (CrossAttnUpBlock3D): An instance of the CrossAttnUpBlock3D class. | |
| hidden_states (Tensor): The input hidden states tensor. | |
| res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors. | |
| temb (Tensor, optional): The token embeddings tensor. Defaults to None. | |
| encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. | |
| upsample_size (int, optional): The upsample size. Defaults to None. | |
| attention_mask (Tensor, optional): The attention mask tensor. Defaults to None. | |
| full_mask (Tensor, optional): The full mask tensor. Defaults to None. | |
| face_mask (Tensor, optional): The face mask tensor. Defaults to None. | |
| lip_mask (Tensor, optional): The lip mask tensor. Defaults to None. | |
| audio_embedding (Tensor, optional): The audio embedding tensor. Defaults to None. | |
| Returns: | |
| Tensor: The output tensor after passing through the CrossAttnUpBlock3D. | |
| """ | |
| for _, (resnet, attn, audio_module, motion_module) in enumerate( | |
| zip(self.resnets, self.attentions, self.audio_modules, self.motion_modules) | |
| ): | |
| # pop res hidden states | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module, return_dict=None): | |
| def custom_forward(*inputs): | |
| if return_dict is not None: | |
| return module(*inputs, return_dict=return_dict) | |
| return module(*inputs) | |
| return custom_forward | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(resnet), hidden_states, temb | |
| ) | |
| motion_frames = [] | |
| hidden_states, motion_frame = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(attn, return_dict=False), | |
| hidden_states, | |
| encoder_hidden_states, | |
| ) | |
| if len(motion_frame[0]) > 0: | |
| motion_frames = motion_frame[0][0] | |
| # motion_frames = torch.cat(motion_frames, dim=0) | |
| motion_frames = rearrange( | |
| motion_frames, | |
| "b f (d1 d2) c -> b c f d1 d2", | |
| d1=hidden_states.size(-1), | |
| ) | |
| else: | |
| motion_frames = torch.zeros( | |
| hidden_states.shape[0], | |
| hidden_states.shape[1], | |
| 4, | |
| hidden_states.shape[3], | |
| hidden_states.shape[4], | |
| ) | |
| n_motion_frames = motion_frames.size(2) | |
| if audio_module is not None: | |
| # audio_embedding = audio_embedding | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(audio_module, return_dict=False), | |
| hidden_states, | |
| audio_embedding, | |
| attention_mask, | |
| full_mask, | |
| face_mask, | |
| lip_mask, | |
| motion_scale, | |
| )[0] | |
| # add motion module | |
| if motion_module is not None: | |
| motion_frames = motion_frames.to( | |
| device=hidden_states.device, dtype=hidden_states.dtype | |
| ) | |
| _hidden_states = ( | |
| torch.cat([motion_frames, hidden_states], dim=2) | |
| if n_motion_frames > 0 | |
| else hidden_states | |
| ) | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(motion_module), | |
| _hidden_states, | |
| encoder_hidden_states, | |
| ) | |
| hidden_states = hidden_states[:, :, n_motion_frames:] | |
| else: | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = attn( | |
| hidden_states, | |
| encoder_hidden_states=encoder_hidden_states, | |
| ).sample | |
| if audio_module is not None: | |
| hidden_states = ( | |
| audio_module( | |
| hidden_states, | |
| encoder_hidden_states=audio_embedding, | |
| attention_mask=attention_mask, | |
| full_mask=full_mask, | |
| face_mask=face_mask, | |
| lip_mask=lip_mask, | |
| ) | |
| ).sample | |
| # add motion module | |
| hidden_states = ( | |
| motion_module( | |
| hidden_states, encoder_hidden_states=encoder_hidden_states | |
| ) | |
| if motion_module is not None | |
| else hidden_states | |
| ) | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, upsample_size) | |
| return hidden_states | |
| class UpBlock3D(nn.Module): | |
| """ | |
| 3D upsampling block with cross attention for the U-Net architecture. This block performs | |
| upsampling operations and incorporates cross attention mechanisms, which allow the model to | |
| focus on different parts of the input when upscaling. | |
| Parameters: | |
| - in_channels (int): Number of input channels. | |
| - out_channels (int): Number of output channels. | |
| - prev_output_channel (int): Number of channels from the previous layer's output. | |
| - temb_channels (int): Number of channels for the temporal embedding. | |
| - dropout (float): Dropout rate for the block. | |
| - num_layers (int): Number of layers in the block. | |
| - resnet_eps (float): Epsilon for residual block stability. | |
| - resnet_time_scale_shift (str): Time scale shift for the residual block's time embedding. | |
| - resnet_act_fn (str): Activation function used in the residual block. | |
| - resnet_groups (int): Number of groups for the convolutions in the residual block. | |
| - resnet_pre_norm (bool): Whether to use pre-normalization in the residual block. | |
| - attn_num_head_channels (int): Number of attention heads for the cross attention mechanism. | |
| - cross_attention_dim (int): Dimensionality of the cross attention layers. | |
| - audio_attention_dim (int): Dimensionality of the audio attention layers. | |
| - output_scale_factor (float): Scaling factor for the block's output. | |
| - add_upsample (bool): Whether to add an upsampling layer. | |
| - dual_cross_attention (bool): Whether to use dual cross attention (not implemented). | |
| - use_linear_projection (bool): Whether to use linear projection in the cross attention. | |
| - only_cross_attention (bool): Whether to use only cross attention (no self-attention). | |
| - upcast_attention (bool): Whether to upcast attention to the original input dimension. | |
| - unet_use_cross_frame_attention (bool): Whether to use cross frame attention in U-Net. | |
| - unet_use_temporal_attention (bool): Whether to use temporal attention in U-Net. | |
| - use_motion_module (bool): Whether to include a motion module. | |
| - use_inflated_groupnorm (bool): Whether to use inflated group normalization. | |
| - motion_module_type (str): Type of motion module to use. | |
| - motion_module_kwargs (dict): Keyword arguments for the motion module. | |
| - use_audio_module (bool): Whether to include an audio module. | |
| - depth (int): Depth of the block in the network. | |
| - stack_enable_blocks_name (str): Name of the stack enable blocks. | |
| - stack_enable_blocks_depth (int): Depth of the stack enable blocks. | |
| Forward method: | |
| The forward method upsamples the input hidden states and residual hidden states, processes | |
| them through the residual and cross attention blocks, and optional motion and audio modules. | |
| It supports gradient checkpointing during training. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| prev_output_channel: int, | |
| out_channels: int, | |
| temb_channels: int, | |
| dropout: float = 0.0, | |
| num_layers: int = 1, | |
| resnet_eps: float = 1e-6, | |
| resnet_time_scale_shift: str = "default", | |
| resnet_act_fn: str = "swish", | |
| resnet_groups: int = 32, | |
| resnet_pre_norm: bool = True, | |
| output_scale_factor=1.0, | |
| add_upsample=True, | |
| use_inflated_groupnorm=None, | |
| use_motion_module=None, | |
| motion_module_type=None, | |
| motion_module_kwargs=None, | |
| ): | |
| super().__init__() | |
| resnets = [] | |
| motion_modules = [] | |
| # use_motion_module = False | |
| for i in range(num_layers): | |
| res_skip_channels = in_channels if (i == num_layers - 1) else out_channels | |
| resnet_in_channels = prev_output_channel if i == 0 else out_channels | |
| resnets.append( | |
| ResnetBlock3D( | |
| in_channels=resnet_in_channels + res_skip_channels, | |
| out_channels=out_channels, | |
| temb_channels=temb_channels, | |
| eps=resnet_eps, | |
| groups=resnet_groups, | |
| dropout=dropout, | |
| time_embedding_norm=resnet_time_scale_shift, | |
| non_linearity=resnet_act_fn, | |
| output_scale_factor=output_scale_factor, | |
| pre_norm=resnet_pre_norm, | |
| use_inflated_groupnorm=use_inflated_groupnorm, | |
| ) | |
| ) | |
| motion_modules.append( | |
| get_motion_module( | |
| in_channels=out_channels, | |
| motion_module_type=motion_module_type, | |
| motion_module_kwargs=motion_module_kwargs, | |
| ) | |
| if use_motion_module | |
| else None | |
| ) | |
| self.resnets = nn.ModuleList(resnets) | |
| self.motion_modules = nn.ModuleList(motion_modules) | |
| if add_upsample: | |
| self.upsamplers = nn.ModuleList( | |
| [Upsample3D(out_channels, use_conv=True, out_channels=out_channels)] | |
| ) | |
| else: | |
| self.upsamplers = None | |
| self.gradient_checkpointing = False | |
| def forward( | |
| self, | |
| hidden_states, | |
| res_hidden_states_tuple, | |
| temb=None, | |
| upsample_size=None, | |
| encoder_hidden_states=None, | |
| ): | |
| """ | |
| Forward pass for the UpBlock3D class. | |
| Args: | |
| self (UpBlock3D): An instance of the UpBlock3D class. | |
| hidden_states (Tensor): The input hidden states tensor. | |
| res_hidden_states_tuple (Tuple[Tensor]): A tuple of residual hidden states tensors. | |
| temb (Tensor, optional): The token embeddings tensor. Defaults to None. | |
| upsample_size (int, optional): The upsample size. Defaults to None. | |
| encoder_hidden_states (Tensor, optional): The encoder hidden states tensor. Defaults to None. | |
| Returns: | |
| Tensor: The output tensor after passing through the UpBlock3D layers. | |
| """ | |
| for resnet, motion_module in zip(self.resnets, self.motion_modules): | |
| # pop res hidden states | |
| res_hidden_states = res_hidden_states_tuple[-1] | |
| res_hidden_states_tuple = res_hidden_states_tuple[:-1] | |
| hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) | |
| # print(f"UpBlock3D {self.gradient_checkpointing = }") | |
| if self.training and self.gradient_checkpointing: | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| return module(*inputs) | |
| return custom_forward | |
| hidden_states = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(resnet), hidden_states, temb | |
| ) | |
| else: | |
| hidden_states = resnet(hidden_states, temb) | |
| hidden_states = ( | |
| motion_module( | |
| hidden_states, encoder_hidden_states=encoder_hidden_states | |
| ) | |
| if motion_module is not None | |
| else hidden_states | |
| ) | |
| if self.upsamplers is not None: | |
| for upsampler in self.upsamplers: | |
| hidden_states = upsampler(hidden_states, upsample_size) | |
| return hidden_states | |