""" The memory maths — the honest heart of the tool. We estimate how much memory a given model needs, broken into three parts: 1. weights : the model itself = params x bits-per-weight 2. kv_cache : the model's short-term "chat memory" — grows with how much text it's holding (context). This is what people forget. 3. overhead : runtime working space + a safety margin. Every formula is spelled out and deliberately a little pessimistic. If we're going to be wrong, we want to be wrong on the safe side. """ from dataclasses import dataclass from .catalogue import ModelClass, QuantTier # Bytes per element in the KV cache. Modern runtimes can store it at 16-bit. _KV_BYTES = 2 # Modern models share key/value heads (this is called "GQA"), which cuts the # KV cache dramatically vs. older designs. ~0.30 is a conservative typical # factor (i.e. we still assume KV is fairly chunky to stay safe). _GQA_FACTOR = 0.30 # Flat runtime working space (program, buffers) in GB. _RUNTIME_OVERHEAD_GB = 0.8 @dataclass class MemoryEstimate: weights_gb: float kv_cache_gb: float overhead_gb: float context_tokens: int @property def total_gb(self) -> float: return round(self.weights_gb + self.kv_cache_gb + self.overhead_gb, 2) def estimate_memory( model: ModelClass, quant: QuantTier, *, context_tokens: int = 4096, job_overhead_factor: float = 1.0, ) -> MemoryEstimate: """Estimate total memory (GB) to run `model` at `quant`. context_tokens: how much text the model holds at once. 4096 (~3000 words) is a sensible default for everyday use. job_overhead_factor: extra multiplier for heavier jobs (RAG, agents, fine-tuning) — see UseCase.overhead_factor. """ # 1) Weights --------------------------------------------------------- weights = model.billions * quant.gb_per_billion # 2) KV cache -------------------------------------------------------- # bytes = 2(K and V) x layers x hidden x tokens x bytes_per_elem x gqa kv_bytes = ( 2 * model.n_layers * model.hidden * context_tokens * _KV_BYTES * _GQA_FACTOR ) kv = kv_bytes / 1e9 # 3) Overhead -------------------------------------------------------- # A flat runtime cost, plus ~10% of the weights as working scratch, # all scaled by how demanding the job is. overhead = (_RUNTIME_OVERHEAD_GB + 0.10 * weights) * job_overhead_factor # For training (factor well above 1) the *whole* footprint inflates, # because optimiser state and activations dwarf plain inference. if job_overhead_factor >= 2.0: weights *= 1.0 # weights themselves unchanged... kv *= 1.0 overhead = (_RUNTIME_OVERHEAD_GB + weights * (job_overhead_factor - 1.0)) return MemoryEstimate( weights_gb=round(weights, 2), kv_cache_gb=round(kv, 2), overhead_gb=round(overhead, 2), context_tokens=context_tokens, )