mdiffae-v1 / m_diffae /encoder.py
data-archetype's picture
Upload folder using huggingface_hub
128cb34 verified
"""mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection."""
from __future__ import annotations
from torch import Tensor, nn
from .dico_block import DiCoBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify
class Encoder(nn.Module):
"""Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].
Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine)
"""
def __init__(
self,
in_channels: int,
patch_size: int,
model_dim: int,
depth: int,
bottleneck_dim: int,
mlp_ratio: float,
depthwise_kernel_size: int,
) -> None:
super().__init__()
self.patchify = Patchify(in_channels, patch_size, model_dim)
self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
self.blocks = nn.ModuleList(
[
DiCoBlock(
model_dim,
mlp_ratio,
depthwise_kernel_size=depthwise_kernel_size,
use_external_adaln=False,
)
for _ in range(depth)
]
)
self.to_bottleneck = nn.Conv2d(
model_dim, bottleneck_dim, kernel_size=1, bias=True
)
self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)
def forward(self, images: Tensor) -> Tensor:
"""Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]."""
z = self.patchify(images)
z = self.norm_in(z)
for block in self.blocks:
z = block(z)
z = self.to_bottleneck(z)
z = self.norm_out(z)
return z