File size: 16,129 Bytes
52da7b7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
from __future__ import annotations

from dataclasses import dataclass
import time
from typing import Sequence

try:  # pragma: no cover - exercised when NumPy is available in runtime envs.
    import numpy as np
except Exception:  # pragma: no cover
    np = None  # type: ignore[assignment]

try:  # pragma: no cover - optional native ANN backend.
    import faiss
except Exception:  # pragma: no cover
    faiss = None  # type: ignore[assignment]


@dataclass(frozen=True)
class SparseSelection:
    positions: list[int]
    scores: list[float]


def _require_numpy() -> None:
    if np is None:
        raise RuntimeError("NumPy is required for the sparse-context kernel.")


def normalize_rows(matrix: object) -> object:
    _require_numpy()
    values = np.asarray(matrix, dtype=np.float32)
    if values.ndim != 2:
        raise ValueError("matrix must be rank-2")
    norms = np.linalg.norm(values, axis=1, keepdims=True)
    return values / np.maximum(norms, 1e-8)


class AnalyticalSparseAttention:
    """Content-dependent long-context selection from corpus-derived embeddings.

    This is Reframr's analytical sparse-context kernel: it selects positions by
    embedding geometry, then aggregates only the selected states. It does not
    contain task-specific answer strings or prompt-pattern shortcuts.
    """

    def __init__(self, embeddings: object, *, k_neighbors: int = 64) -> None:
        _require_numpy()
        self.embeddings = np.asarray(embeddings, dtype=np.float32)
        if self.embeddings.ndim != 2:
            raise ValueError("embeddings must be rank-2")
        self.k_neighbors = max(1, int(k_neighbors))
        self.normalized_embeddings = normalize_rows(self.embeddings)
        self._context_token_ids: object | None = None
        self._context_vectors: object | None = None

    @property
    def embedding_dim(self) -> int:
        return int(self.embeddings.shape[1])

    def select_positions(
        self,
        query_token_id: int,
        context_token_ids: Sequence[int] | object,
        *,
        top_k: int | None = None,
    ) -> SparseSelection:
        token_ids = self._coerce_token_ids(context_token_ids)
        context_vectors = self.normalized_embeddings[token_ids]
        return self._select_positions_from_vectors(
            query_token_id,
            token_ids,
            context_vectors,
            top_k=top_k,
        )

    def build_context_index(self, context_token_ids: Sequence[int] | object) -> None:
        token_ids = self._coerce_token_ids(context_token_ids)
        self._context_token_ids = token_ids
        self._context_vectors = self.normalized_embeddings[token_ids]

    def select_positions_cached(
        self,
        query_token_id: int,
        *,
        top_k: int | None = None,
    ) -> SparseSelection:
        if self._context_token_ids is None or self._context_vectors is None:
            raise RuntimeError("call build_context_index() before select_positions_cached()")
        return self._select_positions_from_vectors(
            query_token_id,
            self._context_token_ids,
            self._context_vectors,
            top_k=top_k,
        )

    def _select_positions_from_vectors(
        self,
        query_token_id: int,
        token_ids: object,
        context_vectors: object,
        *,
        top_k: int | None = None,
    ) -> SparseSelection:
        if token_ids.size == 0:
            return SparseSelection(positions=[], scores=[])
        query_id = int(query_token_id)
        if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]:
            raise ValueError("query_token_id is outside the embedding table")
        k = min(token_ids.size, max(1, int(top_k or self.k_neighbors)))
        query_vector = self.normalized_embeddings[query_id]
        scores = context_vectors @ query_vector
        if k >= scores.size:
            selected = np.argsort(scores)[::-1]
        else:
            selected = np.argpartition(scores, -k)[-k:]
            selected = selected[np.argsort(scores[selected])[::-1]]
        return SparseSelection(
            positions=[int(index) for index in selected.tolist()],
            scores=[float(scores[index]) for index in selected.tolist()],
        )

    def sparse_output(
        self,
        query_token_id: int,
        context_token_ids: Sequence[int] | object,
        context_states: object | None = None,
        *,
        top_k: int | None = None,
        temperature: float = 1.0,
    ) -> object:
        token_ids = self._coerce_token_ids(context_token_ids)
        if context_states is None:
            states = self.embeddings[token_ids]
        else:
            states = np.asarray(context_states, dtype=np.float32)
            if states.ndim != 2 or states.shape[0] != token_ids.size:
                raise ValueError("context_states must be rank-2 and match context length")
        selection = self.select_positions(query_token_id, token_ids, top_k=top_k)
        if not selection.positions:
            return np.zeros(states.shape[1], dtype=np.float32)
        selected_states = states[np.asarray(selection.positions, dtype=np.int64)]
        scores = np.asarray(selection.scores, dtype=np.float32)
        scaled = scores / max(float(temperature), 1e-6)
        scaled -= float(scaled.max())
        weights = np.exp(scaled)
        weights /= max(float(weights.sum()), 1e-8)
        return weights @ selected_states

    def benchmark_selection(
        self,
        context_token_ids: Sequence[int] | object,
        query_token_ids: Sequence[int] | object,
        *,
        top_k: int | None = None,
        cache_context: bool = True,
    ) -> dict[str, object]:
        token_ids = self._coerce_token_ids(context_token_ids)
        queries = self._coerce_token_ids(query_token_ids)
        build_started = time.perf_counter()
        if cache_context:
            self.build_context_index(token_ids)
        build_elapsed = time.perf_counter() - build_started
        started = time.perf_counter()
        selected_total = 0
        for query_id in queries.tolist():
            if cache_context:
                selection = self.select_positions_cached(int(query_id), top_k=top_k)
            else:
                selection = self.select_positions(int(query_id), token_ids, top_k=top_k)
            selected_total += len(selection.positions)
        elapsed = time.perf_counter() - started
        return {
            "context_tokens": int(token_ids.size),
            "query_count": int(queries.size),
            "top_k": min(int(top_k or self.k_neighbors), int(token_ids.size)) if token_ids.size else 0,
            "selected_positions": int(selected_total),
            "cache_context": bool(cache_context),
            "index_build_seconds": build_elapsed,
            "seconds": elapsed,
            "queries_per_second": (float(queries.size) / elapsed) if elapsed > 0.0 else 0.0,
        }

    def _coerce_token_ids(self, token_ids: Sequence[int] | object) -> object:
        ids = np.asarray(token_ids, dtype=np.int64)
        if ids.ndim != 1:
            raise ValueError("token ids must be rank-1")
        if ids.size and (int(ids.min()) < 0 or int(ids.max()) >= self.embeddings.shape[0]):
            raise ValueError("context token id is outside the embedding table")
        return ids


