Spaces:
Paused
Paused
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from math import ceil | |
| import math | |
| from .modes.m4_key_project import valid_m4_basis_families | |
| from .planner import LayerPolicy, PageModeSpec, make_explicit_policy, make_tier_candidates, parse_page_mode_token | |
| _VALID_KEY_MODES = ("M0", "M1", "M2", "M3", "M4", "T3") | |
| _VALID_VALUE_MODES = ("M0", "M1", "M3", "T3") | |
| _VALID_M4_BASIS_FAMILIES = valid_m4_basis_families() | |
| def _parse_mode_override_spec(spec: str, *, allowed_modes: tuple[str, ...], field_name: str) -> tuple[int, int | None, str]: | |
| if "=" not in spec: | |
| raise ValueError(f"{field_name} entries must use layer:<id>=<mode> or layer:<id>:kv:<id>=<mode>") | |
| target, mode = spec.split("=", 1) | |
| mode = mode.strip() | |
| if mode not in allowed_modes: | |
| allowed = ", ".join(allowed_modes) | |
| raise ValueError(f"{field_name} mode must be one of {allowed}") | |
| parts = target.strip().split(":") | |
| if len(parts) == 2 and parts[0] == "layer": | |
| return int(parts[1]), None, mode | |
| if len(parts) == 4 and parts[0] == "layer" and parts[2] == "kv": | |
| return int(parts[1]), int(parts[3]), mode | |
| raise ValueError(f"{field_name} entries must use layer:<id>=<mode> or layer:<id>:kv:<id>=<mode>") | |
| def _parse_layer_value_spec( | |
| spec: str, | |
| *, | |
| field_name: str, | |
| allowed_values: tuple[str, ...], | |
| ) -> tuple[int, str]: | |
| if "=" not in spec: | |
| raise ValueError(f"{field_name} entries must use layer:<id>=<value>") | |
| target, value = spec.split("=", 1) | |
| parts = target.strip().split(":") | |
| if len(parts) != 2 or parts[0] != "layer": | |
| raise ValueError(f"{field_name} entries must use layer:<id>=<value>") | |
| value = value.strip() | |
| if value not in allowed_values: | |
| allowed = ", ".join(allowed_values) | |
| raise ValueError(f"{field_name} values must be one of {allowed}") | |
| return int(parts[1]), value | |
| def _parse_layer_candidate_spec(spec: str, *, field_name: str) -> tuple[int, tuple[PageModeSpec, ...]]: | |
| if "=" not in spec: | |
| raise ValueError(f"{field_name} entries must use layer:<id>=MODE/SCHEME/BITS[,MODE/SCHEME/BITS...]") | |
| target, raw_candidates = spec.split("=", 1) | |
| parts = target.strip().split(":") | |
| if len(parts) != 2 or parts[0] != "layer": | |
| raise ValueError(f"{field_name} entries must use layer:<id>=MODE/SCHEME/BITS[,MODE/SCHEME/BITS...]") | |
| candidates = tuple( | |
| parse_page_mode_token(token.strip()) | |
| for token in raw_candidates.split(",") | |
| if token.strip() | |
| ) | |
| if not candidates: | |
| raise ValueError(f"{field_name} entries must include at least one candidate") | |
| return int(parts[1]), candidates | |
| def _parse_layer_positive_int_spec(spec: str, *, field_name: str) -> tuple[int, int]: | |
| if "=" not in spec: | |
| raise ValueError(f"{field_name} entries must use layer:<id>=<positive_int>") | |
| target, value = spec.split("=", 1) | |
| parts = target.strip().split(":") | |
| if len(parts) != 2 or parts[0] != "layer": | |
| raise ValueError(f"{field_name} entries must use layer:<id>=<positive_int>") | |
| parsed_value = int(value.strip()) | |
| if parsed_value <= 0: | |
| raise ValueError(f"{field_name} values must be positive integers") | |
| return int(parts[1]), parsed_value | |
| def _parse_layer_context_positive_int_spec(spec: str, *, field_name: str) -> tuple[int, int, int]: | |
| if "=" not in spec: | |
| raise ValueError(f"{field_name} entries must use layer:<id>:min_ctx:<non_negative_int>=<positive_int>") | |
| target, value = spec.split("=", 1) | |
| parts = target.strip().split(":") | |
| if len(parts) != 4 or parts[0] != "layer" or parts[2] != "min_ctx": | |
| raise ValueError(f"{field_name} entries must use layer:<id>:min_ctx:<non_negative_int>=<positive_int>") | |
| min_context = int(parts[3]) | |
| if min_context < 0: | |
| raise ValueError(f"{field_name} min_ctx must be non-negative") | |
| parsed_value = int(value.strip()) | |
| if parsed_value <= 0: | |
| raise ValueError(f"{field_name} values must be positive integers") | |
| return int(parts[1]), min_context, parsed_value | |
| class DotCacheConfig: | |
| head_dim: int | |
| group_size: int = 32 | |
| bits_k: int = 4 | |
| bits_v: int = 4 | |
| tokens_per_page: int = 64 | |
| recent_window: int = 128 | |
| sink_window: int = 0 | |
| execution_recent_window: int = 0 | |
| execution_sink_window: int = 0 | |
| execution_recent_window_overrides: tuple[str, ...] = () | |
| execution_recent_window_context_overrides: tuple[str, ...] = () | |
| execution_relevance_top_k: int = 0 | |
| execution_relevance_mode: str = "envelope" | |
| execution_relevance_top_k_overrides: tuple[str, ...] = () | |
| execution_relevance_top_k_context_overrides: tuple[str, ...] = () | |
| execution_full_context_layers: tuple[int, ...] = () | |
| execution_disable_grouped_batching_layers: tuple[int, ...] = () | |
| execution_recent_old_bonus_window: int = 0 | |
| execution_recent_old_bonus_strength: float = 0.0 | |
| execution_recent_old_bonus_layers: tuple[int, ...] = () | |
| execution_secondary_relevance_mode: str = "" | |
| execution_secondary_relevance_top_k: int = 0 | |
| execution_secondary_relevance_min_overlap: float = 0.0 | |
| execution_secondary_relevance_layers: tuple[int, ...] = () | |
| execution_recent_neighbor_rescue_top_k: int = 0 | |
| execution_recent_neighbor_rescue_anchor_window: int = 0 | |
| execution_recent_neighbor_rescue_min_anchor_pages: int = 0 | |
| execution_recent_neighbor_rescue_layers: tuple[int, ...] = () | |
| execution_exact_promote_top_k: int = 0 | |
| execution_exact_promote_min_margin_threshold: float = 0.0 | |
| execution_exact_promote_max_context: int = 0 | |
| execution_exact_promote_margin_threshold: float = 0.0 | |
| execution_exact_promote_layers: tuple[int, ...] = () | |
| execution_exact_promote_union_rescue_top_k: int = 0 | |
| execution_grouped_decode_compact: bool = False | |
| execution_grouped_mix_compact: bool = False | |
| execution_grouped_mix_disable_packed_cuda: bool = False | |
| execution_freeze_chunk_budget_during_decode: bool = False | |
| execution_builtin_selector_cache: bool = False | |
| execution_builtin_selector_score_all_pages: bool = False | |
| execution_builtin_selector_candidate_only: bool = False | |
| execution_builtin_selector_score_all_pages_min_candidate_fraction: float = 0.0 | |
| execution_value_escape_layers: tuple[int, ...] = () | |
| execution_value_escape_mode: str = "M3" | |
| execution_value_escape_old_only: bool = False | |
| execution_value_escape_top_k: int = 0 | |
| execution_value_escape_prewarm: bool = False | |
| execution_value_escape_prewarm_min_context: int = 0 | |
| execution_exact_refine_top_k: int = 0 | |
| execution_exact_refine_layers: tuple[int, ...] = () | |
| store_scales_dtype: str = "float16" | |
| store_bias_dtype: str = "float16" | |
| payload_layout_k: str = "group_major" | |
| payload_layout_v: str = "group_major" | |
| default_mode_k: str = "M0" | |
| default_mode_v: str = "M0" | |
| quant_scheme_k: str = "affine" | |
| quant_scheme_v: str = "affine" | |
| escape_dtype: str = "float16" | |
| recent_page_escape_dtype: str = "float16" | |
| m2_sketch_dim_k: int = 8 | |
| m4_project_basis_k: str = "hadamard" | |
| m4_project_basis_k_overrides: tuple[str, ...] = () | |
| m4_project_dim_k_overrides: tuple[str, ...] = () | |
| m2_center_k: bool = False | |
| m2_segment_count_k: int = 1 | |
| m2_adaptive_segments_k: bool = False | |
| m2_adaptive_min_improvement_k: float = 0.1 | |
| m2_prefilter_top_k: int = 0 | |
| m2_prefilter_min_pages: int = 8 | |
| prefer_m4_project_k: bool = False | |
| lut_refine_steps: int = 6 | |
| preconditioner: str = "none" | |
| precondition_strength: float = 2.0 | |
| m1_segment_count_k: int = 1 | |
| m1_segment_count_v: int = 1 | |
| m1_fallback_to_m0: bool = True | |
| m1_error_threshold: float = 0.35 | |
| m1_token_p95_error_threshold: float = 1000000.0 | |
| prepared_chunk_cache_budget_ratio: float = 0.5 | |
| prepared_chunk_cache_min_bytes: int = 1 * 1024 * 1024 | |
| prepared_chunk_cache_max_bytes: int = 64 * 1024 * 1024 | |
| key_mode_overrides: tuple[str, ...] = () | |
| value_mode_overrides: tuple[str, ...] = () | |
| key_policy_tier: str = "exact" | |
| value_policy_tier: str = "exact" | |
| key_layer_sensitivity: tuple[str, ...] = () | |
| value_layer_sensitivity: tuple[str, ...] = () | |
| key_policy_overrides: tuple[str, ...] = () | |
| value_policy_overrides: tuple[str, ...] = () | |
| learned_page_selector_path: str | None = None | |
| learned_page_selector_prompt_family: str | None = None | |
| learned_page_selector_prompt_variant: str | None = None | |
| learned_page_selector_profile: str = "quality" | |
| learned_page_selector_scope: str = "KV" | |
| learned_page_selector_target_candidate: str = "M3/affine/4/float16" | |
| learned_page_selector_logit_offset: float = 0.0 | |
| def __post_init__(self) -> None: | |
| if self.head_dim <= 0: | |
| raise ValueError("head_dim must be positive") | |
| if self.group_size <= 0: | |
| raise ValueError("group_size must be positive") | |
| if self.bits_k not in (2, 3, 4): | |
| raise ValueError("bits_k must be 2, 3, or 4 for the current runtime") | |
| if self.bits_v not in (2, 3, 4): | |
| raise ValueError("bits_v must be 2, 3, or 4 for the current runtime") | |
| if self.tokens_per_page <= 0: | |
| raise ValueError("tokens_per_page must be positive") | |
| if self.execution_recent_window < 0: | |
| raise ValueError("execution_recent_window must be non-negative") | |
| if self.execution_sink_window < 0: | |
| raise ValueError("execution_sink_window must be non-negative") | |
| for spec in self.execution_recent_window_overrides: | |
| _parse_layer_positive_int_spec(spec, field_name="execution_recent_window_overrides") | |
| for spec in self.execution_recent_window_context_overrides: | |
| _parse_layer_context_positive_int_spec(spec, field_name="execution_recent_window_context_overrides") | |
| if self.execution_relevance_top_k < 0: | |
| raise ValueError("execution_relevance_top_k must be non-negative") | |
| if self.execution_relevance_mode not in ("sketch", "envelope"): | |
| raise ValueError("execution_relevance_mode must be sketch or envelope") | |
| for spec in self.execution_relevance_top_k_overrides: | |
| _parse_layer_positive_int_spec(spec, field_name="execution_relevance_top_k_overrides") | |
| for spec in self.execution_relevance_top_k_context_overrides: | |
| _parse_layer_context_positive_int_spec(spec, field_name="execution_relevance_top_k_context_overrides") | |
| for layer_id in self.execution_full_context_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_full_context_layers must be non-negative") | |
| for layer_id in self.execution_disable_grouped_batching_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_disable_grouped_batching_layers must be non-negative") | |
| if self.execution_recent_old_bonus_window < 0: | |
| raise ValueError("execution_recent_old_bonus_window must be non-negative") | |
| if self.execution_recent_old_bonus_strength < 0: | |
| raise ValueError("execution_recent_old_bonus_strength must be non-negative") | |
| for layer_id in self.execution_recent_old_bonus_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_recent_old_bonus_layers must be non-negative") | |
| if self.execution_secondary_relevance_mode not in ("", "sketch", "envelope"): | |
| raise ValueError("execution_secondary_relevance_mode must be empty, sketch, or envelope") | |
| if self.execution_secondary_relevance_top_k < 0: | |
| raise ValueError("execution_secondary_relevance_top_k must be non-negative") | |
| if not 0.0 <= float(self.execution_secondary_relevance_min_overlap) <= 1.0: | |
| raise ValueError("execution_secondary_relevance_min_overlap must be between 0 and 1") | |
| for layer_id in self.execution_secondary_relevance_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_secondary_relevance_layers must be non-negative") | |
| if self.execution_recent_neighbor_rescue_top_k < 0: | |
| raise ValueError("execution_recent_neighbor_rescue_top_k must be non-negative") | |
| if self.execution_recent_neighbor_rescue_anchor_window < 0: | |
| raise ValueError("execution_recent_neighbor_rescue_anchor_window must be non-negative") | |
| if self.execution_recent_neighbor_rescue_min_anchor_pages < 0: | |
| raise ValueError("execution_recent_neighbor_rescue_min_anchor_pages must be non-negative") | |
| for layer_id in self.execution_recent_neighbor_rescue_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_recent_neighbor_rescue_layers must be non-negative") | |
| if self.execution_exact_promote_top_k < 0: | |
| raise ValueError("execution_exact_promote_top_k must be non-negative") | |
| if self.execution_exact_promote_min_margin_threshold < 0: | |
| raise ValueError("execution_exact_promote_min_margin_threshold must be non-negative") | |
| if self.execution_exact_promote_max_context < 0: | |
| raise ValueError("execution_exact_promote_max_context must be non-negative") | |
| if self.execution_exact_promote_margin_threshold < 0: | |
| raise ValueError("execution_exact_promote_margin_threshold must be non-negative") | |
| for layer_id in self.execution_exact_promote_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_exact_promote_layers must be non-negative") | |
| if self.execution_exact_promote_union_rescue_top_k < 0: | |
| raise ValueError("execution_exact_promote_union_rescue_top_k must be non-negative") | |
| for layer_id in self.execution_value_escape_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_value_escape_layers must be non-negative") | |
| if self.execution_value_escape_mode not in _VALID_VALUE_MODES: | |
| allowed = ", ".join(_VALID_VALUE_MODES) | |
| raise ValueError(f"execution_value_escape_mode must be one of {allowed}") | |
| if self.execution_exact_refine_top_k < 0: | |
| raise ValueError("execution_exact_refine_top_k must be non-negative") | |
| for layer_id in self.execution_exact_refine_layers: | |
| if int(layer_id) < 0: | |
| raise ValueError("execution_exact_refine_layers must be non-negative") | |
| if self.payload_layout_k not in ("group_major", "token_major"): | |
| raise ValueError("payload_layout_k must be group_major or token_major") | |
| if self.payload_layout_v not in ("group_major", "token_major"): | |
| raise ValueError("payload_layout_v must be group_major or token_major") | |
| if self.default_mode_k not in _VALID_KEY_MODES: | |
| raise ValueError("default_mode_k must be M0, M1, M2, M3, M4, or T3") | |
| if self.default_mode_v not in _VALID_VALUE_MODES: | |
| raise ValueError("default_mode_v must be M0, M1, M3, or T3") | |
| if self.quant_scheme_k not in ("affine", "symmetric", "lut", "sketch", "project", "turbo3"): | |
| raise ValueError("quant_scheme_k must be affine, symmetric, lut, sketch, project, or turbo3") | |
| if self.quant_scheme_v not in ("affine", "symmetric", "lut", "turbo3"): | |
| raise ValueError("quant_scheme_v must be affine, symmetric, lut, or turbo3") | |
| if self.escape_dtype not in ("float16", "float32", "int8"): | |
| raise ValueError("escape_dtype must be float16, float32, or int8") | |
| if self.recent_page_escape_dtype not in ("float16", "float32", "int8"): | |
| raise ValueError("recent_page_escape_dtype must be float16, float32, or int8") | |
| if self.m2_sketch_dim_k <= 0: | |
| raise ValueError("m2_sketch_dim_k must be positive") | |
| if self.m4_project_basis_k not in _VALID_M4_BASIS_FAMILIES: | |
| allowed = ", ".join(_VALID_M4_BASIS_FAMILIES) | |
| raise ValueError(f"m4_project_basis_k must be one of {allowed}") | |
| for spec in self.m4_project_basis_k_overrides: | |
| _parse_layer_value_spec( | |
| spec, | |
| field_name="m4_project_basis_k_overrides", | |
| allowed_values=_VALID_M4_BASIS_FAMILIES, | |
| ) | |
| for spec in self.m4_project_dim_k_overrides: | |
| _parse_layer_positive_int_spec(spec, field_name="m4_project_dim_k_overrides") | |
| if not isinstance(self.m2_center_k, bool): | |
| raise ValueError("m2_center_k must be a bool") | |
| if self.m2_segment_count_k <= 0: | |
| raise ValueError("m2_segment_count_k must be positive") | |
| if not isinstance(self.m2_adaptive_segments_k, bool): | |
| raise ValueError("m2_adaptive_segments_k must be a bool") | |
| if self.m2_adaptive_min_improvement_k < 0: | |
| raise ValueError("m2_adaptive_min_improvement_k must be non-negative") | |
| if self.m2_prefilter_top_k < 0: | |
| raise ValueError("m2_prefilter_top_k must be non-negative") | |
| if self.m2_prefilter_min_pages < 0: | |
| raise ValueError("m2_prefilter_min_pages must be non-negative") | |
| if self.lut_refine_steps < 0: | |
| raise ValueError("lut_refine_steps must be non-negative") | |
| if self.preconditioner not in ("none", "tanh"): | |
| raise ValueError("preconditioner must be none or tanh") | |
| if self.precondition_strength <= 0: | |
| raise ValueError("precondition_strength must be positive") | |
| if self.m1_segment_count_k <= 0: | |
| raise ValueError("m1_segment_count_k must be positive") | |
| if self.m1_segment_count_v <= 0: | |
| raise ValueError("m1_segment_count_v must be positive") | |
| if self.m1_error_threshold <= 0: | |
| raise ValueError("m1_error_threshold must be positive") | |
| if self.m1_token_p95_error_threshold <= 0: | |
| raise ValueError("m1_token_p95_error_threshold must be positive") | |
| if self.prepared_chunk_cache_budget_ratio < 0: | |
| raise ValueError("prepared_chunk_cache_budget_ratio must be non-negative") | |
| if self.prepared_chunk_cache_min_bytes < 0: | |
| raise ValueError("prepared_chunk_cache_min_bytes must be non-negative") | |
| if self.prepared_chunk_cache_max_bytes < 0: | |
| raise ValueError("prepared_chunk_cache_max_bytes must be non-negative") | |
| if ( | |
| self.prepared_chunk_cache_max_bytes > 0 | |
| and self.prepared_chunk_cache_min_bytes > self.prepared_chunk_cache_max_bytes | |
| ): | |
| raise ValueError("prepared_chunk_cache_min_bytes must not exceed prepared_chunk_cache_max_bytes") | |
| for spec in self.key_mode_overrides: | |
| _parse_mode_override_spec(spec, allowed_modes=_VALID_KEY_MODES, field_name="key_mode_overrides") | |
| for spec in self.value_mode_overrides: | |
| _parse_mode_override_spec(spec, allowed_modes=_VALID_VALUE_MODES, field_name="value_mode_overrides") | |
| for field_name, tier in (("key_policy_tier", self.key_policy_tier), ("value_policy_tier", self.value_policy_tier)): | |
| if tier not in ("exact", "strict", "balanced", "aggressive"): | |
| raise ValueError(f"{field_name} must be exact, strict, balanced, or aggressive") | |
| for spec in self.key_layer_sensitivity: | |
| _parse_layer_value_spec( | |
| spec, | |
| field_name="key_layer_sensitivity", | |
| allowed_values=("strict", "balanced", "aggressive"), | |
| ) | |
| for spec in self.value_layer_sensitivity: | |
| _parse_layer_value_spec( | |
| spec, | |
| field_name="value_layer_sensitivity", | |
| allowed_values=("strict", "balanced", "aggressive"), | |
| ) | |
| for spec in self.key_policy_overrides: | |
| _parse_layer_candidate_spec(spec, field_name="key_policy_overrides") | |
| for spec in self.value_policy_overrides: | |
| _parse_layer_candidate_spec(spec, field_name="value_policy_overrides") | |
| if self.learned_page_selector_path is not None and not str(self.learned_page_selector_path).strip(): | |
| raise ValueError("learned_page_selector_path must be a non-empty string when provided") | |
| if self.learned_page_selector_prompt_family is not None and not str(self.learned_page_selector_prompt_family).strip(): | |
| raise ValueError("learned_page_selector_prompt_family must be a non-empty string when provided") | |
| if self.learned_page_selector_prompt_variant is not None and not str(self.learned_page_selector_prompt_variant).strip(): | |
| raise ValueError("learned_page_selector_prompt_variant must be a non-empty string when provided") | |
| if str(self.learned_page_selector_profile) not in {"quality", "systems", "manual"}: | |
| raise ValueError("learned_page_selector_profile must be quality, systems, or manual") | |
| if str(self.learned_page_selector_scope) not in {"KV", "K", "V"}: | |
| raise ValueError("learned_page_selector_scope must be KV, K, or V") | |
| if not str(self.learned_page_selector_target_candidate).strip(): | |
| raise ValueError("learned_page_selector_target_candidate must be a non-empty string") | |
| if not math.isfinite(float(self.learned_page_selector_logit_offset)): | |
| raise ValueError("learned_page_selector_logit_offset must be finite") | |
| def num_groups(self) -> int: | |
| return ceil(self.head_dim / self.group_size) | |
| def padded_head_dim(self) -> int: | |
| return self.num_groups * self.group_size | |
| def has_mode_overrides(self, *, kind: str | None = None) -> bool: | |
| if kind == "K": | |
| return bool(self.key_mode_overrides) | |
| if kind == "V": | |
| return bool(self.value_mode_overrides) | |
| return bool(self.key_mode_overrides or self.value_mode_overrides) | |
| def has_policy_overrides(self, *, kind: str | None = None) -> bool: | |
| if kind == "K": | |
| return bool(self.key_layer_sensitivity or self.key_policy_overrides or self.key_policy_tier != "exact") | |
| if kind == "V": | |
| return bool(self.value_layer_sensitivity or self.value_policy_overrides or self.value_policy_tier != "exact") | |
| return bool( | |
| self.key_layer_sensitivity | |
| or self.value_layer_sensitivity | |
| or self.key_policy_overrides | |
| or self.value_policy_overrides | |
| or self.key_policy_tier != "exact" | |
| or self.value_policy_tier != "exact" | |
| ) | |
| def learned_page_selector_enabled(self) -> bool: | |
| return self.learned_page_selector_path is not None and bool(str(self.learned_page_selector_path).strip()) | |
| def learned_page_selector_applies_to_kind(self, *, kind: str) -> bool: | |
| scope = str(self.learned_page_selector_scope) | |
| if scope == "KV": | |
| return kind in {"K", "V"} | |
| return str(kind) == scope | |
| def resolve_page_mode(self, *, kind: str, layer_id: int, kv_head_id: int) -> str: | |
| if kind == "K": | |
| resolved = self.default_mode_k | |
| specs = self.key_mode_overrides | |
| allowed_modes = _VALID_KEY_MODES | |
| field_name = "key_mode_overrides" | |
| elif kind == "V": | |
| resolved = self.default_mode_v | |
| specs = self.value_mode_overrides | |
| allowed_modes = _VALID_VALUE_MODES | |
| field_name = "value_mode_overrides" | |
| else: | |
| raise ValueError("kind must be K or V") | |
| for spec in specs: | |
| override_layer_id, override_kv_head_id, override_mode = _parse_mode_override_spec( | |
| spec, | |
| allowed_modes=allowed_modes, | |
| field_name=field_name, | |
| ) | |
| if override_layer_id != int(layer_id): | |
| continue | |
| if override_kv_head_id is not None and override_kv_head_id != int(kv_head_id): | |
| continue | |
| resolved = override_mode | |
| return resolved | |
| def resolve_m4_project_dim_k(self, *, layer_id: int) -> int: | |
| resolved = int(self.m2_sketch_dim_k) | |
| for spec in self.m4_project_dim_k_overrides: | |
| override_layer_id, override_dim = _parse_layer_positive_int_spec( | |
| spec, | |
| field_name="m4_project_dim_k_overrides", | |
| ) | |
| if override_layer_id == int(layer_id): | |
| resolved = int(override_dim) | |
| return resolved | |
| def resolve_execution_relevance_top_k(self, *, layer_id: int) -> int: | |
| resolved = int(self.execution_relevance_top_k) | |
| for spec in self.execution_relevance_top_k_overrides: | |
| override_layer_id, override_value = _parse_layer_positive_int_spec( | |
| spec, | |
| field_name="execution_relevance_top_k_overrides", | |
| ) | |
| if override_layer_id == int(layer_id): | |
| resolved = int(override_value) | |
| return resolved | |
| def resolve_execution_recent_window(self, *, layer_id: int) -> int: | |
| resolved = int(self.execution_recent_window) | |
| for spec in self.execution_recent_window_overrides: | |
| override_layer_id, override_value = _parse_layer_positive_int_spec( | |
| spec, | |
| field_name="execution_recent_window_overrides", | |
| ) | |
| if override_layer_id == int(layer_id): | |
| resolved = int(override_value) | |
| return resolved | |
| def execution_shortlist_enabled(self) -> bool: | |
| return ( | |
| self.execution_recent_window > 0 | |
| or self.execution_sink_window > 0 | |
| or bool(self.execution_recent_window_overrides) | |
| or bool(self.execution_recent_window_context_overrides) | |
| or self.execution_relevance_top_k > 0 | |
| or bool(self.execution_relevance_top_k_overrides) | |
| or bool(self.execution_relevance_top_k_context_overrides) | |
| ) | |
| def execution_shortlist_disabled_for_layer(self, *, layer_id: int) -> bool: | |
| return int(layer_id) in {int(value) for value in self.execution_full_context_layers} | |
| def execution_grouped_batching_disabled_for_layer(self, *, layer_id: int) -> bool: | |
| return int(layer_id) in {int(value) for value in self.execution_disable_grouped_batching_layers} | |
| def execution_value_escape_enabled_for_layer(self, *, layer_id: int) -> bool: | |
| if not self.execution_value_escape_layers: | |
| return False | |
| return int(layer_id) in {int(value) for value in self.execution_value_escape_layers} | |
| def execution_recent_old_bonus_enabled_for_layer(self, *, layer_id: int) -> bool: | |
| if self.execution_recent_old_bonus_window <= 0 or self.execution_recent_old_bonus_strength <= 0: | |
| return False | |
| if not self.execution_recent_old_bonus_layers: | |
| return False | |
| return int(layer_id) in {int(value) for value in self.execution_recent_old_bonus_layers} | |
| def execution_secondary_relevance_enabled_for_layer(self, *, layer_id: int) -> bool: | |
| if self.execution_secondary_relevance_mode not in ("sketch", "envelope"): | |
| return False | |
| if self.execution_secondary_relevance_top_k <= 0: | |
| return False | |
| if not self.execution_secondary_relevance_layers: | |
| return False | |
| return int(layer_id) in {int(value) for value in self.execution_secondary_relevance_layers} | |
| def execution_recent_neighbor_rescue_enabled_for_layer(self, *, layer_id: int) -> bool: | |
| if self.execution_recent_neighbor_rescue_top_k <= 0: | |
| return False | |
| if self.execution_recent_neighbor_rescue_anchor_window <= 0: | |
| return False | |
| if self.execution_recent_neighbor_rescue_min_anchor_pages <= 0: | |
| return False | |
| if not self.execution_recent_neighbor_rescue_layers: | |
| return False | |
| return int(layer_id) in {int(value) for value in self.execution_recent_neighbor_rescue_layers} | |
| def resolve_execution_relevance_top_k_for_context(self, *, layer_id: int, context_length: int | None = None) -> int: | |
| resolved = self.resolve_execution_relevance_top_k(layer_id=layer_id) | |
| if context_length is None: | |
| return resolved | |
| best_min_context = -1 | |
| for spec in self.execution_relevance_top_k_context_overrides: | |
| override_layer_id, min_context, override_value = _parse_layer_context_positive_int_spec( | |
| spec, | |
| field_name="execution_relevance_top_k_context_overrides", | |
| ) | |
| if override_layer_id != int(layer_id): | |
| continue | |
| if int(context_length) < int(min_context) or int(min_context) < best_min_context: | |
| continue | |
| resolved = int(override_value) | |
| best_min_context = int(min_context) | |
| return resolved | |
| def resolve_execution_recent_window_for_context(self, *, layer_id: int, context_length: int | None = None) -> int: | |
| resolved = self.resolve_execution_recent_window(layer_id=layer_id) | |
| if context_length is None: | |
| return resolved | |
| best_min_context = -1 | |
| for spec in self.execution_recent_window_context_overrides: | |
| override_layer_id, min_context, override_value = _parse_layer_context_positive_int_spec( | |
| spec, | |
| field_name="execution_recent_window_context_overrides", | |
| ) | |
| if override_layer_id != int(layer_id): | |
| continue | |
| if int(context_length) < int(min_context) or int(min_context) < best_min_context: | |
| continue | |
| resolved = int(override_value) | |
| best_min_context = int(min_context) | |
| return resolved | |
| def resolve_m4_project_basis_k(self, *, layer_id: int) -> str: | |
| resolved = self.m4_project_basis_k | |
| for spec in self.m4_project_basis_k_overrides: | |
| override_layer_id, override_basis = _parse_layer_value_spec( | |
| spec, | |
| field_name="m4_project_basis_k_overrides", | |
| allowed_values=_VALID_M4_BASIS_FAMILIES, | |
| ) | |
| if override_layer_id == int(layer_id): | |
| resolved = override_basis | |
| return resolved | |
| def resolve_layer_policy(self, *, kind: str, layer_id: int, kv_head_id: int) -> LayerPolicy: | |
| if kind == "K": | |
| default_mode = self.default_mode_k | |
| default_bits = self.bits_k | |
| default_quant_scheme = self.quant_scheme_k | |
| default_tier = self.key_policy_tier | |
| sensitivity_specs = self.key_layer_sensitivity | |
| explicit_specs = self.key_policy_overrides | |
| mode_overrides = self.key_mode_overrides | |
| elif kind == "V": | |
| default_mode = self.default_mode_v | |
| default_bits = self.bits_v | |
| default_quant_scheme = self.quant_scheme_v | |
| default_tier = self.value_policy_tier | |
| sensitivity_specs = self.value_layer_sensitivity | |
| explicit_specs = self.value_policy_overrides | |
| mode_overrides = self.value_mode_overrides | |
| else: | |
| raise ValueError("kind must be K or V") | |
| resolved_mode = self.resolve_page_mode(kind=kind, layer_id=layer_id, kv_head_id=kv_head_id) | |
| if resolved_mode != default_mode: | |
| override_scheme = ( | |
| "lut" if resolved_mode == "M1" | |
| else "sketch" if resolved_mode == "M2" | |
| else "project" if resolved_mode == "M4" | |
| else "turbo3" if resolved_mode == "T3" | |
| else default_quant_scheme | |
| ) | |
| return make_explicit_policy( | |
| kind=kind, | |
| policy_id=f"{kind.lower()}_mode_override_layer_{int(layer_id)}", | |
| sensitivity_tier="exact", | |
| candidates=(PageModeSpec(mode=resolved_mode, bits=default_bits, quant_scheme=override_scheme),), | |
| recent_escape_dtype=self.recent_page_escape_dtype, | |
| recent_window=0, | |
| ) | |
| for spec in explicit_specs: | |
| override_layer_id, candidates = _parse_layer_candidate_spec(spec, field_name="key_policy_overrides" if kind == "K" else "value_policy_overrides") | |
| if override_layer_id == int(layer_id): | |
| return make_explicit_policy( | |
| kind=kind, | |
| policy_id=f"{kind.lower()}_policy_override_layer_{int(layer_id)}", | |
| sensitivity_tier="balanced", | |
| candidates=candidates, | |
| recent_escape_dtype=self.recent_page_escape_dtype, | |
| recent_window=self.recent_window, | |
| ) | |
| tier = default_tier | |
| for spec in sensitivity_specs: | |
| override_layer_id, override_tier = _parse_layer_value_spec( | |
| spec, | |
| field_name="key_layer_sensitivity" if kind == "K" else "value_layer_sensitivity", | |
| allowed_values=("strict", "balanced", "aggressive"), | |
| ) | |
| if override_layer_id == int(layer_id): | |
| tier = override_tier | |
| return make_tier_candidates( | |
| kind=kind, | |
| sensitivity_tier=tier, | |
| default_bits=default_bits, | |
| default_quant_scheme=default_quant_scheme, | |
| default_mode=default_mode, | |
| recent_escape_dtype=self.recent_page_escape_dtype, | |
| recent_window=self.recent_window, | |
| prefer_project_key_mode=self.prefer_m4_project_k if kind == "K" else False, | |
| ) | |