sage-t2i / model /config.py
itriedcoding's picture
Upload folder using huggingface_hub
2d7087a verified
Raw
History Blame Contribute Delete
1.72 kB
"""Sage-T2I Configuration."""
from dataclasses import dataclass, field
from typing import Literal
@dataclass
class DiTConfig:
in_channels: int = 4
hidden_size: int = 3072
num_layers: int = 24
num_heads: int = 24
mlp_ratio: float = 4.0
patch_size: int = 2
image_size: int = 256
context_dim: int = 768
num_params: int = 0
def __post_init__(self):
self.num_params = self._compute_params()
def _compute_params(self):
d = self.hidden_size
h = self.num_heads
c = self.context_dim
p = self.patch_size
im = self.image_size
ic = self.in_channels
patch_vol = p * p * ic
num_patches = (im // 8 // p) ** 2
embed = patch_vol * d
pos = num_patches * d
t_emb = 256 * d + d * d * 2
c_emb = c * d + d
block = (
3 * d * d # QKV self-attn
+ d * d # self-attn out
+ 2 * d * d # cross Q + out
+ 2 * c * d # cross K, V
+ 3 * d * int(d * self.mlp_ratio) # FF gate/up/down
+ 6 * d * d # adaLN modulation
)
total = embed + pos + t_emb + c_emb + block * self.num_layers
total += 2 * d # final norm
total += 2 * d * d # final adaLN
total += patch_vol * d # final proj
return total
@classmethod
def from_preset(cls, size: Literal["small", "large"]):
if size == "small":
return cls(hidden_size=1024, num_layers=20, num_heads=16)
elif size == "large":
return cls(hidden_size=3072, num_layers=24, num_heads=24)
return cls()