hsaq-tools / assignment_v2.py
mxguru1's picture
Add KV interception hooks + generalised allocator + smoke tests (2/3: assignment_v2.py)
55f5f5e verified
"""
Sovereign Hive — Greedy resource allocator (v2)
================================================
What changed in v2:
- The allocator no longer hardcodes "weights" as the cost dimension. It now
works over any (cost_per_unit, unit_count) pair, so KV-cache options
(cost = bytes_per_kv_token × max_seq_len) flow through the same algorithm
as weight options (cost = bytes_per_param × param_count).
- Existing LayerOption / LayerCandidate / assign_bit_widths names are
preserved as thin aliases over the generic core, so call sites that
haven't been ported yet keep working unchanged.
- assign_combined() runs two independent allocations (one per budget) and
returns a CombinedAssignmentResult. Weight budget and KV budget do NOT
fungibly trade — saving weight bytes can't pay for KV bytes — because
the two pools live in physically different VRAM regions at inference.
The right interface is "two budgets, both must fit," not one combined
pot.
Why two-budgets-not-one:
Weight VRAM is static across the run. KV VRAM scales with context length
at inference. You commit to a max ctx upfront (e.g. 4K, 8K), size the
KV reserve for that, and the weights get what's left. Letting the
allocator decide to spend "saved weight bytes" on extra KV precision is
unsafe: it produces a config that fits at low ctx but OOMs at high ctx.
Algorithm (unchanged in spirit, generalized in code):
1. Start: every candidate at its cheapest option.
2. While budget allows: globally pick the (candidate, upgrade) pair
with the highest drift-reduction-per-extra-byte; apply.
3. Stop: no upgrade fits or no upgrade reduces drift.
Complexity unchanged: O(C × O^2) per pass, converges in ≤ C × (O-1) passes,
where C = number of candidates and O = options per candidate. Milliseconds.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Literal
# ---------------------------------------------------------------------------
# Generic option / candidate types
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class GenericOption:
"""One option for one candidate.
cost_per_unit × unit_count = total bytes if chosen.
drift is the measured quality cost (lower is better).
label / tag carry arbitrary identification for the caller's reconstruction
(e.g. ('hqq', 4) for weights; ('hqq_g64', 4, 4) for K/V split).
"""
cost_per_unit: float
drift: float
label: tuple = () # caller-defined identification of the option
@dataclass
class GenericCandidate:
"""A single allocation site (e.g. one weight layer or one KV layer)."""
candidate_id: tuple # (layer_idx, component) — caller-defined
unit_count: int # params for weights, max_seq_len for KV, etc.
options: list[GenericOption]
def cheapest(self) -> GenericOption:
return min(self.options, key=lambda o: o.cost_per_unit)
@dataclass
class GenericAssignment:
candidate_id: tuple
chosen: GenericOption
bytes_used: float
@dataclass
class GenericAssignmentResult:
assignments: list[GenericAssignment]
total_drift: float
total_bytes: float
budget_bytes: float
saturated: bool
@property
def total_gb(self) -> float:
return self.total_bytes / 1e9
@property
def budget_gb(self) -> float:
return self.budget_bytes / 1e9
@property
def headroom_gb(self) -> float:
return (self.budget_bytes - self.total_bytes) / 1e9
class BudgetInfeasibleError(Exception):
def __init__(self, current_bytes: float, budget_bytes: float, label: str = "budget"):
super().__init__(
f"Even the cheapest assignment ({current_bytes / 1e9:.2f} GB) exceeds "
f"the {label} ({budget_bytes / 1e9:.2f} GB). Reduce candidate count, "
"increase aggressiveness of cheapest option, or relax the budget."
)
self.current_bytes = current_bytes
self.budget_bytes = budget_bytes
# ---------------------------------------------------------------------------
# Core algorithm (generic)
# ---------------------------------------------------------------------------
def assign_greedy(
candidates: list[GenericCandidate],
budget_bytes: float,
*,
budget_label: str = "budget",
) -> GenericAssignmentResult:
"""Greedy allocation by drift-reduction-per-byte ratio.
Raises BudgetInfeasibleError if even the cheapest assignment overshoots.
"""
if not candidates:
raise ValueError("No candidates provided")
if budget_bytes <= 0:
raise ValueError(f"Non-positive budget: {budget_bytes}")
# Initialize at cheapest option per candidate.
current: dict[tuple, GenericOption] = {}
bytes_used: dict[tuple, float] = {}
cand_by_id: dict[tuple, GenericCandidate] = {}
for c in candidates:
key = c.candidate_id
cheapest = c.cheapest()
current[key] = cheapest
bytes_used[key] = cheapest.cost_per_unit * c.unit_count
cand_by_id[key] = c
total_bytes = sum(bytes_used.values())
if total_bytes > budget_bytes:
raise BudgetInfeasibleError(total_bytes, budget_bytes, budget_label)
def best_upgrade(key: tuple):
"""Best (ratio, target_option, extra_bytes) for this candidate, or None."""
cand = cand_by_id[key]
cur = current[key]
best = None
for opt in cand.options:
if opt.cost_per_unit <= cur.cost_per_unit:
continue
if opt.drift >= cur.drift:
continue
drift_reduction = cur.drift - opt.drift
extra_bytes = (opt.cost_per_unit - cur.cost_per_unit) * cand.unit_count
if extra_bytes <= 0:
continue
ratio = drift_reduction / extra_bytes
if best is None or ratio > best[0]:
best = (ratio, opt, extra_bytes)
return best
saturated = False
while True:
winner_key = None
winner_ratio = -1.0
winner_opt = None
winner_extra = 0.0
any_available = False
for key in current:
up = best_upgrade(key)
if up is None:
continue
any_available = True
ratio, target, extra = up
if total_bytes + extra > budget_bytes:
continue
if ratio > winner_ratio:
winner_ratio = ratio
winner_key = key
winner_opt = target
winner_extra = extra
if winner_key is None:
saturated = any_available
break
bytes_used[winner_key] += winner_extra
total_bytes += winner_extra
current[winner_key] = winner_opt
assignments = [
GenericAssignment(
candidate_id=key,
chosen=current[key],
bytes_used=bytes_used[key],
)
for key in sorted(current.keys())
]
total_drift = sum(a.chosen.drift for a in assignments)
return GenericAssignmentResult(
assignments=assignments,
total_drift=total_drift,
total_bytes=total_bytes,
budget_bytes=budget_bytes,
saturated=saturated,
)
# ---------------------------------------------------------------------------
# Combined weight + KV allocation
# ---------------------------------------------------------------------------
@dataclass
class CombinedAssignmentResult:
"""Result of running greedy allocation independently on two budgets."""
weights: GenericAssignmentResult
kv: GenericAssignmentResult | None # None if no KV candidates provided
@property
def total_drift(self) -> float:
kv_drift = self.kv.total_drift if self.kv else 0.0
return self.weights.total_drift + kv_drift
@property
def total_gb(self) -> float:
kv_gb = self.kv.total_gb if self.kv else 0.0
return self.weights.total_gb + kv_gb
def assign_combined(
weight_candidates: list[GenericCandidate],
kv_candidates: list[GenericCandidate] | None,
weight_budget_bytes: float,
kv_budget_bytes: float,
) -> CombinedAssignmentResult:
"""Run two independent greedy allocations under their respective budgets.
The budgets do NOT trade — see module docstring. Saved weight bytes
cannot be reassigned to KV at inference because the two pools live in
different VRAM regions and the KV pool scales with context length.
"""
weight_result = assign_greedy(
weight_candidates, weight_budget_bytes, budget_label="weight budget"
)
kv_result = None
if kv_candidates:
kv_result = assign_greedy(
kv_candidates, kv_budget_bytes, budget_label="KV budget"
)
return CombinedAssignmentResult(weights=weight_result, kv=kv_result)
# ---------------------------------------------------------------------------
# Back-compat: existing names that callers in pipeline.py / hunter use
# ---------------------------------------------------------------------------
# These keep the v1 public surface intact. New code should use the generic
# names above. The aliases construct GenericCandidate/Option under the hood
# and translate results back into the old shapes.
Quantizer = Literal["hqq", "awq", "gptq"]
BitWidth = Literal[2, 3, 4]
@dataclass(frozen=True)
class LayerOption:
"""Weight-quantization option for one layer/component."""
bits: BitWidth
quantizer: Quantizer
drift: float
bytes_per_param: float
def to_generic(self) -> GenericOption:
return GenericOption(
cost_per_unit=self.bytes_per_param,
drift=self.drift,
label=(self.quantizer, self.bits),
)
@classmethod
def from_generic(cls, g: GenericOption) -> LayerOption:
# label = (quantizer, bits)
quantizer, bits = g.label
return cls(
bits=bits,
quantizer=quantizer,
drift=g.drift,
bytes_per_param=g.cost_per_unit,
)
@dataclass
class LayerCandidate:
layer_idx: int
component: str
param_count: int
options: list[LayerOption]
def cheapest(self) -> LayerOption:
return min(self.options, key=lambda o: o.bytes_per_param)
def to_generic(self) -> GenericCandidate:
return GenericCandidate(
candidate_id=(self.layer_idx, self.component),
unit_count=self.param_count,
options=[o.to_generic() for o in self.options],
)
@dataclass
class Assignment:
layer_idx: int
component: str
chosen: LayerOption
bytes_used: float
@dataclass
class AssignmentResult:
assignments: list[Assignment]
total_drift: float
total_weights_gb: float
budget_gb: float
headroom_gb: float
saturated: bool
@property
def by_layer(self) -> dict[tuple[int, str], Assignment]:
return {(a.layer_idx, a.component): a for a in self.assignments}
def assign_bit_widths(
candidates: list[LayerCandidate],
weight_budget_gb: float,
) -> AssignmentResult:
"""v1 API — preserved. Delegates to the generic allocator."""
generic_cands = [c.to_generic() for c in candidates]
gen_result = assign_greedy(
generic_cands,
budget_bytes=weight_budget_gb * 1e9,
budget_label="weight budget",
)
# Translate back to v1 shapes
assignments: list[Assignment] = []
for ga in gen_result.assignments:
layer_idx, component = ga.candidate_id
assignments.append(Assignment(
layer_idx=layer_idx,
component=component,
chosen=LayerOption.from_generic(ga.chosen),
bytes_used=ga.bytes_used,
))
return AssignmentResult(
assignments=assignments,
total_drift=gen_result.total_drift,
total_weights_gb=gen_result.total_gb,
budget_gb=weight_budget_gb,
headroom_gb=weight_budget_gb - gen_result.total_gb,
saturated=gen_result.saturated,
)
def pareto_frontier(
candidates: list[LayerCandidate],
budgets_gb: list[float],
) -> list[AssignmentResult]:
"""v1 API — preserved."""
results: list[AssignmentResult] = []
for b in budgets_gb:
try:
results.append(assign_bit_widths(candidates, b))
except BudgetInfeasibleError:
continue
return results
# ---------------------------------------------------------------------------
# KV-specific convenience wrappers
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class KVOption:
"""KV-cache quantization option for one attention layer."""
k_bits: int
v_bits: int
quantizer: str
drift: float
bytes_per_kv_token: float
def to_generic(self) -> GenericOption:
return GenericOption(
cost_per_unit=self.bytes_per_kv_token,
drift=self.drift,
label=(self.quantizer, self.k_bits, self.v_bits),
)
@classmethod
def from_generic(cls, g: GenericOption) -> KVOption:
quantizer, k_bits, v_bits = g.label
return cls(
k_bits=k_bits,
v_bits=v_bits,
quantizer=quantizer,
drift=g.drift,
bytes_per_kv_token=g.cost_per_unit,
)
@dataclass
class KVCandidate:
layer_idx: int
num_kv_heads: int
head_dim: int
options: list[KVOption]
def to_generic(self, max_seq_len: int) -> GenericCandidate:
# unit_count for KV is the number of tokens we're sizing the cache for.
return GenericCandidate(
candidate_id=(self.layer_idx, "kv"),
unit_count=max_seq_len,
options=[o.to_generic() for o in self.options],
)
@dataclass
class KVAssignment:
layer_idx: int
chosen: KVOption
bytes_used: float
@dataclass
class KVAssignmentResult:
assignments: list[KVAssignment]
total_drift: float
total_kv_gb: float
budget_gb: float
headroom_gb: float
saturated: bool
max_seq_len: int
def assign_kv_bits(
candidates: list[KVCandidate],
kv_budget_gb: float,
max_seq_len: int,
) -> KVAssignmentResult:
"""Allocate KV bit-widths across attention layers under a KV-cache budget.
max_seq_len is the context length you're sizing the cache for. The budget
must fit the worst case (full max_seq_len) because the cache cannot be
re-quantized mid-generation.
"""
generic_cands = [c.to_generic(max_seq_len) for c in candidates]
gen_result = assign_greedy(
generic_cands,
budget_bytes=kv_budget_gb * 1e9,
budget_label="KV cache budget",
)
assignments: list[KVAssignment] = []
for ga in gen_result.assignments:
layer_idx, _component = ga.candidate_id
assignments.append(KVAssignment(
layer_idx=layer_idx,
chosen=KVOption.from_generic(ga.chosen),
bytes_used=ga.bytes_used,
))
return KVAssignmentResult(
assignments=assignments,
total_drift=gen_result.total_drift,
total_kv_gb=gen_result.total_gb,
budget_gb=kv_budget_gb,
headroom_gb=kv_budget_gb - gen_result.total_gb,
saturated=gen_result.saturated,
max_seq_len=max_seq_len,
)