"""Model utility functions.""" from __future__ import annotations import math from typing import TYPE_CHECKING from llm_lab.config import ModelConfig if TYPE_CHECKING: from .llm_model import LLMModel def count_parameters_detailed(model: "LLMModel") -> dict: """Print a detailed breakdown of the model's parameter count by component.""" total = 0 breakdown = {} # Embedding emb_params = model.token_embedding.weight.numel() breakdown["token_embedding"] = emb_params total += emb_params # Per layer layer_total = 0 layer_detail = {} layer = model.layers[0] for name, param in layer.named_parameters(): layer_detail[name] = param.numel() layer_total += param.numel() breakdown["per_layer"] = layer_detail breakdown["per_layer_total"] = layer_total breakdown["all_layers_total"] = layer_total * len(model.layers) total += layer_total * len(model.layers) # Final norm norm_params = model.final_norm.weight.numel() breakdown["final_norm"] = norm_params total += norm_params # LM head (weight tying, so 0 additional parameters) breakdown["lm_head"] = "weight tying (0 additional)" breakdown["total"] = total return breakdown def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict: """Estimate GPU memory usage of the model. Args: dtype_bytes: 2 (bf16/fp16) or 4 (fp32) """ # Approximate parameter count emb = config.vocab_size * config.hidden_dim per_layer = ( config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim # QKV + config.num_heads * config.head_dim * config.hidden_dim # O proj + 3 * config.hidden_dim * config.intermediate_dim # SwiGLU (gate + up + down) + 2 * config.hidden_dim # 2 × RMSNorm ) total_params = emb + per_layer * config.num_layers + config.hidden_dim model_gb = total_params * dtype_bytes / 1e9 optimizer_gb = total_params * 8 / 1e9 # AdamW: 2 states × fp32 gradient_gb = total_params * dtype_bytes / 1e9 # Activation memory (assuming activation checkpointing is applied) # Rough estimate: batch_size × seq_len × hidden_dim × num_layers × factor activation_gb = ( batch_size * config.max_seq_len * config.hidden_dim * 4 # bytes * math.sqrt(config.num_layers) # effect of checkpointing / 1e9 ) return { "total_parameters": total_params, "model_weights_gb": round(model_gb, 2), "optimizer_states_gb": round(optimizer_gb, 2), "gradients_gb": round(gradient_gb, 2), "activations_estimated_gb": round(activation_gb, 2), "total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2), }