File size: 6,793 Bytes
b79954f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
from state import Model as Model, Parallelism, Training


class MemoryCalculation:
    def __init__(self, model: Model, parallelism: Parallelism, training: Training):
        self.model = model
        self.parallelism = parallelism
        self.training = training

    def calculate_num_parameters(self) -> float:
        # https://michaelwornow.net/2024/01/18/counting-params-in-transformer
        # https://huggingface.co/spaces/nanotron/ultrascale-playbook?section=memory_usage_in_transformers

        # Biases are not added/omitted on a per-model basis for simplicity.
        # Just include them where they could appear. They're small in comparison to weights anyway and it forms an upper bound.

        b, s = self.training.batch_size, self.training.sequence_length
        h, i, l, v, e = (
            self.model.hidden_dim,
            self.model.intermediate_size,
            self.model.num_layers,
            self.model.vocab_size,
            self.model.experts,
        )
        tp, pp, ep = (
            self.parallelism.tensor_parallelism,
            self.parallelism.pipeline_parallelism,
            self.parallelism.expert_parallelism,
        )

        # Embedding layers
        input_embedding = v * h / tp
        unembedding = 0
        if not self.model.weight_tied_embeddings:
            unembedding = h * v / tp

        # Attention
        # weights and biases = *2
        layer_norm_attn_in = 2 * h # not tp sharded
        qkv = 3 * h * h / tp
        attn_output_proj = h * h + h / tp
        attn = layer_norm_attn_in + qkv + attn_output_proj

        # MLP
        layer_norm_mlp_in = 2 * h # not tp sharded
        router = h * e + e # assuming replicated for simplicity
        mlp_up_proj = h * i + i / tp
        mlp_gate_proj = h * i + i / tp
        mlp_down_proj = i * h + h / tp
        expert = mlp_up_proj + mlp_gate_proj + mlp_down_proj
        experts = expert * e / ep
        mlp = layer_norm_mlp_in + router + experts

        layer = attn + mlp
        layers = layer * l


        final_layer_norm = 2 * h # not tp sharded

        # pp and weight tying makes knowing where to embed layer challenging
        # going to assume "worst" case and it's at the end with final layer norm
        # even though that's pretty smalle
        if pp == 1:
            total_params = input_embedding + layers + unembedding + final_layer_norm
        if pp > 1:
            total_params = max(input_embedding, unembedding) + layers/pp + final_layer_norm
        return total_params

    def calculate_parameter_memory(self) -> float:
        return (
            self.calculate_num_parameters() * 4
        )  # assuming 4 bytes (32 bits) per parameter

    def calculate_activation_parameters(self) -> float:
        # https://blog.eleuther.ai/transformer-math/#activations-and-batch-size
        # https://arxiv.org/abs/2205.05198
        # pp not considered since most pp schemes will run multiple concurrent batches to reduce the bubble
        b, s = self.training.batch_size, self.training.sequence_length
        h, i, l, v, e, ae = (
            self.model.hidden_dim,
            self.model.intermediate_size,
            self.model.num_layers,
            self.model.vocab_size,
            self.model.active_experts,
        )
        tp, cp, pp, ep = (
            self.parallelism.tensor_parallelism,
            self.parallelism.context_parallelism,
            self.parallelism.pipeline_parallelism,
            self.parallelism.expert_parallelism,
        )
        if self.training.gradient_checkpointing:
            # full recomputation
            embed = 0
            layer = s * b * h / cp / tp  # only keep initial input to layer
            layers = layer * l
            embed = 0
            final_layer_out = (
                s * b * h / cp / tp
            )  # both sequence and tensor parallelism
            final_norm = s * b * h / cp / tp  # both sequence and tensor parallelism
            unembed = s * b * v / cp / tp
            logits = s * b * v / cp / tp  # both vocab and tensor parallelism
            num_params = (
                embed + layers + final_layer_out + final_norm + unembed + logits
            )
            return num_params
        else:
            # assume flash attention ie do selective recomputation
            # assume tensor parallel + sequence parallel as described in https://arxiv.org/abs/2205.05198
            # the variables calculate the activation outputs
            # Attention Block
            layer_in = s * b * h / cp / tp  # both sequence and context parallelism
            attn_norm = s * b * h / cp / tp  # both sequence and context parallelism
            flash = s * b * h / cp / tp
            # everything else is recalculated by flash attention
            projection = s * b * h / cp / tp
            attn = layer_in + attn_norm + flash + projection
            # MLP Block
            mlp_norm = s * b * h / cp / tp  # both sequence and context parallelism
            router = (
                s * b * e / cp / tp
            )  # makes sense to sp shard if mlp_norm out is sp sharded
            mlp_up = s * b * i / cp / tp
            mlp_gate = s * b * i / cp / tp
            hadamard_swiglu = s * b * i / cp / tp
            mlp_down = s * b * h / cp / tp
            expert = mlp_up + mlp_gate + hadamard_swiglu + mlp_down
            experts = expert * ae
            mlp = mlp_norm + router + experts
            layer = attn + mlp
            layers = layer * l
            # Other
            embed = 0
            final_layer_out = (
                s * b * h / cp / tp
            )  # both sequence and context parallelism
            final_norm = s * b * h / cp / tp  # both sequence and context parallelism
            unembed = s * b * v / cp / tp
            logits = s * b * v / cp / tp
            num_params = (
                embed + layers + final_layer_out + final_norm + unembed + logits
            )
            return num_params

    def calculate_activation_memory(self) -> float:
        return (
            self.calculate_activation_parameters() * 4
        )  # assuming 4 bytes (32 bits) per activation

    def calculate_gradient_memory(self) -> float:
        # https://blog.eleuther.ai/transformer-math/#gradients
        return (
            self.calculate_parameter_memory()
        )  # gradients are same size as parameters

    def calculate_optimizer_memory(self) -> float:
        # https://blog.eleuther.ai/transformer-math/#optimizer-states
        # https://www.determined.ai/blog/act-mem-2, https://web.archive.org/web/20250308172134/https://www.determined.ai/blog/act-mem-2
        return (
            2 * self.calculate_parameter_memory()
        )  # Adam optimizer with 3 states per parameter