File size: 4,063 Bytes
1ed770c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""DiCo block: conv path (1x1 -> depthwise -> SiLU -> CCA -> 1x1) + GELU MLP."""

from __future__ import annotations

import torch
import torch.nn.functional as F
from torch import Tensor, nn

from .compact_channel_attention import CompactChannelAttention
from .conv_mlp import ConvMLP
from .norms import ChannelWiseRMSNorm


class DiCoBlock(nn.Module):
    """DiCo-style conv block with optional external AdaLN conditioning.

    Two modes:
    - Unconditioned (encoder): uses learned per-channel residual gates.
    - External AdaLN (decoder): receives packed modulation [B, 4*C] via adaln_m.
    """

    def __init__(
        self,
        channels: int,
        mlp_ratio: float,
        *,
        depthwise_kernel_size: int = 7,
        use_external_adaln: bool = False,
        norm_eps: float = 1e-6,
    ) -> None:
        super().__init__()
        self.channels = int(channels)
        self.use_external_adaln = bool(use_external_adaln)

        # Pre-norm for conv and MLP paths (no affine)
        self.norm1 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)
        self.norm2 = ChannelWiseRMSNorm(self.channels, eps=norm_eps, affine=False)

        # Conv path: 1x1 -> depthwise kxk -> SiLU -> CCA -> 1x1
        self.conv1 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
        self.conv2 = nn.Conv2d(
            self.channels,
            self.channels,
            kernel_size=depthwise_kernel_size,
            padding=depthwise_kernel_size // 2,
            groups=self.channels,
            bias=True,
        )
        self.conv3 = nn.Conv2d(self.channels, self.channels, kernel_size=1, bias=True)
        self.cca = CompactChannelAttention(self.channels)

        # MLP path: GELU activation
        hidden_channels = max(int(round(float(self.channels) * mlp_ratio)), 1)
        self.mlp = ConvMLP(self.channels, hidden_channels, norm_eps=norm_eps)

        # Conditioning: learned gates (encoder) or external adaln_m (decoder)
        if not self.use_external_adaln:
            self.gate_attn = nn.Parameter(torch.zeros(self.channels))
            self.gate_mlp = nn.Parameter(torch.zeros(self.channels))

    def forward(self, x: Tensor, *, adaln_m: Tensor | None = None) -> Tensor:
        b, c = x.shape[:2]

        if self.use_external_adaln:
            if adaln_m is None:
                raise ValueError(
                    "adaln_m required for externally-conditioned DiCoBlock"
                )
            adaln_m_cast = adaln_m.to(device=x.device, dtype=x.dtype)
            scale_a, gate_a, scale_m, gate_m = adaln_m_cast.chunk(4, dim=-1)
        elif adaln_m is not None:
            raise ValueError("adaln_m must be None for unconditioned DiCoBlock")

        residual = x

        # Conv path
        x_att = self.norm1(x)
        if self.use_external_adaln:
            x_att = x_att * (1.0 + scale_a.view(b, c, 1, 1))  # type: ignore[possibly-undefined]
        y = self.conv1(x_att)
        y = self.conv2(y)
        y = F.silu(y)
        y = y * self.cca(y)
        y = self.conv3(y)

        if self.use_external_adaln:
            gate_a_view = torch.tanh(gate_a).view(b, c, 1, 1)  # type: ignore[possibly-undefined]
            x = residual + gate_a_view * y
        else:
            gate = self.gate_attn.view(1, self.channels, 1, 1).to(
                dtype=y.dtype, device=y.device
            )
            x = residual + gate * y

        # MLP path
        residual_mlp = x
        x_mlp = self.norm2(x)
        if self.use_external_adaln:
            x_mlp = x_mlp * (1.0 + scale_m.view(b, c, 1, 1))  # type: ignore[possibly-undefined]
        y_mlp = self.mlp(x_mlp)

        if self.use_external_adaln:
            gate_m_view = torch.tanh(gate_m).view(b, c, 1, 1)  # type: ignore[possibly-undefined]
            x = residual_mlp + gate_m_view * y_mlp
        else:
            gate = self.gate_mlp.view(1, self.channels, 1, 1).to(
                dtype=y_mlp.dtype, device=y_mlp.device
            )
            x = residual_mlp + gate * y_mlp

        return x