File size: 19,702 Bytes
0769ff3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
"""
Engrammatic Geometry Retrieval β€” State Extraction Layer


Extracts a retrieval state vector from a KV cache tensor for MIPS-based
retrieval in EGR (Engrammatic Geometry Retrieval). The state vector is
a compact geometric fingerprint of a cognitive state β€” positioned in the
model's own pre-RoPE key manifold for geometrically consistent retrieval.

Three extraction modes:

  mean_pool:   Fast baseline. Mean over heads + context of key matrices
               across extraction layers. Output: [head_dim]. No learned
               parameters. Use for bootstrapping and smoke tests.

  svd_project: Truncated SVD on pre-RoPE keys, extraction layers (D3: 8-31),
               rank-160 for 8B models. Validated by ShadowKV (ICML 2025,
               ByteDance) on Llama-3.1-8B and Phi-3-Mini-128K.
               Output: [rank]. Projection is prompt-dependent β€” W computed
               per cache via online SVD, not precomputed globally.
               Reference: github.com/ByteDance-Seed/ShadowKV

  xkv_project: Grouped cross-layer SVD. Groups 4 adjacent extraction layers,
               extracts shared basis vectors across the group. Achieves
               6.8x compression vs 2.5x single-layer SVD. K:V rank ratio
               1:1.5 is optimal per xKV paper.
               Reference: github.com/abdelfattah-lab/xKV
               arXiv:2503.18893

REMOVED: sals_project β€” last-layer-only extraction invalidated by
Layer-Condensed KV Cache (ACL 2024). See D3.

D4: No L2 normalization. True MIPS. L2 norm stored as metadata for
    optional downstream use.
"""

from __future__ import annotations

from dataclasses import dataclass, field

import torch
from einops import rearrange

from kvcos.core.types import (
    DEFAULT_SVD_RANK,
    ModelCacheSpec,
    StateExtractionMode,
)


@dataclass
class ExtractionResult:
    """Result of state vector extraction from a KV cache."""

    state_vec: torch.Tensor  # [d_out] β€” the retrieval vector
    l2_norm: float  # stored as metadata per D4
    mode: StateExtractionMode
    n_layers_used: int
    n_tokens: int


@dataclass
class SVDProjection:
    """Learned SVD projection matrix for a specific cache.

    ShadowKV finding: pre-RoPE keys share low-rank subspaces WITHIN
    sequences but differ ACROSS sequences. Projection must be computed
    online per cache, not precomputed globally.
    """

    W: torch.Tensor  # [head_dim, rank] β€” right singular vectors
    singular_values: torch.Tensor  # [rank] β€” for diagnostics
    explained_variance_ratio: float  # fraction of variance captured
    source_shape: tuple[int, ...]  # shape of the keys used to compute this


