from __future__ import annotations from dataclasses import dataclass, field from time import perf_counter from typing import Any, Sequence import numpy as np from .attention_runtime import BackendName, decode_multi_query_step, prepare_pages, score_pages from .backends import ( PreparedPageTorch, clear_prepared_chunk_cache, cuda_available, decode_grouped_multiquery_step_prepared_torch_tensor, decode_multi_query_step_torch_tensor, mps_available, prepare_m0_affine_pages_from_tensor_torch, prepared_chunk_cache_resident_bytes, set_prepared_chunk_cache_budget_override, ) from .config import DotCacheConfig from .decode_reference import decode_page from .encode import encode_page from .planner import PageModeSpec, choose_page_mode, observe_page, parse_page_mode_token from .page_cache import PreparedPageCache from .packing import words_per_group from .selector_baselines import adjust_linear_selector_model_logits, LinearSelectorModel, load_linear_selector_model from .session_runtime import PagedDecodeSession, score_page_relevance, select_execution_page_indices, select_window_page_indices from .tracing import ExecutionTrace from .types import EncodedPage, PageHeader from .modes.m2_key_sketch import segment_ids_for_token_count from .modes.m4_key_project import fit_shared_project_basis PageLike = EncodedPage | PreparedPageTorch _DECODE_STAGE_TIMING_STAGES = ( "prepare_pages_with_tail", "prepare_layout_build", "m2_prefilter", "query_export", "shortlist_selection", "shortlist_base_window", "shortlist_candidate_scoring", "shortlist_candidate_approx_scoring", "shortlist_candidate_ranking", "shortlist_candidate_secondary_scoring", "shortlist_candidate_neighbor_rescue", "shortlist_candidate_builtin_selection", "shortlist_candidate_builtin_candidate_index_build", "shortlist_candidate_builtin_sidecar_stack", "shortlist_candidate_builtin_score_compute", "shortlist_candidate_builtin_ranking", "shortlist_exact_selection", "shortlist_union_rescue", "shortlist_materialization", "grouping_validation", "chunk_budget_sync", "backend_call_wall", "backend_call_non_backend", ) def _empty_decode_stage_timing_totals() -> dict[str, float]: return {stage: 0.0 for stage in _DECODE_STAGE_TIMING_STAGES} def _decode_stage_summary_key(stage: str) -> str: return f"execution_decode_{stage}_ms_total" def _backend_trace_ms_total(trace: ExecutionTrace | None) -> float: if trace is None: return 0.0 return float( trace.prepare_ms_total + trace.score_ms_total + trace.mix_ms_total + trace.softmax_ms_total + trace.unpack_ms_total + trace.fwht_ms_total + trace.chunk_assembly_ms_total ) def default_q_head_to_kv_head(num_attention_heads: int, num_key_value_heads: int) -> np.ndarray: if num_attention_heads <= 0: raise ValueError("num_attention_heads must be positive") if num_key_value_heads <= 0: raise ValueError("num_key_value_heads must be positive") if num_attention_heads % num_key_value_heads != 0: raise ValueError("num_attention_heads must be divisible by num_key_value_heads for the Llama path") return (np.arange(num_attention_heads, dtype=np.int64) // (num_attention_heads // num_key_value_heads)).astype( np.int64, copy=False, ) def _group_query_heads(mapping: np.ndarray, *, num_key_value_heads: int) -> tuple[tuple[int, ...], ...]: grouped: list[list[int]] = [[] for _ in range(num_key_value_heads)] for q_head_id, kv_head_id in enumerate(mapping.tolist()): if kv_head_id < 0 or kv_head_id >= num_key_value_heads: raise ValueError("q_head_to_kv_head contains an invalid KV head id") grouped[kv_head_id].append(q_head_id) return tuple(tuple(group) for group in grouped) def _page_header(page: PageLike) -> PageHeader: return page.header if not isinstance(page, PreparedPageTorch) else page.header def _page_token_range(page: PageLike) -> dict[str, int]: header = _page_header(page) return { "token_start": int(header.token_start), "token_end": int(header.token_start + header.token_count), "token_count": int(header.token_count), } def _page_age_bucket(page: PageLike, *, context_length: int) -> str: header = _page_header(page) if context_length <= 0: return "recent" age_fraction = max(0.0, min(1.0, 1.0 - (float(header.token_start + header.token_count) / float(context_length)))) if age_fraction < 0.25: return "recent" if age_fraction < 0.75: return "middle" return "old" def _recent_old_bonus_weight( page: PageLike, *, recent_start: int, bonus_window: int, ) -> float: if bonus_window <= 0: return 0.0 header = _page_header(page) page_end = int(header.token_start + header.token_count) if page_end > recent_start: return 0.0 distance = max(0, int(recent_start) - page_end) if distance >= int(bonus_window): return 0.0 return float(1.0 - (float(distance) / float(max(int(bonus_window), 1)))) def _rank_correlation(lhs: Sequence[float], rhs: Sequence[float]) -> float | None: if len(lhs) != len(rhs): raise ValueError("rank correlation inputs must have matching lengths") if len(lhs) < 2: return None lhs_array = np.asarray(lhs, dtype=np.float32) rhs_array = np.asarray(rhs, dtype=np.float32) lhs_std = float(np.std(lhs_array)) rhs_std = float(np.std(rhs_array)) if lhs_std <= 0.0 or rhs_std <= 0.0: return None return float(np.corrcoef(lhs_array, rhs_array)[0, 1]) def _score_page_relevance_for_mode( query_slice: np.ndarray, page: PageLike, *, relevance_mode: str, ) -> float | None: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if relevance_mode == "sketch" and source_page.runtime_page_sketch is None: return None if relevance_mode == "envelope" and (source_page.runtime_page_min is None or source_page.runtime_page_max is None): return None return float( score_page_relevance( np.asarray(query_slice, dtype=np.float32), relevance_mode=relevance_mode, page_sketch=None if source_page.runtime_page_sketch is None else np.asarray(source_page.runtime_page_sketch, dtype=np.float32), page_min=None if source_page.runtime_page_min is None else np.asarray(source_page.runtime_page_min, dtype=np.float32), page_max=None if source_page.runtime_page_max is None else np.asarray(source_page.runtime_page_max, dtype=np.float32), ) ) def _page_has_m2_sidecar(page: PageLike) -> bool: if isinstance(page, PreparedPageTorch): return page.m2_sketch is not None and page.m2_basis is not None and page.m2_mean is not None return page.m2_sketch is not None and page.m2_basis is not None and page.m2_mean is not None def _page_m2_prefilter_score_numpy(queries: np.ndarray, page: PageLike) -> float: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if source_page.m2_sketch is None or source_page.m2_basis is None or source_page.m2_mean is None: raise ValueError("page is missing M2 sidecar payload") query_groups = queries.reshape(queries.shape[0], source_page.header.num_groups, source_page.header.group_size) logits = np.zeros((queries.shape[0], source_page.header.token_count), dtype=np.float32) for group_index in range(source_page.header.num_groups): group_basis = source_page.m2_basis[group_index].astype(np.float32) group_mean = source_page.m2_mean[group_index].astype(np.float32) if group_basis.ndim == 2: q_proj = query_groups[:, group_index, :] @ group_basis.T logits += np.einsum("tr,qr->qt", source_page.m2_sketch[:, group_index, :].astype(np.float32), q_proj) logits += np.einsum("g,qg->q", group_mean, query_groups[:, group_index, :])[:, None] continue segment_ids = segment_ids_for_token_count(source_page.header.token_count, int(group_basis.shape[0])) q_proj = np.einsum("srg,qg->qsr", group_basis, query_groups[:, group_index, :]) logits += np.einsum( "tr,qtr->qt", source_page.m2_sketch[:, group_index, :].astype(np.float32), q_proj[:, segment_ids, :], ) logits += np.einsum("tg,qg->qt", group_mean[segment_ids], query_groups[:, group_index, :]) return float(np.max(logits)) def _page_m2_prefilter_score_torch(queries, page: PageLike) -> float: try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for torch-side M2 prefiltering") from exc if not torch.is_tensor(queries): raise TypeError("queries must be a torch.Tensor") prepared = page if isinstance(page, PreparedPageTorch) else None if prepared is None or prepared.m2_sketch is None or prepared.m2_basis is None or prepared.m2_mean is None: return _page_m2_prefilter_score_numpy(queries.detach().cpu().numpy().astype(np.float32, copy=False), page) query_groups = queries.reshape(int(queries.shape[0]), prepared.header.num_groups, prepared.header.group_size) logits = torch.zeros((int(queries.shape[0]), prepared.header.token_count), dtype=torch.float32, device=queries.device) for group_index in range(prepared.header.num_groups): group_basis = prepared.m2_basis[group_index] group_mean = prepared.m2_mean[group_index] group_sketch = prepared.m2_sketch[:, group_index, :] work_dtype = torch.promote_types(query_groups.dtype, group_basis.dtype) work_dtype = torch.promote_types(work_dtype, group_sketch.dtype) work_dtype = torch.promote_types(work_dtype, group_mean.dtype) qg = query_groups[:, group_index, :].to(dtype=work_dtype) group_basis = group_basis.to(dtype=work_dtype) group_mean = group_mean.to(dtype=work_dtype) group_sketch = group_sketch.to(dtype=work_dtype) if group_basis.dim() == 2: q_proj = torch.einsum("qg,rg->qr", qg, group_basis) logits += torch.einsum("tr,qr->qt", group_sketch, q_proj) logits += torch.einsum("g,qg->q", group_mean, qg)[:, None] continue segment_ids = torch.from_numpy( segment_ids_for_token_count(prepared.header.token_count, int(group_basis.shape[0])) ).to(device=queries.device) q_proj = torch.einsum("srg,qg->qsr", group_basis, qg) logits += torch.einsum("tr,qtr->qt", group_sketch, q_proj[:, segment_ids, :]) logits += torch.einsum("tg,qg->qt", group_mean[segment_ids], qg) return float(torch.max(logits).item()) def _pages_can_batch_m2_prefilter(pages: Sequence[PageLike]) -> bool: if not pages: return False first = pages[0] first_source = first.source_page if isinstance(first, PreparedPageTorch) else first if first_source.m2_sketch is None or first_source.m2_basis is None or first_source.m2_mean is None: return False token_count = int(first_source.header.token_count) num_groups = int(first_source.header.num_groups) group_size = int(first_source.header.group_size) sketch_dim = int(first_source.m2_sketch.shape[-1]) segment_count = int(first_source.m2_basis.shape[1]) if first_source.m2_basis.ndim == 4 else 1 prepared_device = first.device_type if isinstance(first, PreparedPageTorch) else None for page in pages[1:]: source = page.source_page if isinstance(page, PreparedPageTorch) else page if source.m2_sketch is None or source.m2_basis is None or source.m2_mean is None: return False if int(source.header.token_count) != token_count: return False if int(source.header.num_groups) != num_groups or int(source.header.group_size) != group_size: return False if int(source.m2_sketch.shape[-1]) != sketch_dim: return False if (int(source.m2_basis.shape[1]) if source.m2_basis.ndim == 4 else 1) != segment_count: return False if prepared_device is not None: if not isinstance(page, PreparedPageTorch) or page.device_type != prepared_device: return False return True def _page_m2_prefilter_scores_numpy(queries: np.ndarray, pages: Sequence[PageLike]) -> np.ndarray: source_pages = [page.source_page if isinstance(page, PreparedPageTorch) else page for page in pages] first = source_pages[0] query_groups = queries.reshape(queries.shape[0], first.header.num_groups, first.header.group_size) sketch = np.stack([page.m2_sketch for page in source_pages], axis=0).astype(np.float32, copy=False) basis = np.stack([page.m2_basis for page in source_pages], axis=0).astype(np.float32, copy=False) mean = np.stack([page.m2_mean for page in source_pages], axis=0).astype(np.float32, copy=False) if basis.ndim == 4: q_proj = np.einsum("pgrd,qgd->qpgr", basis, query_groups) logits = np.einsum("ptgr,qpgr->qpt", sketch, q_proj) logits += np.einsum("pgd,qgd->qp", mean, query_groups)[:, :, None] return np.max(logits, axis=(0, 2)).astype(np.float32, copy=False) segment_ids = segment_ids_for_token_count(first.header.token_count, int(basis.shape[2])) logits = np.zeros((queries.shape[0], len(source_pages), first.header.token_count), dtype=np.float32) for group_index in range(first.header.num_groups): q_proj = np.einsum("psrd,qd->qpsr", basis[:, group_index], query_groups[:, group_index, :]) logits += np.einsum("ptr,qptr->qpt", sketch[:, :, group_index, :], q_proj[:, :, segment_ids, :]) logits += np.einsum("ptg,qg->qpt", mean[:, group_index, segment_ids, :], query_groups[:, group_index, :]) return np.max(logits, axis=(0, 2)).astype(np.float32, copy=False) def _page_m2_prefilter_scores_torch(queries, pages: Sequence[PageLike]) -> np.ndarray: try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for torch-side M2 prefiltering") from exc if not torch.is_tensor(queries): raise TypeError("queries must be a torch.Tensor") if not all(isinstance(page, PreparedPageTorch) for page in pages): return _page_m2_prefilter_scores_numpy(queries.detach().cpu().numpy().astype(np.float32, copy=False), pages) first = pages[0] query_groups = queries.reshape(int(queries.shape[0]), first.header.num_groups, first.header.group_size) sketch = torch.stack([page.m2_sketch for page in pages], dim=0) basis = torch.stack([page.m2_basis for page in pages], dim=0) mean = torch.stack([page.m2_mean for page in pages], dim=0) if basis.dim() == 4: q_proj = torch.einsum("pgrd,qgd->qpgr", basis, query_groups) logits = torch.einsum("ptgr,qpgr->qpt", sketch, q_proj) logits += torch.einsum("pgd,qgd->qp", mean, query_groups)[:, :, None] return torch.amax(logits, dim=(0, 2)).detach().cpu().numpy().astype(np.float32, copy=False) segment_ids = torch.from_numpy(segment_ids_for_token_count(first.header.token_count, int(basis.shape[2]))).to(device=queries.device) logits = torch.zeros((int(queries.shape[0]), len(pages), first.header.token_count), dtype=torch.float32, device=queries.device) for group_index in range(first.header.num_groups): group_basis = basis[:, group_index] group_sketch = sketch[:, :, group_index, :] group_mean = mean[:, group_index, segment_ids, :] work_dtype = torch.promote_types(query_groups.dtype, group_basis.dtype) work_dtype = torch.promote_types(work_dtype, group_sketch.dtype) work_dtype = torch.promote_types(work_dtype, group_mean.dtype) qg = query_groups[:, group_index, :].to(dtype=work_dtype) group_basis = group_basis.to(dtype=work_dtype) group_sketch = group_sketch.to(dtype=work_dtype) group_mean = group_mean.to(dtype=work_dtype) q_proj = torch.einsum("psrd,qg->qpsr", group_basis, qg) logits += torch.einsum("ptr,qptr->qpt", group_sketch, q_proj[:, :, segment_ids, :]) logits += torch.einsum("ptg,qg->qpt", group_mean, qg) return torch.amax(logits, dim=(0, 2)).detach().cpu().numpy().astype(np.float32, copy=False) def _grouped_pages_can_batch( key_pages_by_group: Sequence[Sequence[PageLike]], value_pages_by_group: Sequence[Sequence[PageLike]], query_groups: Sequence[Any], ) -> bool: return _grouped_pages_batch_rejection_reason(key_pages_by_group, value_pages_by_group, query_groups) is None def _grouped_pages_batch_rejection_reason( key_pages_by_group: Sequence[Sequence[PageLike]], value_pages_by_group: Sequence[Sequence[PageLike]], query_groups: Sequence[Any], ) -> str | None: def _page_batch_signature(page: PreparedPageTorch) -> tuple[int | str, ...]: sketch = page.m2_sketch basis = page.m2_basis sketch_dim = int(sketch.shape[-1]) if sketch is not None else 0 segment_count = int(basis.shape[1]) if basis is not None and int(basis.dim()) == 4 else 1 centered = int(page.m2_mean is not None) header = page.header return ( page.device_type, header.kind, header.mode_default, header.escape_dtype if header.mode_default == "M3" else "", header.token_count, header.head_dim, header.padded_head_dim, header.group_size, header.num_groups, header.bits, header.words_per_group, header.layout, header.quant_scheme, sketch_dim, segment_count, centered, ) if not key_pages_by_group: return "no_key_groups" if len(key_pages_by_group) != len(value_pages_by_group): return "group_count_mismatch" group_count = len(key_pages_by_group) if len(query_groups) != group_count: return "query_group_count_mismatch" try: query_count = int(query_groups[0].shape[0]) except Exception: return "query_shape_invalid" page_count = len(key_pages_by_group[0]) if page_count == 0: return "page_count_zero" for group_index in range(group_count): if len(key_pages_by_group[group_index]) != page_count or len(value_pages_by_group[group_index]) != page_count: return "page_count_mismatch" if int(query_groups[group_index].shape[0]) != query_count: return "query_count_mismatch" if not all(isinstance(page, PreparedPageTorch) for page in key_pages_by_group[group_index]): return "key_page_not_prepared" if not all(isinstance(page, PreparedPageTorch) for page in value_pages_by_group[group_index]): return "value_page_not_prepared" if any(page.device_type != key_pages_by_group[0][0].device_type for page in key_pages_by_group[group_index]): return "key_device_mismatch" if any(page.device_type != value_pages_by_group[0][0].device_type for page in value_pages_by_group[group_index]): return "value_device_mismatch" return None @dataclass(slots=True) class _PreparedDecodeViewLayout: grouped_batch_signature: tuple[tuple[tuple[Any, ...], tuple[Any, ...]], ...] key_chunk_lengths: tuple[int, ...] value_chunk_lengths: tuple[int, ...] @dataclass(slots=True) class _ExecutionBuiltinSelectorCache: page_signature: tuple[tuple[int, int, int], ...] = () sketch_matrix: np.ndarray | None = None minima_matrix: np.ndarray | None = None maxima_matrix: np.ndarray | None = None def resident_bytes(self) -> int: total = 0 if self.sketch_matrix is not None: total += int(self.sketch_matrix.nbytes) if self.minima_matrix is not None: total += int(self.minima_matrix.nbytes) if self.maxima_matrix is not None: total += int(self.maxima_matrix.nbytes) return int(total) def _prepared_page_group_signature(page: PreparedPageTorch) -> tuple[Any, ...]: basis = getattr(page, "m2_basis", None) if basis is None: segment_count = 0 else: ndim = int(basis.ndim) if hasattr(basis, "ndim") else int(basis.dim()) segment_count = int(basis.shape[1]) if ndim == 4 else 1 return ( page.device_type, page.header.mode_default, page.header.token_count, page.header.head_dim, page.header.padded_head_dim, page.header.group_size, page.header.num_groups, page.header.bits, page.header.words_per_group, page.header.layout, page.header.quant_scheme, int(page.m2_sketch.shape[-1]) if page.m2_sketch is not None else 0, segment_count, ) def _prepared_page_chunk_lengths(pages: Sequence[PreparedPageTorch]) -> tuple[int, ...]: if not pages: return () lengths: list[int] = [] current_signature: tuple[Any, ...] | None = None current_length = 0 for page in pages: signature = _prepared_page_group_signature(page) if current_signature is None or signature == current_signature: current_signature = signature current_length += 1 continue lengths.append(current_length) current_signature = signature current_length = 1 if current_length > 0: lengths.append(current_length) return tuple(lengths) def _prepared_page_aligned_chunk_lengths( key_pages: Sequence[PreparedPageTorch], value_pages: Sequence[PreparedPageTorch], ) -> tuple[int, ...]: if len(key_pages) != len(value_pages): return () if not key_pages: return () lengths: list[int] = [] current_length = 0 current_key_signature: tuple[Any, ...] | None = None current_value_signature: tuple[Any, ...] | None = None for key_page, value_page in zip(key_pages, value_pages, strict=True): key_signature = _prepared_page_group_signature(key_page) value_signature = _prepared_page_group_signature(value_page) if ( current_length > 0 and (key_signature != current_key_signature or value_signature != current_value_signature) ): lengths.append(current_length) current_length = 0 if current_length == 0: current_key_signature = key_signature current_value_signature = value_signature current_length += 1 if current_length > 0: lengths.append(current_length) return tuple(lengths) def _build_prepared_decode_view_layout( key_pages: Sequence[PageLike], value_pages: Sequence[PageLike], ) -> _PreparedDecodeViewLayout | None: if len(key_pages) != len(value_pages) or not key_pages: return None if not all(isinstance(page, PreparedPageTorch) for page in key_pages): return None if not all(isinstance(page, PreparedPageTorch) for page in value_pages): return None prepared_key_pages = tuple(key_pages) prepared_value_pages = tuple(value_pages) return _PreparedDecodeViewLayout( grouped_batch_signature=tuple( (_prepared_page_group_signature(key_page), _prepared_page_group_signature(value_page)) for key_page, value_page in zip(prepared_key_pages, prepared_value_pages, strict=True) ), key_chunk_lengths=_prepared_page_chunk_lengths(prepared_key_pages), value_chunk_lengths=_prepared_page_chunk_lengths(prepared_value_pages), ) def _grouped_layouts_can_batch( layouts: Sequence[_PreparedDecodeViewLayout | None], query_groups: Sequence[Any], ) -> bool: return _grouped_layout_batch_rejection_reason(layouts, query_groups) is None def _grouped_layout_batch_rejection_reason( layouts: Sequence[_PreparedDecodeViewLayout | None], query_groups: Sequence[Any], ) -> str | None: if not layouts or any(layout is None for layout in layouts): return "layout_missing" try: query_count = int(query_groups[0].shape[0]) except Exception: return "query_shape_invalid" first_layout = layouts[0] assert first_layout is not None for group_index in range(1, len(layouts)): layout = layouts[group_index] assert layout is not None if layout.grouped_batch_signature != first_layout.grouped_batch_signature: return "layout_signature_mismatch" if int(query_groups[group_index].shape[0]) != query_count: return "query_count_mismatch" return None def _normalize_prefill_tensor( values: np.ndarray, *, num_key_value_heads: int, head_dim: int, name: str, ) -> np.ndarray: array = np.asarray(values, dtype=np.float32) if array.ndim == 4: if array.shape[0] != 1: raise ValueError(f"{name} batch dimension must be 1 for the Phase 5 Llama path") array = array[0] if array.ndim != 3: raise ValueError(f"{name} must have shape [kv_heads, seq_len, head_dim] or [1, kv_heads, seq_len, head_dim]") if array.shape[0] != num_key_value_heads: raise ValueError(f"{name} must contain {num_key_value_heads} KV heads") if array.shape[2] != head_dim: raise ValueError(f"{name} head_dim must equal {head_dim}") return array def _normalize_step_tensor( values: np.ndarray, *, num_key_value_heads: int, head_dim: int, name: str, ) -> np.ndarray: array = np.asarray(values, dtype=np.float32) if array.ndim == 4: if array.shape[0] != 1: raise ValueError(f"{name} batch dimension must be 1 for the Phase 5 Llama path") array = array[0] if array.ndim != 3: raise ValueError(f"{name} must have shape [kv_heads, token_count, head_dim]") if array.shape[0] != num_key_value_heads: raise ValueError(f"{name} must contain {num_key_value_heads} KV heads") if array.shape[2] != head_dim: raise ValueError(f"{name} head_dim must equal {head_dim}") return array def _normalize_query_step(query_step: np.ndarray, *, num_attention_heads: int, head_dim: int) -> np.ndarray: queries = np.asarray(query_step, dtype=np.float32) if queries.ndim == 4: if queries.shape[0] != 1 or queries.shape[2] != 1: raise ValueError("query_step must have shape [q_heads, head_dim] or [1, q_heads, 1, head_dim]") queries = queries[0, :, 0, :] if queries.ndim != 2: raise ValueError("query_step must have shape [q_heads, head_dim]") if queries.shape[0] != num_attention_heads: raise ValueError(f"query_step must contain {num_attention_heads} query heads") if queries.shape[1] != head_dim: raise ValueError(f"query_step head_dim must equal {head_dim}") return queries def _normalize_prefill_tensor_torch(values, *, num_key_value_heads: int, head_dim: int, name: str): try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for torch-native prefill ingest") from exc if not torch.is_tensor(values): raise TypeError(f"{name} must be a torch.Tensor") array = values.detach().to(dtype=torch.float32) if array.ndim == 4: if int(array.shape[0]) != 1: raise ValueError(f"{name} batch dimension must be 1 for the Phase 5 Llama path") array = array[0] if array.ndim != 3: raise ValueError(f"{name} must have shape [kv_heads, seq_len, head_dim] or [1, kv_heads, seq_len, head_dim]") if int(array.shape[0]) != num_key_value_heads: raise ValueError(f"{name} must contain {num_key_value_heads} KV heads") if int(array.shape[2]) != head_dim: raise ValueError(f"{name} head_dim must equal {head_dim}") return array @dataclass(slots=True) class _TailPageBuilder: config: DotCacheConfig layer_id: int kv_head_id: int token_start: int | None = None key_rows: list[np.ndarray] = field(default_factory=list) value_rows: list[np.ndarray] = field(default_factory=list) @property def token_count(self) -> int: return len(self.key_rows) def clear(self) -> None: self.token_start = None self.key_rows.clear() self.value_rows.clear() def _should_build_execution_runtime_metadata(self, *, kind: str) -> bool: if kind != "K": return False return self.config.execution_shortlist_enabled() def load_prefill_remainder( self, key_rows: np.ndarray, value_rows: np.ndarray, *, token_start: int, ) -> None: self.clear() if key_rows.shape[0] != value_rows.shape[0]: raise ValueError("prefill remainder key/value rows must align") if key_rows.shape[0] == 0: return self.token_start = int(token_start) self.key_rows.extend(np.asarray(key_rows, dtype=np.float32)) self.value_rows.extend(np.asarray(value_rows, dtype=np.float32)) def append_step_rows( self, key_rows: np.ndarray, value_rows: np.ndarray, *, token_start: int, sequence_length: int | None = None, ) -> tuple[list[EncodedPage], list[EncodedPage]]: if key_rows.shape != value_rows.shape: raise ValueError("step key/value rows must align") if key_rows.ndim != 2: raise ValueError("step rows must have shape [token_count, head_dim]") if key_rows.shape[0] == 0: return [], [] finalized_key_pages: list[EncodedPage] = [] finalized_value_pages: list[EncodedPage] = [] expected_token = self.next_token_index if expected_token is not None and token_start != expected_token: raise ValueError(f"tail-page append expected token_index {expected_token}, received {token_start}") if self.token_start is None: self.token_start = int(token_start) for offset in range(key_rows.shape[0]): self.key_rows.append(np.asarray(key_rows[offset], dtype=np.float32)) self.value_rows.append(np.asarray(value_rows[offset], dtype=np.float32)) if len(self.key_rows) < self.config.tokens_per_page: continue if self.token_start is None: raise RuntimeError("tail-page token_start is missing while finalizing a page") dense_keys = np.stack(self.key_rows, axis=0).astype(np.float32, copy=False) dense_values = np.stack(self.value_rows, axis=0).astype(np.float32, copy=False) current_sequence_length = int(sequence_length if sequence_length is not None else (token_start + key_rows.shape[0])) key_page_mode = self._select_page_mode( dense_keys, kind="K", layer_id=self.layer_id, kv_head_id=self.kv_head_id, token_start=self.token_start, sequence_length=current_sequence_length, stage="decode", ) key_mode = None if key_page_mode is not None else self.config.resolve_page_mode(kind="K", layer_id=self.layer_id, kv_head_id=self.kv_head_id) value_page_mode = self._select_page_mode( dense_values, kind="V", layer_id=self.layer_id, kv_head_id=self.kv_head_id, token_start=self.token_start, sequence_length=current_sequence_length, stage="decode", ) value_mode = None if value_page_mode is not None else self.config.resolve_page_mode(kind="V", layer_id=self.layer_id, kv_head_id=self.kv_head_id) finalized_key_pages.append( encode_page( dense_keys, self.config, kind="K", layer_id=self.layer_id, kv_head_id=self.kv_head_id, token_start=self.token_start, mode=key_mode, page_mode=key_page_mode, ) ) finalized_value_pages.append( encode_page( dense_values, self.config, kind="V", layer_id=self.layer_id, kv_head_id=self.kv_head_id, token_start=self.token_start, mode=value_mode, page_mode=value_page_mode, ) ) self.key_rows.clear() self.value_rows.clear() self.token_start += self.config.tokens_per_page if self.token_count == 0: self.token_start = None return finalized_key_pages, finalized_value_pages @property def next_token_index(self) -> int | None: if self.token_start is None: return None return self.token_start + self.token_count def build_temp_pages(self) -> tuple[EncodedPage, EncodedPage] | None: if self.token_count == 0: return None if self.token_start is None: raise RuntimeError("tail-page token_start is missing") dense_keys = np.stack(self.key_rows, axis=0).astype(np.float32, copy=False) dense_values = np.stack(self.value_rows, axis=0).astype(np.float32, copy=False) return ( encode_page( dense_keys, self.config, kind="K", layer_id=self.layer_id, kv_head_id=self.kv_head_id, token_start=self.token_start, mode="M3", build_runtime_metadata=self._should_build_execution_runtime_metadata(kind="K"), ), encode_page( dense_values, self.config, kind="V", layer_id=self.layer_id, kv_head_id=self.kv_head_id, token_start=self.token_start, mode="M3", build_runtime_metadata=False, ), ) def _tail_escape_dtype_numpy(dtype_name: str) -> np.dtype: if dtype_name == "float16": return np.float16 if dtype_name == "float32": return np.float32 if dtype_name == "int8": return np.int8 raise ValueError(f"unsupported tail escape dtype: {dtype_name}") def _quantize_tail_rows_numpy(rows: np.ndarray, dtype_name: str) -> tuple[np.ndarray, np.ndarray | None]: values = np.asarray(rows, dtype=np.float32) if dtype_name in {"float16", "float32"}: return values.astype(_tail_escape_dtype_numpy(dtype_name), copy=False), None if dtype_name == "int8": row_absmax = np.max(np.abs(values), axis=1) scales = np.maximum(row_absmax / 127.0, 1e-8).astype(np.float16, copy=False) quantized = np.clip(np.rint(values / scales[:, None]), -127.0, 127.0).astype(np.int8, copy=False) return quantized, scales raise ValueError(f"unsupported tail escape dtype: {dtype_name}") @dataclass(slots=True) class _PersistentTailPage: config: DotCacheConfig layer_id: int kv_head_id: int kind: str device_type: str source_page: EncodedPage | None = None prepared_page: PreparedPageTorch | None = None host_buffer: np.ndarray | None = None host_scales: np.ndarray | None = None token_count: int = 0 resident_nbytes: int = 0 def clear(self) -> None: self.token_count = 0 if self.source_page is not None: self.source_page.header.token_count = 0 self.source_page.escape_payload = None if self.host_buffer is None else self.host_buffer[:0] self.source_page.escape_scales = None if self.host_scales is None else self.host_scales[:0] self.source_page.runtime_page_mean = None self.source_page.runtime_page_sketch = None self.source_page.runtime_page_min = None self.source_page.runtime_page_max = None def _should_build_execution_runtime_metadata(self) -> bool: if self.kind != "K": return False return self.config.execution_shortlist_enabled() def _refresh_runtime_metadata(self) -> None: if self.source_page is None or not self._should_build_execution_runtime_metadata(): return if self.token_count <= 0: self.source_page.runtime_page_mean = None self.source_page.runtime_page_sketch = None self.source_page.runtime_page_min = None self.source_page.runtime_page_max = None return dense = self.materialize_rows() self.source_page.runtime_page_mean = dense.mean(axis=0).astype(np.float32, copy=False) self.source_page.runtime_page_sketch = self.source_page.runtime_page_mean[None, :] self.source_page.runtime_page_min = dense.min(axis=0).astype(np.float32, copy=False) self.source_page.runtime_page_max = dense.max(axis=0).astype(np.float32, copy=False) def _ensure_allocated(self, *, token_start: int) -> bool: if self.source_page is not None and self.prepared_page is not None and self.host_buffer is not None: self.source_page.header.token_start = int(token_start) self.prepared_page.header.token_start = int(token_start) return False try: import torch except ImportError as exc: # pragma: no cover - torch is required only for the MPS tail path raise RuntimeError("torch is required for the persistent torch tail path") from exc dtype_name = self.config.escape_dtype np_dtype = _tail_escape_dtype_numpy(dtype_name) torch_dtype = getattr(torch, dtype_name) host_buffer = np.zeros((self.config.tokens_per_page, self.config.head_dim), dtype=np_dtype) host_scales = None if dtype_name != "int8" else np.zeros((self.config.tokens_per_page,), dtype=np.float16) header = PageHeader( layer_id=self.layer_id, kv_head_id=self.kv_head_id, kind=self.kind, token_start=int(token_start), token_count=0, head_dim=self.config.head_dim, padded_head_dim=self.config.padded_head_dim, group_size=self.config.group_size, num_groups=self.config.num_groups, bits=self.config.bits_k if self.kind == "K" else self.config.bits_v, words_per_group=words_per_group(self.config.group_size, self.config.bits_k if self.kind == "K" else self.config.bits_v), mode_default="M3", layout=self.config.payload_layout_k if self.kind == "K" else self.config.payload_layout_v, quant_scheme=self.config.quant_scheme_k if self.kind == "K" else self.config.quant_scheme_v, escape_dtype=dtype_name, ) source_page = EncodedPage( header=header, escape_payload=host_buffer[:0], escape_scales=None if host_scales is None else host_scales[:0], ) device_payload = torch.zeros( (self.config.tokens_per_page, self.config.head_dim), dtype=torch_dtype, device=self.device_type, ) device_scales = None if dtype_name == "int8": scale_dtype = torch.float32 if self.device_type == "mps" else torch.float16 device_scales = torch.zeros((self.config.tokens_per_page,), dtype=scale_dtype, device=self.device_type) prepared_page = PreparedPageTorch( device_type=self.device_type, source_page=source_page, header=header, escape_payload=device_payload, escape_scales=device_scales, host_to_device_nbytes=int(device_payload.numel() * device_payload.element_size()) + (0 if device_scales is None else int(device_scales.numel() * device_scales.element_size())), ) self.source_page = source_page self.prepared_page = prepared_page self.host_buffer = host_buffer self.host_scales = host_scales self.resident_nbytes = int(device_payload.numel() * device_payload.element_size()) + ( 0 if device_scales is None else int(device_scales.numel() * device_scales.element_size()) ) return True def load_rows( self, rows: np.ndarray, *, token_start: int, trace: ExecutionTrace | None = None, ) -> None: values = np.asarray(rows, dtype=np.float32) if values.ndim != 2 or values.shape[1] != self.config.head_dim: raise ValueError("tail rows must have shape [token_count, head_dim]") self.clear() if values.shape[0] == 0: return self._ensure_allocated(token_start=token_start) self.append_rows(values, token_start=token_start, trace=trace) def prepare_append_span(self, *, token_start: int, row_count: int) -> tuple[int, int]: if row_count < 0: raise ValueError("row_count must be non-negative") self._ensure_allocated(token_start=token_start if self.token_count == 0 else self.source_page.header.token_start) if self.source_page is None or self.prepared_page is None or self.host_buffer is None: raise RuntimeError("persistent tail page is not initialized") expected_token = self.source_page.header.token_start + self.token_count if token_start != expected_token: raise ValueError(f"persistent tail expected token_start {expected_token}, received {token_start}") end = self.token_count + row_count if end > self.config.tokens_per_page: raise ValueError("persistent tail cannot exceed tokens_per_page") start = self.token_count self.source_page.header.token_count = end self.prepared_page.header.token_count = end self.source_page.escape_payload = self.host_buffer[:end] self.source_page.escape_scales = None if self.host_scales is None else self.host_scales[:end] self.token_count = end return start, end def append_rows_from_device( self, *, rows: np.ndarray, device_rows: Any, token_start: int, ) -> None: import torch values = np.asarray(rows, dtype=np.float32) if values.ndim != 2 or values.shape[1] != self.config.head_dim: raise ValueError("tail rows must have shape [token_count, head_dim]") if values.shape[0] == 0: return if self.host_buffer is None or self.prepared_page is None: raise RuntimeError("persistent tail page is not initialized") start, end = self.prepare_append_span(token_start=token_start, row_count=values.shape[0]) converted, scales = _quantize_tail_rows_numpy(values, self.config.escape_dtype) self.host_buffer[start:end] = converted if self.host_scales is not None and scales is not None: self.host_scales[start:end] = scales if self.prepared_page.escape_payload.dtype == device_rows.dtype: self.prepared_page.escape_payload[start:end, : self.config.head_dim] = device_rows else: if self.config.escape_dtype != "int8": self.prepared_page.escape_payload[start:end, : self.config.head_dim] = device_rows.to( dtype=self.prepared_page.escape_payload.dtype ) else: row_scales = torch.clamp(device_rows.abs().amax(dim=-1) / 127.0, min=1e-8).to( dtype=self.prepared_page.escape_scales.dtype ) quantized = torch.clamp(torch.round(device_rows / row_scales[:, None]), -127.0, 127.0).to(dtype=torch.int8) self.prepared_page.escape_payload[start:end, : self.config.head_dim] = quantized if self.prepared_page.escape_scales is None: raise RuntimeError("int8 persistent tail is missing escape scales") self.prepared_page.escape_scales[start:end] = row_scales self._refresh_runtime_metadata() return if self.prepared_page.escape_scales is not None and scales is not None: self.prepared_page.escape_scales[start:end] = torch.from_numpy(np.ascontiguousarray(scales)).to( device=self.device_type, dtype=self.prepared_page.escape_scales.dtype, ) self._refresh_runtime_metadata() def append_device_rows( self, device_rows, *, token_start: int, ) -> None: try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for the persistent torch tail path") from exc if not torch.is_tensor(device_rows): raise TypeError("append_device_rows requires a torch.Tensor") if device_rows.ndim != 2 or int(device_rows.shape[1]) != self.config.head_dim: raise ValueError("tail rows must have shape [token_count, head_dim]") if int(device_rows.shape[0]) == 0: return if self.prepared_page is None: self._ensure_allocated(token_start=token_start if self.token_count == 0 else self.source_page.header.token_start) if self.prepared_page is None: raise RuntimeError("persistent tail page is not initialized") start, end = self.prepare_append_span(token_start=token_start, row_count=int(device_rows.shape[0])) if self.config.escape_dtype != "int8": self.prepared_page.escape_payload[start:end, : self.config.head_dim] = device_rows.to( dtype=self.prepared_page.escape_payload.dtype ) self._refresh_runtime_metadata() return row_scales = torch.clamp(device_rows.abs().amax(dim=-1) / 127.0, min=1e-8).to( dtype=self.prepared_page.escape_scales.dtype ) quantized = torch.clamp(torch.round(device_rows / row_scales[:, None]), -127.0, 127.0).to(dtype=torch.int8) self.prepared_page.escape_payload[start:end, : self.config.head_dim] = quantized if self.prepared_page.escape_scales is None: raise RuntimeError("int8 persistent tail is missing escape scales") self.prepared_page.escape_scales[start:end] = row_scales self._refresh_runtime_metadata() def materialize_rows(self) -> np.ndarray: if self.prepared_page is None or self.token_count <= 0: return np.zeros((0, self.config.head_dim), dtype=np.float32) payload = self.prepared_page.escape_payload[: self.token_count, : self.config.head_dim].detach().cpu().numpy() if self.prepared_page.header.escape_dtype == "int8": if self.prepared_page.escape_scales is None: raise RuntimeError("int8 persistent tail is missing escape scales") scales = self.prepared_page.escape_scales[: self.token_count].detach().cpu().numpy() return payload.astype(np.float32, copy=False) * scales.astype(np.float32, copy=False)[:, None] return payload.astype(np.float32, copy=False) def append_rows( self, rows: np.ndarray, *, token_start: int, trace: ExecutionTrace | None = None, ) -> None: values = np.asarray(rows, dtype=np.float32) if values.ndim != 2 or values.shape[1] != self.config.head_dim: raise ValueError("tail rows must have shape [token_count, head_dim]") if values.shape[0] == 0: return self._ensure_allocated(token_start=token_start if self.token_count == 0 else self.source_page.header.token_start) if self.source_page is None or self.prepared_page is None or self.host_buffer is None: raise RuntimeError("persistent tail page is not initialized") converted, scales = _quantize_tail_rows_numpy(values, self.config.escape_dtype) try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for the persistent torch tail path") from exc row_tensor = torch.from_numpy(np.ascontiguousarray(converted)).to(device=self.device_type) start, end = self.prepare_append_span(token_start=token_start, row_count=values.shape[0]) self.host_buffer[start:end] = converted if self.host_scales is not None and scales is not None: self.host_scales[start:end] = scales self.prepared_page.escape_payload[start:end, : self.config.head_dim] = row_tensor if trace is not None: trace.record_host_to_device(int(row_tensor.numel() * row_tensor.element_size())) if self.prepared_page.escape_scales is not None and scales is not None: scale_tensor = torch.from_numpy(np.ascontiguousarray(scales)).to( device=self.device_type, dtype=self.prepared_page.escape_scales.dtype, ) self.prepared_page.escape_scales[start:end] = scale_tensor if trace is not None: trace.record_host_to_device(int(scale_tensor.numel() * scale_tensor.element_size())) self._refresh_runtime_metadata() @property def active_page(self) -> PreparedPageTorch | None: if self.token_count <= 0: return None return self.prepared_page @dataclass(slots=True) class _HeadSessionState: session: PagedDecodeSession tail: _TailPageBuilder persistent_key_tail: _PersistentTailPage | None = None persistent_value_tail: _PersistentTailPage | None = None decode_key_pages_with_tail: list[PageLike] | None = None decode_value_pages_with_tail: list[PageLike] | None = None decode_view_layout: _PreparedDecodeViewLayout | None = None execution_builtin_selector_cache: _ExecutionBuiltinSelectorCache | None = None sequence_length: int = 0 tracked_direct_prepared_pages: dict[int, int] = field(default_factory=dict) tracked_tail_resident_bytes: int = 0 def invalidate_decode_views(self) -> None: self.decode_key_pages_with_tail = None self.decode_value_pages_with_tail = None self.decode_view_layout = None self.execution_builtin_selector_cache = None def clear(self, *, clear_prepared_cache: bool) -> None: self.session.key_pages.clear() self.session.value_pages.clear() self.session.key_page_sketches.clear() self.session.key_page_minima.clear() self.session.key_page_maxima.clear() self.session.value_page_summaries.clear() self.session.last_selected_indices.clear() if clear_prepared_cache and self.session.cache is not None: self.session.cache.clear() self.tail.clear() if self.persistent_key_tail is not None: self.persistent_key_tail.clear() if self.persistent_value_tail is not None: self.persistent_value_tail.clear() self.invalidate_decode_views() self.sequence_length = 0 class ModelPagedKVCache: def __init__( self, *, config: DotCacheConfig, num_hidden_layers: int, num_attention_heads: int, num_key_value_heads: int, backend: BackendName = "auto", cache: PreparedPageCache | None = None, ) -> None: self.config = config self.num_hidden_layers = int(num_hidden_layers) self.num_attention_heads = int(num_attention_heads) self.num_key_value_heads = int(num_key_value_heads) self.backend = backend self.cache = cache if cache is not None else PreparedPageCache() self.default_q_head_to_kv_head = default_q_head_to_kv_head(self.num_attention_heads, self.num_key_value_heads) self.default_grouped_query_heads = _group_query_heads( self.default_q_head_to_kv_head, num_key_value_heads=self.num_key_value_heads, ) self._states: dict[tuple[int, int], _HeadSessionState] = {} self._m2_prefilter_invocations = 0 self._m2_prefilter_candidate_pages = 0 self._m2_prefilter_selected_pages = 0 self._decode_path_counts: dict[str, int] = { "grouped_batched": 0, "per_kv_fallback": 0, } self._decode_path_counts_by_layer: dict[int, dict[str, int]] = {} self._execution_shortlist_invocations = 0 self._execution_shortlist_applied = 0 self._execution_shortlist_group_union_applied = 0 self._execution_shortlist_grouping_rejections = 0 self._execution_shortlist_grouping_rejection_reason_counts: dict[str, int] = {} self._execution_shortlist_grouping_rejection_reason_counts_by_layer: dict[int, dict[str, int]] = {} self._execution_shortlist_total_pages = 0 self._execution_shortlist_selected_pages = 0 self._execution_shortlist_invocations_by_layer: dict[int, int] = {} self._execution_shortlist_applied_by_layer: dict[int, int] = {} self._execution_shortlist_group_union_applied_by_layer: dict[int, int] = {} self._execution_shortlist_grouping_rejections_by_layer: dict[int, int] = {} self._execution_shortlist_total_pages_by_layer: dict[int, int] = {} self._execution_shortlist_selected_pages_by_layer: dict[int, int] = {} self._execution_shortlist_trace_records: list[dict[str, object]] = [] self._execution_exact_refine_invocations = 0 self._execution_exact_refine_candidate_pages = 0 self._execution_exact_refine_selected_pages = 0 self._execution_exact_refine_invocations_by_layer: dict[int, int] = {} self._execution_exact_refine_candidate_pages_by_layer: dict[int, int] = {} self._execution_exact_refine_selected_pages_by_layer: dict[int, int] = {} self._decode_grouped_batch_rejection_reason_counts: dict[str, int] = {} self._decode_grouped_batch_rejection_reason_counts_by_layer: dict[int, dict[str, int]] = {} self._decode_stage_timings = _empty_decode_stage_timing_totals() self._decode_stage_timings_by_layer: dict[int, dict[str, float]] = {} self._direct_prepared_page_resident_bytes = 0 self._direct_prepared_page_refcounts: dict[int, int] = {} self._direct_prepared_page_sizes: dict[int, int] = {} self._tail_resident_bytes = 0 self._chunk_budget_dirty_marks = 0 self._chunk_budget_dirty_transitions = 0 self._chunk_budget_dirty_reason_counts: dict[str, int] = {} self._chunk_budget_sync_invocations = 0 self._chunk_budget_sync_clean_skips = 0 self._chunk_budget_sync_dirty_invocations = 0 self._chunk_budget_override_calls = 0 self._chunk_budget_override_budget_change_calls = 0 self._chunk_budget_override_same_budget_calls = 0 self._chunk_budget_freeze_override_calls = 0 self._builtin_selector_score_all_pages_calls = 0 self._builtin_selector_candidate_only_calls = 0 self._builtin_selector_candidate_pages = 0 self._builtin_selector_total_pages = 0 self._builtin_selector_candidate_fraction_sum = 0.0 self._builtin_selector_candidate_fraction_max = 0.0 self._builtin_selector_cache_hits = 0 self._builtin_selector_cache_builds = 0 self._builtin_selector_cache_build_bytes = 0 self._builtin_selector_cache_build_bytes_max = 0 self._execution_value_escape_cache: dict[tuple[int, int, int, str, str], PageLike] = {} self._execution_value_escape_source_pages: dict[tuple[int, str], EncodedPage] = {} self._execution_value_escape_cache_hits = 0 self._execution_value_escape_source_registrations = 0 self._execution_value_escape_prepared_page_builds = 0 self._execution_value_escape_prewarm_invocations = 0 self._execution_value_escape_prewarm_pages = 0 self._execution_value_escape_prewarm_ms_total = 0.0 self._execution_value_escape_builds = 0 self._execution_value_escape_applied_pages = 0 self._prepared_chunk_cache_frozen_budget_bytes: int | None = None self._prepared_chunk_cache_applied_budget_bytes: int | None = None self._prepared_chunk_cache_budget_dirty = True self._learned_page_selector_model: LinearSelectorModel | None = None self._learned_page_selector_invocations = 0 self._learned_page_selector_predictions: dict[str, int] = {} self._learned_page_selector_fallbacks = 0 self._learned_page_selector_ms_total = 0.0 self._learned_page_selector_invocations_by_stage: dict[str, int] = {} self._learned_page_selector_fallbacks_by_stage: dict[str, int] = {} self._learned_page_selector_ms_total_by_stage: dict[str, float] = {} self._learned_page_selector_predictions_by_stage: dict[str, dict[str, int]] = {} if self.config.learned_page_selector_enabled(): self._learned_page_selector_model = load_linear_selector_model(str(self.config.learned_page_selector_path)) if float(self.config.learned_page_selector_logit_offset) != 0.0: self._learned_page_selector_model = adjust_linear_selector_model_logits( self._learned_page_selector_model, candidate_logit_offsets={ str(self.config.learned_page_selector_target_candidate): float( self.config.learned_page_selector_logit_offset ) }, ) @property def resident_bytes(self) -> int: return self.resident_byte_summary()["resident_bytes"] @staticmethod def _prepared_page_resident_bytes(page: PreparedPageTorch) -> int: resident_nbytes = int(page.resident_nbytes) return resident_nbytes if resident_nbytes > 0 else int(page.host_to_device_nbytes) def _collect_state_direct_prepared_pages(self, state: _HeadSessionState) -> dict[int, int]: direct_pages: dict[int, int] = {} for page in state.session.key_pages: if isinstance(page, PreparedPageTorch) and not self.cache.owns_prepared_page(page): direct_pages[id(page)] = self._prepared_page_resident_bytes(page) for page in state.session.value_pages: if isinstance(page, PreparedPageTorch) and not self.cache.owns_prepared_page(page): direct_pages[id(page)] = self._prepared_page_resident_bytes(page) return direct_pages @staticmethod def _collect_state_tail_resident_bytes(state: _HeadSessionState) -> int: total = 0 if state.persistent_key_tail is not None: total += int(state.persistent_key_tail.resident_nbytes) if state.persistent_value_tail is not None: total += int(state.persistent_value_tail.resident_nbytes) return total def _reset_resident_accounting(self) -> None: self._direct_prepared_page_resident_bytes = 0 self._direct_prepared_page_refcounts.clear() self._direct_prepared_page_sizes.clear() self._tail_resident_bytes = 0 for state in self._states.values(): state.tracked_direct_prepared_pages.clear() state.tracked_tail_resident_bytes = 0 def _refresh_state_resident_accounting(self, state: _HeadSessionState) -> bool: resident_bytes_changed = False new_direct_pages = self._collect_state_direct_prepared_pages(state) old_direct_pages = state.tracked_direct_prepared_pages removed_page_ids = set(old_direct_pages) - set(new_direct_pages) for page_id in removed_page_ids: refcount = int(self._direct_prepared_page_refcounts.get(page_id, 0)) - 1 if refcount <= 0: self._direct_prepared_page_refcounts.pop(page_id, None) removed_size = int(self._direct_prepared_page_sizes.pop(page_id, old_direct_pages[page_id])) self._direct_prepared_page_resident_bytes = max( 0, int(self._direct_prepared_page_resident_bytes) - removed_size, ) else: self._direct_prepared_page_refcounts[page_id] = refcount resident_bytes_changed = True added_page_ids = set(new_direct_pages) - set(old_direct_pages) for page_id in added_page_ids: page_size = int(new_direct_pages[page_id]) refcount = int(self._direct_prepared_page_refcounts.get(page_id, 0)) if refcount <= 0: self._direct_prepared_page_sizes[page_id] = page_size self._direct_prepared_page_resident_bytes += page_size resident_bytes_changed = True self._direct_prepared_page_refcounts[page_id] = refcount + 1 for page_id, page_size in new_direct_pages.items(): if page_id not in old_direct_pages: continue previous_size = int(self._direct_prepared_page_sizes.get(page_id, old_direct_pages[page_id])) if previous_size == int(page_size): continue self._direct_prepared_page_sizes[page_id] = int(page_size) self._direct_prepared_page_resident_bytes += int(page_size) - previous_size resident_bytes_changed = True new_tail_resident_bytes = self._collect_state_tail_resident_bytes(state) if new_tail_resident_bytes != int(state.tracked_tail_resident_bytes): self._tail_resident_bytes += int(new_tail_resident_bytes) - int(state.tracked_tail_resident_bytes) state.tracked_tail_resident_bytes = int(new_tail_resident_bytes) resident_bytes_changed = True state.tracked_direct_prepared_pages = new_direct_pages return resident_bytes_changed def _rebuild_resident_accounting(self) -> None: self._reset_resident_accounting() for state in self._states.values(): self._refresh_state_resident_accounting(state) def _reset_chunk_budget_tracking(self) -> None: self._chunk_budget_dirty_marks = 0 self._chunk_budget_dirty_transitions = 0 self._chunk_budget_dirty_reason_counts = {} self._chunk_budget_sync_invocations = 0 self._chunk_budget_sync_clean_skips = 0 self._chunk_budget_sync_dirty_invocations = 0 self._chunk_budget_override_calls = 0 self._chunk_budget_override_budget_change_calls = 0 self._chunk_budget_override_same_budget_calls = 0 self._chunk_budget_freeze_override_calls = 0 def _reset_builtin_selector_tracking(self) -> None: self._builtin_selector_score_all_pages_calls = 0 self._builtin_selector_candidate_only_calls = 0 self._builtin_selector_candidate_pages = 0 self._builtin_selector_total_pages = 0 self._builtin_selector_candidate_fraction_sum = 0.0 self._builtin_selector_candidate_fraction_max = 0.0 self._builtin_selector_cache_hits = 0 self._builtin_selector_cache_builds = 0 self._builtin_selector_cache_build_bytes = 0 self._builtin_selector_cache_build_bytes_max = 0 def _reset_execution_value_escape_tracking(self) -> None: self._execution_value_escape_cache.clear() self._execution_value_escape_source_pages.clear() self._execution_value_escape_cache_hits = 0 self._execution_value_escape_source_registrations = 0 self._execution_value_escape_prepared_page_builds = 0 self._execution_value_escape_prewarm_invocations = 0 self._execution_value_escape_prewarm_pages = 0 self._execution_value_escape_prewarm_ms_total = 0.0 self._execution_value_escape_builds = 0 self._execution_value_escape_applied_pages = 0 def _kv_resident_byte_summary(self) -> dict[str, int]: static_resident_bytes = int(self._direct_prepared_page_resident_bytes) tail_resident_bytes = int(self._tail_resident_bytes) kv_resident_bytes = int(self.cache.resident_bytes) + static_resident_bytes + tail_resident_bytes return { "prepared_page_cache_resident_bytes": int(self.cache.resident_bytes), "direct_page_resident_bytes": int(static_resident_bytes), "tail_resident_bytes": int(tail_resident_bytes), "kv_resident_bytes": int(kv_resident_bytes), } def _prepared_chunk_cache_budget_bytes(self, *, kv_resident_bytes: int | None = None) -> int: if self._torch_device_type is None: return 0 budget_ratio = float(self.config.prepared_chunk_cache_budget_ratio) if budget_ratio <= 0.0 or int(self.config.prepared_chunk_cache_max_bytes) <= 0: return 0 if kv_resident_bytes is None: kv_resident_bytes = int(self._kv_resident_byte_summary()["kv_resident_bytes"]) adaptive_budget = max( int(self.config.prepared_chunk_cache_min_bytes), int(kv_resident_bytes * budget_ratio), ) return min(int(self.config.prepared_chunk_cache_max_bytes), adaptive_budget) def _sync_prepared_chunk_cache_budget(self, *, freeze_during_decode: bool = False) -> None: if self._torch_device_type is None: return self._chunk_budget_sync_invocations += 1 if not self._prepared_chunk_cache_budget_dirty: self._chunk_budget_sync_clean_skips += 1 return self._chunk_budget_sync_dirty_invocations += 1 if bool(freeze_during_decode and self.config.execution_freeze_chunk_budget_during_decode): if self._prepared_chunk_cache_frozen_budget_bytes is None: self._prepared_chunk_cache_frozen_budget_bytes = int(self._prepared_chunk_cache_budget_bytes()) if self._prepared_chunk_cache_applied_budget_bytes != int(self._prepared_chunk_cache_frozen_budget_bytes): applied_budget_bytes = self._prepared_chunk_cache_applied_budget_bytes set_prepared_chunk_cache_budget_override( max_resident_bytes=self._prepared_chunk_cache_frozen_budget_bytes, ) self._chunk_budget_override_calls += 1 self._chunk_budget_freeze_override_calls += 1 if applied_budget_bytes == int(self._prepared_chunk_cache_frozen_budget_bytes): self._chunk_budget_override_same_budget_calls += 1 else: self._chunk_budget_override_budget_change_calls += 1 self._prepared_chunk_cache_applied_budget_bytes = int(self._prepared_chunk_cache_frozen_budget_bytes) self._prepared_chunk_cache_budget_dirty = False return self._prepared_chunk_cache_frozen_budget_bytes = None budget_bytes = int(self._prepared_chunk_cache_budget_bytes()) applied_budget_bytes = self._prepared_chunk_cache_applied_budget_bytes set_prepared_chunk_cache_budget_override( max_resident_bytes=budget_bytes, ) self._chunk_budget_override_calls += 1 if applied_budget_bytes == budget_bytes: self._chunk_budget_override_same_budget_calls += 1 else: self._chunk_budget_override_budget_change_calls += 1 self._prepared_chunk_cache_applied_budget_bytes = budget_bytes self._prepared_chunk_cache_budget_dirty = False def _mark_prepared_chunk_cache_budget_dirty(self, *, reason: str) -> None: if self._torch_device_type is None: return self._chunk_budget_dirty_marks += 1 self._chunk_budget_dirty_reason_counts[str(reason)] = ( int(self._chunk_budget_dirty_reason_counts.get(str(reason), 0)) + 1 ) if not self._prepared_chunk_cache_budget_dirty: self._chunk_budget_dirty_transitions += 1 self._prepared_chunk_cache_budget_dirty = True def chunk_budget_summary(self) -> dict[str, object]: return { "execution_chunk_budget_dirty_marks": int(self._chunk_budget_dirty_marks), "execution_chunk_budget_dirty_transitions": int(self._chunk_budget_dirty_transitions), "execution_chunk_budget_dirty_reason_counts": { reason: int(count) for reason, count in sorted(self._chunk_budget_dirty_reason_counts.items()) }, "execution_chunk_budget_sync_invocations": int(self._chunk_budget_sync_invocations), "execution_chunk_budget_sync_clean_skips": int(self._chunk_budget_sync_clean_skips), "execution_chunk_budget_sync_dirty_invocations": int(self._chunk_budget_sync_dirty_invocations), "execution_chunk_budget_override_calls": int(self._chunk_budget_override_calls), "execution_chunk_budget_override_budget_change_calls": int( self._chunk_budget_override_budget_change_calls ), "execution_chunk_budget_override_same_budget_calls": int( self._chunk_budget_override_same_budget_calls ), "execution_chunk_budget_freeze_override_calls": int(self._chunk_budget_freeze_override_calls), } def builtin_selector_summary(self) -> dict[str, int | float]: return { "execution_builtin_selector_score_all_pages_calls": int(self._builtin_selector_score_all_pages_calls), "execution_builtin_selector_candidate_only_calls": int(self._builtin_selector_candidate_only_calls), "execution_builtin_selector_candidate_pages": int(self._builtin_selector_candidate_pages), "execution_builtin_selector_total_pages": int(self._builtin_selector_total_pages), "execution_builtin_selector_candidate_fraction_sum": float( self._builtin_selector_candidate_fraction_sum ), "execution_builtin_selector_candidate_fraction_max": float( self._builtin_selector_candidate_fraction_max ), "execution_builtin_selector_cache_hits": int(self._builtin_selector_cache_hits), "execution_builtin_selector_cache_builds": int(self._builtin_selector_cache_builds), "execution_builtin_selector_cache_build_bytes": int(self._builtin_selector_cache_build_bytes), "execution_builtin_selector_cache_build_bytes_max": int(self._builtin_selector_cache_build_bytes_max), } def execution_value_escape_summary(self) -> dict[str, int | str | list[int]]: return { "execution_value_escape_layers": [int(layer_id) for layer_id in self.config.execution_value_escape_layers], "execution_value_escape_mode": str(self.config.execution_value_escape_mode), "execution_value_escape_old_only": bool(self.config.execution_value_escape_old_only), "execution_value_escape_top_k": int(self.config.execution_value_escape_top_k), "execution_value_escape_prewarm": bool(self.config.execution_value_escape_prewarm), "execution_value_escape_prewarm_min_context": int(self.config.execution_value_escape_prewarm_min_context), "execution_value_escape_cache_hits": int(self._execution_value_escape_cache_hits), "execution_value_escape_source_registrations": int(self._execution_value_escape_source_registrations), "execution_value_escape_prepared_page_builds": int(self._execution_value_escape_prepared_page_builds), "execution_value_escape_prewarm_invocations": int(self._execution_value_escape_prewarm_invocations), "execution_value_escape_prewarm_pages": int(self._execution_value_escape_prewarm_pages), "execution_value_escape_prewarm_ms_total": float(self._execution_value_escape_prewarm_ms_total), "execution_value_escape_builds": int(self._execution_value_escape_builds), "execution_value_escape_applied_pages": int(self._execution_value_escape_applied_pages), } def _execution_value_escape_cache_key(self, page: PageLike, *, escape_mode: str) -> tuple[int, int, int, str, str]: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page header = source_page.header return ( id(source_page), int(header.token_start), int(header.token_count), str(escape_mode), str(self._torch_device_type or "cpu_ref"), ) def _execution_value_escape_source_key(self, page: PageLike, *, escape_mode: str) -> tuple[int, str]: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page return (id(source_page), str(escape_mode)) def _maybe_register_execution_value_escape_source( self, source_page: EncodedPage, *, dense_values: np.ndarray, escape_mode: str, ) -> None: if not self.config.execution_value_escape_enabled_for_layer(layer_id=int(source_page.header.layer_id)): return if str(source_page.header.kind) != "V": return if str(source_page.header.mode_default) == str(escape_mode): return source_key = self._execution_value_escape_source_key(source_page, escape_mode=escape_mode) if source_key in self._execution_value_escape_source_pages: return self._execution_value_escape_source_pages[source_key] = encode_page( np.asarray(dense_values, dtype=np.float32, copy=False), self.config, kind="V", layer_id=int(source_page.header.layer_id), kv_head_id=int(source_page.header.kv_head_id), token_start=int(source_page.header.token_start), mode=str(escape_mode), build_runtime_metadata=False, ) self._execution_value_escape_source_registrations += 1 self._execution_value_escape_builds += 1 def _prepare_execution_value_escape_page( self, page: PageLike, *, escape_mode: str, trace: ExecutionTrace | None = None, ) -> PageLike: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if str(source_page.header.kind) != "V": raise ValueError("execution value escape requires V pages") if str(source_page.header.mode_default) == str(escape_mode): return page source_key = self._execution_value_escape_source_key(page, escape_mode=escape_mode) exact_source_page = self._execution_value_escape_source_pages.get(source_key) if exact_source_page is not None: exact_cache_key = self._execution_value_escape_cache_key(exact_source_page, escape_mode=escape_mode) cached_exact_page = self._execution_value_escape_cache.get(exact_cache_key) if cached_exact_page is not None: self._execution_value_escape_cache_hits += 1 return cached_exact_page prepared_exact_page = prepare_pages( [exact_source_page], backend=self.backend, cache=self.cache, trace=trace, )[0] self._execution_value_escape_cache[exact_cache_key] = prepared_exact_page self._execution_value_escape_prepared_page_builds += 1 self._execution_value_escape_builds += 1 return prepared_exact_page cache_key = self._execution_value_escape_cache_key(page, escape_mode=escape_mode) cached_page = self._execution_value_escape_cache.get(cache_key) if cached_page is not None: self._execution_value_escape_cache_hits += 1 return cached_page dense_values = decode_page(source_page).astype(np.float32, copy=False) escaped_page = encode_page( dense_values, self.config, kind="V", layer_id=int(source_page.header.layer_id), kv_head_id=int(source_page.header.kv_head_id), token_start=int(source_page.header.token_start), mode=str(escape_mode), build_runtime_metadata=False, ) prepared_escape_page = prepare_pages( [escaped_page], backend=self.backend, cache=self.cache, trace=trace, )[0] self._execution_value_escape_cache[cache_key] = prepared_escape_page self._execution_value_escape_prepared_page_builds += 1 self._execution_value_escape_builds += 1 return prepared_escape_page def _maybe_prewarm_execution_value_escape_pages( self, state: _HeadSessionState, *, trace: ExecutionTrace | None = None, ) -> None: if not bool(self.config.execution_value_escape_prewarm): return layer_id = int(state.tail.layer_id) if not self.config.execution_value_escape_enabled_for_layer(layer_id=layer_id): return min_context = max(0, int(self.config.execution_value_escape_prewarm_min_context)) if min_context > 0 and int(state.sequence_length) < min_context: return if bool(self.config.execution_value_escape_old_only) or int(self.config.execution_value_escape_top_k) > 0: return if not state.session.value_pages: return started_at = perf_counter() prewarmed_pages = 0 escape_mode = str(self.config.execution_value_escape_mode) for page in state.session.value_pages: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if str(source_page.header.kind) != "V": continue if str(source_page.header.mode_default) == escape_mode: continue prepared_page = self._prepare_execution_value_escape_page(page, escape_mode=escape_mode, trace=trace) if prepared_page is not page: prewarmed_pages += 1 if prewarmed_pages <= 0: return self._execution_value_escape_prewarm_invocations += 1 self._execution_value_escape_prewarm_pages += int(prewarmed_pages) self._execution_value_escape_prewarm_ms_total += float((perf_counter() - started_at) * 1000.0) def _apply_execution_value_escape( self, *, layer_id: int, key_pages_by_group: Sequence[Sequence[PageLike]], value_pages_by_group: Sequence[Sequence[PageLike]], context_lengths_by_group: Sequence[int] | None = None, representative_queries_by_group: Sequence[np.ndarray] | None = None, trace: ExecutionTrace | None = None, ) -> tuple[list[Sequence[PageLike]], bool]: if not self.config.execution_value_escape_enabled_for_layer(layer_id=layer_id): return [list(pages) for pages in value_pages_by_group], False escape_mode = str(self.config.execution_value_escape_mode) escape_top_k = max(0, int(self.config.execution_value_escape_top_k)) escaped_groups: list[Sequence[PageLike]] = [] any_applied = False for group_index, (key_pages, value_pages) in enumerate(zip(key_pages_by_group, value_pages_by_group, strict=True)): escaped_pages: list[PageLike] = [] group_applied = False eligible_indices: set[int] | None = None if bool(self.config.execution_value_escape_old_only): context_length = None if context_lengths_by_group is not None and group_index < len(context_lengths_by_group): context_length = int(context_lengths_by_group[group_index]) layer_recent_window = self.config.resolve_execution_recent_window_for_context( layer_id=layer_id, context_length=context_length, ) eligible_indices = set(range(len(key_pages))) - set( select_window_page_indices( key_pages, recent_window_tokens=layer_recent_window if layer_recent_window > 0 else None, sink_window_tokens=int(self.config.execution_sink_window), ) ) if escape_top_k > 0: query_slice = None if representative_queries_by_group is not None and group_index < len(representative_queries_by_group): query_slice = np.asarray(representative_queries_by_group[group_index], dtype=np.float32) ranked_candidate_indices: list[int] = [] if query_slice is not None: scored_indices: list[tuple[float, int]] = [] candidate_pool = range(len(key_pages)) if eligible_indices is None else sorted(eligible_indices) for page_index in candidate_pool: score = _score_page_relevance_for_mode( query_slice, key_pages[page_index], relevance_mode=self.config.execution_relevance_mode, ) if score is not None: scored_indices.append((float(score), int(page_index))) if scored_indices: scored_indices.sort(key=lambda item: item[0], reverse=True) ranked_candidate_indices = [index for _, index in scored_indices[:escape_top_k]] if ranked_candidate_indices: eligible_indices = set(ranked_candidate_indices) elif eligible_indices is None: eligible_indices = set(range(min(len(key_pages), escape_top_k))) for page_index, page in enumerate(value_pages): if eligible_indices is not None and page_index not in eligible_indices: escaped_pages.append(page) continue source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if str(source_page.header.mode_default) == escape_mode: escaped_pages.append(page) continue escaped_page = self._prepare_execution_value_escape_page( page, escape_mode=escape_mode, trace=trace, ) escaped_pages.append(escaped_page) if escaped_page is not page: group_applied = True self._execution_value_escape_applied_pages += 1 escaped_groups.append(escaped_pages if group_applied else list(value_pages)) any_applied = any_applied or group_applied return escaped_groups, any_applied def resident_byte_summary(self) -> dict[str, int]: summary = self._kv_resident_byte_summary() chunk_resident_bytes = prepared_chunk_cache_resident_bytes() if self._torch_device_type is not None else 0 budget_bytes = self._prepared_chunk_cache_budget_bytes(kv_resident_bytes=int(summary["kv_resident_bytes"])) return { **summary, "prepared_chunk_cache_budget_bytes": int(budget_bytes), "prepared_chunk_resident_bytes": int(chunk_resident_bytes), "resident_bytes": int(summary["kv_resident_bytes"] + chunk_resident_bytes), } def _record_decode_path(self, layer_id: int, path_name: str) -> None: if path_name not in self._decode_path_counts: raise ValueError(f"unknown decode path: {path_name}") self._decode_path_counts[path_name] += 1 layer_counts = self._decode_path_counts_by_layer.setdefault(int(layer_id), {}) layer_counts[path_name] = layer_counts.get(path_name, 0) + 1 def decode_path_summary(self) -> dict[str, object]: return { "decode_path_counts": dict(sorted(self._decode_path_counts.items())), "decode_path_counts_by_layer": { str(layer_id): dict(sorted(counts.items())) for layer_id, counts in sorted(self._decode_path_counts_by_layer.items()) }, "decode_grouped_batch_rejection_reason_counts": dict( sorted(self._decode_grouped_batch_rejection_reason_counts.items()) ), "decode_grouped_batch_rejection_reason_counts_by_layer": { str(layer_id): dict(sorted(counts.items())) for layer_id, counts in sorted(self._decode_grouped_batch_rejection_reason_counts_by_layer.items()) }, } def _record_execution_shortlist( self, *, layer_id: int, total_pages: int, selected_pages: int, applied: bool, group_union_applied: bool = False, grouping_rejected: bool = False, grouping_rejection_reason: str | None = None, ) -> None: self._execution_shortlist_invocations += 1 self._execution_shortlist_total_pages += int(total_pages) self._execution_shortlist_selected_pages += int(selected_pages) self._execution_shortlist_invocations_by_layer[int(layer_id)] = ( self._execution_shortlist_invocations_by_layer.get(int(layer_id), 0) + 1 ) self._execution_shortlist_total_pages_by_layer[int(layer_id)] = ( self._execution_shortlist_total_pages_by_layer.get(int(layer_id), 0) + int(total_pages) ) self._execution_shortlist_selected_pages_by_layer[int(layer_id)] = ( self._execution_shortlist_selected_pages_by_layer.get(int(layer_id), 0) + int(selected_pages) ) if applied: self._execution_shortlist_applied += 1 self._execution_shortlist_applied_by_layer[int(layer_id)] = ( self._execution_shortlist_applied_by_layer.get(int(layer_id), 0) + 1 ) if group_union_applied: self._execution_shortlist_group_union_applied += 1 self._execution_shortlist_group_union_applied_by_layer[int(layer_id)] = ( self._execution_shortlist_group_union_applied_by_layer.get(int(layer_id), 0) + 1 ) if grouping_rejected: self._execution_shortlist_grouping_rejections += 1 self._execution_shortlist_grouping_rejections_by_layer[int(layer_id)] = ( self._execution_shortlist_grouping_rejections_by_layer.get(int(layer_id), 0) + 1 ) if grouping_rejection_reason: self._execution_shortlist_grouping_rejection_reason_counts[grouping_rejection_reason] = ( self._execution_shortlist_grouping_rejection_reason_counts.get(grouping_rejection_reason, 0) + 1 ) layer_reason_counts = self._execution_shortlist_grouping_rejection_reason_counts_by_layer.setdefault( int(layer_id), {} ) layer_reason_counts[grouping_rejection_reason] = ( layer_reason_counts.get(grouping_rejection_reason, 0) + 1 ) def _record_decode_grouped_batch_rejection(self, *, layer_id: int, reason: str) -> None: self._decode_grouped_batch_rejection_reason_counts[reason] = ( self._decode_grouped_batch_rejection_reason_counts.get(reason, 0) + 1 ) layer_reason_counts = self._decode_grouped_batch_rejection_reason_counts_by_layer.setdefault(int(layer_id), {}) layer_reason_counts[reason] = layer_reason_counts.get(reason, 0) + 1 def _record_execution_exact_refine( self, *, layer_id: int, candidate_pages: int, selected_pages: int, ) -> None: self._execution_exact_refine_invocations += 1 self._execution_exact_refine_candidate_pages += int(candidate_pages) self._execution_exact_refine_selected_pages += int(selected_pages) self._execution_exact_refine_invocations_by_layer[int(layer_id)] = ( self._execution_exact_refine_invocations_by_layer.get(int(layer_id), 0) + 1 ) self._execution_exact_refine_candidate_pages_by_layer[int(layer_id)] = ( self._execution_exact_refine_candidate_pages_by_layer.get(int(layer_id), 0) + int(candidate_pages) ) self._execution_exact_refine_selected_pages_by_layer[int(layer_id)] = ( self._execution_exact_refine_selected_pages_by_layer.get(int(layer_id), 0) + int(selected_pages) ) def _record_decode_stage_timing(self, *, layer_id: int, stage: str, ms: float) -> None: if stage not in _DECODE_STAGE_TIMING_STAGES: raise ValueError(f"unknown decode stage timing: {stage}") self._decode_stage_timings[stage] += float(ms) layer_timings = self._decode_stage_timings_by_layer.setdefault(int(layer_id), {}) layer_timings[stage] = float(layer_timings.get(stage, 0.0) + float(ms)) def _record_builtin_selector_stats( self, *, candidate_pages: int, total_pages: int, candidate_fraction: float, used_score_all_pages: bool, ) -> None: if used_score_all_pages: self._builtin_selector_score_all_pages_calls += 1 else: self._builtin_selector_candidate_only_calls += 1 self._builtin_selector_candidate_pages += int(candidate_pages) self._builtin_selector_total_pages += int(total_pages) self._builtin_selector_candidate_fraction_sum += float(candidate_fraction) self._builtin_selector_candidate_fraction_max = max( float(self._builtin_selector_candidate_fraction_max), float(candidate_fraction), ) def decode_stage_runtime_totals(self) -> dict[str, float]: return { _decode_stage_summary_key(stage): float(self._decode_stage_timings.get(stage, 0.0)) for stage in _DECODE_STAGE_TIMING_STAGES } def _execution_builtin_selector_cache_for_state( self, state: _HeadSessionState, *, relevance_mode: str, ) -> _ExecutionBuiltinSelectorCache | None: direct_key_pages = state.session.key_pages if not direct_key_pages: return None page_signature = tuple( ( id(page.source_page if isinstance(page, PreparedPageTorch) else page), int(_page_header(page).token_start), int(_page_header(page).token_count), ) for page in direct_key_pages ) cache = state.execution_builtin_selector_cache if cache is not None and cache.page_signature == page_signature: if relevance_mode == "sketch" and cache.sketch_matrix is not None: self._builtin_selector_cache_hits += 1 return cache if relevance_mode == "envelope" and cache.minima_matrix is not None and cache.maxima_matrix is not None: self._builtin_selector_cache_hits += 1 return cache if relevance_mode == "sketch": sketches = [] for page in direct_key_pages: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if source_page.runtime_page_sketch is None: return None sketches.append(np.asarray(source_page.runtime_page_sketch, dtype=np.float32)) cache = _ExecutionBuiltinSelectorCache( page_signature=page_signature, sketch_matrix=np.stack(sketches, axis=0), ) elif relevance_mode == "envelope": minima = [] maxima = [] for page in direct_key_pages: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page if source_page.runtime_page_min is None or source_page.runtime_page_max is None: return None minima.append(np.asarray(source_page.runtime_page_min, dtype=np.float32)) maxima.append(np.asarray(source_page.runtime_page_max, dtype=np.float32)) cache = _ExecutionBuiltinSelectorCache( page_signature=page_signature, minima_matrix=np.stack(minima, axis=0), maxima_matrix=np.stack(maxima, axis=0), ) else: return None build_bytes = int(cache.resident_bytes()) self._builtin_selector_cache_builds += 1 self._builtin_selector_cache_build_bytes += int(build_bytes) self._builtin_selector_cache_build_bytes_max = max( int(self._builtin_selector_cache_build_bytes_max), int(build_bytes), ) state.execution_builtin_selector_cache = cache return cache def _should_prewarm_execution_builtin_selector_cache(self) -> bool: return bool( self.config.execution_builtin_selector_cache and ( int(self.config.execution_relevance_top_k) > 0 or bool(self.config.execution_relevance_top_k_overrides) or bool(self.config.execution_relevance_top_k_context_overrides) ) ) def _maybe_prewarm_execution_builtin_selector_cache(self, state: _HeadSessionState) -> None: if not self._should_prewarm_execution_builtin_selector_cache(): return if not state.session.key_pages: return if not all(isinstance(page, PreparedPageTorch) for page in state.session.key_pages): return self._execution_builtin_selector_cache_for_state( state, relevance_mode=self.config.execution_relevance_mode, ) def _execution_builtin_selector_matrices( self, *, layer_id: int, kv_head_id: int | None, key_pages: Sequence[PageLike], relevance_mode: str, score_all_pages_with_matrices: bool, ) -> tuple[np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None, np.ndarray | None]: if not self.config.execution_builtin_selector_cache or kv_head_id is None: return None, None, None, None, None, None state = self._state(layer_id, kv_head_id) cache = self._execution_builtin_selector_cache_for_state(state, relevance_mode=relevance_mode) if cache is None: return None, None, None, None, None, None direct_key_pages = state.session.key_pages direct_count = len(direct_key_pages) if len(key_pages) not in (direct_count, direct_count + 1): return None, None, None, None, None, None if any(key_pages[index] is not direct_key_pages[index] for index in range(direct_count)): return None, None, None, None, None, None def _tail_source_page() -> EncodedPage | None: if len(key_pages) != direct_count + 1: return None tail_page = key_pages[-1] return tail_page.source_page if isinstance(tail_page, PreparedPageTorch) else tail_page tail_source_page = _tail_source_page() if relevance_mode == "sketch": if cache.sketch_matrix is None: return None, None, None, None, None, None if tail_source_page is None: return cache.sketch_matrix, None, None, None, None, None if tail_source_page.runtime_page_sketch is None: return None, None, None, None, None, None if score_all_pages_with_matrices: return ( cache.sketch_matrix, None, None, np.asarray(tail_source_page.runtime_page_sketch, dtype=np.float32), None, None, ) return ( np.concatenate( [ cache.sketch_matrix, np.asarray(tail_source_page.runtime_page_sketch, dtype=np.float32)[None, ...], ], axis=0, ), None, None, None, None, None, ) if cache.minima_matrix is None or cache.maxima_matrix is None: return None, None, None, None, None, None if tail_source_page is None: return None, cache.minima_matrix, cache.maxima_matrix, None, None, None if tail_source_page.runtime_page_min is None or tail_source_page.runtime_page_max is None: return None, None, None, None, None, None if score_all_pages_with_matrices: return ( None, cache.minima_matrix, cache.maxima_matrix, None, np.asarray(tail_source_page.runtime_page_min, dtype=np.float32), np.asarray(tail_source_page.runtime_page_max, dtype=np.float32), ) return ( None, np.concatenate( [ cache.minima_matrix, np.asarray(tail_source_page.runtime_page_min, dtype=np.float32)[None, :], ], axis=0, ), np.concatenate( [ cache.maxima_matrix, np.asarray(tail_source_page.runtime_page_max, dtype=np.float32)[None, :], ], axis=0, ), None, None, None, ) def decode_stage_summary(self) -> dict[str, object]: summary: dict[str, object] = self.decode_stage_runtime_totals() for stage in _DECODE_STAGE_TIMING_STAGES: summary[f"{_decode_stage_summary_key(stage)}_by_layer"] = { str(layer_id): float(layer_timings.get(stage, 0.0)) for layer_id, layer_timings in sorted(self._decode_stage_timings_by_layer.items()) if float(layer_timings.get(stage, 0.0)) != 0.0 } return summary def execution_shortlist_summary(self) -> dict[str, object]: return { "execution_shortlist_invocations": int(self._execution_shortlist_invocations), "execution_shortlist_applied": int(self._execution_shortlist_applied), "execution_shortlist_group_union_applied": int(self._execution_shortlist_group_union_applied), "execution_shortlist_grouping_rejections": int(self._execution_shortlist_grouping_rejections), "execution_shortlist_total_pages": int(self._execution_shortlist_total_pages), "execution_shortlist_selected_pages": int(self._execution_shortlist_selected_pages), "execution_shortlist_invocations_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_shortlist_invocations_by_layer.items()) }, "execution_shortlist_applied_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_shortlist_applied_by_layer.items()) }, "execution_shortlist_group_union_applied_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_shortlist_group_union_applied_by_layer.items()) }, "execution_shortlist_grouping_rejection_reason_counts": dict( sorted(self._execution_shortlist_grouping_rejection_reason_counts.items()) ), "execution_shortlist_grouping_rejection_reason_counts_by_layer": { str(layer_id): dict(sorted(counts.items())) for layer_id, counts in sorted(self._execution_shortlist_grouping_rejection_reason_counts_by_layer.items()) }, "execution_shortlist_grouping_rejections_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_shortlist_grouping_rejections_by_layer.items()) }, "execution_shortlist_total_pages_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_shortlist_total_pages_by_layer.items()) }, "execution_shortlist_selected_pages_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_shortlist_selected_pages_by_layer.items()) }, "execution_shortlist_trace_records": list(self._execution_shortlist_trace_records), "execution_exact_refine_invocations": int(self._execution_exact_refine_invocations), "execution_exact_refine_candidate_pages": int(self._execution_exact_refine_candidate_pages), "execution_exact_refine_selected_pages": int(self._execution_exact_refine_selected_pages), "execution_exact_refine_invocations_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_exact_refine_invocations_by_layer.items()) }, "execution_exact_refine_candidate_pages_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_exact_refine_candidate_pages_by_layer.items()) }, "execution_exact_refine_selected_pages_by_layer": { str(layer_id): int(count) for layer_id, count in sorted(self._execution_exact_refine_selected_pages_by_layer.items()) }, } def _execution_exact_refine_enabled(self, *, layer_id: int) -> bool: if self.config.execution_exact_refine_top_k <= 0: return False if not self.config.execution_exact_refine_layers: return False return int(layer_id) in {int(value) for value in self.config.execution_exact_refine_layers} def _execution_exact_promote_enabled(self, *, layer_id: int, context_length: int | None = None) -> bool: enabled, _ = self._execution_exact_promote_policy_status(layer_id=layer_id, context_length=context_length) return enabled def _execution_exact_promote_policy_status( self, *, layer_id: int, context_length: int | None = None, ) -> tuple[bool, str | None]: if self.config.execution_exact_promote_top_k <= 0: return False, "top_k_disabled" if not self.config.execution_exact_promote_layers: return False, "no_layers_configured" if int(layer_id) not in {int(value) for value in self.config.execution_exact_promote_layers}: return False, "layer_not_selected" if ( int(self.config.execution_exact_promote_max_context) > 0 and context_length is not None and int(context_length) > int(self.config.execution_exact_promote_max_context) ): return False, "context_exceeds_max_context" return True, None def _execution_exact_promote_status( self, *, layer_id: int, context_length: int | None = None, boundary_margin_normalized: float | None = None, ) -> tuple[bool, str | None]: enabled, reason = self._execution_exact_promote_policy_status( layer_id=layer_id, context_length=context_length, ) if not enabled: return False, reason if ( float(self.config.execution_exact_promote_min_margin_threshold) > 0.0 and ( boundary_margin_normalized is None or boundary_margin_normalized < float(self.config.execution_exact_promote_min_margin_threshold) ) ): return False, "below_min_margin_threshold" if ( float(self.config.execution_exact_promote_margin_threshold) > 0.0 and boundary_margin_normalized is not None and boundary_margin_normalized > float(self.config.execution_exact_promote_margin_threshold) ): return False, "above_max_margin_threshold" return True, None def _execution_secondary_relevance_enabled(self, *, layer_id: int) -> bool: return self.config.execution_secondary_relevance_enabled_for_layer(layer_id=layer_id) def _execution_recent_neighbor_rescue_enabled(self, *, layer_id: int) -> bool: return self.config.execution_recent_neighbor_rescue_enabled_for_layer(layer_id=layer_id) def _execution_exact_promote_union_rescue_enabled(self, *, layer_id: int) -> bool: return ( int(self.config.execution_exact_promote_union_rescue_top_k) > 0 and int(layer_id) in {int(value) for value in self.config.execution_exact_promote_layers} ) def _apply_execution_exact_promote_union_rescue( self, *, layer_id: int, selected_indices_by_group: Sequence[list[int] | None], key_pages_by_group: Sequence[Sequence[PageLike]], representative_queries: Sequence[np.ndarray], shortlist_traces_by_group: Sequence[dict[str, object] | None], trace: ExecutionTrace | None = None, ) -> tuple[list[list[int] | None], list[dict[str, object]]]: if not self._execution_exact_promote_union_rescue_enabled(layer_id=layer_id): return [None if indices is None else list(indices) for indices in selected_indices_by_group], [] adjusted_indices_by_group = [None if indices is None else list(indices) for indices in selected_indices_by_group] baseline_selected_index_sets = [ set() if indices is None else {int(index) for index in indices} for indices in selected_indices_by_group ] rescue_records: list[dict[str, object]] = [] union_rescue_top_k = int(self.config.execution_exact_promote_union_rescue_top_k) eligible_group_records: list[dict[str, object]] = [] for group_index, (indices, key_pages, query_slice, shortlist_trace) in enumerate( zip( adjusted_indices_by_group, key_pages_by_group, representative_queries, shortlist_traces_by_group, strict=True, ) ): if indices is None: rescue_records.append( { "record_type": "union_rescue", "layer_id": int(layer_id), "group_index": int(group_index), "applied": False, "disable_reason": "no_shortlist", } ) continue if shortlist_trace is None: rescue_records.append( { "record_type": "union_rescue", "layer_id": int(layer_id), "group_index": int(group_index), "applied": False, "disable_reason": "missing_shortlist_trace", } ) continue context_length = shortlist_trace.get("context_length") policy_enabled, policy_disable_reason = self._execution_exact_promote_policy_status( layer_id=layer_id, context_length=None if context_length is None else int(context_length), ) if not policy_enabled: rescue_records.append( { "record_type": "union_rescue", "layer_id": int(layer_id), "group_index": int(group_index), "applied": False, "disable_reason": policy_disable_reason, } ) continue base_indices = {int(index) for index in shortlist_trace.get("base_indices", [])} selected_index_set = baseline_selected_index_sets[group_index] other_selected_indices = set().union( *[ selected_index_set for other_group_index, selected_index_set in enumerate(baseline_selected_index_sets) if other_group_index != group_index ] ) novel_candidate_indices = [ int(index) for index in range(len(key_pages)) if int(index) not in base_indices and int(index) not in selected_index_set and int(index) not in other_selected_indices ] if not novel_candidate_indices: rescue_records.append( { "record_type": "union_rescue", "layer_id": int(layer_id), "group_index": int(group_index), "applied": False, "disable_reason": "no_novel_candidates", "base_count": int(len(base_indices)), "selected_count": int(len(selected_index_set)), } ) continue eligible_group_records.append( { "group_index": int(group_index), "kv_head_id": shortlist_trace.get("kv_head_id"), "indices": indices, "key_pages": key_pages, "query_slice": np.asarray(query_slice, dtype=np.float32), "base_indices": base_indices, "selected_index_set": selected_index_set, "novel_candidate_indices": novel_candidate_indices, } ) if not eligible_group_records: return adjusted_indices_by_group, rescue_records scored_union_candidates: list[tuple[float, int, int]] = [] for eligible_group_record in eligible_group_records: novel_candidate_indices = list(eligible_group_record["novel_candidate_indices"]) novel_candidate_logits = score_pages( eligible_group_record["query_slice"], [eligible_group_record["key_pages"][index] for index in novel_candidate_indices], backend=self.backend, trace=trace, ) for index, logits in zip(novel_candidate_indices, novel_candidate_logits, strict=True): scored_union_candidates.append( ( float(np.max(np.asarray(logits, dtype=np.float32))), int(eligible_group_record["group_index"]), int(index), ) ) selected_by_group: dict[int, list[int]] = {} seen_indices: set[int] = set() for _, group_index, index in sorted(scored_union_candidates, key=lambda item: item[0], reverse=True): if int(index) in seen_indices: continue selected_by_group.setdefault(int(group_index), []).append(int(index)) seen_indices.add(int(index)) if len(seen_indices) >= union_rescue_top_k: break for eligible_group_record in eligible_group_records: group_index = int(eligible_group_record["group_index"]) chosen_novel_indices = selected_by_group.get(group_index, []) if not chosen_novel_indices: rescue_records.append( { "record_type": "union_rescue", "layer_id": int(layer_id), "group_index": int(group_index), "applied": False, "disable_reason": "novel_candidates_not_selected", "base_count": int(len(eligible_group_record["base_indices"])), "selected_count": int(len(eligible_group_record["selected_index_set"])), "novel_candidate_count": int(len(eligible_group_record["novel_candidate_indices"])), } ) continue group_indices = eligible_group_record["indices"] group_key_pages = eligible_group_record["key_pages"] adjusted_indices_by_group[group_index] = sorted( set(group_indices).union(chosen_novel_indices) ) rescue_records.append( { "record_type": "union_rescue", "layer_id": int(layer_id), "group_index": int(group_index), "kv_head_id": eligible_group_record["kv_head_id"], "applied": True, "disable_reason": None, "base_count": int(len(eligible_group_record["base_indices"])), "selected_count": int(len(eligible_group_record["selected_index_set"])), "novel_candidate_count": int(len(eligible_group_record["novel_candidate_indices"])), "selected_novel_count": int(len(chosen_novel_indices)), "selected_novel_indices": [int(index) for index in chosen_novel_indices], "selected_novel_page_ranges": [ f"{int(index)}:{int(_page_header(group_key_pages[index]).token_start)}-{int(_page_header(group_key_pages[index]).token_start + _page_header(group_key_pages[index]).token_count)}" for index in chosen_novel_indices ], } ) return adjusted_indices_by_group, rescue_records def _execution_shortlist_page_indices( self, key_pages: Sequence[PageLike], *, layer_id: int, kv_head_id: int | None = None, query_slice: np.ndarray, context_length_override: int | None = None, trace: ExecutionTrace | None = None, ) -> list[int] | None: capture_stage_timings = bool(trace is not None and trace.capture_timings) def _stage_start() -> float | None: return perf_counter() if capture_stage_timings else None def _stage_finish(stage: str, started_at: float | None) -> None: if started_at is None: return self._record_decode_stage_timing( layer_id=int(layer_id), stage=stage, ms=(perf_counter() - started_at) * 1000.0, ) if self.config.execution_shortlist_disabled_for_layer(layer_id=layer_id): return None context_length = int(context_length_override) if context_length_override is not None else None if context_length is None and key_pages: context_length = max( int((page.source_page if isinstance(page, PreparedPageTorch) else page).header.token_start) + int((page.source_page if isinstance(page, PreparedPageTorch) else page).header.token_count) for page in key_pages ) layer_recent_window = int( self.config.resolve_execution_recent_window_for_context( layer_id=layer_id, context_length=context_length, ) ) layer_relevance_top_k = int( self.config.resolve_execution_relevance_top_k_for_context( layer_id=layer_id, context_length=context_length, ) ) if ( layer_recent_window <= 0 and self.config.execution_sink_window <= 0 and layer_relevance_top_k <= 0 ): return None promote_candidate_expansion_enabled, promote_candidate_expansion_disable_reason = ( self._execution_exact_promote_policy_status( layer_id=layer_id, context_length=context_length, ) ) def _page_range_labels(indices: Sequence[int]) -> list[str]: labels: list[str] = [] for index in indices: header = _page_header(key_pages[int(index)]) labels.append( f"{int(index)}:{int(header.token_start)}-{int(header.token_start + header.token_count)}" ) return labels def _record_shortlist_trace( *, base_index_set: set[int], stage1_selected_indices: Sequence[int], final_selected_indices: Sequence[int], promote_enabled: bool, promote_disable_reason: str | None, promote_candidate_indices: Sequence[int] | None = None, promote_selected_indices: Sequence[int] | None = None, promote_target_old_page_count: int | None = None, ) -> None: stage1_old_indices = [int(index) for index in stage1_selected_indices if int(index) not in base_index_set] final_old_indices = [int(index) for index in final_selected_indices if int(index) not in base_index_set] promote_candidate_indices = ( [] if promote_candidate_indices is None else [int(index) for index in promote_candidate_indices] ) promote_selected_indices = ( [] if promote_selected_indices is None else [int(index) for index in promote_selected_indices] ) self._execution_shortlist_trace_records.append( { "record_type": "shortlist_group", "layer_id": int(layer_id), "kv_head_id": None if kv_head_id is None else int(kv_head_id), "context_length": None if context_length is None else int(context_length), "layer_recent_window": int(layer_recent_window), "layer_relevance_top_k": int(layer_relevance_top_k), "candidate_relevance_top_k": int(candidate_relevance_top_k), "base_count": int(len(base_index_set)), "stage1_count": int(len(stage1_selected_indices)), "final_count": int(len(final_selected_indices)), "stage1_old_count": int(len(stage1_old_indices)), "final_old_count": int(len(final_old_indices)), "base_indices": [int(index) for index in sorted(base_index_set)], "stage1_indices": [int(index) for index in stage1_selected_indices], "final_indices": [int(index) for index in final_selected_indices], "base_page_ranges": _page_range_labels(sorted(base_index_set)), "stage1_old_page_ranges": _page_range_labels(stage1_old_indices), "final_old_page_ranges": _page_range_labels(final_old_indices), "exact_promote_candidate_expansion_enabled": bool(promote_candidate_expansion_enabled), "exact_promote_candidate_expansion_disable_reason": promote_candidate_expansion_disable_reason, "exact_promote_enabled": bool(promote_enabled), "exact_promote_disable_reason": promote_disable_reason, "promote_candidate_count": int(len(promote_candidate_indices)), "promote_selected_count": int(len(promote_selected_indices)), "promote_target_old_page_count": ( None if promote_target_old_page_count is None else int(promote_target_old_page_count) ), "promote_candidate_indices": [int(index) for index in promote_candidate_indices], "promote_selected_indices": [int(index) for index in promote_selected_indices], "promote_candidate_page_ranges": _page_range_labels(promote_candidate_indices), "promote_selected_page_ranges": _page_range_labels(promote_selected_indices), "boundary_margin_normalized": ( None if boundary_margin_normalized is None else float(boundary_margin_normalized) ), } ) key_page_sketches: list[np.ndarray] = [] key_page_minima: list[np.ndarray] = [] key_page_maxima: list[np.ndarray] = [] for page in key_pages: source_page = page.source_page if isinstance(page, PreparedPageTorch) else page sketch = source_page.runtime_page_sketch page_min = source_page.runtime_page_min page_max = source_page.runtime_page_max if self.config.execution_relevance_mode == "sketch": if sketch is None: return None key_page_sketches.append(np.asarray(sketch, dtype=np.float32)) else: if page_min is None or page_max is None: return None key_page_minima.append(np.asarray(page_min, dtype=np.float32)) key_page_maxima.append(np.asarray(page_max, dtype=np.float32)) candidate_relevance_top_k = int(layer_relevance_top_k) if promote_candidate_expansion_enabled: candidate_relevance_top_k = max( candidate_relevance_top_k, int(layer_relevance_top_k) + int(self.config.execution_exact_promote_top_k) * 2, ) if self._execution_exact_refine_enabled(layer_id=layer_id): candidate_relevance_top_k = max( candidate_relevance_top_k, int(self.config.execution_exact_refine_top_k) * 2, ) base_window_started_at = _stage_start() base_indices = set( select_window_page_indices( key_pages, recent_window_tokens=layer_recent_window if layer_recent_window > 0 else None, sink_window_tokens=int(self.config.execution_sink_window), ) ) _stage_finish("shortlist_base_window", base_window_started_at) use_recent_old_bonus = self.config.execution_recent_old_bonus_enabled_for_layer(layer_id=layer_id) use_secondary_relevance_rescue = self._execution_secondary_relevance_enabled(layer_id=layer_id) use_recent_neighbor_rescue = self._execution_recent_neighbor_rescue_enabled(layer_id=layer_id) use_confidence_gated_exact_promote = ( promote_candidate_expansion_enabled and float(self.config.execution_exact_promote_margin_threshold) > 0.0 ) boundary_margin_normalized = None shortlist_candidate_scoring_started_at = _stage_start() if ( use_recent_old_bonus or use_secondary_relevance_rescue or use_recent_neighbor_rescue or use_confidence_gated_exact_promote ): if candidate_relevance_top_k > 0: candidate_indices = [index for index in range(len(key_pages)) if index not in base_indices] if candidate_indices: approx_scores: list[float] = [] shortlist_candidate_approx_scoring_started_at = _stage_start() for index in candidate_indices: approx_score = _score_page_relevance_for_mode( np.asarray(query_slice, dtype=np.float32), key_pages[index], relevance_mode=self.config.execution_relevance_mode, ) if approx_score is None: return None approx_scores.append(float(approx_score)) _stage_finish("shortlist_candidate_approx_scoring", shortlist_candidate_approx_scoring_started_at) shortlist_candidate_ranking_started_at = _stage_start() score_scale = max(float(np.std(np.asarray(approx_scores, dtype=np.float32))), 1e-6) recent_start = int(context_length) - int(layer_recent_window) if layer_recent_window > 0 and context_length is not None else int(context_length or 0) adjusted_scores = [ score + ( float(self.config.execution_recent_old_bonus_strength) * score_scale * _recent_old_bonus_weight( key_pages[index], recent_start=int(recent_start), bonus_window=int(self.config.execution_recent_old_bonus_window), ) ) for index, score in zip(candidate_indices, approx_scores, strict=True) ] if len(adjusted_scores) > candidate_relevance_top_k and candidate_relevance_top_k > 0: sorted_scores = sorted(adjusted_scores, reverse=True) boundary_margin_normalized = float( (sorted_scores[candidate_relevance_top_k - 1] - sorted_scores[candidate_relevance_top_k]) / max(float(np.std(np.asarray(adjusted_scores, dtype=np.float32))), 1e-6) ) ranked_candidates = [ index for _, index in sorted( zip(adjusted_scores, candidate_indices, strict=True), key=lambda item: item[0], reverse=True, ) ] stage1_ranked_candidates = ranked_candidates[:candidate_relevance_top_k] stage1_indices = sorted(base_indices.union(stage1_ranked_candidates)) _stage_finish("shortlist_candidate_ranking", shortlist_candidate_ranking_started_at) if use_recent_neighbor_rescue and layer_recent_window > 0 and context_length is not None: shortlist_candidate_neighbor_rescue_started_at = _stage_start() recent_start = int(context_length) - int(layer_recent_window) primary_top_indices = stage1_ranked_candidates[:layer_relevance_top_k] anchor_pages = [ index for index in primary_top_indices if int(_page_header(key_pages[index]).token_start + _page_header(key_pages[index]).token_count) <= int(self.config.execution_recent_neighbor_rescue_anchor_window) ] recent_old_indices = [ index for index in primary_top_indices if ( int(_page_header(key_pages[index]).token_start + _page_header(key_pages[index]).token_count) <= int(recent_start) and int(_page_header(key_pages[index]).token_start + _page_header(key_pages[index]).token_count) > int(recent_start - layer_recent_window) ) ] if ( len(anchor_pages) >= int(self.config.execution_recent_neighbor_rescue_min_anchor_pages) and recent_old_indices ): rescue_indices: list[int] = [] probe_index = min(recent_old_indices) - 1 stage1_index_set = set(stage1_indices) while probe_index >= 0 and len(rescue_indices) < int(self.config.execution_recent_neighbor_rescue_top_k): page_end = int( _page_header(key_pages[probe_index]).token_start + _page_header(key_pages[probe_index]).token_count ) if page_end <= int(recent_start - layer_recent_window): break if probe_index not in base_indices and probe_index not in stage1_index_set: rescue_indices.append(int(probe_index)) probe_index -= 1 if rescue_indices: stage1_indices = sorted(stage1_index_set.union(rescue_indices)) _stage_finish("shortlist_candidate_neighbor_rescue", shortlist_candidate_neighbor_rescue_started_at) if use_secondary_relevance_rescue and layer_relevance_top_k > 0: shortlist_candidate_secondary_scoring_started_at = _stage_start() secondary_scores: list[float] = [] for index in candidate_indices: secondary_score = _score_page_relevance_for_mode( np.asarray(query_slice, dtype=np.float32), key_pages[index], relevance_mode=self.config.execution_secondary_relevance_mode, ) if secondary_score is None: secondary_scores = [] break secondary_scores.append(float(secondary_score)) if secondary_scores: secondary_ranked_candidates = [ index for _, index in sorted( zip(secondary_scores, candidate_indices, strict=True), key=lambda item: item[0], reverse=True, ) ] primary_top_indices = stage1_ranked_candidates[:layer_relevance_top_k] secondary_top_indices = secondary_ranked_candidates[:layer_relevance_top_k] overlap_budget = min(len(primary_top_indices), len(secondary_top_indices)) overlap_ratio = ( float(len(set(primary_top_indices) & set(secondary_top_indices))) / float(overlap_budget) if overlap_budget > 0 else 1.0 ) if overlap_ratio < float(self.config.execution_secondary_relevance_min_overlap): rescue_indices: list[int] = [] for index in secondary_ranked_candidates: if index in stage1_indices: continue rescue_indices.append(int(index)) if len(rescue_indices) >= int(self.config.execution_secondary_relevance_top_k): break if rescue_indices: stage1_indices = sorted(set(stage1_indices).union(rescue_indices)) _stage_finish("shortlist_candidate_secondary_scoring", shortlist_candidate_secondary_scoring_started_at) else: stage1_indices = sorted(base_indices) else: stage1_indices = sorted(base_indices) else: builtin_sketch_matrix = None builtin_minima_matrix = None builtin_maxima_matrix = None builtin_tail_sketch = None builtin_tail_minimum = None builtin_tail_maximum = None builtin_matrix_prepare_started_at = _stage_start() if self.config.execution_builtin_selector_cache: ( builtin_sketch_matrix, builtin_minima_matrix, builtin_maxima_matrix, builtin_tail_sketch, builtin_tail_minimum, builtin_tail_maximum, ) = self._execution_builtin_selector_matrices( layer_id=int(layer_id), kv_head_id=kv_head_id, key_pages=key_pages, relevance_mode=self.config.execution_relevance_mode, score_all_pages_with_matrices=( self.config.execution_builtin_selector_score_all_pages and not self.config.execution_builtin_selector_candidate_only ), ) if ( builtin_sketch_matrix is not None or builtin_minima_matrix is not None or builtin_maxima_matrix is not None ): _stage_finish("shortlist_candidate_builtin_sidecar_stack", builtin_matrix_prepare_started_at) shortlist_candidate_builtin_selection_started_at = _stage_start() stage1_indices = select_execution_page_indices( key_pages, recent_window_tokens=layer_recent_window if layer_recent_window > 0 else None, sink_window_tokens=int(self.config.execution_sink_window), query_slice=np.asarray(query_slice, dtype=np.float32), key_page_sketches=key_page_sketches if key_page_sketches else None, key_page_sketch_matrix=builtin_sketch_matrix, tail_page_sketch=builtin_tail_sketch, key_page_minima=key_page_minima if key_page_minima else None, key_page_minima_matrix=builtin_minima_matrix, tail_page_minimum=builtin_tail_minimum, key_page_maxima=key_page_maxima if key_page_maxima else None, key_page_maxima_matrix=builtin_maxima_matrix, tail_page_maximum=builtin_tail_maximum, relevance_top_k=candidate_relevance_top_k, relevance_mode=self.config.execution_relevance_mode, score_all_pages_with_matrices=( self.config.execution_builtin_selector_score_all_pages and not self.config.execution_builtin_selector_candidate_only ), score_all_pages_min_candidate_fraction=self.config.execution_builtin_selector_score_all_pages_min_candidate_fraction, selector_stats_recorder=lambda stats: self._record_builtin_selector_stats( candidate_pages=int(stats["candidate_pages"]), total_pages=int(stats["total_pages"]), candidate_fraction=float(stats["candidate_fraction"]), used_score_all_pages=bool(stats["used_score_all_pages"]), ), stage_recorder=lambda stage, ms: self._record_decode_stage_timing( layer_id=int(layer_id), stage=stage, ms=float(ms), ), ) _stage_finish("shortlist_candidate_builtin_selection", shortlist_candidate_builtin_selection_started_at) _stage_finish("shortlist_candidate_scoring", shortlist_candidate_scoring_started_at) if not self._execution_exact_refine_enabled(layer_id=layer_id): promote_enabled, promote_disable_reason = self._execution_exact_promote_status( layer_id=layer_id, context_length=context_length, boundary_margin_normalized=boundary_margin_normalized, ) if not promote_enabled: _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=stage1_indices, promote_enabled=False, promote_disable_reason=promote_disable_reason, ) return stage1_indices candidate_indices = [index for index in stage1_indices if index not in base_indices] if not candidate_indices: _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=stage1_indices, promote_enabled=True, promote_disable_reason=None, promote_candidate_indices=[], promote_selected_indices=[], promote_target_old_page_count=0, ) return stage1_indices target_old_page_count = min( len(candidate_indices), int(layer_relevance_top_k) + int(self.config.execution_exact_promote_top_k), ) if target_old_page_count >= len(candidate_indices): _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=stage1_indices, promote_enabled=True, promote_disable_reason=None, promote_candidate_indices=candidate_indices, promote_selected_indices=candidate_indices, promote_target_old_page_count=target_old_page_count, ) return stage1_indices shortlist_exact_selection_started_at = _stage_start() candidate_logits = score_pages( np.asarray(query_slice, dtype=np.float32), [key_pages[index] for index in candidate_indices], backend=self.backend, trace=trace, ) chosen = [ index for _, index in sorted( ( (float(np.max(np.asarray(logits, dtype=np.float32))), index) for index, logits in zip(candidate_indices, candidate_logits, strict=True) ), key=lambda item: item[0], reverse=True, )[:target_old_page_count] ] final_indices = sorted(base_indices.union(chosen)) _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=final_indices, promote_enabled=True, promote_disable_reason=None, promote_candidate_indices=candidate_indices, promote_selected_indices=chosen, promote_target_old_page_count=target_old_page_count, ) _stage_finish("shortlist_exact_selection", shortlist_exact_selection_started_at) return final_indices base_indices = set( select_window_page_indices( key_pages, recent_window_tokens=layer_recent_window if layer_recent_window > 0 else None, sink_window_tokens=int(self.config.execution_sink_window), ) ) candidate_indices = [index for index in stage1_indices if index not in base_indices] if not candidate_indices: self._record_execution_exact_refine(layer_id=layer_id, candidate_pages=0, selected_pages=0) _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=stage1_indices, promote_enabled=False, promote_disable_reason="exact_refine_enabled", ) return stage1_indices top_k = min(int(self.config.execution_exact_refine_top_k), len(candidate_indices)) if top_k >= len(candidate_indices): self._record_execution_exact_refine( layer_id=layer_id, candidate_pages=len(candidate_indices), selected_pages=len(candidate_indices), ) _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=stage1_indices, promote_enabled=False, promote_disable_reason="exact_refine_enabled", ) return stage1_indices shortlist_exact_selection_started_at = _stage_start() candidate_logits = score_pages( np.asarray(query_slice, dtype=np.float32), [key_pages[index] for index in candidate_indices], backend=self.backend, trace=trace, ) chosen = [ index for _, index in sorted( ( (float(np.max(np.asarray(logits, dtype=np.float32))), index) for index, logits in zip(candidate_indices, candidate_logits, strict=True) ), key=lambda item: item[0], reverse=True, )[:top_k] ] self._record_execution_exact_refine( layer_id=layer_id, candidate_pages=len(candidate_indices), selected_pages=len(chosen), ) final_indices = sorted(base_indices.union(chosen)) _record_shortlist_trace( base_index_set=base_indices, stage1_selected_indices=stage1_indices, final_selected_indices=final_indices, promote_enabled=False, promote_disable_reason="exact_refine_enabled", ) _stage_finish("shortlist_exact_selection", shortlist_exact_selection_started_at) return final_indices def _should_build_execution_runtime_metadata(self, *, kind: str) -> bool: if kind != "K": return False return self.config.execution_shortlist_enabled() def clear(self) -> None: for state in self._states.values(): state.clear(clear_prepared_cache=False) self.cache.clear() clear_prepared_chunk_cache() self._m2_prefilter_invocations = 0 self._m2_prefilter_candidate_pages = 0 self._m2_prefilter_selected_pages = 0 self._decode_path_counts = { "grouped_batched": 0, "per_kv_fallback": 0, } self._decode_path_counts_by_layer = {} self._execution_shortlist_invocations = 0 self._execution_shortlist_applied = 0 self._execution_shortlist_group_union_applied = 0 self._execution_shortlist_grouping_rejections = 0 self._execution_shortlist_grouping_rejection_reason_counts = {} self._execution_shortlist_grouping_rejection_reason_counts_by_layer = {} self._execution_shortlist_total_pages = 0 self._execution_shortlist_selected_pages = 0 self._execution_shortlist_invocations_by_layer = {} self._execution_shortlist_applied_by_layer = {} self._execution_shortlist_group_union_applied_by_layer = {} self._execution_shortlist_grouping_rejections_by_layer = {} self._execution_shortlist_total_pages_by_layer = {} self._execution_shortlist_selected_pages_by_layer = {} self._execution_shortlist_trace_records = [] self._execution_exact_refine_invocations = 0 self._execution_exact_refine_candidate_pages = 0 self._execution_exact_refine_selected_pages = 0 self._execution_exact_refine_invocations_by_layer = {} self._execution_exact_refine_candidate_pages_by_layer = {} self._execution_exact_refine_selected_pages_by_layer = {} self._decode_grouped_batch_rejection_reason_counts = {} self._decode_grouped_batch_rejection_reason_counts_by_layer = {} self._decode_stage_timings = _empty_decode_stage_timing_totals() self._decode_stage_timings_by_layer = {} self._reset_resident_accounting() self._reset_chunk_budget_tracking() self._reset_builtin_selector_tracking() self._reset_execution_value_escape_tracking() self._prepared_chunk_cache_frozen_budget_bytes = None self._prepared_chunk_cache_applied_budget_bytes = None self._prepared_chunk_cache_budget_dirty = True def clear_layer(self, layer_id: int) -> None: self._validate_layer_id(layer_id) layer_keys = [key for key in self._states if key[0] == layer_id] if not layer_keys: return for key in layer_keys: self._states[key].clear(clear_prepared_cache=False) self.cache.clear() clear_prepared_chunk_cache() self._rebuild_resident_accounting() self._decode_path_counts_by_layer.pop(int(layer_id), None) self._decode_path_counts = { "grouped_batched": 0, "per_kv_fallback": 0, } for counts in self._decode_path_counts_by_layer.values(): for path_name, count in counts.items(): if path_name in self._decode_path_counts: self._decode_path_counts[path_name] += int(count) self._execution_shortlist_invocations_by_layer.pop(int(layer_id), None) self._execution_shortlist_applied_by_layer.pop(int(layer_id), None) self._execution_shortlist_group_union_applied_by_layer.pop(int(layer_id), None) self._execution_shortlist_grouping_rejections_by_layer.pop(int(layer_id), None) self._execution_shortlist_grouping_rejection_reason_counts_by_layer.pop(int(layer_id), None) self._execution_shortlist_total_pages_by_layer.pop(int(layer_id), None) self._execution_shortlist_selected_pages_by_layer.pop(int(layer_id), None) self._execution_exact_refine_invocations_by_layer.pop(int(layer_id), None) self._execution_exact_refine_candidate_pages_by_layer.pop(int(layer_id), None) self._execution_exact_refine_selected_pages_by_layer.pop(int(layer_id), None) self._decode_grouped_batch_rejection_reason_counts_by_layer.pop(int(layer_id), None) self._prepared_chunk_cache_frozen_budget_bytes = None self._prepared_chunk_cache_applied_budget_bytes = None self._prepared_chunk_cache_budget_dirty = True self._decode_stage_timings_by_layer.pop(int(layer_id), None) self._execution_shortlist_invocations = sum(self._execution_shortlist_invocations_by_layer.values()) self._execution_shortlist_applied = sum(self._execution_shortlist_applied_by_layer.values()) self._execution_shortlist_group_union_applied = sum(self._execution_shortlist_group_union_applied_by_layer.values()) self._execution_shortlist_grouping_rejections = sum(self._execution_shortlist_grouping_rejections_by_layer.values()) self._execution_shortlist_grouping_rejection_reason_counts = {} for layer_reason_counts in self._execution_shortlist_grouping_rejection_reason_counts_by_layer.values(): for reason, count in layer_reason_counts.items(): self._execution_shortlist_grouping_rejection_reason_counts[reason] = ( self._execution_shortlist_grouping_rejection_reason_counts.get(reason, 0) + int(count) ) self._execution_shortlist_total_pages = sum(self._execution_shortlist_total_pages_by_layer.values()) self._execution_shortlist_selected_pages = sum(self._execution_shortlist_selected_pages_by_layer.values()) self._execution_exact_refine_invocations = sum(self._execution_exact_refine_invocations_by_layer.values()) self._execution_exact_refine_candidate_pages = sum(self._execution_exact_refine_candidate_pages_by_layer.values()) self._execution_exact_refine_selected_pages = sum(self._execution_exact_refine_selected_pages_by_layer.values()) self._decode_grouped_batch_rejection_reason_counts = {} for layer_reason_counts in self._decode_grouped_batch_rejection_reason_counts_by_layer.values(): for reason, count in layer_reason_counts.items(): self._decode_grouped_batch_rejection_reason_counts[reason] = ( self._decode_grouped_batch_rejection_reason_counts.get(reason, 0) + int(count) ) self._decode_stage_timings = _empty_decode_stage_timing_totals() for layer_timings in self._decode_stage_timings_by_layer.values(): for stage, value in layer_timings.items(): if stage in self._decode_stage_timings: self._decode_stage_timings[stage] += float(value) self._reset_builtin_selector_tracking() self._reset_execution_value_escape_tracking() self._prepared_chunk_cache_frozen_budget_bytes = None self._prepared_chunk_cache_budget_dirty = True def _grouped_query_heads_for_mapping(self, q_head_to_kv_head: Sequence[int] | np.ndarray) -> tuple[tuple[int, ...], ...]: mapping = np.asarray(q_head_to_kv_head, dtype=np.int64) if mapping.shape != (self.num_attention_heads,): raise ValueError("q_head_to_kv_head must have shape [num_attention_heads]") if np.array_equal(mapping, self.default_q_head_to_kv_head): return self.default_grouped_query_heads return _group_query_heads(mapping, num_key_value_heads=self.num_key_value_heads) def _encode_full_prefill_pages( self, layer_id: int, keys: np.ndarray, values: np.ndarray, *, sequence_length: int, full_tokens: int, ) -> tuple[list[list[EncodedPage]], list[list[EncodedPage]]]: key_pages_by_head: list[list[EncodedPage]] = [[] for _ in range(self.num_key_value_heads)] value_pages_by_head: list[list[EncodedPage]] = [[] for _ in range(self.num_key_value_heads)] if full_tokens <= 0: return key_pages_by_head, value_pages_by_head page_size = self.config.tokens_per_page full_page_count = full_tokens // page_size build_key_sidecar = ( self.config.m2_prefilter_top_k > 0 and full_page_count >= int(self.config.m2_prefilter_min_pages) ) full_keys = np.ascontiguousarray(keys[:, :full_tokens], dtype=np.float32) full_values = np.ascontiguousarray(values[:, :full_tokens], dtype=np.float32) shared_m4_basis_by_head: list[np.ndarray | None] = [None] * self.num_key_value_heads for kv_head_id in range(self.num_key_value_heads): resolved_basis = self.config.resolve_m4_project_basis_k(layer_id=layer_id) if resolved_basis != "svd_shared": continue key_page_mode = self._select_page_mode( full_keys[kv_head_id, :page_size], kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=0, sequence_length=sequence_length, stage="prefill", ) key_mode_name = ( key_page_mode.mode if key_page_mode is not None else self.config.resolve_page_mode(kind="K", layer_id=layer_id, kv_head_id=kv_head_id) ) if key_mode_name != "M4": continue shared_m4_basis_by_head[kv_head_id] = fit_shared_project_basis( full_keys[kv_head_id, :full_tokens], group_size=self.config.group_size, project_dim=self.config.resolve_m4_project_dim_k(layer_id=layer_id), page_size=page_size, ).astype(np.float16, copy=False) for page_start in range(0, full_tokens, page_size): page_end = page_start + page_size for kv_head_id in range(self.num_key_value_heads): key_page_mode = self._select_page_mode( full_keys[kv_head_id, page_start:page_end], kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=page_start, sequence_length=sequence_length, stage="prefill", ) key_pages_by_head[kv_head_id].append( encode_page( full_keys[kv_head_id, page_start:page_end], self.config, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=page_start, page_mode=key_page_mode, build_runtime_metadata=self._should_build_execution_runtime_metadata(kind="K"), build_m2_sidecar=build_key_sidecar, m4_basis_override=( shared_m4_basis_by_head[kv_head_id] if ( ( key_page_mode.mode if key_page_mode is not None else self.config.resolve_page_mode(kind="K", layer_id=layer_id, kv_head_id=kv_head_id) ) == "M4" and shared_m4_basis_by_head[kv_head_id] is not None ) else None ), ) ) dense_value_page = full_values[kv_head_id, page_start:page_end] value_page = encode_page( dense_value_page, self.config, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=page_start, page_mode=self._select_page_mode( dense_value_page, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=page_start, sequence_length=sequence_length, stage="prefill", ), build_runtime_metadata=False, ) self._maybe_register_execution_value_escape_source( value_page, dense_values=dense_value_page, escape_mode=str(self.config.execution_value_escape_mode), ) value_pages_by_head[kv_head_id].append(value_page) return key_pages_by_head, value_pages_by_head def _can_direct_prepare_full_prefill_pages_torch(self) -> bool: if not self._use_persistent_torch_tail: return False if self.config.learned_page_selector_enabled(): return False if int(self.config.m2_prefilter_top_k) > 0: return False if self.config.has_mode_overrides() or self.config.has_policy_overrides(): return False if self.config.default_mode_k != "M0" or self.config.default_mode_v != "M0": return False if self.config.quant_scheme_k != "affine" or self.config.quant_scheme_v != "affine": return False if self.config.payload_layout_k != "group_major" or self.config.payload_layout_v != "group_major": return False return True def _select_page_mode( self, values: np.ndarray, *, kind: str, layer_id: int, kv_head_id: int, token_start: int, sequence_length: int, stage: str = "unknown", ) -> PageModeSpec | None: learned_page_mode = self._select_page_mode_with_learned_selector( values, kind=kind, layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start, sequence_length=sequence_length, stage=stage, ) if learned_page_mode is not None: return learned_page_mode if not self.config.has_policy_overrides(kind=kind) and not self.config.has_mode_overrides(kind=kind): return None layer_policy = self.config.resolve_layer_policy(kind=kind, layer_id=layer_id, kv_head_id=kv_head_id) page_stats = observe_page(values) token_age = max(0, int(sequence_length) - int(token_start) - 1) return choose_page_mode( int(layer_id), kind, token_age, page_stats, layer_policy=layer_policy, ) def _select_page_mode_with_learned_selector( self, values: np.ndarray, *, kind: str, layer_id: int, kv_head_id: int, token_start: int, sequence_length: int, stage: str, ) -> PageModeSpec | None: model = self._learned_page_selector_model if model is None: return None if not self.config.learned_page_selector_applies_to_kind(kind=str(kind)): return None stage_name = str(stage) started_at = perf_counter() page_stats = observe_page(values) row = { "stage": stage_name, "kind": str(kind), "prompt_family": self.config.learned_page_selector_prompt_family, "prompt_variant": self.config.learned_page_selector_prompt_variant, "query_present": False, "layer_fraction": float(int(layer_id) / max(self.num_hidden_layers - 1, 1)), "kv_head_fraction": float(int(kv_head_id) / max(self.num_key_value_heads - 1, 1)), "token_start": int(token_start), "token_age": max(0, int(sequence_length) - int(token_start) - int(values.shape[0])), "token_count": int(values.shape[0]), "head_dim": int(values.shape[1]), "safe_candidate_count": 0.0, "trace_rms": float(page_stats.rms), "trace_abs_max": float(page_stats.abs_max), "trace_channel_range_mean": float(page_stats.channel_range_mean), "trace_outlier_fraction": float(page_stats.outlier_fraction), "age_per_token": float(max(0, int(sequence_length) - int(token_start) - int(values.shape[0])) / max(int(values.shape[0]), 1)), } predicted = model.predict_row(row) elapsed_ms = float((perf_counter() - started_at) * 1000.0) self._learned_page_selector_ms_total += elapsed_ms self._learned_page_selector_invocations += 1 self._learned_page_selector_invocations_by_stage[stage_name] = ( self._learned_page_selector_invocations_by_stage.get(stage_name, 0) + 1 ) self._learned_page_selector_ms_total_by_stage[stage_name] = ( float(self._learned_page_selector_ms_total_by_stage.get(stage_name, 0.0)) + elapsed_ms ) if predicted is None: self._learned_page_selector_fallbacks += 1 self._learned_page_selector_fallbacks_by_stage[stage_name] = ( self._learned_page_selector_fallbacks_by_stage.get(stage_name, 0) + 1 ) return None try: page_mode = parse_page_mode_token(predicted) except ValueError: self._learned_page_selector_fallbacks += 1 self._learned_page_selector_fallbacks_by_stage[stage_name] = ( self._learned_page_selector_fallbacks_by_stage.get(stage_name, 0) + 1 ) return None token = f"{page_mode.mode}/{page_mode.quant_scheme}/{page_mode.bits}" + ( "" if page_mode.escape_dtype is None else f"/{page_mode.escape_dtype}" ) self._learned_page_selector_predictions[token] = self._learned_page_selector_predictions.get(token, 0) + 1 stage_predictions = self._learned_page_selector_predictions_by_stage.setdefault(stage_name, {}) stage_predictions[token] = stage_predictions.get(token, 0) + 1 return page_mode def prepare_static_pages(self, *, trace: ExecutionTrace | None = None) -> None: if self._torch_device_type is None: return key_refs: list[tuple[_HeadSessionState, int]] = [] key_pages: list[PageLike] = [] value_refs: list[tuple[_HeadSessionState, int]] = [] value_pages: list[PageLike] = [] touched_states: dict[int, _HeadSessionState] = {} for state in self._states.values(): for index, page in enumerate(state.session.key_pages): if isinstance(page, PreparedPageTorch): continue key_refs.append((state, index)) key_pages.append(page) touched_states[id(state)] = state for index, page in enumerate(state.session.value_pages): if isinstance(page, PreparedPageTorch): continue value_refs.append((state, index)) value_pages.append(page) touched_states[id(state)] = state if key_pages: prepared_keys = prepare_pages(key_pages, backend=self.backend, cache=self.cache, trace=trace) for (state, index), prepared in zip(key_refs, prepared_keys, strict=True): state.session.key_pages[index] = prepared state.invalidate_decode_views() if value_pages: prepared_values = prepare_pages(value_pages, backend=self.backend, cache=self.cache, trace=trace) for (state, index), prepared in zip(value_refs, prepared_values, strict=True): state.session.value_pages[index] = prepared state.invalidate_decode_views() if key_pages or value_pages: for state in touched_states.values(): self._refresh_state_resident_accounting(state) self._mark_prepared_chunk_cache_budget_dirty(reason="prepare_static_pages") for state in self._states.values(): self._maybe_prewarm_execution_builtin_selector_cache(state) self._maybe_prewarm_execution_value_escape_pages(state, trace=trace) def _ensure_prepared_static_pages( self, state: _HeadSessionState, *, trace: ExecutionTrace | None = None, ) -> None: if self._torch_device_type is None: return state_changed = False if state.session.key_pages and not all(isinstance(page, PreparedPageTorch) for page in state.session.key_pages): state.session.key_pages = prepare_pages( state.session.key_pages, backend=self.backend, cache=self.cache, trace=trace, ) state.invalidate_decode_views() state_changed = True if state.session.value_pages and not all(isinstance(page, PreparedPageTorch) for page in state.session.value_pages): state.session.value_pages = prepare_pages( state.session.value_pages, backend=self.backend, cache=self.cache, trace=trace, ) state.invalidate_decode_views() state_changed = True if state_changed: self._refresh_state_resident_accounting(state) self._mark_prepared_chunk_cache_budget_dirty(reason="ensure_prepared_static_pages") self._maybe_prewarm_execution_builtin_selector_cache(state) self._maybe_prewarm_execution_value_escape_pages(state, trace=trace) def _validate_layer_id(self, layer_id: int) -> None: if layer_id < 0 or layer_id >= self.num_hidden_layers: raise ValueError(f"layer_id must be in [0, {self.num_hidden_layers})") def _state(self, layer_id: int, kv_head_id: int) -> _HeadSessionState: self._validate_layer_id(layer_id) if kv_head_id < 0 or kv_head_id >= self.num_key_value_heads: raise ValueError(f"kv_head_id must be in [0, {self.num_key_value_heads})") key = (layer_id, kv_head_id) state = self._states.get(key) if state is None: torch_device_type = self._torch_device_type state = _HeadSessionState( session=PagedDecodeSession(backend=self.backend, cache=self.cache), tail=_TailPageBuilder(self.config, layer_id=layer_id, kv_head_id=kv_head_id), persistent_key_tail=_PersistentTailPage( self.config, layer_id=layer_id, kv_head_id=kv_head_id, kind="K", device_type=torch_device_type, ) if torch_device_type is not None else None, persistent_value_tail=_PersistentTailPage( self.config, layer_id=layer_id, kv_head_id=kv_head_id, kind="V", device_type=torch_device_type, ) if torch_device_type is not None else None, ) self._states[key] = state return state @property def _torch_device_type(self) -> str | None: if self.backend == "torch_mps": return "mps" if mps_available() else None if self.backend == "torch_cuda": return "cuda" if cuda_available() else None if self.backend == "auto": if cuda_available(): return "cuda" if mps_available(): return "mps" return None @property def _use_persistent_torch_tail(self) -> bool: return self._torch_device_type is not None def layer_sequence_length(self, layer_id: int) -> int: self._validate_layer_id(layer_id) lengths = {self._state(layer_id, kv_head_id).sequence_length for kv_head_id in range(self.num_key_value_heads)} if len(lengths) > 1: raise RuntimeError(f"layer {layer_id} KV heads disagree on sequence length") return next(iter(lengths), 0) def _m2_prefilter_pages_numpy( self, queries: np.ndarray, key_pages: Sequence[PageLike], value_pages: Sequence[PageLike], ) -> tuple[list[PageLike], list[PageLike]]: top_k = int(self.config.m2_prefilter_top_k) if top_k <= 0 or len(key_pages) <= top_k: return list(key_pages), list(value_pages) always_keep: list[int] = [] candidate_indices: list[int] = [] candidate_scores: list[float] = [] for page_index, page in enumerate(key_pages): if not _page_has_m2_sidecar(page): always_keep.append(page_index) continue candidate_indices.append(page_index) candidate_pages = [key_pages[index] for index in candidate_indices] if ( len(candidate_indices) <= top_k or len(candidate_indices) < int(self.config.m2_prefilter_min_pages) ): return list(key_pages), list(value_pages) if _pages_can_batch_m2_prefilter(candidate_pages): score_array = _page_m2_prefilter_scores_numpy(queries, candidate_pages) else: for page in candidate_pages: candidate_scores.append(_page_m2_prefilter_score_numpy(queries, page)) score_array = np.asarray(candidate_scores, dtype=np.float32) selected_order = np.argpartition(score_array, -top_k)[-top_k:] selected_indices = sorted(always_keep + [candidate_indices[index] for index in selected_order.tolist()]) self._m2_prefilter_invocations += 1 self._m2_prefilter_candidate_pages += len(candidate_indices) self._m2_prefilter_selected_pages += len(selected_indices) return [key_pages[index] for index in selected_indices], [value_pages[index] for index in selected_indices] def _m2_prefilter_pages_torch( self, queries, key_pages: Sequence[PageLike], value_pages: Sequence[PageLike], ) -> tuple[list[PageLike], list[PageLike]]: top_k = int(self.config.m2_prefilter_top_k) if top_k <= 0 or len(key_pages) <= top_k: return list(key_pages), list(value_pages) always_keep: list[int] = [] candidate_indices: list[int] = [] candidate_scores: list[float] = [] for page_index, page in enumerate(key_pages): if not _page_has_m2_sidecar(page): always_keep.append(page_index) continue candidate_indices.append(page_index) candidate_pages = [key_pages[index] for index in candidate_indices] if ( len(candidate_indices) <= top_k or len(candidate_indices) < int(self.config.m2_prefilter_min_pages) ): return list(key_pages), list(value_pages) if _pages_can_batch_m2_prefilter(candidate_pages): score_array = _page_m2_prefilter_scores_torch(queries, candidate_pages) else: for page in candidate_pages: candidate_scores.append(_page_m2_prefilter_score_torch(queries, page)) score_array = np.asarray(candidate_scores, dtype=np.float32) selected_order = np.argpartition(score_array, -top_k)[-top_k:] selected_indices = sorted(always_keep + [candidate_indices[index] for index in selected_order.tolist()]) self._m2_prefilter_invocations += 1 self._m2_prefilter_candidate_pages += len(candidate_indices) self._m2_prefilter_selected_pages += len(selected_indices) return [key_pages[index] for index in selected_indices], [value_pages[index] for index in selected_indices] def page_mode_summary(self) -> dict[str, object]: counts: dict[str, int] = { "total_static_pages": 0, "m0_pages": 0, "m1_pages": 0, "m2_pages": 0, "m4_pages": 0, "m3_pages": 0, "m2_sidecar_pages": 0, "requested_m1_pages": 0, "m1_fallback_pages": 0, "active_tail_pages": 0, "k_total_static_pages": 0, "v_total_static_pages": 0, "k_m0_pages": 0, "k_m1_pages": 0, "k_m2_pages": 0, "k_m4_pages": 0, "k_m3_pages": 0, "k_m2_sidecar_pages": 0, "v_m0_pages": 0, "v_m1_pages": 0, "v_m2_pages": 0, "v_m4_pages": 0, "v_m3_pages": 0, "v_m2_sidecar_pages": 0, "k_requested_m1_pages": 0, "v_requested_m1_pages": 0, "k_m1_fallback_pages": 0, "v_m1_fallback_pages": 0, } m1_trial_errors: list[float] = [] m1_trial_token_p95_errors: list[float] = [] k_m1_trial_errors: list[float] = [] k_m1_trial_token_p95_errors: list[float] = [] v_m1_trial_errors: list[float] = [] v_m1_trial_token_p95_errors: list[float] = [] policy_tier_counts: dict[str, int] = {} fallback_reason_counts: dict[str, int] = {} signature_counts: dict[str, int] = {} layer_kind_mode_counts: dict[str, int] = {} def visit_page(page: PageLike) -> None: source = page.source_page if isinstance(page, PreparedPageTorch) else page counts["total_static_pages"] += 1 kind_prefix = str(source.header.kind).lower() counts[f"{kind_prefix}_total_static_pages"] += 1 mode_name = str(source.header.mode_default) key = f"{mode_name.lower()}_pages" if key in counts: counts[key] += 1 kind_key = f"{kind_prefix}_{mode_name.lower()}_pages" if kind_key in counts: counts[kind_key] += 1 policy_tier_counts[source.header.sensitivity_tier] = policy_tier_counts.get(source.header.sensitivity_tier, 0) + 1 fallback_key = source.header.fallback_reason or "none" fallback_reason_counts[fallback_key] = fallback_reason_counts.get(fallback_key, 0) + 1 signature = f"{source.header.kind}:{source.header.mode_default}:{source.header.quant_scheme}:{source.header.bits}" if source.header.mode_default == "M3": signature = f"{signature}:{source.header.escape_dtype}" signature_counts[signature] = signature_counts.get(signature, 0) + 1 layer_mode_key = f"layer:{source.header.layer_id}:{source.header.kind}:{source.header.mode_default}:{source.header.bits}" if source.header.mode_default == "M3": layer_mode_key = f"{layer_mode_key}:{source.header.escape_dtype}" layer_kind_mode_counts[layer_mode_key] = layer_kind_mode_counts.get(layer_mode_key, 0) + 1 if source.m2_sketch is not None and source.m2_basis is not None and source.header.mode_default != "M2": counts["m2_sidecar_pages"] += 1 counts[f"{kind_prefix}_m2_sidecar_pages"] += 1 if source.requested_mode == "M1": counts["requested_m1_pages"] += 1 counts[f"{kind_prefix}_requested_m1_pages"] += 1 if source.header.mode_default != "M1": counts["m1_fallback_pages"] += 1 counts[f"{kind_prefix}_m1_fallback_pages"] += 1 if source.trial_quant_error is not None: error_value = float(source.trial_quant_error) m1_trial_errors.append(error_value) if kind_prefix == "k": k_m1_trial_errors.append(error_value) else: v_m1_trial_errors.append(error_value) if source.trial_token_p95_error is not None: token_error_value = float(source.trial_token_p95_error) m1_trial_token_p95_errors.append(token_error_value) if kind_prefix == "k": k_m1_trial_token_p95_errors.append(token_error_value) else: v_m1_trial_token_p95_errors.append(token_error_value) for state in self._states.values(): for page in state.session.key_pages: visit_page(page) for page in state.session.value_pages: visit_page(page) if state.tail.token_count > 0: counts["active_tail_pages"] += 2 summary: dict[str, float | int] = dict(counts) summary["policy_tier_counts"] = dict(sorted(policy_tier_counts.items())) summary["fallback_reason_counts"] = dict(sorted(fallback_reason_counts.items())) summary["mode_signature_counts"] = dict(sorted(signature_counts.items())) summary["layer_kind_mode_counts"] = dict(sorted(layer_kind_mode_counts.items())) total_buckets = len(signature_counts) single_page_buckets = sum(1 for count in signature_counts.values() if count == 1) total_pages = int(counts["total_static_pages"]) summary["fragmentation_total_buckets"] = total_buckets summary["fragmentation_single_page_buckets"] = single_page_buckets summary["fragmentation_avg_pages_per_bucket"] = ( float(total_pages / total_buckets) if total_buckets > 0 else 0.0 ) summary["fragmentation_single_page_bucket_fraction"] = ( float(single_page_buckets / total_buckets) if total_buckets > 0 else 0.0 ) if m1_trial_errors: errors = np.asarray(m1_trial_errors, dtype=np.float32) summary["m1_trial_error_mean"] = float(np.mean(errors)) summary["m1_trial_error_max"] = float(np.max(errors)) summary["m1_trial_error_p95"] = float(np.percentile(errors, 95)) else: summary["m1_trial_error_mean"] = 0.0 summary["m1_trial_error_max"] = 0.0 summary["m1_trial_error_p95"] = 0.0 if m1_trial_token_p95_errors: errors = np.asarray(m1_trial_token_p95_errors, dtype=np.float32) summary["m1_trial_token_p95_error_mean"] = float(np.mean(errors)) summary["m1_trial_token_p95_error_max"] = float(np.max(errors)) summary["m1_trial_token_p95_error_p95"] = float(np.percentile(errors, 95)) else: summary["m1_trial_token_p95_error_mean"] = 0.0 summary["m1_trial_token_p95_error_max"] = 0.0 summary["m1_trial_token_p95_error_p95"] = 0.0 for prefix, error_values in (("k", k_m1_trial_errors), ("v", v_m1_trial_errors)): if error_values: errors = np.asarray(error_values, dtype=np.float32) summary[f"{prefix}_m1_trial_error_mean"] = float(np.mean(errors)) summary[f"{prefix}_m1_trial_error_max"] = float(np.max(errors)) summary[f"{prefix}_m1_trial_error_p95"] = float(np.percentile(errors, 95)) else: summary[f"{prefix}_m1_trial_error_mean"] = 0.0 summary[f"{prefix}_m1_trial_error_max"] = 0.0 summary[f"{prefix}_m1_trial_error_p95"] = 0.0 for prefix, error_values in (("k", k_m1_trial_token_p95_errors), ("v", v_m1_trial_token_p95_errors)): if error_values: errors = np.asarray(error_values, dtype=np.float32) summary[f"{prefix}_m1_trial_token_p95_error_mean"] = float(np.mean(errors)) summary[f"{prefix}_m1_trial_token_p95_error_max"] = float(np.max(errors)) summary[f"{prefix}_m1_trial_token_p95_error_p95"] = float(np.percentile(errors, 95)) else: summary[f"{prefix}_m1_trial_token_p95_error_mean"] = 0.0 summary[f"{prefix}_m1_trial_token_p95_error_max"] = 0.0 summary[f"{prefix}_m1_trial_token_p95_error_p95"] = 0.0 summary["m2_prefilter_top_k"] = int(self.config.m2_prefilter_top_k) summary["m2_prefilter_min_pages"] = int(self.config.m2_prefilter_min_pages) summary["m2_prefilter_invocations"] = int(self._m2_prefilter_invocations) summary["m2_prefilter_candidate_pages"] = int(self._m2_prefilter_candidate_pages) summary["m2_prefilter_selected_pages"] = int(self._m2_prefilter_selected_pages) summary["learned_page_selector_enabled"] = bool(self._learned_page_selector_model is not None) summary["learned_page_selector_path"] = ( None if self.config.learned_page_selector_path is None else str(self.config.learned_page_selector_path) ) summary["learned_page_selector_prompt_family"] = ( None if self.config.learned_page_selector_prompt_family is None else str(self.config.learned_page_selector_prompt_family) ) summary["learned_page_selector_prompt_variant"] = ( None if self.config.learned_page_selector_prompt_variant is None else str(self.config.learned_page_selector_prompt_variant) ) summary["learned_page_selector_profile"] = str(self.config.learned_page_selector_profile) summary["learned_page_selector_scope"] = str(self.config.learned_page_selector_scope) summary["learned_page_selector_target_candidate"] = str(self.config.learned_page_selector_target_candidate) summary["learned_page_selector_logit_offset"] = float(self.config.learned_page_selector_logit_offset) summary["learned_page_selector_invocations"] = int(self._learned_page_selector_invocations) summary["learned_page_selector_fallbacks"] = int(self._learned_page_selector_fallbacks) summary["learned_page_selector_ms_total"] = float(self._learned_page_selector_ms_total) summary["learned_page_selector_invocations_by_stage"] = { stage: int(count) for stage, count in sorted(self._learned_page_selector_invocations_by_stage.items()) } summary["learned_page_selector_fallbacks_by_stage"] = { stage: int(count) for stage, count in sorted(self._learned_page_selector_fallbacks_by_stage.items()) } summary["learned_page_selector_ms_total_by_stage"] = { stage: float(ms) for stage, ms in sorted(self._learned_page_selector_ms_total_by_stage.items()) } summary["learned_page_selector_prediction_counts"] = { token: int(count) for token, count in sorted(self._learned_page_selector_predictions.items()) } summary["learned_page_selector_prediction_counts_by_stage"] = { stage: {token: int(count) for token, count in sorted(stage_counts.items())} for stage, stage_counts in sorted(self._learned_page_selector_predictions_by_stage.items()) } return summary def _batch_upload_persistent_tail_rows( self, tails: Sequence[_PersistentTailPage | None], rows_by_head: np.ndarray, *, token_start: int, trace: ExecutionTrace | None = None, ) -> None: active_pairs = [(tail, rows_by_head[index]) for index, tail in enumerate(tails) if tail is not None] if not active_pairs: return non_empty_pairs = [(tail, rows) for tail, rows in active_pairs if rows.shape[0] > 0] if not non_empty_pairs: return try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for the persistent torch tail path") from exc contiguous_rows = np.ascontiguousarray(np.stack([rows.astype(np.float16, copy=False) for _, rows in non_empty_pairs], axis=0)) device_rows = torch.from_numpy(contiguous_rows).to(device=non_empty_pairs[0][0].device_type) if trace is not None: trace.record_host_to_device(int(device_rows.numel() * device_rows.element_size())) for batch_index, (tail, rows) in enumerate(non_empty_pairs): if tail is None: continue if tail.host_buffer is None or tail.prepared_page is None: tail._ensure_allocated(token_start=token_start if tail.token_count == 0 else tail.source_page.header.token_start) tail.append_rows_from_device( rows=rows, device_rows=device_rows[batch_index], token_start=token_start, ) def _batch_append_persistent_tail_tensors( self, tails: Sequence[_PersistentTailPage | None], rows_by_head, *, token_start: int, ) -> None: try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for the persistent torch tail path") from exc if not torch.is_tensor(rows_by_head): raise TypeError("rows_by_head must be a torch.Tensor") if rows_by_head.ndim != 3: raise ValueError("rows_by_head must have shape [kv_heads, token_count, head_dim]") for index, tail in enumerate(tails): if tail is None or int(rows_by_head[index].shape[0]) == 0: continue if tail.prepared_page is None: tail._ensure_allocated(token_start=token_start if tail.token_count == 0 else tail.source_page.header.token_start) tail.append_device_rows(rows_by_head[index], token_start=token_start) def ingest_prefill_cache( self, layer_id: int, layer_k: np.ndarray, layer_v: np.ndarray, *, trace: ExecutionTrace | None = None, ) -> None: keys = _normalize_prefill_tensor( layer_k, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="layer_k", ) values = _normalize_prefill_tensor( layer_v, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="layer_v", ) if keys.shape[1] != values.shape[1]: raise ValueError("layer_k and layer_v sequence lengths must match") seq_len = int(keys.shape[1]) full_page_count = seq_len // self.config.tokens_per_page full_tokens = full_page_count * self.config.tokens_per_page preload_key_pages_by_head, preload_value_pages_by_head = self._encode_full_prefill_pages( layer_id, keys, values, sequence_length=seq_len, full_tokens=full_tokens, ) for kv_head_id in range(self.num_key_value_heads): state = self._state(layer_id, kv_head_id) state.clear(clear_prepared_cache=False) preload_key_pages = preload_key_pages_by_head[kv_head_id] preload_value_pages = preload_value_pages_by_head[kv_head_id] if preload_key_pages: state.session.append(preload_key_pages, preload_value_pages, prepare=False, trace=trace) state.invalidate_decode_views() remainder_keys = keys[kv_head_id, full_tokens:] remainder_values = values[kv_head_id, full_tokens:] state.tail.load_prefill_remainder(remainder_keys, remainder_values, token_start=full_tokens) state.sequence_length = seq_len if self._use_persistent_torch_tail: key_tails = [self._state(layer_id, kv_head_id).persistent_key_tail for kv_head_id in range(self.num_key_value_heads)] value_tails = [self._state(layer_id, kv_head_id).persistent_value_tail for kv_head_id in range(self.num_key_value_heads)] for kv_head_id in range(self.num_key_value_heads): tail = key_tails[kv_head_id] if tail is not None: tail.clear() if keys[kv_head_id, full_tokens:].shape[0] > 0: tail._ensure_allocated(token_start=full_tokens) tail = value_tails[kv_head_id] if tail is not None: tail.clear() if values[kv_head_id, full_tokens:].shape[0] > 0: tail._ensure_allocated(token_start=full_tokens) self._batch_upload_persistent_tail_rows( key_tails, keys[:, full_tokens:], token_start=full_tokens, trace=trace, ) self._batch_upload_persistent_tail_rows( value_tails, values[:, full_tokens:], token_start=full_tokens, trace=trace, ) self._rebuild_resident_accounting() self._mark_prepared_chunk_cache_budget_dirty(reason="ingest_prefill_cache") def ingest_prefill_cache_torch( self, layer_id: int, layer_k, layer_v, *, trace: ExecutionTrace | None = None, ) -> None: try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for ingest_prefill_cache_torch") from exc if self._torch_device_type is None: raise RuntimeError("ingest_prefill_cache_torch is only available for a torch accelerator backend") keys = _normalize_prefill_tensor_torch( layer_k, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="layer_k", ) values = _normalize_prefill_tensor_torch( layer_v, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="layer_v", ) if tuple(keys.shape) != tuple(values.shape): raise ValueError("layer_k and layer_v sequence lengths must match") seq_len = int(keys.shape[1]) full_page_count = seq_len // self.config.tokens_per_page full_tokens = full_page_count * self.config.tokens_per_page direct_prepare_full_pages = ( self._can_direct_prepare_full_prefill_pages_torch() and full_tokens > 0 and full_tokens == seq_len ) if direct_prepare_full_pages: page_size = int(self.config.tokens_per_page) full_key_pages_by_head = [ prepare_m0_affine_pages_from_tensor_torch( keys[kv_head_id, :full_tokens].reshape(full_page_count, page_size, self.config.head_dim), config=self.config, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=0, device_type=self._torch_device_type, build_runtime_metadata=self._should_build_execution_runtime_metadata(kind="K"), ) for kv_head_id in range(self.num_key_value_heads) ] full_value_pages_by_head = [ prepare_m0_affine_pages_from_tensor_torch( values[kv_head_id, :full_tokens].reshape(full_page_count, page_size, self.config.head_dim), config=self.config, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=0, device_type=self._torch_device_type, ) for kv_head_id in range(self.num_key_value_heads) ] elif full_tokens > 0: full_keys_cpu = keys[:, :full_tokens].detach().cpu().numpy() full_values_cpu = values[:, :full_tokens].detach().cpu().numpy() preload_key_pages_by_head, preload_value_pages_by_head = self._encode_full_prefill_pages( layer_id, full_keys_cpu, full_values_cpu, sequence_length=seq_len, full_tokens=full_tokens, ) else: preload_key_pages_by_head = [[] for _ in range(self.num_key_value_heads)] preload_value_pages_by_head = [[] for _ in range(self.num_key_value_heads)] if direct_prepare_full_pages: preload_key_pages_by_head = full_key_pages_by_head preload_value_pages_by_head = full_value_pages_by_head if not self._use_persistent_torch_tail: remainder_keys_cpu = keys[:, full_tokens:].detach().cpu().numpy() remainder_values_cpu = values[:, full_tokens:].detach().cpu().numpy() for kv_head_id in range(self.num_key_value_heads): state = self._state(layer_id, kv_head_id) state.clear(clear_prepared_cache=False) preload_key_pages = preload_key_pages_by_head[kv_head_id] preload_value_pages = preload_value_pages_by_head[kv_head_id] if preload_key_pages: state.session.append(preload_key_pages, preload_value_pages, prepare=False, trace=trace) state.invalidate_decode_views() if self._use_persistent_torch_tail: state.tail.clear() else: remainder_keys = remainder_keys_cpu[kv_head_id] remainder_values = remainder_values_cpu[kv_head_id] state.tail.load_prefill_remainder(remainder_keys, remainder_values, token_start=full_tokens) state.sequence_length = seq_len if self._use_persistent_torch_tail: key_tails = [self._state(layer_id, kv_head_id).persistent_key_tail for kv_head_id in range(self.num_key_value_heads)] value_tails = [self._state(layer_id, kv_head_id).persistent_value_tail for kv_head_id in range(self.num_key_value_heads)] for kv_head_id in range(self.num_key_value_heads): tail = key_tails[kv_head_id] if tail is not None: tail.clear() if int(keys[kv_head_id, full_tokens:].shape[0]) > 0: tail._ensure_allocated(token_start=full_tokens) tail = value_tails[kv_head_id] if tail is not None: tail.clear() if int(values[kv_head_id, full_tokens:].shape[0]) > 0: tail._ensure_allocated(token_start=full_tokens) self._batch_append_persistent_tail_tensors( key_tails, keys[:, full_tokens:], token_start=full_tokens, ) self._batch_append_persistent_tail_tensors( value_tails, values[:, full_tokens:], token_start=full_tokens, ) self._rebuild_resident_accounting() self._mark_prepared_chunk_cache_budget_dirty(reason="ingest_prefill_cache_torch") def append_step( self, layer_id: int, key_step: np.ndarray, value_step: np.ndarray, token_index: int, *, trace: ExecutionTrace | None = None, ) -> None: keys = _normalize_step_tensor( key_step, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="key_step", ) values = _normalize_step_tensor( value_step, num_key_value_heads=self.num_key_value_heads, head_dim=self.config.head_dim, name="value_step", ) if keys.shape[1] != values.shape[1]: raise ValueError("key_step and value_step token counts must match") token_count = int(keys.shape[1]) resident_bytes_changed = False if self._use_persistent_torch_tail: key_tails = [self._state(layer_id, kv_head_id).persistent_key_tail for kv_head_id in range(self.num_key_value_heads)] value_tails = [self._state(layer_id, kv_head_id).persistent_value_tail for kv_head_id in range(self.num_key_value_heads)] for kv_head_id in range(self.num_key_value_heads): key_tail = key_tails[kv_head_id] value_tail = value_tails[kv_head_id] if key_tail is not None: resident_bytes_changed = ( key_tail._ensure_allocated( token_start=token_index if key_tail.token_count == 0 else key_tail.source_page.header.token_start ) or resident_bytes_changed ) if value_tail is not None: resident_bytes_changed = ( value_tail._ensure_allocated( token_start=token_index if value_tail.token_count == 0 else value_tail.source_page.header.token_start ) or resident_bytes_changed ) self._batch_upload_persistent_tail_rows(key_tails, keys, token_start=token_index, trace=trace) self._batch_upload_persistent_tail_rows(value_tails, values, token_start=token_index, trace=trace) for kv_head_id in range(self.num_key_value_heads): state = self._state(layer_id, kv_head_id) if state.sequence_length != token_index: raise ValueError( f"layer {layer_id} kv_head {kv_head_id} expected token_index {state.sequence_length}, received {token_index}" ) finalized_key_pages, finalized_value_pages = state.tail.append_step_rows( keys[kv_head_id], values[kv_head_id], token_start=token_index, sequence_length=token_index + token_count, ) if finalized_key_pages: state.session.append(finalized_key_pages, finalized_value_pages, trace=trace) state.invalidate_decode_views() if state.persistent_key_tail is not None and state.persistent_value_tail is not None: state.persistent_key_tail.clear() state.persistent_value_tail.clear() state.sequence_length += token_count if resident_bytes_changed: for kv_head_id in range(self.num_key_value_heads): self._refresh_state_resident_accounting(self._state(layer_id, kv_head_id)) self._mark_prepared_chunk_cache_budget_dirty(reason="append_step_tail_alloc") return def append_step_torch( self, layer_id: int, key_step, value_step, token_index: int, *, trace: ExecutionTrace | None = None, ) -> None: try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for append_step_torch") from exc if not torch.is_tensor(key_step) or not torch.is_tensor(value_step): raise TypeError("append_step_torch requires torch.Tensor inputs") if self._torch_device_type is None: raise RuntimeError("append_step_torch is only available for a torch accelerator backend") keys = key_step.detach().to(dtype=torch.float32) values = value_step.detach().to(dtype=torch.float32) if keys.ndim == 4: if int(keys.shape[0]) != 1: raise ValueError("key_step batch dimension must be 1 for the Phase 5 Llama path") keys = keys[0] if values.ndim == 4: if int(values.shape[0]) != 1: raise ValueError("value_step batch dimension must be 1 for the Phase 5 Llama path") values = values[0] if keys.ndim != 3 or values.ndim != 3: raise ValueError("key_step and value_step must have shape [kv_heads, token_count, head_dim]") if int(keys.shape[0]) != self.num_key_value_heads or int(values.shape[0]) != self.num_key_value_heads: raise ValueError(f"append steps must contain {self.num_key_value_heads} KV heads") if int(keys.shape[2]) != self.config.head_dim or int(values.shape[2]) != self.config.head_dim: raise ValueError(f"append steps head_dim must equal {self.config.head_dim}") if tuple(keys.shape) != tuple(values.shape): raise ValueError("key_step and value_step token counts must match") token_count = int(keys.shape[1]) if not self._use_persistent_torch_tail: self.append_step( layer_id, keys.cpu().numpy(), values.cpu().numpy(), token_index, trace=trace, ) return key_tails = [self._state(layer_id, kv_head_id).persistent_key_tail for kv_head_id in range(self.num_key_value_heads)] value_tails = [self._state(layer_id, kv_head_id).persistent_value_tail for kv_head_id in range(self.num_key_value_heads)] resident_bytes_changed = False for kv_head_id in range(self.num_key_value_heads): state = self._state(layer_id, kv_head_id) if state.sequence_length != token_index: raise ValueError( f"layer {layer_id} kv_head {kv_head_id} expected token_index {state.sequence_length}, received {token_index}" ) if key_tails[kv_head_id] is not None: resident_bytes_changed = ( key_tails[kv_head_id]._ensure_allocated( token_start=token_index if key_tails[kv_head_id].token_count == 0 else key_tails[kv_head_id].source_page.header.token_start ) or resident_bytes_changed ) if value_tails[kv_head_id] is not None: resident_bytes_changed = ( value_tails[kv_head_id]._ensure_allocated( token_start=token_index if value_tails[kv_head_id].token_count == 0 else value_tails[kv_head_id].source_page.header.token_start ) or resident_bytes_changed ) self._batch_append_persistent_tail_tensors(key_tails, keys, token_start=token_index) self._batch_append_persistent_tail_tensors(value_tails, values, token_start=token_index) for kv_head_id in range(self.num_key_value_heads): state = self._state(layer_id, kv_head_id) if state.persistent_key_tail is None or state.persistent_value_tail is None: raise RuntimeError("persistent torch tail path requires allocated key/value tails") if state.tail.token_count > 0: state.tail.clear() if state.persistent_key_tail.token_count >= self.config.tokens_per_page: token_start_full = state.persistent_key_tail.source_page.header.token_start dense_keys = state.persistent_key_tail.materialize_rows() dense_values = state.persistent_value_tail.materialize_rows() finalized_key_page = encode_page( dense_keys, self.config, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start_full, mode=None, page_mode=( self._select_page_mode( dense_keys, kind="K", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start_full, sequence_length=int(token_index + token_count), stage="decode", ) ), build_runtime_metadata=self._should_build_execution_runtime_metadata(kind="K"), build_m2_sidecar=( self.config.m2_prefilter_top_k > 0 and (len(state.session.key_pages) + 1) >= int(self.config.m2_prefilter_min_pages) ), ) finalized_value_page = encode_page( dense_values, self.config, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start_full, mode=None, page_mode=( self._select_page_mode( dense_values, kind="V", layer_id=layer_id, kv_head_id=kv_head_id, token_start=token_start_full, sequence_length=int(token_index + token_count), stage="decode", ) ), build_runtime_metadata=False, ) state.session.append([finalized_key_page], [finalized_value_page], trace=trace) state.invalidate_decode_views() state.persistent_key_tail.clear() state.persistent_value_tail.clear() state.sequence_length += token_count if resident_bytes_changed: for kv_head_id in range(self.num_key_value_heads): self._refresh_state_resident_accounting(self._state(layer_id, kv_head_id)) self._mark_prepared_chunk_cache_budget_dirty(reason="append_step_torch_tail_alloc") def _prepared_pages_with_tail( self, layer_id: int, kv_head_id: int, *, trace: ExecutionTrace | None = None, ) -> tuple[list[PageLike], list[PageLike], _PreparedDecodeViewLayout | None]: capture_stage_timings = bool(trace is not None and trace.capture_timings) def _record_layout_build_timing(started_at: float | None) -> None: if started_at is None: return self._record_decode_stage_timing( layer_id=int(layer_id), stage="prepare_layout_build", ms=(perf_counter() - started_at) * 1000.0, ) state = self._state(layer_id, kv_head_id) self._ensure_prepared_static_pages(state, trace=trace) if state.persistent_key_tail is not None and state.persistent_value_tail is not None: prepared_key_tail = state.persistent_key_tail.active_page prepared_value_tail = state.persistent_value_tail.active_page if prepared_key_tail is not None and prepared_value_tail is not None: cached_key_pages = state.decode_key_pages_with_tail cached_value_pages = state.decode_value_pages_with_tail if ( cached_key_pages is not None and cached_value_pages is not None and cached_key_pages and cached_value_pages and cached_key_pages[-1] is prepared_key_tail and cached_value_pages[-1] is prepared_value_tail and len(cached_key_pages) == len(state.session.key_pages) + 1 and len(cached_value_pages) == len(state.session.value_pages) + 1 ): return cached_key_pages, cached_value_pages, state.decode_view_layout key_pages = list(state.session.key_pages) value_pages = list(state.session.value_pages) key_pages.append(prepared_key_tail) value_pages.append(prepared_value_tail) state.decode_key_pages_with_tail = key_pages state.decode_value_pages_with_tail = value_pages layout_started_at = perf_counter() if capture_stage_timings else None state.decode_view_layout = _build_prepared_decode_view_layout(key_pages, value_pages) _record_layout_build_timing(layout_started_at) return key_pages, value_pages, state.decode_view_layout state.invalidate_decode_views() layout_started_at = perf_counter() if capture_stage_timings else None state.decode_view_layout = _build_prepared_decode_view_layout(state.session.key_pages, state.session.value_pages) _record_layout_build_timing(layout_started_at) return state.session.key_pages, state.session.value_pages, state.decode_view_layout temp_pages = state.tail.build_temp_pages() if temp_pages is None: layout_started_at = perf_counter() if capture_stage_timings else None state.decode_view_layout = _build_prepared_decode_view_layout(state.session.key_pages, state.session.value_pages) _record_layout_build_timing(layout_started_at) return state.session.key_pages, state.session.value_pages, state.decode_view_layout temp_key_page, temp_value_page = temp_pages # Temporary live-tail pages are rebuilt on demand and should not go # through the shared prepared-page cache keyed by object id. prepared_temp_key_page = prepare_pages([temp_key_page], backend=self.backend, cache=None, trace=trace)[0] prepared_temp_value_page = prepare_pages([temp_value_page], backend=self.backend, cache=None, trace=trace)[0] key_pages = list(state.session.key_pages) value_pages = list(state.session.value_pages) key_pages.append(prepared_temp_key_page) value_pages.append(prepared_temp_value_page) layout_started_at = perf_counter() if capture_stage_timings else None layout = _build_prepared_decode_view_layout(key_pages, value_pages) _record_layout_build_timing(layout_started_at) return key_pages, value_pages, layout def decode_layer( self, layer_id: int, query_step: np.ndarray, q_head_to_kv_head: Sequence[int] | np.ndarray, *, query_scale: float = 1.0, trace: ExecutionTrace | None = None, ) -> np.ndarray: queries = _normalize_query_step( query_step, num_attention_heads=self.num_attention_heads, head_dim=self.config.head_dim, ) scaled_queries = queries * np.float32(query_scale) grouped_query_heads = self._grouped_query_heads_for_mapping(q_head_to_kv_head) outputs = np.zeros((self.num_attention_heads, self.config.head_dim), dtype=np.float32) for kv_head_id, q_head_ids in enumerate(grouped_query_heads): if not q_head_ids: continue key_pages, value_pages, _ = self._prepared_pages_with_tail(layer_id, kv_head_id, trace=trace) if not key_pages: raise ValueError(f"layer {layer_id} kv_head {kv_head_id} has no cached tokens to decode against") kv_queries = scaled_queries[list(q_head_ids)] key_pages, value_pages = self._m2_prefilter_pages_numpy(kv_queries, key_pages, value_pages) self._sync_prepared_chunk_cache_budget() _, _, kv_outputs = decode_multi_query_step( kv_queries, key_pages, value_pages, backend=self.backend, trace=trace, ) outputs[list(q_head_ids)] = kv_outputs return outputs def analyze_execution_shortlist_layer( self, layer_id: int, query_step: np.ndarray, q_head_to_kv_head: Sequence[int] | np.ndarray, *, query_scale: float = 1.0, prefer_grouped_batching: bool = True, trace: ExecutionTrace | None = None, ) -> dict[str, object]: queries = _normalize_query_step( query_step, num_attention_heads=self.num_attention_heads, head_dim=self.config.head_dim, ) scaled_queries = queries * np.float32(query_scale) grouped_query_heads = self._grouped_query_heads_for_mapping(q_head_to_kv_head) layer_prefer_grouped_batching = ( bool(prefer_grouped_batching) and not self.config.execution_grouped_batching_disabled_for_layer(layer_id=layer_id) ) shortlist_enabled = self.config.execution_shortlist_enabled() group_entries: list[dict[str, object]] = [] raw_selected_indices_by_group: list[list[int] | None] = [] full_selected_indices_by_group: list[list[int]] = [] key_pages_by_group: list[list[PageLike]] = [] window_index_sets_by_group: list[set[int]] = [] representative_queries: list[np.ndarray] = [] shortlist_trace_records_by_group: list[dict[str, object] | None] = [] for kv_head_id, q_head_ids in enumerate(grouped_query_heads): if not q_head_ids: continue key_pages, value_pages, _ = self._prepared_pages_with_tail(layer_id, kv_head_id, trace=trace) if not key_pages: raise ValueError(f"layer {layer_id} kv_head {kv_head_id} has no cached tokens to decode against") state = self._state(layer_id, kv_head_id) kv_queries = scaled_queries[list(q_head_ids)] key_pages, _ = self._m2_prefilter_pages_numpy(kv_queries, key_pages, value_pages) representative_query = kv_queries.mean(axis=0).astype(np.float32, copy=False) raw_selected_indices = None page_max_context_length = ( max(_page_header(page).token_start + _page_header(page).token_count for page in key_pages) if key_pages else 0 ) context_length = int(state.sequence_length) if int(state.sequence_length) > 0 else int(page_max_context_length) layer_recent_window = int( self.config.resolve_execution_recent_window_for_context( layer_id=layer_id, context_length=context_length, ) ) if shortlist_enabled: trace_record_count = len(self._execution_shortlist_trace_records) raw_selected_indices = self._execution_shortlist_page_indices( key_pages, layer_id=layer_id, kv_head_id=int(kv_head_id), query_slice=representative_query, context_length_override=int(context_length) if int(context_length) > 0 else None, trace=trace, ) shortlist_trace_record = ( dict(self._execution_shortlist_trace_records[-1]) if len(self._execution_shortlist_trace_records) > trace_record_count else None ) else: shortlist_trace_record = None selected_indices = ( list(range(len(key_pages))) if raw_selected_indices is None else [int(index) for index in raw_selected_indices] ) window_index_set = set( select_window_page_indices( key_pages, recent_window_tokens=layer_recent_window if layer_recent_window > 0 else None, sink_window_tokens=int(self.config.execution_sink_window), ) ) group_entries.append( { "kv_head_id": int(kv_head_id), "query_head_ids": list(q_head_ids), "layer_recent_window": int(layer_recent_window), "context_length_page_max": int(page_max_context_length), "context_length_effective": int(context_length), "context_length_override_applied": bool(int(state.sequence_length) > 0), } ) raw_selected_indices_by_group.append(raw_selected_indices) full_selected_indices_by_group.append(selected_indices) key_pages_by_group.append(key_pages) window_index_sets_by_group.append(window_index_set) representative_queries.append(representative_query) shortlist_trace_records_by_group.append(shortlist_trace_record) shortlist_attempted = any(indices is not None for indices in raw_selected_indices_by_group) union_rescue_records: list[dict[str, object]] = [] if shortlist_attempted and len(group_entries) > 1 and layer_prefer_grouped_batching: adjusted_selected_indices_by_group, union_rescue_records = self._apply_execution_exact_promote_union_rescue( layer_id=layer_id, selected_indices_by_group=full_selected_indices_by_group, key_pages_by_group=key_pages_by_group, representative_queries=representative_queries, shortlist_traces_by_group=shortlist_trace_records_by_group, trace=trace, ) full_selected_indices_by_group = [ list(range(len(key_pages))) if indices is None else [int(index) for index in indices] for key_pages, indices in zip(key_pages_by_group, adjusted_selected_indices_by_group, strict=True) ] raw_selected_indices_by_group = [ None if indices is None else [int(index) for index in indices] for indices in adjusted_selected_indices_by_group ] self._execution_shortlist_trace_records.extend(union_rescue_records) union_indices = sorted( { index for indices in raw_selected_indices_by_group if indices is not None for index in indices } ) union_active = bool(union_indices) and len(group_entries) > 1 and layer_prefer_grouped_batching layer_exact_top_budget_total = 0 layer_exact_top_overlap_total = 0 layer_union_added_pages_total = 0 layer_missed_age_buckets = {"recent": 0, "middle": 0, "old": 0} layer_recalls: list[float] = [] layer_first_missed_ranks: list[int] = [] for group_index, entry in enumerate(group_entries): key_pages = key_pages_by_group[group_index] selected_indices = full_selected_indices_by_group[group_index] union_rescue_record = next( ( record for record in union_rescue_records if int(record.get("group_index", -1)) == int(group_index) ), None, ) if union_active: final_indices = list(union_indices) union_added_indices = sorted(set(union_indices) - set(selected_indices)) else: final_indices = list(selected_indices) union_added_indices = [] window_index_set = window_index_sets_by_group[group_index] representative_query = representative_queries[group_index] old_candidate_indices = [index for index in range(len(key_pages)) if index not in window_index_set] selected_old_indices = [index for index in final_indices if index not in window_index_set] exact_top_budget = min(len(selected_old_indices), len(old_candidate_indices)) approx_top_budget = min( int( self.config.resolve_execution_relevance_top_k_for_context( layer_id=layer_id, context_length=context_length, ) ), len(old_candidate_indices), ) exact_rank_by_index: dict[int, int] = {} exact_top_indices: list[int] = [] approx_rank_by_index: dict[int, int] = {} approx_top_indices: list[int] = [] approx_scores_by_index: dict[int, float] = {} exact_scores_by_index: dict[int, float] = {} approx_boundary_margin = None approx_boundary_margin_normalized = None secondary_rank_by_index: dict[int, int] = {} secondary_top_indices: list[int] = [] secondary_scores_by_index: dict[int, float] = {} recent_neighbor_anchor_pages = 0 recent_neighbor_recent_old_pages = 0 if old_candidate_indices: for index in old_candidate_indices: approx_score = _score_page_relevance_for_mode( representative_query, key_pages[index], relevance_mode=self.config.execution_relevance_mode, ) if approx_score is None: raise ValueError( f"missing {self.config.execution_relevance_mode} relevance sidecars for layer {layer_id}" ) approx_scores_by_index[int(index)] = float(approx_score) approx_ranked_pairs = sorted( ((score, int(index)) for index, score in approx_scores_by_index.items()), key=lambda item: item[0], reverse=True, ) if len(approx_ranked_pairs) > approx_top_budget and approx_top_budget > 0: approx_boundary_margin = float( approx_ranked_pairs[approx_top_budget - 1][0] - approx_ranked_pairs[approx_top_budget][0] ) approx_std = max( float(np.std(np.asarray(list(approx_scores_by_index.values()), dtype=np.float32))), 1e-6, ) approx_boundary_margin_normalized = float(approx_boundary_margin / approx_std) approx_rank_by_index = {int(index): rank for rank, (_, index) in enumerate(approx_ranked_pairs, start=1)} approx_top_indices = [int(index) for _, index in approx_ranked_pairs[:approx_top_budget]] if self._execution_secondary_relevance_enabled(layer_id=layer_id): secondary_scores_missing = False for index in old_candidate_indices: secondary_score = _score_page_relevance_for_mode( representative_query, key_pages[index], relevance_mode=self.config.execution_secondary_relevance_mode, ) if secondary_score is None: secondary_scores_missing = True secondary_scores_by_index = {} secondary_rank_by_index = {} secondary_top_indices = [] break secondary_scores_by_index[int(index)] = float(secondary_score) if not secondary_scores_missing: secondary_ranked_pairs = sorted( ((score, int(index)) for index, score in secondary_scores_by_index.items()), key=lambda item: item[0], reverse=True, ) secondary_rank_by_index = { int(index): rank for rank, (_, index) in enumerate(secondary_ranked_pairs, start=1) } secondary_top_indices = [int(index) for _, index in secondary_ranked_pairs[:approx_top_budget]] if self._execution_recent_neighbor_rescue_enabled(layer_id=layer_id) and layer_recent_window > 0: recent_start = int(context_length) - int(layer_recent_window) recent_neighbor_anchor_pages = sum( 1 for index in approx_top_indices if int(_page_header(key_pages[index]).token_start + _page_header(key_pages[index]).token_count) <= int(self.config.execution_recent_neighbor_rescue_anchor_window) ) recent_neighbor_recent_old_pages = sum( 1 for index in approx_top_indices if ( int(_page_header(key_pages[index]).token_start + _page_header(key_pages[index]).token_count) <= int(recent_start) and int(_page_header(key_pages[index]).token_start + _page_header(key_pages[index]).token_count) > int(recent_start - layer_recent_window) ) ) candidate_logits = score_pages( representative_query, [key_pages[index] for index in old_candidate_indices], backend=self.backend, trace=trace, ) ranked_pairs = sorted( ( ( float(np.max(np.asarray(logits, dtype=np.float32))), int(index), ) for index, logits in zip(old_candidate_indices, candidate_logits, strict=True) ), key=lambda item: item[0], reverse=True, ) exact_scores_by_index = { int(index): float(score) for score, index in ranked_pairs } exact_rank_by_index = {int(index): rank for rank, (_, index) in enumerate(ranked_pairs, start=1)} exact_top_indices = [int(index) for _, index in ranked_pairs[:exact_top_budget]] exact_top_index_set = set(exact_top_indices) selected_old_index_set = set(selected_old_indices) exact_top_overlap = len(selected_old_index_set & exact_top_index_set) approx_top_index_set = set(approx_top_indices) approx_exact_top_overlap = len(approx_top_index_set & exact_top_index_set) secondary_top_index_set = set(secondary_top_indices) secondary_primary_top_overlap = len(secondary_top_index_set & approx_top_index_set) secondary_exact_top_overlap = len(secondary_top_index_set & exact_top_index_set) if exact_top_budget > 0: exact_top_recall = float(exact_top_overlap) / float(exact_top_budget) layer_recalls.append(exact_top_recall) else: exact_top_recall = 1.0 if exact_top_budget > 0: approx_exact_top_recall = float(approx_exact_top_overlap) / float(exact_top_budget) secondary_exact_top_recall = float(secondary_exact_top_overlap) / float(exact_top_budget) else: approx_exact_top_recall = 1.0 secondary_exact_top_recall = 1.0 secondary_primary_top_recall = ( float(secondary_primary_top_overlap) / float(max(min(len(secondary_top_indices), len(approx_top_indices)), 1)) if secondary_top_indices and approx_top_indices else (1.0 if not secondary_top_indices else 0.0) ) secondary_triggered = bool( secondary_top_indices and secondary_primary_top_recall < float(self.config.execution_secondary_relevance_min_overlap) ) exact_promote_candidate_expansion_enabled, exact_promote_candidate_expansion_reason = ( self._execution_exact_promote_policy_status(layer_id=layer_id, context_length=context_length) ) exact_promote_enabled, exact_promote_disable_reason = self._execution_exact_promote_status( layer_id=layer_id, context_length=context_length, boundary_margin_normalized=approx_boundary_margin_normalized, ) recent_neighbor_rescue_triggered = bool( self._execution_recent_neighbor_rescue_enabled(layer_id=layer_id) and recent_neighbor_anchor_pages >= int(self.config.execution_recent_neighbor_rescue_min_anchor_pages) and recent_neighbor_recent_old_pages > 0 ) missed_exact_indices = [index for index in exact_top_indices if index not in selected_old_index_set] first_missed_exact_rank = ( int(exact_rank_by_index[missed_exact_indices[0]]) if missed_exact_indices else None ) if first_missed_exact_rank is not None: layer_first_missed_ranks.append(first_missed_exact_rank) scorer_missed_exact_indices = [index for index in exact_top_indices if index not in approx_top_index_set] first_scorer_missed_exact_rank = ( int(exact_rank_by_index[scorer_missed_exact_indices[0]]) if scorer_missed_exact_indices else None ) missed_exact_age_buckets = {"recent": 0, "middle": 0, "old": 0} for index in missed_exact_indices: age_bucket = _page_age_bucket(key_pages[index], context_length=int(context_length)) missed_exact_age_buckets[age_bucket] += 1 layer_missed_age_buckets[age_bucket] += 1 scorer_missed_exact_age_buckets = {"recent": 0, "middle": 0, "old": 0} for index in scorer_missed_exact_indices: age_bucket = _page_age_bucket(key_pages[index], context_length=int(context_length)) scorer_missed_exact_age_buckets[age_bucket] += 1 exact_top1_approx_rank = None approx_top1_exact_rank = None secondary_top1_exact_rank = None primary_top1_secondary_rank = None score_rank_correlation = None score_value_correlation = None mean_abs_rank_error = None if exact_top_indices: exact_top1_approx_rank = approx_rank_by_index.get(int(exact_top_indices[0])) if approx_top_indices: approx_top1_exact_rank = exact_rank_by_index.get(int(approx_top_indices[0])) primary_top1_secondary_rank = secondary_rank_by_index.get(int(approx_top_indices[0])) if secondary_top_indices: secondary_top1_exact_rank = exact_rank_by_index.get(int(secondary_top_indices[0])) if exact_rank_by_index and approx_rank_by_index: shared_indices = [index for index in old_candidate_indices if index in exact_rank_by_index and index in approx_rank_by_index] if shared_indices: exact_ranks = [float(exact_rank_by_index[index]) for index in shared_indices] approx_ranks = [float(approx_rank_by_index[index]) for index in shared_indices] score_rank_correlation = _rank_correlation(approx_ranks, exact_ranks) mean_abs_rank_error = float( np.mean( np.abs( np.asarray(approx_ranks, dtype=np.float32) - np.asarray(exact_ranks, dtype=np.float32) ) ) ) exact_scores = [float(exact_scores_by_index[index]) for index in shared_indices] approx_scores = [float(approx_scores_by_index[index]) for index in shared_indices] score_value_correlation = _rank_correlation(approx_scores, exact_scores) union_added_old_indices = [index for index in union_added_indices if index not in window_index_set] union_added_ranks = [ int(exact_rank_by_index[index]) for index in union_added_old_indices if index in exact_rank_by_index ] union_added_mean_exact_rank = ( float(sum(union_added_ranks) / len(union_added_ranks)) if union_added_ranks else None ) layer_exact_top_budget_total += int(exact_top_budget) layer_exact_top_overlap_total += int(exact_top_overlap) layer_union_added_pages_total += int(len(union_added_indices)) entry.update( { "context_length": int(context_length), "total_pages": int(len(key_pages)), "window_pages": int(len(window_index_set)), "old_candidate_pages": int(len(old_candidate_indices)), "selected_pages": int(len(final_indices)), "selected_old_pages": int(len(selected_old_indices)), "exact_top_budget": int(exact_top_budget), "exact_top_overlap": int(exact_top_overlap), "exact_top_recall": float(exact_top_recall), "approx_top_budget": int(approx_top_budget), "approx_exact_top_overlap": int(approx_exact_top_overlap), "approx_exact_top_recall": float(approx_exact_top_recall), "secondary_relevance_mode": ( self.config.execution_secondary_relevance_mode if self._execution_secondary_relevance_enabled(layer_id=layer_id) else None ), "secondary_primary_top_overlap": int(secondary_primary_top_overlap), "secondary_primary_top_recall": float(secondary_primary_top_recall), "secondary_exact_top_overlap": int(secondary_exact_top_overlap), "secondary_exact_top_recall": float(secondary_exact_top_recall), "secondary_triggered": bool(secondary_triggered), "exact_promote_candidate_expansion_enabled": bool(exact_promote_candidate_expansion_enabled), "exact_promote_candidate_expansion_disable_reason": exact_promote_candidate_expansion_reason, "exact_promote_enabled": bool(exact_promote_enabled), "exact_promote_disable_reason": exact_promote_disable_reason, "union_exact_promote_rescue_applied": bool( union_rescue_record is not None and bool(union_rescue_record.get("applied", False)) ), "union_exact_promote_rescue_disable_reason": ( None if union_rescue_record is None else union_rescue_record.get("disable_reason") ), "union_exact_promote_rescue_selected_novel_count": int( 0 if union_rescue_record is None else union_rescue_record.get("selected_novel_count", 0) ), "union_exact_promote_rescue_selected_novel_page_ranges": ( [] if union_rescue_record is None else list(union_rescue_record.get("selected_novel_page_ranges", [])) ), "recent_neighbor_anchor_pages": int(recent_neighbor_anchor_pages), "recent_neighbor_recent_old_pages": int(recent_neighbor_recent_old_pages), "recent_neighbor_rescue_triggered": bool(recent_neighbor_rescue_triggered), "approx_boundary_margin": approx_boundary_margin, "approx_boundary_margin_normalized": approx_boundary_margin_normalized, "exact_top1_approx_rank": exact_top1_approx_rank, "approx_top1_exact_rank": approx_top1_exact_rank, "secondary_top1_exact_rank": secondary_top1_exact_rank, "primary_top1_secondary_rank": primary_top1_secondary_rank, "first_scorer_missed_exact_rank": first_scorer_missed_exact_rank, "score_rank_correlation": score_rank_correlation, "score_value_correlation": score_value_correlation, "mean_abs_rank_error": mean_abs_rank_error, "first_missed_exact_rank": first_missed_exact_rank, "union_active": bool(union_active), "union_added_pages": int(len(union_added_indices)), "union_added_exact_top_hits": int(sum(1 for index in union_added_old_indices if index in exact_top_index_set)), "union_added_mean_exact_rank": union_added_mean_exact_rank, "selected_old_page_ranges": [_page_token_range(key_pages[index]) for index in selected_old_indices], "top_approx_page_ranges": [_page_token_range(key_pages[index]) for index in approx_top_indices], "top_secondary_page_ranges": [_page_token_range(key_pages[index]) for index in secondary_top_indices], "top_exact_page_ranges": [_page_token_range(key_pages[index]) for index in exact_top_indices], "missed_exact_page_ranges": [_page_token_range(key_pages[index]) for index in missed_exact_indices], "missed_exact_age_buckets": missed_exact_age_buckets, "scorer_missed_exact_page_ranges": [_page_token_range(key_pages[index]) for index in scorer_missed_exact_indices], "scorer_missed_exact_age_buckets": scorer_missed_exact_age_buckets, } ) first_missed_exact_rank_min = min(layer_first_missed_ranks) if layer_first_missed_ranks else None return { "layer_id": int(layer_id), "shortlist_enabled": bool(shortlist_enabled), "shortlist_attempted": bool(shortlist_attempted), "grouped_batching_enabled": bool(layer_prefer_grouped_batching), "union_active": bool(union_active), "group_count": int(len(group_entries)), "exact_top_budget_total": int(layer_exact_top_budget_total), "exact_top_overlap_total": int(layer_exact_top_overlap_total), "exact_top_recall_mean": ( float(sum(layer_recalls) / len(layer_recalls)) if layer_recalls else 1.0 ), "exact_top_recall_min": float(min(layer_recalls)) if layer_recalls else 1.0, "first_missed_exact_rank_min": first_missed_exact_rank_min, "union_added_pages_total": int(layer_union_added_pages_total), "missed_exact_age_buckets": dict(layer_missed_age_buckets), "groups": group_entries, } def decode_layer_torch( self, layer_id: int, query_step, q_head_to_kv_head: Sequence[int] | np.ndarray, *, query_scale: float = 1.0, prefer_grouped_batching: bool = True, trace: ExecutionTrace | None = None, ): try: import torch except ImportError as exc: # pragma: no cover raise RuntimeError("torch is required for decode_layer_torch") from exc if not torch.is_tensor(query_step): raise TypeError("decode_layer_torch requires a torch.Tensor query_step") if self._torch_device_type is None: raise RuntimeError("decode_layer_torch is only available for a torch accelerator backend") if query_step.ndim == 4: if tuple(query_step.shape[:1] + query_step.shape[2:3]) != (1, 1): raise ValueError("query_step must have shape [q_heads, head_dim] or [1, q_heads, 1, head_dim]") queries = query_step[0, :, 0, :] elif query_step.ndim == 2: queries = query_step else: raise ValueError("query_step must have shape [q_heads, head_dim]") if int(queries.shape[0]) != self.num_attention_heads: raise ValueError(f"query_step must contain {self.num_attention_heads} query heads") if int(queries.shape[1]) != self.config.head_dim: raise ValueError(f"query_step head_dim must equal {self.config.head_dim}") scaled_queries = queries.to(dtype=torch.float32) * float(query_scale) grouped_query_heads = self._grouped_query_heads_for_mapping(q_head_to_kv_head) layer_prefer_grouped_batching = ( bool(prefer_grouped_batching) and not self.config.execution_grouped_batching_disabled_for_layer(layer_id=layer_id) ) capture_stage_timings = bool(trace is not None and trace.capture_timings) def _stage_start() -> float | None: return perf_counter() if capture_stage_timings else None def _stage_finish(stage: str, started_at: float | None) -> None: if started_at is None: return self._record_decode_stage_timing( layer_id=int(layer_id), stage=stage, ms=(perf_counter() - started_at) * 1000.0, ) outputs = torch.zeros( (self.num_attention_heads, self.config.head_dim), dtype=torch.float32, device=scaled_queries.device, ) active_q_head_ids: list[tuple[int, ...]] = [] active_queries: list[Any] = [] active_key_pages: list[Sequence[PageLike]] = [] active_value_pages: list[Sequence[PageLike]] = [] active_layouts: list[_PreparedDecodeViewLayout | None] = [] active_context_lengths: list[int] = [] active_representative_queries: list[np.ndarray] = [] original_key_pages_by_group: list[Sequence[PageLike]] = [] original_value_pages_by_group: list[Sequence[PageLike]] = [] original_layouts: list[_PreparedDecodeViewLayout | None] = [] shortlist_selected_indices_by_group: list[list[int] | None] = [] shortlist_trace_records_by_group: list[dict[str, object] | None] = [] for kv_head_id, q_head_ids in enumerate(grouped_query_heads): if not q_head_ids: continue prepare_started_at = _stage_start() key_pages, value_pages, decode_layout = self._prepared_pages_with_tail(layer_id, kv_head_id, trace=trace) _stage_finish("prepare_pages_with_tail", prepare_started_at) if not key_pages: raise ValueError(f"layer {layer_id} kv_head {kv_head_id} has no cached tokens to decode against") state = self._state(layer_id, kv_head_id) kv_queries = scaled_queries[list(q_head_ids)] m2_prefilter_started_at = _stage_start() key_pages, value_pages = self._m2_prefilter_pages_torch(kv_queries, key_pages, value_pages) _stage_finish("m2_prefilter", m2_prefilter_started_at) selected_indices = None shortlist_trace_record = None representative_query = kv_queries.mean(dim=0).detach().cpu().numpy().astype(np.float32, copy=False) if self.config.execution_shortlist_enabled(): query_export_started_at = _stage_start() _stage_finish("query_export", query_export_started_at) trace_record_count = len(self._execution_shortlist_trace_records) shortlist_started_at = _stage_start() selected_indices = self._execution_shortlist_page_indices( key_pages, layer_id=layer_id, kv_head_id=int(kv_head_id), query_slice=representative_query, context_length_override=int(state.sequence_length) if int(state.sequence_length) > 0 else None, trace=trace, ) if len(self._execution_shortlist_trace_records) > trace_record_count: shortlist_trace_record = dict(self._execution_shortlist_trace_records[-1]) _stage_finish("shortlist_selection", shortlist_started_at) active_q_head_ids.append(q_head_ids) active_queries.append(kv_queries) original_key_pages_by_group.append(key_pages) original_value_pages_by_group.append(value_pages) original_layouts.append(decode_layout) active_key_pages.append(key_pages) active_value_pages.append(value_pages) active_layouts.append(decode_layout) active_context_lengths.append(int(state.sequence_length)) active_representative_queries.append(representative_query) shortlist_selected_indices_by_group.append(selected_indices) shortlist_trace_records_by_group.append(shortlist_trace_record) shortlist_group_union_applied = False shortlist_attempted = any(indices is not None for indices in shortlist_selected_indices_by_group) shortlist_applied = False shortlist_selected_pages = 0 shortlist_total_pages = 0 if shortlist_attempted: total_pages_per_group = [len(pages) for pages in original_key_pages_by_group] shortlist_total_pages = int(sum(total_pages_per_group)) if len(active_queries) > 1 and layer_prefer_grouped_batching: query_export_started_at = _stage_start() representative_queries = [ query.mean(dim=0).detach().cpu().numpy().astype(np.float32, copy=False) for query in active_queries ] _stage_finish("query_export", query_export_started_at) union_rescue_started_at = _stage_start() shortlist_selected_indices_by_group, union_rescue_records = self._apply_execution_exact_promote_union_rescue( layer_id=layer_id, selected_indices_by_group=shortlist_selected_indices_by_group, key_pages_by_group=original_key_pages_by_group, representative_queries=representative_queries, shortlist_traces_by_group=shortlist_trace_records_by_group, trace=trace, ) self._execution_shortlist_trace_records.extend(union_rescue_records) _stage_finish("shortlist_union_rescue", union_rescue_started_at) shortlist_materialization_started_at = _stage_start() union_indices = sorted( { index for indices in shortlist_selected_indices_by_group if indices is not None for index in indices } ) use_union_indices = bool(union_indices) and len(active_queries) > 1 and layer_prefer_grouped_batching if use_union_indices: shortlist_group_union_applied = True shortlist_selected_pages = int(len(union_indices) * len(active_queries)) if len(union_indices) < total_pages_per_group[0]: shortlist_applied = True active_key_pages = [[pages[index] for index in union_indices] for pages in original_key_pages_by_group] active_value_pages = [[pages[index] for index in union_indices] for pages in original_value_pages_by_group] active_layouts = [None] * len(active_key_pages) else: shortlisted_key_pages: list[Sequence[PageLike]] = [] shortlisted_value_pages: list[Sequence[PageLike]] = [] shortlisted_layouts: list[_PreparedDecodeViewLayout | None] = [] for key_pages, value_pages, decode_layout, selected_indices in zip( original_key_pages_by_group, original_value_pages_by_group, original_layouts, shortlist_selected_indices_by_group, strict=True, ): if selected_indices is None: shortlisted_key_pages.append(key_pages) shortlisted_value_pages.append(value_pages) shortlisted_layouts.append(decode_layout) shortlist_selected_pages += int(len(key_pages)) continue shortlist_selected_pages += int(len(selected_indices)) if len(selected_indices) < len(key_pages): shortlist_applied = True shortlisted_key_pages.append([key_pages[index] for index in selected_indices]) shortlisted_value_pages.append([value_pages[index] for index in selected_indices]) shortlisted_layouts.append(None) else: shortlisted_key_pages.append(key_pages) shortlisted_value_pages.append(value_pages) shortlisted_layouts.append(decode_layout) active_key_pages = shortlisted_key_pages active_value_pages = shortlisted_value_pages active_layouts = shortlisted_layouts _stage_finish("shortlist_materialization", shortlist_materialization_started_at) value_escape_applied = False if self.config.execution_value_escape_enabled_for_layer(layer_id=layer_id): active_value_pages, value_escape_applied = self._apply_execution_value_escape( layer_id=layer_id, key_pages_by_group=active_key_pages, value_pages_by_group=active_value_pages, context_lengths_by_group=active_context_lengths, representative_queries_by_group=active_representative_queries, trace=trace, ) if value_escape_applied: active_layouts = [None] * len(active_layouts) chunk_budget_sync_started_at = _stage_start() self._sync_prepared_chunk_cache_budget( freeze_during_decode=bool(self.config.execution_freeze_chunk_budget_during_decode) ) _stage_finish("chunk_budget_sync", chunk_budget_sync_started_at) grouping_validation_started_at = _stage_start() if shortlist_attempted and shortlist_applied and layer_prefer_grouped_batching and len(active_queries) > 1: shortlist_grouping_rejection_reason = _grouped_pages_batch_rejection_reason( active_key_pages, active_value_pages, active_queries, ) shortlisted_can_batch = shortlist_grouping_rejection_reason is None if not shortlisted_can_batch: self._record_execution_shortlist( layer_id=layer_id, total_pages=shortlist_total_pages, selected_pages=shortlist_selected_pages, applied=False, group_union_applied=shortlist_group_union_applied, grouping_rejected=True, grouping_rejection_reason=shortlist_grouping_rejection_reason, ) active_key_pages = list(original_key_pages_by_group) active_value_pages = list(original_value_pages_by_group) active_layouts = list(original_layouts) shortlist_applied = False else: self._record_execution_shortlist( layer_id=layer_id, total_pages=shortlist_total_pages, selected_pages=shortlist_selected_pages, applied=True, group_union_applied=shortlist_group_union_applied, ) elif shortlist_attempted: self._record_execution_shortlist( layer_id=layer_id, total_pages=shortlist_total_pages, selected_pages=shortlist_selected_pages, applied=shortlist_applied, group_union_applied=shortlist_group_union_applied, ) grouped_layout_rejection_reason = None grouped_page_rejection_reason = None if layer_prefer_grouped_batching: grouped_layout_rejection_reason = _grouped_layout_batch_rejection_reason(active_layouts, active_queries) grouped_page_rejection_reason = _grouped_pages_batch_rejection_reason( active_key_pages, active_value_pages, active_queries, ) cached_group_layout = layer_prefer_grouped_batching and grouped_layout_rejection_reason is None grouped_path_ready = cached_group_layout or (layer_prefer_grouped_batching and grouped_page_rejection_reason is None) _stage_finish("grouping_validation", grouping_validation_started_at) if grouped_path_ready: self._record_decode_path(layer_id, "grouped_batched") key_chunk_lengths = active_layouts[0].key_chunk_lengths if cached_group_layout and active_layouts[0] is not None else None value_chunk_lengths = active_layouts[0].value_chunk_lengths if cached_group_layout and active_layouts[0] is not None else None backend_started_at = _stage_start() backend_trace_before = _backend_trace_ms_total(trace) if capture_stage_timings else 0.0 _, _, grouped_outputs = decode_grouped_multiquery_step_prepared_torch_tensor( active_queries, active_key_pages, active_value_pages, key_chunk_lengths=key_chunk_lengths, value_chunk_lengths=value_chunk_lengths, compact_grouped_chunk=bool(self.config.execution_grouped_decode_compact), compact_grouped_mix_chunk=bool(self.config.execution_grouped_mix_compact), disable_packed_grouped_cuda_mix=bool(self.config.execution_grouped_mix_disable_packed_cuda), trace=trace, ) backend_call_ms = 0.0 if backend_started_at is not None: backend_call_ms = (perf_counter() - backend_started_at) * 1000.0 self._record_decode_stage_timing( layer_id=int(layer_id), stage="backend_call_wall", ms=backend_call_ms, ) backend_trace_after = _backend_trace_ms_total(trace) self._record_decode_stage_timing( layer_id=int(layer_id), stage="backend_call_non_backend", ms=backend_call_ms - float(backend_trace_after - backend_trace_before), ) for q_head_ids, kv_outputs in zip(active_q_head_ids, grouped_outputs, strict=True): outputs[list(q_head_ids)] = kv_outputs return outputs if layer_prefer_grouped_batching: grouped_batch_rejection_reason = grouped_page_rejection_reason if grouped_batch_rejection_reason is None and grouped_layout_rejection_reason is not None: grouped_batch_rejection_reason = f"layout_{grouped_layout_rejection_reason}" if grouped_batch_rejection_reason is None and len(active_queries) <= 1: grouped_batch_rejection_reason = "single_query_group" if grouped_batch_rejection_reason is None: grouped_batch_rejection_reason = "unknown" self._record_decode_grouped_batch_rejection( layer_id=int(layer_id), reason=grouped_batch_rejection_reason, ) self._record_decode_path(layer_id, "per_kv_fallback") for q_head_ids, kv_queries, key_pages, value_pages in zip( active_q_head_ids, active_queries, active_key_pages, active_value_pages, strict=True, ): backend_started_at = _stage_start() backend_trace_before = _backend_trace_ms_total(trace) if capture_stage_timings else 0.0 _, _, kv_outputs = decode_multi_query_step_torch_tensor( kv_queries, key_pages, value_pages, device_type=self._torch_device_type, trace=trace, ) if backend_started_at is not None: backend_call_ms = (perf_counter() - backend_started_at) * 1000.0 self._record_decode_stage_timing( layer_id=int(layer_id), stage="backend_call_wall", ms=backend_call_ms, ) backend_trace_after = _backend_trace_ms_total(trace) self._record_decode_stage_timing( layer_id=int(layer_id), stage="backend_call_non_backend", ms=backend_call_ms - float(backend_trace_after - backend_trace_before), ) outputs[list(q_head_ids)] = kv_outputs return outputs