File size: 15,382 Bytes
55f5f5e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
"""
Sovereign Hive — Greedy resource allocator (v2)
================================================

What changed in v2:
- The allocator no longer hardcodes "weights" as the cost dimension. It now
  works over any (cost_per_unit, unit_count) pair, so KV-cache options
  (cost = bytes_per_kv_token × max_seq_len) flow through the same algorithm
  as weight options (cost = bytes_per_param × param_count).
- Existing LayerOption / LayerCandidate / assign_bit_widths names are
  preserved as thin aliases over the generic core, so call sites that
  haven't been ported yet keep working unchanged.
- assign_combined() runs two independent allocations (one per budget) and
  returns a CombinedAssignmentResult. Weight budget and KV budget do NOT
  fungibly trade — saving weight bytes can't pay for KV bytes — because
  the two pools live in physically different VRAM regions at inference.
  The right interface is "two budgets, both must fit," not one combined
  pot.

Why two-budgets-not-one:
  Weight VRAM is static across the run. KV VRAM scales with context length
  at inference. You commit to a max ctx upfront (e.g. 4K, 8K), size the
  KV reserve for that, and the weights get what's left. Letting the
  allocator decide to spend "saved weight bytes" on extra KV precision is
  unsafe: it produces a config that fits at low ctx but OOMs at high ctx.

Algorithm (unchanged in spirit, generalized in code):
  1. Start: every candidate at its cheapest option.
  2. While budget allows: globally pick the (candidate, upgrade) pair
     with the highest drift-reduction-per-extra-byte; apply.
  3. Stop: no upgrade fits or no upgrade reduces drift.

Complexity unchanged: O(C × O^2) per pass, converges in ≤ C × (O-1) passes,
where C = number of candidates and O = options per candidate. Milliseconds.
"""

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Literal

# ---------------------------------------------------------------------------
# Generic option / candidate types
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class GenericOption:
    """One option for one candidate.

    cost_per_unit × unit_count = total bytes if chosen.
    drift is the measured quality cost (lower is better).
    label / tag carry arbitrary identification for the caller's reconstruction
    (e.g. ('hqq', 4) for weights; ('hqq_g64', 4, 4) for K/V split).
    """
    cost_per_unit: float
    drift: float
    label: tuple = ()      # caller-defined identification of the option


@dataclass
class GenericCandidate:
    """A single allocation site (e.g. one weight layer or one KV layer)."""
    candidate_id: tuple        # (layer_idx, component) — caller-defined
    unit_count: int            # params for weights, max_seq_len for KV, etc.
    options: list[GenericOption]

    def cheapest(self) -> GenericOption:
        return min(self.options, key=lambda o: o.cost_per_unit)


@dataclass
class GenericAssignment:
    candidate_id: tuple
    chosen: GenericOption
    bytes_used: float


@dataclass
class GenericAssignmentResult:
    assignments: list[GenericAssignment]
    total_drift: float
    total_bytes: float
    budget_bytes: float
    saturated: bool

    @property
    def total_gb(self) -> float:
        return self.total_bytes / 1e9

    @property
    def budget_gb(self) -> float:
        return self.budget_bytes / 1e9

    @property
    def headroom_gb(self) -> float:
        return (self.budget_bytes - self.total_bytes) / 1e9


class BudgetInfeasibleError(Exception):
    def __init__(self, current_bytes: float, budget_bytes: float, label: str = "budget"):
        super().__init__(
            f"Even the cheapest assignment ({current_bytes / 1e9:.2f} GB) exceeds "
            f"the {label} ({budget_bytes / 1e9:.2f} GB). Reduce candidate count, "
            "increase aggressiveness of cheapest option, or relax the budget."
        )
        self.current_bytes = current_bytes
        self.budget_bytes = budget_bytes


# ---------------------------------------------------------------------------
# Core algorithm (generic)
# ---------------------------------------------------------------------------


