ismail / Model_Architecture /model_size.py
ikaganacar's picture
Better Configuration Implementation
7557c9f
import sys
from pathlib import Path
# Add the Model_Architecture directory to path
sys.path.insert(0, str(Path(__file__).parent))
from model import ModelArgs
def estimate_model_size(args: ModelArgs):
"""Calculate detailed model size and parameter count"""
print(f"\n{'='*70}")
print(f"MODEL ARCHITECTURE ANALYSIS: ismail")
print(f"{'='*70}\n")
# Display configuration
print(f"📋 CONFIGURATION:")
print(f" Model dimension (dim): {args.dim}")
print(f" Vocabulary size: {args.vocab_size:,}")
print(f" Number of layers: {args.n_layers}")
print(f" Dense layers: {args.n_dense_layers}")
print(f" MoE layers: {args.n_layers - args.n_dense_layers}")
print(f" Attention heads: {args.n_heads}")
print(f" Max sequence length: {args.max_seq_len}")
print(f" Max batch size: {args.max_batch_size}")
print(f" \nMoE Configuration:")
print(f" Routed experts: {args.n_routed_experts}")
print(f" Shared experts: {args.n_shared_experts}")
print(f" Activated experts: {args.n_activated_experts}")
print(f" \nMLA Configuration:")
print(f" Q LoRA rank: {args.q_lora_rank}")
print(f" KV LoRA rank: {args.kv_lora_rank}")
print(f" QK nope head dim: {args.qk_nope_head_dim}")
print(f" QK rope head dim: {args.qk_rope_head_dim}")
print(f" V head dim: {args.v_head_dim}")
# Calculate parameters by component
print(f"\n{'='*70}")
print(f"🔢 PARAMETER COUNT BY COMPONENT:")
print(f"{'='*70}\n")
# 1. Embeddings
tok_embed_params = args.vocab_size * args.dim
output_params = args.vocab_size * args.dim
total_embed_params = tok_embed_params + output_params
print(f" Token Embeddings: {tok_embed_params:>15,} params")
print(f" Output Layer: {output_params:>15,} params")
print(f" {'─' * 50}")
print(f" Total Embeddings: {total_embed_params:>15,} params\n")
# 2. Attention (per layer)
if args.q_lora_rank == 0:
wq_params = args.dim * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
wq_norm_params = 0
else:
wq_params = args.dim * args.q_lora_rank + args.q_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.qk_rope_head_dim)
wq_norm_params = args.q_lora_rank
wkv_a_params = args.dim * (args.kv_lora_rank + args.qk_rope_head_dim)
kv_norm_params = args.kv_lora_rank
wkv_b_params = args.kv_lora_rank * args.n_heads * (args.qk_nope_head_dim + args.v_head_dim)
wo_params = args.n_heads * args.v_head_dim * args.dim
attn_norm_params = args.dim
attn_params_per_layer = wq_params + wq_norm_params + wkv_a_params + kv_norm_params + wkv_b_params + wo_params + attn_norm_params
print(f" Attention (per layer):")
if args.q_lora_rank > 0:
print(f" WQ (LoRA): {wq_params:>15,} params")
print(f" Q Norm: {wq_norm_params:>15,} params")
else:
print(f" WQ: {wq_params:>15,} params")
print(f" WKV_A: {wkv_a_params:>15,} params")
print(f" KV Norm: {kv_norm_params:>15,} params")
print(f" WKV_B: {wkv_b_params:>15,} params")
print(f" WO: {wo_params:>15,} params")
print(f" Attn Norm: {attn_norm_params:>15,} params")
print(f" {'─' * 50}")
print(f" Subtotal: {attn_params_per_layer:>15,} params\n")
# 3. Dense FFN
dense_w1_params = args.dim * args.inter_dim
dense_w2_params = args.inter_dim * args.dim
dense_w3_params = args.dim * args.inter_dim
ffn_norm_params = args.dim
dense_ffn_per_layer = dense_w1_params + dense_w2_params + dense_w3_params + ffn_norm_params
print(f" Dense FFN (per layer):")
print(f" FC1 (W1): {dense_w1_params:>15,} params")
print(f" FC2 (W3): {dense_w3_params:>15,} params")
print(f" FC3 (W2): {dense_w2_params:>15,} params")
print(f" FFN Norm: {ffn_norm_params:>15,} params")
print(f" {'─' * 50}")
print(f" Subtotal: {dense_ffn_per_layer:>15,} params\n")
# 4. MoE FFN
gate_params = args.n_routed_experts * args.dim
if args.use_routing_bias:
gate_params += args.n_routed_experts
expert_w1_params = args.dim * args.moe_inter_dim
expert_w2_params = args.moe_inter_dim * args.dim
expert_w3_params = args.dim * args.moe_inter_dim
per_expert_params = expert_w1_params + expert_w2_params + expert_w3_params
routed_experts_params = args.n_routed_experts * per_expert_params
shared_w1_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
shared_w2_params = (args.n_shared_experts * args.moe_inter_dim) * args.dim
shared_w3_params = args.dim * (args.n_shared_experts * args.moe_inter_dim)
shared_experts_params = shared_w1_params + shared_w2_params + shared_w3_params
moe_ffn_per_layer = gate_params + routed_experts_params + shared_experts_params + ffn_norm_params
print(f" MoE FFN (per layer):")
print(f" Gate: {gate_params:>15,} params")
print(f" Routed Experts ({args.n_routed_experts}x): {routed_experts_params:>15,} params")
print(f" Per expert: {per_expert_params:>15,} params")
print(f" Shared Experts: {shared_experts_params:>15,} params")
print(f" FFN Norm: {ffn_norm_params:>15,} params")
print(f" {'─' * 50}")
print(f" Subtotal: {moe_ffn_per_layer:>15,} params\n")
# 5. Final Norm
final_norm_params = args.dim
# Total calculation
dense_layer_params = attn_params_per_layer + dense_ffn_per_layer
moe_layer_params = attn_params_per_layer + moe_ffn_per_layer
total_dense_params = args.n_dense_layers * dense_layer_params
total_moe_params = (args.n_layers - args.n_dense_layers) * moe_layer_params
total_params = total_embed_params + total_dense_params + total_moe_params + final_norm_params
print(f" Layer Summary:")
print(f" Dense layers ({args.n_dense_layers}x): {total_dense_params:>15,} params")
print(f" MoE layers ({args.n_layers - args.n_dense_layers}x): {total_moe_params:>15,} params")
print(f" Final Norm: {final_norm_params:>15,} params")
print(f"\n{'='*70}")
print(f"📊 TOTAL PARAMETERS: {total_params:>15,} ({total_params/1e6:.2f}M)")
print(f"{'='*70}\n")
# Memory calculations
print(f"{'='*70}")
print(f"💾 MEMORY USAGE:")
print(f"{'='*70}\n")
bytes_per_param_bf16 = 2
bytes_per_param_fp32 = 4
# Model weights
weight_memory_bf16 = total_params * bytes_per_param_bf16 / (1024**3)
weight_memory_fp32 = total_params * bytes_per_param_fp32 / (1024**3)
print(f" Model Weights:")
print(f" BF16 (inference): {weight_memory_bf16:>10.3f} GB")
print(f" FP32 (training): {weight_memory_fp32:>10.3f} GB\n")
# KV Cache
kv_cache_per_layer = args.max_batch_size * args.max_seq_len * (args.kv_lora_rank + args.qk_rope_head_dim)
total_kv_cache = kv_cache_per_layer * args.n_layers * bytes_per_param_bf16 / (1024**3)
print(f" KV Cache (BF16):")
print(f" Per layer: {kv_cache_per_layer * bytes_per_param_bf16 / (1024**3):>10.3f} GB")
print(f" Total ({args.n_layers} layers): {total_kv_cache:>10.3f} GB\n")
# Activations (rough estimate)
activation_memory = (args.max_batch_size * args.max_seq_len * args.dim * args.n_layers * 4) / (1024**3)
print(f" Activations (estimate): {activation_memory:>10.3f} GB\n")
# Training overhead
gradients_memory = weight_memory_fp32 # Same size as weights
optimizer_states = weight_memory_fp32 * 2 # Adam: 2x for momentum + variance
training_overhead = gradients_memory + optimizer_states
print(f" Training Overhead (FP32):")
print(f" Gradients: {gradients_memory:>10.3f} GB")
print(f" Optimizer states (Adam): {optimizer_states:>10.3f} GB")
print(f" Total overhead: {training_overhead:>10.3f} GB\n")
# Total estimates
inference_total = weight_memory_bf16 + total_kv_cache + activation_memory
training_total = weight_memory_fp32 + total_kv_cache + activation_memory + training_overhead
print(f"{'='*70}")
print(f" INFERENCE (BF16): {inference_total:>10.3f} GB")
print(f" TRAINING (FP32 + Adam): {training_total:>10.3f} GB")
print(f"{'='*70}\n")
# Memory analysis
print(f"{'='*70}")
print(f"🎯 MEMORY ANALYSIS:")
print(f"{'='*70}\n")
for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
if inference_total <= threshold:
print(f" ✅ Inference fits in {name} GPU")
break
else:
print(f" ❌ Inference requires >80GB GPU")
for threshold, name in [(8, "8GB"), (16, "16GB"), (24, "24GB"), (32, "32GB"), (40, "40GB"), (48, "48GB"), (80, "80GB")]:
if training_total <= threshold:
print(f" ✅ Training fits in {name} GPU")
break
else:
print(f" ❌ Training requires >80GB GPU")
print(f"\n{'='*70}\n")
return {
'total_params': total_params,
'weight_memory_gb': weight_memory_bf16,
'inference_memory_gb': inference_total,
'training_memory_gb': training_total
}
if __name__ == "__main__":
import json
from pathlib import Path
# Try to load from config.json, otherwise use defaults
config_path = Path(__file__).parent / "config.json"
if config_path.exists():
print(f"📄 Loading configuration from {config_path}")
with open(config_path) as f:
config = json.load(f)
args = ModelArgs(**config["model"])
else:
print("⚠️ config.json not found, using default ModelArgs")
args = ModelArgs()
# Run estimation
results = estimate_model_size(args)