Spaces:
Paused
Paused
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Literal, Sequence | |
| import numpy as np | |
| ModeName = Literal["M0", "M1", "M2", "M3", "M4", "T3"] | |
| QuantSchemeName = Literal["affine", "symmetric", "lut", "sketch", "project", "turbo3"] | |
| SensitivityTier = Literal["exact", "strict", "balanced", "aggressive"] | |
| class PageStats: | |
| token_count: int | |
| rms: float | |
| abs_max: float | |
| outlier_fraction: float | |
| channel_range_mean: float | |
| class PageModeSpec: | |
| mode: ModeName | |
| bits: int | |
| quant_scheme: QuantSchemeName | |
| escape_dtype: str | None = None | |
| policy_id: str = "exact_baseline" | |
| sensitivity_tier: SensitivityTier = "exact" | |
| fallback_reason: str = "" | |
| age_bucket: str = "aged" | |
| class LayerPolicy: | |
| policy_id: str | |
| sensitivity_tier: SensitivityTier | |
| kind: str | |
| candidates: tuple[PageModeSpec, ...] | |
| recent_candidate: PageModeSpec | None = None | |
| recent_window: int = 128 | |
| outlier_fraction_threshold: float = 0.05 | |
| abs_max_threshold: float = 6.0 | |
| channel_range_threshold: float = 4.0 | |
| def observe_page(tensor_slice: np.ndarray) -> PageStats: | |
| values = np.asarray(tensor_slice, dtype=np.float32) | |
| if values.ndim != 2: | |
| raise ValueError("tensor_slice must have shape [token_count, head_dim]") | |
| if values.shape[0] == 0: | |
| return PageStats( | |
| token_count=0, | |
| rms=0.0, | |
| abs_max=0.0, | |
| outlier_fraction=0.0, | |
| channel_range_mean=0.0, | |
| ) | |
| abs_values = np.abs(values, dtype=np.float32) | |
| rms = float(np.sqrt(np.mean(np.square(values, dtype=np.float32), dtype=np.float64))) | |
| abs_max = float(np.max(abs_values)) | |
| outlier_cutoff = max(3.0 * max(rms, 1e-6), 6.0) | |
| outlier_fraction = float(np.mean(abs_values >= outlier_cutoff, dtype=np.float64)) | |
| channel_range_mean = float(np.mean(np.max(values, axis=0) - np.min(values, axis=0), dtype=np.float64)) | |
| return PageStats( | |
| token_count=int(values.shape[0]), | |
| rms=rms, | |
| abs_max=abs_max, | |
| outlier_fraction=outlier_fraction, | |
| channel_range_mean=channel_range_mean, | |
| ) | |
| def choose_page_mode( | |
| layer: int, | |
| kind: str, | |
| token_age: int, | |
| page_stats: PageStats | None, | |
| *, | |
| layer_policy: LayerPolicy, | |
| ) -> PageModeSpec: | |
| del layer | |
| if token_age < int(layer_policy.recent_window): | |
| recent_candidate = layer_policy.recent_candidate | |
| if recent_candidate is not None: | |
| return PageModeSpec( | |
| mode=recent_candidate.mode, | |
| bits=recent_candidate.bits, | |
| quant_scheme=recent_candidate.quant_scheme, | |
| escape_dtype=recent_candidate.escape_dtype, | |
| policy_id=layer_policy.policy_id, | |
| sensitivity_tier=layer_policy.sensitivity_tier, | |
| fallback_reason="recent_window", | |
| age_bucket="recent", | |
| ) | |
| return PageModeSpec( | |
| mode="M3", | |
| bits=layer_policy.candidates[-1].bits if layer_policy.candidates else 4, | |
| quant_scheme=layer_policy.candidates[-1].quant_scheme if layer_policy.candidates else "affine", | |
| escape_dtype=None, | |
| policy_id=layer_policy.policy_id, | |
| sensitivity_tier=layer_policy.sensitivity_tier, | |
| fallback_reason="recent_window", | |
| age_bucket="recent", | |
| ) | |
| if not layer_policy.candidates: | |
| fallback_mode: ModeName = "M0" if kind == "K" else "M0" | |
| return PageModeSpec( | |
| mode=fallback_mode, | |
| bits=4, | |
| quant_scheme="affine", | |
| policy_id=layer_policy.policy_id, | |
| sensitivity_tier=layer_policy.sensitivity_tier, | |
| fallback_reason="no_candidates", | |
| age_bucket="aged", | |
| ) | |
| stats = page_stats | |
| failure_reasons: list[str] = [] | |
| for index, candidate in enumerate(layer_policy.candidates): | |
| if _candidate_is_allowed(candidate, kind=kind, stats=stats, policy=layer_policy): | |
| fallback_reason = "" if index == 0 else "+".join(failure_reasons) or "fallback" | |
| return PageModeSpec( | |
| mode=candidate.mode, | |
| bits=candidate.bits, | |
| quant_scheme=candidate.quant_scheme, | |
| escape_dtype=candidate.escape_dtype, | |
| policy_id=layer_policy.policy_id, | |
| sensitivity_tier=layer_policy.sensitivity_tier, | |
| fallback_reason=fallback_reason, | |
| age_bucket="aged", | |
| ) | |
| failure_reasons.append(f"{candidate.mode.lower()}_stats") | |
| safest = layer_policy.candidates[-1] | |
| return PageModeSpec( | |
| mode=safest.mode, | |
| bits=safest.bits, | |
| quant_scheme=safest.quant_scheme, | |
| escape_dtype=safest.escape_dtype, | |
| policy_id=layer_policy.policy_id, | |
| sensitivity_tier=layer_policy.sensitivity_tier, | |
| fallback_reason="+".join(failure_reasons) if failure_reasons else "threshold_fallback", | |
| age_bucket="aged", | |
| ) | |
| def choose_mode( | |
| layer: int, | |
| head: int, | |
| token_age: int, | |
| stats: dict[str, float | bool] | None = None, | |
| *, | |
| recent_window: int = 128, | |
| error_threshold: float | None = None, | |
| ) -> str: | |
| del layer | |
| del head | |
| if token_age < recent_window: | |
| return "M3" | |
| if stats is None: | |
| return "M0" | |
| if bool(stats.get("force_escape", False)): | |
| return "M3" | |
| quant_error = float(stats.get("quant_error", 0.0)) | |
| if error_threshold is not None and quant_error > error_threshold: | |
| return "M3" | |
| return "M0" | |
| def parse_page_mode_token(token: str) -> PageModeSpec: | |
| parts = [part.strip() for part in token.split("/") if part.strip()] | |
| if len(parts) not in (3, 4): | |
| raise ValueError("page mode tokens must use MODE/SCHEME/BITS[/ESCAPE_DTYPE], for example M0/affine/4 or M3/affine/4/int8") | |
| mode_text, scheme_text, bits_text = parts[:3] | |
| mode = mode_text.upper() | |
| if mode not in {"M0", "M1", "M2", "M3", "M4", "T3"}: | |
| raise ValueError(f"unsupported page mode: {mode_text}") | |
| quant_scheme = scheme_text.lower() | |
| if quant_scheme not in {"affine", "symmetric", "lut", "sketch", "project", "turbo3"}: | |
| raise ValueError(f"unsupported quant scheme: {scheme_text}") | |
| bits = int(bits_text) | |
| escape_dtype = None | |
| if len(parts) == 4: | |
| escape_dtype = parts[3].lower() | |
| if escape_dtype not in {"float16", "float32", "int8"}: | |
| raise ValueError(f"unsupported escape dtype: {parts[3]}") | |
| if mode != "M3": | |
| raise ValueError("escape dtype qualifiers are only supported for M3 page modes") | |
| return PageModeSpec(mode=mode, bits=bits, quant_scheme=quant_scheme, escape_dtype=escape_dtype) | |
| def make_tier_candidates( | |
| *, | |
| kind: str, | |
| sensitivity_tier: SensitivityTier, | |
| default_bits: int, | |
| default_quant_scheme: str, | |
| default_mode: str, | |
| recent_window: int, | |
| recent_escape_dtype: str = "float16", | |
| prefer_project_key_mode: bool = False, | |
| ) -> LayerPolicy: | |
| def candidate(mode: ModeName, scheme: QuantSchemeName, bits: int) -> PageModeSpec: | |
| return PageModeSpec(mode=mode, quant_scheme=scheme, bits=bits, sensitivity_tier=sensitivity_tier) | |
| exact_mode = candidate( | |
| default_mode if default_mode in {"M0", "M1", "M2", "M3", "M4", "T3"} else "M0", | |
| default_quant_scheme if default_quant_scheme in {"affine", "symmetric", "lut", "sketch", "project", "turbo3"} else "affine", | |
| default_bits, | |
| ) | |
| if sensitivity_tier == "exact": | |
| candidates = (exact_mode,) | |
| thresholds = (0.0, float("inf"), float("inf")) | |
| elif kind == "K": | |
| if sensitivity_tier == "strict": | |
| candidates = (candidate("M0", "affine", 4),) | |
| thresholds = (0.02, 4.5, 3.0) | |
| elif sensitivity_tier == "aggressive": | |
| candidates = ( | |
| candidate("M0", "affine", 2), | |
| candidate("M4", "project", 4) if prefer_project_key_mode else candidate("M2", "sketch", 4), | |
| candidate("M0", "affine", 4), | |
| ) | |
| thresholds = (0.10, 8.0, 5.5) | |
| else: | |
| candidates = ( | |
| candidate("M0", "affine", 2), | |
| candidate("M4", "project", 4) if prefer_project_key_mode else candidate("M2", "sketch", 4), | |
| candidate("M0", "affine", 4), | |
| ) | |
| thresholds = (0.05, 6.0, 4.0) | |
| else: | |
| if sensitivity_tier == "strict": | |
| candidates = ( | |
| candidate("M1", "lut", 4), | |
| candidate("M0", "affine", 4), | |
| ) | |
| thresholds = (0.02, 4.5, 3.0) | |
| elif sensitivity_tier == "aggressive": | |
| candidates = ( | |
| candidate("M0", "affine", 2), | |
| candidate("M0", "affine", 3), | |
| candidate("M1", "lut", 4), | |
| candidate("M0", "affine", 4), | |
| ) | |
| thresholds = (0.10, 8.0, 5.5) | |
| else: | |
| candidates = ( | |
| candidate("M0", "affine", 3), | |
| candidate("M1", "lut", 4), | |
| candidate("M0", "affine", 4), | |
| ) | |
| thresholds = (0.05, 6.0, 4.0) | |
| return LayerPolicy( | |
| policy_id=f"{kind.lower()}_{sensitivity_tier}", | |
| sensitivity_tier=sensitivity_tier, | |
| kind=kind, | |
| candidates=candidates, | |
| recent_candidate=PageModeSpec( | |
| mode="M3", | |
| bits=default_bits, | |
| quant_scheme="affine", | |
| escape_dtype=recent_escape_dtype, | |
| sensitivity_tier=sensitivity_tier, | |
| ), | |
| recent_window=0 if sensitivity_tier == "exact" else recent_window, | |
| outlier_fraction_threshold=float(thresholds[0]), | |
| abs_max_threshold=float(thresholds[1]), | |
| channel_range_threshold=float(thresholds[2]), | |
| ) | |
| def make_explicit_policy( | |
| *, | |
| kind: str, | |
| policy_id: str, | |
| sensitivity_tier: SensitivityTier, | |
| candidates: Sequence[PageModeSpec], | |
| recent_window: int, | |
| recent_escape_dtype: str = "float16", | |
| ) -> LayerPolicy: | |
| if not candidates: | |
| raise ValueError("explicit policies must provide at least one candidate") | |
| recent_candidate = next((candidate for candidate in candidates if candidate.mode == "M3"), None) | |
| if recent_candidate is None: | |
| recent_candidate = PageModeSpec( | |
| mode="M3", | |
| bits=4, | |
| quant_scheme="affine", | |
| escape_dtype=recent_escape_dtype, | |
| sensitivity_tier=sensitivity_tier, | |
| ) | |
| return LayerPolicy( | |
| policy_id=policy_id, | |
| sensitivity_tier=sensitivity_tier, | |
| kind=kind, | |
| candidates=tuple(candidates), | |
| recent_candidate=recent_candidate, | |
| recent_window=recent_window, | |
| ) | |
| def _candidate_is_allowed( | |
| candidate: PageModeSpec, | |
| *, | |
| kind: str, | |
| stats: PageStats | None, | |
| policy: LayerPolicy, | |
| ) -> bool: | |
| if stats is None or candidate.mode == "M0" and candidate.bits >= 4: | |
| return True | |
| if candidate.mode == "M3": | |
| return True | |
| if candidate.mode == "M2": | |
| if kind != "K": | |
| return False | |
| return ( | |
| stats.outlier_fraction <= policy.outlier_fraction_threshold | |
| and stats.channel_range_mean <= policy.channel_range_threshold | |
| ) | |
| if candidate.mode == "M4": | |
| if kind != "K": | |
| return False | |
| return ( | |
| stats.outlier_fraction <= policy.outlier_fraction_threshold | |
| and stats.channel_range_mean <= policy.channel_range_threshold | |
| ) | |
| if candidate.mode == "M1": | |
| if kind != "V": | |
| return False | |
| return ( | |
| stats.outlier_fraction <= policy.outlier_fraction_threshold | |
| and stats.channel_range_mean <= policy.channel_range_threshold | |
| ) | |
| if candidate.mode == "T3": | |
| return False | |
| if candidate.mode == "M0" and candidate.bits <= 2: | |
| return ( | |
| stats.outlier_fraction <= policy.outlier_fraction_threshold | |
| and stats.abs_max <= policy.abs_max_threshold | |
| ) | |
| if candidate.mode == "M0" and candidate.bits == 3: | |
| return ( | |
| stats.outlier_fraction <= (policy.outlier_fraction_threshold * 1.25) | |
| and stats.abs_max <= (policy.abs_max_threshold * 1.2) | |
| ) | |
| return True | |