| """ |
| ENGRAM Protocol — Standalone State Extraction Functions |
| |
| Contains the Engram Absolute fingerprint: compute_fourier_fingerprint(). |
| This is the primary cross-model retrieval fingerprint, validated at |
| 98% recall@1 at N=1000 with power-law decay N^-0.207. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import TYPE_CHECKING |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| if TYPE_CHECKING: |
| from kvcos.core.blob_parser import ParsedMultiSectionCache |
|
|
|
|
| def compute_fourier_fingerprint( |
| layer_keys: torch.Tensor, |
| freqs: list[int] | None = None, |
| ) -> torch.Tensor: |
| """Compute the Engram Absolute fingerprint (f0+f1) from per-layer mean keys. |
| |
| Takes the real DFT over the layer dimension, extracts amplitude at |
| the specified frequencies, normalizes each, and concatenates. |
| |
| Args: |
| layer_keys: [n_layers, dim] where dim = n_kv_heads * head_dim. |
| Must be per-layer MEAN across token positions. |
| If shape is [n_layers, n_kv, hd], reshapes automatically. |
| freqs: DFT frequency indices to concatenate. |
| Default [0, 1] = DC (f0) + first harmonic (f1). |
| |
| Returns: |
| Fingerprint tensor [dim * len(freqs)], L2-normalized, float32. |
| |
| Properties: |
| - Cross-model invariant within Llama-3.x family (cos ~0.89) |
| - Zero corpus dependency: no centroid, no basis, no training data |
| - Recall@1 N=200: 98% N=1000: 98% decay: N^-0.207 |
| """ |
| if freqs is None: |
| freqs = [0, 1] |
|
|
| if layer_keys.dim() == 3: |
| n_layers = layer_keys.shape[0] |
| layer_keys = layer_keys.reshape(n_layers, -1) |
|
|
| layer_keys = layer_keys.float() |
|
|
| F_complex = torch.fft.rfft(layer_keys, dim=0) |
| F_amp = F_complex.abs() |
|
|
| components = [] |
| for f in freqs: |
| if f >= F_amp.shape[0]: |
| raise ValueError( |
| f"Requested freq={f} but rfft produced only " |
| f"{F_amp.shape[0]} components for {layer_keys.shape[0]} layers." |
| ) |
| components.append(F.normalize(F_amp[f], dim=-1)) |
|
|
| return torch.cat(components, dim=-1) |
|
|
|
|
| def compute_eigenform_score( |
| layer_keys: torch.Tensor, |
| noise_sigma: float = 0.001, |
| n_trials: int = 3, |
| freqs: list | None = None, |
| ) -> float: |
| """Compute eigenform stability score via noise perturbation. |
| |
| Measures how stable the Fourier fingerprint is under small noise. |
| Score near 1.0 = stable. Below 0.95 = fragile fingerprint. |
| |
| Args: |
| layer_keys: [n_layers, dim] per-layer mean key vectors. |
| noise_sigma: Gaussian noise standard deviation. |
| n_trials: Number of perturbed copies to compare. |
| freqs: DFT frequencies. Default [0, 1]. |
| |
| Returns: |
| float in [0, 1]. Mean pairwise cosine across noise trials. |
| """ |
| if freqs is None: |
| freqs = [0, 1] |
| fps = [] |
| for t in range(n_trials): |
| noisy = layer_keys if t == 0 else layer_keys + torch.randn_like(layer_keys) * noise_sigma |
| fps.append(compute_fourier_fingerprint(noisy.float(), freqs=freqs)) |
| pairs = [(i, j) for i in range(n_trials) for j in range(i+1, n_trials)] |
| if not pairs: |
| return 1.0 |
| return float(sum(F.cosine_similarity(fps[a].unsqueeze(0), fps[b].unsqueeze(0)).item() for a, b in pairs) / len(pairs)) |
|
|
|
|
| def compute_iswa_fingerprint( |
| parsed: "ParsedMultiSectionCache", |
| freqs: list | None = None, |
| normalize_layers: bool = True, |
| ) -> torch.Tensor: |
| """Compute concatenated Fourier fingerprint for ISWA multi-section caches. |
| |
| Strategy A (per-section concatenation): |
| For each cache section, compute mean over tokens, then Fourier FP. |
| Concatenate section FPs into one vector. |
| |
| For Gemma 4 with freqs=[0, 1]: |
| Global (5 layers, 2 heads, 512 dim) → 1024 * 2 = 2048 |
| SWA (25 layers, 8 heads, 256 dim) → 2048 * 2 = 4096 |
| Total: 6144-dim fingerprint |
| |
| Each section's sub-fingerprint is independently L2-normalized, |
| preserving the relative geometry within each attention type. |
| |
| Args: |
| parsed: ParsedMultiSectionCache from parse_multi_section_blob() |
| freqs: DFT frequency indices. Default [0, 1]. |
| normalize_layers: L2-normalize each layer before DFT (v2 behavior). |
| |
| Returns: |
| Concatenated fingerprint tensor, float32. |
| """ |
| if freqs is None: |
| freqs = [0, 1] |
|
|
| section_fps: list[torch.Tensor] = [] |
| for section in parsed.sections: |
| |
| layer_keys = section.keys.float().mean(dim=2) |
| fp = compute_fourier_fingerprint_v2(layer_keys, freqs=freqs, normalize_layers=normalize_layers) |
| section_fps.append(fp) |
|
|
| return torch.cat(section_fps, dim=-1) |
|
|
|
|
| def compute_fourier_fingerprint_v2( |
| layer_keys: torch.Tensor, |
| freqs: list | None = None, |
| normalize_layers: bool = True, |
| ) -> torch.Tensor: |
| """Fourier fingerprint v2: L2-normalize each layer before DFT. |
| |
| Removes absolute magnitude scale (which differs by KV head count |
| across model families), preserves layer-progression shape. |
| |
| Within-family: same recall as v1 (98%). |
| Cross-family: f0+f1 cross-sim expected >>0.26 (v1 baseline). |
| """ |
| if freqs is None: |
| freqs = [0, 1] |
| if layer_keys.dim() == 3: |
| layer_keys = layer_keys.reshape(layer_keys.shape[0], -1) |
| layer_keys = layer_keys.float() |
| if normalize_layers: |
| layer_keys = F.normalize(layer_keys, dim=-1) |
| F_complex = torch.fft.rfft(layer_keys, dim=0) |
| F_amp = F_complex.abs() |
| components = [] |
| for f in freqs: |
| if f >= F_amp.shape[0]: |
| raise ValueError(f"freq={f} out of range for {layer_keys.shape[0]} layers") |
| components.append(F.normalize(F_amp[f], dim=-1)) |
| return torch.cat(components, dim=-1) |
|
|