Spaces:
Running on Zero
Running on Zero
File size: 3,009 Bytes
12d2e34 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | """
The memory maths — the honest heart of the tool.
We estimate how much memory a given model needs, broken into three parts:
1. weights : the model itself = params x bits-per-weight
2. kv_cache : the model's short-term "chat memory" — grows with how much
text it's holding (context). This is what people forget.
3. overhead : runtime working space + a safety margin.
Every formula is spelled out and deliberately a little pessimistic. If we're
going to be wrong, we want to be wrong on the safe side.
"""
from dataclasses import dataclass
from .catalogue import ModelClass, QuantTier
# Bytes per element in the KV cache. Modern runtimes can store it at 16-bit.
_KV_BYTES = 2
# Modern models share key/value heads (this is called "GQA"), which cuts the
# KV cache dramatically vs. older designs. ~0.30 is a conservative typical
# factor (i.e. we still assume KV is fairly chunky to stay safe).
_GQA_FACTOR = 0.30
# Flat runtime working space (program, buffers) in GB.
_RUNTIME_OVERHEAD_GB = 0.8
@dataclass
class MemoryEstimate:
weights_gb: float
kv_cache_gb: float
overhead_gb: float
context_tokens: int
@property
def total_gb(self) -> float:
return round(self.weights_gb + self.kv_cache_gb + self.overhead_gb, 2)
def estimate_memory(
model: ModelClass,
quant: QuantTier,
*,
context_tokens: int = 4096,
job_overhead_factor: float = 1.0,
) -> MemoryEstimate:
"""Estimate total memory (GB) to run `model` at `quant`.
context_tokens: how much text the model holds at once. 4096 (~3000 words)
is a sensible default for everyday use.
job_overhead_factor: extra multiplier for heavier jobs (RAG, agents,
fine-tuning) — see UseCase.overhead_factor.
"""
# 1) Weights ---------------------------------------------------------
weights = model.billions * quant.gb_per_billion
# 2) KV cache --------------------------------------------------------
# bytes = 2(K and V) x layers x hidden x tokens x bytes_per_elem x gqa
kv_bytes = (
2 * model.n_layers * model.hidden * context_tokens
* _KV_BYTES * _GQA_FACTOR
)
kv = kv_bytes / 1e9
# 3) Overhead --------------------------------------------------------
# A flat runtime cost, plus ~10% of the weights as working scratch,
# all scaled by how demanding the job is.
overhead = (_RUNTIME_OVERHEAD_GB + 0.10 * weights) * job_overhead_factor
# For training (factor well above 1) the *whole* footprint inflates,
# because optimiser state and activations dwarf plain inference.
if job_overhead_factor >= 2.0:
weights *= 1.0 # weights themselves unchanged...
kv *= 1.0
overhead = (_RUNTIME_OVERHEAD_GB + weights * (job_overhead_factor - 1.0))
return MemoryEstimate(
weights_gb=round(weights, 2),
kv_cache_gb=round(kv, 2),
overhead_gb=round(overhead, 2),
context_tokens=context_tokens,
)
|