| """Sage-T2I Configuration.""" | |
| from dataclasses import dataclass, field | |
| from typing import Literal | |
| class DiTConfig: | |
| in_channels: int = 4 | |
| hidden_size: int = 3072 | |
| num_layers: int = 24 | |
| num_heads: int = 24 | |
| mlp_ratio: float = 4.0 | |
| patch_size: int = 2 | |
| image_size: int = 256 | |
| context_dim: int = 768 | |
| num_params: int = 0 | |
| def __post_init__(self): | |
| self.num_params = self._compute_params() | |
| def _compute_params(self): | |
| d = self.hidden_size | |
| h = self.num_heads | |
| c = self.context_dim | |
| p = self.patch_size | |
| im = self.image_size | |
| ic = self.in_channels | |
| patch_vol = p * p * ic | |
| num_patches = (im // 8 // p) ** 2 | |
| embed = patch_vol * d | |
| pos = num_patches * d | |
| t_emb = 256 * d + d * d * 2 | |
| c_emb = c * d + d | |
| block = ( | |
| 3 * d * d # QKV self-attn | |
| + d * d # self-attn out | |
| + 2 * d * d # cross Q + out | |
| + 2 * c * d # cross K, V | |
| + 3 * d * int(d * self.mlp_ratio) # FF gate/up/down | |
| + 6 * d * d # adaLN modulation | |
| ) | |
| total = embed + pos + t_emb + c_emb + block * self.num_layers | |
| total += 2 * d # final norm | |
| total += 2 * d * d # final adaLN | |
| total += patch_vol * d # final proj | |
| return total | |
| def from_preset(cls, size: Literal["small", "large"]): | |
| if size == "small": | |
| return cls(hidden_size=1024, num_layers=20, num_heads=16) | |
| elif size == "large": | |
| return cls(hidden_size=3072, num_layers=24, num_heads=24) | |
| return cls() | |