def assign_greedy(
    candidates: list[GenericCandidate],
    budget_bytes: float,
    *,
    budget_label: str = "budget",
) -> GenericAssignmentResult:
    """Greedy allocation by drift-reduction-per-byte ratio.

    Raises BudgetInfeasibleError if even the cheapest assignment overshoots.
    """
    if not candidates:
        raise ValueError("No candidates provided")
    if budget_bytes <= 0:
        raise ValueError(f"Non-positive budget: {budget_bytes}")

    # Initialize at cheapest option per candidate.
    current: dict[tuple, GenericOption] = {}
    bytes_used: dict[tuple, float] = {}
    cand_by_id: dict[tuple, GenericCandidate] = {}

    for c in candidates:
        key = c.candidate_id
        cheapest = c.cheapest()
        current[key] = cheapest
        bytes_used[key] = cheapest.cost_per_unit * c.unit_count
        cand_by_id[key] = c

    total_bytes = sum(bytes_used.values())
    if total_bytes > budget_bytes:
        raise BudgetInfeasibleError(total_bytes, budget_bytes, budget_label)

    def best_upgrade(key: tuple):
        """Best (ratio, target_option, extra_bytes) for this candidate, or None."""
        cand = cand_by_id[key]
        cur = current[key]
        best = None
        for opt in cand.options:
            if opt.cost_per_unit <= cur.cost_per_unit:
                continue
            if opt.drift >= cur.drift:
                continue
            drift_reduction = cur.drift - opt.drift
            extra_bytes = (opt.cost_per_unit - cur.cost_per_unit) * cand.unit_count
            if extra_bytes <= 0:
                continue
            ratio = drift_reduction / extra_bytes
            if best is None or ratio > best[0]:
                best = (ratio, opt, extra_bytes)
        return best

    saturated = False
    while True:
        winner_key = None
        winner_ratio = -1.0
        winner_opt = None
        winner_extra = 0.0
        any_available = False

        for key in current:
            up = best_upgrade(key)
            if up is None:
                continue
            any_available = True
            ratio, target, extra = up
            if total_bytes + extra > budget_bytes:
                continue
            if ratio > winner_ratio:
                winner_ratio = ratio
                winner_key = key
                winner_opt = target
                winner_extra = extra

        if winner_key is None:
            saturated = any_available
            break

        bytes_used[winner_key] += winner_extra
        total_bytes += winner_extra
        current[winner_key] = winner_opt

    assignments = [
        GenericAssignment(
            candidate_id=key,
            chosen=current[key],
            bytes_used=bytes_used[key],
        )
        for key in sorted(current.keys())
    ]
    total_drift = sum(a.chosen.drift for a in assignments)
    return GenericAssignmentResult(
        assignments=assignments,
        total_drift=total_drift,
        total_bytes=total_bytes,
        budget_bytes=budget_bytes,
        saturated=saturated,
    )


# ---------------------------------------------------------------------------
# Combined weight + KV allocation
# ---------------------------------------------------------------------------


@dataclass
class CombinedAssignmentResult:
    """Result of running greedy allocation independently on two budgets."""
    weights: GenericAssignmentResult
    kv: GenericAssignmentResult | None  # None if no KV candidates provided

    @property
    def total_drift(self) -> float:
        kv_drift = self.kv.total_drift if self.kv else 0.0
        return self.weights.total_drift + kv_drift

    @property
    def total_gb(self) -> float:
        kv_gb = self.kv.total_gb if self.kv else 0.0
        return self.weights.total_gb + kv_gb


def assign_combined(
    weight_candidates: list[GenericCandidate],
    kv_candidates: list[GenericCandidate] | None,
    weight_budget_bytes: float,
    kv_budget_bytes: float,
) -> CombinedAssignmentResult:
    """Run two independent greedy allocations under their respective budgets.

    The budgets do NOT trade — see module docstring. Saved weight bytes
    cannot be reassigned to KV at inference because the two pools live in
    different VRAM regions and the KV pool scales with context length.
    """
    weight_result = assign_greedy(
        weight_candidates, weight_budget_bytes, budget_label="weight budget"
    )
    kv_result = None
    if kv_candidates:
        kv_result = assign_greedy(
            kv_candidates, kv_budget_bytes, budget_label="KV budget"
        )
    return CombinedAssignmentResult(weights=weight_result, kv=kv_result)


# ---------------------------------------------------------------------------
# Back-compat: existing names that callers in pipeline.py / hunter use
# ---------------------------------------------------------------------------
# These keep the v1 public surface intact. New code should use the generic
# names above. The aliases construct GenericCandidate/Option under the hood
# and translate results back into the old shapes.

Quantizer = Literal["hqq", "awq", "gptq"]
BitWidth = Literal[2, 3, 4]


@dataclass(frozen=True)
class LayerOption:
    """Weight-quantization option for one layer/component."""
    bits: BitWidth
    quantizer: Quantizer
    drift: float
    bytes_per_param: float

    def to_generic(self) -> GenericOption:
        return GenericOption(
            cost_per_unit=self.bytes_per_param,
            drift=self.drift,
            label=(self.quantizer, self.bits),
        )

    @classmethod
    def from_generic(cls, g: GenericOption) -> LayerOption:
        # label = (quantizer, bits)
        quantizer, bits = g.label
        return cls(
            bits=bits,
            quantizer=quantizer,
            drift=g.drift,
            bytes_per_param=g.cost_per_unit,
        )


@dataclass
class LayerCandidate:
    layer_idx: int
    component: str
    param_count: int
    options: list[LayerOption]

    def cheapest(self) -> LayerOption:
        return min(self.options, key=lambda o: o.bytes_per_param)

    def to_generic(self) -> GenericCandidate:
        return GenericCandidate(
            candidate_id=(self.layer_idx, self.component),
            unit_count=self.param_count,
            options=[o.to_generic() for o in self.options],
        )


