Reframr-RFM-v2-Base / reframr /sparse_context.py
OkeyMeta's picture
Add Reframr-RFM-v2-Base release files
52da7b7 verified
from __future__ import annotations
from dataclasses import dataclass
import time
from typing import Sequence
try: # pragma: no cover - exercised when NumPy is available in runtime envs.
import numpy as np
except Exception: # pragma: no cover
np = None # type: ignore[assignment]
try: # pragma: no cover - optional native ANN backend.
import faiss
except Exception: # pragma: no cover
faiss = None # type: ignore[assignment]
@dataclass(frozen=True)
class SparseSelection:
positions: list[int]
scores: list[float]
def _require_numpy() -> None:
if np is None:
raise RuntimeError("NumPy is required for the sparse-context kernel.")
def normalize_rows(matrix: object) -> object:
_require_numpy()
values = np.asarray(matrix, dtype=np.float32)
if values.ndim != 2:
raise ValueError("matrix must be rank-2")
norms = np.linalg.norm(values, axis=1, keepdims=True)
return values / np.maximum(norms, 1e-8)
class AnalyticalSparseAttention:
"""Content-dependent long-context selection from corpus-derived embeddings.
This is Reframr's analytical sparse-context kernel: it selects positions by
embedding geometry, then aggregates only the selected states. It does not
contain task-specific answer strings or prompt-pattern shortcuts.
"""
def __init__(self, embeddings: object, *, k_neighbors: int = 64) -> None:
_require_numpy()
self.embeddings = np.asarray(embeddings, dtype=np.float32)
if self.embeddings.ndim != 2:
raise ValueError("embeddings must be rank-2")
self.k_neighbors = max(1, int(k_neighbors))
self.normalized_embeddings = normalize_rows(self.embeddings)
self._context_token_ids: object | None = None
self._context_vectors: object | None = None
@property
def embedding_dim(self) -> int:
return int(self.embeddings.shape[1])
def select_positions(
self,
query_token_id: int,
context_token_ids: Sequence[int] | object,
*,
top_k: int | None = None,
) -> SparseSelection:
token_ids = self._coerce_token_ids(context_token_ids)
context_vectors = self.normalized_embeddings[token_ids]
return self._select_positions_from_vectors(
query_token_id,
token_ids,
context_vectors,
top_k=top_k,
)
def build_context_index(self, context_token_ids: Sequence[int] | object) -> None:
token_ids = self._coerce_token_ids(context_token_ids)
self._context_token_ids = token_ids
self._context_vectors = self.normalized_embeddings[token_ids]
def select_positions_cached(
self,
query_token_id: int,
*,
top_k: int | None = None,
) -> SparseSelection:
if self._context_token_ids is None or self._context_vectors is None:
raise RuntimeError("call build_context_index() before select_positions_cached()")
return self._select_positions_from_vectors(
query_token_id,
self._context_token_ids,
self._context_vectors,
top_k=top_k,
)
def _select_positions_from_vectors(
self,
query_token_id: int,
token_ids: object,
context_vectors: object,
*,
top_k: int | None = None,
) -> SparseSelection:
if token_ids.size == 0:
return SparseSelection(positions=[], scores=[])
query_id = int(query_token_id)
if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]:
raise ValueError("query_token_id is outside the embedding table")
k = min(token_ids.size, max(1, int(top_k or self.k_neighbors)))
query_vector = self.normalized_embeddings[query_id]
scores = context_vectors @ query_vector
if k >= scores.size:
selected = np.argsort(scores)[::-1]
else:
selected = np.argpartition(scores, -k)[-k:]
selected = selected[np.argsort(scores[selected])[::-1]]
return SparseSelection(
positions=[int(index) for index in selected.tolist()],
scores=[float(scores[index]) for index in selected.tolist()],
)
def sparse_output(
self,
query_token_id: int,
context_token_ids: Sequence[int] | object,
context_states: object | None = None,
*,
top_k: int | None = None,
temperature: float = 1.0,
) -> object:
token_ids = self._coerce_token_ids(context_token_ids)
if context_states is None:
states = self.embeddings[token_ids]
else:
states = np.asarray(context_states, dtype=np.float32)
if states.ndim != 2 or states.shape[0] != token_ids.size:
raise ValueError("context_states must be rank-2 and match context length")
selection = self.select_positions(query_token_id, token_ids, top_k=top_k)
if not selection.positions:
return np.zeros(states.shape[1], dtype=np.float32)
selected_states = states[np.asarray(selection.positions, dtype=np.int64)]
scores = np.asarray(selection.scores, dtype=np.float32)
scaled = scores / max(float(temperature), 1e-6)
scaled -= float(scaled.max())
weights = np.exp(scaled)
weights /= max(float(weights.sum()), 1e-8)
return weights @ selected_states
def benchmark_selection(
self,
context_token_ids: Sequence[int] | object,
query_token_ids: Sequence[int] | object,
*,
top_k: int | None = None,
cache_context: bool = True,
) -> dict[str, object]:
token_ids = self._coerce_token_ids(context_token_ids)
queries = self._coerce_token_ids(query_token_ids)
build_started = time.perf_counter()
if cache_context:
self.build_context_index(token_ids)
build_elapsed = time.perf_counter() - build_started
started = time.perf_counter()
selected_total = 0
for query_id in queries.tolist():
if cache_context:
selection = self.select_positions_cached(int(query_id), top_k=top_k)
else:
selection = self.select_positions(int(query_id), token_ids, top_k=top_k)
selected_total += len(selection.positions)
elapsed = time.perf_counter() - started
return {
"context_tokens": int(token_ids.size),
"query_count": int(queries.size),
"top_k": min(int(top_k or self.k_neighbors), int(token_ids.size)) if token_ids.size else 0,
"selected_positions": int(selected_total),
"cache_context": bool(cache_context),
"index_build_seconds": build_elapsed,
"seconds": elapsed,
"queries_per_second": (float(queries.size) / elapsed) if elapsed > 0.0 else 0.0,
}
def _coerce_token_ids(self, token_ids: Sequence[int] | object) -> object:
ids = np.asarray(token_ids, dtype=np.int64)
if ids.ndim != 1:
raise ValueError("token ids must be rank-1")
if ids.size and (int(ids.min()) < 0 or int(ids.max()) >= self.embeddings.shape[0]):
raise ValueError("context token id is outside the embedding table")
return ids
def compare_selectors(
embeddings: object,
context_token_ids: Sequence[int] | object,
query_token_ids: Sequence[int] | object,
*,
top_k: int = 64,
hash_bits: int = 12,
probe_radius: int = 1,
seed: int = 2026,
) -> dict[str, object]:
_require_numpy()
exact = AnalyticalSparseAttention(embeddings, k_neighbors=top_k)
hashed = HashedSparseAttention(
embeddings,
k_neighbors=top_k,
hash_bits=hash_bits,
probe_radius=probe_radius,
seed=seed,
)
token_ids = exact._coerce_token_ids(context_token_ids)
queries = exact._coerce_token_ids(query_token_ids)
hashed.build_context_index(token_ids)
recalls: list[float] = []
for query_id in queries.tolist():
exact_positions = set(exact.select_positions(int(query_id), token_ids, top_k=top_k).positions)
hashed_positions = set(hashed.select_positions_cached(int(query_id), top_k=top_k).positions)
if not exact_positions:
recalls.append(1.0)
else:
recalls.append(len(exact_positions & hashed_positions) / len(exact_positions))
return {
"context_tokens": int(token_ids.size),
"query_count": int(queries.size),
"top_k": int(top_k),
"hash_bits": int(hash_bits),
"probe_radius": int(probe_radius),
"mean_recall_at_k": float(sum(recalls) / len(recalls)) if recalls else 0.0,
"min_recall_at_k": float(min(recalls)) if recalls else 0.0,
}
class HashedSparseAttention(AnalyticalSparseAttention):
"""Approximate sparse selector using deterministic random-hyperplane buckets.
It keeps the analytical embedding-geometry rule, but avoids scanning the full
context for every query. Buckets are built once from signs of fixed
hyperplane projections; each query scans only matching buckets, then reranks
the candidate set exactly by cosine similarity.
"""
def __init__(
self,
embeddings: object,
*,
k_neighbors: int = 64,
hash_bits: int = 12,
probe_radius: int = 1,
seed: int = 2026,
candidate_multiplier: int = 12,
) -> None:
super().__init__(embeddings, k_neighbors=k_neighbors)
self.hash_bits = max(1, int(hash_bits))
self.probe_radius = max(0, int(probe_radius))
self.candidate_multiplier = max(1, int(candidate_multiplier))
rng = np.random.default_rng(int(seed))
self.hyperplanes = rng.normal(
size=(self.embedding_dim, self.hash_bits)
).astype(np.float32)
self._bucket_positions: dict[int, list[int]] = {}
def build_context_index(self, context_token_ids: Sequence[int] | object) -> None:
token_ids = self._coerce_token_ids(context_token_ids)
self._context_token_ids = token_ids
self._context_vectors = self.normalized_embeddings[token_ids]
codes = self._codes_for_vectors(self._context_vectors)
buckets: dict[int, list[int]] = {}
for position, code in enumerate(codes.tolist()):
buckets.setdefault(int(code), []).append(position)
self._bucket_positions = buckets
def select_positions_cached(
self,
query_token_id: int,
*,
top_k: int | None = None,
) -> SparseSelection:
if self._context_token_ids is None or self._context_vectors is None:
raise RuntimeError("call build_context_index() before select_positions_cached()")
query_id = int(query_token_id)
if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]:
raise ValueError("query_token_id is outside the embedding table")
k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors)))
candidate_positions = self._candidate_positions(query_id, k)
if len(candidate_positions) < k:
return super().select_positions_cached(query_id, top_k=top_k)
positions = np.asarray(candidate_positions, dtype=np.int64)
query_vector = self.normalized_embeddings[query_id]
scores = self._context_vectors[positions] @ query_vector
if k >= scores.size:
selected_local = np.argsort(scores)[::-1]
else:
selected_local = np.argpartition(scores, -k)[-k:]
selected_local = selected_local[np.argsort(scores[selected_local])[::-1]]
selected_positions = positions[selected_local]
return SparseSelection(
positions=[int(index) for index in selected_positions.tolist()],
scores=[float(scores[index]) for index in selected_local.tolist()],
)
def _candidate_positions(self, query_token_id: int, k: int) -> list[int]:
query_vector = self.normalized_embeddings[int(query_token_id)].reshape(1, -1)
query_code = int(self._codes_for_vectors(query_vector)[0])
candidate_limit = max(k, k * self.candidate_multiplier)
candidates: list[int] = []
seen: set[int] = set()
for code in self._probe_codes(query_code):
for position in self._bucket_positions.get(code, []):
if position in seen:
continue
seen.add(position)
candidates.append(position)
if len(candidates) >= candidate_limit:
return candidates
return candidates
def _codes_for_vectors(self, vectors: object) -> object:
projections = np.asarray(vectors, dtype=np.float32) @ self.hyperplanes
bits = projections >= 0.0
codes = np.zeros(bits.shape[0], dtype=np.int64)
for bit_index in range(self.hash_bits):
codes |= bits[:, bit_index].astype(np.int64) << bit_index
return codes
def _probe_codes(self, code: int) -> list[int]:
codes = [int(code)]
if self.probe_radius >= 1:
codes.extend(int(code) ^ (1 << bit) for bit in range(self.hash_bits))
if self.probe_radius >= 2:
for first in range(self.hash_bits):
for second in range(first + 1, self.hash_bits):
codes.append(int(code) ^ (1 << first) ^ (1 << second))
return codes
class FaissSparseAttention(AnalyticalSparseAttention):
"""Native FAISS-backed sparse selector over normalized embedding geometry."""
def __init__(
self,
embeddings: object,
*,
k_neighbors: int = 64,
approximate: bool = False,
hnsw_neighbors: int = 32,
ef_search: int = 64,
) -> None:
if faiss is None:
raise RuntimeError("faiss-cpu is not installed")
super().__init__(embeddings, k_neighbors=k_neighbors)
self.approximate = bool(approximate)
self.hnsw_neighbors = max(4, int(hnsw_neighbors))
self.ef_search = max(int(k_neighbors), int(ef_search))
self.index = self._new_index()
def _new_index(self) -> object:
if self.approximate:
index = faiss.IndexHNSWFlat(
self.embedding_dim,
self.hnsw_neighbors,
faiss.METRIC_INNER_PRODUCT,
)
index.hnsw.efSearch = self.ef_search
index.hnsw.efConstruction = max(self.ef_search, self.hnsw_neighbors * 2)
return index
return faiss.IndexFlatIP(self.embedding_dim)
def build_context_index(self, context_token_ids: Sequence[int] | object) -> None:
token_ids = self._coerce_token_ids(context_token_ids)
self._context_token_ids = token_ids
self._context_vectors = np.ascontiguousarray(
self.normalized_embeddings[token_ids],
dtype=np.float32,
)
self.index = self._new_index()
self.index.add(self._context_vectors)
def select_positions_cached(
self,
query_token_id: int,
*,
top_k: int | None = None,
) -> SparseSelection:
if self._context_token_ids is None or self._context_vectors is None:
raise RuntimeError("call build_context_index() before select_positions_cached()")
query_id = int(query_token_id)
if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]:
raise ValueError("query_token_id is outside the embedding table")
k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors)))
query = np.ascontiguousarray(
self.normalized_embeddings[query_id].reshape(1, -1),
dtype=np.float32,
)
scores, indices = self.index.search(query, k)
valid = indices[0] >= 0
return SparseSelection(
positions=[int(index) for index in indices[0][valid].tolist()],
scores=[float(score) for score in scores[0][valid].tolist()],
)