from __future__ import annotations from dataclasses import dataclass import time from typing import Sequence try: # pragma: no cover - exercised when NumPy is available in runtime envs. import numpy as np except Exception: # pragma: no cover np = None # type: ignore[assignment] try: # pragma: no cover - optional native ANN backend. import faiss except Exception: # pragma: no cover faiss = None # type: ignore[assignment] @dataclass(frozen=True) class SparseSelection: positions: list[int] scores: list[float] def _require_numpy() -> None: if np is None: raise RuntimeError("NumPy is required for the sparse-context kernel.") def normalize_rows(matrix: object) -> object: _require_numpy() values = np.asarray(matrix, dtype=np.float32) if values.ndim != 2: raise ValueError("matrix must be rank-2") norms = np.linalg.norm(values, axis=1, keepdims=True) return values / np.maximum(norms, 1e-8) class AnalyticalSparseAttention: """Content-dependent long-context selection from corpus-derived embeddings. This is Reframr's analytical sparse-context kernel: it selects positions by embedding geometry, then aggregates only the selected states. It does not contain task-specific answer strings or prompt-pattern shortcuts. """ def __init__(self, embeddings: object, *, k_neighbors: int = 64) -> None: _require_numpy() self.embeddings = np.asarray(embeddings, dtype=np.float32) if self.embeddings.ndim != 2: raise ValueError("embeddings must be rank-2") self.k_neighbors = max(1, int(k_neighbors)) self.normalized_embeddings = normalize_rows(self.embeddings) self._context_token_ids: object | None = None self._context_vectors: object | None = None @property def embedding_dim(self) -> int: return int(self.embeddings.shape[1]) def select_positions( self, query_token_id: int, context_token_ids: Sequence[int] | object, *, top_k: int | None = None, ) -> SparseSelection: token_ids = self._coerce_token_ids(context_token_ids) context_vectors = self.normalized_embeddings[token_ids] return self._select_positions_from_vectors( query_token_id, token_ids, context_vectors, top_k=top_k, ) def build_context_index(self, context_token_ids: Sequence[int] | object) -> None: token_ids = self._coerce_token_ids(context_token_ids) self._context_token_ids = token_ids self._context_vectors = self.normalized_embeddings[token_ids] def select_positions_cached( self, query_token_id: int, *, top_k: int | None = None, ) -> SparseSelection: if self._context_token_ids is None or self._context_vectors is None: raise RuntimeError("call build_context_index() before select_positions_cached()") return self._select_positions_from_vectors( query_token_id, self._context_token_ids, self._context_vectors, top_k=top_k, ) def _select_positions_from_vectors( self, query_token_id: int, token_ids: object, context_vectors: object, *, top_k: int | None = None, ) -> SparseSelection: if token_ids.size == 0: return SparseSelection(positions=[], scores=[]) query_id = int(query_token_id) if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]: raise ValueError("query_token_id is outside the embedding table") k = min(token_ids.size, max(1, int(top_k or self.k_neighbors))) query_vector = self.normalized_embeddings[query_id] scores = context_vectors @ query_vector if k >= scores.size: selected = np.argsort(scores)[::-1] else: selected = np.argpartition(scores, -k)[-k:] selected = selected[np.argsort(scores[selected])[::-1]] return SparseSelection( positions=[int(index) for index in selected.tolist()], scores=[float(scores[index]) for index in selected.tolist()], ) def sparse_output( self, query_token_id: int, context_token_ids: Sequence[int] | object, context_states: object | None = None, *, top_k: int | None = None, temperature: float = 1.0, ) -> object: token_ids = self._coerce_token_ids(context_token_ids) if context_states is None: states = self.embeddings[token_ids] else: states = np.asarray(context_states, dtype=np.float32) if states.ndim != 2 or states.shape[0] != token_ids.size: raise ValueError("context_states must be rank-2 and match context length") selection = self.select_positions(query_token_id, token_ids, top_k=top_k) if not selection.positions: return np.zeros(states.shape[1], dtype=np.float32) selected_states = states[np.asarray(selection.positions, dtype=np.int64)] scores = np.asarray(selection.scores, dtype=np.float32) scaled = scores / max(float(temperature), 1e-6) scaled -= float(scaled.max()) weights = np.exp(scaled) weights /= max(float(weights.sum()), 1e-8) return weights @ selected_states def benchmark_selection( self, context_token_ids: Sequence[int] | object, query_token_ids: Sequence[int] | object, *, top_k: int | None = None, cache_context: bool = True, ) -> dict[str, object]: token_ids = self._coerce_token_ids(context_token_ids) queries = self._coerce_token_ids(query_token_ids) build_started = time.perf_counter() if cache_context: self.build_context_index(token_ids) build_elapsed = time.perf_counter() - build_started started = time.perf_counter() selected_total = 0 for query_id in queries.tolist(): if cache_context: selection = self.select_positions_cached(int(query_id), top_k=top_k) else: selection = self.select_positions(int(query_id), token_ids, top_k=top_k) selected_total += len(selection.positions) elapsed = time.perf_counter() - started return { "context_tokens": int(token_ids.size), "query_count": int(queries.size), "top_k": min(int(top_k or self.k_neighbors), int(token_ids.size)) if token_ids.size else 0, "selected_positions": int(selected_total), "cache_context": bool(cache_context), "index_build_seconds": build_elapsed, "seconds": elapsed, "queries_per_second": (float(queries.size) / elapsed) if elapsed > 0.0 else 0.0, } def _coerce_token_ids(self, token_ids: Sequence[int] | object) -> object: ids = np.asarray(token_ids, dtype=np.int64) if ids.ndim != 1: raise ValueError("token ids must be rank-1") if ids.size and (int(ids.min()) < 0 or int(ids.max()) >= self.embeddings.shape[0]): raise ValueError("context token id is outside the embedding table") return ids def compare_selectors( embeddings: object, context_token_ids: Sequence[int] | object, query_token_ids: Sequence[int] | object, *, top_k: int = 64, hash_bits: int = 12, probe_radius: int = 1, seed: int = 2026, ) -> dict[str, object]: _require_numpy() exact = AnalyticalSparseAttention(embeddings, k_neighbors=top_k) hashed = HashedSparseAttention( embeddings, k_neighbors=top_k, hash_bits=hash_bits, probe_radius=probe_radius, seed=seed, ) token_ids = exact._coerce_token_ids(context_token_ids) queries = exact._coerce_token_ids(query_token_ids) hashed.build_context_index(token_ids) recalls: list[float] = [] for query_id in queries.tolist(): exact_positions = set(exact.select_positions(int(query_id), token_ids, top_k=top_k).positions) hashed_positions = set(hashed.select_positions_cached(int(query_id), top_k=top_k).positions) if not exact_positions: recalls.append(1.0) else: recalls.append(len(exact_positions & hashed_positions) / len(exact_positions)) return { "context_tokens": int(token_ids.size), "query_count": int(queries.size), "top_k": int(top_k), "hash_bits": int(hash_bits), "probe_radius": int(probe_radius), "mean_recall_at_k": float(sum(recalls) / len(recalls)) if recalls else 0.0, "min_recall_at_k": float(min(recalls)) if recalls else 0.0, } class HashedSparseAttention(AnalyticalSparseAttention): """Approximate sparse selector using deterministic random-hyperplane buckets. It keeps the analytical embedding-geometry rule, but avoids scanning the full context for every query. Buckets are built once from signs of fixed hyperplane projections; each query scans only matching buckets, then reranks the candidate set exactly by cosine similarity. """ def __init__( self, embeddings: object, *, k_neighbors: int = 64, hash_bits: int = 12, probe_radius: int = 1, seed: int = 2026, candidate_multiplier: int = 12, ) -> None: super().__init__(embeddings, k_neighbors=k_neighbors) self.hash_bits = max(1, int(hash_bits)) self.probe_radius = max(0, int(probe_radius)) self.candidate_multiplier = max(1, int(candidate_multiplier)) rng = np.random.default_rng(int(seed)) self.hyperplanes = rng.normal( size=(self.embedding_dim, self.hash_bits) ).astype(np.float32) self._bucket_positions: dict[int, list[int]] = {} def build_context_index(self, context_token_ids: Sequence[int] | object) -> None: token_ids = self._coerce_token_ids(context_token_ids) self._context_token_ids = token_ids self._context_vectors = self.normalized_embeddings[token_ids] codes = self._codes_for_vectors(self._context_vectors) buckets: dict[int, list[int]] = {} for position, code in enumerate(codes.tolist()): buckets.setdefault(int(code), []).append(position) self._bucket_positions = buckets def select_positions_cached( self, query_token_id: int, *, top_k: int | None = None, ) -> SparseSelection: if self._context_token_ids is None or self._context_vectors is None: raise RuntimeError("call build_context_index() before select_positions_cached()") query_id = int(query_token_id) if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]: raise ValueError("query_token_id is outside the embedding table") k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors))) candidate_positions = self._candidate_positions(query_id, k) if len(candidate_positions) < k: return super().select_positions_cached(query_id, top_k=top_k) positions = np.asarray(candidate_positions, dtype=np.int64) query_vector = self.normalized_embeddings[query_id] scores = self._context_vectors[positions] @ query_vector if k >= scores.size: selected_local = np.argsort(scores)[::-1] else: selected_local = np.argpartition(scores, -k)[-k:] selected_local = selected_local[np.argsort(scores[selected_local])[::-1]] selected_positions = positions[selected_local] return SparseSelection( positions=[int(index) for index in selected_positions.tolist()], scores=[float(scores[index]) for index in selected_local.tolist()], ) def _candidate_positions(self, query_token_id: int, k: int) -> list[int]: query_vector = self.normalized_embeddings[int(query_token_id)].reshape(1, -1) query_code = int(self._codes_for_vectors(query_vector)[0]) candidate_limit = max(k, k * self.candidate_multiplier) candidates: list[int] = [] seen: set[int] = set() for code in self._probe_codes(query_code): for position in self._bucket_positions.get(code, []): if position in seen: continue seen.add(position) candidates.append(position) if len(candidates) >= candidate_limit: return candidates return candidates def _codes_for_vectors(self, vectors: object) -> object: projections = np.asarray(vectors, dtype=np.float32) @ self.hyperplanes bits = projections >= 0.0 codes = np.zeros(bits.shape[0], dtype=np.int64) for bit_index in range(self.hash_bits): codes |= bits[:, bit_index].astype(np.int64) << bit_index return codes def _probe_codes(self, code: int) -> list[int]: codes = [int(code)] if self.probe_radius >= 1: codes.extend(int(code) ^ (1 << bit) for bit in range(self.hash_bits)) if self.probe_radius >= 2: for first in range(self.hash_bits): for second in range(first + 1, self.hash_bits): codes.append(int(code) ^ (1 << first) ^ (1 << second)) return codes class FaissSparseAttention(AnalyticalSparseAttention): """Native FAISS-backed sparse selector over normalized embedding geometry.""" def __init__( self, embeddings: object, *, k_neighbors: int = 64, approximate: bool = False, hnsw_neighbors: int = 32, ef_search: int = 64, ) -> None: if faiss is None: raise RuntimeError("faiss-cpu is not installed") super().__init__(embeddings, k_neighbors=k_neighbors) self.approximate = bool(approximate) self.hnsw_neighbors = max(4, int(hnsw_neighbors)) self.ef_search = max(int(k_neighbors), int(ef_search)) self.index = self._new_index() def _new_index(self) -> object: if self.approximate: index = faiss.IndexHNSWFlat( self.embedding_dim, self.hnsw_neighbors, faiss.METRIC_INNER_PRODUCT, ) index.hnsw.efSearch = self.ef_search index.hnsw.efConstruction = max(self.ef_search, self.hnsw_neighbors * 2) return index return faiss.IndexFlatIP(self.embedding_dim) def build_context_index(self, context_token_ids: Sequence[int] | object) -> None: token_ids = self._coerce_token_ids(context_token_ids) self._context_token_ids = token_ids self._context_vectors = np.ascontiguousarray( self.normalized_embeddings[token_ids], dtype=np.float32, ) self.index = self._new_index() self.index.add(self._context_vectors) def select_positions_cached( self, query_token_id: int, *, top_k: int | None = None, ) -> SparseSelection: if self._context_token_ids is None or self._context_vectors is None: raise RuntimeError("call build_context_index() before select_positions_cached()") query_id = int(query_token_id) if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]: raise ValueError("query_token_id is outside the embedding table") k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors))) query = np.ascontiguousarray( self.normalized_embeddings[query_id].reshape(1, -1), dtype=np.float32, ) scores, indices = self.index.search(query, k) valid = indices[0] >= 0 return SparseSelection( positions=[int(index) for index in indices[0][valid].tolist()], scores=[float(score) for score in scores[0][valid].tolist()], )