File size: 1,838 Bytes
128cb34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | """AdaLN-Zero modules for shared-base + low-rank-delta conditioning."""
from __future__ import annotations
from torch import Tensor, nn
class AdaLNZeroProjector(nn.Module):
"""Shared base AdaLN projection: SiLU -> Linear(d_cond -> 4*d_model).
Returns packed modulation tensor [B, 4*d_model]. Zero-initialized.
"""
def __init__(self, d_model: int, d_cond: int) -> None:
super().__init__()
self.d_model = int(d_model)
self.d_cond = int(d_cond)
self.act = nn.SiLU()
self.proj = nn.Linear(self.d_cond, 4 * self.d_model)
nn.init.zeros_(self.proj.weight)
nn.init.zeros_(self.proj.bias)
def forward(self, cond: Tensor) -> Tensor:
"""Return packed modulation [B, 4*d_model] from conditioning [B, d_cond]."""
act = self.act(cond)
return self.proj(act)
def forward_activated(self, act_cond: Tensor) -> Tensor:
"""Return packed modulation from pre-activated conditioning."""
return self.proj(act_cond)
class AdaLNZeroLowRankDelta(nn.Module):
"""Per-layer low-rank delta: down(d_cond -> rank) -> up(rank -> 4*d_model).
Zero-initialized up-projection preserves AdaLN "zero output" at init.
"""
def __init__(self, *, d_model: int, d_cond: int, rank: int) -> None:
super().__init__()
self.d_model = int(d_model)
self.d_cond = int(d_cond)
self.rank = int(rank)
self.down = nn.Linear(self.d_cond, self.rank, bias=False)
self.up = nn.Linear(self.rank, 4 * self.d_model, bias=False)
nn.init.normal_(self.down.weight, mean=0.0, std=0.02)
nn.init.zeros_(self.up.weight)
def forward(self, act_cond: Tensor) -> Tensor:
"""Return packed delta modulation [B, 4*d_model] from activated cond."""
return self.up(self.down(act_cond))
|