| """ |
| Engrammatic Geometry Retrieval β State Extraction Layer |
| |
| |
| Extracts a retrieval state vector from a KV cache tensor for MIPS-based |
| retrieval in EGR (Engrammatic Geometry Retrieval). The state vector is |
| a compact geometric fingerprint of a cognitive state β positioned in the |
| model's own pre-RoPE key manifold for geometrically consistent retrieval. |
| |
| Three extraction modes: |
| |
| mean_pool: Fast baseline. Mean over heads + context of key matrices |
| across extraction layers. Output: [head_dim]. No learned |
| parameters. Use for bootstrapping and smoke tests. |
| |
| svd_project: Truncated SVD on pre-RoPE keys, extraction layers (D3: 8-31), |
| rank-160 for 8B models. Validated by ShadowKV (ICML 2025, |
| ByteDance) on Llama-3.1-8B and Phi-3-Mini-128K. |
| Output: [rank]. Projection is prompt-dependent β W computed |
| per cache via online SVD, not precomputed globally. |
| Reference: github.com/ByteDance-Seed/ShadowKV |
| |
| xkv_project: Grouped cross-layer SVD. Groups 4 adjacent extraction layers, |
| extracts shared basis vectors across the group. Achieves |
| 6.8x compression vs 2.5x single-layer SVD. K:V rank ratio |
| 1:1.5 is optimal per xKV paper. |
| Reference: github.com/abdelfattah-lab/xKV |
| arXiv:2503.18893 |
| |
| REMOVED: sals_project β last-layer-only extraction invalidated by |
| Layer-Condensed KV Cache (ACL 2024). See D3. |
| |
| D4: No L2 normalization. True MIPS. L2 norm stored as metadata for |
| optional downstream use. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass, field |
|
|
| import torch |
| from einops import rearrange |
|
|
| from kvcos.core.types import ( |
| DEFAULT_SVD_RANK, |
| ModelCacheSpec, |
| StateExtractionMode, |
| ) |
|
|
|
|
| @dataclass |
| class ExtractionResult: |
| """Result of state vector extraction from a KV cache.""" |
|
|
| state_vec: torch.Tensor |
| l2_norm: float |
| mode: StateExtractionMode |
| n_layers_used: int |
| n_tokens: int |
|
|
|
|
| @dataclass |
| class SVDProjection: |
| """Learned SVD projection matrix for a specific cache. |
| |
| ShadowKV finding: pre-RoPE keys share low-rank subspaces WITHIN |
| sequences but differ ACROSS sequences. Projection must be computed |
| online per cache, not precomputed globally. |
| """ |
|
|
| W: torch.Tensor |
| singular_values: torch.Tensor |
| explained_variance_ratio: float |
| source_shape: tuple[int, ...] |
|
|
|
|
| class MARStateExtractor: |
| """Extracts retrieval state vectors from KV cache tensors for EGR. |
| |
| Usage: |
| extractor = MARStateExtractor(mode="svd_project", rank=160) |
| result = extractor.extract(keys, spec) |
| # result.state_vec is the retrieval vector for FAISS IndexFlatIP |
| # result.l2_norm goes into .eng metadata (D4) |
| """ |
|
|
| |
| |
| |
| |
| MAX_SVD_ROWS: int = 8192 |
|
|
| def __init__( |
| self, |
| mode: StateExtractionMode = StateExtractionMode.SVD_PROJECT, |
| rank: int = DEFAULT_SVD_RANK, |
| xkv_group_size: int = 4, |
| xkv_kv_rank_ratio: float = 1.5, |
| max_svd_rows: int | None = None, |
| layer_range: tuple[int, int] | None = None, |
| gate_start: int = 0, |
| ): |
| self.mode = mode |
| self.rank = rank |
| self.xkv_group_size = xkv_group_size |
| self.xkv_kv_rank_ratio = xkv_kv_rank_ratio |
| self.max_svd_rows = max_svd_rows or self.MAX_SVD_ROWS |
| |
| |
| self.layer_range = layer_range |
| |
| |
| |
| self.gate_start = gate_start |
|
|
| |
| self._last_projection: SVDProjection | None = None |
|
|
| def extract( |
| self, |
| keys: torch.Tensor, |
| spec: ModelCacheSpec, |
| ) -> ExtractionResult: |
| """Extract a state vector from KV cache key tensors. |
| |
| Args: |
| keys: [n_layers, n_kv_heads, ctx_len, head_dim] β the K cache. |
| Must be pre-RoPE if available. Post-RoPE works but with |
| reduced retrieval quality due to position-dependent distortion. |
| spec: Model architecture spec (provides extraction_layers). |
| |
| Returns: |
| ExtractionResult with state vector and metadata. |
| """ |
| n_layers, n_kv_heads, ctx_len, head_dim = keys.shape |
|
|
| |
| if self.layer_range is not None: |
| start, end = self.layer_range |
| start = max(0, min(start, n_layers)) |
| end = max(start, min(end, n_layers)) |
| layer_indices = list(range(start, end)) |
| else: |
| extraction_layers = spec["extraction_layers"] |
| layer_indices = [l for l in extraction_layers if l < n_layers] |
|
|
| if not layer_indices: |
| layer_indices = list(range(n_layers)) |
|
|
| selected_keys = keys[layer_indices] |
|
|
| match self.mode: |
| case StateExtractionMode.MEAN_POOL: |
| state_vec = self._mean_pool(selected_keys) |
| case StateExtractionMode.SVD_PROJECT: |
| state_vec = self._svd_project(selected_keys) |
| case StateExtractionMode.XKV_PROJECT: |
| state_vec = self._xkv_project(selected_keys) |
| case _: |
| raise ValueError(f"Unknown extraction mode: {self.mode}") |
|
|
| |
| l2_norm = float(torch.linalg.vector_norm(state_vec).item()) |
|
|
| return ExtractionResult( |
| state_vec=state_vec, |
| l2_norm=l2_norm, |
| mode=self.mode, |
| n_layers_used=len(layer_indices), |
| n_tokens=ctx_len, |
| ) |
|
|
| def _mean_pool(self, keys: torch.Tensor) -> torch.Tensor: |
| """Fast baseline: mean over layers, heads, and context positions. |
| |
| Input: [n_layers, n_kv_heads, ctx_len, head_dim] |
| Output: [head_dim] |
| """ |
| return keys.float().mean(dim=(0, 1, 2)) |
|
|
| def _svd_project(self, keys: torch.Tensor) -> torch.Tensor: |
| """Truncated SVD projection on pre-RoPE keys. |
| |
| ShadowKV approach: flatten all extraction layers' keys into a 2D matrix |
| [N, head_dim], compute truncated SVD, project onto top-rank singular vectors, |
| then mean-pool the projected vectors. |
| |
| For large contexts (N > max_svd_rows), we subsample rows before SVD. |
| SVD only needs O(head_dimΒ²) samples to recover the top singular vectors |
| of a low-rank matrix, so subsampling to 8K rows preserves subspace quality |
| while reducing SVD from ~2000ms to ~15ms at 4K context. |
| |
| Input: [n_layers, n_kv_heads, ctx_len, head_dim] |
| Output: [rank] |
| """ |
| n_layers, n_kv_heads, ctx_len, head_dim = keys.shape |
|
|
| |
| n_rows = n_layers * n_kv_heads * ctx_len |
|
|
| if n_rows > self.max_svd_rows: |
| |
| |
| gen = torch.Generator() |
| gen.manual_seed(42) |
| indices = torch.randperm(n_rows, generator=gen)[:self.max_svd_rows] |
| flat_keys = keys.reshape(n_rows, head_dim)[indices].float() |
| svd_input = flat_keys |
| else: |
| flat_keys = rearrange(keys.float(), 'l h t d -> (l h t) d') |
| svd_input = flat_keys |
|
|
| |
| max_rank = min(head_dim, svd_input.shape[0]) |
| effective_rank = min(self.gate_start + self.rank, max_rank) |
|
|
| |
| U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False) |
|
|
| |
| |
| W = Vh[self.gate_start:effective_rank, :].T |
|
|
| |
| total_var = (S ** 2).sum() |
| explained_var = (S[:effective_rank] ** 2).sum() |
| self._last_projection = SVDProjection( |
| W=W, |
| singular_values=S[:effective_rank], |
| explained_variance_ratio=float((explained_var / total_var).item()) if total_var > 0 else 0.0, |
| source_shape=tuple(keys.shape), |
| ) |
|
|
| |
| |
| |
| projected = svd_input @ W |
| state_vec = projected.mean(dim=0) |
|
|
| return state_vec |
|
|
| def _xkv_project(self, keys: torch.Tensor) -> torch.Tensor: |
| """Grouped cross-layer SVD (xKV approach). |
| |
| Groups adjacent layers (default 4), computes shared SVD basis |
| per group, projects keys onto that basis, then concatenates |
| group state vectors. |
| |
| This captures cross-layer structure that single-layer SVD misses. |
| Achieves 6.8x vs 2.5x for single-layer SVD on Llama-3.1-8B. |
| |
| K:V rank ratio 1:1.5 is optimal per xKV paper, but since we |
| only index keys (D2: KβK retrieval), we use the K rank only. |
| |
| Input: [n_layers, n_kv_heads, ctx_len, head_dim] |
| Output: [n_groups * rank_per_group] |
| """ |
| n_layers, n_kv_heads, ctx_len, head_dim = keys.shape |
|
|
| |
| |
| |
| n_groups = max(1, n_layers // self.xkv_group_size) |
| rank_per_group = max(1, self.rank // n_groups) |
| rank_per_group = min(rank_per_group, head_dim) |
|
|
| group_vecs: list[torch.Tensor] = [] |
|
|
| for g in range(n_groups): |
| start = g * self.xkv_group_size |
| end = min(start + self.xkv_group_size, n_layers) |
| group_keys = keys[start:end] |
|
|
| |
| n_group_rows = group_keys.shape[0] * n_kv_heads * ctx_len |
|
|
| if n_group_rows > self.max_svd_rows: |
| gen = torch.Generator() |
| gen.manual_seed(42 + g) |
| indices = torch.randperm(n_group_rows, generator=gen)[:self.max_svd_rows] |
| svd_input = group_keys.reshape(n_group_rows, head_dim)[indices].float() |
| else: |
| svd_input = rearrange(group_keys.float(), 'l h t d -> (l h t) d') |
|
|
| effective_rank = min(rank_per_group, svd_input.shape[0], head_dim) |
|
|
| |
| U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False) |
| W_group = Vh[:effective_rank, :].T |
|
|
| |
| projected = svd_input @ W_group |
| group_vec = projected.mean(dim=0) |
| group_vecs.append(group_vec) |
|
|
| |
| remainder_start = n_groups * self.xkv_group_size |
| if remainder_start < n_layers: |
| remainder_keys = keys[remainder_start:] |
| n_rem_rows = remainder_keys.shape[0] * n_kv_heads * ctx_len |
|
|
| if n_rem_rows > self.max_svd_rows: |
| gen = torch.Generator() |
| gen.manual_seed(42 + n_groups) |
| indices = torch.randperm(n_rem_rows, generator=gen)[:self.max_svd_rows] |
| svd_input = remainder_keys.reshape(n_rem_rows, head_dim)[indices].float() |
| else: |
| svd_input = rearrange(remainder_keys.float(), 'l h t d -> (l h t) d') |
|
|
| effective_rank = min(rank_per_group, svd_input.shape[0], head_dim) |
| U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False) |
| W_rem = Vh[:effective_rank, :].T |
| projected = svd_input @ W_rem |
| group_vecs.append(projected.mean(dim=0)) |
|
|
| |
| state_vec = torch.cat(group_vecs, dim=0) |
|
|
| return state_vec |
|
|
| |
|
|
| @classmethod |
| def compute_corpus_basis( |
| cls, |
| key_tensors: list[torch.Tensor], |
| layer_range: tuple[int, int], |
| gate_start: int, |
| rank: int, |
| max_rows: int = 32768, |
| seed: int = 42, |
| ) -> torch.Tensor: |
| """Compute a fixed projection matrix from a corpus of key tensors. |
| |
| Returns P: [rank, head_dim] β the global semantic basis. |
| Unlike per-document SVD, this basis is document-independent. |
| All documents projected with P exist in the same coordinate system, |
| enabling stable cross-document and cross-model comparison. |
| """ |
| l_start, l_end = layer_range |
| gen = torch.Generator() |
| gen.manual_seed(seed) |
|
|
| all_rows: list[torch.Tensor] = [] |
| per_doc_max = max(1, max_rows // len(key_tensors)) |
|
|
| for keys in key_tensors: |
| k = keys[l_start:l_end].float() |
| n_rows = k.shape[0] * k.shape[1] * k.shape[2] |
| flat = k.reshape(n_rows, k.shape[3]) |
| if flat.shape[0] > per_doc_max: |
| idx = torch.randperm(flat.shape[0], generator=gen)[:per_doc_max] |
| flat = flat[idx] |
| all_rows.append(flat) |
|
|
| corpus = torch.cat(all_rows, dim=0) |
| if corpus.shape[0] > max_rows: |
| idx = torch.randperm(corpus.shape[0], generator=gen)[:max_rows] |
| corpus = corpus[idx] |
|
|
| _, S, Vh = torch.linalg.svd(corpus, full_matrices=False) |
| P = Vh[gate_start : gate_start + rank] |
| return P |
|
|
| def extract_with_basis( |
| self, |
| keys: torch.Tensor, |
| spec: ModelCacheSpec, |
| basis: torch.Tensor, |
| ) -> ExtractionResult: |
| """Extract state vector using a pre-computed fixed corpus basis. |
| |
| All vectors computed with the same basis share a coordinate system, |
| which is required for cross-model transfer via adapter. |
| |
| Args: |
| keys: [n_layers, n_kv_heads, n_cells, head_dim] |
| spec: Model spec (used for layer_range fallback) |
| basis: [rank, head_dim] from compute_corpus_basis() |
| |
| Returns: |
| ExtractionResult with L2-normalized state vector |
| """ |
| if self.layer_range is not None: |
| l_start, l_end = self.layer_range |
| else: |
| l_start, l_end = 0, keys.shape[0] |
| l_start = max(0, min(l_start, keys.shape[0])) |
| l_end = max(l_start, min(l_end, keys.shape[0])) |
|
|
| k = keys[l_start:l_end].float() |
| n_rows = k.shape[0] * k.shape[1] * k.shape[2] |
| flat = k.reshape(n_rows, k.shape[3]) |
|
|
| proj = flat @ basis.T |
| vec = proj.mean(dim=0) |
|
|
| norm = float(torch.linalg.vector_norm(vec).item()) |
| vec_normed = vec / (norm + 1e-8) |
|
|
| return ExtractionResult( |
| state_vec=vec_normed.to(torch.float32), |
| l2_norm=norm, |
| mode=self.mode, |
| n_layers_used=l_end - l_start, |
| n_tokens=k.shape[2], |
| ) |
|
|
| |
|
|
| @staticmethod |
| def compute_fourier_fingerprint( |
| keys: torch.Tensor, |
| freqs: tuple[int, ...] = (0, 1), |
| ) -> torch.Tensor: |
| """Compute the Fourier Absolute fingerprint from KV cache keys. |
| |
| Takes the real DFT over the layer dimension, extracts the |
| amplitude at the specified frequencies, normalizes each, and |
| concatenates them into a single fingerprint vector. |
| |
| This fingerprint is: |
| - Cross-model invariant (cos ~0.90 between 3B and 8B) |
| - Corpus-independent (no basis, no center, no training) |
| - Scale-stable (98% recall@1 at N=1000, decay N^-0.207) |
| |
| Args: |
| keys: [n_layers, n_kv_heads, n_cells, head_dim] β full KV keys. |
| All layers are used (not sliced by layer_range). |
| freqs: Frequency indices to extract. Default (0, 1) = DC + 1st harmonic. |
| f=0 captures overall key magnitude profile. |
| f=1 captures dominant oscillation across depth. |
| |
| Returns: |
| Fingerprint vector [dim * len(freqs)], L2-normalized. |
| """ |
| |
| n_layers = keys.shape[0] |
| layer_means = keys.float().mean(dim=2).reshape(n_layers, -1) |
|
|
| |
| F_complex = torch.fft.rfft(layer_means, dim=0) |
| F_amp = F_complex.abs() |
|
|
| |
| parts = [] |
| for f in freqs: |
| if f >= F_amp.shape[0]: |
| |
| parts.append(torch.zeros(F_amp.shape[1])) |
| else: |
| v = F_amp[f] |
| parts.append(v / (v.norm() + 1e-8)) |
|
|
| fingerprint = torch.cat(parts, dim=0) |
| return fingerprint / (fingerprint.norm() + 1e-8) |
|
|
| @property |
| def last_projection(self) -> SVDProjection | None: |
| """Access the SVD projection from the last svd_project call. |
| |
| Useful for diagnostics: check explained_variance_ratio to validate |
| that the rank is sufficient for this particular cache. |
| """ |
| return self._last_projection |
|
|
| def output_dim(self, spec: ModelCacheSpec) -> int: |
| """Compute the output dimension of the state vector for a given spec. |
| |
| This is needed to initialize the FAISS index with the correct dimension. |
| """ |
| match self.mode: |
| case StateExtractionMode.MEAN_POOL: |
| return spec["head_dim"] |
| case StateExtractionMode.SVD_PROJECT: |
| max_rank = min(self.gate_start + self.rank, spec["head_dim"]) |
| return max_rank - self.gate_start |
| case StateExtractionMode.XKV_PROJECT: |
| extraction_layers = spec["extraction_layers"] |
| n_layers = len(extraction_layers) |
| n_groups = max(1, n_layers // self.xkv_group_size) |
| rank_per_group = max(1, self.rank // n_groups) |
| rank_per_group = min(rank_per_group, spec["head_dim"]) |
| |
| has_remainder = (n_layers % self.xkv_group_size) != 0 |
| total_groups = n_groups + (1 if has_remainder else 0) |
| return total_groups * rank_per_group |
| case _: |
| raise ValueError(f"Unknown mode: {self.mode}") |
|
|