|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
from model import ModelArgs |
|
|
|
|
|
def estimate_model_size(args: ModelArgs): |
|
|
"""Calculate detailed model size and parameter count""" |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"MODEL ARCHITECTURE ANALYSIS: ismail") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
|
|
|
print(f"📋 CONFIGURATION:") |
|
|
print(f" Model dimension (dim): {args.dim}") |
|
|
print(f" Vocabulary size: {args.vocab_size:,}") |
|
|
print(f" Number of layers: {args.n_layers}") |
|
|
print(f" Dense layers: {args.n_dense_layers}") |
|
|
print(f" MoE layers: {args.n_layers - args.n_dense_layers}") |
|
|
print(f" Attention heads: {args.n_heads}") |
|
|
print(f" Max sequence length: {args.max_seq_len}") |
|
|
print(f" Max batch size: {args.max_batch_size}") |
|
|
print(f" \nMoE Configuration:") |
|
|
print(f" Routed experts: {args.n_routed_experts}") |
|
|
print(f" Shared experts: {args.n_shared_experts}") |
|
|
print(f" Activated experts: {args.n_activated_experts}") |
|
|
print(f" \nMLA Configuration:") |
|
|
print(f" Q LoRA rank: {args.q_lora_rank}") |
|
|
print(f" KV LoRA rank: {args.kv_lora_rank}") |
|
|
print(f" QK nope head dim: {args.qk_nope_head_dim}") |
|
|
print(f" QK rope head dim: {args.qk_rope_head_dim}") |
|
|
print(f" V head dim: {args.v_head_dim}") |
|
|
|
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"🔢 PARAMETER COUNT BY COMPONENT:") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
|
|
|
tok_embed_params = args.vocab_size * args.dim |
|
|
output_params = args.vocab_size * args.dim |
|
|
total_embed_params = tok_embed_params + output_params |
|
|
print(f" Token Embeddings: {tok_embed_params:>15,} params") |
|
|
print(f" Output Layer: {output_params:>15,} params") |
|
|
print(f" {'─' * 50}") |
|
|
print(f" Total Embeddings: {total_embed_params:>15,} params\n") |
|
|
|
|
|
|
|
|
if args.q_lora_rank == 0: |
|
|
wq_params = args.dim * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim) |
|
|
wq_norm_params = 0 |
|
|
else: |
|
|
wq_params = args.dim * args.q_lora_rank + args.q_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim) |
|
|
wq_norm_params = args.q_lora_rank |
|
|
|
|
|
wkv_a_params = args.dim * (args.kv_lora_rank + args.qk_rope_head_dim) |
|
|
kv_norm_params = args.kv_lora_rank |
|
|
wkv_b_params = args.kv_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.v_head_dim) |
|
|
wo_params = args.n_heads * args.v_head_dim * args.dim |
|
|
attn_norm_params = args.dim |
|
|
|
|
|
attn_params_per_layer = wq_params + wq_norm_params + wkv_a_params + kv_norm_params + wkv_b_params + wo_params + attn_norm_params |
|
|
|
|
|
print(f" Attention (per layer):") |
|
|
if args.q_lora_rank > 0: |
|
|
print(f" WQ (LoRA): {wq_params:>15,} params") |
|
|
print(f" Q Norm: {wq_norm_params:>15,} params") |
|
|
else: |
|
|
print(f" WQ: {wq_params:>15,} params") |
|
|
print(f" WKV_A: {wkv_a_params:>15,} params") |
|
|
print(f" KV Norm: {kv_norm_params:>15,} params") |
|
|
print(f" WKV_B: {wkv_b_params:>15,} params") |
|
|
print(f" WO: {wo_params:>15,} params") |
|
|
print(f" Attn Norm: {attn_norm_params:>15,} params") |
|
|
print(f" {'─' * 50}") |
|
|
print(f" Subtotal: {attn_params_per_layer:>15,} params\n") |
|
|
|
|
|
|
|
|
dense_w1_params = args.dim * args.inter_dim |
|
|
dense_w2_params = args.inter_dim * args.dim |
|
|
dense_w3_params = args.dim * args.inter_dim |
|
|
ffn_norm_params = args.dim |
|
|
dense_ffn_per_layer = dense_w1_params + dense_w2_params + dense_w3_params + ffn_norm_params |
|
|
|
|
|
print(f" Dense FFN (per layer):") |
|
|
print(f" FC1 (W1): {dense_w1_params:>15,} params") |
|
|
print(f" FC2 (W3): {dense_w3_params:>15,} params") |
|
|
print(f" FC3 (W2): {dense_w2_params:>15,} params") |
|
|
print(f" FFN Norm: {ffn_norm_params:>15,} params") |
|
|
print(f" {'─' * 50}") |
|
|
print(f" Subtotal: {dense_ffn_per_layer:>15,} params\n") |
|
|
|
|
|
|
|
|
gate_params = args.n_routed_experts * args.dim |
|
|
if args.use_routing_bias: |
|
|
gate_params += args.n_routed_experts |
|
|
|
|
|
expert_w1_params = args.dim * args.moe_inter_dim |
|
|
expert_w2_params = args.moe_inter_dim * args.dim |
|
|
expert_w3_params = args.dim * args.moe_inter_dim |
|
|
per_expert_params = expert_w1_params + expert_w2_params + expert_w3_params |
|
|
routed_experts_params = args.n_routed_experts * per_expert_params |
|
|
|
|
|
shared_w1_params = args.dim * (args.n_shared_experts * args.moe_inter_dim) |
|
|
shared_w2_params = (args.n_shared_experts * args.moe_inter_dim) * args.dim |
|
|
shared_w3_params = args.dim * (args.n_shared_experts * args.moe_inter_dim) |
|
|
shared_experts_params = shared_w1_params + shared_w2_params + shared_w3_params |
|
|
|
|
|
moe_ffn_per_layer = gate_params + routed_experts_params + shared_experts_params + ffn_norm_params |
|
|
|
|
|
print(f" MoE FFN (per layer):") |
|
|
print(f" Gate: {gate_params:>15,} params") |
|
|
print(f" Routed Experts ({args.n_routed_experts}x): {routed_experts_params:>15,} params") |
|
|
print(f" Per expert: {per_expert_params:>15,} params") |
|
|
print(f" Shared Experts: {shared_experts_params:>15,} params") |
|
|
print(f" FFN Norm: {ffn_norm_params:>15,} params") |
|
|
print(f" {'─' * 50}") |
|
|
print(f" Subtotal: {moe_ffn_per_layer:>15,} params\n") |
|
|
|
|
|
|
|
|
final_norm_params = args.dim |
|
|
|
|
|
|
|
|
dense_layer_params = attn_params_per_layer + dense_ffn_per_layer |
|
|
moe_layer_params = attn_params_per_layer + moe_ffn_per_layer |
|
|
|
|
|
total_dense_params = args.n_dense_layers * dense_layer_params |
|
|
total_moe_params = (args.n_layers - args.n_dense_layers) * moe_layer_params |
|
|
|
|
|
total_params = total_embed_params + total_dense_params + total_moe_params + final_norm_params |
|
|
|
|
|
print(f" Layer Summary:") |
|
|
print(f" Dense layers ({args.n_dense_layers}x): {total_dense_params:>15,} params") |
|
|
print(f" MoE layers ({args.n_layers - args.n_dense_layers}x): {total_moe_params:>15,} params") |
|
|
print(f" Final Norm: {final_norm_params:>15,} params") |
|
|
|
|
|
print(f"\n{'='*70}") |
|
|
print(f"📊 TOTAL PARAMETERS: {total_params:>15,} ({total_params/1e6:.2f}M)") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f"💾 MEMORY USAGE:") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
bytes_per_param_bf16 = 2 |
|
|
bytes_per_param_fp32 = 4 |
|
|
|
|
|
|
|
|
weight_memory_bf16 = total_params * bytes_per_param_bf16 / (1024**3) |
|
|
weight_memory_fp32 = total_params * bytes_per_param_fp32 / (1024**3) |
|
|
|
|
|
print(f" Model Weights:") |
|
|
print(f" BF16 (inference): {weight_memory_bf16:>10.3f} GB") |
|
|
print(f" FP32 (training): {weight_memory_fp32:>10.3f} GB\n") |
|
|
|
|
|
|
|
|
kv_cache_per_layer = args.max_batch_size * args.max_seq_len * (args.kv_lora_rank + args.qk_rope_head_dim) |
|
|
total_kv_cache = kv_cache_per_layer * args.n_layers * bytes_per_param_bf16 / (1024**3) |
|
|
|
|
|
print(f" KV Cache (BF16):") |
|
|
print(f" Per layer: {kv_cache_per_layer * bytes_per_param_bf16 / (1024**3):>10.3f} GB") |
|
|
print(f" Total ({args.n_layers} layers): {total_kv_cache:>10.3f} GB\n") |
|
|
|
|
|
|
|
|
activation_memory = (args.max_batch_size * args.max_seq_len * args.dim * args.n_layers * 4) / (1024**3) |
|
|
|
|
|
print(f" Activations (estimate): {activation_memory:>10.3f} GB\n") |
|
|
|
|
|
|
|
|
gradients_memory = weight_memory_fp32 |
|
|
optimizer_states = weight_memory_fp32 * 2 |
|
|
training_overhead = gradients_memory + optimizer_states |
|
|
|
|
|
print(f" Training Overhead (FP32):") |
|
|
print(f" Gradients: {gradients_memory:>10.3f} GB") |
|
|
print(f" Optimizer states (Adam): {optimizer_states:>10.3f} GB") |
|
|
print(f" Total overhead: {training_overhead:>10.3f} GB\n") |
|
|
|
|
|
|
|
|
inference_total = weight_memory_bf16 + total_kv_cache + activation_memory |
|
|
training_total = weight_memory_fp32 + total_kv_cache + activation_memory + training_overhead |
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f" INFERENCE (BF16): {inference_total:>10.3f} GB") |
|
|
print(f" TRAINING (FP32 + Adam): {training_total:>10.3f} GB") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
|
|
|
print(f"{'='*70}") |
|
|
print(f"🎯 MEMORY ANALYSIS:") |
|
|
print(f"{'='*70}\n") |
|
|
|
|
|
for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]: |
|
|
if inference_total <= threshold: |
|
|
print(f" ✅ Inference fits in {name} GPU") |
|
|
break |
|
|
else: |
|
|
print(f" ❌ Inference requires >80GB GPU") |
|
|
|
|
|
for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]: |
|
|
if training_total <= threshold: |
|
|
print(f" ✅ Training fits in {name} GPU") |
|
|
break |
|
|
else: |
|
|
print(f" ❌ Training requires >80GB GPU") |
|
|
|
|
|
print(f"\n{'='*70}\n") |
|
|
|
|
|
return { |
|
|
'total_params': total_params, |
|
|
'weight_memory_gb': weight_memory_bf16, |
|
|
'inference_memory_gb': inference_total, |
|
|
'training_memory_gb': training_total |
|
|
} |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import json |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
config_path = Path(__file__).parent / "config.json" |
|
|
if config_path.exists(): |
|
|
print(f"📄 Loading configuration from {config_path}") |
|
|
with open(config_path) as f: |
|
|
config = json.load(f) |
|
|
args = ModelArgs(**config["model"]) |
|
|
else: |
|
|
print("⚠️ config.json not found, using default ModelArgs") |
|
|
args = ModelArgs() |
|
|
|
|
|
|
|
|
results = estimate_model_size(args) |