"""mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection.""" from __future__ import annotations from torch import Tensor, nn from .dico_block import DiCoBlock from .norms import ChannelWiseRMSNorm from .straight_through_encoder import Patchify class Encoder(nn.Module): """Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w]. Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine) """ def __init__( self, in_channels: int, patch_size: int, model_dim: int, depth: int, bottleneck_dim: int, mlp_ratio: float, depthwise_kernel_size: int, ) -> None: super().__init__() self.patchify = Patchify(in_channels, patch_size, model_dim) self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True) self.blocks = nn.ModuleList( [ DiCoBlock( model_dim, mlp_ratio, depthwise_kernel_size=depthwise_kernel_size, use_external_adaln=False, ) for _ in range(depth) ] ) self.to_bottleneck = nn.Conv2d( model_dim, bottleneck_dim, kernel_size=1, bias=True ) self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False) def forward(self, images: Tensor) -> Tensor: """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w].""" z = self.patchify(images) z = self.norm_in(z) for block in self.blocks: z = block(z) z = self.to_bottleneck(z) z = self.norm_out(z) return z