StableFaceEmotion / guidance_encoder.py
ValerianFourel
no relative
44bee85
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from diffusers.models.modeling_utils import ModelMixin
from diffusers.utils import BaseOutput
from dataclasses import dataclass
from motion_module import zero_module
from resnet import InflatedConv3d, InflatedGroupNorm
from attention import TemporalBasicTransformerBlock
from transformer_3d import Transformer3DModel
class GuidanceEncoder(ModelMixin):
def __init__(
self,
guidance_embedding_channels: int,
guidance_input_channels: int = 3,
block_out_channels: Tuple[int] = (16, 32, 96, 256),
attention_num_heads: int = 8,
):
super().__init__()
self.guidance_input_channels = guidance_input_channels
self.conv_in = InflatedConv3d(
guidance_input_channels, block_out_channels[0], kernel_size=3, padding=1
)
self.blocks = nn.ModuleList([])
self.attentions = nn.ModuleList([])
for i in range(len(block_out_channels) - 1):
channel_in = block_out_channels[i]
channel_out = block_out_channels[i + 1]
self.blocks.append(
InflatedConv3d(channel_in, channel_in, kernel_size=3, padding=1)
)
self.attentions.append(
Transformer3DModel(
attention_num_heads,
channel_in // attention_num_heads,
channel_in,
norm_num_groups=1,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
)
)
self.blocks.append(
InflatedConv3d(
channel_in, channel_out, kernel_size=3, padding=1, stride=2
)
)
self.attentions.append(
Transformer3DModel(
attention_num_heads,
channel_out // attention_num_heads,
channel_out,
norm_num_groups=32,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
)
)
attention_channel_out = block_out_channels[-1]
self.guidance_attention = Transformer3DModel(
attention_num_heads,
attention_channel_out // attention_num_heads,
attention_channel_out,
norm_num_groups=32,
unet_use_cross_frame_attention=False,
unet_use_temporal_attention=False,
)
self.conv_out = zero_module(
InflatedConv3d(
block_out_channels[-1],
guidance_embedding_channels,
kernel_size=3,
padding=1,
)
)
def forward(self, condition):
embedding = self.conv_in(condition)
embedding = F.silu(embedding)
for block in self.blocks:
embedding = block(embedding)
embedding = F.silu(embedding)
# FIXME: Temporarily only use the last attention.
embedding = self.attentions[-1](embedding).sample
embedding = self.conv_out(embedding)
return embedding