"""Sage-T2I Configuration.""" from dataclasses import dataclass, field from typing import Literal @dataclass class DiTConfig: in_channels: int = 4 hidden_size: int = 3072 num_layers: int = 24 num_heads: int = 24 mlp_ratio: float = 4.0 patch_size: int = 2 image_size: int = 256 context_dim: int = 768 num_params: int = 0 def __post_init__(self): self.num_params = self._compute_params() def _compute_params(self): d = self.hidden_size h = self.num_heads c = self.context_dim p = self.patch_size im = self.image_size ic = self.in_channels patch_vol = p * p * ic num_patches = (im // 8 // p) ** 2 embed = patch_vol * d pos = num_patches * d t_emb = 256 * d + d * d * 2 c_emb = c * d + d block = ( 3 * d * d # QKV self-attn + d * d # self-attn out + 2 * d * d # cross Q + out + 2 * c * d # cross K, V + 3 * d * int(d * self.mlp_ratio) # FF gate/up/down + 6 * d * d # adaLN modulation ) total = embed + pos + t_emb + c_emb + block * self.num_layers total += 2 * d # final norm total += 2 * d * d # final adaLN total += patch_vol * d # final proj return total @classmethod def from_preset(cls, size: Literal["small", "large"]): if size == "small": return cls(hidden_size=1024, num_layers=20, num_heads=16) elif size == "large": return cls(hidden_size=3072, num_layers=24, num_heads=24) return cls()