| """Model utility functions.""" |
|
|
| from __future__ import annotations |
|
|
| import math |
| from typing import TYPE_CHECKING |
|
|
| from llm_lab.config import ModelConfig |
|
|
| if TYPE_CHECKING: |
| from .llm_model import LLMModel |
|
|
|
|
| def count_parameters_detailed(model: "LLMModel") -> dict: |
| """Print a detailed breakdown of the model's parameter count by component.""" |
| total = 0 |
| breakdown = {} |
|
|
| |
| emb_params = model.token_embedding.weight.numel() |
| breakdown["token_embedding"] = emb_params |
| total += emb_params |
|
|
| |
| layer_total = 0 |
| layer_detail = {} |
| layer = model.layers[0] |
|
|
| for name, param in layer.named_parameters(): |
| layer_detail[name] = param.numel() |
| layer_total += param.numel() |
|
|
| breakdown["per_layer"] = layer_detail |
| breakdown["per_layer_total"] = layer_total |
| breakdown["all_layers_total"] = layer_total * len(model.layers) |
| total += layer_total * len(model.layers) |
|
|
| |
| norm_params = model.final_norm.weight.numel() |
| breakdown["final_norm"] = norm_params |
| total += norm_params |
|
|
| |
| breakdown["lm_head"] = "weight tying (0 additional)" |
| breakdown["total"] = total |
|
|
| return breakdown |
|
|
|
|
| def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict: |
| """Estimate GPU memory usage of the model. |
| |
| Args: |
| dtype_bytes: 2 (bf16/fp16) or 4 (fp32) |
| """ |
| |
| emb = config.vocab_size * config.hidden_dim |
| per_layer = ( |
| config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim |
| + config.num_heads * config.head_dim * config.hidden_dim |
| + 3 * config.hidden_dim * config.intermediate_dim |
| + 2 * config.hidden_dim |
| ) |
| total_params = emb + per_layer * config.num_layers + config.hidden_dim |
|
|
| model_gb = total_params * dtype_bytes / 1e9 |
| optimizer_gb = total_params * 8 / 1e9 |
| gradient_gb = total_params * dtype_bytes / 1e9 |
|
|
| |
| |
| activation_gb = ( |
| batch_size * config.max_seq_len * config.hidden_dim * 4 |
| * math.sqrt(config.num_layers) |
| / 1e9 |
| ) |
|
|
| return { |
| "total_parameters": total_params, |
| "model_weights_gb": round(model_gb, 2), |
| "optimizer_states_gb": round(optimizer_gb, 2), |
| "gradients_gb": round(gradient_gb, 2), |
| "activations_estimated_gb": round(activation_gb, 2), |
| "total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2), |
| } |
|
|