File size: 1,717 Bytes
2d7087a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | """Sage-T2I Configuration."""
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class DiTConfig:
in_channels: int = 4
hidden_size: int = 3072
num_layers: int = 24
num_heads: int = 24
mlp_ratio: float = 4.0
patch_size: int = 2
image_size: int = 256
context_dim: int = 768
num_params: int = 0
def __post_init__(self):
self.num_params = self._compute_params()
def _compute_params(self):
d = self.hidden_size
h = self.num_heads
c = self.context_dim
p = self.patch_size
im = self.image_size
ic = self.in_channels
patch_vol = p * p * ic
num_patches = (im // 8 // p) ** 2
embed = patch_vol * d
pos = num_patches * d
t_emb = 256 * d + d * d * 2
c_emb = c * d + d
block = (
3 * d * d # QKV self-attn
+ d * d # self-attn out
+ 2 * d * d # cross Q + out
+ 2 * c * d # cross K, V
+ 3 * d * int(d * self.mlp_ratio) # FF gate/up/down
+ 6 * d * d # adaLN modulation
)
total = embed + pos + t_emb + c_emb + block * self.num_layers
total += 2 * d # final norm
total += 2 * d * d # final adaLN
total += patch_vol * d # final proj
return total
@classmethod
def from_preset(cls, size: Literal["small", "large"]):
if size == "small":
return cls(hidden_size=1024, num_layers=20, num_heads=16)
elif size == "large":
return cls(hidden_size=3072, num_layers=24, num_heads=24)
return cls()
|