class MARStateExtractor:
    """Extracts retrieval state vectors from KV cache tensors for EGR.

    Usage:
        extractor = MARStateExtractor(mode="svd_project", rank=160)
        result = extractor.extract(keys, spec)
        # result.state_vec is the retrieval vector for FAISS IndexFlatIP
        # result.l2_norm goes into .eng metadata (D4)
    """

    # Max rows fed to SVD. 8192 rows on a 128-dim matrix runs in ~15ms
    # vs ~2000ms for the full 786K-row matrix. Subspace quality is
    # preserved because SVD only needs O(head_dimΒ²) samples to recover
    # the top singular vectors of a low-rank matrix.
    MAX_SVD_ROWS: int = 8192

    def __init__(
        self,
        mode: StateExtractionMode = StateExtractionMode.SVD_PROJECT,
        rank: int = DEFAULT_SVD_RANK,
        xkv_group_size: int = 4,
        xkv_kv_rank_ratio: float = 1.5,
        max_svd_rows: int | None = None,
        layer_range: tuple[int, int] | None = None,
        gate_start: int = 0,
    ):
        self.mode = mode
        self.rank = rank
        self.xkv_group_size = xkv_group_size
        self.xkv_kv_rank_ratio = xkv_kv_rank_ratio
        self.max_svd_rows = max_svd_rows or self.MAX_SVD_ROWS
        # Override spec extraction_layers when set. (8, 24) uses middle
        # layers which encode semantic content (Tenney 2019, Huh 2024).
        self.layer_range = layer_range
        # Skip top gate_start singular values in SVD projection.
        # Top SVs encode shared positional/syntactic structure;
        # skipping them isolates semantic content (gate_start=6 optimal).
        self.gate_start = gate_start

        # Cached projection from last extract call (for inspection/reuse)
        self._last_projection: SVDProjection | None = None

    def extract(
        self,
        keys: torch.Tensor,
        spec: ModelCacheSpec,
    ) -> ExtractionResult:
        """Extract a state vector from KV cache key tensors.

        Args:
            keys: [n_layers, n_kv_heads, ctx_len, head_dim] β€” the K cache.
                  Must be pre-RoPE if available. Post-RoPE works but with
                  reduced retrieval quality due to position-dependent distortion.
            spec: Model architecture spec (provides extraction_layers).

        Returns:
            ExtractionResult with state vector and metadata.
        """
        n_layers, n_kv_heads, ctx_len, head_dim = keys.shape

        # Layer selection: layer_range overrides spec extraction_layers
        if self.layer_range is not None:
            start, end = self.layer_range
            start = max(0, min(start, n_layers))
            end = max(start, min(end, n_layers))
            layer_indices = list(range(start, end))
        else:
            extraction_layers = spec["extraction_layers"]
            layer_indices = [l for l in extraction_layers if l < n_layers]

        if not layer_indices:
            layer_indices = list(range(n_layers))

        selected_keys = keys[layer_indices]  # [n_selected, n_kv_heads, ctx_len, head_dim]

        match self.mode:
            case StateExtractionMode.MEAN_POOL:
                state_vec = self._mean_pool(selected_keys)
            case StateExtractionMode.SVD_PROJECT:
                state_vec = self._svd_project(selected_keys)
            case StateExtractionMode.XKV_PROJECT:
                state_vec = self._xkv_project(selected_keys)
            case _:
                raise ValueError(f"Unknown extraction mode: {self.mode}")

        # D4: No normalization. True MIPS. Store norm as metadata.
        l2_norm = float(torch.linalg.vector_norm(state_vec).item())

        return ExtractionResult(
            state_vec=state_vec,
            l2_norm=l2_norm,
            mode=self.mode,
            n_layers_used=len(layer_indices),
            n_tokens=ctx_len,
        )

    def _mean_pool(self, keys: torch.Tensor) -> torch.Tensor:
        """Fast baseline: mean over layers, heads, and context positions.

        Input:  [n_layers, n_kv_heads, ctx_len, head_dim]
        Output: [head_dim]
        """
        return keys.float().mean(dim=(0, 1, 2))

    def _svd_project(self, keys: torch.Tensor) -> torch.Tensor:
        """Truncated SVD projection on pre-RoPE keys.

        ShadowKV approach: flatten all extraction layers' keys into a 2D matrix
        [N, head_dim], compute truncated SVD, project onto top-rank singular vectors,
        then mean-pool the projected vectors.

        For large contexts (N > max_svd_rows), we subsample rows before SVD.
        SVD only needs O(head_dimΒ²) samples to recover the top singular vectors
        of a low-rank matrix, so subsampling to 8K rows preserves subspace quality
        while reducing SVD from ~2000ms to ~15ms at 4K context.

        Input:  [n_layers, n_kv_heads, ctx_len, head_dim]
        Output: [rank]
        """
        n_layers, n_kv_heads, ctx_len, head_dim = keys.shape

        # Total rows in the flattened matrix
        n_rows = n_layers * n_kv_heads * ctx_len

        if n_rows > self.max_svd_rows:
            # Subsample BEFORE flatten+cast to avoid allocating the full
            # float32 matrix (saves ~30ms rearrange + 100MB at 4K context).
            gen = torch.Generator()
            gen.manual_seed(42)
            indices = torch.randperm(n_rows, generator=gen)[:self.max_svd_rows]
            flat_keys = keys.reshape(n_rows, head_dim)[indices].float()
            svd_input = flat_keys
        else:
            flat_keys = rearrange(keys.float(), 'l h t d -> (l h t) d')
            svd_input = flat_keys

        # Clamp rank to not exceed matrix dimensions
        max_rank = min(head_dim, svd_input.shape[0])
        effective_rank = min(self.gate_start + self.rank, max_rank)

        # Truncated SVD on (subsampled) matrix
        U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)

        # W = right singular vectors with gating: skip top gate_start SVs
        # to remove shared positional/syntactic structure
        W = Vh[self.gate_start:effective_rank, :].T

        # Store projection for inspection
        total_var = (S ** 2).sum()
        explained_var = (S[:effective_rank] ** 2).sum()
        self._last_projection = SVDProjection(
            W=W,
            singular_values=S[:effective_rank],
            explained_variance_ratio=float((explained_var / total_var).item()) if total_var > 0 else 0.0,
            source_shape=tuple(keys.shape),
        )

        # Project subsampled rows and mean-pool β†’ [rank]
        # Using the subsample for projection too avoids the expensive
        # 786K Γ— 128 matmul + mean that dominates at large contexts.
        projected = svd_input @ W
        state_vec = projected.mean(dim=0)

        return state_vec

    def _xkv_project(self, keys: torch.Tensor) -> torch.Tensor:
        """Grouped cross-layer SVD (xKV approach).

        Groups adjacent layers (default 4), computes shared SVD basis
        per group, projects keys onto that basis, then concatenates
        group state vectors.

        This captures cross-layer structure that single-layer SVD misses.
        Achieves 6.8x vs 2.5x for single-layer SVD on Llama-3.1-8B.

        K:V rank ratio 1:1.5 is optimal per xKV paper, but since we
        only index keys (D2: K→K retrieval), we use the K rank only.

        Input:  [n_layers, n_kv_heads, ctx_len, head_dim]
        Output: [n_groups * rank_per_group]
        """
        n_layers, n_kv_heads, ctx_len, head_dim = keys.shape

        # Compute rank per group
        # xKV finding: K rank is lower than V rank by factor 1:1.5
        # For 160 total rank budget across groups, allocate per group
        n_groups = max(1, n_layers // self.xkv_group_size)
        rank_per_group = max(1, self.rank // n_groups)
        rank_per_group = min(rank_per_group, head_dim)

        group_vecs: list[torch.Tensor] = []

        for g in range(n_groups):
            start = g * self.xkv_group_size
            end = min(start + self.xkv_group_size, n_layers)
            group_keys = keys[start:end]  # [group_size, n_kv_heads, ctx_len, head_dim]

            # Flatten group
            n_group_rows = group_keys.shape[0] * n_kv_heads * ctx_len

            if n_group_rows > self.max_svd_rows:
                gen = torch.Generator()
                gen.manual_seed(42 + g)
                indices = torch.randperm(n_group_rows, generator=gen)[:self.max_svd_rows]
                svd_input = group_keys.reshape(n_group_rows, head_dim)[indices].float()
            else:
                svd_input = rearrange(group_keys.float(), 'l h t d -> (l h t) d')

            effective_rank = min(rank_per_group, svd_input.shape[0], head_dim)

            # Truncated SVD for this group (on subsampled data)
            U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
            W_group = Vh[:effective_rank, :].T  # [head_dim, rank_per_group]

            # Project subsampled rows and mean-pool β†’ [rank_per_group]
            projected = svd_input @ W_group
            group_vec = projected.mean(dim=0)
            group_vecs.append(group_vec)

        # Handle remainder layers (if n_layers not divisible by group_size)
        remainder_start = n_groups * self.xkv_group_size
        if remainder_start < n_layers:
            remainder_keys = keys[remainder_start:]
            n_rem_rows = remainder_keys.shape[0] * n_kv_heads * ctx_len

            if n_rem_rows > self.max_svd_rows:
                gen = torch.Generator()
                gen.manual_seed(42 + n_groups)
                indices = torch.randperm(n_rem_rows, generator=gen)[:self.max_svd_rows]
                svd_input = remainder_keys.reshape(n_rem_rows, head_dim)[indices].float()
            else:
                svd_input = rearrange(remainder_keys.float(), 'l h t d -> (l h t) d')

            effective_rank = min(rank_per_group, svd_input.shape[0], head_dim)
            U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
            W_rem = Vh[:effective_rank, :].T
            projected = svd_input @ W_rem
            group_vecs.append(projected.mean(dim=0))

        # Concatenate all group vectors β†’ [n_groups * rank_per_group + remainder]
        state_vec = torch.cat(group_vecs, dim=0)

        return state_vec

    # ── Fixed Corpus Basis (FCB) ────────────────────────────────────────────

    @classmethod
    def compute_corpus_basis(
        cls,
        key_tensors: list[torch.Tensor],
        layer_range: tuple[int, int],
        gate_start: int,
        rank: int,
        max_rows: int = 32768,
        seed: int = 42,
    ) -> torch.Tensor:
        """Compute a fixed projection matrix from a corpus of key tensors.

        Returns P: [rank, head_dim] β€” the global semantic basis.
        Unlike per-document SVD, this basis is document-independent.
        All documents projected with P exist in the same coordinate system,
        enabling stable cross-document and cross-model comparison.
        """
        l_start, l_end = layer_range
        gen = torch.Generator()
        gen.manual_seed(seed)

        all_rows: list[torch.Tensor] = []
        per_doc_max = max(1, max_rows // len(key_tensors))

        for keys in key_tensors:
            k = keys[l_start:l_end].float()
            n_rows = k.shape[0] * k.shape[1] * k.shape[2]
            flat = k.reshape(n_rows, k.shape[3])
            if flat.shape[0] > per_doc_max:
                idx = torch.randperm(flat.shape[0], generator=gen)[:per_doc_max]
                flat = flat[idx]
            all_rows.append(flat)

        corpus = torch.cat(all_rows, dim=0)
        if corpus.shape[0] > max_rows:
            idx = torch.randperm(corpus.shape[0], generator=gen)[:max_rows]
            corpus = corpus[idx]

        _, S, Vh = torch.linalg.svd(corpus, full_matrices=False)
        P = Vh[gate_start : gate_start + rank]  # [rank, head_dim]
        return P

    def extract_with_basis(
        self,
        keys: torch.Tensor,
        spec: ModelCacheSpec,
        basis: torch.Tensor,
    ) -> ExtractionResult:
        """Extract state vector using a pre-computed fixed corpus basis.

        All vectors computed with the same basis share a coordinate system,
        which is required for cross-model transfer via adapter.

        Args:
            keys: [n_layers, n_kv_heads, n_cells, head_dim]
            spec: Model spec (used for layer_range fallback)
            basis: [rank, head_dim] from compute_corpus_basis()

        Returns:
            ExtractionResult with L2-normalized state vector
        """
        if self.layer_range is not None:
            l_start, l_end = self.layer_range
        else:
            l_start, l_end = 0, keys.shape[0]
        l_start = max(0, min(l_start, keys.shape[0]))
        l_end = max(l_start, min(l_end, keys.shape[0]))

        k = keys[l_start:l_end].float()
        n_rows = k.shape[0] * k.shape[1] * k.shape[2]
        flat = k.reshape(n_rows, k.shape[3])

        proj = flat @ basis.T  # [N_rows, rank]
        vec = proj.mean(dim=0)  # [rank]

        norm = float(torch.linalg.vector_norm(vec).item())
        vec_normed = vec / (norm + 1e-8)

        return ExtractionResult(
            state_vec=vec_normed.to(torch.float32),
            l2_norm=norm,
            mode=self.mode,
            n_layers_used=l_end - l_start,
            n_tokens=k.shape[2],
        )

    # ── Fourier Fingerprint (Engram Absolute) ────────────────────────

    @staticmethod
    def compute_fourier_fingerprint(
        keys: torch.Tensor,
        freqs: tuple[int, ...] = (0, 1),
    ) -> torch.Tensor:
        """Compute the Fourier Absolute fingerprint from KV cache keys.

        Takes the real DFT over the layer dimension, extracts the
        amplitude at the specified frequencies, normalizes each, and
        concatenates them into a single fingerprint vector.

        This fingerprint is:
          - Cross-model invariant (cos ~0.90 between 3B and 8B)
          - Corpus-independent (no basis, no center, no training)
          - Scale-stable (98% recall@1 at N=1000, decay N^-0.207)

        Args:
            keys: [n_layers, n_kv_heads, n_cells, head_dim] β€” full KV keys.
                  All layers are used (not sliced by layer_range).
            freqs: Frequency indices to extract. Default (0, 1) = DC + 1st harmonic.
                   f=0 captures overall key magnitude profile.
                   f=1 captures dominant oscillation across depth.

        Returns:
            Fingerprint vector [dim * len(freqs)], L2-normalized.
        """
        # Mean over cells (tokens) per layer: [n_layers, n_kv_heads * head_dim]
        n_layers = keys.shape[0]
        layer_means = keys.float().mean(dim=2).reshape(n_layers, -1)

        # DFT over layer dimension
        F_complex = torch.fft.rfft(layer_means, dim=0)  # [n_freq, dim]
        F_amp = F_complex.abs()  # amplitude spectrum

        # Extract and normalize each frequency component
        parts = []
        for f in freqs:
            if f >= F_amp.shape[0]:
                # Frequency out of range β€” use zeros
                parts.append(torch.zeros(F_amp.shape[1]))
            else:
                v = F_amp[f]
                parts.append(v / (v.norm() + 1e-8))

        fingerprint = torch.cat(parts, dim=0)
        return fingerprint / (fingerprint.norm() + 1e-8)

    @property
    def last_projection(self) -> SVDProjection | None:
        """Access the SVD projection from the last svd_project call.

        Useful for diagnostics: check explained_variance_ratio to validate
        that the rank is sufficient for this particular cache.
        """
        return self._last_projection

    def output_dim(self, spec: ModelCacheSpec) -> int:
        """Compute the output dimension of the state vector for a given spec.

        This is needed to initialize the FAISS index with the correct dimension.
        """
        match self.mode:
            case StateExtractionMode.MEAN_POOL:
                return spec["head_dim"]
            case StateExtractionMode.SVD_PROJECT:
                max_rank = min(self.gate_start + self.rank, spec["head_dim"])
                return max_rank - self.gate_start
            case StateExtractionMode.XKV_PROJECT:
                extraction_layers = spec["extraction_layers"]
                n_layers = len(extraction_layers)
                n_groups = max(1, n_layers // self.xkv_group_size)
                rank_per_group = max(1, self.rank // n_groups)
                rank_per_group = min(rank_per_group, spec["head_dim"])
                # Groups + possible remainder group
                has_remainder = (n_layers % self.xkv_group_size) != 0
                total_groups = n_groups + (1 if has_remainder else 0)
                return total_groups * rank_per_group
            case _:
                raise ValueError(f"Unknown mode: {self.mode}")