File size: 1,838 Bytes
128cb34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
"""AdaLN-Zero modules for shared-base + low-rank-delta conditioning."""

from __future__ import annotations

from torch import Tensor, nn


class AdaLNZeroProjector(nn.Module):
    """Shared base AdaLN projection: SiLU -> Linear(d_cond -> 4*d_model).

    Returns packed modulation tensor [B, 4*d_model]. Zero-initialized.
    """

    def __init__(self, d_model: int, d_cond: int) -> None:
        super().__init__()
        self.d_model = int(d_model)
        self.d_cond = int(d_cond)
        self.act = nn.SiLU()
        self.proj = nn.Linear(self.d_cond, 4 * self.d_model)
        nn.init.zeros_(self.proj.weight)
        nn.init.zeros_(self.proj.bias)

    def forward(self, cond: Tensor) -> Tensor:
        """Return packed modulation [B, 4*d_model] from conditioning [B, d_cond]."""
        act = self.act(cond)
        return self.proj(act)

    def forward_activated(self, act_cond: Tensor) -> Tensor:
        """Return packed modulation from pre-activated conditioning."""
        return self.proj(act_cond)


class AdaLNZeroLowRankDelta(nn.Module):
    """Per-layer low-rank delta: down(d_cond -> rank) -> up(rank -> 4*d_model).

    Zero-initialized up-projection preserves AdaLN "zero output" at init.
    """

    def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
        super().__init__()
        self.d_model = int(d_model)
        self.d_cond = int(d_cond)
        self.rank = int(rank)
        self.down = nn.Linear(self.d_cond, self.rank, bias=False)
        self.up = nn.Linear(self.rank, 4 * self.d_model, bias=False)
        nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
        nn.init.zeros_(self.up.weight)

    def forward(self, act_cond: Tensor) -> Tensor:
        """Return packed delta modulation [B, 4*d_model] from activated cond."""
        return self.up(self.down(act_cond))