def compare_selectors(
    embeddings: object,
    context_token_ids: Sequence[int] | object,
    query_token_ids: Sequence[int] | object,
    *,
    top_k: int = 64,
    hash_bits: int = 12,
    probe_radius: int = 1,
    seed: int = 2026,
) -> dict[str, object]:
    _require_numpy()
    exact = AnalyticalSparseAttention(embeddings, k_neighbors=top_k)
    hashed = HashedSparseAttention(
        embeddings,
        k_neighbors=top_k,
        hash_bits=hash_bits,
        probe_radius=probe_radius,
        seed=seed,
    )
    token_ids = exact._coerce_token_ids(context_token_ids)
    queries = exact._coerce_token_ids(query_token_ids)
    hashed.build_context_index(token_ids)
    recalls: list[float] = []
    for query_id in queries.tolist():
        exact_positions = set(exact.select_positions(int(query_id), token_ids, top_k=top_k).positions)
        hashed_positions = set(hashed.select_positions_cached(int(query_id), top_k=top_k).positions)
        if not exact_positions:
            recalls.append(1.0)
        else:
            recalls.append(len(exact_positions & hashed_positions) / len(exact_positions))
    return {
        "context_tokens": int(token_ids.size),
        "query_count": int(queries.size),
        "top_k": int(top_k),
        "hash_bits": int(hash_bits),
        "probe_radius": int(probe_radius),
        "mean_recall_at_k": float(sum(recalls) / len(recalls)) if recalls else 0.0,
        "min_recall_at_k": float(min(recalls)) if recalls else 0.0,
    }

