File size: 2,819 Bytes
858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe 858e8b2 8a58ffe | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 | """Model utility functions."""
from __future__ import annotations
import math
from typing import TYPE_CHECKING
from llm_lab.config import ModelConfig
if TYPE_CHECKING:
from .llm_model import LLMModel
def count_parameters_detailed(model: "LLMModel") -> dict:
"""Print a detailed breakdown of the model's parameter count by component."""
total = 0
breakdown = {}
# Embedding
emb_params = model.token_embedding.weight.numel()
breakdown["token_embedding"] = emb_params
total += emb_params
# Per layer
layer_total = 0
layer_detail = {}
layer = model.layers[0]
for name, param in layer.named_parameters():
layer_detail[name] = param.numel()
layer_total += param.numel()
breakdown["per_layer"] = layer_detail
breakdown["per_layer_total"] = layer_total
breakdown["all_layers_total"] = layer_total * len(model.layers)
total += layer_total * len(model.layers)
# Final norm
norm_params = model.final_norm.weight.numel()
breakdown["final_norm"] = norm_params
total += norm_params
# LM head (weight tying, so 0 additional parameters)
breakdown["lm_head"] = "weight tying (0 additional)"
breakdown["total"] = total
return breakdown
def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
"""Estimate GPU memory usage of the model.
Args:
dtype_bytes: 2 (bf16/fp16) or 4 (fp32)
"""
# Approximate parameter count
emb = config.vocab_size * config.hidden_dim
per_layer = (
config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV
+ config.num_heads * config.head_dim * config.hidden_dim # O proj
+ 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down)
+ 2 * config.hidden_dim # 2 × RMSNorm
)
total_params = emb + per_layer * config.num_layers + config.hidden_dim
model_gb = total_params * dtype_bytes / 1e9
optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32
gradient_gb = total_params * dtype_bytes / 1e9
# Activation memory (assuming activation checkpointing is applied)
# Rough estimate: batch_size × seq_len × hidden_dim × num_layers × factor
activation_gb = (
batch_size * config.max_seq_len * config.hidden_dim * 4 # bytes
* math.sqrt(config.num_layers) # effect of checkpointing
/ 1e9
)
return {
"total_parameters": total_params,
"model_weights_gb": round(model_gb, 2),
"optimizer_states_gb": round(optimizer_gb, 2),
"gradients_gb": round(gradient_gb, 2),
"activations_estimated_gb": round(activation_gb, 2),
"total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
}
|