from typing import Tuple import torch import torch.nn as nn import torch.nn.functional as F import torch.nn.init as init from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import BaseOutput from dataclasses import dataclass from motion_module import zero_module from resnet import InflatedConv3d, InflatedGroupNorm from attention import TemporalBasicTransformerBlock from transformer_3d import Transformer3DModel class GuidanceEncoder(ModelMixin): def __init__( self, guidance_embedding_channels: int, guidance_input_channels: int = 3, block_out_channels: Tuple[int] = (16, 32, 96, 256), attention_num_heads: int = 8, ): super().__init__() self.guidance_input_channels = guidance_input_channels self.conv_in = InflatedConv3d( guidance_input_channels, block_out_channels[0], kernel_size=3, padding=1 ) self.blocks = nn.ModuleList([]) self.attentions = nn.ModuleList([]) for i in range(len(block_out_channels) - 1): channel_in = block_out_channels[i] channel_out = block_out_channels[i + 1] self.blocks.append( InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1) ) self.attentions.append( Transformer3DModel( attention_num_heads, channel_in // attention_num_heads, channel_in, norm_num_groups=1, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, ) ) self.blocks.append( InflatedConv3d( channel_in, channel_out, kernel_size=3, padding=1, stride=2 ) ) self.attentions.append( Transformer3DModel( attention_num_heads, channel_out // attention_num_heads, channel_out, norm_num_groups=32, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, ) ) attention_channel_out = block_out_channels[-1] self.guidance_attention = Transformer3DModel( attention_num_heads, attention_channel_out // attention_num_heads, attention_channel_out, norm_num_groups=32, unet_use_cross_frame_attention=False, unet_use_temporal_attention=False, ) self.conv_out = zero_module( InflatedConv3d( block_out_channels[-1], guidance_embedding_channels, kernel_size=3, padding=1, ) ) def forward(self, condition): embedding = self.conv_in(condition) embedding = F.silu(embedding) for block in self.blocks: embedding = block(embedding) embedding = F.silu(embedding) # FIXME: Temporarily only use the last attention. embedding = self.attentions[-1](embedding).sample embedding = self.conv_out(embedding) return embedding