| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
| import time |
| from typing import Sequence |
|
|
| try: |
| import numpy as np |
| except Exception: |
| np = None |
|
|
| try: |
| import faiss |
| except Exception: |
| faiss = None |
|
|
|
|
| @dataclass(frozen=True) |
| class SparseSelection: |
| positions: list[int] |
| scores: list[float] |
|
|
|
|
| def _require_numpy() -> None: |
| if np is None: |
| raise RuntimeError("NumPy is required for the sparse-context kernel.") |
|
|
|
|
| def normalize_rows(matrix: object) -> object: |
| _require_numpy() |
| values = np.asarray(matrix, dtype=np.float32) |
| if values.ndim != 2: |
| raise ValueError("matrix must be rank-2") |
| norms = np.linalg.norm(values, axis=1, keepdims=True) |
| return values / np.maximum(norms, 1e-8) |
|
|
|
|
| class AnalyticalSparseAttention: |
| """Content-dependent long-context selection from corpus-derived embeddings. |
| |
| This is Reframr's analytical sparse-context kernel: it selects positions by |
| embedding geometry, then aggregates only the selected states. It does not |
| contain task-specific answer strings or prompt-pattern shortcuts. |
| """ |
|
|
| def __init__(self, embeddings: object, *, k_neighbors: int = 64) -> None: |
| _require_numpy() |
| self.embeddings = np.asarray(embeddings, dtype=np.float32) |
| if self.embeddings.ndim != 2: |
| raise ValueError("embeddings must be rank-2") |
| self.k_neighbors = max(1, int(k_neighbors)) |
| self.normalized_embeddings = normalize_rows(self.embeddings) |
| self._context_token_ids: object | None = None |
| self._context_vectors: object | None = None |
|
|
| @property |
| def embedding_dim(self) -> int: |
| return int(self.embeddings.shape[1]) |
|
|
| def select_positions( |
| self, |
| query_token_id: int, |
| context_token_ids: Sequence[int] | object, |
| *, |
| top_k: int | None = None, |
| ) -> SparseSelection: |
| token_ids = self._coerce_token_ids(context_token_ids) |
| context_vectors = self.normalized_embeddings[token_ids] |
| return self._select_positions_from_vectors( |
| query_token_id, |
| token_ids, |
| context_vectors, |
| top_k=top_k, |
| ) |
|
|
| def build_context_index(self, context_token_ids: Sequence[int] | object) -> None: |
| token_ids = self._coerce_token_ids(context_token_ids) |
| self._context_token_ids = token_ids |
| self._context_vectors = self.normalized_embeddings[token_ids] |
|
|
| def select_positions_cached( |
| self, |
| query_token_id: int, |
| *, |
| top_k: int | None = None, |
| ) -> SparseSelection: |
| if self._context_token_ids is None or self._context_vectors is None: |
| raise RuntimeError("call build_context_index() before select_positions_cached()") |
| return self._select_positions_from_vectors( |
| query_token_id, |
| self._context_token_ids, |
| self._context_vectors, |
| top_k=top_k, |
| ) |
|
|
| def _select_positions_from_vectors( |
| self, |
| query_token_id: int, |
| token_ids: object, |
| context_vectors: object, |
| *, |
| top_k: int | None = None, |
| ) -> SparseSelection: |
| if token_ids.size == 0: |
| return SparseSelection(positions=[], scores=[]) |
| query_id = int(query_token_id) |
| if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]: |
| raise ValueError("query_token_id is outside the embedding table") |
| k = min(token_ids.size, max(1, int(top_k or self.k_neighbors))) |
| query_vector = self.normalized_embeddings[query_id] |
| scores = context_vectors @ query_vector |
| if k >= scores.size: |
| selected = np.argsort(scores)[::-1] |
| else: |
| selected = np.argpartition(scores, -k)[-k:] |
| selected = selected[np.argsort(scores[selected])[::-1]] |
| return SparseSelection( |
| positions=[int(index) for index in selected.tolist()], |
| scores=[float(scores[index]) for index in selected.tolist()], |
| ) |
|
|
| def sparse_output( |
| self, |
| query_token_id: int, |
| context_token_ids: Sequence[int] | object, |
| context_states: object | None = None, |
| *, |
| top_k: int | None = None, |
| temperature: float = 1.0, |
| ) -> object: |
| token_ids = self._coerce_token_ids(context_token_ids) |
| if context_states is None: |
| states = self.embeddings[token_ids] |
| else: |
| states = np.asarray(context_states, dtype=np.float32) |
| if states.ndim != 2 or states.shape[0] != token_ids.size: |
| raise ValueError("context_states must be rank-2 and match context length") |
| selection = self.select_positions(query_token_id, token_ids, top_k=top_k) |
| if not selection.positions: |
| return np.zeros(states.shape[1], dtype=np.float32) |
| selected_states = states[np.asarray(selection.positions, dtype=np.int64)] |
| scores = np.asarray(selection.scores, dtype=np.float32) |
| scaled = scores / max(float(temperature), 1e-6) |
| scaled -= float(scaled.max()) |
| weights = np.exp(scaled) |
| weights /= max(float(weights.sum()), 1e-8) |
| return weights @ selected_states |
|
|
| def benchmark_selection( |
| self, |
| context_token_ids: Sequence[int] | object, |
| query_token_ids: Sequence[int] | object, |
| *, |
| top_k: int | None = None, |
| cache_context: bool = True, |
| ) -> dict[str, object]: |
| token_ids = self._coerce_token_ids(context_token_ids) |
| queries = self._coerce_token_ids(query_token_ids) |
| build_started = time.perf_counter() |
| if cache_context: |
| self.build_context_index(token_ids) |
| build_elapsed = time.perf_counter() - build_started |
| started = time.perf_counter() |
| selected_total = 0 |
| for query_id in queries.tolist(): |
| if cache_context: |
| selection = self.select_positions_cached(int(query_id), top_k=top_k) |
| else: |
| selection = self.select_positions(int(query_id), token_ids, top_k=top_k) |
| selected_total += len(selection.positions) |
| elapsed = time.perf_counter() - started |
| return { |
| "context_tokens": int(token_ids.size), |
| "query_count": int(queries.size), |
| "top_k": min(int(top_k or self.k_neighbors), int(token_ids.size)) if token_ids.size else 0, |
| "selected_positions": int(selected_total), |
| "cache_context": bool(cache_context), |
| "index_build_seconds": build_elapsed, |
| "seconds": elapsed, |
| "queries_per_second": (float(queries.size) / elapsed) if elapsed > 0.0 else 0.0, |
| } |
|
|
| def _coerce_token_ids(self, token_ids: Sequence[int] | object) -> object: |
| ids = np.asarray(token_ids, dtype=np.int64) |
| if ids.ndim != 1: |
| raise ValueError("token ids must be rank-1") |
| if ids.size and (int(ids.min()) < 0 or int(ids.max()) >= self.embeddings.shape[0]): |
| raise ValueError("context token id is outside the embedding table") |
| return ids |
|
|
|
|
| def compare_selectors( |
| embeddings: object, |
| context_token_ids: Sequence[int] | object, |
| query_token_ids: Sequence[int] | object, |
| *, |
| top_k: int = 64, |
| hash_bits: int = 12, |
| probe_radius: int = 1, |
| seed: int = 2026, |
| ) -> dict[str, object]: |
| _require_numpy() |
| exact = AnalyticalSparseAttention(embeddings, k_neighbors=top_k) |
| hashed = HashedSparseAttention( |
| embeddings, |
| k_neighbors=top_k, |
| hash_bits=hash_bits, |
| probe_radius=probe_radius, |
| seed=seed, |
| ) |
| token_ids = exact._coerce_token_ids(context_token_ids) |
| queries = exact._coerce_token_ids(query_token_ids) |
| hashed.build_context_index(token_ids) |
| recalls: list[float] = [] |
| for query_id in queries.tolist(): |
| exact_positions = set(exact.select_positions(int(query_id), token_ids, top_k=top_k).positions) |
| hashed_positions = set(hashed.select_positions_cached(int(query_id), top_k=top_k).positions) |
| if not exact_positions: |
| recalls.append(1.0) |
| else: |
| recalls.append(len(exact_positions & hashed_positions) / len(exact_positions)) |
| return { |
| "context_tokens": int(token_ids.size), |
| "query_count": int(queries.size), |
| "top_k": int(top_k), |
| "hash_bits": int(hash_bits), |
| "probe_radius": int(probe_radius), |
| "mean_recall_at_k": float(sum(recalls) / len(recalls)) if recalls else 0.0, |
| "min_recall_at_k": float(min(recalls)) if recalls else 0.0, |
| } |
|
|
| class HashedSparseAttention(AnalyticalSparseAttention): |
| """Approximate sparse selector using deterministic random-hyperplane buckets. |
| |
| It keeps the analytical embedding-geometry rule, but avoids scanning the full |
| context for every query. Buckets are built once from signs of fixed |
| hyperplane projections; each query scans only matching buckets, then reranks |
| the candidate set exactly by cosine similarity. |
| """ |
|
|
| def __init__( |
| self, |
| embeddings: object, |
| *, |
| k_neighbors: int = 64, |
| hash_bits: int = 12, |
| probe_radius: int = 1, |
| seed: int = 2026, |
| candidate_multiplier: int = 12, |
| ) -> None: |
| super().__init__(embeddings, k_neighbors=k_neighbors) |
| self.hash_bits = max(1, int(hash_bits)) |
| self.probe_radius = max(0, int(probe_radius)) |
| self.candidate_multiplier = max(1, int(candidate_multiplier)) |
| rng = np.random.default_rng(int(seed)) |
| self.hyperplanes = rng.normal( |
| size=(self.embedding_dim, self.hash_bits) |
| ).astype(np.float32) |
| self._bucket_positions: dict[int, list[int]] = {} |
|
|
| def build_context_index(self, context_token_ids: Sequence[int] | object) -> None: |
| token_ids = self._coerce_token_ids(context_token_ids) |
| self._context_token_ids = token_ids |
| self._context_vectors = self.normalized_embeddings[token_ids] |
| codes = self._codes_for_vectors(self._context_vectors) |
| buckets: dict[int, list[int]] = {} |
| for position, code in enumerate(codes.tolist()): |
| buckets.setdefault(int(code), []).append(position) |
| self._bucket_positions = buckets |
|
|
| def select_positions_cached( |
| self, |
| query_token_id: int, |
| *, |
| top_k: int | None = None, |
| ) -> SparseSelection: |
| if self._context_token_ids is None or self._context_vectors is None: |
| raise RuntimeError("call build_context_index() before select_positions_cached()") |
| query_id = int(query_token_id) |
| if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]: |
| raise ValueError("query_token_id is outside the embedding table") |
| k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors))) |
| candidate_positions = self._candidate_positions(query_id, k) |
| if len(candidate_positions) < k: |
| return super().select_positions_cached(query_id, top_k=top_k) |
| positions = np.asarray(candidate_positions, dtype=np.int64) |
| query_vector = self.normalized_embeddings[query_id] |
| scores = self._context_vectors[positions] @ query_vector |
| if k >= scores.size: |
| selected_local = np.argsort(scores)[::-1] |
| else: |
| selected_local = np.argpartition(scores, -k)[-k:] |
| selected_local = selected_local[np.argsort(scores[selected_local])[::-1]] |
| selected_positions = positions[selected_local] |
| return SparseSelection( |
| positions=[int(index) for index in selected_positions.tolist()], |
| scores=[float(scores[index]) for index in selected_local.tolist()], |
| ) |
|
|
| def _candidate_positions(self, query_token_id: int, k: int) -> list[int]: |
| query_vector = self.normalized_embeddings[int(query_token_id)].reshape(1, -1) |
| query_code = int(self._codes_for_vectors(query_vector)[0]) |
| candidate_limit = max(k, k * self.candidate_multiplier) |
| candidates: list[int] = [] |
| seen: set[int] = set() |
| for code in self._probe_codes(query_code): |
| for position in self._bucket_positions.get(code, []): |
| if position in seen: |
| continue |
| seen.add(position) |
| candidates.append(position) |
| if len(candidates) >= candidate_limit: |
| return candidates |
| return candidates |
|
|
| def _codes_for_vectors(self, vectors: object) -> object: |
| projections = np.asarray(vectors, dtype=np.float32) @ self.hyperplanes |
| bits = projections >= 0.0 |
| codes = np.zeros(bits.shape[0], dtype=np.int64) |
| for bit_index in range(self.hash_bits): |
| codes |= bits[:, bit_index].astype(np.int64) << bit_index |
| return codes |
|
|
| def _probe_codes(self, code: int) -> list[int]: |
| codes = [int(code)] |
| if self.probe_radius >= 1: |
| codes.extend(int(code) ^ (1 << bit) for bit in range(self.hash_bits)) |
| if self.probe_radius >= 2: |
| for first in range(self.hash_bits): |
| for second in range(first + 1, self.hash_bits): |
| codes.append(int(code) ^ (1 << first) ^ (1 << second)) |
| return codes |
|
|
|
|
| class FaissSparseAttention(AnalyticalSparseAttention): |
| """Native FAISS-backed sparse selector over normalized embedding geometry.""" |
|
|
| def __init__( |
| self, |
| embeddings: object, |
| *, |
| k_neighbors: int = 64, |
| approximate: bool = False, |
| hnsw_neighbors: int = 32, |
| ef_search: int = 64, |
| ) -> None: |
| if faiss is None: |
| raise RuntimeError("faiss-cpu is not installed") |
| super().__init__(embeddings, k_neighbors=k_neighbors) |
| self.approximate = bool(approximate) |
| self.hnsw_neighbors = max(4, int(hnsw_neighbors)) |
| self.ef_search = max(int(k_neighbors), int(ef_search)) |
| self.index = self._new_index() |
|
|
| def _new_index(self) -> object: |
| if self.approximate: |
| index = faiss.IndexHNSWFlat( |
| self.embedding_dim, |
| self.hnsw_neighbors, |
| faiss.METRIC_INNER_PRODUCT, |
| ) |
| index.hnsw.efSearch = self.ef_search |
| index.hnsw.efConstruction = max(self.ef_search, self.hnsw_neighbors * 2) |
| return index |
| return faiss.IndexFlatIP(self.embedding_dim) |
|
|
| def build_context_index(self, context_token_ids: Sequence[int] | object) -> None: |
| token_ids = self._coerce_token_ids(context_token_ids) |
| self._context_token_ids = token_ids |
| self._context_vectors = np.ascontiguousarray( |
| self.normalized_embeddings[token_ids], |
| dtype=np.float32, |
| ) |
| self.index = self._new_index() |
| self.index.add(self._context_vectors) |
|
|
| def select_positions_cached( |
| self, |
| query_token_id: int, |
| *, |
| top_k: int | None = None, |
| ) -> SparseSelection: |
| if self._context_token_ids is None or self._context_vectors is None: |
| raise RuntimeError("call build_context_index() before select_positions_cached()") |
| query_id = int(query_token_id) |
| if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]: |
| raise ValueError("query_token_id is outside the embedding table") |
| k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors))) |
| query = np.ascontiguousarray( |
| self.normalized_embeddings[query_id].reshape(1, -1), |
| dtype=np.float32, |
| ) |
| scores, indices = self.index.search(query, k) |
| valid = indices[0] >= 0 |
| return SparseSelection( |
| positions=[int(index) for index in indices[0][valid].tolist()], |
| scores=[float(score) for score in scores[0][valid].tolist()], |
| ) |
|
|