DotCache-Arena / dotcache /planner.py
DeanoCalver's picture
Initial DotCache Arena Space upload
751ad26 verified
Raw
History Blame Contribute Delete
12.4 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Literal, Sequence
import numpy as np
ModeName = Literal["M0", "M1", "M2", "M3", "M4", "T3"]
QuantSchemeName = Literal["affine", "symmetric", "lut", "sketch", "project", "turbo3"]
SensitivityTier = Literal["exact", "strict", "balanced", "aggressive"]
@dataclass(frozen=True, slots=True)
class PageStats:
token_count: int
rms: float
abs_max: float
outlier_fraction: float
channel_range_mean: float
@dataclass(frozen=True, slots=True)
class PageModeSpec:
mode: ModeName
bits: int
quant_scheme: QuantSchemeName
escape_dtype: str | None = None
policy_id: str = "exact_baseline"
sensitivity_tier: SensitivityTier = "exact"
fallback_reason: str = ""
age_bucket: str = "aged"
@dataclass(frozen=True, slots=True)
class LayerPolicy:
policy_id: str
sensitivity_tier: SensitivityTier
kind: str
candidates: tuple[PageModeSpec, ...]
recent_candidate: PageModeSpec | None = None
recent_window: int = 128
outlier_fraction_threshold: float = 0.05
abs_max_threshold: float = 6.0
channel_range_threshold: float = 4.0
def observe_page(tensor_slice: np.ndarray) -> PageStats:
values = np.asarray(tensor_slice, dtype=np.float32)
if values.ndim != 2:
raise ValueError("tensor_slice must have shape [token_count, head_dim]")
if values.shape[0] == 0:
return PageStats(
token_count=0,
rms=0.0,
abs_max=0.0,
outlier_fraction=0.0,
channel_range_mean=0.0,
)
abs_values = np.abs(values, dtype=np.float32)
rms = float(np.sqrt(np.mean(np.square(values, dtype=np.float32), dtype=np.float64)))
abs_max = float(np.max(abs_values))
outlier_cutoff = max(3.0 * max(rms, 1e-6), 6.0)
outlier_fraction = float(np.mean(abs_values >= outlier_cutoff, dtype=np.float64))
channel_range_mean = float(np.mean(np.max(values, axis=0) - np.min(values, axis=0), dtype=np.float64))
return PageStats(
token_count=int(values.shape[0]),
rms=rms,
abs_max=abs_max,
outlier_fraction=outlier_fraction,
channel_range_mean=channel_range_mean,
)
def choose_page_mode(
layer: int,
kind: str,
token_age: int,
page_stats: PageStats | None,
*,
layer_policy: LayerPolicy,
) -> PageModeSpec:
del layer
if token_age < int(layer_policy.recent_window):
recent_candidate = layer_policy.recent_candidate
if recent_candidate is not None:
return PageModeSpec(
mode=recent_candidate.mode,
bits=recent_candidate.bits,
quant_scheme=recent_candidate.quant_scheme,
escape_dtype=recent_candidate.escape_dtype,
policy_id=layer_policy.policy_id,
sensitivity_tier=layer_policy.sensitivity_tier,
fallback_reason="recent_window",
age_bucket="recent",
)
return PageModeSpec(
mode="M3",
bits=layer_policy.candidates[-1].bits if layer_policy.candidates else 4,
quant_scheme=layer_policy.candidates[-1].quant_scheme if layer_policy.candidates else "affine",
escape_dtype=None,
policy_id=layer_policy.policy_id,
sensitivity_tier=layer_policy.sensitivity_tier,
fallback_reason="recent_window",
age_bucket="recent",
)
if not layer_policy.candidates:
fallback_mode: ModeName = "M0" if kind == "K" else "M0"
return PageModeSpec(
mode=fallback_mode,
bits=4,
quant_scheme="affine",
policy_id=layer_policy.policy_id,
sensitivity_tier=layer_policy.sensitivity_tier,
fallback_reason="no_candidates",
age_bucket="aged",
)
stats = page_stats
failure_reasons: list[str] = []
for index, candidate in enumerate(layer_policy.candidates):
if _candidate_is_allowed(candidate, kind=kind, stats=stats, policy=layer_policy):
fallback_reason = "" if index == 0 else "+".join(failure_reasons) or "fallback"
return PageModeSpec(
mode=candidate.mode,
bits=candidate.bits,
quant_scheme=candidate.quant_scheme,
escape_dtype=candidate.escape_dtype,
policy_id=layer_policy.policy_id,
sensitivity_tier=layer_policy.sensitivity_tier,
fallback_reason=fallback_reason,
age_bucket="aged",
)
failure_reasons.append(f"{candidate.mode.lower()}_stats")
safest = layer_policy.candidates[-1]
return PageModeSpec(
mode=safest.mode,
bits=safest.bits,
quant_scheme=safest.quant_scheme,
escape_dtype=safest.escape_dtype,
policy_id=layer_policy.policy_id,
sensitivity_tier=layer_policy.sensitivity_tier,
fallback_reason="+".join(failure_reasons) if failure_reasons else "threshold_fallback",
age_bucket="aged",
)
def choose_mode(
layer: int,
head: int,
token_age: int,
stats: dict[str, float | bool] | None = None,
*,
recent_window: int = 128,
error_threshold: float | None = None,
) -> str:
del layer
del head
if token_age < recent_window:
return "M3"
if stats is None:
return "M0"
if bool(stats.get("force_escape", False)):
return "M3"
quant_error = float(stats.get("quant_error", 0.0))
if error_threshold is not None and quant_error > error_threshold:
return "M3"
return "M0"
def parse_page_mode_token(token: str) -> PageModeSpec:
parts = [part.strip() for part in token.split("/") if part.strip()]
if len(parts) not in (3, 4):
raise ValueError("page mode tokens must use MODE/SCHEME/BITS[/ESCAPE_DTYPE], for example M0/affine/4 or M3/affine/4/int8")
mode_text, scheme_text, bits_text = parts[:3]
mode = mode_text.upper()
if mode not in {"M0", "M1", "M2", "M3", "M4", "T3"}:
raise ValueError(f"unsupported page mode: {mode_text}")
quant_scheme = scheme_text.lower()
if quant_scheme not in {"affine", "symmetric", "lut", "sketch", "project", "turbo3"}:
raise ValueError(f"unsupported quant scheme: {scheme_text}")
bits = int(bits_text)
escape_dtype = None
if len(parts) == 4:
escape_dtype = parts[3].lower()
if escape_dtype not in {"float16", "float32", "int8"}:
raise ValueError(f"unsupported escape dtype: {parts[3]}")
if mode != "M3":
raise ValueError("escape dtype qualifiers are only supported for M3 page modes")
return PageModeSpec(mode=mode, bits=bits, quant_scheme=quant_scheme, escape_dtype=escape_dtype)
def make_tier_candidates(
*,
kind: str,
sensitivity_tier: SensitivityTier,
default_bits: int,
default_quant_scheme: str,
default_mode: str,
recent_window: int,
recent_escape_dtype: str = "float16",
prefer_project_key_mode: bool = False,
) -> LayerPolicy:
def candidate(mode: ModeName, scheme: QuantSchemeName, bits: int) -> PageModeSpec:
return PageModeSpec(mode=mode, quant_scheme=scheme, bits=bits, sensitivity_tier=sensitivity_tier)
exact_mode = candidate(
default_mode if default_mode in {"M0", "M1", "M2", "M3", "M4", "T3"} else "M0",
default_quant_scheme if default_quant_scheme in {"affine", "symmetric", "lut", "sketch", "project", "turbo3"} else "affine",
default_bits,
)
if sensitivity_tier == "exact":
candidates = (exact_mode,)
thresholds = (0.0, float("inf"), float("inf"))
elif kind == "K":
if sensitivity_tier == "strict":
candidates = (candidate("M0", "affine", 4),)
thresholds = (0.02, 4.5, 3.0)
elif sensitivity_tier == "aggressive":
candidates = (
candidate("M0", "affine", 2),
candidate("M4", "project", 4) if prefer_project_key_mode else candidate("M2", "sketch", 4),
candidate("M0", "affine", 4),
)
thresholds = (0.10, 8.0, 5.5)
else:
candidates = (
candidate("M0", "affine", 2),
candidate("M4", "project", 4) if prefer_project_key_mode else candidate("M2", "sketch", 4),
candidate("M0", "affine", 4),
)
thresholds = (0.05, 6.0, 4.0)
else:
if sensitivity_tier == "strict":
candidates = (
candidate("M1", "lut", 4),
candidate("M0", "affine", 4),
)
thresholds = (0.02, 4.5, 3.0)
elif sensitivity_tier == "aggressive":
candidates = (
candidate("M0", "affine", 2),
candidate("M0", "affine", 3),
candidate("M1", "lut", 4),
candidate("M0", "affine", 4),
)
thresholds = (0.10, 8.0, 5.5)
else:
candidates = (
candidate("M0", "affine", 3),
candidate("M1", "lut", 4),
candidate("M0", "affine", 4),
)
thresholds = (0.05, 6.0, 4.0)
return LayerPolicy(
policy_id=f"{kind.lower()}_{sensitivity_tier}",
sensitivity_tier=sensitivity_tier,
kind=kind,
candidates=candidates,
recent_candidate=PageModeSpec(
mode="M3",
bits=default_bits,
quant_scheme="affine",
escape_dtype=recent_escape_dtype,
sensitivity_tier=sensitivity_tier,
),
recent_window=0 if sensitivity_tier == "exact" else recent_window,
outlier_fraction_threshold=float(thresholds[0]),
abs_max_threshold=float(thresholds[1]),
channel_range_threshold=float(thresholds[2]),
)
def make_explicit_policy(
*,
kind: str,
policy_id: str,
sensitivity_tier: SensitivityTier,
candidates: Sequence[PageModeSpec],
recent_window: int,
recent_escape_dtype: str = "float16",
) -> LayerPolicy:
if not candidates:
raise ValueError("explicit policies must provide at least one candidate")
recent_candidate = next((candidate for candidate in candidates if candidate.mode == "M3"), None)
if recent_candidate is None:
recent_candidate = PageModeSpec(
mode="M3",
bits=4,
quant_scheme="affine",
escape_dtype=recent_escape_dtype,
sensitivity_tier=sensitivity_tier,
)
return LayerPolicy(
policy_id=policy_id,
sensitivity_tier=sensitivity_tier,
kind=kind,
candidates=tuple(candidates),
recent_candidate=recent_candidate,
recent_window=recent_window,
)
def _candidate_is_allowed(
candidate: PageModeSpec,
*,
kind: str,
stats: PageStats | None,
policy: LayerPolicy,
) -> bool:
if stats is None or candidate.mode == "M0" and candidate.bits >= 4:
return True
if candidate.mode == "M3":
return True
if candidate.mode == "M2":
if kind != "K":
return False
return (
stats.outlier_fraction <= policy.outlier_fraction_threshold
and stats.channel_range_mean <= policy.channel_range_threshold
)
if candidate.mode == "M4":
if kind != "K":
return False
return (
stats.outlier_fraction <= policy.outlier_fraction_threshold
and stats.channel_range_mean <= policy.channel_range_threshold
)
if candidate.mode == "M1":
if kind != "V":
return False
return (
stats.outlier_fraction <= policy.outlier_fraction_threshold
and stats.channel_range_mean <= policy.channel_range_threshold
)
if candidate.mode == "T3":
return False
if candidate.mode == "M0" and candidate.bits <= 2:
return (
stats.outlier_fraction <= policy.outlier_fraction_threshold
and stats.abs_max <= policy.abs_max_threshold
)
if candidate.mode == "M0" and candidate.bits == 3:
return (
stats.outlier_fraction <= (policy.outlier_fraction_threshold * 1.25)
and stats.abs_max <= (policy.abs_max_threshold * 1.2)
)
return True