File size: 1,717 Bytes
2d7087a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
"""Sage-T2I Configuration."""
from dataclasses import dataclass, field
from typing import Literal

@dataclass
class DiTConfig:
    in_channels: int = 4
    hidden_size: int = 3072
    num_layers: int = 24
    num_heads: int = 24
    mlp_ratio: float = 4.0
    patch_size: int = 2
    image_size: int = 256
    context_dim: int = 768
    num_params: int = 0

    def __post_init__(self):
        self.num_params = self._compute_params()

    def _compute_params(self):
        d = self.hidden_size
        h = self.num_heads
        c = self.context_dim
        p = self.patch_size
        im = self.image_size
        ic = self.in_channels
        patch_vol = p * p * ic
        num_patches = (im // 8 // p) ** 2

        embed = patch_vol * d
        pos = num_patches * d
        t_emb = 256 * d + d * d * 2
        c_emb = c * d + d

        block = (
            3 * d * d              # QKV self-attn
            + d * d                # self-attn out
            + 2 * d * d            # cross Q + out
            + 2 * c * d            # cross K, V
            + 3 * d * int(d * self.mlp_ratio)  # FF gate/up/down
            + 6 * d * d            # adaLN modulation
        )
        total = embed + pos + t_emb + c_emb + block * self.num_layers
        total += 2 * d             # final norm
        total += 2 * d * d         # final adaLN
        total += patch_vol * d     # final proj
        return total

    @classmethod
    def from_preset(cls, size: Literal["small", "large"]):
        if size == "small":
            return cls(hidden_size=1024, num_layers=20, num_heads=16)
        elif size == "large":
            return cls(hidden_size=3072, num_layers=24, num_heads=24)
        return cls()