| """mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection.""" | |
| from __future__ import annotations | |
| from torch import Tensor, nn | |
| from .dico_block import DiCoBlock | |
| from .norms import ChannelWiseRMSNorm | |
| from .straight_through_encoder import Patchify | |
| class Encoder(nn.Module): | |
| """Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w]. | |
| Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine) | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| patch_size: int, | |
| model_dim: int, | |
| depth: int, | |
| bottleneck_dim: int, | |
| mlp_ratio: float, | |
| depthwise_kernel_size: int, | |
| ) -> None: | |
| super().__init__() | |
| self.patchify = Patchify(in_channels, patch_size, model_dim) | |
| self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) | |
| self.blocks = nn.ModuleList( | |
| [ | |
| DiCoBlock( | |
| model_dim, | |
| mlp_ratio, | |
| depthwise_kernel_size=depthwise_kernel_size, | |
| use_external_adaln=False, | |
| ) | |
| for _ in range(depth) | |
| ] | |
| ) | |
| self.to_bottleneck = nn.Conv2d( | |
| model_dim, bottleneck_dim, kernel_size=1, bias=True | |
| ) | |
| self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False) | |
| def forward(self, images: Tensor) -> Tensor: | |
| """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].""" | |
| z = self.patchify(images) | |
| z = self.norm_in(z) | |
| for block in self.blocks: | |
| z = block(z) | |
| z = self.to_bottleneck(z) | |
| z = self.norm_out(z) | |
| return z | |