File size: 3,009 Bytes
12d2e34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
The memory maths — the honest heart of the tool.

We estimate how much memory a given model needs, broken into three parts:

  1. weights      : the model itself = params x bits-per-weight
  2. kv_cache     : the model's short-term "chat memory" — grows with how much
                    text it's holding (context). This is what people forget.
  3. overhead     : runtime working space + a safety margin.

Every formula is spelled out and deliberately a little pessimistic. If we're
going to be wrong, we want to be wrong on the safe side.
"""

from dataclasses import dataclass

from .catalogue import ModelClass, QuantTier


# Bytes per element in the KV cache. Modern runtimes can store it at 16-bit.
_KV_BYTES = 2

# Modern models share key/value heads (this is called "GQA"), which cuts the
# KV cache dramatically vs. older designs. ~0.30 is a conservative typical
# factor (i.e. we still assume KV is fairly chunky to stay safe).
_GQA_FACTOR = 0.30

# Flat runtime working space (program, buffers) in GB.
_RUNTIME_OVERHEAD_GB = 0.8


@dataclass
class MemoryEstimate:
    weights_gb: float
    kv_cache_gb: float
    overhead_gb: float
    context_tokens: int

    @property
    def total_gb(self) -> float:
        return round(self.weights_gb + self.kv_cache_gb + self.overhead_gb, 2)


def estimate_memory(
    model: ModelClass,
    quant: QuantTier,
    *,
    context_tokens: int = 4096,
    job_overhead_factor: float = 1.0,
) -> MemoryEstimate:
    """Estimate total memory (GB) to run `model` at `quant`.

    context_tokens: how much text the model holds at once. 4096 (~3000 words)
        is a sensible default for everyday use.
    job_overhead_factor: extra multiplier for heavier jobs (RAG, agents,
        fine-tuning) — see UseCase.overhead_factor.
    """
    # 1) Weights ---------------------------------------------------------
    weights = model.billions * quant.gb_per_billion

    # 2) KV cache --------------------------------------------------------
    # bytes = 2(K and V) x layers x hidden x tokens x bytes_per_elem x gqa
    kv_bytes = (
        2 * model.n_layers * model.hidden * context_tokens
        * _KV_BYTES * _GQA_FACTOR
    )
    kv = kv_bytes / 1e9

    # 3) Overhead --------------------------------------------------------
    # A flat runtime cost, plus ~10% of the weights as working scratch,
    # all scaled by how demanding the job is.
    overhead = (_RUNTIME_OVERHEAD_GB + 0.10 * weights) * job_overhead_factor

    # For training (factor well above 1) the *whole* footprint inflates,
    # because optimiser state and activations dwarf plain inference.
    if job_overhead_factor >= 2.0:
        weights *= 1.0          # weights themselves unchanged...
        kv *= 1.0
        overhead = (_RUNTIME_OVERHEAD_GB + weights * (job_overhead_factor - 1.0))

    return MemoryEstimate(
        weights_gb=round(weights, 2),
        kv_cache_gb=round(kv, 2),
        overhead_gb=round(overhead, 2),
        context_tokens=context_tokens,
    )