class HashedSparseAttention(AnalyticalSparseAttention):
    """Approximate sparse selector using deterministic random-hyperplane buckets.

    It keeps the analytical embedding-geometry rule, but avoids scanning the full
    context for every query. Buckets are built once from signs of fixed
    hyperplane projections; each query scans only matching buckets, then reranks
    the candidate set exactly by cosine similarity.
    """

    def __init__(
        self,
        embeddings: object,
        *,
        k_neighbors: int = 64,
        hash_bits: int = 12,
        probe_radius: int = 1,
        seed: int = 2026,
        candidate_multiplier: int = 12,
    ) -> None:
        super().__init__(embeddings, k_neighbors=k_neighbors)
        self.hash_bits = max(1, int(hash_bits))
        self.probe_radius = max(0, int(probe_radius))
        self.candidate_multiplier = max(1, int(candidate_multiplier))
        rng = np.random.default_rng(int(seed))
        self.hyperplanes = rng.normal(
            size=(self.embedding_dim, self.hash_bits)
        ).astype(np.float32)
        self._bucket_positions: dict[int, list[int]] = {}

    def build_context_index(self, context_token_ids: Sequence[int] | object) -> None:
        token_ids = self._coerce_token_ids(context_token_ids)
        self._context_token_ids = token_ids
        self._context_vectors = self.normalized_embeddings[token_ids]
        codes = self._codes_for_vectors(self._context_vectors)
        buckets: dict[int, list[int]] = {}
        for position, code in enumerate(codes.tolist()):
            buckets.setdefault(int(code), []).append(position)
        self._bucket_positions = buckets

    def select_positions_cached(
        self,
        query_token_id: int,
        *,
        top_k: int | None = None,
    ) -> SparseSelection:
        if self._context_token_ids is None or self._context_vectors is None:
            raise RuntimeError("call build_context_index() before select_positions_cached()")
        query_id = int(query_token_id)
        if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]:
            raise ValueError("query_token_id is outside the embedding table")
        k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors)))
        candidate_positions = self._candidate_positions(query_id, k)
        if len(candidate_positions) < k:
            return super().select_positions_cached(query_id, top_k=top_k)
        positions = np.asarray(candidate_positions, dtype=np.int64)
        query_vector = self.normalized_embeddings[query_id]
        scores = self._context_vectors[positions] @ query_vector
        if k >= scores.size:
            selected_local = np.argsort(scores)[::-1]
        else:
            selected_local = np.argpartition(scores, -k)[-k:]
            selected_local = selected_local[np.argsort(scores[selected_local])[::-1]]
        selected_positions = positions[selected_local]
        return SparseSelection(
            positions=[int(index) for index in selected_positions.tolist()],
            scores=[float(scores[index]) for index in selected_local.tolist()],
        )

    def _candidate_positions(self, query_token_id: int, k: int) -> list[int]:
        query_vector = self.normalized_embeddings[int(query_token_id)].reshape(1, -1)
        query_code = int(self._codes_for_vectors(query_vector)[0])
        candidate_limit = max(k, k * self.candidate_multiplier)
        candidates: list[int] = []
        seen: set[int] = set()
        for code in self._probe_codes(query_code):
            for position in self._bucket_positions.get(code, []):
                if position in seen:
                    continue
                seen.add(position)
                candidates.append(position)
                if len(candidates) >= candidate_limit:
                    return candidates
        return candidates

    def _codes_for_vectors(self, vectors: object) -> object:
        projections = np.asarray(vectors, dtype=np.float32) @ self.hyperplanes
        bits = projections >= 0.0
        codes = np.zeros(bits.shape[0], dtype=np.int64)
        for bit_index in range(self.hash_bits):
            codes |= bits[:, bit_index].astype(np.int64) << bit_index
        return codes

    def _probe_codes(self, code: int) -> list[int]:
        codes = [int(code)]
        if self.probe_radius >= 1:
            codes.extend(int(code) ^ (1 << bit) for bit in range(self.hash_bits))
        if self.probe_radius >= 2:
            for first in range(self.hash_bits):
                for second in range(first + 1, self.hash_bits):
                    codes.append(int(code) ^ (1 << first) ^ (1 << second))
        return codes


class FaissSparseAttention(AnalyticalSparseAttention):
    """Native FAISS-backed sparse selector over normalized embedding geometry."""

    def __init__(
        self,
        embeddings: object,
        *,
        k_neighbors: int = 64,
        approximate: bool = False,
        hnsw_neighbors: int = 32,
        ef_search: int = 64,
    ) -> None:
        if faiss is None:
            raise RuntimeError("faiss-cpu is not installed")
        super().__init__(embeddings, k_neighbors=k_neighbors)
        self.approximate = bool(approximate)
        self.hnsw_neighbors = max(4, int(hnsw_neighbors))
        self.ef_search = max(int(k_neighbors), int(ef_search))
        self.index = self._new_index()

    def _new_index(self) -> object:
        if self.approximate:
            index = faiss.IndexHNSWFlat(
                self.embedding_dim,
                self.hnsw_neighbors,
                faiss.METRIC_INNER_PRODUCT,
            )
            index.hnsw.efSearch = self.ef_search
            index.hnsw.efConstruction = max(self.ef_search, self.hnsw_neighbors * 2)
            return index
        return faiss.IndexFlatIP(self.embedding_dim)

    def build_context_index(self, context_token_ids: Sequence[int] | object) -> None:
        token_ids = self._coerce_token_ids(context_token_ids)
        self._context_token_ids = token_ids
        self._context_vectors = np.ascontiguousarray(
            self.normalized_embeddings[token_ids],
            dtype=np.float32,
        )
        self.index = self._new_index()
        self.index.add(self._context_vectors)

    def select_positions_cached(
        self,
        query_token_id: int,
        *,
        top_k: int | None = None,
    ) -> SparseSelection:
        if self._context_token_ids is None or self._context_vectors is None:
            raise RuntimeError("call build_context_index() before select_positions_cached()")
        query_id = int(query_token_id)
        if query_id < 0 or query_id >= self.normalized_embeddings.shape[0]:
            raise ValueError("query_token_id is outside the embedding table")
        k = min(self._context_token_ids.size, max(1, int(top_k or self.k_neighbors)))
        query = np.ascontiguousarray(
            self.normalized_embeddings[query_id].reshape(1, -1),
            dtype=np.float32,
        )
        scores, indices = self.index.search(query, k)
        valid = indices[0] >= 0
        return SparseSelection(
            positions=[int(index) for index in indices[0][valid].tolist()],
            scores=[float(score) for score in scores[0][valid].tolist()],
        )