File size: 1,741 Bytes
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""mDiffAE encoder: patchify -> DiCoBlocks -> bottleneck projection."""

from __future__ import annotations

from torch import Tensor, nn

from .dico_block import DiCoBlock
from .norms import ChannelWiseRMSNorm
from .straight_through_encoder import Patchify


class Encoder(nn.Module):
    """Deterministic encoder: Image [B,3,H,W] -> latents [B,bottleneck_dim,h,w].

    Pipeline: Patchify -> RMSNorm -> DiCoBlocks (unconditioned) -> Conv1x1 -> RMSNorm(no affine)
    """

    def __init__(
        self,
        in_channels: int,
        patch_size: int,
        model_dim: int,
        depth: int,
        bottleneck_dim: int,
        mlp_ratio: float,
        depthwise_kernel_size: int,
    ) -> None:
        super().__init__()
        self.patchify = Patchify(in_channels, patch_size, model_dim)
        self.norm_in = ChannelWiseRMSNorm(model_dim, eps=1e-6, affine=True)
        self.blocks = nn.ModuleList(
            [
                DiCoBlock(
                    model_dim,
                    mlp_ratio,
                    depthwise_kernel_size=depthwise_kernel_size,
                    use_external_adaln=False,
                )
                for _ in range(depth)
            ]
        )
        self.to_bottleneck = nn.Conv2d(
            model_dim, bottleneck_dim, kernel_size=1, bias=True
        )
        self.norm_out = ChannelWiseRMSNorm(bottleneck_dim, eps=1e-6, affine=False)

    def forward(self, images: Tensor) -> Tensor:
        """Encode images [B,3,H,W] in [-1,1] to latents [B,bottleneck_dim,h,w]."""
        z = self.patchify(images)
        z = self.norm_in(z)
        for block in self.blocks:
            z = block(z)
        z = self.to_bottleneck(z)
        z = self.norm_out(z)
        return z