@dataclass
class Assignment:
    layer_idx: int
    component: str
    chosen: LayerOption
    bytes_used: float


@dataclass
class AssignmentResult:
    assignments: list[Assignment]
    total_drift: float
    total_weights_gb: float
    budget_gb: float
    headroom_gb: float
    saturated: bool

    @property
    def by_layer(self) -> dict[tuple[int, str], Assignment]:
        return {(a.layer_idx, a.component): a for a in self.assignments}


def assign_bit_widths(
    candidates: list[LayerCandidate],
    weight_budget_gb: float,
) -> AssignmentResult:
    """v1 API — preserved. Delegates to the generic allocator."""
    generic_cands = [c.to_generic() for c in candidates]
    gen_result = assign_greedy(
        generic_cands,
        budget_bytes=weight_budget_gb * 1e9,
        budget_label="weight budget",
    )

    # Translate back to v1 shapes
    assignments: list[Assignment] = []
    for ga in gen_result.assignments:
        layer_idx, component = ga.candidate_id
        assignments.append(Assignment(
            layer_idx=layer_idx,
            component=component,
            chosen=LayerOption.from_generic(ga.chosen),
            bytes_used=ga.bytes_used,
        ))
    return AssignmentResult(
        assignments=assignments,
        total_drift=gen_result.total_drift,
        total_weights_gb=gen_result.total_gb,
        budget_gb=weight_budget_gb,
        headroom_gb=weight_budget_gb - gen_result.total_gb,
        saturated=gen_result.saturated,
    )


def pareto_frontier(
    candidates: list[LayerCandidate],
    budgets_gb: list[float],
) -> list[AssignmentResult]:
    """v1 API — preserved."""
    results: list[AssignmentResult] = []
    for b in budgets_gb:
        try:
            results.append(assign_bit_widths(candidates, b))
        except BudgetInfeasibleError:
            continue
    return results


# ---------------------------------------------------------------------------
# KV-specific convenience wrappers
# ---------------------------------------------------------------------------


@dataclass(frozen=True)
class KVOption:
    """KV-cache quantization option for one attention layer."""
    k_bits: int
    v_bits: int
    quantizer: str
    drift: float
    bytes_per_kv_token: float

    def to_generic(self) -> GenericOption:
        return GenericOption(
            cost_per_unit=self.bytes_per_kv_token,
            drift=self.drift,
            label=(self.quantizer, self.k_bits, self.v_bits),
        )

    @classmethod
    def from_generic(cls, g: GenericOption) -> KVOption:
        quantizer, k_bits, v_bits = g.label
        return cls(
            k_bits=k_bits,
            v_bits=v_bits,
            quantizer=quantizer,
            drift=g.drift,
            bytes_per_kv_token=g.cost_per_unit,
        )


@dataclass
class KVCandidate:
    layer_idx: int
    num_kv_heads: int
    head_dim: int
    options: list[KVOption]

    def to_generic(self, max_seq_len: int) -> GenericCandidate:
        # unit_count for KV is the number of tokens we're sizing the cache for.
        return GenericCandidate(
            candidate_id=(self.layer_idx, "kv"),
            unit_count=max_seq_len,
            options=[o.to_generic() for o in self.options],
        )


@dataclass
class KVAssignment:
    layer_idx: int
    chosen: KVOption
    bytes_used: float


@dataclass
class KVAssignmentResult:
    assignments: list[KVAssignment]
    total_drift: float
    total_kv_gb: float
    budget_gb: float
    headroom_gb: float
    saturated: bool
    max_seq_len: int


def assign_kv_bits(
    candidates: list[KVCandidate],
    kv_budget_gb: float,
    max_seq_len: int,
) -> KVAssignmentResult:
    """Allocate KV bit-widths across attention layers under a KV-cache budget.

    max_seq_len is the context length you're sizing the cache for. The budget
    must fit the worst case (full max_seq_len) because the cache cannot be
    re-quantized mid-generation.
    """
    generic_cands = [c.to_generic(max_seq_len) for c in candidates]
    gen_result = assign_greedy(
        generic_cands,
        budget_bytes=kv_budget_gb * 1e9,
        budget_label="KV cache budget",
    )

    assignments: list[KVAssignment] = []
    for ga in gen_result.assignments:
        layer_idx, _component = ga.candidate_id
        assignments.append(KVAssignment(
            layer_idx=layer_idx,
            chosen=KVOption.from_generic(ga.chosen),
            bytes_used=ga.bytes_used,
        ))
    return KVAssignmentResult(
        assignments=assignments,
        total_drift=gen_result.total_drift,
        total_kv_gb=gen_result.total_gb,
        budget_gb=kv_budget_gb,
        headroom_gb=kv_budget_gb - gen_result.total_gb,
        saturated=gen_result.saturated,
        max_seq_len=max_seq_len,
    )