File size: 2,819 Bytes
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
858e8b2
8a58ffe
 
 
 
 
 
 
858e8b2
8a58ffe
 
858e8b2
8a58ffe
858e8b2
8a58ffe
 
 
 
 
 
 
 
 
 
 
 
 
858e8b2
 
8a58ffe
858e8b2
 
8a58ffe
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Model utility functions."""

from __future__ import annotations

import math
from typing import TYPE_CHECKING

from llm_lab.config import ModelConfig

if TYPE_CHECKING:
    from .llm_model import LLMModel


def count_parameters_detailed(model: "LLMModel") -> dict:
    """Print a detailed breakdown of the model's parameter count by component."""
    total = 0
    breakdown = {}

    # Embedding
    emb_params = model.token_embedding.weight.numel()
    breakdown["token_embedding"] = emb_params
    total += emb_params

    # Per layer
    layer_total = 0
    layer_detail = {}
    layer = model.layers[0]

    for name, param in layer.named_parameters():
        layer_detail[name] = param.numel()
        layer_total += param.numel()

    breakdown["per_layer"] = layer_detail
    breakdown["per_layer_total"] = layer_total
    breakdown["all_layers_total"] = layer_total * len(model.layers)
    total += layer_total * len(model.layers)

    # Final norm
    norm_params = model.final_norm.weight.numel()
    breakdown["final_norm"] = norm_params
    total += norm_params

    # LM head (weight tying, so 0 additional parameters)
    breakdown["lm_head"] = "weight tying (0 additional)"
    breakdown["total"] = total

    return breakdown


def estimate_memory_gb(config: ModelConfig, batch_size: int = 4, dtype_bytes: int = 2) -> dict:
    """Estimate GPU memory usage of the model.

    Args:
        dtype_bytes: 2 (bf16/fp16) or 4 (fp32)
    """
    # Approximate parameter count
    emb = config.vocab_size * config.hidden_dim
    per_layer = (
        config.hidden_dim * (config.num_heads + 2 * config.num_kv_heads) * config.head_dim  # QKV
        + config.num_heads * config.head_dim * config.hidden_dim  # O proj
        + 3 * config.hidden_dim * config.intermediate_dim  # SwiGLU (gate + up + down)
        + 2 * config.hidden_dim  # 2 × RMSNorm
    )
    total_params = emb + per_layer * config.num_layers + config.hidden_dim

    model_gb = total_params * dtype_bytes / 1e9
    optimizer_gb = total_params * 8 / 1e9  # AdamW: 2 states × fp32
    gradient_gb = total_params * dtype_bytes / 1e9

    # Activation memory (assuming activation checkpointing is applied)
    # Rough estimate: batch_size × seq_len × hidden_dim × num_layers × factor
    activation_gb = (
        batch_size * config.max_seq_len * config.hidden_dim * 4  # bytes
        * math.sqrt(config.num_layers)  # effect of checkpointing
        / 1e9
    )

    return {
        "total_parameters": total_params,
        "model_weights_gb": round(model_gb, 2),
        "optimizer_states_gb": round(optimizer_gb, 2),
        "gradients_gb": round(gradient_gb, 2),
        "activations_estimated_gb": round(activation_gb, 2),
        "total_estimated_gb": round(model_gb + optimizer_gb + gradient_gb + activation_gb, 2),
    }