| from typing import Tuple |
|
|
| import torch.nn as nn |
| import torch.nn.functional as F |
| import torch.nn.init as init |
| from diffusers.models.modeling_utils import ModelMixin |
| from einops import rearrange |
|
|
| def zero_module(module): |
| for p in module.parameters(): |
| nn.init.zeros_(p) |
| return module |
|
|
|
|
| class PoseGuider(ModelMixin): |
| def __init__( |
| self, |
| conditioning_embedding_channels: int, |
| conditioning_channels: int = 3, |
| block_out_channels: Tuple[int] = (16, 32, 96, 256), |
| ): |
| super().__init__() |
| self.conv_in = nn.Conv2d( |
| conditioning_channels, block_out_channels[0], kernel_size=3, padding=1 |
| ) |
|
|
| self.blocks = nn.ModuleList([]) |
|
|
| for i in range(len(block_out_channels) - 1): |
| channel_in = block_out_channels[i] |
| channel_out = block_out_channels[i + 1] |
| self.blocks.append( |
| nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1) |
| ) |
| self.blocks.append( |
| nn.Conv2d( |
| channel_in, channel_out, kernel_size=3, padding=1, stride=2 |
| ) |
| ) |
| |
| self.out = zero_module( |
| nn.Linear( |
| block_out_channels[-1]*4, |
| conditioning_embedding_channels, |
| ) |
| ) |
|
|
| def forward(self, conditioning): |
| embedding = self.conv_in(conditioning) |
| embedding = F.silu(embedding) |
|
|
| for block in self.blocks: |
| embedding = block(embedding) |
| embedding = F.silu(embedding) |
| embedding = rearrange(embedding, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) |
| embedding = self.out(embedding) |
|
|
| return embedding |
|
|
| if __name__ == "__main__": |
| import torch |
| model = PoseGuider(conditioning_embedding_channels=3072, block_out_channels = (16, 32, 96, 256)) |
| inp = torch.randn((4, 3, 1024, 768)) |
| out = model(inp) |
|
|
|
|