Spaces:
Paused
Paused
| from __future__ import annotations | |
| from dataclasses import dataclass, field | |
| from time import perf_counter | |
| from typing import Callable, Literal, Sequence | |
| import numpy as np | |
| from .attention_runtime import ( | |
| BackendName, | |
| decode_step_with_page_logits, | |
| mix_page, | |
| prepare_pages, | |
| score_page, | |
| score_pages, | |
| ) | |
| from .modes.m0_affine import dequantize_group | |
| from .modes.m1_lut import dequantize_group_lut | |
| from .modes.m2_key_sketch import reconstruct_group_m2 | |
| from .modes.m4_key_project import reconstruct_group_m4 | |
| from .modes.m3_escape import decode_escape_payload | |
| from .modes.turbo3 import dequantize_group_turbo3 | |
| from .page_cache import PreparedPageCache | |
| from .page_format import load_group_words | |
| from .packing import unpack_bits | |
| from .tracing import ExecutionTrace | |
| from .types import EncodedPage | |
| from .backends import PreparedPageTorch | |
| PageLike = EncodedPage | PreparedPageTorch | |
| RelevanceMode = Literal["sketch", "envelope"] | |
| def _decode_page_dense(page: PageLike) -> np.ndarray: | |
| source_page = page.source_page if isinstance(page, PreparedPageTorch) else page | |
| header = source_page.header | |
| if header.mode_default == "M3": | |
| if source_page.escape_payload is None: | |
| raise ValueError("escape payload is missing") | |
| return np.asarray( | |
| decode_escape_payload( | |
| source_page.escape_payload, | |
| head_dim=header.head_dim, | |
| scales=source_page.escape_scales, | |
| ), | |
| dtype=np.float32, | |
| ) | |
| if header.mode_default == "M2": | |
| if source_page.m2_sketch is None or source_page.m2_basis is None: | |
| raise ValueError("M2 page is missing sketch payload") | |
| dense = np.zeros((header.token_count, header.padded_head_dim), dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| start = group_index * header.group_size | |
| end = start + header.group_size | |
| dense[:, start:end] = reconstruct_group_m2( | |
| source_page.m2_sketch[:, group_index, :], | |
| basis=source_page.m2_basis[group_index], | |
| mean=None if source_page.m2_mean is None else source_page.m2_mean[group_index], | |
| ) | |
| return dense[:, : header.head_dim] | |
| if header.mode_default == "M4": | |
| if source_page.m2_sketch is None or source_page.m2_mean is None: | |
| raise ValueError("M4 page is missing projected payload") | |
| dense = np.zeros((header.token_count, header.padded_head_dim), dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| start = group_index * header.group_size | |
| end = start + header.group_size | |
| dense[:, start:end] = reconstruct_group_m4( | |
| source_page.m2_sketch[:, group_index, :], | |
| mean=source_page.m2_mean[group_index], | |
| group_size=header.group_size, | |
| basis_family=header.project_basis, | |
| basis=None if source_page.m2_basis is None else source_page.m2_basis[group_index], | |
| ) | |
| return dense[:, : header.head_dim] | |
| if source_page.payload is None: | |
| raise ValueError(f"{header.mode_default} page is missing payload") | |
| dense = np.zeros((header.token_count, header.padded_head_dim), dtype=np.float32) | |
| for group_index in range(header.num_groups): | |
| words = load_group_words(source_page, group_index) | |
| codes = unpack_bits(words, header.bits, header.group_size) | |
| if header.mode_default == "M1": | |
| if source_page.codebooks is None: | |
| raise ValueError("M1 page is missing codebooks") | |
| group_values = dequantize_group_lut( | |
| codes, | |
| codebook=np.asarray(source_page.codebooks[group_index], dtype=np.float32), | |
| ) | |
| elif header.mode_default == "T3": | |
| if source_page.scales is None or source_page.codebooks is None: | |
| raise ValueError("T3 page is missing correction metadata") | |
| group_values = dequantize_group_turbo3( | |
| codes, | |
| correction=source_page.scales[:, group_index].astype(np.float32), | |
| centroids=np.asarray(source_page.codebooks, dtype=np.float32), | |
| ) | |
| else: | |
| if source_page.scales is None: | |
| raise ValueError("M0 page is missing scales") | |
| scales = source_page.scales[:, group_index].astype(np.float32)[:, None] | |
| bias = None | |
| if source_page.bias is not None: | |
| bias = source_page.bias[:, group_index].astype(np.float32)[:, None] | |
| group_values = dequantize_group( | |
| codes, | |
| scales=scales, | |
| bias=bias, | |
| bits=header.bits, | |
| scheme=header.quant_scheme, | |
| ) | |
| start = group_index * header.group_size | |
| end = start + header.group_size | |
| dense[:, start:end] = group_values | |
| return dense[:, : header.head_dim] | |
| def sketch_key_page(page: PageLike, *, sketch_size: int = 1) -> np.ndarray: | |
| if sketch_size <= 0: | |
| raise ValueError("sketch_size must be positive") | |
| source_page = page.source_page if isinstance(page, PreparedPageTorch) else page | |
| if source_page.runtime_page_sketch is not None: | |
| stored = np.asarray(source_page.runtime_page_sketch, dtype=np.float32) | |
| if sketch_size == 1 and source_page.runtime_page_mean is not None: | |
| return np.asarray(source_page.runtime_page_mean, dtype=np.float32)[None, :] | |
| if stored.shape[0] == sketch_size: | |
| return stored | |
| if stored.shape[0] > sketch_size and sketch_size > 1: | |
| chunks = np.array_split(stored, sketch_size, axis=0) | |
| return np.stack([chunk.mean(axis=0) for chunk in chunks], axis=0).astype(np.float32, copy=False) | |
| dense = _decode_page_dense(page) | |
| if sketch_size == 1: | |
| return dense.mean(axis=0, keepdims=True) | |
| chunks = np.array_split(dense, min(sketch_size, dense.shape[0]), axis=0) | |
| return np.stack([chunk.mean(axis=0) for chunk in chunks], axis=0).astype(np.float32, copy=False) | |
| def summarize_key_page(page: PageLike) -> np.ndarray: | |
| return sketch_key_page(page, sketch_size=1)[0] | |
| def summarize_value_page(page: PageLike) -> np.ndarray: | |
| source_page = page.source_page if isinstance(page, PreparedPageTorch) else page | |
| if source_page.runtime_page_mean is not None: | |
| return np.asarray(source_page.runtime_page_mean, dtype=np.float32) | |
| return _decode_page_dense(page).mean(axis=0) | |
| def envelope_key_page(page: PageLike) -> tuple[np.ndarray, np.ndarray]: | |
| source_page = page.source_page if isinstance(page, PreparedPageTorch) else page | |
| if source_page.runtime_page_min is not None and source_page.runtime_page_max is not None: | |
| return ( | |
| np.asarray(source_page.runtime_page_min, dtype=np.float32), | |
| np.asarray(source_page.runtime_page_max, dtype=np.float32), | |
| ) | |
| dense = _decode_page_dense(page) | |
| return ( | |
| dense.min(axis=0).astype(np.float32, copy=False), | |
| dense.max(axis=0).astype(np.float32, copy=False), | |
| ) | |
| def score_page_relevance( | |
| query_slice: np.ndarray, | |
| *, | |
| relevance_mode: RelevanceMode, | |
| page_sketch: np.ndarray | None = None, | |
| page_min: np.ndarray | None = None, | |
| page_max: np.ndarray | None = None, | |
| ) -> float: | |
| query = np.asarray(query_slice, dtype=np.float32) | |
| if relevance_mode == "sketch": | |
| if page_sketch is None: | |
| raise ValueError("sketch relevance requires page_sketch") | |
| return float(np.max(np.asarray(page_sketch, dtype=np.float32) @ query)) | |
| if relevance_mode == "envelope": | |
| if page_min is None or page_max is None: | |
| raise ValueError("envelope relevance requires page_min and page_max") | |
| positive_query = np.maximum(query, 0.0) | |
| negative_query = np.minimum(query, 0.0) | |
| return float( | |
| np.asarray(page_max, dtype=np.float32) @ positive_query | |
| + np.asarray(page_min, dtype=np.float32) @ negative_query | |
| ) | |
| raise ValueError(f"unsupported relevance_mode: {relevance_mode}") | |
| def select_window_page_indices( | |
| key_pages: Sequence[PageLike], | |
| *, | |
| recent_window_tokens: int | None = None, | |
| sink_window_tokens: int = 0, | |
| ) -> list[int]: | |
| if not key_pages: | |
| return [] | |
| context_end = max(page.header.token_start + page.header.token_count for page in key_pages) | |
| sink_end = max(0, sink_window_tokens) | |
| recent_start = context_end | |
| if recent_window_tokens is not None and recent_window_tokens > 0: | |
| recent_start = max(0, context_end - recent_window_tokens) | |
| selected_indices: set[int] = set() | |
| for index, page in enumerate(key_pages): | |
| page_start = page.header.token_start | |
| page_end = page_start + page.header.token_count | |
| in_sink = sink_end > 0 and page_start < sink_end and page_end > 0 | |
| in_recent = recent_window_tokens is not None and recent_window_tokens > 0 and page_end > recent_start | |
| if in_sink or in_recent: | |
| selected_indices.add(index) | |
| return sorted(selected_indices) | |
| def select_execution_page_indices( | |
| key_pages: Sequence[PageLike], | |
| *, | |
| recent_window_tokens: int | None = None, | |
| sink_window_tokens: int = 0, | |
| query_slice: np.ndarray | None = None, | |
| key_page_sketches: Sequence[np.ndarray] | None = None, | |
| key_page_sketch_matrix: np.ndarray | None = None, | |
| tail_page_sketch: np.ndarray | None = None, | |
| key_page_minima: Sequence[np.ndarray] | None = None, | |
| key_page_minima_matrix: np.ndarray | None = None, | |
| tail_page_minimum: np.ndarray | None = None, | |
| key_page_maxima: Sequence[np.ndarray] | None = None, | |
| key_page_maxima_matrix: np.ndarray | None = None, | |
| tail_page_maximum: np.ndarray | None = None, | |
| relevance_top_k: int = 0, | |
| relevance_mode: RelevanceMode = "sketch", | |
| stage_recorder: Callable[[str, float], None] | None = None, | |
| score_all_pages_with_matrices: bool = False, | |
| score_all_pages_min_candidate_fraction: float = 0.0, | |
| selector_stats_recorder: Callable[[dict[str, int | float | bool]], None] | None = None, | |
| ) -> list[int]: | |
| def _record_stage(stage: str, started_at: float | None) -> None: | |
| if stage_recorder is None or started_at is None: | |
| return | |
| stage_recorder(stage, (perf_counter() - started_at) * 1000.0) | |
| def _materialize_candidate_rows(matrix: np.ndarray, direct_candidate_indices: Sequence[int]) -> np.ndarray: | |
| if not direct_candidate_indices: | |
| return np.empty((0,) + tuple(matrix.shape[1:]), dtype=np.float32) | |
| first_index = int(direct_candidate_indices[0]) | |
| last_index = int(direct_candidate_indices[-1]) | |
| if last_index - first_index + 1 == len(direct_candidate_indices): | |
| return np.ascontiguousarray(matrix[first_index : last_index + 1], dtype=np.float32) | |
| return np.take(matrix, direct_candidate_indices, axis=0).astype(np.float32, copy=False) | |
| if not key_pages: | |
| return [] | |
| selected_indices = set( | |
| select_window_page_indices( | |
| key_pages, | |
| recent_window_tokens=recent_window_tokens, | |
| sink_window_tokens=sink_window_tokens, | |
| ) | |
| ) | |
| if relevance_top_k > 0: | |
| if query_slice is None: | |
| raise ValueError("relevance gating requires query_slice") | |
| candidate_index_build_started_at = perf_counter() if stage_recorder is not None else None | |
| candidate_indices = [index for index in range(len(key_pages)) if index not in selected_indices] | |
| _record_stage("shortlist_candidate_builtin_candidate_index_build", candidate_index_build_started_at) | |
| if candidate_indices: | |
| candidate_fraction = float(len(candidate_indices)) / float(len(key_pages)) | |
| use_score_all_pages = bool( | |
| score_all_pages_with_matrices | |
| and candidate_fraction >= max(0.0, float(score_all_pages_min_candidate_fraction)) | |
| ) | |
| if selector_stats_recorder is not None: | |
| selector_stats_recorder( | |
| { | |
| "candidate_pages": int(len(candidate_indices)), | |
| "total_pages": int(len(key_pages)), | |
| "candidate_fraction": float(candidate_fraction), | |
| "used_score_all_pages": bool(use_score_all_pages), | |
| } | |
| ) | |
| query = np.asarray(query_slice, dtype=np.float32) | |
| if relevance_mode == "sketch": | |
| if key_page_sketch_matrix is not None: | |
| expected_sketch_rows = len(key_pages) - 1 if tail_page_sketch is not None else len(key_pages) | |
| if int(key_page_sketch_matrix.shape[0]) != expected_sketch_rows: | |
| raise ValueError("key_page_sketch_matrix must align with key_pages") | |
| if use_score_all_pages: | |
| score_compute_started_at = perf_counter() if stage_recorder is not None else None | |
| all_scores = np.max(key_page_sketch_matrix @ query, axis=1).astype(np.float32, copy=False) | |
| if tail_page_sketch is not None: | |
| tail_score = float(np.max(np.asarray(tail_page_sketch, dtype=np.float32) @ query)) | |
| all_scores = np.concatenate( | |
| [all_scores, np.asarray([tail_score], dtype=np.float32)], | |
| axis=0, | |
| ) | |
| scores = np.asarray(all_scores[candidate_indices], dtype=np.float32) | |
| _record_stage("shortlist_candidate_builtin_score_compute", score_compute_started_at) | |
| else: | |
| direct_candidate_indices = [index for index in candidate_indices if index < key_page_sketch_matrix.shape[0]] | |
| tail_candidate_selected = ( | |
| tail_page_sketch is not None and len(candidate_indices) > len(direct_candidate_indices) | |
| ) | |
| sidecar_stack_started_at = perf_counter() if stage_recorder is not None else None | |
| candidate_sketches = _materialize_candidate_rows( | |
| key_page_sketch_matrix, | |
| direct_candidate_indices, | |
| ) | |
| _record_stage("shortlist_candidate_builtin_sidecar_stack", sidecar_stack_started_at) | |
| score_compute_started_at = perf_counter() if stage_recorder is not None else None | |
| direct_scores = np.max(candidate_sketches @ query, axis=1).astype(np.float32, copy=False) | |
| if tail_candidate_selected: | |
| tail_score = float(np.max(np.asarray(tail_page_sketch, dtype=np.float32) @ query)) | |
| scores = np.concatenate( | |
| [direct_scores, np.asarray([tail_score], dtype=np.float32)], | |
| axis=0, | |
| ) | |
| else: | |
| scores = direct_scores | |
| _record_stage("shortlist_candidate_builtin_score_compute", score_compute_started_at) | |
| else: | |
| if key_page_sketches is None: | |
| raise ValueError("sketch relevance gating requires key_page_sketches") | |
| if len(key_page_sketches) != len(key_pages): | |
| raise ValueError("key_page_sketches must align with key_pages") | |
| sidecar_stack_started_at = perf_counter() if stage_recorder is not None else None | |
| candidate_sketches = np.stack( | |
| [np.asarray(key_page_sketches[index], dtype=np.float32) for index in candidate_indices], | |
| axis=0, | |
| ) | |
| _record_stage("shortlist_candidate_builtin_sidecar_stack", sidecar_stack_started_at) | |
| score_compute_started_at = perf_counter() if stage_recorder is not None else None | |
| scores = np.max(candidate_sketches @ query, axis=1).astype(np.float32, copy=False) | |
| _record_stage("shortlist_candidate_builtin_score_compute", score_compute_started_at) | |
| elif relevance_mode == "envelope": | |
| positive_query = np.maximum(query, 0.0) | |
| negative_query = np.minimum(query, 0.0) | |
| if key_page_minima_matrix is not None and key_page_maxima_matrix is not None: | |
| expected_envelope_rows = ( | |
| len(key_pages) - 1 | |
| if tail_page_minimum is not None and tail_page_maximum is not None | |
| else len(key_pages) | |
| ) | |
| if ( | |
| int(key_page_minima_matrix.shape[0]) != expected_envelope_rows | |
| or int(key_page_maxima_matrix.shape[0]) != expected_envelope_rows | |
| ): | |
| raise ValueError("page minima and maxima matrices must align with key_pages") | |
| if use_score_all_pages: | |
| score_compute_started_at = perf_counter() if stage_recorder is not None else None | |
| all_scores = ( | |
| key_page_maxima_matrix @ positive_query + key_page_minima_matrix @ negative_query | |
| ).astype(np.float32, copy=False) | |
| if tail_page_minimum is not None and tail_page_maximum is not None: | |
| tail_score = float( | |
| np.asarray(tail_page_maximum, dtype=np.float32) @ positive_query | |
| + np.asarray(tail_page_minimum, dtype=np.float32) @ negative_query | |
| ) | |
| all_scores = np.concatenate( | |
| [all_scores, np.asarray([tail_score], dtype=np.float32)], | |
| axis=0, | |
| ) | |
| scores = np.asarray(all_scores[candidate_indices], dtype=np.float32) | |
| _record_stage("shortlist_candidate_builtin_score_compute", score_compute_started_at) | |
| else: | |
| direct_candidate_indices = [index for index in candidate_indices if index < key_page_minima_matrix.shape[0]] | |
| tail_candidate_selected = ( | |
| tail_page_minimum is not None | |
| and tail_page_maximum is not None | |
| and len(candidate_indices) > len(direct_candidate_indices) | |
| ) | |
| sidecar_stack_started_at = perf_counter() if stage_recorder is not None else None | |
| candidate_minima = _materialize_candidate_rows( | |
| key_page_minima_matrix, | |
| direct_candidate_indices, | |
| ) | |
| candidate_maxima = _materialize_candidate_rows( | |
| key_page_maxima_matrix, | |
| direct_candidate_indices, | |
| ) | |
| _record_stage("shortlist_candidate_builtin_sidecar_stack", sidecar_stack_started_at) | |
| score_compute_started_at = perf_counter() if stage_recorder is not None else None | |
| direct_scores = ( | |
| candidate_maxima @ positive_query + candidate_minima @ negative_query | |
| ).astype(np.float32, copy=False) | |
| if tail_candidate_selected: | |
| tail_score = float( | |
| np.asarray(tail_page_maximum, dtype=np.float32) @ positive_query | |
| + np.asarray(tail_page_minimum, dtype=np.float32) @ negative_query | |
| ) | |
| scores = np.concatenate( | |
| [direct_scores, np.asarray([tail_score], dtype=np.float32)], | |
| axis=0, | |
| ) | |
| else: | |
| scores = direct_scores | |
| _record_stage("shortlist_candidate_builtin_score_compute", score_compute_started_at) | |
| else: | |
| if key_page_minima is None or key_page_maxima is None: | |
| raise ValueError("envelope relevance gating requires page minima and maxima") | |
| if len(key_page_minima) != len(key_pages) or len(key_page_maxima) != len(key_pages): | |
| raise ValueError("page minima and maxima must align with key_pages") | |
| sidecar_stack_started_at = perf_counter() if stage_recorder is not None else None | |
| candidate_minima = np.stack( | |
| [np.asarray(key_page_minima[index], dtype=np.float32) for index in candidate_indices], | |
| axis=0, | |
| ) | |
| candidate_maxima = np.stack( | |
| [np.asarray(key_page_maxima[index], dtype=np.float32) for index in candidate_indices], | |
| axis=0, | |
| ) | |
| _record_stage("shortlist_candidate_builtin_sidecar_stack", sidecar_stack_started_at) | |
| score_compute_started_at = perf_counter() if stage_recorder is not None else None | |
| scores = (candidate_maxima @ positive_query + candidate_minima @ negative_query).astype( | |
| np.float32, | |
| copy=False, | |
| ) | |
| _record_stage("shortlist_candidate_builtin_score_compute", score_compute_started_at) | |
| else: | |
| raise ValueError(f"unsupported relevance_mode: {relevance_mode}") | |
| ranking_started_at = perf_counter() if stage_recorder is not None else None | |
| ranked_candidates = [ | |
| index | |
| for _, index in sorted( | |
| zip(scores.tolist(), candidate_indices, strict=True), | |
| key=lambda item: item[0], | |
| reverse=True, | |
| ) | |
| ] | |
| _record_stage("shortlist_candidate_builtin_ranking", ranking_started_at) | |
| selected_indices.update(ranked_candidates[:relevance_top_k]) | |
| if not selected_indices: | |
| return list(range(len(key_pages))) | |
| return sorted(selected_indices) | |
| def select_execution_page_pairs( | |
| key_pages: Sequence[PageLike], | |
| value_pages: Sequence[PageLike], | |
| *, | |
| recent_window_tokens: int | None = None, | |
| sink_window_tokens: int = 0, | |
| query_slice: np.ndarray | None = None, | |
| key_page_sketches: Sequence[np.ndarray] | None = None, | |
| key_page_minima: Sequence[np.ndarray] | None = None, | |
| key_page_maxima: Sequence[np.ndarray] | None = None, | |
| relevance_top_k: int = 0, | |
| relevance_mode: RelevanceMode = "sketch", | |
| ) -> tuple[list[PageLike], list[PageLike]]: | |
| if len(key_pages) != len(value_pages): | |
| raise ValueError("key_pages and value_pages must contain the same number of pages") | |
| if not key_pages: | |
| return [], [] | |
| if ( | |
| (recent_window_tokens is None or recent_window_tokens <= 0) | |
| and sink_window_tokens <= 0 | |
| and relevance_top_k <= 0 | |
| ): | |
| return list(key_pages), list(value_pages) | |
| selected_indices = select_execution_page_indices( | |
| key_pages, | |
| recent_window_tokens=recent_window_tokens, | |
| sink_window_tokens=sink_window_tokens, | |
| query_slice=query_slice, | |
| key_page_sketches=key_page_sketches, | |
| key_page_minima=key_page_minima, | |
| key_page_maxima=key_page_maxima, | |
| relevance_top_k=relevance_top_k, | |
| relevance_mode=relevance_mode, | |
| ) | |
| return ( | |
| [key_pages[index] for index in selected_indices], | |
| [value_pages[index] for index in selected_indices], | |
| ) | |
| class PagedDecodeSession: | |
| backend: BackendName = "auto" | |
| cache: PreparedPageCache | None = None | |
| recent_window_tokens: int | None = None | |
| sink_window_tokens: int = 0 | |
| relevance_top_k: int = 0 | |
| relevance_sketch_size: int = 1 | |
| relevance_mode: RelevanceMode = "sketch" | |
| exact_refine_top_k: int = 0 | |
| approximate_old_pages: bool = False | |
| key_pages: list[PageLike] = field(default_factory=list) | |
| value_pages: list[PageLike] = field(default_factory=list) | |
| key_page_sketches: list[np.ndarray] = field(default_factory=list) | |
| key_page_minima: list[np.ndarray] = field(default_factory=list) | |
| key_page_maxima: list[np.ndarray] = field(default_factory=list) | |
| value_page_summaries: list[np.ndarray] = field(default_factory=list) | |
| last_selected_indices: list[int] = field(default_factory=list) | |
| def clear(self) -> None: | |
| self.key_pages.clear() | |
| self.value_pages.clear() | |
| self.key_page_sketches.clear() | |
| self.key_page_minima.clear() | |
| self.key_page_maxima.clear() | |
| self.value_page_summaries.clear() | |
| self.last_selected_indices.clear() | |
| if self.cache is not None: | |
| self.cache.clear() | |
| def page_count(self) -> int: | |
| return len(self.key_pages) | |
| def active_page_count(self) -> int: | |
| return len(self.execution_pages()[0]) | |
| def active_token_count(self) -> int: | |
| return sum(page.header.token_count for page in self.execution_pages()[0]) | |
| def preload( | |
| self, | |
| key_pages: Sequence[PageLike], | |
| value_pages: Sequence[PageLike], | |
| *, | |
| prepare: bool = True, | |
| trace: ExecutionTrace | None = None, | |
| ) -> None: | |
| self.clear() | |
| self.append(key_pages, value_pages, prepare=prepare, trace=trace) | |
| def append( | |
| self, | |
| key_pages: Sequence[PageLike], | |
| value_pages: Sequence[PageLike], | |
| *, | |
| prepare: bool = True, | |
| trace: ExecutionTrace | None = None, | |
| ) -> None: | |
| if len(key_pages) != len(value_pages): | |
| raise ValueError("key_pages and value_pages must contain the same number of pages") | |
| if prepare: | |
| prepared_key_pages = prepare_pages(key_pages, backend=self.backend, cache=self.cache, trace=trace) | |
| prepared_value_pages = prepare_pages(value_pages, backend=self.backend, cache=self.cache, trace=trace) | |
| else: | |
| prepared_key_pages = list(key_pages) | |
| prepared_value_pages = list(value_pages) | |
| self.key_pages.extend(prepared_key_pages) | |
| self.value_pages.extend(prepared_value_pages) | |
| self.key_page_sketches.extend( | |
| sketch_key_page(page, sketch_size=self.relevance_sketch_size) for page in prepared_key_pages | |
| ) | |
| for page in prepared_key_pages: | |
| page_min, page_max = envelope_key_page(page) | |
| self.key_page_minima.append(page_min) | |
| self.key_page_maxima.append(page_max) | |
| self.value_page_summaries.extend(summarize_value_page(page) for page in prepared_value_pages) | |
| def execution_pages(self, query_slice: np.ndarray | None = None) -> tuple[list[PageLike], list[PageLike]]: | |
| return select_execution_page_pairs( | |
| self.key_pages, | |
| self.value_pages, | |
| recent_window_tokens=self.recent_window_tokens, | |
| sink_window_tokens=self.sink_window_tokens, | |
| query_slice=query_slice, | |
| key_page_sketches=self.key_page_sketches, | |
| key_page_minima=self.key_page_minima, | |
| key_page_maxima=self.key_page_maxima, | |
| relevance_top_k=self.relevance_top_k, | |
| relevance_mode=self.relevance_mode, | |
| ) | |
| def execution_indices( | |
| self, | |
| query_slice: np.ndarray | None = None, | |
| *, | |
| trace: ExecutionTrace | None = None, | |
| ) -> list[int]: | |
| return self._execution_plan(query_slice, trace=trace)[0] | |
| def _execution_plan( | |
| self, | |
| query_slice: np.ndarray | None = None, | |
| *, | |
| trace: ExecutionTrace | None = None, | |
| ) -> tuple[list[int], dict[int, np.ndarray]]: | |
| stage1_indices = select_execution_page_indices( | |
| self.key_pages, | |
| recent_window_tokens=self.recent_window_tokens, | |
| sink_window_tokens=self.sink_window_tokens, | |
| query_slice=query_slice, | |
| key_page_sketches=self.key_page_sketches, | |
| key_page_minima=self.key_page_minima, | |
| key_page_maxima=self.key_page_maxima, | |
| relevance_top_k=self.relevance_top_k, | |
| relevance_mode=self.relevance_mode, | |
| ) | |
| if query_slice is None or self.exact_refine_top_k <= 0 or self.relevance_top_k <= 0: | |
| return stage1_indices, {} | |
| if not stage1_indices: | |
| return stage1_indices, {} | |
| base_indices = set( | |
| select_window_page_indices( | |
| self.key_pages, | |
| recent_window_tokens=self.recent_window_tokens, | |
| sink_window_tokens=self.sink_window_tokens, | |
| ) | |
| ) | |
| candidate_indices = [index for index in stage1_indices if index not in base_indices] | |
| if not candidate_indices or self.exact_refine_top_k >= len(candidate_indices): | |
| return stage1_indices, {} | |
| candidate_logits = score_pages( | |
| query_slice, | |
| [self.key_pages[index] for index in candidate_indices], | |
| backend=self.backend, | |
| trace=trace, | |
| ) | |
| exact_scores = [] | |
| for index, logits in zip(candidate_indices, candidate_logits, strict=True): | |
| exact_scores.append((float(np.max(logits)), index)) | |
| chosen = [ | |
| index | |
| for _, index in sorted( | |
| exact_scores, | |
| key=lambda item: item[0], | |
| reverse=True, | |
| )[: self.exact_refine_top_k] | |
| ] | |
| chosen_set = set(chosen) | |
| chosen_logits = { | |
| index: np.asarray(logits, dtype=np.float32) | |
| for index, logits in zip(candidate_indices, candidate_logits, strict=True) | |
| if index in chosen_set | |
| } | |
| return sorted(base_indices.union(chosen)), chosen_logits | |
| def decode( | |
| self, | |
| query_slice: np.ndarray, | |
| *, | |
| trace: ExecutionTrace | None = None, | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| if not self.key_pages or not self.value_pages: | |
| raise ValueError("PagedDecodeSession requires preloaded pages before decode") | |
| selected_indices, selected_logits = self._execution_plan(query_slice, trace=trace) | |
| self.last_selected_indices = list(selected_indices) | |
| key_pages = [self.key_pages[index] for index in selected_indices] | |
| value_pages = [self.value_pages[index] for index in selected_indices] | |
| if not self.approximate_old_pages or len(selected_indices) == len(self.key_pages): | |
| precomputed_page_logits = [selected_logits.get(index) for index in selected_indices] | |
| return decode_step_with_page_logits( | |
| query_slice, | |
| key_pages, | |
| value_pages, | |
| page_logits=precomputed_page_logits, | |
| backend=self.backend, | |
| trace=trace, | |
| ) | |
| return self._decode_with_old_page_fallback(query_slice, selected_indices, trace=trace) | |
| def _decode_with_old_page_fallback( | |
| self, | |
| query_slice: np.ndarray, | |
| selected_indices: Sequence[int], | |
| *, | |
| trace: ExecutionTrace | None = None, | |
| ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| query = np.asarray(query_slice, dtype=np.float32) | |
| exact_index_set = set(selected_indices) | |
| all_logits: list[np.ndarray] = [] | |
| max_logit = -np.inf | |
| for index, page in enumerate(self.key_pages): | |
| if index in exact_index_set: | |
| logits = score_page(query, page, backend=self.backend, trace=trace).astype(np.float32, copy=False) | |
| all_logits.append(logits) | |
| max_logit = max(max_logit, float(np.max(logits))) | |
| continue | |
| page_score = score_page_relevance( | |
| query, | |
| relevance_mode=self.relevance_mode, | |
| page_sketch=self.key_page_sketches[index], | |
| page_min=self.key_page_minima[index], | |
| page_max=self.key_page_maxima[index], | |
| ) | |
| logits = np.full(page.header.token_count, page_score, dtype=np.float32) | |
| all_logits.append(logits) | |
| max_logit = max(max_logit, page_score) | |
| if not np.isfinite(max_logit): | |
| raise ValueError("failed to compute logits for session decode") | |
| output = np.zeros(self.value_pages[0].header.head_dim, dtype=np.float32) | |
| all_weights: list[np.ndarray] = [] | |
| denom = 0.0 | |
| for index, page in enumerate(self.key_pages): | |
| logits = all_logits[index] | |
| weights = np.exp(logits - max_logit).astype(np.float32, copy=False) | |
| all_weights.append(weights) | |
| denom += float(np.sum(weights)) | |
| if index in exact_index_set: | |
| output = mix_page( | |
| weights, | |
| self.value_pages[index], | |
| out_acc=output, | |
| backend=self.backend, | |
| trace=trace, | |
| ) | |
| else: | |
| output += float(np.sum(weights)) * self.value_page_summaries[index] | |
| if denom <= 0.0: | |
| raise ValueError("invalid normalization denominator in session fallback decode") | |
| logits = np.concatenate(all_logits).astype(np.float32, copy=False) | |
| weights = np.concatenate(all_weights).astype(np.float32, copy=False) / np.float32(denom) | |
| return logits, weights, output / np.float32(denom) | |