File size: 10,525 Bytes
d7d2fb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7557c9f
 
 
 
 
 
 
 
 
 
 
 
 
d7d2fb2
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
import sys
from pathlib import Path

# Add the Model_Architecture directory to path
sys.path.insert(0, str(Path(__file__).parent))

from model import ModelArgs

def estimate_model_size(args: ModelArgs):
    """Calculate detailed model size and parameter count"""

    print(f"\n{'='*70}")
    print(f"MODEL ARCHITECTURE ANALYSIS: ismail")
    print(f"{'='*70}\n")

    # Display configuration
    print(f"📋 CONFIGURATION:")
    print(f"   Model dimension (dim):        {args.dim}")
    print(f"   Vocabulary size:              {args.vocab_size:,}")
    print(f"   Number of layers:             {args.n_layers}")
    print(f"   Dense layers:                 {args.n_dense_layers}")
    print(f"   MoE layers:                   {args.n_layers - args.n_dense_layers}")
    print(f"   Attention heads:              {args.n_heads}")
    print(f"   Max sequence length:          {args.max_seq_len}")
    print(f"   Max batch size:               {args.max_batch_size}")
    print(f"   \nMoE Configuration:")
    print(f"   Routed experts:               {args.n_routed_experts}")
    print(f"   Shared experts:               {args.n_shared_experts}")
    print(f"   Activated experts:            {args.n_activated_experts}")
    print(f"   \nMLA Configuration:")
    print(f"   Q LoRA rank:                  {args.q_lora_rank}")
    print(f"   KV LoRA rank:                 {args.kv_lora_rank}")
    print(f"   QK nope head dim:             {args.qk_nope_head_dim}")
    print(f"   QK rope head dim:             {args.qk_rope_head_dim}")
    print(f"   V head dim:                   {args.v_head_dim}")

    # Calculate parameters by component
    print(f"\n{'='*70}")
    print(f"🔢 PARAMETER COUNT BY COMPONENT:")
    print(f"{'='*70}\n")

    # 1. Embeddings
    tok_embed_params = args.vocab_size * args.dim
    output_params = args.vocab_size * args.dim
    total_embed_params = tok_embed_params + output_params
    print(f"   Token Embeddings:             {tok_embed_params:>15,} params")
    print(f"   Output Layer:                 {output_params:>15,} params")
    print(f"   {'─' * 50}")
    print(f"   Total Embeddings:             {total_embed_params:>15,} params\n")

    # 2. Attention (per layer)
    if args.q_lora_rank == 0:
        wq_params = args.dim * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
        wq_norm_params = 0
    else:
        wq_params = args.dim * args.q_lora_rank + args.q_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
        wq_norm_params = args.q_lora_rank

    wkv_a_params = args.dim * (args.kv_lora_rank + args.qk_rope_head_dim)
    kv_norm_params = args.kv_lora_rank
    wkv_b_params = args.kv_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.v_head_dim)
    wo_params = args.n_heads * args.v_head_dim * args.dim
    attn_norm_params = args.dim

    attn_params_per_layer = wq_params + wq_norm_params + wkv_a_params + kv_norm_params + wkv_b_params + wo_params + attn_norm_params

    print(f"   Attention (per layer):")
    if args.q_lora_rank > 0:
        print(f"      WQ (LoRA):                 {wq_params:>15,} params")
        print(f"      Q Norm:                    {wq_norm_params:>15,} params")
    else:
        print(f"      WQ:                        {wq_params:>15,} params")
    print(f"      WKV_A:                     {wkv_a_params:>15,} params")
    print(f"      KV Norm:                   {kv_norm_params:>15,} params")
    print(f"      WKV_B:                     {wkv_b_params:>15,} params")
    print(f"      WO:                        {wo_params:>15,} params")
    print(f"      Attn Norm:                 {attn_norm_params:>15,} params")
    print(f"   {'─' * 50}")
    print(f"      Subtotal:                  {attn_params_per_layer:>15,} params\n")

    # 3. Dense FFN
    dense_w1_params = args.dim * args.inter_dim
    dense_w2_params = args.inter_dim * args.dim
    dense_w3_params = args.dim * args.inter_dim
    ffn_norm_params = args.dim
    dense_ffn_per_layer = dense_w1_params + dense_w2_params + dense_w3_params + ffn_norm_params

    print(f"   Dense FFN (per layer):")
    print(f"      FC1 (W1):                  {dense_w1_params:>15,} params")
    print(f"      FC2 (W3):                  {dense_w3_params:>15,} params")
    print(f"      FC3 (W2):                  {dense_w2_params:>15,} params")
    print(f"      FFN Norm:                  {ffn_norm_params:>15,} params")
    print(f"   {'─' * 50}")
    print(f"      Subtotal:                  {dense_ffn_per_layer:>15,} params\n")

    # 4. MoE FFN
    gate_params = args.n_routed_experts * args.dim
    if args.use_routing_bias:
        gate_params += args.n_routed_experts

    expert_w1_params = args.dim * args.moe_inter_dim
    expert_w2_params = args.moe_inter_dim * args.dim
    expert_w3_params = args.dim * args.moe_inter_dim
    per_expert_params = expert_w1_params + expert_w2_params + expert_w3_params
    routed_experts_params = args.n_routed_experts * per_expert_params

    shared_w1_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
    shared_w2_params = (args.n_shared_experts * args.moe_inter_dim) * args.dim
    shared_w3_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
    shared_experts_params = shared_w1_params + shared_w2_params + shared_w3_params

    moe_ffn_per_layer = gate_params + routed_experts_params + shared_experts_params + ffn_norm_params

    print(f"   MoE FFN (per layer):")
    print(f"      Gate:                      {gate_params:>15,} params")
    print(f"      Routed Experts ({args.n_routed_experts}x):       {routed_experts_params:>15,} params")
    print(f"         Per expert:             {per_expert_params:>15,} params")
    print(f"      Shared Experts:            {shared_experts_params:>15,} params")
    print(f"      FFN Norm:                  {ffn_norm_params:>15,} params")
    print(f"   {'─' * 50}")
    print(f"      Subtotal:                  {moe_ffn_per_layer:>15,} params\n")

    # 5. Final Norm
    final_norm_params = args.dim

    # Total calculation
    dense_layer_params = attn_params_per_layer + dense_ffn_per_layer
    moe_layer_params = attn_params_per_layer + moe_ffn_per_layer

    total_dense_params = args.n_dense_layers * dense_layer_params
    total_moe_params = (args.n_layers - args.n_dense_layers) * moe_layer_params

    total_params = total_embed_params + total_dense_params + total_moe_params + final_norm_params

    print(f"   Layer Summary:")
    print(f"      Dense layers ({args.n_dense_layers}x):        {total_dense_params:>15,} params")
    print(f"      MoE layers ({args.n_layers - args.n_dense_layers}x):          {total_moe_params:>15,} params")
    print(f"      Final Norm:                {final_norm_params:>15,} params")

    print(f"\n{'='*70}")
    print(f"📊 TOTAL PARAMETERS:              {total_params:>15,} ({total_params/1e6:.2f}M)")
    print(f"{'='*70}\n")

    # Memory calculations
    print(f"{'='*70}")
    print(f"💾 MEMORY USAGE:")
    print(f"{'='*70}\n")

    bytes_per_param_bf16 = 2
    bytes_per_param_fp32 = 4

    # Model weights
    weight_memory_bf16 = total_params * bytes_per_param_bf16 / (1024**3)
    weight_memory_fp32 = total_params * bytes_per_param_fp32 / (1024**3)

    print(f"   Model Weights:")
    print(f"      BF16 (inference):          {weight_memory_bf16:>10.3f} GB")
    print(f"      FP32 (training):           {weight_memory_fp32:>10.3f} GB\n")

    # KV Cache
    kv_cache_per_layer = args.max_batch_size * args.max_seq_len * (args.kv_lora_rank + args.qk_rope_head_dim)
    total_kv_cache = kv_cache_per_layer * args.n_layers * bytes_per_param_bf16 / (1024**3)

    print(f"   KV Cache (BF16):")
    print(f"      Per layer:                 {kv_cache_per_layer * bytes_per_param_bf16 / (1024**3):>10.3f} GB")
    print(f"      Total ({args.n_layers} layers):         {total_kv_cache:>10.3f} GB\n")

    # Activations (rough estimate)
    activation_memory = (args.max_batch_size * args.max_seq_len * args.dim * args.n_layers * 4) / (1024**3)

    print(f"   Activations (estimate):       {activation_memory:>10.3f} GB\n")

    # Training overhead
    gradients_memory = weight_memory_fp32  # Same size as weights
    optimizer_states = weight_memory_fp32 * 2  # Adam: 2x for momentum + variance
    training_overhead = gradients_memory + optimizer_states

    print(f"   Training Overhead (FP32):")
    print(f"      Gradients:                 {gradients_memory:>10.3f} GB")
    print(f"      Optimizer states (Adam):   {optimizer_states:>10.3f} GB")
    print(f"      Total overhead:            {training_overhead:>10.3f} GB\n")

    # Total estimates
    inference_total = weight_memory_bf16 + total_kv_cache + activation_memory
    training_total = weight_memory_fp32 + total_kv_cache + activation_memory + training_overhead

    print(f"{'='*70}")
    print(f"   INFERENCE (BF16):             {inference_total:>10.3f} GB")
    print(f"   TRAINING (FP32 + Adam):       {training_total:>10.3f} GB")
    print(f"{'='*70}\n")

    # Memory analysis
    print(f"{'='*70}")
    print(f"🎯 MEMORY ANALYSIS:")
    print(f"{'='*70}\n")

    for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
        if inference_total <= threshold:
            print(f"   ✅ Inference fits in {name} GPU")
            break
    else:
        print(f"   ❌ Inference requires >80GB GPU")

    for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
        if training_total <= threshold:
            print(f"   ✅ Training fits in {name} GPU")
            break
    else:
        print(f"   ❌ Training requires >80GB GPU")

    print(f"\n{'='*70}\n")

    return {
        'total_params': total_params,
        'weight_memory_gb': weight_memory_bf16,
        'inference_memory_gb': inference_total,
        'training_memory_gb': training_total
    }


if __name__ == "__main__":
    import json
    from pathlib import Path

    # Try to load from config.json, otherwise use defaults
    config_path = Path(__file__).parent / "config.json"
    if config_path.exists():
        print(f"📄 Loading configuration from {config_path}")
        with open(config_path) as f:
            config = json.load(f)
        args = ModelArgs(**config["model"])
    else:
        print("⚠️  config.json not found, using default ModelArgs")
        args = ModelArgs()

    # Run estimation
    results = estimate_model_size(args)