eigengram commited on
Commit
0769ff3
·
verified ·
1 Parent(s): 2a3efd4

feat: upload core kvcos library

Browse files
kvcos/.DS_Store ADDED
Binary file (6.15 kB). View file
 
kvcos/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """ENGRAM Protocol — KV cache fingerprinting for persistent semantic retrieval."""
2
+
3
+ from kvcos.core.types import ENGRAM_VERSION
4
+
5
+ __version__ = ENGRAM_VERSION
kvcos/api/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """ENGRAM Protocol — REST API package."""
2
+
3
+ from kvcos.api.server import create_app
4
+
5
+ __all__ = ["create_app"]
kvcos/api/routes.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — API Routes
3
+
4
+
5
+ FastAPI route handlers for the ENGRAM REST API.
6
+ All endpoints under /v1/ prefix.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from fastapi import APIRouter, HTTPException, UploadFile, File
12
+
13
+ from kvcos.api.schemas import (
14
+ DeleteResponse,
15
+ HealthResponse,
16
+ SearchRequest,
17
+ SearchResponse,
18
+ SearchResultItem,
19
+ StatsResponse,
20
+ StoreResponse,
21
+ )
22
+ from kvcos.core.types import ENGRAM_VERSION
23
+
24
+ router = APIRouter(prefix="/v1")
25
+
26
+
27
+ # ── Dependency stubs ──────────────────────────────────────────────────────────
28
+ # These are replaced by real instances in server.py lifespan.
29
+ # Using module-level state that the server sets during startup.
30
+
31
+ _retriever = None
32
+ _storage = None
33
+ _index = None
34
+
35
+
36
+ def _get_retriever():
37
+ if _retriever is None:
38
+ raise HTTPException(503, "ENGRAM not initialized. Server starting up.")
39
+ return _retriever
40
+
41
+
42
+ def _get_storage():
43
+ if _storage is None:
44
+ raise HTTPException(503, "ENGRAM not initialized. Server starting up.")
45
+ return _storage
46
+
47
+
48
+ def _get_index():
49
+ if _index is None:
50
+ raise HTTPException(503, "ENGRAM not initialized. Server starting up.")
51
+ return _index
52
+
53
+
54
+ # ── Health ────────────────────────────────────────────────────────────────────
55
+
56
+
57
+ @router.get("/health", response_model=HealthResponse)
58
+ async def health():
59
+ """Health check endpoint."""
60
+ index = _get_index()
61
+ storage = _get_storage()
62
+ return HealthResponse(
63
+ status="ok",
64
+ version=ENGRAM_VERSION,
65
+ index_entries=index.n_entries,
66
+ storage_backend="local",
67
+ )
68
+
69
+
70
+ # ── Stats (must come before /cache/{cache_id} to avoid route shadowing) ──────
71
+
72
+
73
+ @router.get("/cache/stats", response_model=StatsResponse)
74
+ async def cache_stats():
75
+ """Get aggregate statistics for the engram store."""
76
+ storage = _get_storage()
77
+ stats = storage.stats()
78
+ return StatsResponse(
79
+ total_entries=stats["total_entries"],
80
+ total_size_bytes=stats["total_size_bytes"],
81
+ total_size_mb=round(stats["total_size_bytes"] / (1024 * 1024), 2),
82
+ avg_compression_ratio=stats["avg_compression_ratio"],
83
+ model_breakdown=stats["model_breakdown"],
84
+ )
85
+
86
+
87
+ # ── Store ─────────────────────────────────────────────────────────────────────
88
+
89
+
90
+ @router.post("/cache", response_model=StoreResponse)
91
+ async def store_cache(
92
+ agent_id: str,
93
+ task_description: str,
94
+ model_id: str,
95
+ file: UploadFile = File(...),
96
+ compression: str = "q8_0",
97
+ ):
98
+ """Store a .eng file in the engram store.
99
+
100
+ Accepts a pre-serialized .eng file upload.
101
+ The file is stored and its metadata indexed for EGR retrieval.
102
+ """
103
+ storage = _get_storage()
104
+
105
+ data = await file.read()
106
+ if len(data) == 0:
107
+ raise HTTPException(400, "Empty file upload")
108
+
109
+ import uuid
110
+ cache_id = str(uuid.uuid4())
111
+
112
+ from kvcos.core.types import EngramMetadata
113
+ from datetime import datetime, timezone
114
+
115
+ metadata: EngramMetadata = {
116
+ "engram_version": ENGRAM_VERSION,
117
+ "cache_id": cache_id,
118
+ "compression": compression,
119
+ "model_id": model_id,
120
+ "model_family": "",
121
+ "n_layers": "0",
122
+ "n_heads": "0",
123
+ "n_kv_heads": "0",
124
+ "head_dim": "0",
125
+ "context_len": "0",
126
+ "agent_id": agent_id,
127
+ "task_description": task_description,
128
+ "created_at": datetime.now(timezone.utc).isoformat(),
129
+ }
130
+
131
+ path = storage.store(cache_id, data, metadata)
132
+
133
+ return StoreResponse(
134
+ cache_id=cache_id,
135
+ size_bytes=len(data),
136
+ compression_ratio=1.0,
137
+ path=path,
138
+ )
139
+
140
+
141
+ # ── Retrieve by ID ────────────────────────────────────────────────────────────
142
+
143
+
144
+ @router.get("/cache/{cache_id}")
145
+ async def get_cache(cache_id: str):
146
+ """Retrieve a .eng file by cache ID.
147
+
148
+ Returns the raw .eng file bytes (application/octet-stream).
149
+ """
150
+ storage = _get_storage()
151
+
152
+ data = storage.get(cache_id)
153
+ if data is None:
154
+ raise HTTPException(404, f"Cache entry not found: {cache_id}")
155
+
156
+ from fastapi.responses import Response
157
+ return Response(
158
+ content=data,
159
+ media_type="application/octet-stream",
160
+ headers={"Content-Disposition": f'attachment; filename="{cache_id}.eng"'},
161
+ )
162
+
163
+
164
+ # ── Search ────────────────────────────────────────────────────────────────────
165
+
166
+
167
+ @router.post("/cache/search", response_model=SearchResponse)
168
+ async def search_cache(req: SearchRequest):
169
+ """Search for similar engram states via EGR manifold search.
170
+
171
+ Uses inner product similarity (MIPS) in the model's pre-RoPE
172
+ key manifold. D2: K→K retrieval only.
173
+ """
174
+ index = _get_index()
175
+
176
+ # For text-only search without a KV query vector, we need the
177
+ # retriever to extract a state vector first. This endpoint
178
+ # currently returns index entries matching by metadata filter.
179
+ # Full EGR vector search requires a query KV cache (via /egr/retrieve).
180
+
181
+ # Metadata-based listing with optional filters
182
+ storage = _get_storage()
183
+ entries = storage.list_entries(model_family=None, limit=req.top_k)
184
+
185
+ results = [
186
+ SearchResultItem(
187
+ cache_id=e.get("cache_id", ""),
188
+ similarity=0.0,
189
+ task_description=e.get("task_description", ""),
190
+ model_id=e.get("model_id", ""),
191
+ created_at=e.get("created_at", ""),
192
+ context_len=int(e.get("context_len", "0")),
193
+ )
194
+ for e in entries
195
+ if (req.model_id is None or e.get("model_id") == req.model_id)
196
+ ]
197
+
198
+ return SearchResponse(results=results[:req.top_k], n_searched=index.n_entries)
199
+
200
+
201
+ # ── Delete ────────────────────────────────────────────────────────────────────
202
+
203
+
204
+ @router.delete("/cache/{cache_id}", response_model=DeleteResponse)
205
+ async def delete_cache(cache_id: str):
206
+ """Delete an engram from storage and index."""
207
+ retriever = _get_retriever()
208
+ deleted = retriever.delete_engram(cache_id)
209
+ return DeleteResponse(deleted=deleted, cache_id=cache_id)
210
+
211
+
kvcos/api/schemas.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — API Schemas
3
+
4
+
5
+ Pydantic models for all REST API request/response payloads.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from pydantic import BaseModel, Field
11
+
12
+
13
+ # ── Store ─────────────────────────────────────────────────────────────────────
14
+
15
+
16
+ class StoreRequest(BaseModel):
17
+ agent_id: str
18
+ task_description: str
19
+ model_id: str
20
+ compression: str = "q8_0"
21
+
22
+
23
+ class StoreResponse(BaseModel):
24
+ cache_id: str
25
+ size_bytes: int
26
+ compression_ratio: float
27
+ path: str
28
+
29
+
30
+ # ── Retrieve ──────────────────────────────────────────────────────────────────
31
+
32
+
33
+ class SearchRequest(BaseModel):
34
+ task_description: str
35
+ model_id: str | None = None
36
+ top_k: int = Field(default=5, ge=1, le=100)
37
+ min_similarity: float | None = None
38
+
39
+
40
+ class SearchResultItem(BaseModel):
41
+ cache_id: str
42
+ similarity: float
43
+ task_description: str
44
+ model_id: str
45
+ created_at: str
46
+ context_len: int
47
+
48
+
49
+ class SearchResponse(BaseModel):
50
+ results: list[SearchResultItem]
51
+ n_searched: int
52
+
53
+
54
+ # ── Extend ────────────────────────────────────────────────────────────────────
55
+
56
+
57
+ class ExtendResponse(BaseModel):
58
+ cache_id: str
59
+ new_context_len: int
60
+
61
+
62
+ # ── Delete ────────────────────────────────────────────────────────────────────
63
+
64
+
65
+ class DeleteResponse(BaseModel):
66
+ deleted: bool
67
+ cache_id: str
68
+
69
+
70
+ # ── Stats ─────────────────────────────────────────────────────────────────────
71
+
72
+
73
+ class StatsResponse(BaseModel):
74
+ total_entries: int
75
+ total_size_bytes: int
76
+ total_size_mb: float
77
+ avg_compression_ratio: float
78
+ model_breakdown: dict[str, int]
79
+
80
+
81
+ # ── Health ────────────────────────────────────────────────────────────────────
82
+
83
+
84
+ class HealthResponse(BaseModel):
85
+ status: str = "ok"
86
+ version: str
87
+ index_entries: int
88
+ storage_backend: str
kvcos/api/server.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — ENGRAM Server
3
+
4
+
5
+ FastAPI application factory with lifespan management.
6
+ Initializes storage, index, extractor, and retriever on startup.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import logging
12
+ from contextlib import asynccontextmanager
13
+
14
+ from fastapi import FastAPI
15
+
16
+ from kvcos.api import routes
17
+ from kvcos.core.config import get_config
18
+ from kvcos.core.serializer import EngramSerializer
19
+ from kvcos.core.types import ENGRAM_VERSION, StateExtractionMode
20
+ from kvcos.core.manifold_index import ManifoldIndex
21
+ from kvcos.core.retriever import EGRRetriever
22
+ from kvcos.core.state_extractor import MARStateExtractor
23
+ from kvcos.storage.local import LocalStorageBackend
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ @asynccontextmanager
29
+ async def lifespan(app: FastAPI):
30
+ """Initialize ENGRAM components on startup, clean up on shutdown."""
31
+ config = get_config()
32
+
33
+ # Initialize storage backend
34
+ storage = LocalStorageBackend(data_dir=config.data_dir)
35
+
36
+ # Initialize EGR manifold index
37
+ index_path = config.index_dir / "egr.faiss"
38
+ index = ManifoldIndex(dim=config.state_vec_dim, index_path=index_path)
39
+
40
+ # Initialize state extractor
41
+ extractor = MARStateExtractor(
42
+ mode=StateExtractionMode.SVD_PROJECT,
43
+ rank=config.state_vec_dim,
44
+ )
45
+
46
+ # Initialize retriever
47
+ serializer = EngramSerializer()
48
+ retriever = EGRRetriever(
49
+ extractor=extractor,
50
+ index=index,
51
+ storage=storage,
52
+ serializer=serializer,
53
+ )
54
+
55
+ # Wire into route handlers
56
+ routes._storage = storage
57
+ routes._index = index
58
+ routes._retriever = retriever
59
+
60
+ logger.info("ENGRAM v%s started", ENGRAM_VERSION)
61
+ logger.info(" Storage: %s (%d entries)", config.data_dir, storage.stats()["total_entries"])
62
+ logger.info(" Index: %s (%d vectors, dim=%d)", config.index_dir, index.n_entries, config.state_vec_dim)
63
+ logger.info(" Backend: %s", config.backend.value)
64
+
65
+ yield
66
+
67
+ # Shutdown: persist index
68
+ try:
69
+ index.save(index_path)
70
+ logger.info("Index saved to %s", index_path)
71
+ except Exception as e:
72
+ logger.warning("Failed to save index: %s", e)
73
+
74
+ # Clear route references
75
+ routes._storage = None
76
+ routes._index = None
77
+ routes._retriever = None
78
+
79
+ logger.info("ENGRAM shutdown complete")
80
+
81
+
82
+ def create_app() -> FastAPI:
83
+ """Create the ENGRAM FastAPI application."""
84
+ app = FastAPI(
85
+ title="ENGRAM Protocol API",
86
+ description="ENGRAM Protocol: Cognitive state, persisted.",
87
+ version=ENGRAM_VERSION,
88
+ lifespan=lifespan,
89
+ docs_url="/docs",
90
+ redoc_url="/redoc",
91
+ )
92
+ app.include_router(routes.router)
93
+ return app
94
+
95
+
96
+ def main() -> None:
97
+ """Entry point for `engram-server` console script."""
98
+ import uvicorn
99
+
100
+ config = get_config()
101
+ application = create_app()
102
+ uvicorn.run(
103
+ application,
104
+ host=config.host,
105
+ port=config.port,
106
+ log_level="info",
107
+ )
108
+
109
+
110
+ def _get_app() -> FastAPI:
111
+ """Lazy app factory for `uvicorn kvcos.api.server:app`.
112
+
113
+ Defers create_app() until the attribute is actually accessed,
114
+ avoiding side effects on module import.
115
+ """
116
+ return create_app()
117
+
118
+
119
+ def __getattr__(name: str) -> FastAPI:
120
+ if name == "app":
121
+ return _get_app()
122
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
123
+
124
+
125
+ if __name__ == "__main__":
126
+ main()
kvcos/client/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """ENGRAM Protocol — Client library."""
2
+
3
+ from kvcos.client.python_client import EngramClient
4
+
5
+ __all__ = ["EngramClient"]
kvcos/client/python_client.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — ENGRAM Python Client
3
+
4
+
5
+ Async HTTP client wrapping all ENGRAM API endpoints.
6
+ This is what agents import to interact with the engram store.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from pathlib import Path
12
+ from typing import Any
13
+
14
+ import httpx
15
+
16
+
17
+ class EngramClient:
18
+ """Python client for the ENGRAM REST API.
19
+
20
+ Usage:
21
+ client = EngramClient("http://localhost:8080")
22
+ result = client.store_file(path, agent_id="worker", task="analyze code", model_id="llama-3.1-8b")
23
+ matches = client.search(task_description="debug auth error", top_k=3)
24
+ data = client.get(matches[0]["cache_id"])
25
+ """
26
+
27
+ def __init__(self, base_url: str = "http://localhost:8080", timeout: float = 30.0):
28
+ self.base_url = base_url.rstrip("/")
29
+ self._client = httpx.Client(base_url=f"{self.base_url}/v1", timeout=timeout)
30
+
31
+ def close(self) -> None:
32
+ self._client.close()
33
+
34
+ def __enter__(self):
35
+ return self
36
+
37
+ def __exit__(self, *args):
38
+ self.close()
39
+
40
+ # ── Health ────────────────────────────────────────────────
41
+
42
+ def health(self) -> dict[str, Any]:
43
+ """Check ENGRAM server health."""
44
+ resp = self._client.get("/health")
45
+ resp.raise_for_status()
46
+ return resp.json()
47
+
48
+ # ── Store ─────────────────────────────────────────────────
49
+
50
+ def store_file(
51
+ self,
52
+ file_path: Path,
53
+ agent_id: str,
54
+ task_description: str,
55
+ model_id: str,
56
+ compression: str = "q8_0",
57
+ ) -> dict[str, Any]:
58
+ """Upload a .eng file to the engram store.
59
+
60
+ Args:
61
+ file_path: Path to the .eng file
62
+ agent_id: Agent identifier
63
+ task_description: Human-readable description
64
+ model_id: Model identifier
65
+ compression: Compression method used
66
+
67
+ Returns:
68
+ Dict with cache_id, size_bytes, compression_ratio, path
69
+ """
70
+ with open(file_path, "rb") as f:
71
+ resp = self._client.post(
72
+ "/cache",
73
+ params={
74
+ "agent_id": agent_id,
75
+ "task_description": task_description,
76
+ "model_id": model_id,
77
+ "compression": compression,
78
+ },
79
+ files={"file": (file_path.name, f, "application/octet-stream")},
80
+ )
81
+ resp.raise_for_status()
82
+ return resp.json()
83
+
84
+ def store_bytes(
85
+ self,
86
+ data: bytes,
87
+ agent_id: str,
88
+ task_description: str,
89
+ model_id: str,
90
+ compression: str = "q8_0",
91
+ filename: str = "cache.eng",
92
+ ) -> dict[str, Any]:
93
+ """Upload raw .eng bytes to the engram store."""
94
+ resp = self._client.post(
95
+ "/cache",
96
+ params={
97
+ "agent_id": agent_id,
98
+ "task_description": task_description,
99
+ "model_id": model_id,
100
+ "compression": compression,
101
+ },
102
+ files={"file": (filename, data, "application/octet-stream")},
103
+ )
104
+ resp.raise_for_status()
105
+ return resp.json()
106
+
107
+ # ── Retrieve ──────────────────────────────────────────────
108
+
109
+ def get(self, cache_id: str) -> bytes:
110
+ """Retrieve a .eng file by cache ID.
111
+
112
+ Returns raw bytes of the .eng file.
113
+ """
114
+ resp = self._client.get(f"/cache/{cache_id}")
115
+ resp.raise_for_status()
116
+ return resp.content
117
+
118
+ # ── Search ────────────────────────────────────────────────
119
+
120
+ def search(
121
+ self,
122
+ task_description: str,
123
+ model_id: str | None = None,
124
+ top_k: int = 5,
125
+ min_similarity: float | None = None,
126
+ ) -> list[dict[str, Any]]:
127
+ """Search for similar engram states.
128
+
129
+ Returns list of search result dicts with cache_id, similarity, etc.
130
+ """
131
+ body: dict[str, Any] = {
132
+ "task_description": task_description,
133
+ "top_k": top_k,
134
+ }
135
+ if model_id:
136
+ body["model_id"] = model_id
137
+ if min_similarity is not None:
138
+ body["min_similarity"] = min_similarity
139
+
140
+ resp = self._client.post("/cache/search", json=body)
141
+ resp.raise_for_status()
142
+ return resp.json()["results"]
143
+
144
+ # ── Delete ────────────────────────────────────────────────
145
+
146
+ def delete(self, cache_id: str) -> bool:
147
+ """Delete an engram from storage and index."""
148
+ resp = self._client.delete(f"/cache/{cache_id}")
149
+ resp.raise_for_status()
150
+ return resp.json()["deleted"]
151
+
152
+ # ── Stats ─────────────────────────────────────────────────
153
+
154
+ def stats(self) -> dict[str, Any]:
155
+ """Get aggregate engram store statistics."""
156
+ resp = self._client.get("/cache/stats")
157
+ resp.raise_for_status()
158
+ return resp.json()
kvcos/core/__init__.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ENGRAM Protocol — Core library: types, parsing, compression, serialization, retrieval."""
2
+
3
+ from kvcos.core.types import (
4
+ ENGRAM_VERSION,
5
+ AttentionType,
6
+ CacheSection,
7
+ CacheSearchResult,
8
+ CacheStats,
9
+ CompressionMethod,
10
+ EngramMetadata,
11
+ ModelCacheSpec,
12
+ StateExtractionMode,
13
+ )
14
+ from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
15
+ from kvcos.core.retriever import EGRRetriever, RetrievalResponse, RetrievalResult
16
+ from kvcos.core.state_extractor import ExtractionResult, MARStateExtractor, SVDProjection
17
+
18
+ __all__ = [
19
+ # Types
20
+ "ENGRAM_VERSION",
21
+ "AttentionType",
22
+ "CacheSection",
23
+ "CacheSearchResult",
24
+ "CacheStats",
25
+ "CompressionMethod",
26
+ "EngramMetadata",
27
+ "ModelCacheSpec",
28
+ "StateExtractionMode",
29
+ # Manifold index
30
+ "IndexEntry",
31
+ "ManifoldIndex",
32
+ # Retriever
33
+ "EGRRetriever",
34
+ "RetrievalResponse",
35
+ "RetrievalResult",
36
+ # State extraction (MAR)
37
+ "ExtractionResult",
38
+ "MARStateExtractor",
39
+ "SVDProjection",
40
+ ]
kvcos/core/blob_parser.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — llama.cpp State Blob Parser
3
+
4
+
5
+ Parses the binary state blob from llama_state_get_data() (via save_state())
6
+ into structured PyTorch tensors of shape [n_layers, n_kv_heads, n_cells, head_dim].
7
+
8
+ D1: This is the critical extraction path. The blob format is defined by
9
+ llama.cpp's llama_kv_cache::state_write() and is version-dependent.
10
+
11
+ Validated against llama-cpp-python 0.3.19 (llama.cpp b5000+).
12
+
13
+ Binary format of llama_state_get_data() output:
14
+ 1. Architecture string: uint32 str_len + str_len bytes (e.g. "llama")
15
+ 2. KV cache section (from memory->state_write()):
16
+ a. uint32 n_stream (always 1 for single-context)
17
+ b. Per stream:
18
+ - uint32 cell_count (= n_used_cells, NOT n_ctx)
19
+ - Per cell: int32 pos, uint32 n_seq_id, int32[] seq_ids
20
+ - uint32 v_trans (1 = values stored transposed)
21
+ - uint32 n_layer
22
+ - Per layer K: int32 type_k, uint64 row_size_k, bytes data[row_size_k * cell_count]
23
+ - Per layer V (non-transposed): int32 type_v, uint64 row_size_v, bytes data[row_size_v * cell_count]
24
+ - Per layer V (transposed): int32 type_v, uint32 el_size, uint32 n_embd_v_gqa,
25
+ bytes data[el_size * n_embd_v_gqa * cell_count]
26
+
27
+ WARNING: This format is not stable across llama.cpp versions.
28
+ Pin llama-cpp-python version in pyproject.toml.
29
+ """
30
+
31
+ from __future__ import annotations
32
+
33
+ import struct
34
+ from dataclasses import dataclass
35
+
36
+ import numpy as np
37
+ import torch
38
+
39
+ from kvcos.core.types import CacheSection
40
+
41
+
42
+ # ── GGML dtype constants ──────────────────────────────────────────────────────
43
+
44
+ GGML_TYPE_F32 = 0
45
+ GGML_TYPE_F16 = 1
46
+ GGML_TYPE_Q8_0 = 8
47
+ GGML_TYPE_Q4_0 = 2
48
+
49
+ GGML_TYPE_SIZE: dict[int, float] = {
50
+ GGML_TYPE_F32: 4.0,
51
+ GGML_TYPE_F16: 2.0,
52
+ GGML_TYPE_Q8_0: 34.0 / 32.0,
53
+ GGML_TYPE_Q4_0: 18.0 / 32.0,
54
+ }
55
+
56
+ GGML_BLOCK_SIZE: dict[int, int] = {
57
+ GGML_TYPE_F32: 1,
58
+ GGML_TYPE_F16: 1,
59
+ GGML_TYPE_Q8_0: 32,
60
+ GGML_TYPE_Q4_0: 32,
61
+ }
62
+
63
+
64
+ @dataclass
65
+ class CellMeta:
66
+ """Metadata for a single KV cache cell."""
67
+
68
+ pos: int
69
+ seq_ids: list[int]
70
+
71
+
72
+ @dataclass
73
+ class ParsedKVCache:
74
+ """Result of parsing a llama.cpp state blob into structured engram tensors."""
75
+
76
+ keys: torch.Tensor # [n_layers, n_kv_heads, n_cells, head_dim] float16
77
+ values: torch.Tensor # [n_layers, n_kv_heads, n_cells, head_dim] float16
78
+ cells: list[CellMeta]
79
+ n_cells: int
80
+ n_layers: int
81
+ v_trans: bool
82
+ arch: str
83
+
84
+
85
+ @dataclass
86
+ class ParsedMultiSectionCache:
87
+ """Result of parsing an ISWA state blob with multiple KV cache sections.
88
+
89
+ Each section is a ParsedKVCache with its own tensor shapes.
90
+ For Gemma 4: section[0] is Global (5 layers), section[1] is SWA (25 layers).
91
+ """
92
+
93
+ sections: list[ParsedKVCache]
94
+ arch: str
95
+
96
+ @property
97
+ def n_sections(self) -> int:
98
+ return len(self.sections)
99
+
100
+ @property
101
+ def total_layers(self) -> int:
102
+ return sum(s.n_layers for s in self.sections)
103
+
104
+
105
+ class BlobParseError(Exception):
106
+ """Raised when the state blob cannot be parsed."""
107
+
108
+
109
+ def _read_u32(data: bytes, offset: int) -> tuple[int, int]:
110
+ return struct.unpack_from("<I", data, offset)[0], offset + 4
111
+
112
+
113
+ def _read_i32(data: bytes, offset: int) -> tuple[int, int]:
114
+ return struct.unpack_from("<i", data, offset)[0], offset + 4
115
+
116
+
117
+ def _read_u64(data: bytes, offset: int) -> tuple[int, int]:
118
+ return struct.unpack_from("<Q", data, offset)[0], offset + 8
119
+
120
+
121
+ def _read_f16_block(
122
+ data: bytes, offset: int, n_elements: int,
123
+ ) -> tuple[torch.Tensor, int]:
124
+ """Read n_elements of float16 data from bytes."""
125
+ n_bytes = n_elements * 2
126
+ if offset + n_bytes > len(data):
127
+ raise BlobParseError(
128
+ f"F16 read overflow: need {n_bytes}B at offset {offset}, blob is {len(data)}B"
129
+ )
130
+ arr = np.frombuffer(data, dtype=np.float16, count=n_elements, offset=offset)
131
+ return torch.from_numpy(arr.copy()).to(torch.float16), offset + n_bytes
132
+
133
+
134
+ def parse_state_blob(
135
+ blob: bytes,
136
+ n_kv_heads: int,
137
+ head_dim: int,
138
+ ) -> ParsedKVCache:
139
+ """Parse a llama.cpp full-context state blob into structured KV tensors.
140
+
141
+ Parses output of llama_state_get_data() (via save_state()):
142
+ 1. Architecture string header
143
+ 2. KV stream: cell metadata + per-layer K and V tensor data
144
+
145
+ The parser auto-detects n_layers, cell_count, and v_trans from the blob.
146
+
147
+ Args:
148
+ blob: Raw bytes from save_state().llama_state
149
+ n_kv_heads: Number of KV heads (from model spec)
150
+ head_dim: Head dimension (from model spec)
151
+
152
+ Returns:
153
+ ParsedKVCache with [n_layers, n_kv_heads, n_cells, head_dim] tensors.
154
+ """
155
+ if len(blob) < 20:
156
+ raise BlobParseError(f"Blob too small: {len(blob)} bytes")
157
+
158
+ offset = 0
159
+ n_embd_kv = n_kv_heads * head_dim
160
+
161
+ # ── 1. Architecture string ────────────────────────────────
162
+ str_len, offset = _read_u32(blob, offset)
163
+ if str_len > 100:
164
+ raise BlobParseError(f"Arch string length {str_len} too large — format mismatch")
165
+ arch = blob[offset : offset + str_len].decode("ascii", errors="replace")
166
+ offset += str_len
167
+
168
+ # ── 2. KV stream header ───────────────────────────────────
169
+ n_stream, offset = _read_u32(blob, offset)
170
+ if n_stream != 1:
171
+ raise BlobParseError(f"Expected 1 KV stream, got {n_stream}")
172
+
173
+ cell_count, offset = _read_u32(blob, offset)
174
+ if cell_count == 0:
175
+ raise BlobParseError("State blob has 0 cells")
176
+ if cell_count > 200_000:
177
+ raise BlobParseError(f"Suspiciously large cell_count: {cell_count}")
178
+
179
+ # ── 3. Cell metadata ──────────────────────────────────────
180
+ cells: list[CellMeta] = []
181
+ for _ in range(cell_count):
182
+ pos, offset = _read_i32(blob, offset)
183
+ n_seq, offset = _read_u32(blob, offset)
184
+ seq_ids: list[int] = []
185
+ for _ in range(n_seq):
186
+ sid, offset = _read_i32(blob, offset)
187
+ seq_ids.append(sid)
188
+ cells.append(CellMeta(pos=pos, seq_ids=seq_ids))
189
+
190
+ # ── 4. Data section header ────────────────────────────────
191
+ v_trans_u32, offset = _read_u32(blob, offset)
192
+ v_trans = v_trans_u32 != 0
193
+
194
+ n_layers, offset = _read_u32(blob, offset)
195
+ if n_layers == 0 or n_layers > 200:
196
+ raise BlobParseError(f"Invalid n_layers: {n_layers}")
197
+
198
+ # ── 5. K tensor data (per layer) ──────────────────────────
199
+ k_layers: list[torch.Tensor] = []
200
+ for layer_idx in range(n_layers):
201
+ type_k, offset = _read_i32(blob, offset)
202
+ row_size_k, offset = _read_u64(blob, offset)
203
+
204
+ if type_k != GGML_TYPE_F16:
205
+ raise BlobParseError(
206
+ f"Layer {layer_idx} K: unsupported type {type_k} (expected F16={GGML_TYPE_F16})"
207
+ )
208
+
209
+ data_bytes = row_size_k * cell_count
210
+ n_elements = data_bytes // 2 # fp16
211
+
212
+ if n_elements != n_embd_kv * cell_count:
213
+ raise BlobParseError(
214
+ f"Layer {layer_idx} K: expected {n_embd_kv * cell_count} elements, "
215
+ f"got {n_elements} (row_size={row_size_k}, cells={cell_count})"
216
+ )
217
+
218
+ tensor, offset = _read_f16_block(blob, offset, n_elements)
219
+ # Shape: [cell_count, n_kv_heads * head_dim] → [n_kv_heads, cell_count, head_dim]
220
+ tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
221
+ tensor = tensor.permute(1, 0, 2).contiguous()
222
+ k_layers.append(tensor)
223
+
224
+ # ── 6. V tensor data (per layer) ──────────────────────────
225
+ v_layers: list[torch.Tensor] = []
226
+ for layer_idx in range(n_layers):
227
+ type_v, offset = _read_i32(blob, offset)
228
+
229
+ if type_v != GGML_TYPE_F16:
230
+ raise BlobParseError(
231
+ f"Layer {layer_idx} V: unsupported type {type_v} (expected F16={GGML_TYPE_F16})"
232
+ )
233
+
234
+ if v_trans:
235
+ el_size, offset = _read_u32(blob, offset)
236
+ n_embd_v, offset = _read_u32(blob, offset)
237
+ data_bytes = el_size * n_embd_v * cell_count
238
+ n_elements = data_bytes // 2
239
+
240
+ tensor, offset = _read_f16_block(blob, offset, n_elements)
241
+ # V transposed: stored as [n_embd_v, cell_count] per layer
242
+ # n_embd_v = n_kv_heads * head_dim
243
+ tensor = tensor.reshape(n_embd_v // head_dim, head_dim, cell_count)
244
+ # → [n_kv_heads, head_dim, cell_count] → [n_kv_heads, cell_count, head_dim]
245
+ tensor = tensor.permute(0, 2, 1).contiguous()
246
+ else:
247
+ row_size_v, offset = _read_u64(blob, offset)
248
+ data_bytes = row_size_v * cell_count
249
+ n_elements = data_bytes // 2
250
+
251
+ tensor, offset = _read_f16_block(blob, offset, n_elements)
252
+ tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
253
+ tensor = tensor.permute(1, 0, 2).contiguous()
254
+
255
+ v_layers.append(tensor)
256
+
257
+ # ── 7. Stack into [n_layers, n_kv_heads, n_cells, head_dim] ─
258
+ keys = torch.stack(k_layers, dim=0)
259
+ values = torch.stack(v_layers, dim=0)
260
+
261
+ expected_shape = (n_layers, n_kv_heads, cell_count, head_dim)
262
+ if keys.shape != expected_shape:
263
+ raise BlobParseError(f"K shape {keys.shape} != expected {expected_shape}")
264
+ if values.shape != expected_shape:
265
+ raise BlobParseError(f"V shape {values.shape} != expected {expected_shape}")
266
+
267
+ return ParsedKVCache(
268
+ keys=keys,
269
+ values=values,
270
+ cells=cells,
271
+ n_cells=cell_count,
272
+ n_layers=n_layers,
273
+ v_trans=v_trans,
274
+ arch=arch,
275
+ )
276
+
277
+
278
+ def _parse_single_stream(
279
+ blob: bytes,
280
+ offset: int,
281
+ n_kv_heads: int,
282
+ head_dim: int,
283
+ arch: str,
284
+ ) -> tuple[ParsedKVCache, int]:
285
+ """Parse one KV cache stream from blob at given offset.
286
+
287
+ Returns (ParsedKVCache, new_offset) so caller can continue
288
+ parsing subsequent streams for ISWA blobs.
289
+ """
290
+ n_embd_kv = n_kv_heads * head_dim
291
+
292
+ # Cell count
293
+ cell_count, offset = _read_u32(blob, offset)
294
+ if cell_count == 0:
295
+ raise BlobParseError("Stream has 0 cells")
296
+ if cell_count > 200_000:
297
+ raise BlobParseError(f"Suspiciously large cell_count: {cell_count}")
298
+
299
+ # Cell metadata
300
+ cells: list[CellMeta] = []
301
+ for _ in range(cell_count):
302
+ pos, offset = _read_i32(blob, offset)
303
+ n_seq, offset = _read_u32(blob, offset)
304
+ seq_ids: list[int] = []
305
+ for _ in range(n_seq):
306
+ sid, offset = _read_i32(blob, offset)
307
+ seq_ids.append(sid)
308
+ cells.append(CellMeta(pos=pos, seq_ids=seq_ids))
309
+
310
+ # Data section header
311
+ v_trans_u32, offset = _read_u32(blob, offset)
312
+ v_trans = v_trans_u32 != 0
313
+
314
+ n_layers, offset = _read_u32(blob, offset)
315
+ if n_layers == 0 or n_layers > 200:
316
+ raise BlobParseError(f"Invalid n_layers: {n_layers}")
317
+
318
+ # K layers
319
+ k_layers: list[torch.Tensor] = []
320
+ for layer_idx in range(n_layers):
321
+ type_k, offset = _read_i32(blob, offset)
322
+ row_size_k, offset = _read_u64(blob, offset)
323
+
324
+ if type_k != GGML_TYPE_F16:
325
+ raise BlobParseError(
326
+ f"Layer {layer_idx} K: unsupported type {type_k} (expected F16={GGML_TYPE_F16})"
327
+ )
328
+
329
+ data_bytes = row_size_k * cell_count
330
+ n_elements = data_bytes // 2
331
+
332
+ if n_elements != n_embd_kv * cell_count:
333
+ raise BlobParseError(
334
+ f"Layer {layer_idx} K: expected {n_embd_kv * cell_count} elements, "
335
+ f"got {n_elements} (row_size={row_size_k}, cells={cell_count})"
336
+ )
337
+
338
+ tensor, offset = _read_f16_block(blob, offset, n_elements)
339
+ tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
340
+ tensor = tensor.permute(1, 0, 2).contiguous()
341
+ k_layers.append(tensor)
342
+
343
+ # V layers
344
+ v_layers: list[torch.Tensor] = []
345
+ for layer_idx in range(n_layers):
346
+ type_v, offset = _read_i32(blob, offset)
347
+
348
+ if type_v != GGML_TYPE_F16:
349
+ raise BlobParseError(
350
+ f"Layer {layer_idx} V: unsupported type {type_v} (expected F16={GGML_TYPE_F16})"
351
+ )
352
+
353
+ if v_trans:
354
+ el_size, offset = _read_u32(blob, offset)
355
+ n_embd_v, offset = _read_u32(blob, offset)
356
+ data_bytes = el_size * n_embd_v * cell_count
357
+ n_elements = data_bytes // 2
358
+
359
+ tensor, offset = _read_f16_block(blob, offset, n_elements)
360
+ tensor = tensor.reshape(n_embd_v // head_dim, head_dim, cell_count)
361
+ tensor = tensor.permute(0, 2, 1).contiguous()
362
+ else:
363
+ row_size_v, offset = _read_u64(blob, offset)
364
+ data_bytes = row_size_v * cell_count
365
+ n_elements = data_bytes // 2
366
+
367
+ tensor, offset = _read_f16_block(blob, offset, n_elements)
368
+ tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
369
+ tensor = tensor.permute(1, 0, 2).contiguous()
370
+
371
+ v_layers.append(tensor)
372
+
373
+ keys = torch.stack(k_layers, dim=0)
374
+ values = torch.stack(v_layers, dim=0)
375
+
376
+ expected_shape = (n_layers, n_kv_heads, cell_count, head_dim)
377
+ if keys.shape != expected_shape:
378
+ raise BlobParseError(f"K shape {keys.shape} != expected {expected_shape}")
379
+ if values.shape != expected_shape:
380
+ raise BlobParseError(f"V shape {values.shape} != expected {expected_shape}")
381
+
382
+ parsed = ParsedKVCache(
383
+ keys=keys,
384
+ values=values,
385
+ cells=cells,
386
+ n_cells=cell_count,
387
+ n_layers=n_layers,
388
+ v_trans=v_trans,
389
+ arch=arch,
390
+ )
391
+ return parsed, offset
392
+
393
+
394
+ def parse_multi_section_blob(
395
+ blob: bytes,
396
+ sections: tuple[CacheSection, ...],
397
+ ) -> ParsedMultiSectionCache:
398
+ """Parse an ISWA state blob with multiple sequential KV cache sections.
399
+
400
+ ISWA models (e.g., Gemma 4) serialize multiple cache sections in a single
401
+ blob. Each section has its own cell metadata, layer count, and KV dimensions.
402
+ The n_stream field in the blob header equals the number of sections.
403
+
404
+ Args:
405
+ blob: Raw bytes from save_state().llama_state
406
+ sections: Cache section specifications (order must match blob layout)
407
+
408
+ Returns:
409
+ ParsedMultiSectionCache with one ParsedKVCache per section.
410
+ """
411
+ if len(blob) < 20:
412
+ raise BlobParseError(f"Blob too small: {len(blob)} bytes")
413
+
414
+ offset = 0
415
+
416
+ # Architecture string
417
+ str_len, offset = _read_u32(blob, offset)
418
+ if str_len > 100:
419
+ raise BlobParseError(f"Arch string length {str_len} too large")
420
+ arch = blob[offset : offset + str_len].decode("ascii", errors="replace")
421
+ offset += str_len
422
+
423
+ # Stream count
424
+ n_stream, offset = _read_u32(blob, offset)
425
+ if n_stream != len(sections):
426
+ raise BlobParseError(
427
+ f"Expected {len(sections)} streams, got {n_stream}"
428
+ )
429
+
430
+ # Parse each stream
431
+ parsed_sections: list[ParsedKVCache] = []
432
+ for section in sections:
433
+ parsed, offset = _parse_single_stream(
434
+ blob, offset,
435
+ n_kv_heads=section.n_kv_heads,
436
+ head_dim=section.head_dim,
437
+ arch=arch,
438
+ )
439
+ parsed_sections.append(parsed)
440
+
441
+ return ParsedMultiSectionCache(sections=parsed_sections, arch=arch)
442
+
443
+
444
+ # ── Legacy compat wrapper ────────────────────────────────────────────────────
445
+
446
+
447
+ def parse_seq_state_blob(
448
+ blob: bytes,
449
+ spec: dict,
450
+ kv_dtype: int = GGML_TYPE_F16,
451
+ ) -> ParsedKVCache:
452
+ """Legacy wrapper — delegates to parse_state_blob.
453
+
454
+ Kept for backward compatibility with existing tests.
455
+ """
456
+ return parse_state_blob(
457
+ blob=blob,
458
+ n_kv_heads=spec["n_kv_heads"],
459
+ head_dim=spec["head_dim"],
460
+ )
461
+
462
+
463
+ def estimate_blob_size(
464
+ n_kv_heads: int,
465
+ head_dim: int,
466
+ n_layers: int,
467
+ n_cells: int,
468
+ v_trans: bool = True,
469
+ ) -> int:
470
+ """Estimate expected blob size for validation."""
471
+ header = 4 + 5 + 4 + 4 # str_len + "llama" + n_stream + cell_count
472
+ cell_meta = n_cells * 12 # pos(4) + n_seq(4) + seq_id(4) typical
473
+ data_header = 4 + 4 # v_trans + n_layer
474
+
475
+ n_embd_kv = n_kv_heads * head_dim
476
+ k_per_layer = 4 + 8 + (n_embd_kv * 2 * n_cells) # type + row_size + data
477
+ if v_trans:
478
+ v_per_layer = 4 + 4 + 4 + (n_embd_kv * 2 * n_cells) # type + el_size + n_embd + data
479
+ else:
480
+ v_per_layer = 4 + 8 + (n_embd_kv * 2 * n_cells)
481
+
482
+ return header + cell_meta + data_header + n_layers * (k_per_layer + v_per_layer)
kvcos/core/block_pool.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — 256-Token Block Pool Manager
3
+
4
+
5
+ Segments a full KV cache into fixed-size blocks (256 tokens each) that can be:
6
+ - Stored independently (one .eng file per block — D7)
7
+ - Retrieved individually via EGR (fine-grained cache hits)
8
+ - Composed (assemble a context from multiple blocks)
9
+ - Evicted independently (LRU per block, not per session)
10
+
11
+ Design from arXiv:2603.04428 (Persistent Q4 KV Cache, agent-memory paper).
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass, field
17
+
18
+ import torch
19
+
20
+ from kvcos.core.types import BLOCK_SIZE_TOKENS
21
+
22
+
23
+ @dataclass
24
+ class KVBlock:
25
+ """A single 256-token block of KV cache data."""
26
+
27
+ block_index: int
28
+ token_start: int
29
+ token_end: int # exclusive
30
+
31
+ keys: torch.Tensor # [n_layers, n_kv_heads, block_len, head_dim]
32
+ values: torch.Tensor # [n_layers, n_kv_heads, block_len, head_dim]
33
+
34
+ @property
35
+ def block_len(self) -> int:
36
+ return self.token_end - self.token_start
37
+
38
+ @property
39
+ def is_full(self) -> bool:
40
+ return self.block_len == BLOCK_SIZE_TOKENS
41
+
42
+ @property
43
+ def n_layers(self) -> int:
44
+ return self.keys.shape[0]
45
+
46
+ @property
47
+ def n_kv_heads(self) -> int:
48
+ return self.keys.shape[1]
49
+
50
+ @property
51
+ def head_dim(self) -> int:
52
+ return self.keys.shape[3]
53
+
54
+
55
+ @dataclass
56
+ class BlockPool:
57
+ """Manages a collection of KV blocks for an agent session."""
58
+
59
+ agent_id: str
60
+ model_id: str
61
+ blocks: list[KVBlock] = field(default_factory=list)
62
+
63
+ @property
64
+ def total_tokens(self) -> int:
65
+ return sum(b.block_len for b in self.blocks)
66
+
67
+ @property
68
+ def n_blocks(self) -> int:
69
+ return len(self.blocks)
70
+
71
+ def segment(
72
+ self, keys: torch.Tensor, values: torch.Tensor,
73
+ ) -> list[KVBlock]:
74
+ """Segment a full KV cache into 256-token blocks.
75
+
76
+ Args:
77
+ keys: [n_layers, n_kv_heads, ctx_len, head_dim]
78
+ values: [n_layers, n_kv_heads, ctx_len, head_dim]
79
+ """
80
+ if keys.shape != values.shape:
81
+ raise ValueError(f"Shape mismatch: keys {keys.shape} vs values {values.shape}")
82
+
83
+ ctx_len = keys.shape[2]
84
+ blocks: list[KVBlock] = []
85
+
86
+ for i in range(0, ctx_len, BLOCK_SIZE_TOKENS):
87
+ end = min(i + BLOCK_SIZE_TOKENS, ctx_len)
88
+ block = KVBlock(
89
+ block_index=len(blocks),
90
+ token_start=i,
91
+ token_end=end,
92
+ keys=keys[:, :, i:end, :].contiguous(),
93
+ values=values[:, :, i:end, :].contiguous(),
94
+ )
95
+ blocks.append(block)
96
+
97
+ self.blocks = blocks
98
+ return blocks
99
+
100
+ def assemble(
101
+ self, block_indices: list[int] | None = None,
102
+ ) -> tuple[torch.Tensor, torch.Tensor]:
103
+ """Assemble KV cache from blocks (concatenate along ctx_len dim)."""
104
+ if not self.blocks:
105
+ raise ValueError("No blocks to assemble")
106
+
107
+ selected = self.blocks if block_indices is None else [self.blocks[i] for i in block_indices]
108
+ if not selected:
109
+ raise ValueError("No blocks selected for assembly")
110
+
111
+ keys = torch.cat([b.keys for b in selected], dim=2)
112
+ values = torch.cat([b.values for b in selected], dim=2)
113
+ return keys, values
114
+
115
+ def append_block(self, block: KVBlock) -> None:
116
+ block.block_index = len(self.blocks)
117
+ self.blocks.append(block)
118
+
119
+ def get_block(self, index: int) -> KVBlock:
120
+ if index < 0 or index >= len(self.blocks):
121
+ raise IndexError(f"Block index {index} out of range [0, {len(self.blocks)})")
122
+ return self.blocks[index]
123
+
124
+ def extend(
125
+ self, new_keys: torch.Tensor, new_values: torch.Tensor,
126
+ ) -> list[KVBlock]:
127
+ """Extend the pool with additional tokens, filling last block first."""
128
+ new_ctx_len = new_keys.shape[2]
129
+ modified_blocks: list[KVBlock] = []
130
+ offset = 0
131
+
132
+ if self.blocks and not self.blocks[-1].is_full:
133
+ last = self.blocks[-1]
134
+ space = BLOCK_SIZE_TOKENS - last.block_len
135
+ fill = min(space, new_ctx_len)
136
+
137
+ merged_k = torch.cat([last.keys, new_keys[:, :, :fill, :]], dim=2).contiguous()
138
+ merged_v = torch.cat([last.values, new_values[:, :, :fill, :]], dim=2).contiguous()
139
+
140
+ self.blocks[-1] = KVBlock(
141
+ block_index=last.block_index,
142
+ token_start=last.token_start,
143
+ token_end=last.token_start + merged_k.shape[2],
144
+ keys=merged_k,
145
+ values=merged_v,
146
+ )
147
+ modified_blocks.append(self.blocks[-1])
148
+ offset = fill
149
+
150
+ remaining = new_ctx_len - offset
151
+ if remaining > 0:
152
+ token_base = self.blocks[-1].token_end if self.blocks else 0
153
+ sub_pool = BlockPool(agent_id=self.agent_id, model_id=self.model_id)
154
+ new_blocks = sub_pool.segment(
155
+ new_keys[:, :, offset:, :], new_values[:, :, offset:, :],
156
+ )
157
+ for b in new_blocks:
158
+ b.block_index = len(self.blocks)
159
+ b.token_start += token_base
160
+ b.token_end += token_base
161
+ self.blocks.append(b)
162
+ modified_blocks.append(b)
163
+
164
+ return modified_blocks
165
+
166
+ def clear(self) -> None:
167
+ self.blocks.clear()
kvcos/core/cache_spec.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Model Architecture Registry
3
+
4
+
5
+ Contains ModelCacheSpec definitions for known models and utilities
6
+ to look up specs by model_id or infer model family from string.
7
+
8
+ D3: extraction_layers set to middle-to-deep (8-31 for 32-layer models)
9
+ per ShadowKV validation. Early layers (0-7) and final layer preserved.
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
15
+
16
+ # ── Pre-registered Model Specs ────────────────────────────────────────────────
17
+
18
+ # Llama 3.1 8B — Primary Phase 1 target (D1, D6)
19
+ # GQA: 32 query heads, 8 KV heads, head_dim 128
20
+ LLAMA_3_1_8B = ModelCacheSpec(
21
+ model_id="meta-llama/Llama-3.1-8B-Instruct",
22
+ model_family="llama",
23
+ n_layers=32,
24
+ n_heads=32,
25
+ n_kv_heads=8,
26
+ head_dim=128,
27
+ rope_enabled=True,
28
+ extraction_layers=tuple(range(8, 32)), # layers 8-31 (D3)
29
+ )
30
+
31
+ # Llama 3.1 8B base (non-instruct)
32
+ LLAMA_3_1_8B_BASE = ModelCacheSpec(
33
+ model_id="meta-llama/Llama-3.1-8B",
34
+ model_family="llama",
35
+ n_layers=32,
36
+ n_heads=32,
37
+ n_kv_heads=8,
38
+ head_dim=128,
39
+ rope_enabled=True,
40
+ extraction_layers=tuple(range(8, 32)),
41
+ )
42
+
43
+ # Phi-3-Mini-128K — Secondary Phase 1 target
44
+ # ShadowKV validated SVD on this model (D3)
45
+ # MHA: 32 query heads, 32 KV heads (no GQA), head_dim 96
46
+ PHI_3_MINI = ModelCacheSpec(
47
+ model_id="microsoft/Phi-3-mini-128k-instruct",
48
+ model_family="phi",
49
+ n_layers=32,
50
+ n_heads=32,
51
+ n_kv_heads=32, # Phi-3-Mini uses MHA, not GQA
52
+ head_dim=96,
53
+ rope_enabled=True,
54
+ extraction_layers=tuple(range(8, 32)),
55
+ )
56
+
57
+ # Gemma 2 2B — NOTE: QK-Norm model, SVD behavior may differ (T3 caveat)
58
+ GEMMA_2_2B = ModelCacheSpec(
59
+ model_id="google/gemma-2-2b-it",
60
+ model_family="gemma",
61
+ n_layers=26,
62
+ n_heads=8,
63
+ n_kv_heads=4,
64
+ head_dim=256,
65
+ rope_enabled=True,
66
+ extraction_layers=tuple(range(6, 26)),
67
+ )
68
+
69
+ # Qwen 2.5 7B
70
+ QWEN_2_5_7B = ModelCacheSpec(
71
+ model_id="Qwen/Qwen2.5-7B-Instruct",
72
+ model_family="qwen",
73
+ n_layers=28,
74
+ n_heads=28,
75
+ n_kv_heads=4,
76
+ head_dim=128,
77
+ rope_enabled=True,
78
+ extraction_layers=tuple(range(7, 28)),
79
+ )
80
+
81
+ # Mistral 7B v0.3
82
+ MISTRAL_7B = ModelCacheSpec(
83
+ model_id="mistralai/Mistral-7B-Instruct-v0.3",
84
+ model_family="mistral",
85
+ n_layers=32,
86
+ n_heads=32,
87
+ n_kv_heads=8,
88
+ head_dim=128,
89
+ rope_enabled=True,
90
+ extraction_layers=tuple(range(8, 32)),
91
+ )
92
+
93
+
94
+ # Gemma 4 26B-A4B — ISWA model (Interleaved Sliding Window Attention)
95
+ # Dual KV cache: Global (full context) + SWA (sliding window 1024 tokens)
96
+ # MoE: 128 experts, 8 active — does NOT affect KV cache (FFN-only)
97
+ # Reverse-engineered from llama.cpp b5200+ state blob format.
98
+ GEMMA_4_26B_A4B = ModelCacheSpec(
99
+ model_id="google/gemma-4-26b-a4b-it",
100
+ model_family="gemma",
101
+ n_layers=30, # total: 5 global + 25 SWA
102
+ n_heads=32,
103
+ n_kv_heads=8, # dominant section (SWA)
104
+ head_dim=256, # dominant section (SWA)
105
+ rope_enabled=True,
106
+ extraction_layers=tuple(range(8, 30)),
107
+ cache_sections=(
108
+ CacheSection(
109
+ attention_type=AttentionType.FULL,
110
+ n_layers=5,
111
+ n_kv_heads=2,
112
+ head_dim=512,
113
+ ),
114
+ CacheSection(
115
+ attention_type=AttentionType.SLIDING,
116
+ n_layers=25,
117
+ n_kv_heads=8,
118
+ head_dim=256,
119
+ window_size=1024,
120
+ ),
121
+ ),
122
+ )
123
+
124
+
125
+ # ── Registry ──────────────────────────────────────────────────────────────────
126
+
127
+ _REGISTRY: dict[str, ModelCacheSpec] = {
128
+ spec["model_id"]: spec
129
+ for spec in [
130
+ LLAMA_3_1_8B,
131
+ LLAMA_3_1_8B_BASE,
132
+ PHI_3_MINI,
133
+ GEMMA_2_2B,
134
+ GEMMA_4_26B_A4B,
135
+ QWEN_2_5_7B,
136
+ MISTRAL_7B,
137
+ ]
138
+ }
139
+
140
+ _FAMILY_MAP: dict[str, str] = {
141
+ "llama": "llama",
142
+ "meta-llama": "llama",
143
+ "phi": "phi",
144
+ "microsoft/phi": "phi",
145
+ "gemma": "gemma",
146
+ "google/gemma": "gemma",
147
+ "qwen": "qwen",
148
+ "mistral": "mistral",
149
+ "deepseek": "deepseek",
150
+ }
151
+
152
+
153
+ def get_model_spec(model_id: str) -> ModelCacheSpec | None:
154
+ """Look up a ModelCacheSpec by exact model_id."""
155
+ return _REGISTRY.get(model_id)
156
+
157
+
158
+ def register_model_spec(spec: ModelCacheSpec) -> None:
159
+ """Register a new model spec in the runtime registry."""
160
+ _REGISTRY[spec["model_id"]] = spec
161
+
162
+
163
+ def infer_model_family(model_id: str) -> str:
164
+ """Infer model family from a model_id string."""
165
+ model_id_lower = model_id.lower()
166
+ for prefix, family in _FAMILY_MAP.items():
167
+ if prefix in model_id_lower:
168
+ return family
169
+ return "unknown"
170
+
171
+
172
+ def make_spec_from_metadata(
173
+ model_id: str,
174
+ n_layers: int,
175
+ n_heads: int,
176
+ n_kv_heads: int,
177
+ head_dim: int,
178
+ rope_enabled: bool = True,
179
+ ) -> ModelCacheSpec:
180
+ """Create a ModelCacheSpec from raw parameters.
181
+
182
+ Automatically sets extraction_layers to middle-to-deep range (D3).
183
+ """
184
+ skip_layers = max(1, n_layers // 4)
185
+ extraction_layers = tuple(range(skip_layers, n_layers))
186
+
187
+ return ModelCacheSpec(
188
+ model_id=model_id,
189
+ model_family=infer_model_family(model_id),
190
+ n_layers=n_layers,
191
+ n_heads=n_heads,
192
+ n_kv_heads=n_kv_heads,
193
+ head_dim=head_dim,
194
+ rope_enabled=rope_enabled,
195
+ extraction_layers=extraction_layers,
196
+ )
197
+
198
+
199
+ def is_iswa_spec(spec: ModelCacheSpec) -> bool:
200
+ """Check if a model spec describes an ISWA (multi-section) cache."""
201
+ return "cache_sections" in spec
202
+
203
+
204
+ def validate_kv_shape(
205
+ spec: ModelCacheSpec,
206
+ n_layers: int,
207
+ n_kv_heads: int,
208
+ head_dim: int,
209
+ ) -> bool:
210
+ """Validate that KV tensor dimensions match the model spec."""
211
+ return (
212
+ spec["n_layers"] == n_layers
213
+ and spec["n_kv_heads"] == n_kv_heads
214
+ and spec["head_dim"] == head_dim
215
+ )
kvcos/core/compression.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — KV Cache Compression Layer
3
+
4
+
5
+ Implements:
6
+ - FP16 passthrough (no compression)
7
+ - Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0
8
+ Phase 1 production fallback. ~2x compression, <5% speed hit (D5).
9
+ - PolarQuant: MSE-optimal random rotation + Lloyd-Max codebook at 3 bits.
10
+ QJL REMOVED — confirmed harmful by 6+ independent implementations (D5).
11
+ Softmax amplifies QJL variance, making two-stage worse than MSE-only.
12
+
13
+ Reference: TheTom/turboquant_plus (511+ tests, most mature impl)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from kvcos.core.types import CompressionMethod
24
+
25
+ # ── Q8_0 Constants ────────────────────────────────────────────────────────────
26
+ Q8_GROUP_SIZE = 32
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class CompressionResult:
31
+ """Result of compressing a KV cache tensor."""
32
+
33
+ data: torch.Tensor
34
+ method: CompressionMethod
35
+ original_dtype: torch.dtype
36
+ compression_ratio: float
37
+ metadata: dict[str, str]
38
+
39
+
40
+ # ── FP16 Passthrough ──────────────────────────────────────────────────────────
41
+
42
+
43
+ def compress_fp16(kv: torch.Tensor) -> CompressionResult:
44
+ """No-op compression: ensure tensor is FP16."""
45
+ data = kv.to(torch.float16).contiguous()
46
+ return CompressionResult(
47
+ data=data,
48
+ method=CompressionMethod.FP16,
49
+ original_dtype=kv.dtype,
50
+ compression_ratio=1.0,
51
+ metadata={},
52
+ )
53
+
54
+
55
+ def decompress_fp16(data: torch.Tensor) -> torch.Tensor:
56
+ return data.to(torch.float16)
57
+
58
+
59
+ # ── Q8_0 Quantization ────────────────────────────────────────────────────────
60
+ # Matches llama.cpp GGML_TYPE_Q8_0 layout:
61
+ # 32-element groups, 1 float16 scale per group, 32 int8 values
62
+ # Storage: (32*1 + 2) / (32*2) = 34/64 ≈ 1.88x compression
63
+
64
+
65
+ def compress_q8_0(kv: torch.Tensor) -> CompressionResult:
66
+ """Quantize KV cache to Q8_0 (int8 with per-group scale).
67
+
68
+ Stores dequantized bfloat16 for safetensors compatibility —
69
+ safetensors doesn't support int8+scale pairs natively.
70
+ """
71
+ original_dtype = kv.dtype
72
+ original_bytes = kv.numel() * kv.element_size()
73
+
74
+ kv_flat = kv.float().contiguous()
75
+ orig_shape = kv_flat.shape
76
+
77
+ last_dim = orig_shape[-1]
78
+ pad_amount = (Q8_GROUP_SIZE - last_dim % Q8_GROUP_SIZE) % Q8_GROUP_SIZE
79
+ if pad_amount > 0:
80
+ kv_flat = torch.nn.functional.pad(kv_flat, (0, pad_amount))
81
+
82
+ new_shape = kv_flat.shape[:-1] + (-1, Q8_GROUP_SIZE)
83
+ grouped = kv_flat.reshape(new_shape)
84
+
85
+ scales = grouped.abs().amax(dim=-1, keepdim=True) / 127.0
86
+ scales = scales.clamp(min=1e-10)
87
+
88
+ quantized = torch.clamp(torch.round(grouped / scales), -127, 127)
89
+ dequantized = (quantized * scales).reshape(kv_flat.shape)
90
+
91
+ if pad_amount > 0:
92
+ dequantized = dequantized[..., :last_dim]
93
+
94
+ dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
95
+ compressed_bytes = dequantized.numel() * 2
96
+
97
+ return CompressionResult(
98
+ data=dequantized,
99
+ method=CompressionMethod.Q8_0,
100
+ original_dtype=original_dtype,
101
+ compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
102
+ metadata={"q8_group_size": str(Q8_GROUP_SIZE)},
103
+ )
104
+
105
+
106
+ def decompress_q8_0(data: torch.Tensor) -> torch.Tensor:
107
+ return data.to(torch.float16)
108
+
109
+
110
+ # ── PolarQuant (Phase 2 — TurboQuant without QJL) ────────────────────────────
111
+ # QJL is INTENTIONALLY ABSENT per D5.
112
+
113
+
114
+ class PolarQuantConfig:
115
+ """Configuration for PolarQuant compression."""
116
+
117
+ def __init__(self, bits: int = 3, seed: int = 42):
118
+ self.bits = bits
119
+ self.n_centroids = 2**bits
120
+ self.seed = seed
121
+ self._rotation_cache: dict[int, torch.Tensor] = {}
122
+ self._codebook_cache: dict[int, torch.Tensor] = {}
123
+
124
+ def get_rotation_matrix(self, dim: int, device: torch.device) -> torch.Tensor:
125
+ """Get fixed random orthogonal rotation matrix R ∈ R^(d×d)."""
126
+ if dim not in self._rotation_cache:
127
+ rng = np.random.RandomState(self.seed)
128
+ gaussian = rng.randn(dim, dim).astype(np.float32)
129
+ q, r = np.linalg.qr(gaussian)
130
+ d = np.diag(r)
131
+ ph = np.sign(d)
132
+ q *= ph[np.newaxis, :]
133
+ self._rotation_cache[dim] = torch.from_numpy(q)
134
+ return self._rotation_cache[dim].to(device)
135
+
136
+ def get_lloyd_max_codebook(self, dim: int) -> torch.Tensor:
137
+ """Lloyd-Max optimal centroids for N(0,1), 3-bit (8 levels)."""
138
+ if dim not in self._codebook_cache:
139
+ codebook = torch.tensor(
140
+ [-1.748, -1.050, -0.501, -0.000, 0.000, 0.501, 1.050, 1.748],
141
+ dtype=torch.float32,
142
+ )
143
+ self._codebook_cache[dim] = codebook
144
+ return self._codebook_cache[dim]
145
+
146
+
147
+ _POLAR_CONFIG = PolarQuantConfig()
148
+
149
+
150
+ def compress_polarquant(kv: torch.Tensor) -> CompressionResult:
151
+ """Compress using PolarQuant (3-bit Lloyd-Max after random rotation).
152
+
153
+ Phase 2 implementation. Currently stores dequantized bfloat16.
154
+ True 3-bit packed storage is Phase 2+.
155
+ """
156
+ original_dtype = kv.dtype
157
+ original_bytes = kv.numel() * kv.element_size()
158
+ device = kv.device
159
+
160
+ kv_float = kv.float().contiguous()
161
+ orig_shape = kv_float.shape
162
+
163
+ head_dim = orig_shape[-1]
164
+ flat = kv_float.reshape(-1, head_dim)
165
+
166
+ R = _POLAR_CONFIG.get_rotation_matrix(head_dim, device)
167
+ rotated = flat @ R
168
+
169
+ dim_std = rotated.std(dim=0, keepdim=True).clamp(min=1e-10)
170
+ normalized = rotated / dim_std
171
+
172
+ codebook = _POLAR_CONFIG.get_lloyd_max_codebook(head_dim).to(device)
173
+ distances = (normalized.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0)) ** 2
174
+ indices = distances.argmin(dim=-1)
175
+
176
+ dequantized = codebook[indices]
177
+ dequantized = dequantized * dim_std
178
+ R_inv = R.T
179
+ dequantized = dequantized @ R_inv
180
+
181
+ dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
182
+ compressed_bytes = dequantized.numel() * 2
183
+
184
+ return CompressionResult(
185
+ data=dequantized,
186
+ method=CompressionMethod.POLARQUANT,
187
+ original_dtype=original_dtype,
188
+ compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
189
+ metadata={
190
+ "polarquant_bits": "3",
191
+ "polarquant_seed": str(_POLAR_CONFIG.seed),
192
+ "qjl_enabled": "false", # D5: QJL permanently disabled
193
+ },
194
+ )
195
+
196
+
197
+ def decompress_polarquant(data: torch.Tensor) -> torch.Tensor:
198
+ return data.to(torch.float16)
199
+
200
+
201
+ # ── INT8 Quantization (Phase 2 — true on-disk compression) ───────────────────
202
+ # Stores actual int8 tensors in safetensors (1 byte/element vs 2 for fp16).
203
+ # Per-row symmetric quantization: scale = max(abs(row)) / 127.
204
+ # Separate scale tensor stored alongside quantized data.
205
+ # 2x on-disk compression with cos_sim > 0.999.
206
+
207
+
208
+ @dataclass(frozen=True)
209
+ class Int8CompressedPair:
210
+ """INT8 quantized tensor + per-row scales."""
211
+
212
+ quantized: torch.Tensor # int8 [same shape as input]
213
+ scales: torch.Tensor # float16 [shape[:-1]] — one scale per row
214
+
215
+
216
+ def compress_int8_tensor(kv: torch.Tensor) -> Int8CompressedPair:
217
+ """Quantize a KV tensor to int8 with per-row scales.
218
+
219
+ Args:
220
+ kv: [..., head_dim] tensor (any dtype)
221
+
222
+ Returns:
223
+ Int8CompressedPair with int8 data and float16 scales
224
+ """
225
+ orig_shape = kv.shape
226
+ flat = kv.float().reshape(-1, orig_shape[-1])
227
+
228
+ row_max = flat.abs().amax(dim=1, keepdim=True).clamp(min=1e-8)
229
+ scales = row_max / 127.0
230
+
231
+ quantized = (flat / scales).round().clamp(-127, 127).to(torch.int8)
232
+ scales_f16 = scales.squeeze(1).to(torch.float16)
233
+
234
+ return Int8CompressedPair(
235
+ quantized=quantized.reshape(orig_shape),
236
+ scales=scales_f16.reshape(orig_shape[:-1]),
237
+ )
238
+
239
+
240
+ def decompress_int8_tensor(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
241
+ """Dequantize int8 tensor using per-row scales.
242
+
243
+ Returns float16 tensor of the original shape.
244
+ """
245
+ return (quantized.float() * scales.float().unsqueeze(-1)).to(torch.float16)
246
+
247
+
248
+ def compress_int8(kv: torch.Tensor) -> CompressionResult:
249
+ """INT8 compression — returns dequantized float16 for CompressionResult compat.
250
+
251
+ The actual int8 storage is handled by the serializer which calls
252
+ compress_int8_tensor() directly for true on-disk compression.
253
+ This wrapper exists for the dispatcher API.
254
+ """
255
+ pair = compress_int8_tensor(kv)
256
+ dequantized = decompress_int8_tensor(pair.quantized, pair.scales)
257
+
258
+ original_bytes = kv.numel() * kv.element_size()
259
+ # True on-disk: int8 data + float16 scales
260
+ compressed_bytes = pair.quantized.numel() * 1 + pair.scales.numel() * 2
261
+
262
+ return CompressionResult(
263
+ data=dequantized,
264
+ method=CompressionMethod.INT8,
265
+ original_dtype=kv.dtype,
266
+ compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
267
+ metadata={"int8_scale_dtype": "float16"},
268
+ )
269
+
270
+
271
+ # ── LAYER_DELTA Compression ──────────────────────────────────────────────────
272
+ # Stores layer 0 as fp16 baseline, layers 1..N as int8 deltas from previous.
273
+ # Inter-layer residuals are typically small (adjacent layers are correlated),
274
+ # so int8 quantization of deltas achieves better fidelity than direct int8.
275
+ # On-disk: ~(1/N) fp16 + ((N-1)/N) int8 ≈ slightly better than straight INT8.
276
+
277
+
278
+ @dataclass(frozen=True)
279
+ class LayerDeltaCompressed:
280
+ """Layer-delta compressed: fp16 baseline + int8 deltas."""
281
+
282
+ baseline: torch.Tensor # [n_kv_heads, n_cells, head_dim] fp16
283
+ delta_quantized: list[torch.Tensor] # each int8 [n_kv_heads, n_cells, head_dim]
284
+ delta_scales: list[torch.Tensor] # each fp16 [n_kv_heads, n_cells]
285
+ n_layers: int
286
+
287
+
288
+ def compress_layer_delta(kv: torch.Tensor) -> LayerDeltaCompressed:
289
+ """Compress KV tensor using inter-layer delta encoding.
290
+
291
+ Args:
292
+ kv: [n_layers, n_kv_heads, n_cells, head_dim]
293
+
294
+ Returns:
295
+ LayerDeltaCompressed with fp16 baseline + int8 deltas
296
+ """
297
+ n_layers = kv.shape[0]
298
+ baseline = kv[0].to(torch.float16)
299
+
300
+ deltas: list[torch.Tensor] = []
301
+ scales: list[torch.Tensor] = []
302
+
303
+ for i in range(1, n_layers):
304
+ delta = (kv[i].float() - kv[i - 1].float())
305
+ flat = delta.reshape(-1, delta.shape[-1])
306
+ row_max = flat.abs().amax(dim=1).clamp(min=1e-8) / 127.0
307
+ q = (flat / row_max.unsqueeze(1)).round().clamp(-127, 127).to(torch.int8)
308
+ deltas.append(q.reshape(delta.shape))
309
+ scales.append(row_max.to(torch.float16).reshape(delta.shape[:-1]))
310
+
311
+ return LayerDeltaCompressed(
312
+ baseline=baseline, delta_quantized=deltas,
313
+ delta_scales=scales, n_layers=n_layers,
314
+ )
315
+
316
+
317
+ def decompress_layer_delta(data: LayerDeltaCompressed) -> torch.Tensor:
318
+ """Decompress layer-delta encoded KV tensor."""
319
+ layers = [data.baseline.float()]
320
+ for dq, ds in zip(data.delta_quantized, data.delta_scales):
321
+ flat = dq.float().reshape(-1, dq.shape[-1])
322
+ delta = (flat * ds.float().reshape(-1).unsqueeze(1)).reshape(dq.shape)
323
+ layers.append(layers[-1] + delta)
324
+ return torch.stack(layers).to(torch.float16)
325
+
326
+
327
+ def compress_layer_delta_result(kv: torch.Tensor) -> CompressionResult:
328
+ """Layer-delta wrapper for CompressionResult API."""
329
+ compressed = compress_layer_delta(kv)
330
+ decompressed = decompress_layer_delta(compressed)
331
+
332
+ original_bytes = kv.numel() * kv.element_size()
333
+ # On-disk: baseline fp16 + (N-1) int8 deltas + (N-1) fp16 scales
334
+ n = compressed.n_layers
335
+ per_layer_elements = kv[0].numel()
336
+ scale_elements = kv.shape[1] * kv.shape[2] # n_kv_heads * n_cells
337
+ compressed_bytes = (
338
+ per_layer_elements * 2 # baseline fp16
339
+ + (n - 1) * per_layer_elements * 1 # int8 deltas
340
+ + (n - 1) * scale_elements * 2 # fp16 scales
341
+ )
342
+
343
+ return CompressionResult(
344
+ data=decompressed,
345
+ method=CompressionMethod.LAYER_DELTA,
346
+ original_dtype=kv.dtype,
347
+ compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
348
+ metadata={"delta_n_layers": str(n)},
349
+ )
350
+
351
+
352
+ # ── Dispatcher ────────────────────────────────────────────────────────────────
353
+
354
+
355
+ def compress(kv: torch.Tensor, method: CompressionMethod) -> CompressionResult:
356
+ """Compress a KV cache tensor using the specified method."""
357
+ match method:
358
+ case CompressionMethod.FP16:
359
+ return compress_fp16(kv)
360
+ case CompressionMethod.Q8_0:
361
+ return compress_q8_0(kv)
362
+ case CompressionMethod.POLARQUANT:
363
+ return compress_polarquant(kv)
364
+ case CompressionMethod.INT8:
365
+ return compress_int8(kv)
366
+ case CompressionMethod.LAYER_DELTA:
367
+ return compress_layer_delta_result(kv)
368
+ case CompressionMethod.Q4_0:
369
+ import warnings
370
+
371
+ warnings.warn(
372
+ "Q4_0 has 92% dequantization slowdown at 64K+ context. "
373
+ "Using Q8_0 instead. See D5.",
374
+ UserWarning,
375
+ stacklevel=2,
376
+ )
377
+ return compress_q8_0(kv)
378
+ case _:
379
+ raise ValueError(f"Unknown compression method: {method}")
380
+
381
+
382
+ def decompress(data: torch.Tensor, method: CompressionMethod) -> torch.Tensor:
383
+ """Decompress a KV cache tensor."""
384
+ match method:
385
+ case CompressionMethod.FP16:
386
+ return decompress_fp16(data)
387
+ case CompressionMethod.Q8_0 | CompressionMethod.Q4_0:
388
+ return decompress_q8_0(data)
389
+ case CompressionMethod.POLARQUANT:
390
+ return decompress_polarquant(data)
391
+ case CompressionMethod.INT8 | CompressionMethod.LAYER_DELTA:
392
+ # Already dequantized float16 in CompressionResult
393
+ return data.to(torch.float16)
394
+ case _:
395
+ raise ValueError(f"Unknown compression method: {method}")
kvcos/core/config.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Centralized Configuration
3
+
4
+
5
+ Single source of truth for all runtime configuration.
6
+ Uses pydantic-settings for validation and type coercion.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from functools import lru_cache
12
+ from pathlib import Path
13
+
14
+ from pydantic_settings import BaseSettings, SettingsConfigDict
15
+
16
+ from kvcos.core.types import CompressionMethod, IndexBackend, StorageBackend
17
+
18
+
19
+ class EngramConfig(BaseSettings):
20
+ """ENGRAM runtime configuration.
21
+
22
+ Loaded from environment variables with ENGRAM_ prefix,
23
+ or from a .env file in the project root.
24
+ """
25
+
26
+ model_config = SettingsConfigDict(
27
+ env_prefix="ENGRAM_",
28
+ env_file=".env",
29
+ env_file_encoding="utf-8",
30
+ extra="ignore",
31
+ )
32
+
33
+ # ── Server ────────────────────────────────────────────────
34
+ port: int = 8080
35
+ host: str = "0.0.0.0"
36
+
37
+ # ── Storage ───────────────────────────────────────────────
38
+ data_dir: Path = Path.home() / ".engram" / "data"
39
+ backend: StorageBackend = StorageBackend.LOCAL
40
+ default_compression: CompressionMethod = CompressionMethod.Q8_0
41
+
42
+ # ── FAISS Index (D2) ──────────────────────────────────────
43
+ index_backend: IndexBackend = IndexBackend.FAISS_FLAT_IP
44
+ index_dir: Path = Path.home() / ".engram" / "index"
45
+ # State vector dimension — must match extraction output
46
+ # 128 for mean_pool (head_dim), 160 for svd_project (rank-160)
47
+ state_vec_dim: int = 160
48
+
49
+ # ── LLM Runtime (D1) ──────────────────────────────────────
50
+ model_path: str = "" # Path to GGUF model file
51
+ n_gpu_layers: int = 0 # D1: CPU-only Phase 1 (avoids Issue #743)
52
+ n_ctx: int = 16384 # D6: 16K context for Phase 1 demo target
53
+
54
+ # ── Phase 2: Remote backends ──────────────────────────────
55
+ redis_url: str = "redis://localhost:6379"
56
+ redis_max_memory_gb: float = 2.0
57
+ s3_bucket: str = "engram-cache"
58
+ s3_region: str = "eu-central-1"
59
+ cloudflare_r2_endpoint: str = ""
60
+
61
+ # ── Phase 2: Semantic index ───────────────────────────────
62
+ qdrant_url: str = "http://localhost:6333"
63
+ qdrant_api_key: str = ""
64
+ qdrant_collection: str = "engram_states"
65
+ cohere_api_key: str = ""
66
+
67
+ # ── Phase 4: Cross-model transfer ─────────────────────────
68
+ adapter_enabled: bool = False
69
+ adapter_checkpoint_dir: Path = Path.home() / ".engram" / "adapters"
70
+
71
+ def ensure_dirs(self) -> None:
72
+ """Create required directories if they don't exist."""
73
+ self.data_dir.mkdir(parents=True, exist_ok=True)
74
+ self.index_dir.mkdir(parents=True, exist_ok=True)
75
+
76
+
77
+ @lru_cache(maxsize=1)
78
+ def get_config() -> EngramConfig:
79
+ """Get the singleton config instance. Cached after first call."""
80
+ config = EngramConfig()
81
+ config.ensure_dirs()
82
+ return config
kvcos/core/fingerprint.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Standalone State Extraction Functions
3
+
4
+ Contains the Engram Absolute fingerprint: compute_fourier_fingerprint().
5
+ This is the primary cross-model retrieval fingerprint, validated at
6
+ 98% recall@1 at N=1000 with power-law decay N^-0.207.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from typing import TYPE_CHECKING
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+
16
+ if TYPE_CHECKING:
17
+ from kvcos.core.blob_parser import ParsedMultiSectionCache
18
+
19
+
20
+ def compute_fourier_fingerprint(
21
+ layer_keys: torch.Tensor,
22
+ freqs: list[int] | None = None,
23
+ ) -> torch.Tensor:
24
+ """Compute the Engram Absolute fingerprint (f0+f1) from per-layer mean keys.
25
+
26
+ Takes the real DFT over the layer dimension, extracts amplitude at
27
+ the specified frequencies, normalizes each, and concatenates.
28
+
29
+ Args:
30
+ layer_keys: [n_layers, dim] where dim = n_kv_heads * head_dim.
31
+ Must be per-layer MEAN across token positions.
32
+ If shape is [n_layers, n_kv, hd], reshapes automatically.
33
+ freqs: DFT frequency indices to concatenate.
34
+ Default [0, 1] = DC (f0) + first harmonic (f1).
35
+
36
+ Returns:
37
+ Fingerprint tensor [dim * len(freqs)], L2-normalized, float32.
38
+
39
+ Properties:
40
+ - Cross-model invariant within Llama-3.x family (cos ~0.89)
41
+ - Zero corpus dependency: no centroid, no basis, no training data
42
+ - Recall@1 N=200: 98% N=1000: 98% decay: N^-0.207
43
+ """
44
+ if freqs is None:
45
+ freqs = [0, 1]
46
+
47
+ if layer_keys.dim() == 3:
48
+ n_layers = layer_keys.shape[0]
49
+ layer_keys = layer_keys.reshape(n_layers, -1)
50
+
51
+ layer_keys = layer_keys.float()
52
+
53
+ F_complex = torch.fft.rfft(layer_keys, dim=0)
54
+ F_amp = F_complex.abs()
55
+
56
+ components = []
57
+ for f in freqs:
58
+ if f >= F_amp.shape[0]:
59
+ raise ValueError(
60
+ f"Requested freq={f} but rfft produced only "
61
+ f"{F_amp.shape[0]} components for {layer_keys.shape[0]} layers."
62
+ )
63
+ components.append(F.normalize(F_amp[f], dim=-1))
64
+
65
+ return torch.cat(components, dim=-1)
66
+
67
+
68
+ def compute_eigenform_score(
69
+ layer_keys: torch.Tensor,
70
+ noise_sigma: float = 0.001,
71
+ n_trials: int = 3,
72
+ freqs: list | None = None,
73
+ ) -> float:
74
+ """Compute eigenform stability score via noise perturbation.
75
+
76
+ Measures how stable the Fourier fingerprint is under small noise.
77
+ Score near 1.0 = stable. Below 0.95 = fragile fingerprint.
78
+
79
+ Args:
80
+ layer_keys: [n_layers, dim] per-layer mean key vectors.
81
+ noise_sigma: Gaussian noise standard deviation.
82
+ n_trials: Number of perturbed copies to compare.
83
+ freqs: DFT frequencies. Default [0, 1].
84
+
85
+ Returns:
86
+ float in [0, 1]. Mean pairwise cosine across noise trials.
87
+ """
88
+ if freqs is None:
89
+ freqs = [0, 1]
90
+ fps = []
91
+ for t in range(n_trials):
92
+ noisy = layer_keys if t == 0 else layer_keys + torch.randn_like(layer_keys) * noise_sigma
93
+ fps.append(compute_fourier_fingerprint(noisy.float(), freqs=freqs))
94
+ pairs = [(i, j) for i in range(n_trials) for j in range(i+1, n_trials)]
95
+ if not pairs:
96
+ return 1.0
97
+ return float(sum(F.cosine_similarity(fps[a].unsqueeze(0), fps[b].unsqueeze(0)).item() for a, b in pairs) / len(pairs))
98
+
99
+
100
+ def compute_iswa_fingerprint(
101
+ parsed: "ParsedMultiSectionCache",
102
+ freqs: list | None = None,
103
+ normalize_layers: bool = True,
104
+ ) -> torch.Tensor:
105
+ """Compute concatenated Fourier fingerprint for ISWA multi-section caches.
106
+
107
+ Strategy A (per-section concatenation):
108
+ For each cache section, compute mean over tokens, then Fourier FP.
109
+ Concatenate section FPs into one vector.
110
+
111
+ For Gemma 4 with freqs=[0, 1]:
112
+ Global (5 layers, 2 heads, 512 dim) → 1024 * 2 = 2048
113
+ SWA (25 layers, 8 heads, 256 dim) → 2048 * 2 = 4096
114
+ Total: 6144-dim fingerprint
115
+
116
+ Each section's sub-fingerprint is independently L2-normalized,
117
+ preserving the relative geometry within each attention type.
118
+
119
+ Args:
120
+ parsed: ParsedMultiSectionCache from parse_multi_section_blob()
121
+ freqs: DFT frequency indices. Default [0, 1].
122
+ normalize_layers: L2-normalize each layer before DFT (v2 behavior).
123
+
124
+ Returns:
125
+ Concatenated fingerprint tensor, float32.
126
+ """
127
+ if freqs is None:
128
+ freqs = [0, 1]
129
+
130
+ section_fps: list[torch.Tensor] = []
131
+ for section in parsed.sections:
132
+ # Mean over tokens: [n_layers, n_kv_heads, n_cells, head_dim] → [n_layers, n_kv_heads * head_dim]
133
+ layer_keys = section.keys.float().mean(dim=2)
134
+ fp = compute_fourier_fingerprint_v2(layer_keys, freqs=freqs, normalize_layers=normalize_layers)
135
+ section_fps.append(fp)
136
+
137
+ return torch.cat(section_fps, dim=-1)
138
+
139
+
140
+ def compute_fourier_fingerprint_v2(
141
+ layer_keys: torch.Tensor,
142
+ freqs: list | None = None,
143
+ normalize_layers: bool = True,
144
+ ) -> torch.Tensor:
145
+ """Fourier fingerprint v2: L2-normalize each layer before DFT.
146
+
147
+ Removes absolute magnitude scale (which differs by KV head count
148
+ across model families), preserves layer-progression shape.
149
+
150
+ Within-family: same recall as v1 (98%).
151
+ Cross-family: f0+f1 cross-sim expected >>0.26 (v1 baseline).
152
+ """
153
+ if freqs is None:
154
+ freqs = [0, 1]
155
+ if layer_keys.dim() == 3:
156
+ layer_keys = layer_keys.reshape(layer_keys.shape[0], -1)
157
+ layer_keys = layer_keys.float()
158
+ if normalize_layers:
159
+ layer_keys = F.normalize(layer_keys, dim=-1)
160
+ F_complex = torch.fft.rfft(layer_keys, dim=0)
161
+ F_amp = F_complex.abs()
162
+ components = []
163
+ for f in freqs:
164
+ if f >= F_amp.shape[0]:
165
+ raise ValueError(f"freq={f} out of range for {layer_keys.shape[0]} layers")
166
+ components.append(F.normalize(F_amp[f], dim=-1))
167
+ return torch.cat(components, dim=-1)
kvcos/core/manifold_index.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engrammatic Geometry Retrieval — Manifold Index
3
+
4
+
5
+ FAISS-backed MIPS (Maximum Inner Product Search) index for EGR retrieval.
6
+ Indexes state vectors extracted from .eng files by MARStateExtractor.
7
+
8
+ D2: FAISS IndexFlatIP for K→K retrieval only. Never Q→K.
9
+ faiss.serialize_index() for persistence (not write_index — avoids
10
+ platform incompatibility Issue #3888). Atomic write via temp + rename.
11
+ MKL build enforced at import time.
12
+
13
+ D4: No L2 normalization. True MIPS. Raw inner product scores.
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass, field
19
+ from pathlib import Path
20
+
21
+ import faiss
22
+ import numpy as np
23
+ import torch
24
+
25
+ from kvcos.core.types import CacheSearchResult
26
+
27
+
28
+ @dataclass
29
+ class IndexEntry:
30
+ """Metadata associated with an indexed state vector."""
31
+
32
+ cache_id: str
33
+ task_description: str
34
+ model_id: str
35
+ created_at: str
36
+ context_len: int
37
+ l2_norm: float # D4: stored for optional downstream use
38
+
39
+
40
+ class ManifoldIndex:
41
+ """FAISS-backed inner product index for EGR state vectors.
42
+
43
+ Stores state vectors and associated metadata for MIPS retrieval.
44
+ Persistence via faiss.serialize_index() with atomic file writes.
45
+
46
+ Usage:
47
+ index = ManifoldIndex(dim=160)
48
+ index.add(state_vec, entry)
49
+ results = index.search(query_vec, top_k=5)
50
+ index.save(Path("~/.engram/index/egr.faiss"))
51
+ """
52
+
53
+ def __init__(self, dim: int, index_path: Path | None = None):
54
+ """Initialize the manifold index.
55
+
56
+ Args:
57
+ dim: Dimension of state vectors (must match MARStateExtractor output).
58
+ index_path: Optional path to load an existing index from disk.
59
+ """
60
+ self.dim = dim
61
+ self._entries: list[IndexEntry] = []
62
+ self._id_to_position: dict[str, int] = {} # cache_id → FAISS row position
63
+
64
+ if index_path and index_path.exists():
65
+ self._index = self._load_index(index_path)
66
+ else:
67
+ # D2: IndexFlatIP — exact MIPS, correct for Phase 1 corpus sizes (<100K)
68
+ self._index = faiss.IndexFlatIP(dim)
69
+
70
+ @property
71
+ def n_entries(self) -> int:
72
+ """Number of indexed state vectors."""
73
+ return self._index.ntotal
74
+
75
+ def add(
76
+ self,
77
+ state_vec: torch.Tensor | np.ndarray,
78
+ entry: IndexEntry,
79
+ ) -> None:
80
+ """Add a state vector and its metadata to the index.
81
+
82
+ Args:
83
+ state_vec: [dim] state vector (D4: NOT normalized)
84
+ entry: Associated metadata for this engram
85
+ """
86
+ vec = self._to_numpy(state_vec)
87
+
88
+ if vec.shape != (self.dim,):
89
+ raise ValueError(
90
+ f"State vector dim {vec.shape} != index dim ({self.dim},)"
91
+ )
92
+
93
+ # Check for duplicate cache_id
94
+ if entry.cache_id in self._id_to_position:
95
+ # Update: remove old entry position tracking, add at new position
96
+ # FAISS IndexFlat doesn't support in-place update, so we just
97
+ # track the latest position. Old vector remains but is shadowed.
98
+ pass
99
+
100
+ position = self._index.ntotal
101
+ self._index.add(vec.reshape(1, -1).astype(np.float32))
102
+ self._entries.append(entry)
103
+ self._id_to_position[entry.cache_id] = position
104
+
105
+ def search(
106
+ self,
107
+ query_vec: torch.Tensor | np.ndarray,
108
+ top_k: int = 5,
109
+ min_similarity: float | None = None,
110
+ model_id: str | None = None,
111
+ ) -> list[CacheSearchResult]:
112
+ """Search for the most similar engram states via MIPS.
113
+
114
+ Args:
115
+ query_vec: [dim] query state vector
116
+ top_k: Number of results to return
117
+ min_similarity: Minimum inner product score threshold
118
+ model_id: Optional filter by model ID
119
+
120
+ Returns:
121
+ List of CacheSearchResult sorted by similarity (descending)
122
+ """
123
+ if self._index.ntotal == 0:
124
+ return []
125
+
126
+ vec = self._to_numpy(query_vec)
127
+ if vec.shape != (self.dim,):
128
+ raise ValueError(
129
+ f"Query vector dim {vec.shape} != index dim ({self.dim},)"
130
+ )
131
+
132
+ # Search more than top_k to account for filtering
133
+ search_k = min(top_k * 3, self._index.ntotal) if model_id else min(top_k, self._index.ntotal)
134
+ scores, indices = self._index.search(
135
+ vec.reshape(1, -1).astype(np.float32), search_k
136
+ )
137
+
138
+ results: list[CacheSearchResult] = []
139
+ for score, idx in zip(scores[0], indices[0]):
140
+ if idx < 0 or idx >= len(self._entries):
141
+ continue
142
+
143
+ entry = self._entries[idx]
144
+
145
+ # Skip if this cache_id has been superseded by a later add
146
+ if self._id_to_position.get(entry.cache_id) != idx:
147
+ continue
148
+
149
+ # Apply filters
150
+ if model_id and entry.model_id != model_id:
151
+ continue
152
+ if min_similarity is not None and score < min_similarity:
153
+ continue
154
+
155
+ results.append(CacheSearchResult(
156
+ cache_id=entry.cache_id,
157
+ similarity=float(score),
158
+ task_description=entry.task_description,
159
+ model_id=entry.model_id,
160
+ created_at=entry.created_at,
161
+ context_len=entry.context_len,
162
+ ))
163
+
164
+ if len(results) >= top_k:
165
+ break
166
+
167
+ return results
168
+
169
+ def remove(self, cache_id: str) -> bool:
170
+ """Mark a cache entry as removed from the index.
171
+
172
+ FAISS IndexFlat doesn't support deletion. We remove from the
173
+ metadata tracking so the entry is filtered out of search results.
174
+ The vector remains in FAISS until the next rebuild.
175
+
176
+ Args:
177
+ cache_id: ID to remove
178
+
179
+ Returns:
180
+ True if the entry was found and removed from tracking
181
+ """
182
+ if cache_id in self._id_to_position:
183
+ del self._id_to_position[cache_id]
184
+ return True
185
+ return False
186
+
187
+ def rebuild(self) -> int:
188
+ """Rebuild the index from only active entries.
189
+
190
+ Removes gaps left by remove() calls. Returns count of active entries.
191
+ """
192
+ active_positions = set(self._id_to_position.values())
193
+ if len(active_positions) == len(self._entries):
194
+ return len(active_positions) # No gaps
195
+
196
+ # Collect active vectors and entries
197
+ new_entries: list[IndexEntry] = []
198
+ vectors: list[np.ndarray] = []
199
+
200
+ for pos, entry in enumerate(self._entries):
201
+ if pos in active_positions and entry.cache_id in self._id_to_position:
202
+ if self._id_to_position[entry.cache_id] == pos:
203
+ vec = faiss.rev_swig_ptr(
204
+ self._index.get_xb(), self._index.ntotal * self.dim
205
+ ).reshape(-1, self.dim)[pos]
206
+ vectors.append(vec.copy())
207
+ new_entries.append(entry)
208
+
209
+ # Rebuild
210
+ self._index = faiss.IndexFlatIP(self.dim)
211
+ self._entries = []
212
+ self._id_to_position = {}
213
+
214
+ for vec, entry in zip(vectors, new_entries):
215
+ self.add(torch.from_numpy(vec), entry)
216
+
217
+ return self.n_entries
218
+
219
+ def save(self, path: Path) -> None:
220
+ """Persist the index to disk.
221
+
222
+ D2: Uses faiss.serialize_index() (not write_index) to avoid
223
+ platform incompatibility. Atomic write via temp file + rename.
224
+ Metadata saved as a sidecar .json file.
225
+ """
226
+ import json
227
+
228
+ path.parent.mkdir(parents=True, exist_ok=True)
229
+
230
+ # D2: serialize_index returns numpy uint8 array — write raw bytes
231
+ index_bytes: np.ndarray = faiss.serialize_index(self._index)
232
+
233
+ # Atomic write for FAISS index
234
+ tmp_path = path.with_suffix(".faiss.tmp")
235
+ try:
236
+ tmp_path.write_bytes(index_bytes.tobytes())
237
+ tmp_path.rename(path)
238
+ except Exception:
239
+ tmp_path.unlink(missing_ok=True)
240
+ raise
241
+
242
+ # Save metadata sidecar
243
+ meta_path = path.with_suffix(".meta.json")
244
+ meta_tmp = meta_path.with_suffix(".json.tmp")
245
+ try:
246
+ sidecar = {
247
+ "dim": self.dim,
248
+ "entries": [
249
+ {
250
+ "cache_id": e.cache_id,
251
+ "task_description": e.task_description,
252
+ "model_id": e.model_id,
253
+ "created_at": e.created_at,
254
+ "context_len": e.context_len,
255
+ "l2_norm": e.l2_norm,
256
+ }
257
+ for e in self._entries
258
+ ],
259
+ "id_to_position": self._id_to_position,
260
+ }
261
+ meta_tmp.write_text(json.dumps(sidecar, indent=2))
262
+ meta_tmp.rename(meta_path)
263
+ except Exception:
264
+ meta_tmp.unlink(missing_ok=True)
265
+ raise
266
+
267
+ def _load_index(self, path: Path) -> faiss.IndexFlatIP:
268
+ """Load a FAISS index and its metadata sidecar from disk.
269
+
270
+ D2: Uses faiss.deserialize_index() from raw bytes (not read_index).
271
+ """
272
+ import json
273
+
274
+ raw = np.frombuffer(path.read_bytes(), dtype=np.uint8)
275
+ index = faiss.deserialize_index(raw)
276
+
277
+ meta_path = path.with_suffix(".meta.json")
278
+ if meta_path.exists():
279
+ sidecar = json.loads(meta_path.read_text())
280
+ self._entries = [
281
+ IndexEntry(**e) for e in sidecar.get("entries", [])
282
+ ]
283
+ self._id_to_position = {
284
+ k: int(v) for k, v in sidecar.get("id_to_position", {}).items()
285
+ }
286
+
287
+ return index
288
+
289
+ @staticmethod
290
+ def _to_numpy(vec: torch.Tensor | np.ndarray) -> np.ndarray:
291
+ """Convert a vector to numpy float32."""
292
+ if isinstance(vec, torch.Tensor):
293
+ return vec.detach().cpu().float().numpy()
294
+ return vec.astype(np.float32)
kvcos/core/retriever.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engrammatic Geometry Retrieval — Retriever
3
+
4
+
5
+ Orchestrates the full EGR retrieval pipeline:
6
+ 1. Extract state vector from query KV cache (MARStateExtractor)
7
+ 2. Search manifold index for similar engram states (ManifoldIndex)
8
+ 3. Load matched .eng files from storage (StorageBackend)
9
+ 4. Return ranked results with KV tensors ready for injection
10
+
11
+ This is the primary interface agents use for retrieval.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ from dataclasses import dataclass
17
+ from pathlib import Path
18
+
19
+ import torch
20
+
21
+ from kvcos.core.serializer import EngramSerializer
22
+ from kvcos.core.types import (
23
+ CacheSearchResult,
24
+ CompressionMethod,
25
+ EngramMetadata,
26
+ ModelCacheSpec,
27
+ StateExtractionMode,
28
+ )
29
+ from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
30
+ from kvcos.core.state_extractor import ExtractionResult, MARStateExtractor
31
+ from kvcos.storage.backends import StorageBackend
32
+
33
+
34
+ @dataclass
35
+ class RetrievalResult:
36
+ """A single retrieval result with loaded KV tensors."""
37
+
38
+ cache_id: str
39
+ similarity: float
40
+ task_description: str
41
+ model_id: str
42
+ keys: torch.Tensor # [n_layers, n_kv_heads, ctx_len, head_dim]
43
+ values: torch.Tensor # [n_layers, n_kv_heads, ctx_len, head_dim]
44
+ metadata: EngramMetadata
45
+
46
+
47
+ @dataclass
48
+ class RetrievalResponse:
49
+ """Full response from a retrieval query."""
50
+
51
+ query_extraction: ExtractionResult
52
+ results: list[RetrievalResult]
53
+ n_searched: int # total entries in the index
54
+
55
+
56
+ class EGRRetriever:
57
+ """Engrammatic Geometry Retrieval — full pipeline.
58
+
59
+ Connects MARStateExtractor → ManifoldIndex → StorageBackend
60
+ into a single retrieval call.
61
+
62
+ Usage:
63
+ retriever = EGRRetriever(extractor, index, storage)
64
+
65
+ # Store an engram
66
+ retriever.index_engram(keys, values, spec, agent_id, task_desc, model_id)
67
+
68
+ # Retrieve similar engrams
69
+ response = retriever.retrieve(query_keys, spec, top_k=3)
70
+ for result in response.results:
71
+ print(result.similarity, result.task_description)
72
+ # result.keys / result.values ready for injection
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ extractor: MARStateExtractor,
78
+ index: ManifoldIndex,
79
+ storage: StorageBackend,
80
+ serializer: EngramSerializer | None = None,
81
+ ):
82
+ self.extractor = extractor
83
+ self.index = index
84
+ self.storage = storage
85
+ self._serializer = serializer or EngramSerializer()
86
+
87
+ def index_engram(
88
+ self,
89
+ keys: torch.Tensor,
90
+ values: torch.Tensor,
91
+ spec: ModelCacheSpec,
92
+ agent_id: str,
93
+ task_description: str,
94
+ model_id: str,
95
+ cache_id: str | None = None,
96
+ compression: CompressionMethod = CompressionMethod.Q8_0,
97
+ output_dir: Path | None = None,
98
+ extra_metadata: dict[str, str] | None = None,
99
+ ) -> str:
100
+ """Extract state vector, store .eng file, and add to index.
101
+
102
+ This is the "write" path: compute once → store → index → reuse forever.
103
+
104
+ Args:
105
+ keys: [n_layers, n_kv_heads, ctx_len, head_dim]
106
+ values: same shape as keys
107
+ spec: Model architecture spec
108
+ agent_id: Agent identifier
109
+ task_description: Human-readable task description (searchable)
110
+ model_id: Full model identifier
111
+ cache_id: Explicit ID (auto-generated if None)
112
+ compression: Compression method for storage
113
+ output_dir: Directory for .eng file (uses storage backend default if None)
114
+ extra_metadata: Additional metadata key-value pairs
115
+
116
+ Returns:
117
+ cache_id of the stored engram
118
+ """
119
+ import uuid
120
+ from datetime import datetime, timezone
121
+
122
+ from kvcos.core.types import ENG_FILE_EXTENSION
123
+
124
+ cid = cache_id or str(uuid.uuid4())
125
+
126
+ # 1. Extract state vector
127
+ extraction = self.extractor.extract(keys, spec)
128
+
129
+ # 2. Serialize to .eng file
130
+ if output_dir:
131
+ output_path = output_dir / f"{cid}{ENG_FILE_EXTENSION}"
132
+ else:
133
+ # Use a temp path; storage backend will move it
134
+ import tempfile
135
+ output_path = Path(tempfile.mkdtemp()) / f"{cid}{ENG_FILE_EXTENSION}"
136
+
137
+ merge_meta = {
138
+ "state_vec_norm": str(extraction.l2_norm),
139
+ "extraction_mode": extraction.mode.value,
140
+ }
141
+ if extra_metadata:
142
+ merge_meta.update(extra_metadata)
143
+
144
+ result = self._serializer.serialize(
145
+ keys=keys,
146
+ values=values,
147
+ agent_id=agent_id,
148
+ task_description=task_description,
149
+ model_id=model_id,
150
+ output_path=output_path,
151
+ compression=compression,
152
+ cache_id=cid,
153
+ extra_metadata=merge_meta,
154
+ )
155
+
156
+ # 3. Store in backend
157
+ metadata = self._serializer.read_metadata_only(output_path)
158
+ self.storage.store_file(cid, output_path, metadata)
159
+
160
+ # 4. Add to manifold index
161
+ now = datetime.now(timezone.utc).isoformat()
162
+ entry = IndexEntry(
163
+ cache_id=cid,
164
+ task_description=task_description,
165
+ model_id=model_id,
166
+ created_at=now,
167
+ context_len=keys.shape[2],
168
+ l2_norm=extraction.l2_norm,
169
+ )
170
+ self.index.add(extraction.state_vec, entry)
171
+
172
+ return cid
173
+
174
+ def retrieve(
175
+ self,
176
+ query_keys: torch.Tensor,
177
+ spec: ModelCacheSpec,
178
+ top_k: int = 5,
179
+ min_similarity: float | None = None,
180
+ model_id: str | None = None,
181
+ load_tensors: bool = True,
182
+ ) -> RetrievalResponse:
183
+ """Retrieve similar engram states for a query KV cache.
184
+
185
+ This is the "read" path: extract query vector → search index →
186
+ load matching .eng files.
187
+
188
+ Args:
189
+ query_keys: [n_layers, n_kv_heads, ctx_len, head_dim] query K cache
190
+ spec: Model architecture spec
191
+ top_k: Number of results to return
192
+ min_similarity: Minimum MIPS score threshold
193
+ model_id: Filter by model ID
194
+ load_tensors: If True, load full KV tensors from storage.
195
+ If False, return metadata only (faster for previewing).
196
+
197
+ Returns:
198
+ RetrievalResponse with ranked results
199
+ """
200
+ # 1. Extract query state vector
201
+ query_extraction = self.extractor.extract(query_keys, spec)
202
+
203
+ # 2. Search manifold index
204
+ search_results = self.index.search(
205
+ query_vec=query_extraction.state_vec,
206
+ top_k=top_k,
207
+ min_similarity=min_similarity,
208
+ model_id=model_id,
209
+ )
210
+
211
+ # 3. Load matching engrams from storage
212
+ results: list[RetrievalResult] = []
213
+ for sr in search_results:
214
+ if load_tensors:
215
+ path = self.storage.get_path(sr["cache_id"])
216
+ if path is None:
217
+ continue
218
+
219
+ try:
220
+ keys, values, metadata = self._serializer.deserialize(path)
221
+ except Exception:
222
+ continue
223
+
224
+ results.append(RetrievalResult(
225
+ cache_id=sr["cache_id"],
226
+ similarity=sr["similarity"],
227
+ task_description=sr["task_description"],
228
+ model_id=sr["model_id"],
229
+ keys=keys,
230
+ values=values,
231
+ metadata=metadata,
232
+ ))
233
+ else:
234
+ # Metadata-only mode
235
+ metadata = self.storage.get_metadata(sr["cache_id"])
236
+ if metadata is None:
237
+ continue
238
+
239
+ results.append(RetrievalResult(
240
+ cache_id=sr["cache_id"],
241
+ similarity=sr["similarity"],
242
+ task_description=sr["task_description"],
243
+ model_id=sr["model_id"],
244
+ keys=torch.empty(0),
245
+ values=torch.empty(0),
246
+ metadata=metadata,
247
+ ))
248
+
249
+ return RetrievalResponse(
250
+ query_extraction=query_extraction,
251
+ results=results,
252
+ n_searched=self.index.n_entries,
253
+ )
254
+
255
+ def delete_engram(self, cache_id: str) -> bool:
256
+ """Remove an engram from both index and storage."""
257
+ idx_removed = self.index.remove(cache_id)
258
+ store_removed = self.storage.delete(cache_id)
259
+ return idx_removed or store_removed
260
+
261
+ def save_index(self, path: Path) -> None:
262
+ """Persist the manifold index to disk."""
263
+ self.index.save(path)
kvcos/core/serializer.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — .eng File Serializer
3
+
4
+
5
+ .eng = safetensors container with:
6
+ - __metadata__: JSON-stringified EngramMetadata (all string values per D7)
7
+ - Tensor keys: layer_{i}_keys, layer_{i}_values
8
+ - Each tensor: [n_kv_heads, ctx_len, head_dim] at compressed dtype
9
+
10
+ D7: safetensors confirmed. GGUF rejected. String-only metadata values.
11
+ Reference: arXiv:2603.04428 uses identical safetensors approach.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import hashlib
17
+ import uuid
18
+ from datetime import datetime, timezone
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ import torch
23
+ from safetensors.torch import load_file, save_file
24
+
25
+ from kvcos.core.cache_spec import infer_model_family
26
+ from kvcos.core.compression import CompressionResult, compress, decompress
27
+ from kvcos.core.types import (
28
+ ENGRAM_VERSION,
29
+ ENG_FILE_EXTENSION,
30
+ CompressionMethod,
31
+ EngramMetadata,
32
+ )
33
+
34
+
35
+ class SerializationError(Exception):
36
+ """Raised when serialization or deserialization fails."""
37
+
38
+
39
+ class EngramSerializer:
40
+ """Serializes/deserializes KV cache tensors to/from .eng files.
41
+
42
+ Canonical shape for KV tensors in ENGRAM:
43
+ keys: [n_layers, n_kv_heads, ctx_len, head_dim]
44
+ values: [n_layers, n_kv_heads, ctx_len, head_dim]
45
+ """
46
+
47
+ def serialize(
48
+ self,
49
+ keys: torch.Tensor,
50
+ values: torch.Tensor,
51
+ agent_id: str,
52
+ task_description: str,
53
+ model_id: str,
54
+ output_path: Path,
55
+ compression: CompressionMethod = CompressionMethod.Q8_0,
56
+ cache_id: str | None = None,
57
+ parent_cache_id: str | None = None,
58
+ input_tokens: list[int] | None = None,
59
+ extra_metadata: dict[str, str] | None = None,
60
+ ) -> dict[str, Any]:
61
+ """Serialize KV cache tensors to a .eng file.
62
+
63
+ Args:
64
+ keys: [n_layers, n_kv_heads, ctx_len, head_dim]
65
+ values: [n_layers, n_kv_heads, ctx_len, head_dim]
66
+ agent_id: Identifier for the agent that produced this state
67
+ task_description: Human-readable description (used for EGR search)
68
+ model_id: Full model identifier
69
+ output_path: Path to write .eng file
70
+ compression: Compression method to apply
71
+ cache_id: Explicit cache ID (auto-generated if None)
72
+ parent_cache_id: ID of parent for delta chains
73
+ input_tokens: Token IDs that generated this state (for hash)
74
+ extra_metadata: Additional string key-value pairs
75
+
76
+ Returns:
77
+ Dict with cache_id, size_bytes, compression_ratio, path
78
+ """
79
+ if keys.shape != values.shape:
80
+ raise SerializationError(
81
+ f"Keys/values shape mismatch: {keys.shape} vs {values.shape}"
82
+ )
83
+ if keys.dim() != 4:
84
+ raise SerializationError(
85
+ f"Expected 4D [n_layers, n_kv_heads, ctx_len, head_dim], "
86
+ f"got {keys.dim()}D: {keys.shape}"
87
+ )
88
+
89
+ n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
90
+
91
+ tensors: dict[str, torch.Tensor] = {}
92
+
93
+ if compression == CompressionMethod.INT8:
94
+ from kvcos.core.compression import compress_int8_tensor
95
+
96
+ k_pair = compress_int8_tensor(keys)
97
+ v_pair = compress_int8_tensor(values)
98
+ for i in range(n_layers):
99
+ tensors[f"layer_{i}_keys"] = k_pair.quantized[i].contiguous()
100
+ tensors[f"layer_{i}_keys_scale"] = k_pair.scales[i].contiguous()
101
+ tensors[f"layer_{i}_values"] = v_pair.quantized[i].contiguous()
102
+ tensors[f"layer_{i}_values_scale"] = v_pair.scales[i].contiguous()
103
+ # Reuse k_compressed for metadata only — actual INT8 data is
104
+ # already written per-layer above via k_pair/v_pair.
105
+ k_compressed = compress(keys, compression)
106
+ v_compressed = k_compressed
107
+ elif compression == CompressionMethod.LAYER_DELTA:
108
+ from kvcos.core.compression import compress_layer_delta
109
+
110
+ k_ld = compress_layer_delta(keys)
111
+ v_ld = compress_layer_delta(values)
112
+ # Layer 0: fp16 baseline
113
+ tensors["layer_0_keys"] = k_ld.baseline.contiguous()
114
+ tensors["layer_0_values"] = v_ld.baseline.contiguous()
115
+ # Layers 1..N: int8 deltas + fp16 scales
116
+ for i in range(n_layers - 1):
117
+ tensors[f"layer_{i+1}_keys"] = k_ld.delta_quantized[i].contiguous()
118
+ tensors[f"layer_{i+1}_keys_scale"] = k_ld.delta_scales[i].contiguous()
119
+ tensors[f"layer_{i+1}_values"] = v_ld.delta_quantized[i].contiguous()
120
+ tensors[f"layer_{i+1}_values_scale"] = v_ld.delta_scales[i].contiguous()
121
+ # Reuse k_compressed for metadata only — actual layer-delta data
122
+ # is already written above via k_ld/v_ld.
123
+ k_compressed = compress(keys, compression)
124
+ v_compressed = k_compressed
125
+ else:
126
+ k_compressed = compress(keys, compression)
127
+ v_compressed = compress(values, compression)
128
+ for i in range(n_layers):
129
+ tensors[f"layer_{i}_keys"] = k_compressed.data[i].contiguous()
130
+ tensors[f"layer_{i}_values"] = v_compressed.data[i].contiguous()
131
+
132
+ cid = cache_id or str(uuid.uuid4())
133
+ now = datetime.now(timezone.utc).isoformat()
134
+
135
+ token_hash = ""
136
+ if input_tokens:
137
+ token_bytes = b"".join(t.to_bytes(4, "little") for t in input_tokens)
138
+ token_hash = f"sha256:{hashlib.sha256(token_bytes).hexdigest()}"
139
+
140
+ metadata: EngramMetadata = {
141
+ "engram_version": ENGRAM_VERSION,
142
+ "cache_id": cid,
143
+ "compression": compression.value,
144
+ "model_id": model_id,
145
+ "model_family": infer_model_family(model_id),
146
+ "n_layers": str(n_layers),
147
+ "n_heads": str(n_kv_heads),
148
+ "n_kv_heads": str(n_kv_heads),
149
+ "head_dim": str(head_dim),
150
+ "context_len": str(ctx_len),
151
+ "agent_id": agent_id,
152
+ "task_description": task_description,
153
+ "created_at": now,
154
+ }
155
+
156
+ if parent_cache_id:
157
+ metadata["parent_cache_id"] = parent_cache_id
158
+ if token_hash:
159
+ metadata["token_hash"] = token_hash
160
+ for key, val in k_compressed.metadata.items():
161
+ metadata[f"compression_{key}"] = val # type: ignore[literal-required]
162
+ if extra_metadata:
163
+ for key, val in extra_metadata.items():
164
+ metadata[key] = val # type: ignore[literal-required]
165
+
166
+ output_path.parent.mkdir(parents=True, exist_ok=True)
167
+
168
+ str_metadata: dict[str, str] = {k: str(v) for k, v in metadata.items()}
169
+ save_file(tensors, str(output_path), metadata=str_metadata)
170
+
171
+ original_bytes = (keys.numel() + values.numel()) * keys.element_size()
172
+ compressed_bytes = output_path.stat().st_size
173
+
174
+ return {
175
+ "cache_id": cid,
176
+ "size_bytes": compressed_bytes,
177
+ "compression_ratio": original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
178
+ "path": str(output_path),
179
+ "n_layers": n_layers,
180
+ "context_len": ctx_len,
181
+ }
182
+
183
+ def deserialize(
184
+ self,
185
+ path: Path,
186
+ target_compression: CompressionMethod | None = None,
187
+ ) -> tuple[torch.Tensor, torch.Tensor, EngramMetadata]:
188
+ """Deserialize a .eng file into KV cache tensors.
189
+
190
+ Returns (keys, values, metadata) where tensors are
191
+ [n_layers, n_kv_heads, ctx_len, head_dim].
192
+ """
193
+ if not path.exists():
194
+ raise SerializationError(f"Engram file not found: {path}")
195
+
196
+ tensors = load_file(str(path))
197
+ metadata = self._read_metadata(path)
198
+
199
+ n_layers = int(metadata.get("n_layers", "0"))
200
+ if n_layers == 0:
201
+ n_layers = (
202
+ max(int(k.split("_")[1]) for k in tensors if k.startswith("layer_")) + 1
203
+ )
204
+
205
+ stored_compression = metadata.get("compression", "fp16")
206
+ is_int8 = stored_compression == CompressionMethod.INT8.value
207
+ is_layer_delta = stored_compression == CompressionMethod.LAYER_DELTA.value
208
+
209
+ k_layers: list[torch.Tensor] = []
210
+ v_layers: list[torch.Tensor] = []
211
+
212
+ if is_layer_delta:
213
+ from kvcos.core.compression import decompress_int8_tensor
214
+
215
+ # Layer 0: fp16 baseline
216
+ k_layers.append(tensors["layer_0_keys"].float())
217
+ v_layers.append(tensors["layer_0_values"].float())
218
+ # Layers 1..N: accumulate int8 deltas
219
+ for i in range(1, n_layers):
220
+ k_delta = decompress_int8_tensor(
221
+ tensors[f"layer_{i}_keys"], tensors[f"layer_{i}_keys_scale"]
222
+ )
223
+ v_delta = decompress_int8_tensor(
224
+ tensors[f"layer_{i}_values"], tensors[f"layer_{i}_values_scale"]
225
+ )
226
+ k_layers.append(k_layers[-1] + k_delta.float())
227
+ v_layers.append(v_layers[-1] + v_delta.float())
228
+ keys = torch.stack([l.to(torch.float16) for l in k_layers], dim=0)
229
+ values = torch.stack([l.to(torch.float16) for l in v_layers], dim=0)
230
+ else:
231
+ for i in range(n_layers):
232
+ k_key = f"layer_{i}_keys"
233
+ v_key = f"layer_{i}_values"
234
+ if k_key not in tensors or v_key not in tensors:
235
+ raise SerializationError(f"Missing tensor for layer {i}")
236
+
237
+ if is_int8:
238
+ from kvcos.core.compression import decompress_int8_tensor
239
+
240
+ k_scale_key = f"layer_{i}_keys_scale"
241
+ v_scale_key = f"layer_{i}_values_scale"
242
+ if k_scale_key not in tensors or v_scale_key not in tensors:
243
+ raise SerializationError(f"Missing INT8 scale for layer {i}")
244
+ k_layers.append(decompress_int8_tensor(tensors[k_key], tensors[k_scale_key]))
245
+ v_layers.append(decompress_int8_tensor(tensors[v_key], tensors[v_scale_key]))
246
+ else:
247
+ k_layers.append(tensors[k_key])
248
+ v_layers.append(tensors[v_key])
249
+
250
+ keys = torch.stack(k_layers, dim=0)
251
+ values = torch.stack(v_layers, dim=0)
252
+
253
+ if target_compression is not None:
254
+ stored = CompressionMethod(metadata.get("compression", "fp16"))
255
+ keys = decompress(keys, stored)
256
+ values = decompress(values, stored)
257
+
258
+ return keys, values, metadata # type: ignore[return-value]
259
+
260
+ def _read_metadata(self, path: Path) -> dict[str, str]:
261
+ """Read only the metadata header (no tensor data loaded)."""
262
+ from safetensors import safe_open
263
+
264
+ metadata: dict[str, str] = {}
265
+ with safe_open(str(path), framework="pt") as f:
266
+ raw_meta = f.metadata()
267
+ if raw_meta:
268
+ metadata = dict(raw_meta)
269
+ return metadata
270
+
271
+ def read_metadata_only(self, path: Path) -> EngramMetadata:
272
+ """Read just the metadata from a .eng file. Efficient for indexing."""
273
+ raw = self._read_metadata(path)
274
+ return raw # type: ignore[return-value]
kvcos/core/state_extractor.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Engrammatic Geometry Retrieval — State Extraction Layer
3
+
4
+
5
+ Extracts a retrieval state vector from a KV cache tensor for MIPS-based
6
+ retrieval in EGR (Engrammatic Geometry Retrieval). The state vector is
7
+ a compact geometric fingerprint of a cognitive state — positioned in the
8
+ model's own pre-RoPE key manifold for geometrically consistent retrieval.
9
+
10
+ Three extraction modes:
11
+
12
+ mean_pool: Fast baseline. Mean over heads + context of key matrices
13
+ across extraction layers. Output: [head_dim]. No learned
14
+ parameters. Use for bootstrapping and smoke tests.
15
+
16
+ svd_project: Truncated SVD on pre-RoPE keys, extraction layers (D3: 8-31),
17
+ rank-160 for 8B models. Validated by ShadowKV (ICML 2025,
18
+ ByteDance) on Llama-3.1-8B and Phi-3-Mini-128K.
19
+ Output: [rank]. Projection is prompt-dependent — W computed
20
+ per cache via online SVD, not precomputed globally.
21
+ Reference: github.com/ByteDance-Seed/ShadowKV
22
+
23
+ xkv_project: Grouped cross-layer SVD. Groups 4 adjacent extraction layers,
24
+ extracts shared basis vectors across the group. Achieves
25
+ 6.8x compression vs 2.5x single-layer SVD. K:V rank ratio
26
+ 1:1.5 is optimal per xKV paper.
27
+ Reference: github.com/abdelfattah-lab/xKV
28
+ arXiv:2503.18893
29
+
30
+ REMOVED: sals_project — last-layer-only extraction invalidated by
31
+ Layer-Condensed KV Cache (ACL 2024). See D3.
32
+
33
+ D4: No L2 normalization. True MIPS. L2 norm stored as metadata for
34
+ optional downstream use.
35
+ """
36
+
37
+ from __future__ import annotations
38
+
39
+ from dataclasses import dataclass, field
40
+
41
+ import torch
42
+ from einops import rearrange
43
+
44
+ from kvcos.core.types import (
45
+ DEFAULT_SVD_RANK,
46
+ ModelCacheSpec,
47
+ StateExtractionMode,
48
+ )
49
+
50
+
51
+ @dataclass
52
+ class ExtractionResult:
53
+ """Result of state vector extraction from a KV cache."""
54
+
55
+ state_vec: torch.Tensor # [d_out] — the retrieval vector
56
+ l2_norm: float # stored as metadata per D4
57
+ mode: StateExtractionMode
58
+ n_layers_used: int
59
+ n_tokens: int
60
+
61
+
62
+ @dataclass
63
+ class SVDProjection:
64
+ """Learned SVD projection matrix for a specific cache.
65
+
66
+ ShadowKV finding: pre-RoPE keys share low-rank subspaces WITHIN
67
+ sequences but differ ACROSS sequences. Projection must be computed
68
+ online per cache, not precomputed globally.
69
+ """
70
+
71
+ W: torch.Tensor # [head_dim, rank] — right singular vectors
72
+ singular_values: torch.Tensor # [rank] — for diagnostics
73
+ explained_variance_ratio: float # fraction of variance captured
74
+ source_shape: tuple[int, ...] # shape of the keys used to compute this
75
+
76
+
77
+ class MARStateExtractor:
78
+ """Extracts retrieval state vectors from KV cache tensors for EGR.
79
+
80
+ Usage:
81
+ extractor = MARStateExtractor(mode="svd_project", rank=160)
82
+ result = extractor.extract(keys, spec)
83
+ # result.state_vec is the retrieval vector for FAISS IndexFlatIP
84
+ # result.l2_norm goes into .eng metadata (D4)
85
+ """
86
+
87
+ # Max rows fed to SVD. 8192 rows on a 128-dim matrix runs in ~15ms
88
+ # vs ~2000ms for the full 786K-row matrix. Subspace quality is
89
+ # preserved because SVD only needs O(head_dim²) samples to recover
90
+ # the top singular vectors of a low-rank matrix.
91
+ MAX_SVD_ROWS: int = 8192
92
+
93
+ def __init__(
94
+ self,
95
+ mode: StateExtractionMode = StateExtractionMode.SVD_PROJECT,
96
+ rank: int = DEFAULT_SVD_RANK,
97
+ xkv_group_size: int = 4,
98
+ xkv_kv_rank_ratio: float = 1.5,
99
+ max_svd_rows: int | None = None,
100
+ layer_range: tuple[int, int] | None = None,
101
+ gate_start: int = 0,
102
+ ):
103
+ self.mode = mode
104
+ self.rank = rank
105
+ self.xkv_group_size = xkv_group_size
106
+ self.xkv_kv_rank_ratio = xkv_kv_rank_ratio
107
+ self.max_svd_rows = max_svd_rows or self.MAX_SVD_ROWS
108
+ # Override spec extraction_layers when set. (8, 24) uses middle
109
+ # layers which encode semantic content (Tenney 2019, Huh 2024).
110
+ self.layer_range = layer_range
111
+ # Skip top gate_start singular values in SVD projection.
112
+ # Top SVs encode shared positional/syntactic structure;
113
+ # skipping them isolates semantic content (gate_start=6 optimal).
114
+ self.gate_start = gate_start
115
+
116
+ # Cached projection from last extract call (for inspection/reuse)
117
+ self._last_projection: SVDProjection | None = None
118
+
119
+ def extract(
120
+ self,
121
+ keys: torch.Tensor,
122
+ spec: ModelCacheSpec,
123
+ ) -> ExtractionResult:
124
+ """Extract a state vector from KV cache key tensors.
125
+
126
+ Args:
127
+ keys: [n_layers, n_kv_heads, ctx_len, head_dim] — the K cache.
128
+ Must be pre-RoPE if available. Post-RoPE works but with
129
+ reduced retrieval quality due to position-dependent distortion.
130
+ spec: Model architecture spec (provides extraction_layers).
131
+
132
+ Returns:
133
+ ExtractionResult with state vector and metadata.
134
+ """
135
+ n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
136
+
137
+ # Layer selection: layer_range overrides spec extraction_layers
138
+ if self.layer_range is not None:
139
+ start, end = self.layer_range
140
+ start = max(0, min(start, n_layers))
141
+ end = max(start, min(end, n_layers))
142
+ layer_indices = list(range(start, end))
143
+ else:
144
+ extraction_layers = spec["extraction_layers"]
145
+ layer_indices = [l for l in extraction_layers if l < n_layers]
146
+
147
+ if not layer_indices:
148
+ layer_indices = list(range(n_layers))
149
+
150
+ selected_keys = keys[layer_indices] # [n_selected, n_kv_heads, ctx_len, head_dim]
151
+
152
+ match self.mode:
153
+ case StateExtractionMode.MEAN_POOL:
154
+ state_vec = self._mean_pool(selected_keys)
155
+ case StateExtractionMode.SVD_PROJECT:
156
+ state_vec = self._svd_project(selected_keys)
157
+ case StateExtractionMode.XKV_PROJECT:
158
+ state_vec = self._xkv_project(selected_keys)
159
+ case _:
160
+ raise ValueError(f"Unknown extraction mode: {self.mode}")
161
+
162
+ # D4: No normalization. True MIPS. Store norm as metadata.
163
+ l2_norm = float(torch.linalg.vector_norm(state_vec).item())
164
+
165
+ return ExtractionResult(
166
+ state_vec=state_vec,
167
+ l2_norm=l2_norm,
168
+ mode=self.mode,
169
+ n_layers_used=len(layer_indices),
170
+ n_tokens=ctx_len,
171
+ )
172
+
173
+ def _mean_pool(self, keys: torch.Tensor) -> torch.Tensor:
174
+ """Fast baseline: mean over layers, heads, and context positions.
175
+
176
+ Input: [n_layers, n_kv_heads, ctx_len, head_dim]
177
+ Output: [head_dim]
178
+ """
179
+ return keys.float().mean(dim=(0, 1, 2))
180
+
181
+ def _svd_project(self, keys: torch.Tensor) -> torch.Tensor:
182
+ """Truncated SVD projection on pre-RoPE keys.
183
+
184
+ ShadowKV approach: flatten all extraction layers' keys into a 2D matrix
185
+ [N, head_dim], compute truncated SVD, project onto top-rank singular vectors,
186
+ then mean-pool the projected vectors.
187
+
188
+ For large contexts (N > max_svd_rows), we subsample rows before SVD.
189
+ SVD only needs O(head_dim²) samples to recover the top singular vectors
190
+ of a low-rank matrix, so subsampling to 8K rows preserves subspace quality
191
+ while reducing SVD from ~2000ms to ~15ms at 4K context.
192
+
193
+ Input: [n_layers, n_kv_heads, ctx_len, head_dim]
194
+ Output: [rank]
195
+ """
196
+ n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
197
+
198
+ # Total rows in the flattened matrix
199
+ n_rows = n_layers * n_kv_heads * ctx_len
200
+
201
+ if n_rows > self.max_svd_rows:
202
+ # Subsample BEFORE flatten+cast to avoid allocating the full
203
+ # float32 matrix (saves ~30ms rearrange + 100MB at 4K context).
204
+ gen = torch.Generator()
205
+ gen.manual_seed(42)
206
+ indices = torch.randperm(n_rows, generator=gen)[:self.max_svd_rows]
207
+ flat_keys = keys.reshape(n_rows, head_dim)[indices].float()
208
+ svd_input = flat_keys
209
+ else:
210
+ flat_keys = rearrange(keys.float(), 'l h t d -> (l h t) d')
211
+ svd_input = flat_keys
212
+
213
+ # Clamp rank to not exceed matrix dimensions
214
+ max_rank = min(head_dim, svd_input.shape[0])
215
+ effective_rank = min(self.gate_start + self.rank, max_rank)
216
+
217
+ # Truncated SVD on (subsampled) matrix
218
+ U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
219
+
220
+ # W = right singular vectors with gating: skip top gate_start SVs
221
+ # to remove shared positional/syntactic structure
222
+ W = Vh[self.gate_start:effective_rank, :].T
223
+
224
+ # Store projection for inspection
225
+ total_var = (S ** 2).sum()
226
+ explained_var = (S[:effective_rank] ** 2).sum()
227
+ self._last_projection = SVDProjection(
228
+ W=W,
229
+ singular_values=S[:effective_rank],
230
+ explained_variance_ratio=float((explained_var / total_var).item()) if total_var > 0 else 0.0,
231
+ source_shape=tuple(keys.shape),
232
+ )
233
+
234
+ # Project subsampled rows and mean-pool → [rank]
235
+ # Using the subsample for projection too avoids the expensive
236
+ # 786K × 128 matmul + mean that dominates at large contexts.
237
+ projected = svd_input @ W
238
+ state_vec = projected.mean(dim=0)
239
+
240
+ return state_vec
241
+
242
+ def _xkv_project(self, keys: torch.Tensor) -> torch.Tensor:
243
+ """Grouped cross-layer SVD (xKV approach).
244
+
245
+ Groups adjacent layers (default 4), computes shared SVD basis
246
+ per group, projects keys onto that basis, then concatenates
247
+ group state vectors.
248
+
249
+ This captures cross-layer structure that single-layer SVD misses.
250
+ Achieves 6.8x vs 2.5x for single-layer SVD on Llama-3.1-8B.
251
+
252
+ K:V rank ratio 1:1.5 is optimal per xKV paper, but since we
253
+ only index keys (D2: K→K retrieval), we use the K rank only.
254
+
255
+ Input: [n_layers, n_kv_heads, ctx_len, head_dim]
256
+ Output: [n_groups * rank_per_group]
257
+ """
258
+ n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
259
+
260
+ # Compute rank per group
261
+ # xKV finding: K rank is lower than V rank by factor 1:1.5
262
+ # For 160 total rank budget across groups, allocate per group
263
+ n_groups = max(1, n_layers // self.xkv_group_size)
264
+ rank_per_group = max(1, self.rank // n_groups)
265
+ rank_per_group = min(rank_per_group, head_dim)
266
+
267
+ group_vecs: list[torch.Tensor] = []
268
+
269
+ for g in range(n_groups):
270
+ start = g * self.xkv_group_size
271
+ end = min(start + self.xkv_group_size, n_layers)
272
+ group_keys = keys[start:end] # [group_size, n_kv_heads, ctx_len, head_dim]
273
+
274
+ # Flatten group
275
+ n_group_rows = group_keys.shape[0] * n_kv_heads * ctx_len
276
+
277
+ if n_group_rows > self.max_svd_rows:
278
+ gen = torch.Generator()
279
+ gen.manual_seed(42 + g)
280
+ indices = torch.randperm(n_group_rows, generator=gen)[:self.max_svd_rows]
281
+ svd_input = group_keys.reshape(n_group_rows, head_dim)[indices].float()
282
+ else:
283
+ svd_input = rearrange(group_keys.float(), 'l h t d -> (l h t) d')
284
+
285
+ effective_rank = min(rank_per_group, svd_input.shape[0], head_dim)
286
+
287
+ # Truncated SVD for this group (on subsampled data)
288
+ U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
289
+ W_group = Vh[:effective_rank, :].T # [head_dim, rank_per_group]
290
+
291
+ # Project subsampled rows and mean-pool → [rank_per_group]
292
+ projected = svd_input @ W_group
293
+ group_vec = projected.mean(dim=0)
294
+ group_vecs.append(group_vec)
295
+
296
+ # Handle remainder layers (if n_layers not divisible by group_size)
297
+ remainder_start = n_groups * self.xkv_group_size
298
+ if remainder_start < n_layers:
299
+ remainder_keys = keys[remainder_start:]
300
+ n_rem_rows = remainder_keys.shape[0] * n_kv_heads * ctx_len
301
+
302
+ if n_rem_rows > self.max_svd_rows:
303
+ gen = torch.Generator()
304
+ gen.manual_seed(42 + n_groups)
305
+ indices = torch.randperm(n_rem_rows, generator=gen)[:self.max_svd_rows]
306
+ svd_input = remainder_keys.reshape(n_rem_rows, head_dim)[indices].float()
307
+ else:
308
+ svd_input = rearrange(remainder_keys.float(), 'l h t d -> (l h t) d')
309
+
310
+ effective_rank = min(rank_per_group, svd_input.shape[0], head_dim)
311
+ U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
312
+ W_rem = Vh[:effective_rank, :].T
313
+ projected = svd_input @ W_rem
314
+ group_vecs.append(projected.mean(dim=0))
315
+
316
+ # Concatenate all group vectors → [n_groups * rank_per_group + remainder]
317
+ state_vec = torch.cat(group_vecs, dim=0)
318
+
319
+ return state_vec
320
+
321
+ # ── Fixed Corpus Basis (FCB) ────────────────────────────────────────────
322
+
323
+ @classmethod
324
+ def compute_corpus_basis(
325
+ cls,
326
+ key_tensors: list[torch.Tensor],
327
+ layer_range: tuple[int, int],
328
+ gate_start: int,
329
+ rank: int,
330
+ max_rows: int = 32768,
331
+ seed: int = 42,
332
+ ) -> torch.Tensor:
333
+ """Compute a fixed projection matrix from a corpus of key tensors.
334
+
335
+ Returns P: [rank, head_dim] — the global semantic basis.
336
+ Unlike per-document SVD, this basis is document-independent.
337
+ All documents projected with P exist in the same coordinate system,
338
+ enabling stable cross-document and cross-model comparison.
339
+ """
340
+ l_start, l_end = layer_range
341
+ gen = torch.Generator()
342
+ gen.manual_seed(seed)
343
+
344
+ all_rows: list[torch.Tensor] = []
345
+ per_doc_max = max(1, max_rows // len(key_tensors))
346
+
347
+ for keys in key_tensors:
348
+ k = keys[l_start:l_end].float()
349
+ n_rows = k.shape[0] * k.shape[1] * k.shape[2]
350
+ flat = k.reshape(n_rows, k.shape[3])
351
+ if flat.shape[0] > per_doc_max:
352
+ idx = torch.randperm(flat.shape[0], generator=gen)[:per_doc_max]
353
+ flat = flat[idx]
354
+ all_rows.append(flat)
355
+
356
+ corpus = torch.cat(all_rows, dim=0)
357
+ if corpus.shape[0] > max_rows:
358
+ idx = torch.randperm(corpus.shape[0], generator=gen)[:max_rows]
359
+ corpus = corpus[idx]
360
+
361
+ _, S, Vh = torch.linalg.svd(corpus, full_matrices=False)
362
+ P = Vh[gate_start : gate_start + rank] # [rank, head_dim]
363
+ return P
364
+
365
+ def extract_with_basis(
366
+ self,
367
+ keys: torch.Tensor,
368
+ spec: ModelCacheSpec,
369
+ basis: torch.Tensor,
370
+ ) -> ExtractionResult:
371
+ """Extract state vector using a pre-computed fixed corpus basis.
372
+
373
+ All vectors computed with the same basis share a coordinate system,
374
+ which is required for cross-model transfer via adapter.
375
+
376
+ Args:
377
+ keys: [n_layers, n_kv_heads, n_cells, head_dim]
378
+ spec: Model spec (used for layer_range fallback)
379
+ basis: [rank, head_dim] from compute_corpus_basis()
380
+
381
+ Returns:
382
+ ExtractionResult with L2-normalized state vector
383
+ """
384
+ if self.layer_range is not None:
385
+ l_start, l_end = self.layer_range
386
+ else:
387
+ l_start, l_end = 0, keys.shape[0]
388
+ l_start = max(0, min(l_start, keys.shape[0]))
389
+ l_end = max(l_start, min(l_end, keys.shape[0]))
390
+
391
+ k = keys[l_start:l_end].float()
392
+ n_rows = k.shape[0] * k.shape[1] * k.shape[2]
393
+ flat = k.reshape(n_rows, k.shape[3])
394
+
395
+ proj = flat @ basis.T # [N_rows, rank]
396
+ vec = proj.mean(dim=0) # [rank]
397
+
398
+ norm = float(torch.linalg.vector_norm(vec).item())
399
+ vec_normed = vec / (norm + 1e-8)
400
+
401
+ return ExtractionResult(
402
+ state_vec=vec_normed.to(torch.float32),
403
+ l2_norm=norm,
404
+ mode=self.mode,
405
+ n_layers_used=l_end - l_start,
406
+ n_tokens=k.shape[2],
407
+ )
408
+
409
+ # ── Fourier Fingerprint (Engram Absolute) ────────────────────────
410
+
411
+ @staticmethod
412
+ def compute_fourier_fingerprint(
413
+ keys: torch.Tensor,
414
+ freqs: tuple[int, ...] = (0, 1),
415
+ ) -> torch.Tensor:
416
+ """Compute the Fourier Absolute fingerprint from KV cache keys.
417
+
418
+ Takes the real DFT over the layer dimension, extracts the
419
+ amplitude at the specified frequencies, normalizes each, and
420
+ concatenates them into a single fingerprint vector.
421
+
422
+ This fingerprint is:
423
+ - Cross-model invariant (cos ~0.90 between 3B and 8B)
424
+ - Corpus-independent (no basis, no center, no training)
425
+ - Scale-stable (98% recall@1 at N=1000, decay N^-0.207)
426
+
427
+ Args:
428
+ keys: [n_layers, n_kv_heads, n_cells, head_dim] — full KV keys.
429
+ All layers are used (not sliced by layer_range).
430
+ freqs: Frequency indices to extract. Default (0, 1) = DC + 1st harmonic.
431
+ f=0 captures overall key magnitude profile.
432
+ f=1 captures dominant oscillation across depth.
433
+
434
+ Returns:
435
+ Fingerprint vector [dim * len(freqs)], L2-normalized.
436
+ """
437
+ # Mean over cells (tokens) per layer: [n_layers, n_kv_heads * head_dim]
438
+ n_layers = keys.shape[0]
439
+ layer_means = keys.float().mean(dim=2).reshape(n_layers, -1)
440
+
441
+ # DFT over layer dimension
442
+ F_complex = torch.fft.rfft(layer_means, dim=0) # [n_freq, dim]
443
+ F_amp = F_complex.abs() # amplitude spectrum
444
+
445
+ # Extract and normalize each frequency component
446
+ parts = []
447
+ for f in freqs:
448
+ if f >= F_amp.shape[0]:
449
+ # Frequency out of range — use zeros
450
+ parts.append(torch.zeros(F_amp.shape[1]))
451
+ else:
452
+ v = F_amp[f]
453
+ parts.append(v / (v.norm() + 1e-8))
454
+
455
+ fingerprint = torch.cat(parts, dim=0)
456
+ return fingerprint / (fingerprint.norm() + 1e-8)
457
+
458
+ @property
459
+ def last_projection(self) -> SVDProjection | None:
460
+ """Access the SVD projection from the last svd_project call.
461
+
462
+ Useful for diagnostics: check explained_variance_ratio to validate
463
+ that the rank is sufficient for this particular cache.
464
+ """
465
+ return self._last_projection
466
+
467
+ def output_dim(self, spec: ModelCacheSpec) -> int:
468
+ """Compute the output dimension of the state vector for a given spec.
469
+
470
+ This is needed to initialize the FAISS index with the correct dimension.
471
+ """
472
+ match self.mode:
473
+ case StateExtractionMode.MEAN_POOL:
474
+ return spec["head_dim"]
475
+ case StateExtractionMode.SVD_PROJECT:
476
+ max_rank = min(self.gate_start + self.rank, spec["head_dim"])
477
+ return max_rank - self.gate_start
478
+ case StateExtractionMode.XKV_PROJECT:
479
+ extraction_layers = spec["extraction_layers"]
480
+ n_layers = len(extraction_layers)
481
+ n_groups = max(1, n_layers // self.xkv_group_size)
482
+ rank_per_group = max(1, self.rank // n_groups)
483
+ rank_per_group = min(rank_per_group, spec["head_dim"])
484
+ # Groups + possible remainder group
485
+ has_remainder = (n_layers % self.xkv_group_size) != 0
486
+ total_groups = n_groups + (1 if has_remainder else 0)
487
+ return total_groups * rank_per_group
488
+ case _:
489
+ raise ValueError(f"Unknown mode: {self.mode}")
kvcos/core/types.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Core Type Definitions
3
+
4
+
5
+ All enums, TypedDicts, constants, and type aliases live here.
6
+ Every downstream module imports from this file. No circular dependencies.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from dataclasses import dataclass
12
+ from enum import StrEnum
13
+ from typing import TypedDict
14
+
15
+ # ── Constants ─────────────────────────────────────────────────────────────────
16
+
17
+ ENGRAM_VERSION = "0.1.0"
18
+ ENG_FILE_EXTENSION = ".eng" # ENGRAM file format extension
19
+ BLOCK_SIZE_TOKENS = 256 # 256-token blocks per arXiv:2603.04428
20
+ DEFAULT_SVD_RANK = 160 # ShadowKV default for 8B models
21
+ DEFAULT_LATENT_DIM = 512 # MLA: full KV info recoverable from 512-dim
22
+ MAX_CONTEXT_TOKENS = 131072 # 128K max supported context
23
+
24
+
25
+ # ── Enums ─────────────────────────────────────────────────────────────────────
26
+
27
+
28
+ class CompressionMethod(StrEnum):
29
+ """Supported KV cache compression methods.
30
+
31
+ Phase 1: Q8_0, FP16
32
+ Phase 2: POLARQUANT (TurboQuant without QJL — QJL removed per D5)
33
+ """
34
+
35
+ FP16 = "fp16"
36
+ Q8_0 = "q8_0" # llama.cpp GGML_TYPE_Q8_0: ~2x compression, <5% speed hit
37
+ Q4_0 = "q4_0" # NOT recommended at 64K+ (92% dequant slowdown)
38
+ POLARQUANT = "polarquant_3bit" # PolarQuant only, no QJL
39
+ INT8 = "int8" # Phase 2: true int8 + per-row scale, 2x on-disk compression
40
+ LAYER_DELTA = "layer_delta" # Phase 2: fp16 baseline + int8 inter-layer deltas
41
+
42
+
43
+ class StorageBackend(StrEnum):
44
+ """Supported storage backends."""
45
+
46
+ LOCAL = "local"
47
+ REDIS = "redis" # Phase 2
48
+ S3 = "s3" # Phase 2
49
+
50
+
51
+ class StateExtractionMode(StrEnum):
52
+ """EGR (Engrammatic Geometry Retrieval) state vector extraction modes.
53
+
54
+ mean_pool: Fast baseline. Mean over heads + context of key matrices.
55
+ svd_project: Truncated SVD on pre-RoPE keys, layers 8-31, rank-160.
56
+ Validated by ShadowKV (ICML 2025) on Llama-3.1-8B.
57
+ xkv_project: Grouped cross-layer SVD, 4-layer groups, K:V rank 1:1.5.
58
+ From xKV (arXiv:2503.18893). 6.8x compression.
59
+
60
+ REMOVED: sals_project — last-layer-only extraction invalidated by
61
+ Layer-Condensed KV Cache (ACL 2024). See D3.
62
+ """
63
+
64
+ MEAN_POOL = "mean_pool"
65
+ SVD_PROJECT = "svd_project"
66
+ XKV_PROJECT = "xkv_project"
67
+
68
+
69
+ class IndexBackend(StrEnum):
70
+ """EGR manifold index backends."""
71
+
72
+ FAISS_FLAT_IP = "faiss_flat_ip" # Phase 1: exact MIPS
73
+ FAISS_IVF_IP = "faiss_ivf_ip" # Phase 2: approximate MIPS for >100K vectors
74
+ QDRANT_DOT = "qdrant_dot" # Phase 2: production persistent index
75
+
76
+
77
+ class AttentionType(StrEnum):
78
+ """KV cache attention mechanism per layer group.
79
+
80
+ Standard models use FULL for all layers.
81
+ ISWA models (Gemma 4) interleave FULL (global) and SLIDING (SWA) sections.
82
+ """
83
+
84
+ FULL = "full" # Full-context attention (standard)
85
+ SLIDING = "sliding" # Sliding window attention (SWA)
86
+
87
+
88
+ # ── Data Classes ─────────────────────────────────────────────────────────────
89
+
90
+
91
+ @dataclass(frozen=True)
92
+ class CacheSection:
93
+ """One section of a multi-section KV cache.
94
+
95
+ Standard models have a single implicit section covering all layers.
96
+ ISWA models serialize multiple sections sequentially in the state blob,
97
+ each with its own n_layers, n_kv_heads, and head_dim.
98
+
99
+ Reverse-engineered from Gemma 4 26B-A4B (llama.cpp b5200+):
100
+ Section 0 (Global): 5 layers, 2 KV heads, head_dim=512
101
+ Section 1 (SWA): 25 layers, 8 KV heads, head_dim=256
102
+ """
103
+
104
+ attention_type: AttentionType
105
+ n_layers: int
106
+ n_kv_heads: int
107
+ head_dim: int
108
+ window_size: int | None = None # SWA window size in tokens (None for full)
109
+
110
+ @property
111
+ def n_embd_kv(self) -> int:
112
+ """Total KV embedding dimension for this section."""
113
+ return self.n_kv_heads * self.head_dim
114
+
115
+
116
+ # ── TypedDicts ────────────────────────────────────────────────────────────────
117
+
118
+
119
+ class _ModelCacheSpecRequired(TypedDict):
120
+ """Required fields for ModelCacheSpec (internal base)."""
121
+
122
+ model_id: str # e.g. "meta-llama/Llama-3.1-8B-Instruct"
123
+ model_family: str # e.g. "llama"
124
+ n_layers: int # total transformer layers
125
+ n_heads: int # query heads (may differ from KV heads in GQA)
126
+ n_kv_heads: int # key/value heads (GQA-aware)
127
+ head_dim: int # dimension per head
128
+ rope_enabled: bool # whether model uses RoPE
129
+ extraction_layers: tuple[int, ...] # layers for EGR state extraction (D3)
130
+
131
+
132
+ class ModelCacheSpec(_ModelCacheSpecRequired, total=False):
133
+ """Architecture-agnostic specification of a model's KV cache layout.
134
+
135
+ Used to validate .eng files and ensure correct tensor shapes.
136
+
137
+ For standard models (Llama, Phi, Qwen, Mistral):
138
+ n_kv_heads and head_dim describe the single uniform KV cache.
139
+ cache_sections is absent.
140
+
141
+ For ISWA models (Gemma 4):
142
+ cache_sections lists per-section dimensions. Each section has its
143
+ own n_layers, n_kv_heads, and head_dim. The top-level n_kv_heads
144
+ and head_dim reflect the dominant (largest) section.
145
+ The state blob contains multiple sequential KV streams.
146
+ """
147
+
148
+ cache_sections: tuple[CacheSection, ...]
149
+
150
+
151
+ class EngramMetadata(TypedDict, total=False):
152
+ """Metadata stored in .eng file header (safetensors __metadata__).
153
+
154
+ All values are strings per safetensors spec (D7).
155
+ Optional fields use total=False.
156
+ """
157
+
158
+ # Required
159
+ engram_version: str
160
+ cache_id: str
161
+ compression: str # CompressionMethod value
162
+ model_id: str
163
+ model_family: str
164
+ n_layers: str # stringified int
165
+ n_heads: str
166
+ n_kv_heads: str
167
+ head_dim: str
168
+ context_len: str
169
+ agent_id: str
170
+ task_description: str
171
+ created_at: str # ISO 8601
172
+
173
+ # Optional
174
+ expires_at: str
175
+ parent_cache_id: str
176
+ delta_from: str
177
+ token_hash: str # SHA-256 of input tokens
178
+ state_vec_norm: str # L2 norm of state vector (D4: stored as metadata)
179
+ extraction_mode: str # StateExtractionMode value
180
+ block_index: str # block position within a multi-block cache
181
+ total_blocks: str
182
+
183
+
184
+ class CacheSearchResult(TypedDict):
185
+ """Result from EGR manifold search over cached engram states."""
186
+
187
+ cache_id: str
188
+ similarity: float # raw inner product score (not normalized — D4)
189
+ task_description: str
190
+ model_id: str
191
+ created_at: str
192
+ context_len: int
193
+
194
+
195
+ class CacheStats(TypedDict):
196
+ """Aggregate statistics for the engram store."""
197
+
198
+ total_entries: int
199
+ total_size_bytes: int
200
+ avg_compression_ratio: float
201
+ model_breakdown: dict[str, int] # model_family → count
kvcos/engram/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # EIGENGRAM format package
2
+ from .format import EigramEncoder, EigramDecoder, EIGENGRAM_MAGIC, EIGENGRAM_VERSION
3
+ from .writer import write_eigengram
4
+ from .reader import read_eigengram, load_eigengram_index
kvcos/engram/__main__.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EIGENGRAM command-line interface.
3
+
4
+ Usage:
5
+ python -m kvcos.engram encode --model <gguf> --text "..." --out doc.eng
6
+ python -m kvcos.engram search --model <gguf> --query "..." index/*.eng
7
+ python -m kvcos.engram inspect doc.eng
8
+ python -m kvcos.engram list index/*.eng
9
+
10
+ Commands:
11
+ encode Run a document through a GGUF model and write a .eng file.
12
+ search Query .eng files using a text query and a model.
13
+ inspect Print all fields from .eng files (no model needed).
14
+ list Print a summary table of .eng files (no model needed).
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import argparse
20
+ import gc
21
+ import glob
22
+ import os
23
+ import sys
24
+
25
+ import torch
26
+
27
+
28
+ def _resolve_paths(patterns: list[str]) -> list[str]:
29
+ """Expand glob patterns, return sorted list of .eng paths."""
30
+ paths = []
31
+ for p in patterns:
32
+ expanded = glob.glob(p)
33
+ if expanded:
34
+ paths.extend(expanded)
35
+ elif os.path.exists(p):
36
+ paths.append(p)
37
+ else:
38
+ print(f"Warning: no files matched '{p}'", file=sys.stderr)
39
+ return sorted(set(paths))
40
+
41
+
42
+ def cmd_encode(args: argparse.Namespace) -> None:
43
+ """Encode a document as a .eng EIGENGRAM file."""
44
+ from kvcos.engram.writer import write_eigengram
45
+
46
+ if args.text:
47
+ text = args.text
48
+ elif args.file:
49
+ if not os.path.exists(args.file):
50
+ print(f"Error: input file not found: {args.file}", file=sys.stderr)
51
+ sys.exit(1)
52
+ text = open(args.file).read().strip()
53
+ else:
54
+ print("Error: provide --text or --file", file=sys.stderr)
55
+ sys.exit(1)
56
+
57
+ output_path = args.out or (
58
+ os.path.splitext(args.file)[0] + ".eng" if args.file else "output.eng"
59
+ )
60
+ task_desc = args.description or text[:80]
61
+ cache_id_val = args.id or text[:64]
62
+
63
+ print(f"Encoding document...")
64
+ print(f" Model: {args.model}")
65
+ print(f" Text: {text[:60]}{'...' if len(text) > 60 else ''}")
66
+ print(f" Output: {output_path}")
67
+ print()
68
+
69
+ result = write_eigengram(
70
+ model_path=args.model,
71
+ text=text,
72
+ output_path=output_path,
73
+ cache_id=cache_id_val,
74
+ task_description=task_desc,
75
+ basis_path=args.basis,
76
+ )
77
+
78
+ print(f"Done.")
79
+ print(f" File size : {result['file_size_bytes']} bytes")
80
+ print(f" Model ID : {result['model_id']}")
81
+ print(f" SCS : {result['scs']:.4f}")
82
+ print(f" Basis rank: {result['basis_rank']}")
83
+
84
+
85
+ def cmd_search(args: argparse.Namespace) -> None:
86
+ """Search .eng files using a text query."""
87
+ from llama_cpp import Llama
88
+
89
+ from kvcos.core.blob_parser import parse_state_blob
90
+ from kvcos.engram.reader import load_eigengram_index
91
+ from kvcos.core.manifold_index import ManifoldIndex
92
+
93
+ paths = _resolve_paths(args.eng_files)
94
+ if not paths:
95
+ print("No .eng files found.", file=sys.stderr)
96
+ sys.exit(1)
97
+
98
+ fingerprint = args.fingerprint
99
+ saved = torch.load(args.basis, weights_only=False)
100
+ P = saved["basis"]
101
+ center = saved["joint_center"]
102
+ LR = (8, 24)
103
+ GATE = 6
104
+ RANK = P.shape[0]
105
+
106
+ print(f"Query: {args.query}")
107
+ print(f"Index: {len(paths)} files")
108
+ print(f"Fingerprint: {fingerprint}")
109
+ print()
110
+
111
+ llm = Llama(model_path=args.model, n_ctx=2048, n_gpu_layers=-1, verbose=False)
112
+ meta = llm.metadata
113
+ n_kv = int(meta.get("llama.attention.head_count_kv", "8"))
114
+ hd = int(meta.get("llama.embedding_length", "4096")) // int(
115
+ meta.get("llama.attention.head_count", "32")
116
+ )
117
+ llm.reset()
118
+ llm(args.query.strip(), max_tokens=1, temperature=0.0)
119
+ p_q = parse_state_blob(
120
+ bytes(llm.save_state().llama_state), n_kv_heads=n_kv, head_dim=hd
121
+ )
122
+ del llm
123
+ gc.collect()
124
+
125
+ if fingerprint == "fourier":
126
+ from kvcos.core.fingerprint import compute_fourier_fingerprint
127
+
128
+ # Use ALL layers for Fourier (not sliced)
129
+ layer_means = p_q.keys.float().mean(dim=2).reshape(p_q.keys.shape[0], -1)
130
+ query_vec = compute_fourier_fingerprint(layer_means, freqs=[0, 1])
131
+ dim = query_vec.shape[0]
132
+ elif fingerprint == "perdoc":
133
+ k_q = p_q.keys[LR[0] : LR[1]].float().reshape(-1, hd)
134
+ _, _, Vh = torch.linalg.svd(k_q, full_matrices=False)
135
+ proj_q = (k_q @ Vh[GATE : GATE + RANK].T).mean(0)
136
+ query_vec = proj_q / (proj_q.norm() + 1e-8)
137
+ dim = RANK
138
+ else: # fcdb
139
+ k_q = p_q.keys[LR[0] : LR[1]].float().reshape(-1, hd)
140
+ mean_q = k_q.mean(0)
141
+ delta_q = mean_q - center
142
+ delta_q = delta_q / (delta_q.norm() + 1e-8)
143
+ query_vec = delta_q @ P.T
144
+ query_vec = query_vec / (query_vec.norm() + 1e-8)
145
+ dim = RANK
146
+
147
+ vecs, entries = load_eigengram_index(paths, fingerprint=fingerprint)
148
+ idx = ManifoldIndex(dim=dim)
149
+ for v, e in zip(vecs, entries):
150
+ idx.add(v, e)
151
+
152
+ top_k = min(args.top_k, len(paths))
153
+ results = idx.search(query_vec, top_k=top_k)
154
+
155
+ print(f"Results (top {top_k}):")
156
+ print(f" {'#':<3} {'sim':>7} {'cache_id':<20} description")
157
+ print(f" {'---'} {'-------'} {'--------------------'} {'----------------------------------------'}")
158
+ for i, r in enumerate(results):
159
+ desc = r.get("task_description", "")[:40]
160
+ cid = r.get("cache_id", "")[:20]
161
+ print(f" {i + 1:<3} {r['similarity']:>+.4f} {cid:<20} {desc}")
162
+
163
+
164
+ def cmd_inspect(args: argparse.Namespace) -> None:
165
+ """Print all fields of .eng files in readable format."""
166
+ from kvcos.engram.reader import read_eigengram
167
+
168
+ paths = _resolve_paths(args.eng_files)
169
+ if not paths:
170
+ print("No .eng files found.", file=sys.stderr)
171
+ sys.exit(1)
172
+
173
+ for path in paths:
174
+ try:
175
+ rec = read_eigengram(path)
176
+ except Exception as e:
177
+ print(f" {path}: ERROR - {e}")
178
+ continue
179
+
180
+ size = os.path.getsize(path)
181
+ print(f"{'=' * 55}")
182
+ print(f" File: {path} ({size} bytes)")
183
+ print(f" Format: EGR1 v{rec['version']}")
184
+ print(f" Created: {rec['created_at']} UTC")
185
+ print(f" Model: {rec['model_id']}")
186
+ print(f" cache_id: {rec['cache_id']}")
187
+ print(f" Description: {rec['task_description']}")
188
+ print()
189
+ print(f" Basis rank: {rec['basis_rank']}")
190
+ print(f" N corpus: {rec['n_corpus']}")
191
+ print(f" Layer range: {rec['layer_range']}")
192
+ print(f" Context len: {rec['context_len']} KV rows")
193
+ print(f" L2 norm: {rec['l2_norm']:.4f}")
194
+ print(f" SCS: {rec['scs']:.4f}")
195
+ print(f" Margin proof: {rec['margin_proof']:.4f}")
196
+ print(f" Corpus hash: {rec['corpus_hash']}")
197
+ print(f" vec_perdoc: [{rec['vec_perdoc'].shape[0]}] norm={rec['vec_perdoc'].norm():.4f}")
198
+ print(f" vec_fcdb: [{rec['vec_fcdb'].shape[0]}] norm={rec['vec_fcdb'].norm():.4f}")
199
+ print()
200
+
201
+
202
+ def cmd_list(args: argparse.Namespace) -> None:
203
+ """Print a one-line summary table of .eng files."""
204
+ from kvcos.engram.reader import read_eigengram
205
+
206
+ paths = _resolve_paths(args.eng_files)
207
+ if not paths:
208
+ print("No .eng files found.", file=sys.stderr)
209
+ sys.exit(1)
210
+
211
+ hdr = f"{'filename':<30} {'model':<14} {'scs':>6} {'bytes':>5} description"
212
+ print(hdr)
213
+ print("-" * len(hdr))
214
+
215
+ for path in paths:
216
+ fname = os.path.basename(path)[:29]
217
+ try:
218
+ rec = read_eigengram(path)
219
+ size = os.path.getsize(path)
220
+ print(
221
+ f"{fname:<30} {rec['model_id'][:14]:<14} "
222
+ f"{rec['scs']:>6.3f} {size:>5} "
223
+ f"{rec['task_description'][:40]}"
224
+ )
225
+ except Exception as e:
226
+ print(f"{fname:<30} ERROR: {e}")
227
+
228
+
229
+ def main() -> None:
230
+ """EIGENGRAM CLI entry point."""
231
+ parser = argparse.ArgumentParser(
232
+ prog="python -m kvcos.engram",
233
+ description="EIGENGRAM CLI - encode and search KV-cache semantic certificates.",
234
+ )
235
+ sub = parser.add_subparsers(dest="command", required=True)
236
+
237
+ enc = sub.add_parser("encode", help="Encode a document as a .eng file.")
238
+ enc.add_argument("--model", required=True, help="Path to GGUF model file.")
239
+ enc.add_argument("--text", help="Document text.")
240
+ enc.add_argument("--file", help="Path to a text file to encode.")
241
+ enc.add_argument("--out", help="Output .eng file path.")
242
+ enc.add_argument("--id", help="Unique cache_id.")
243
+ enc.add_argument("--description", help="Human-readable description.")
244
+ enc.add_argument("--basis", default="results/corpus_basis_fcdb_v2.pt", help="FCDB v2 basis path.")
245
+
246
+ srch = sub.add_parser("search", help="Search .eng files with a query.")
247
+ srch.add_argument("--model", required=True, help="GGUF model for query encoding.")
248
+ srch.add_argument("--query", required=True, help="Query text.")
249
+ srch.add_argument("--fingerprint", default="fourier", choices=["perdoc", "fcdb", "fourier"])
250
+ srch.add_argument("--top-k", type=int, default=5, dest="top_k")
251
+ srch.add_argument("--basis", default="results/corpus_basis_fcdb_v2.pt")
252
+ srch.add_argument("eng_files", nargs="+", help=".eng file paths or globs.")
253
+
254
+ ins = sub.add_parser("inspect", help="Print all fields of .eng files.")
255
+ ins.add_argument("eng_files", nargs="+")
256
+
257
+ lst = sub.add_parser("list", help="Summary table of .eng files.")
258
+ lst.add_argument("eng_files", nargs="+")
259
+
260
+ args = parser.parse_args()
261
+ {"encode": cmd_encode, "search": cmd_search, "inspect": cmd_inspect, "list": cmd_list}[args.command](args)
262
+
263
+
264
+ if __name__ == "__main__":
265
+ main()
kvcos/engram/chunker.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/chunker.py — Markdown-aware semantic chunker.
3
+
4
+ Splits markdown files into chunks suitable for .eng indexing.
5
+ Each chunk gets its own fingerprint and becomes independently
6
+ retrievable via HNSW.
7
+
8
+ Strategy:
9
+ 1. Split on H1/H2 headers first (natural semantic boundaries)
10
+ 2. If a section exceeds max_chars, split on H3/H4
11
+ 3. If still too large, split on paragraph boundaries
12
+ 4. Never break mid-paragraph (preserve semantic coherence)
13
+
14
+ Each chunk carries context: the file's title + parent headers
15
+ are prepended so the fingerprint captures the full meaning.
16
+ """
17
+
18
+ from __future__ import annotations
19
+
20
+ import re
21
+ from dataclasses import dataclass
22
+ from typing import Sequence
23
+
24
+
25
+ @dataclass(frozen=True)
26
+ class Chunk:
27
+ """One semantic chunk from a markdown file."""
28
+ text: str # Chunk content (with context header prepended)
29
+ raw_text: str # Original content without context header
30
+ char_start: int # Start offset in original file
31
+ char_end: int # End offset in original file
32
+ index: int # 0-based chunk index
33
+ headers: tuple[str, ...] # Header hierarchy (e.g., ("# Title", "## Section"))
34
+
35
+ @property
36
+ def char_count(self) -> int:
37
+ return len(self.text)
38
+
39
+
40
+ # Regex for markdown headers (ATX style: # through ######)
41
+ _HEADER_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
42
+
43
+
44
+ def _header_level(line: str) -> int:
45
+ """Return header level (1-6) or 0 if not a header."""
46
+ m = re.match(r"^(#{1,6})\s+", line)
47
+ return len(m.group(1)) if m else 0
48
+
49
+
50
+ def _split_by_headers(
51
+ content: str,
52
+ max_level: int = 2,
53
+ ) -> list[tuple[int, int, list[str]]]:
54
+ """
55
+ Split content into sections by header level.
56
+
57
+ Returns list of (start_offset, end_offset, header_stack) tuples.
58
+ max_level: split on headers of this level or lower (1=H1, 2=H2, etc.)
59
+ """
60
+ sections: list[tuple[int, int, list[str]]] = []
61
+ header_stack: list[str] = []
62
+ current_start = 0
63
+
64
+ for m in _HEADER_RE.finditer(content):
65
+ level = len(m.group(1))
66
+ header_text = m.group(0).strip()
67
+
68
+ if level <= max_level and m.start() > current_start:
69
+ # Close previous section
70
+ section_text = content[current_start:m.start()].strip()
71
+ if section_text:
72
+ sections.append((
73
+ current_start,
74
+ m.start(),
75
+ list(header_stack),
76
+ ))
77
+ current_start = m.start()
78
+
79
+ # Update header stack
80
+ if level <= max_level:
81
+ # Trim stack to parent level and push current
82
+ header_stack = [
83
+ h for h in header_stack
84
+ if _header_level(h) < level
85
+ ]
86
+ header_stack.append(header_text)
87
+
88
+ # Final section
89
+ if current_start < len(content):
90
+ final_text = content[current_start:].strip()
91
+ if final_text:
92
+ sections.append((
93
+ current_start,
94
+ len(content),
95
+ list(header_stack),
96
+ ))
97
+
98
+ return sections
99
+
100
+
101
+ def _split_paragraphs(
102
+ text: str,
103
+ max_chars: int,
104
+ base_offset: int = 0,
105
+ ) -> list[tuple[int, int]]:
106
+ """
107
+ Split text into chunks at paragraph boundaries.
108
+
109
+ Returns list of (start_offset, end_offset) tuples relative
110
+ to the original file (offset by base_offset).
111
+ """
112
+ paragraphs = re.split(r"\n\n+", text)
113
+ chunks: list[tuple[int, int]] = []
114
+ current_start = 0
115
+ current_len = 0
116
+
117
+ for para in paragraphs:
118
+ para_len = len(para) + 2 # +2 for the \n\n separator
119
+
120
+ if current_len + para_len > max_chars and current_len > 0:
121
+ # Close current chunk
122
+ chunks.append((
123
+ base_offset + current_start,
124
+ base_offset + current_start + current_len,
125
+ ))
126
+ current_start = current_start + current_len
127
+ current_len = 0
128
+
129
+ current_len += para_len
130
+
131
+ # Final chunk
132
+ if current_len > 0:
133
+ chunks.append((
134
+ base_offset + current_start,
135
+ base_offset + current_start + current_len,
136
+ ))
137
+
138
+ return chunks
139
+
140
+
141
+ def chunk_markdown(
142
+ content: str,
143
+ max_chars: int = 2000,
144
+ min_chars: int = 100,
145
+ context_prefix: str = "",
146
+ ) -> list[Chunk]:
147
+ """
148
+ Split a markdown file into semantic chunks.
149
+
150
+ Args:
151
+ content: Full markdown file content.
152
+ max_chars: Target maximum chars per chunk (soft limit).
153
+ min_chars: Minimum chars — smaller sections merge with next.
154
+ context_prefix: Prepended to each chunk for context
155
+ (e.g., "Source: geodesic3.md | Project: engram").
156
+
157
+ Returns:
158
+ List of Chunk objects, ordered by position in file.
159
+ """
160
+ if not content.strip():
161
+ return []
162
+
163
+ # If the whole file fits in one chunk, return it directly
164
+ if len(content) <= max_chars:
165
+ full_text = f"{context_prefix}\n\n{content}" if context_prefix else content
166
+ return [Chunk(
167
+ text=full_text,
168
+ raw_text=content,
169
+ char_start=0,
170
+ char_end=len(content),
171
+ index=0,
172
+ headers=(),
173
+ )]
174
+
175
+ # Phase 1: Split on H1/H2 boundaries
176
+ sections = _split_by_headers(content, max_level=2)
177
+
178
+ # If no headers found, treat as single block
179
+ if not sections:
180
+ sections = [(0, len(content), [])]
181
+
182
+ # Phase 2: Sub-split large sections on H3/H4
183
+ refined: list[tuple[int, int, list[str]]] = []
184
+ for start, end, headers in sections:
185
+ section_text = content[start:end]
186
+ if len(section_text) > max_chars:
187
+ subsections = _split_by_headers(section_text, max_level=4)
188
+ if len(subsections) > 1:
189
+ for sub_start, sub_end, sub_headers in subsections:
190
+ refined.append((
191
+ start + sub_start,
192
+ start + sub_end,
193
+ headers + sub_headers,
194
+ ))
195
+ else:
196
+ refined.append((start, end, headers))
197
+ else:
198
+ refined.append((start, end, headers))
199
+
200
+ # Phase 3: Paragraph-split anything still over max_chars
201
+ final_ranges: list[tuple[int, int, list[str]]] = []
202
+ for start, end, headers in refined:
203
+ section_text = content[start:end]
204
+ if len(section_text) > max_chars:
205
+ para_ranges = _split_paragraphs(section_text, max_chars, start)
206
+ for p_start, p_end in para_ranges:
207
+ final_ranges.append((p_start, p_end, headers))
208
+ else:
209
+ final_ranges.append((start, end, headers))
210
+
211
+ # Phase 4: Greedily pack sections into chunks up to max_chars.
212
+ # Keep merging consecutive sections while their combined size
213
+ # stays under max_chars. This prevents over-fragmentation of
214
+ # files with many small header sections.
215
+ merged: list[tuple[int, int, list[str]]] = []
216
+ for start, end, headers in final_ranges:
217
+ chunk_text = content[start:end].strip()
218
+ if not chunk_text:
219
+ continue
220
+
221
+ if merged:
222
+ prev_start, prev_end, prev_headers = merged[-1]
223
+ prev_len = prev_end - prev_start
224
+ curr_len = end - start
225
+
226
+ # Merge if combined chunk stays under max_chars
227
+ if (prev_len + curr_len) <= max_chars:
228
+ merged[-1] = (prev_start, end, prev_headers)
229
+ continue
230
+
231
+ merged.append((start, end, headers))
232
+
233
+ # Phase 5: Build Chunk objects with context
234
+ chunks: list[Chunk] = []
235
+ for idx, (start, end, headers) in enumerate(merged):
236
+ raw = content[start:end].strip()
237
+ if not raw:
238
+ continue
239
+
240
+ # Build context header
241
+ parts = []
242
+ if context_prefix:
243
+ parts.append(context_prefix)
244
+ if headers:
245
+ parts.append(" > ".join(headers))
246
+
247
+ prefix = "\n".join(parts)
248
+ text = f"{prefix}\n\n{raw}" if prefix else raw
249
+
250
+ chunks.append(Chunk(
251
+ text=text,
252
+ raw_text=raw,
253
+ char_start=start,
254
+ char_end=end,
255
+ index=idx,
256
+ headers=tuple(headers),
257
+ ))
258
+
259
+ # Re-index after merging
260
+ return [
261
+ Chunk(
262
+ text=c.text,
263
+ raw_text=c.raw_text,
264
+ char_start=c.char_start,
265
+ char_end=c.char_end,
266
+ index=i,
267
+ headers=c.headers,
268
+ )
269
+ for i, c in enumerate(chunks)
270
+ ]
271
+
272
+
273
+ def slug_from_path(path: str) -> str:
274
+ """
275
+ Generate a kebab-case slug from a file path.
276
+
277
+ Examples:
278
+ "geodesic3.md" → "geodesic3"
279
+ "EIGENGRAM_SPEC.md" → "eigengram-spec"
280
+ "coding-style.md" → "coding-style"
281
+ """
282
+ name = path.rsplit("/", 1)[-1] # filename only
283
+ name = name.rsplit(".", 1)[0] # strip extension
284
+ # Convert underscores and spaces to hyphens, lowercase
285
+ slug = re.sub(r"[_\s]+", "-", name).lower()
286
+ # Strip non-alphanumeric except hyphens
287
+ slug = re.sub(r"[^a-z0-9-]", "", slug)
288
+ # Collapse multiple hyphens
289
+ slug = re.sub(r"-+", "-", slug).strip("-")
290
+ return slug
291
+
292
+
293
+ def eng_filename(
294
+ project: str,
295
+ slug: str,
296
+ date: str,
297
+ chunk_index: int | None = None,
298
+ chunk_total: int | None = None,
299
+ time_str: str = "",
300
+ ) -> str:
301
+ """
302
+ Generate .eng filename following the naming convention.
303
+
304
+ Format: <slug>[_<chunk>]_<date>[_<time>].eng
305
+
306
+ Args:
307
+ project: Project namespace (used for directory, not filename)
308
+ slug: Kebab-case file identifier
309
+ date: ISO date string (YYYY-MM-DD)
310
+ chunk_index: 0-based chunk index (None if single chunk)
311
+ chunk_total: Total chunks (None if single chunk)
312
+ time_str: Optional HHmm time string
313
+
314
+ Returns:
315
+ Filename (not full path) like "geodesic3_001_2026-04-02.eng"
316
+ """
317
+ parts = [slug]
318
+
319
+ if chunk_index is not None and chunk_total is not None and chunk_total > 1:
320
+ parts.append(f"{chunk_index + 1:03d}")
321
+
322
+ parts.append(date)
323
+
324
+ if time_str:
325
+ parts.append(time_str)
326
+
327
+ return "_".join(parts) + ".eng"
kvcos/engram/embedder.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/embedder.py — Unified text-to-fingerprint embedding.
3
+
4
+ Three strategies, tried in priority order:
5
+ 1. llama_cpp: Native ENGRAM KV-cache Fourier pipeline (2048-dim)
6
+ 2. sbert: Sentence-transformers all-MiniLM-L6-v2 (384-dim)
7
+ 3. hash: Deterministic SHA256-seeded pseudo-fingerprint (2048-dim)
8
+
9
+ The chosen strategy is cached after first call. The fingerprint
10
+ source tag travels with every .eng file so retrieval knows what
11
+ comparison is valid.
12
+
13
+ Usage:
14
+ from kvcos.engram.embedder import get_fingerprint
15
+ fp, source = get_fingerprint("some text")
16
+ # fp: torch.Tensor, source: "llama_cpp"|"sbert"|"hash-fallback"
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import hashlib
22
+ import os
23
+ from pathlib import Path
24
+ from typing import Protocol
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+
30
+ class Embedder(Protocol):
31
+ """Protocol for text → fingerprint embedding."""
32
+ def embed(self, text: str) -> torch.Tensor: ...
33
+ @property
34
+ def source(self) -> str: ...
35
+ @property
36
+ def dim(self) -> int: ...
37
+
38
+
39
+ # ── Strategy 1: Native ENGRAM (llama_cpp) ────────────────────────────
40
+
41
+ class LlamaCppEmbedder:
42
+ """KV-cache Fourier fingerprint via local GGUF model.
43
+
44
+ Uses the full ENGRAM pipeline:
45
+ text → LlamaCppBridge (generate → KV cache) → Fourier DFT → fingerprint
46
+
47
+ Supports both standard and ISWA models:
48
+ Standard (Llama): 2048-dim (8 × 128 × 2)
49
+ ISWA (Gemma 4): 6144-dim (1024×2 + 2048×2)
50
+ """
51
+
52
+ def __init__(self, model_path: str) -> None:
53
+ from integrations.llama_cpp_bridge import LlamaCppBridge
54
+ from kvcos.core.cache_spec import is_iswa_spec
55
+ from kvcos.core.fingerprint import compute_fourier_fingerprint_v2, compute_iswa_fingerprint
56
+
57
+ self._bridge = LlamaCppBridge(
58
+ model_path,
59
+ n_ctx=2048,
60
+ n_gpu_layers=0,
61
+ verbose=False,
62
+ )
63
+ self._spec = self._bridge.load_model()
64
+ self._is_iswa = is_iswa_spec(self._spec)
65
+ self._compute_standard = compute_fourier_fingerprint_v2
66
+ self._compute_iswa = compute_iswa_fingerprint
67
+
68
+ if self._is_iswa:
69
+ sections = self._spec["cache_sections"]
70
+ self._dim = sum(s.n_kv_heads * s.head_dim * 2 for s in sections)
71
+ else:
72
+ self._dim = self._spec["n_kv_heads"] * self._spec["head_dim"] * 2
73
+
74
+ def embed(self, text: str) -> torch.Tensor:
75
+ """Generate text through model, extract KV keys, compute Fourier fp."""
76
+ self._bridge.llm.reset()
77
+ self._bridge.generate(text, max_tokens=1)
78
+
79
+ if self._is_iswa:
80
+ parsed = self._bridge.extract_kv_cache_iswa()
81
+ return self._compute_iswa(parsed, freqs=[0, 1])
82
+
83
+ parsed = self._bridge.extract_kv_cache()
84
+ layer_keys = parsed.keys.float().mean(dim=2)
85
+ return self._compute_standard(layer_keys, freqs=[0, 1])
86
+
87
+ @property
88
+ def source(self) -> str:
89
+ return "llama_cpp"
90
+
91
+ @property
92
+ def dim(self) -> int:
93
+ return self._dim
94
+
95
+
96
+ # ── Strategy 2: Sentence-transformers ────────────────────────────────
97
+
98
+ class SBertEmbedder:
99
+ """Semantic fingerprint via sentence-transformers.
100
+
101
+ Uses all-MiniLM-L6-v2 (80MB, 384-dim). Downloads on first use.
102
+ Subsequent calls use the cached model (~50ms per text on CPU).
103
+ """
104
+
105
+ MODEL_NAME = "all-MiniLM-L6-v2"
106
+
107
+ def __init__(self) -> None:
108
+ import logging
109
+ import warnings
110
+ # Suppress noisy HF/tokenizer/sbert/safetensors warnings on load
111
+ for name in (
112
+ "sentence_transformers",
113
+ "transformers",
114
+ "transformers.modeling_utils",
115
+ "huggingface_hub",
116
+ "huggingface_hub.utils",
117
+ "safetensors",
118
+ ):
119
+ logging.getLogger(name).setLevel(logging.CRITICAL)
120
+ # Suppress the HF_TOKEN and load-report warnings
121
+ os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
122
+ os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
123
+ os.environ.setdefault("HF_HUB_VERBOSITY", "error")
124
+ warnings.filterwarnings("ignore", category=FutureWarning)
125
+ from sentence_transformers import SentenceTransformer
126
+ self._model = SentenceTransformer(self.MODEL_NAME)
127
+ self._dim = self._model.get_sentence_embedding_dimension()
128
+
129
+ def embed(self, text: str) -> torch.Tensor:
130
+ # encode returns numpy array
131
+ vec = self._model.encode(text, normalize_embeddings=True)
132
+ return torch.from_numpy(vec.astype(np.float32))
133
+
134
+ @property
135
+ def source(self) -> str:
136
+ return "sbert"
137
+
138
+ @property
139
+ def dim(self) -> int:
140
+ return self._dim
141
+
142
+
143
+ # ── Strategy 3: Hash fallback ────────────────────────────────────────
144
+
145
+ class HashEmbedder:
146
+ """Deterministic pseudo-fingerprint from SHA256 hash.
147
+
148
+ No semantic meaning — same text always maps to same vector,
149
+ but unrelated texts have random cosine similarity (~0).
150
+ """
151
+
152
+ def __init__(self, dim: int = 2048) -> None:
153
+ self._dim = dim
154
+
155
+ def embed(self, text: str) -> torch.Tensor:
156
+ seed = int(hashlib.sha256(text.encode()).hexdigest()[:8], 16)
157
+ rng = np.random.RandomState(seed)
158
+ fp = rng.randn(self._dim).astype("float32")
159
+ fp /= np.linalg.norm(fp) + 1e-8
160
+ return torch.from_numpy(fp)
161
+
162
+ @property
163
+ def source(self) -> str:
164
+ return "hash-fallback"
165
+
166
+ @property
167
+ def dim(self) -> int:
168
+ return self._dim
169
+
170
+
171
+ # ── Singleton factory ────────────────────────────────────────────────
172
+
173
+ _cached_embedder: Embedder | None = None
174
+
175
+
176
+ def _create_embedder() -> Embedder:
177
+ """Try strategies in priority order, return first that works."""
178
+
179
+ # Strategy 1: llama_cpp
180
+ model_path = os.environ.get("ENGRAM_MODEL_PATH", "")
181
+ if model_path and Path(model_path).exists():
182
+ try:
183
+ return LlamaCppEmbedder(model_path)
184
+ except Exception:
185
+ pass
186
+
187
+ # Strategy 2: sentence-transformers
188
+ try:
189
+ embedder = SBertEmbedder()
190
+ return embedder
191
+ except Exception:
192
+ pass
193
+
194
+ # Strategy 3: hash fallback (always works)
195
+ return HashEmbedder()
196
+
197
+
198
+ def get_embedder() -> Embedder:
199
+ """Get the cached embedder singleton."""
200
+ global _cached_embedder
201
+ if _cached_embedder is None:
202
+ _cached_embedder = _create_embedder()
203
+ return _cached_embedder
204
+
205
+
206
+ def get_fingerprint(text: str) -> tuple[torch.Tensor, str]:
207
+ """
208
+ Compute fingerprint for text using best available strategy.
209
+
210
+ Returns:
211
+ (fingerprint_tensor, source_tag)
212
+ """
213
+ embedder = get_embedder()
214
+ fp = embedder.embed(text)
215
+ return fp, embedder.source
216
+
217
+
218
+ def reset_embedder() -> None:
219
+ """Reset the cached embedder (for testing or strategy switching)."""
220
+ global _cached_embedder
221
+ _cached_embedder = None
kvcos/engram/format.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EIGENGRAM binary format codec (EGR1 v1.0).
3
+
4
+ An EIGENGRAM (.eng) file is a self-contained semantic certificate
5
+ for a KV-cache document. It encodes two fingerprint vectors, the
6
+ shared coordinate system they live in, and enough metadata to
7
+ reproduce the query fold-in without access to the original text
8
+ or model.
9
+
10
+ Design goals:
11
+ - Portable: pure binary, no JSON, no pickle, no protobuf.
12
+ - Versioned: magic bytes + version field.
13
+ - Self-contained: joint_center embedded for query fold-in.
14
+ - Compact: float16 vectors, ~800 bytes total per document.
15
+
16
+ Dual-fingerprint architecture:
17
+ vec_perdoc - per-document SVD projection (same-model, margin ~0.37)
18
+ vec_fcdb - FCDB projection (cross-model, margin ~0.013)
19
+
20
+ Binary layout (little-endian, 99-byte fixed header):
21
+ Offset Size Type Field
22
+ 0 4 bytes magic = "EGR1"
23
+ 4 1 uint8 version (currently 1)
24
+ 5 32 ascii corpus_hash
25
+ 37 20 ascii created_at
26
+ 57 16 ascii model_id (null-padded)
27
+ 73 2 uint16 basis_rank R
28
+ 75 2 uint16 n_corpus
29
+ 77 2 int8x2 layer_range
30
+ 79 4 uint32 context_len
31
+ 83 4 float32 l2_norm
32
+ 87 4 float32 scs
33
+ 91 4 float32 margin_proof
34
+ 95 2 uint16 task_desc_len
35
+ 97 2 uint16 cache_id_len
36
+ Variable:
37
+ 99 R*2 float16 vec_perdoc
38
+ +R*2 R*2 float16 vec_fcdb
39
+ +2R*2 256 float16 joint_center (128 x float16)
40
+ +256 var utf-8 task_description
41
+ +var var utf-8 cache_id
42
+
43
+ Total for R=116: ~800 bytes.
44
+
45
+ Compatibility: readers MUST reject magic != "EGR1" or version mismatch.
46
+ """
47
+
48
+ from __future__ import annotations
49
+
50
+ import struct
51
+
52
+ import numpy as np
53
+ import torch
54
+
55
+ EIGENGRAM_MAGIC = b"EGR1"
56
+ EIGENGRAM_VERSION = 1
57
+
58
+
59
+ class EigramEncoder:
60
+ """Encode and decode EIGENGRAM binary certificates.
61
+
62
+ A single instance handles both directions. EigramDecoder is an alias.
63
+ Float16 storage preserves cosine similarity to > 0.999.
64
+ """
65
+
66
+ def encode(
67
+ self,
68
+ vec_perdoc: torch.Tensor,
69
+ vec_fcdb: torch.Tensor,
70
+ joint_center: torch.Tensor,
71
+ corpus_hash: str,
72
+ model_id: str,
73
+ basis_rank: int,
74
+ n_corpus: int,
75
+ layer_range: tuple[int, int],
76
+ context_len: int,
77
+ l2_norm: float,
78
+ scs: float,
79
+ margin_proof: float,
80
+ task_description: str,
81
+ cache_id: str,
82
+ vec_fourier: torch.Tensor | None = None,
83
+ local_density: int = 0,
84
+ eigenform_score: float = 1.0,
85
+ confusion_flag: bool = False,
86
+ vec_fourier_v2: torch.Tensor | None = None,
87
+ ) -> bytes:
88
+ """Serialise all fields into an EIGENGRAM binary blob."""
89
+ from datetime import datetime, timezone
90
+
91
+ td_b = task_description.encode("utf-8")[:256]
92
+ ci_b = cache_id.encode("utf-8")[:64]
93
+ now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S")
94
+
95
+ buf = bytearray()
96
+ buf += EIGENGRAM_MAGIC
97
+ buf += struct.pack("<B", EIGENGRAM_VERSION)
98
+ buf += corpus_hash.encode("ascii")[:32].ljust(32, b"\x00")
99
+ buf += now.encode("ascii")[:20].ljust(20, b"\x00")
100
+ buf += model_id.encode("ascii")[:16].ljust(16, b"\x00")
101
+ buf += struct.pack("<H", basis_rank)
102
+ buf += struct.pack("<H", n_corpus)
103
+ buf += struct.pack("<bb", layer_range[0], layer_range[1])
104
+ buf += struct.pack("<I", context_len)
105
+ buf += struct.pack("<f", l2_norm)
106
+ buf += struct.pack("<f", scs)
107
+ buf += struct.pack("<f", margin_proof)
108
+ buf += struct.pack("<H", len(td_b))
109
+ buf += struct.pack("<H", len(ci_b))
110
+ # fourier_dim: 0 if no vec_fourier, else len(vec_fourier)
111
+ fourier_dim = len(vec_fourier) if vec_fourier is not None else 0
112
+ buf += struct.pack("<H", fourier_dim)
113
+ buf += struct.pack("<H", local_density)
114
+ buf += struct.pack("<f", eigenform_score)
115
+
116
+ buf += vec_perdoc.to(torch.float16).numpy().tobytes()
117
+ buf += vec_fcdb.to(torch.float16).numpy().tobytes()
118
+ buf += joint_center[:128].to(torch.float16).numpy().tobytes()
119
+
120
+ buf += td_b
121
+ buf += ci_b
122
+
123
+ # Append vec_fourier if present (backward-compatible extension)
124
+ if vec_fourier is not None:
125
+ buf += vec_fourier.to(torch.float16).numpy().tobytes()
126
+
127
+ # v1.2 extension: confusion_flag + vec_fourier_v2
128
+ # Written only when at least one is non-default, preserving
129
+ # backward compat with readers that stop after vec_fourier.
130
+ if confusion_flag or vec_fourier_v2 is not None:
131
+ buf += struct.pack("<B", 1 if confusion_flag else 0)
132
+ v2_dim = len(vec_fourier_v2) if vec_fourier_v2 is not None else 0
133
+ buf += struct.pack("<H", v2_dim)
134
+ if vec_fourier_v2 is not None:
135
+ buf += vec_fourier_v2.to(torch.float16).numpy().tobytes()
136
+
137
+ return bytes(buf)
138
+
139
+ def decode(self, data: bytes) -> dict:
140
+ """Deserialise an EIGENGRAM binary blob into a dict.
141
+
142
+ Returns dict with all fields. Vectors upcast to float32.
143
+ Raises ValueError on magic/version mismatch.
144
+ """
145
+ if len(data) < 4 or data[:4] != EIGENGRAM_MAGIC:
146
+ raise ValueError(
147
+ f"Invalid EIGENGRAM magic: {data[:4]!r} (expected {EIGENGRAM_MAGIC!r})"
148
+ )
149
+
150
+ off = 4
151
+ version = struct.unpack_from("<B", data, off)[0]; off += 1
152
+ if version != EIGENGRAM_VERSION:
153
+ raise ValueError(
154
+ f"Unsupported EIGENGRAM version {version} "
155
+ f"(this reader supports v{EIGENGRAM_VERSION})"
156
+ )
157
+
158
+ corpus_hash = data[off : off + 32].rstrip(b"\x00").decode("ascii"); off += 32
159
+ created_at = data[off : off + 20].rstrip(b"\x00").decode("ascii"); off += 20
160
+ model_id = data[off : off + 16].rstrip(b"\x00").decode("ascii"); off += 16
161
+
162
+ basis_rank = struct.unpack_from("<H", data, off)[0]; off += 2
163
+ n_corpus = struct.unpack_from("<H", data, off)[0]; off += 2
164
+ lr0, lr1 = struct.unpack_from("<bb", data, off); off += 2
165
+ context_len = struct.unpack_from("<I", data, off)[0]; off += 4
166
+ l2_norm = struct.unpack_from("<f", data, off)[0]; off += 4
167
+ scs = struct.unpack_from("<f", data, off)[0]; off += 4
168
+ margin_proof = struct.unpack_from("<f", data, off)[0]; off += 4
169
+ td_len = struct.unpack_from("<H", data, off)[0]; off += 2
170
+ ci_len = struct.unpack_from("<H", data, off)[0]; off += 2
171
+
172
+ # v1.1 extension fields: fourier_dim + local_density
173
+ # Detect by checking if file has extra bytes beyond v1.0 layout
174
+ fourier_dim = 0
175
+ local_density = 0
176
+ expected_old_size = off + basis_rank * 4 + 256 + td_len + ci_len
177
+ eigenform_score = 1.0
178
+ if len(data) > expected_old_size + 4:
179
+ fourier_dim = struct.unpack_from("<H", data, off)[0]; off += 2
180
+ local_density = struct.unpack_from("<H", data, off)[0]; off += 2
181
+ eigenform_score = struct.unpack_from("<f", data, off)[0]; off += 4
182
+ # If the file was written with fourier_dim field but is old format,
183
+ # we already consumed 2 bytes. This is safe because old files
184
+ # won't have extra bytes.
185
+
186
+ R = basis_rank
187
+ vec_perdoc = torch.from_numpy(
188
+ np.frombuffer(data, dtype=np.float16, count=R, offset=off).copy()
189
+ ).float(); off += R * 2
190
+
191
+ vec_fcdb = torch.from_numpy(
192
+ np.frombuffer(data, dtype=np.float16, count=R, offset=off).copy()
193
+ ).float(); off += R * 2
194
+
195
+ joint_center = torch.from_numpy(
196
+ np.frombuffer(data, dtype=np.float16, count=128, offset=off).copy()
197
+ ).float(); off += 128 * 2
198
+
199
+ task_description = data[off : off + td_len].decode("utf-8", errors="replace"); off += td_len
200
+ cache_id = data[off : off + ci_len].decode("utf-8", errors="replace"); off += ci_len
201
+
202
+ # Read vec_fourier if present
203
+ vec_fourier = None
204
+ if fourier_dim > 0 and off + fourier_dim * 2 <= len(data):
205
+ vec_fourier = torch.from_numpy(
206
+ np.frombuffer(data, dtype=np.float16, count=fourier_dim, offset=off).copy()
207
+ ).float()
208
+ off += fourier_dim * 2
209
+
210
+ # v1.2 extension: confusion_flag + vec_fourier_v2
211
+ confusion_flag = False
212
+ vec_fourier_v2 = None
213
+ if off + 3 <= len(data): # 1 byte flag + 2 byte dim minimum
214
+ confusion_flag = bool(struct.unpack_from("<B", data, off)[0])
215
+ off += 1
216
+ v2_dim = struct.unpack_from("<H", data, off)[0]
217
+ off += 2
218
+ if v2_dim > 0 and off + v2_dim * 2 <= len(data):
219
+ vec_fourier_v2 = torch.from_numpy(
220
+ np.frombuffer(data, dtype=np.float16, count=v2_dim, offset=off).copy()
221
+ ).float()
222
+
223
+ result = {
224
+ "version": version,
225
+ "corpus_hash": corpus_hash,
226
+ "created_at": created_at,
227
+ "model_id": model_id,
228
+ "basis_rank": basis_rank,
229
+ "n_corpus": n_corpus,
230
+ "layer_range": (lr0, lr1),
231
+ "context_len": context_len,
232
+ "l2_norm": l2_norm,
233
+ "scs": scs,
234
+ "margin_proof": margin_proof,
235
+ "vec_perdoc": vec_perdoc,
236
+ "vec_fcdb": vec_fcdb,
237
+ "joint_center": joint_center,
238
+ "task_description": task_description,
239
+ "cache_id": cache_id,
240
+ }
241
+ if vec_fourier is not None:
242
+ result["vec_fourier"] = vec_fourier
243
+ if vec_fourier_v2 is not None:
244
+ result["vec_fourier_v2"] = vec_fourier_v2
245
+ result["local_density"] = local_density
246
+ result["eigenform_score"] = eigenform_score
247
+ result["confusion_flag"] = confusion_flag
248
+ return result
249
+
250
+
251
+ EigramDecoder = EigramEncoder
kvcos/engram/hnsw_index.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM HNSW Index — O(log N) approximate nearest neighbor retrieval.
3
+
4
+ Wraps faiss.IndexHNSWFlat for production-scale ENGRAM search.
5
+ Primary fingerprint: v2 layer-normalized Fourier f0+f1.
6
+
7
+ Usage:
8
+ idx = EngramIndex(dim=2048)
9
+ idx.add_batch(doc_ids, vectors)
10
+ results = idx.search(query_fp, top_k=5)
11
+ idx.save('index/hnsw')
12
+ idx2 = EngramIndex.load('index/hnsw')
13
+ """
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ import os
20
+ from dataclasses import dataclass
21
+
22
+ import faiss
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn.functional as F
26
+
27
+ logger = logging.getLogger(__name__)
28
+
29
+
30
+ @dataclass
31
+ class HNSWResult:
32
+ """Single HNSW search result."""
33
+
34
+ doc_id: str
35
+ score: float
36
+ rank: int
37
+ margin: float = 0.0
38
+
39
+
40
+ class EngramIndex:
41
+ """HNSW-backed ENGRAM retrieval index.
42
+
43
+ HNSW parameters:
44
+ M=32: graph degree (higher = better recall, more memory)
45
+ efConstruction=200: build-time search width
46
+ efSearch=64: query-time search width
47
+ """
48
+
49
+ M = 32
50
+ EF_CONSTRUCTION = 200
51
+ EF_SEARCH = 64
52
+
53
+ def __init__(self, dim: int = 2048):
54
+ self._dim = dim
55
+ self._index: faiss.IndexHNSWFlat | None = None
56
+ self._ids: list[str] = []
57
+ self._id_to_pos: dict[str, int] = {}
58
+ self._n_docs: int = 0
59
+
60
+ def add_batch(
61
+ self,
62
+ doc_ids: list[str],
63
+ vectors: torch.Tensor,
64
+ ) -> None:
65
+ """Build HNSW index from vectors.
66
+
67
+ Args:
68
+ doc_ids: list of document identifiers
69
+ vectors: [N, dim] tensor of fingerprints
70
+ """
71
+ matrix = F.normalize(vectors.float(), dim=-1).numpy().astype("float32")
72
+ self._dim = matrix.shape[1]
73
+ self._ids = list(doc_ids)
74
+ self._id_to_pos = {cid: i for i, cid in enumerate(doc_ids)}
75
+ self._n_docs = len(doc_ids)
76
+
77
+ self._index = faiss.IndexHNSWFlat(self._dim, self.M)
78
+ self._index.hnsw.efConstruction = self.EF_CONSTRUCTION
79
+ self._index.hnsw.efSearch = self.EF_SEARCH
80
+ self._index.add(matrix)
81
+
82
+ def build(
83
+ self,
84
+ eng_files: list[str],
85
+ fp_key: str = "vec_fourier_v2",
86
+ verbose: bool = True,
87
+ ) -> None:
88
+ """Build HNSW index from list of .eng file paths.
89
+
90
+ Args:
91
+ eng_files: List of paths to .eng encoded files.
92
+ fp_key: Fingerprint field to index.
93
+ Default 'vec_fourier_v2' (S3 validated, 99.5% recall).
94
+ Falls back to 'vec_fourier' if v2 not present.
95
+ """
96
+ from kvcos.engram.reader import read_eigengram
97
+
98
+ doc_ids = []
99
+ vecs = []
100
+ missing_v2 = 0
101
+
102
+ for fp in eng_files:
103
+ data = read_eigengram(fp)
104
+ cid = data.get("cache_id")
105
+ if not cid:
106
+ continue
107
+
108
+ vec = data.get(fp_key)
109
+ if vec is None:
110
+ vec = data.get("vec_fourier")
111
+ missing_v2 += 1
112
+ if vec is None:
113
+ continue
114
+
115
+ doc_ids.append(cid)
116
+ vecs.append(vec.float())
117
+
118
+ if not vecs:
119
+ raise ValueError(
120
+ f"No valid fingerprints found in {len(eng_files)} files"
121
+ )
122
+
123
+ if missing_v2 > 0 and verbose:
124
+ logger.warning(
125
+ "%d docs missing %s, used vec_fourier fallback",
126
+ missing_v2, fp_key,
127
+ )
128
+
129
+ self.add_batch(doc_ids, torch.stack(vecs))
130
+
131
+ if verbose:
132
+ logger.info("HNSW index built: %d docs, dim=%d", self._n_docs, self._dim)
133
+ logger.info("M=%d, efC=%d, efS=%d", self.M, self.EF_CONSTRUCTION, self.EF_SEARCH)
134
+
135
+ def search(
136
+ self,
137
+ query_fp: torch.Tensor,
138
+ top_k: int = 5,
139
+ ) -> list[HNSWResult]:
140
+ """Search the HNSW index.
141
+
142
+ Returns list of HNSWResult sorted by score descending.
143
+ HNSW uses L2 on normalized vectors: cosine = 1 - L2^2/2.
144
+ """
145
+ if self._index is None:
146
+ raise RuntimeError("Index not built. Call add_batch() or load() first.")
147
+
148
+ qn = F.normalize(query_fp.float().unsqueeze(0), dim=-1).numpy().astype("float32")
149
+ D, I = self._index.search(qn, min(top_k + 1, self._n_docs))
150
+
151
+ results = []
152
+ for rank, (dist, idx) in enumerate(zip(D[0], I[0])):
153
+ if idx < 0:
154
+ continue
155
+ cosine_sim = float(1.0 - dist / 2.0)
156
+ results.append(HNSWResult(
157
+ doc_id=self._ids[idx], score=cosine_sim, rank=rank,
158
+ ))
159
+
160
+ if len(results) >= 2:
161
+ results[0].margin = results[0].score - results[1].score
162
+ return results[:top_k]
163
+
164
+ def save(self, path: str) -> None:
165
+ """Save index to disk (faiss + JSON metadata)."""
166
+ os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
167
+ faiss.write_index(self._index, path + ".faiss")
168
+ meta_path = path + ".meta.json"
169
+ with open(meta_path, "w") as f:
170
+ json.dump({
171
+ "ids": self._ids,
172
+ "id_to_pos": self._id_to_pos,
173
+ "dim": self._dim,
174
+ "n_docs": self._n_docs,
175
+ }, f, indent=2)
176
+
177
+ @classmethod
178
+ def load(cls, path: str) -> EngramIndex:
179
+ """Load index from disk."""
180
+ obj = cls()
181
+ obj._index = faiss.read_index(path + ".faiss")
182
+ meta_path = path + ".meta.json"
183
+ with open(meta_path, "r") as f:
184
+ meta = json.load(f)
185
+ obj._ids = meta["ids"]
186
+ obj._id_to_pos = meta["id_to_pos"]
187
+ obj._dim = meta["dim"]
188
+ obj._n_docs = meta["n_docs"]
189
+ return obj
190
+
191
+ def __len__(self) -> int:
192
+ return self._n_docs
193
+
194
+
195
+ def get_vector(self, doc_id: str) -> torch.Tensor | None:
196
+ """Return stored vector for doc_id, or None if not found."""
197
+ pos = self._id_to_pos.get(doc_id)
198
+ if pos is None:
199
+ return None
200
+ vec_np = np.zeros(self._dim, dtype="float32")
201
+ self._index.reconstruct(pos, vec_np)
202
+ return torch.from_numpy(vec_np)
203
+
204
+ def __repr__(self) -> str:
205
+ return f"EngramIndex(n={self._n_docs}, dim={self._dim}, M={self.M})"
kvcos/engram/index_c.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/index_c.py — Confidence history index for ENGRAM.
3
+
4
+ Stores retrieval confidence records across sessions.
5
+ Makes the system self-aware: chronic failures are known before retrieval.
6
+
7
+ Schema:
8
+ retrievals: one row per geodesic_retrieve() call
9
+ confusion_pairs: doc pairs that confuse each other (confidence<threshold)
10
+ doc_stats: per-doc aggregate reliability scores
11
+
12
+ Usage:
13
+ ic = IndexC.open("results/index_c.db")
14
+ ic.record(session_id="s1", query_doc_id="doc_146", result=geodesic_result)
15
+ prior = ic.prior("doc_146")
16
+ pairs = ic.confusion_registry()
17
+ rmap = ic.reliability_map()
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import sqlite3
24
+ import time
25
+ from dataclasses import dataclass
26
+ from pathlib import Path
27
+
28
+
29
+ @dataclass
30
+ class ConfidenceRecord:
31
+ session_id: str
32
+ query_doc_id: str
33
+ result_doc_id: str
34
+ confidence: str
35
+ margin: float
36
+ stages_used: int
37
+ constraint_used: bool
38
+ correct: bool
39
+ ts: float
40
+
41
+
42
+ @dataclass
43
+ class DocPrior:
44
+ """Prior confidence distribution for a doc_id."""
45
+
46
+ doc_id: str
47
+ n_high: int
48
+ n_medium: int
49
+ n_low: int
50
+ n_total: int
51
+ reliability: float
52
+ is_chronic_failure: bool
53
+
54
+ @property
55
+ def dominant_confidence(self) -> str:
56
+ if self.n_total == 0:
57
+ return "unknown"
58
+ counts = {
59
+ "high": self.n_high,
60
+ "medium": self.n_medium,
61
+ "low": self.n_low,
62
+ }
63
+ return max(counts, key=counts.get)
64
+
65
+
66
+ @dataclass
67
+ class ConfusionPair:
68
+ doc_a: str
69
+ doc_b: str
70
+ n_confusions: int
71
+ first_seen: float
72
+ last_seen: float
73
+
74
+
75
+ SCHEMA = """
76
+ CREATE TABLE IF NOT EXISTS retrievals (
77
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
78
+ session_id TEXT NOT NULL,
79
+ query_doc_id TEXT NOT NULL,
80
+ result_doc_id TEXT NOT NULL,
81
+ confidence TEXT NOT NULL,
82
+ margin REAL NOT NULL,
83
+ stages_used INTEGER NOT NULL,
84
+ constraint_used INTEGER NOT NULL,
85
+ correct INTEGER NOT NULL,
86
+ ts REAL NOT NULL
87
+ );
88
+
89
+ CREATE INDEX IF NOT EXISTS idx_ret_query ON retrievals(query_doc_id);
90
+ CREATE INDEX IF NOT EXISTS idx_ret_result ON retrievals(result_doc_id);
91
+ CREATE INDEX IF NOT EXISTS idx_ret_conf ON retrievals(confidence);
92
+ CREATE INDEX IF NOT EXISTS idx_ret_sess ON retrievals(session_id);
93
+
94
+ CREATE TABLE IF NOT EXISTS confusion_pairs (
95
+ doc_a TEXT NOT NULL,
96
+ doc_b TEXT NOT NULL,
97
+ n_confusions INTEGER NOT NULL DEFAULT 1,
98
+ first_seen REAL NOT NULL,
99
+ last_seen REAL NOT NULL,
100
+ PRIMARY KEY (doc_a, doc_b)
101
+ );
102
+
103
+ CREATE TABLE IF NOT EXISTS doc_stats (
104
+ doc_id TEXT PRIMARY KEY,
105
+ n_high INTEGER NOT NULL DEFAULT 0,
106
+ n_medium INTEGER NOT NULL DEFAULT 0,
107
+ n_low INTEGER NOT NULL DEFAULT 0,
108
+ reliability REAL NOT NULL DEFAULT 1.0,
109
+ last_updated REAL NOT NULL
110
+ );
111
+ """
112
+
113
+
114
+ class IndexC:
115
+ """
116
+ Confidence history index.
117
+
118
+ Backed by SQLite — append-only, persistent across sessions.
119
+ Provides priors for geodesic_retrieve() to pre-apply constraints
120
+ on docs that are known chronic failures.
121
+ """
122
+
123
+ CHRONIC_FAILURE_THRESHOLD = 0.5
124
+
125
+ def __init__(self, db_path: str):
126
+ self._db_path = str(db_path)
127
+ self._conn = sqlite3.connect(
128
+ db_path, check_same_thread=False, isolation_level=None
129
+ )
130
+ self._conn.execute("PRAGMA journal_mode=WAL")
131
+ self._conn.executescript(SCHEMA)
132
+
133
+ @classmethod
134
+ def open(cls, db_path: str | Path) -> "IndexC":
135
+ """Open (or create) the Index-C database at db_path."""
136
+ os.makedirs(Path(db_path).parent, exist_ok=True)
137
+ return cls(str(db_path))
138
+
139
+ # ── WRITE ────────────────────────────────────────────────────────
140
+
141
+ def record(
142
+ self,
143
+ session_id: str,
144
+ query_doc_id: str,
145
+ result_doc_id: str,
146
+ confidence: str,
147
+ margin: float,
148
+ stages_used: int = 1,
149
+ constraint_used: bool = False,
150
+ correct: bool = True,
151
+ ts: float | None = None,
152
+ ) -> None:
153
+ """Log one retrieval result to the index."""
154
+ ts = ts or time.time()
155
+ with self._conn:
156
+ self._conn.execute(
157
+ """INSERT INTO retrievals
158
+ (session_id, query_doc_id, result_doc_id, confidence,
159
+ margin, stages_used, constraint_used, correct, ts)
160
+ VALUES (?,?,?,?,?,?,?,?,?)""",
161
+ (session_id, query_doc_id, result_doc_id, confidence,
162
+ float(margin), int(stages_used), int(constraint_used),
163
+ int(correct), float(ts)),
164
+ )
165
+ if not correct:
166
+ self._register_confusion(query_doc_id, result_doc_id, ts)
167
+ self._update_doc_stats(query_doc_id, confidence, ts)
168
+
169
+ def _register_confusion(
170
+ self, doc_a: str, doc_b: str, ts: float
171
+ ) -> None:
172
+ """Insert or increment confusion pair."""
173
+ existing = self._conn.execute(
174
+ "SELECT n_confusions FROM confusion_pairs "
175
+ "WHERE doc_a=? AND doc_b=?",
176
+ (doc_a, doc_b),
177
+ ).fetchone()
178
+ if existing:
179
+ self._conn.execute(
180
+ "UPDATE confusion_pairs SET n_confusions=n_confusions+1, "
181
+ "last_seen=? WHERE doc_a=? AND doc_b=?",
182
+ (ts, doc_a, doc_b),
183
+ )
184
+ else:
185
+ self._conn.execute(
186
+ "INSERT INTO confusion_pairs "
187
+ "(doc_a, doc_b, n_confusions, first_seen, last_seen) "
188
+ "VALUES (?,?,1,?,?)",
189
+ (doc_a, doc_b, ts, ts),
190
+ )
191
+
192
+ def _update_doc_stats(
193
+ self, doc_id: str, confidence: str, ts: float
194
+ ) -> None:
195
+ """Upsert doc_stats row for doc_id."""
196
+ col_map = {"high": "n_high", "medium": "n_medium", "low": "n_low"}
197
+ col = col_map.get(confidence, "n_medium")
198
+
199
+ existing = self._conn.execute(
200
+ "SELECT n_high, n_medium, n_low FROM doc_stats WHERE doc_id=?",
201
+ (doc_id,),
202
+ ).fetchone()
203
+
204
+ if existing:
205
+ n_high, n_medium, n_low = existing
206
+ if col == "n_high":
207
+ n_high += 1
208
+ elif col == "n_medium":
209
+ n_medium += 1
210
+ else:
211
+ n_low += 1
212
+ n_total = n_high + n_medium + n_low
213
+ reliability = (n_high + n_medium) / n_total if n_total > 0 else 1.0
214
+ self._conn.execute(
215
+ "UPDATE doc_stats SET n_high=?, n_medium=?, n_low=?, "
216
+ "reliability=?, last_updated=? WHERE doc_id=?",
217
+ (n_high, n_medium, n_low, reliability, ts, doc_id),
218
+ )
219
+ else:
220
+ vals = {"n_high": 0, "n_medium": 0, "n_low": 0}
221
+ vals[col] = 1
222
+ reliability = (vals["n_high"] + vals["n_medium"]) / 1
223
+ self._conn.execute(
224
+ "INSERT INTO doc_stats "
225
+ "(doc_id, n_high, n_medium, n_low, reliability, last_updated) "
226
+ "VALUES (?,?,?,?,?,?)",
227
+ (doc_id, vals["n_high"], vals["n_medium"],
228
+ vals["n_low"], reliability, ts),
229
+ )
230
+
231
+ # ── READ ─────────────────────────────────────────────────────────
232
+
233
+ def prior(self, doc_id: str) -> DocPrior:
234
+ """Return prior confidence distribution for doc_id."""
235
+ row = self._conn.execute(
236
+ "SELECT n_high, n_medium, n_low, reliability "
237
+ "FROM doc_stats WHERE doc_id=?",
238
+ (doc_id,),
239
+ ).fetchone()
240
+ if not row:
241
+ return DocPrior(
242
+ doc_id=doc_id, n_high=0, n_medium=0, n_low=0,
243
+ n_total=0, reliability=1.0, is_chronic_failure=False,
244
+ )
245
+ n_high, n_medium, n_low, reliability = row
246
+ n_total = n_high + n_medium + n_low
247
+ return DocPrior(
248
+ doc_id=doc_id,
249
+ n_high=n_high,
250
+ n_medium=n_medium,
251
+ n_low=n_low,
252
+ n_total=n_total,
253
+ reliability=reliability,
254
+ is_chronic_failure=(
255
+ n_low / n_total > self.CHRONIC_FAILURE_THRESHOLD
256
+ if n_total > 0
257
+ else False
258
+ ),
259
+ )
260
+
261
+ def confusion_registry(
262
+ self, min_confusions: int = 1
263
+ ) -> list[ConfusionPair]:
264
+ """Return known confusion pairs with >= min_confusions."""
265
+ rows = self._conn.execute(
266
+ "SELECT doc_a, doc_b, n_confusions, first_seen, last_seen "
267
+ "FROM confusion_pairs WHERE n_confusions >= ? "
268
+ "ORDER BY n_confusions DESC",
269
+ (min_confusions,),
270
+ ).fetchall()
271
+ return [
272
+ ConfusionPair(
273
+ doc_a=r[0], doc_b=r[1], n_confusions=r[2],
274
+ first_seen=r[3], last_seen=r[4],
275
+ )
276
+ for r in rows
277
+ ]
278
+
279
+ def reliability_map(self) -> dict[str, float]:
280
+ """Return {doc_id: reliability_score} for all tracked docs."""
281
+ rows = self._conn.execute(
282
+ "SELECT doc_id, reliability FROM doc_stats ORDER BY reliability"
283
+ ).fetchall()
284
+ return {r[0]: float(r[1]) for r in rows}
285
+
286
+ def session_history(self, session_id: str) -> list[ConfidenceRecord]:
287
+ """Return all records for a session_id."""
288
+ rows = self._conn.execute(
289
+ "SELECT session_id, query_doc_id, result_doc_id, confidence, "
290
+ "margin, stages_used, constraint_used, correct, ts "
291
+ "FROM retrievals WHERE session_id=? ORDER BY ts",
292
+ (session_id,),
293
+ ).fetchall()
294
+ return [
295
+ ConfidenceRecord(
296
+ session_id=r[0], query_doc_id=r[1], result_doc_id=r[2],
297
+ confidence=r[3], margin=float(r[4]), stages_used=int(r[5]),
298
+ constraint_used=bool(r[6]), correct=bool(r[7]), ts=float(r[8]),
299
+ )
300
+ for r in rows
301
+ ]
302
+
303
+ def n_sessions(self) -> int:
304
+ """Number of distinct sessions recorded."""
305
+ row = self._conn.execute(
306
+ "SELECT COUNT(DISTINCT session_id) FROM retrievals"
307
+ ).fetchone()
308
+ return row[0] if row else 0
309
+
310
+
311
+
312
+ # ── RECENCY-WEIGHTED RELIABILITY ────────────────────────────────
313
+
314
+ def weighted_reliability(
315
+ self,
316
+ doc_id: str,
317
+ decay: float = 0.85,
318
+ ) -> float:
319
+ """
320
+ Exponentially weighted reliability score.
321
+ Newer retrievals have higher weight.
322
+ decay=0.85: a failure 5 sessions ago counts 0.85^5 = 0.44
323
+ of a failure last session.
324
+
325
+ Returns float in [0, 1]. Returns 1.0 if no history.
326
+ """
327
+ rows = self._conn.execute(
328
+ """SELECT correct FROM retrievals
329
+ WHERE query_doc_id=?
330
+ ORDER BY ts ASC""",
331
+ (doc_id,),
332
+ ).fetchall()
333
+ if not rows:
334
+ return 1.0
335
+ history = [bool(r[0]) for r in rows]
336
+ n = len(history)
337
+ weights = [decay ** (n - 1 - i) for i in range(n)]
338
+ total_w = sum(weights)
339
+ score = sum(w * int(h) for w, h in zip(weights, history))
340
+ return round(score / total_w, 6) if total_w > 0 else 1.0
341
+
342
+ # ── INDEX GROWTH REVALIDATION ────────────────────────────────────
343
+
344
+ def on_document_added(
345
+ self,
346
+ new_doc_id: str,
347
+ new_vec, # torch.Tensor [dim]
348
+ hnsw_index, # EngramIndex instance
349
+ revalidation_radius: float = 0.85,
350
+ density_threshold: int = 3,
351
+ ) -> list[str]:
352
+ """
353
+ Call after adding a document to the HNSW index.
354
+ Recomputes local_density for neighbors of new_doc_id
355
+ whose similarity > revalidation_radius.
356
+
357
+ Returns list of doc_ids whose density was updated.
358
+
359
+ Why this matters:
360
+ At N=200, doc_042 is in sparse space (density=1).
361
+ At N=500, doc_201 lands near doc_042 (cosine=0.88).
362
+ doc_042 is now in a denser region — its confidence tier
363
+ has degraded. This method detects and records that.
364
+ """
365
+ import torch
366
+ import torch.nn.functional as F
367
+
368
+ updated = []
369
+ try:
370
+ results = hnsw_index.search(new_vec, top_k=20)
371
+ except Exception:
372
+ return updated
373
+
374
+ for r in results:
375
+ if r.doc_id == new_doc_id:
376
+ continue
377
+ if r.score < revalidation_radius:
378
+ continue
379
+
380
+ # Recompute density for this neighbor
381
+ neighbor_vec = hnsw_index.get_vector(r.doc_id)
382
+ if neighbor_vec is None:
383
+ continue
384
+
385
+ all_results = hnsw_index.search(neighbor_vec, top_k=50)
386
+ new_density = sum(
387
+ 1 for x in all_results
388
+ if x.doc_id != r.doc_id and x.score > revalidation_radius
389
+ )
390
+
391
+ ts = __import__("time").time()
392
+ existing = self._conn.execute(
393
+ "SELECT n_high FROM doc_stats WHERE doc_id=?",
394
+ (r.doc_id,),
395
+ ).fetchone()
396
+
397
+ if existing:
398
+ self._conn.execute(
399
+ "UPDATE doc_stats SET last_updated=? WHERE doc_id=?",
400
+ (ts, r.doc_id),
401
+ )
402
+ else:
403
+ self._conn.execute(
404
+ "INSERT OR IGNORE INTO doc_stats "
405
+ "(doc_id, n_high, n_medium, n_low, reliability, last_updated) "
406
+ "VALUES (?,0,0,0,1.0,?)",
407
+ (r.doc_id, ts),
408
+ )
409
+
410
+ # If density crossed threshold, register as needing constraint activation
411
+ if new_density > density_threshold:
412
+ self._register_confusion(r.doc_id, new_doc_id, ts)
413
+
414
+ updated.append(r.doc_id)
415
+ self._conn.commit()
416
+
417
+ return updated
418
+
419
+ def close(self) -> None:
420
+ self._conn.close()
421
+
422
+ def __repr__(self) -> str:
423
+ n = self._conn.execute(
424
+ "SELECT COUNT(*) FROM retrievals"
425
+ ).fetchone()[0]
426
+ return f"IndexC(db={self._db_path!r}, n_records={n})"
kvcos/engram/knowledge_index.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/knowledge_index.py — HNSW index over the knowledge store.
3
+
4
+ Builds and maintains a faiss HNSW index over all .eng files in
5
+ ~/.engram/knowledge/. Supports dynamic dimension (384 for sbert,
6
+ 2048 for llama_cpp/hash) — determined at build time from the first
7
+ .eng file.
8
+
9
+ Usage:
10
+ # Build from all knowledge .eng files
11
+ kidx = KnowledgeIndex.build_from_knowledge_dir()
12
+ results = kidx.search("HNSW recall benchmark", k=5)
13
+ kidx.save()
14
+
15
+ # Load pre-built index
16
+ kidx = KnowledgeIndex.load()
17
+ results = kidx.search("testing patterns", k=3)
18
+
19
+ Index files:
20
+ ~/.engram/index/knowledge.faiss
21
+ ~/.engram/index/knowledge.meta
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import json
27
+ import logging
28
+ import os
29
+ from dataclasses import dataclass
30
+ from pathlib import Path
31
+
32
+ import faiss
33
+ import numpy as np
34
+ import torch
35
+ import torch.nn.functional as F
36
+
37
+ from kvcos.engram.embedder import get_fingerprint
38
+ from kvcos.engram.format import EigramEncoder
39
+
40
+ logger = logging.getLogger(__name__)
41
+
42
+
43
+ INDEX_DIR = Path(
44
+ os.environ.get("ENGRAM_INDEX_DIR", "~/.engram/index")
45
+ ).expanduser()
46
+
47
+ KNOWLEDGE_DIR = Path(
48
+ os.environ.get("ENGRAM_KNOWLEDGE_DIR", "~/.engram/knowledge")
49
+ ).expanduser()
50
+
51
+ INDEX_NAME = "knowledge"
52
+
53
+ _encoder = EigramEncoder()
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class KnowledgeResult:
58
+ """Single search result from the knowledge index."""
59
+ doc_id: str
60
+ score: float
61
+ rank: int
62
+ source_path: str
63
+ project: str
64
+ content: str
65
+ chunk_info: str # "2/5" format
66
+ headers: list[str]
67
+ margin: float = 0.0
68
+
69
+
70
+ class KnowledgeIndex:
71
+ """HNSW index over the ENGRAM knowledge store.
72
+
73
+ Parameters match EngramIndex for consistency:
74
+ M=32, efConstruction=200, efSearch=64
75
+ """
76
+
77
+ M = 32
78
+ EF_CONSTRUCTION = 200
79
+ EF_SEARCH = 64
80
+
81
+ def __init__(self, dim: int = 384) -> None:
82
+ self._dim = dim
83
+ self._index: faiss.IndexHNSWFlat | None = None
84
+ self._meta: list[dict] = [] # per-vector metadata
85
+ self._n_docs: int = 0
86
+
87
+ @classmethod
88
+ def build_from_knowledge_dir(
89
+ cls,
90
+ knowledge_dir: Path | None = None,
91
+ verbose: bool = True,
92
+ ) -> KnowledgeIndex:
93
+ """Build HNSW index from all .eng files in the knowledge directory."""
94
+ if knowledge_dir is None:
95
+ knowledge_dir = KNOWLEDGE_DIR
96
+
97
+ eng_files = sorted(knowledge_dir.rglob("*.eng"), key=os.path.getmtime)
98
+ eng_files = [p for p in eng_files if p.suffix == ".eng"]
99
+
100
+ if not eng_files:
101
+ raise ValueError(f"No .eng files found in {knowledge_dir}")
102
+
103
+ vectors: list[torch.Tensor] = []
104
+ metas: list[dict] = []
105
+ skipped = 0
106
+
107
+ for p in eng_files:
108
+ try:
109
+ data = _encoder.decode(p.read_bytes())
110
+
111
+ fp = data.get("vec_fourier_v2")
112
+ if fp is None:
113
+ fp = data.get("vec_fourier")
114
+ if fp is None:
115
+ skipped += 1
116
+ continue
117
+
118
+ # Load sidecar metadata
119
+ meta_path = Path(str(p) + ".meta.json")
120
+ meta = {}
121
+ if meta_path.exists():
122
+ meta = json.loads(meta_path.read_text())
123
+
124
+ # Use sidecar description if longer than binary
125
+ description = meta.get("task_description", "") or \
126
+ data.get("task_description", "")
127
+
128
+ vectors.append(fp.float())
129
+ metas.append({
130
+ "doc_id": data.get("cache_id", p.stem),
131
+ "source_path": meta.get("source_path", ""),
132
+ "project": meta.get("project", ""),
133
+ "content": description,
134
+ "chunk_index": meta.get("chunk_index", 0),
135
+ "chunk_total": meta.get("chunk_total", 1),
136
+ "headers": meta.get("headers", []),
137
+ "fp_source": meta.get("fp_source", "unknown"),
138
+ })
139
+ except Exception as exc:
140
+ logger.debug("Skipping %s: %s", p, exc)
141
+ skipped += 1
142
+
143
+ if not vectors:
144
+ raise ValueError(
145
+ f"No valid fingerprints in {len(eng_files)} .eng files"
146
+ )
147
+
148
+ # Stack and determine dimension from actual data
149
+ matrix = torch.stack(vectors)
150
+ dim = matrix.shape[1]
151
+
152
+ # Normalize for cosine similarity via L2
153
+ matrix = F.normalize(matrix, dim=-1).numpy().astype("float32")
154
+
155
+ # Build HNSW
156
+ obj = cls(dim=dim)
157
+ obj._index = faiss.IndexHNSWFlat(dim, cls.M)
158
+ obj._index.hnsw.efConstruction = cls.EF_CONSTRUCTION
159
+ obj._index.hnsw.efSearch = cls.EF_SEARCH
160
+ obj._index.add(matrix)
161
+ obj._meta = metas
162
+ obj._n_docs = len(metas)
163
+
164
+ if verbose:
165
+ projects = {m["project"] for m in metas}
166
+ logger.info("Knowledge HNSW: %d vectors, dim=%d", obj._n_docs, dim)
167
+ logger.info("Projects: %s", sorted(projects))
168
+ if skipped:
169
+ logger.warning("Skipped: %d files (no fingerprint)", skipped)
170
+
171
+ return obj
172
+
173
+ def search(
174
+ self,
175
+ query: str | torch.Tensor,
176
+ k: int = 5,
177
+ ) -> list[KnowledgeResult]:
178
+ """
179
+ Search the knowledge index.
180
+
181
+ Args:
182
+ query: Search text (will be fingerprinted) or pre-computed tensor.
183
+ k: Number of results to return.
184
+
185
+ Returns:
186
+ List of KnowledgeResult sorted by score descending.
187
+ """
188
+ if self._index is None:
189
+ raise RuntimeError("Index not built. Call build_from_knowledge_dir() first.")
190
+
191
+ if isinstance(query, str):
192
+ query_fp, _ = get_fingerprint(query)
193
+ else:
194
+ query_fp = query
195
+
196
+ qn = F.normalize(
197
+ query_fp.float().unsqueeze(0), dim=-1
198
+ ).numpy().astype("float32")
199
+
200
+ top = min(k + 1, self._n_docs)
201
+ D, I = self._index.search(qn, top)
202
+
203
+ results: list[KnowledgeResult] = []
204
+ for rank, (dist, idx) in enumerate(zip(D[0], I[0])):
205
+ if idx < 0 or idx >= len(self._meta):
206
+ continue
207
+ meta = self._meta[idx]
208
+ cosine = float(1.0 - dist / 2.0)
209
+ ci = meta.get("chunk_index", 0)
210
+ ct = meta.get("chunk_total", 1)
211
+
212
+ results.append(KnowledgeResult(
213
+ doc_id=meta["doc_id"],
214
+ score=cosine,
215
+ rank=rank,
216
+ source_path=meta.get("source_path", ""),
217
+ project=meta.get("project", ""),
218
+ content=meta.get("content", ""),
219
+ chunk_info=f"{ci + 1}/{ct}",
220
+ headers=meta.get("headers", []),
221
+ ))
222
+
223
+ # Set margin on top result
224
+ if len(results) >= 2:
225
+ results[0] = KnowledgeResult(
226
+ doc_id=results[0].doc_id,
227
+ score=results[0].score,
228
+ rank=results[0].rank,
229
+ source_path=results[0].source_path,
230
+ project=results[0].project,
231
+ content=results[0].content,
232
+ chunk_info=results[0].chunk_info,
233
+ headers=results[0].headers,
234
+ margin=results[0].score - results[1].score,
235
+ )
236
+
237
+ return results[:k]
238
+
239
+ def save(self, index_dir: Path | None = None) -> Path:
240
+ """Save index to disk."""
241
+ if index_dir is None:
242
+ index_dir = INDEX_DIR
243
+ index_dir.mkdir(parents=True, exist_ok=True)
244
+
245
+ faiss_path = index_dir / f"{INDEX_NAME}.faiss"
246
+ meta_path = index_dir / f"{INDEX_NAME}.meta.json"
247
+
248
+ faiss.write_index(self._index, str(faiss_path))
249
+ with open(meta_path, "w") as f:
250
+ json.dump({
251
+ "meta": self._meta,
252
+ "dim": self._dim,
253
+ "n_docs": self._n_docs,
254
+ }, f, indent=2)
255
+
256
+ return faiss_path
257
+
258
+ @classmethod
259
+ def load(cls, index_dir: Path | None = None) -> KnowledgeIndex:
260
+ """Load pre-built index from disk."""
261
+ if index_dir is None:
262
+ index_dir = INDEX_DIR
263
+
264
+ faiss_path = index_dir / f"{INDEX_NAME}.faiss"
265
+ meta_path = index_dir / f"{INDEX_NAME}.meta.json"
266
+
267
+ if not faiss_path.exists():
268
+ raise FileNotFoundError(
269
+ f"No knowledge index at {faiss_path}. "
270
+ "Build with KnowledgeIndex.build_from_knowledge_dir()"
271
+ )
272
+
273
+ obj = cls()
274
+ obj._index = faiss.read_index(str(faiss_path))
275
+ with open(meta_path, "r") as f:
276
+ data = json.load(f)
277
+ obj._meta = data["meta"]
278
+ obj._dim = data["dim"]
279
+ obj._n_docs = data["n_docs"]
280
+ return obj
281
+
282
+ def __len__(self) -> int:
283
+ return self._n_docs
284
+
285
+ def __repr__(self) -> str:
286
+ return (
287
+ f"KnowledgeIndex(n={self._n_docs}, dim={self._dim}, "
288
+ f"M={self.M})"
289
+ )
kvcos/engram/manifest.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/manifest.py — Knowledge index manifest registry.
3
+
4
+ Tracks which source files have been indexed into .eng files,
5
+ their content hashes for incremental re-indexing, and chunk
6
+ metadata for multi-chunk files.
7
+
8
+ Storage: JSON file at ~/.engram/manifest.json (human-readable,
9
+ git-friendly, easily inspectable).
10
+
11
+ Thread safety: reads are lock-free, writes use atomic rename.
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import hashlib
17
+ import json
18
+ import os
19
+ import tempfile
20
+ import time
21
+ from dataclasses import asdict, dataclass, field
22
+ from pathlib import Path
23
+ from typing import Iterator
24
+
25
+
26
+ @dataclass(frozen=True)
27
+ class ChunkRecord:
28
+ """One indexed chunk from a source file."""
29
+ eng_path: str # Absolute path to .eng file
30
+ chunk_index: int # 0-based chunk index within source
31
+ chunk_total: int # Total chunks for this source
32
+ char_start: int # Start offset in source content
33
+ char_end: int # End offset in source content
34
+ indexed_at: float # Unix timestamp of indexing
35
+
36
+
37
+ @dataclass(frozen=True)
38
+ class SourceRecord:
39
+ """Registry entry for one indexed source file."""
40
+ source_path: str # Absolute path to original .md file
41
+ content_hash: str # SHA-256 of file content at index time
42
+ project: str # Project namespace (e.g., "engram", "_global")
43
+ file_size: int # Bytes at index time
44
+ chunks: tuple[ChunkRecord, ...] = ()
45
+ indexed_at: float = 0.0
46
+ last_verified: float = 0.0
47
+
48
+ @property
49
+ def eng_paths(self) -> list[str]:
50
+ """All .eng file paths for this source."""
51
+ return [c.eng_path for c in self.chunks]
52
+
53
+
54
+ def _content_hash(content: str) -> str:
55
+ """SHA-256 hex digest of string content."""
56
+ return hashlib.sha256(content.encode("utf-8")).hexdigest()
57
+
58
+
59
+ def _file_hash(path: Path) -> str:
60
+ """SHA-256 hex digest of file on disk."""
61
+ return hashlib.sha256(path.read_bytes()).hexdigest()
62
+
63
+
64
+ class Manifest:
65
+ """
66
+ Knowledge index manifest — tracks source-to-.eng mappings.
67
+
68
+ Immutable-style operations: all mutations return new state
69
+ and write atomically to disk.
70
+
71
+ Usage:
72
+ m = Manifest.load()
73
+ m = m.register(source_path, content_hash, project, chunks)
74
+ # m is now updated and persisted to disk
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ records: dict[str, SourceRecord],
80
+ manifest_path: Path,
81
+ ) -> None:
82
+ self._records = dict(records) # defensive copy
83
+ self._path = manifest_path
84
+
85
+ @classmethod
86
+ def load(cls, manifest_path: Path | None = None) -> Manifest:
87
+ """Load manifest from disk, or create empty if not found."""
88
+ if manifest_path is None:
89
+ manifest_path = Path(
90
+ os.environ.get("ENGRAM_MANIFEST_PATH",
91
+ "~/.engram/manifest.json")
92
+ ).expanduser()
93
+
94
+ if manifest_path.exists():
95
+ data = json.loads(manifest_path.read_text())
96
+ records = {}
97
+ for key, rec_data in data.get("sources", {}).items():
98
+ chunks = tuple(
99
+ ChunkRecord(**c) for c in rec_data.pop("chunks", [])
100
+ )
101
+ records[key] = SourceRecord(**rec_data, chunks=chunks)
102
+ return cls(records, manifest_path)
103
+
104
+ return cls({}, manifest_path)
105
+
106
+ def register(
107
+ self,
108
+ source_path: str,
109
+ content_hash: str,
110
+ project: str,
111
+ file_size: int,
112
+ chunks: list[ChunkRecord],
113
+ ) -> Manifest:
114
+ """
115
+ Register a newly indexed source file. Returns updated Manifest.
116
+
117
+ Overwrites any existing record for the same source_path
118
+ (re-index scenario).
119
+ """
120
+ now = time.time()
121
+ record = SourceRecord(
122
+ source_path=source_path,
123
+ content_hash=content_hash,
124
+ project=project,
125
+ file_size=file_size,
126
+ chunks=tuple(chunks),
127
+ indexed_at=now,
128
+ last_verified=now,
129
+ )
130
+
131
+ new_records = dict(self._records)
132
+ new_records[source_path] = record
133
+
134
+ new_manifest = Manifest(new_records, self._path)
135
+ new_manifest._persist()
136
+ return new_manifest
137
+
138
+ def unregister(self, source_path: str) -> Manifest:
139
+ """Remove a source from the manifest. Returns updated Manifest."""
140
+ new_records = {
141
+ k: v for k, v in self._records.items()
142
+ if k != source_path
143
+ }
144
+ new_manifest = Manifest(new_records, self._path)
145
+ new_manifest._persist()
146
+ return new_manifest
147
+
148
+ def needs_reindex(self, source_path: str, current_hash: str) -> bool:
149
+ """Check if a source file needs re-indexing (content changed)."""
150
+ record = self._records.get(source_path)
151
+ if record is None:
152
+ return True
153
+ return record.content_hash != current_hash
154
+
155
+ def get_record(self, source_path: str) -> SourceRecord | None:
156
+ """Look up a source record by path."""
157
+ return self._records.get(source_path)
158
+
159
+ def get_project_records(self, project: str) -> list[SourceRecord]:
160
+ """All records for a given project namespace."""
161
+ return [
162
+ r for r in self._records.values()
163
+ if r.project == project
164
+ ]
165
+
166
+ def all_records(self) -> Iterator[SourceRecord]:
167
+ """Iterate over all registered source records."""
168
+ yield from self._records.values()
169
+
170
+ @property
171
+ def total_sources(self) -> int:
172
+ return len(self._records)
173
+
174
+ @property
175
+ def total_chunks(self) -> int:
176
+ return sum(len(r.chunks) for r in self._records.values())
177
+
178
+ @property
179
+ def projects(self) -> set[str]:
180
+ return {r.project for r in self._records.values()}
181
+
182
+ def summary(self) -> dict:
183
+ """Quick stats for display."""
184
+ return {
185
+ "total_sources": self.total_sources,
186
+ "total_chunks": self.total_chunks,
187
+ "projects": sorted(self.projects),
188
+ "manifest_path": str(self._path),
189
+ }
190
+
191
+ def _persist(self) -> None:
192
+ """Atomic write to disk via tempfile + rename."""
193
+ self._path.parent.mkdir(parents=True, exist_ok=True)
194
+
195
+ serializable = {
196
+ "version": 1,
197
+ "updated_at": time.time(),
198
+ "sources": {},
199
+ }
200
+ for key, rec in self._records.items():
201
+ rec_dict = asdict(rec)
202
+ serializable["sources"][key] = rec_dict
203
+
204
+ # Atomic write: write to temp, then rename
205
+ fd, tmp_path = tempfile.mkstemp(
206
+ dir=str(self._path.parent),
207
+ suffix=".tmp",
208
+ )
209
+ try:
210
+ with os.fdopen(fd, "w") as f:
211
+ json.dump(serializable, f, indent=2)
212
+ os.replace(tmp_path, str(self._path))
213
+ except Exception:
214
+ # Clean up temp file on failure
215
+ try:
216
+ os.unlink(tmp_path)
217
+ except OSError:
218
+ pass
219
+ raise
220
+
221
+ def __len__(self) -> int:
222
+ return self.total_sources
223
+
224
+ def __contains__(self, source_path: str) -> bool:
225
+ return source_path in self._records
226
+
227
+ def __repr__(self) -> str:
228
+ return (
229
+ f"Manifest({self.total_sources} sources, "
230
+ f"{self.total_chunks} chunks, "
231
+ f"projects={sorted(self.projects)})"
232
+ )
kvcos/engram/metadata_disambiguate.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/metadata_disambiguate.py
3
+
4
+ Stage 4 retrieval: activates when the fingerprint pipeline returns LOW.
5
+ Uses .eng metadata fields (domain, context_len, l2_norm, task_description)
6
+ to break ties that the Fourier fingerprint cannot resolve.
7
+
8
+ Returns Stage4Result with confidence='low-metadata' and metadata_used=True.
9
+
10
+ Design note (VRCM source):
11
+ When constraint satisfaction fails on the fingerprint axis, switch to
12
+ orthogonal axes — metadata fields that are independent of spectral structure.
13
+ The medicine/biology failure exists because their f0+f1 profiles are
14
+ spectrally identical. Their metadata is NOT identical: context_len differs,
15
+ l2_norm differs, task_description keywords differ. That orthogonal signal
16
+ is what Stage 4 exploits.
17
+ """
18
+
19
+ from __future__ import annotations
20
+ import re
21
+ from dataclasses import dataclass
22
+
23
+
24
+ @dataclass
25
+ class Stage4Result:
26
+ doc_id: str
27
+ meta_score: float
28
+ confidence: str = 'low-metadata'
29
+ metadata_used: bool = True
30
+ domain_matched: bool = False
31
+ score_breakdown: dict = None
32
+
33
+ def __post_init__(self):
34
+ if self.score_breakdown is None:
35
+ self.score_breakdown = {}
36
+
37
+
38
+ def _keyword_overlap(text_a: str, text_b: str) -> float:
39
+ """Jaccard overlap on lowercase word sets, excluding stopwords."""
40
+ STOPWORDS = {
41
+ 'the', 'a', 'an', 'is', 'are', 'was', 'were', 'in', 'of',
42
+ 'to', 'and', 'or', 'that', 'this', 'it', 'with', 'for',
43
+ 'on', 'at', 'by', 'from', 'be', 'has', 'have', 'had',
44
+ }
45
+ def words(t):
46
+ return set(w for w in re.sub(r'[^a-z0-9 ]', ' ',
47
+ (t or '').lower()).split()
48
+ if w not in STOPWORDS and len(w) > 2)
49
+ a, b = words(text_a), words(text_b)
50
+ if not a or not b:
51
+ return 0.0
52
+ return len(a & b) / len(a | b)
53
+
54
+
55
+ def metadata_disambiguate(
56
+ candidates: list[dict],
57
+ query_metadata: dict,
58
+ domain_bonus: float = 0.3,
59
+ max_len: int = 8192,
60
+ max_norm: float = 10.0,
61
+ ) -> Stage4Result | None:
62
+ """
63
+ Stage 4 disambiguation using .eng metadata fields.
64
+
65
+ Args:
66
+ candidates: list of dicts from eng_index values.
67
+ Each must have: cache_id, task_description,
68
+ context_len (optional), l2_norm (optional),
69
+ metadata dict with domain (optional).
70
+ query_metadata: dict with same structure as one candidate.
71
+ domain_bonus: score added for exact domain match (default 0.3).
72
+ max_len: normalisation constant for context_len diff.
73
+ max_norm: normalisation constant for l2_norm diff.
74
+
75
+ Returns:
76
+ Stage4Result for the highest meta-scoring candidate, or None
77
+ if candidates list is empty.
78
+ """
79
+ if not candidates:
80
+ return None
81
+
82
+ best: Stage4Result | None = None
83
+
84
+ q_domain = (query_metadata.get('metadata') or {}).get('domain', '')
85
+ q_len = float(query_metadata.get('context_len') or 512)
86
+ q_norm = float(query_metadata.get('l2_norm') or 1.0)
87
+ q_desc = (query_metadata.get('task_description') or '')[:80]
88
+
89
+ for cand in candidates:
90
+ c_domain = (cand.get('metadata') or {}).get('domain', '')
91
+ c_len = float(cand.get('context_len') or 512)
92
+ c_norm = float(cand.get('l2_norm') or 1.0)
93
+ c_desc = (cand.get('task_description') or '')[:80]
94
+ c_id = cand.get('cache_id', '')
95
+
96
+ domain_match = (q_domain and c_domain and q_domain == c_domain)
97
+ domain_score = domain_bonus if domain_match else 0.0
98
+ len_score = 1.0 - min(abs(q_len - c_len) / max(max_len, 1), 1.0)
99
+ norm_score = 1.0 - min(abs(q_norm - c_norm) / max(max_norm, 1), 1.0)
100
+ kw_score = _keyword_overlap(q_desc, c_desc)
101
+ meta_score = domain_score + len_score + norm_score + kw_score
102
+
103
+ r = Stage4Result(
104
+ doc_id = c_id,
105
+ meta_score = meta_score,
106
+ confidence = 'low-metadata',
107
+ metadata_used = True,
108
+ domain_matched = domain_match,
109
+ score_breakdown = {
110
+ 'domain': round(domain_score, 3),
111
+ 'len': round(len_score, 3),
112
+ 'norm': round(norm_score, 3),
113
+ 'kw': round(kw_score, 3),
114
+ },
115
+ )
116
+ if best is None or meta_score > best.meta_score:
117
+ best = r
118
+
119
+ return best
kvcos/engram/reader.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EIGENGRAM reader: .eng file -> IndexEntry + fingerprint vectors
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ from pathlib import Path
8
+
9
+ from .format import EigramDecoder
10
+ from kvcos.core.manifold_index import IndexEntry
11
+
12
+ _decoder = EigramDecoder()
13
+
14
+
15
+ def read_eigengram(path: str) -> dict:
16
+ """Read a .eng file and return decoded fields."""
17
+ if not Path(path).exists():
18
+ raise FileNotFoundError(f"EIGENGRAM not found: {path}")
19
+ data = Path(path).read_bytes()
20
+ return _decoder.decode(data)
21
+
22
+
23
+ def load_eigengram_index(
24
+ paths: list[str],
25
+ fingerprint: str = "perdoc",
26
+ ) -> tuple[list, list]:
27
+ """Load multiple .eng files for ManifoldIndex.
28
+
29
+ fingerprint: 'perdoc' (same-model) | 'fcdb' (cross-model)
30
+
31
+ Returns (vecs, entries) ready for ManifoldIndex.add().
32
+ """
33
+ if fingerprint not in ("perdoc", "fcdb", "fourier"):
34
+ raise ValueError(f"fingerprint must be 'perdoc', 'fcdb', or 'fourier', got '{fingerprint}'")
35
+
36
+ vecs = []
37
+ entries = []
38
+ key = f"vec_{fingerprint}"
39
+
40
+ for path in paths:
41
+ rec = read_eigengram(path)
42
+ vecs.append(rec[key])
43
+ entries.append(
44
+ IndexEntry(
45
+ cache_id=rec["cache_id"],
46
+ task_description=rec["task_description"],
47
+ model_id=rec["model_id"],
48
+ created_at=rec["created_at"],
49
+ context_len=rec["context_len"],
50
+ l2_norm=rec["l2_norm"],
51
+ )
52
+ )
53
+
54
+ return vecs, entries
kvcos/engram/retrieval.py ADDED
@@ -0,0 +1,513 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM constrained retrieval — apophatic negative constraint layer.
3
+
4
+ Implements constrained_retrieve() which penalizes candidates too
5
+ similar to known confusion partners, resolving dense-region failures.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ from dataclasses import dataclass, field
11
+
12
+ import torch
13
+ import torch.nn.functional as F
14
+
15
+
16
+ @dataclass
17
+ class CosineResult:
18
+ """Single retrieval result."""
19
+
20
+ doc_id: str
21
+ score: float
22
+ cos_score: float
23
+ margin: float = 0.0
24
+ constrained: bool = False
25
+
26
+
27
+ @dataclass
28
+ class EngramQuery:
29
+ """Query with optional negative constraints.
30
+
31
+ like: Fingerprint to match (positive constraint).
32
+ unlike: Fingerprints to avoid (negative constraints).
33
+ min_margin: Minimum acceptable score gap.
34
+ fingerprint: Which fingerprint field to use ('fourier', 'fcdb', 'perdoc').
35
+ """
36
+
37
+ like: torch.Tensor
38
+ unlike: list[torch.Tensor] = field(default_factory=list)
39
+ min_margin: float = 0.001
40
+ domain_hint: str | None = None
41
+ fingerprint: str = "fourier"
42
+
43
+
44
+ def cosine_search(
45
+ query_fp: torch.Tensor,
46
+ index: dict[str, torch.Tensor],
47
+ top_k: int = 5,
48
+ ) -> list[CosineResult]:
49
+ """Standard unconstrained cosine similarity search."""
50
+ if not index:
51
+ return []
52
+
53
+ doc_ids = list(index.keys())
54
+ matrix = torch.stack([index[d] for d in doc_ids])
55
+ qn = F.normalize(query_fp.unsqueeze(0).float(), dim=-1)
56
+ mn = F.normalize(matrix.float(), dim=-1)
57
+ sims = (qn @ mn.T).squeeze(0)
58
+
59
+ top_indices = sims.topk(min(top_k, len(doc_ids))).indices.tolist()
60
+ results = [
61
+ CosineResult(
62
+ doc_id=doc_ids[i],
63
+ score=float(sims[i].item()),
64
+ cos_score=float(sims[i].item()),
65
+ )
66
+ for i in top_indices
67
+ ]
68
+ if len(results) >= 2:
69
+ results[0].margin = results[0].score - results[1].score
70
+ return results
71
+
72
+
73
+ def constrained_retrieve(
74
+ query: EngramQuery,
75
+ index: dict[str, torch.Tensor],
76
+ top_k: int = 5,
77
+ neg_weight: float = 0.5,
78
+ neg_threshold: float = 0.85,
79
+ ) -> list[CosineResult]:
80
+ """Retrieval with negative (apophatic) constraint layer.
81
+
82
+ Penalizes candidates too similar to `unlike` fingerprints.
83
+ In dense regions, this discriminates between docs that would
84
+ otherwise have identical cosine scores.
85
+
86
+ Algorithm:
87
+ 1. Compute cosine similarity to query (positive score)
88
+ 2. For each unlike fingerprint, compute sim to each candidate
89
+ 3. Subtract penalty: neg_weight * max(0, sim_to_unlike - threshold)
90
+ 4. Sort by adjusted score
91
+ """
92
+ if not index:
93
+ return []
94
+
95
+ doc_ids = list(index.keys())
96
+ matrix = torch.stack([index[d] for d in doc_ids])
97
+ qn = F.normalize(query.like.unsqueeze(0).float(), dim=-1)
98
+ mn = F.normalize(matrix.float(), dim=-1)
99
+ cos_scores = (qn @ mn.T).squeeze(0)
100
+
101
+ adjusted = cos_scores.clone()
102
+
103
+ if query.unlike:
104
+ for unlike_fp in query.unlike:
105
+ un = F.normalize(unlike_fp.unsqueeze(0).float(), dim=-1)
106
+ neg_sims = (un @ mn.T).squeeze(0)
107
+ penalty = neg_weight * torch.clamp(neg_sims - neg_threshold, min=0)
108
+ adjusted = adjusted - penalty
109
+
110
+ top_indices = adjusted.topk(min(top_k, len(doc_ids))).indices.tolist()
111
+ results = [
112
+ CosineResult(
113
+ doc_id=doc_ids[i],
114
+ score=float(adjusted[i].item()),
115
+ cos_score=float(cos_scores[i].item()),
116
+ constrained=bool(query.unlike),
117
+ )
118
+ for i in top_indices
119
+ ]
120
+ if len(results) >= 2:
121
+ results[0].margin = results[0].score - results[1].score
122
+ return results
123
+
124
+
125
+ # ── TWO-STAGE GEODESIC RETRIEVAL ──────────────────────────────────────
126
+
127
+ from enum import Enum
128
+
129
+
130
+ class RetrievalConfidence(Enum):
131
+ HIGH = "high" # margin > 5x threshold, single pass sufficient
132
+ MEDIUM = "medium" # margin > threshold, or resolved by stage-2
133
+ LOW = "low" # margin < threshold after stage-2 — uncertain
134
+
135
+
136
+ @dataclass
137
+ class GeodesicResult:
138
+ """Result from geodesic_retrieve()."""
139
+
140
+ doc_id: str
141
+ score: float
142
+ margin: float
143
+ confidence: RetrievalConfidence
144
+ stages_used: int = 1 # 1 = single pass, 2 = two-stage, 3 = constrained
145
+ constraint_used: bool = False
146
+ stage4_used: bool = False
147
+ stage4_doc_id: str = ""
148
+
149
+
150
+ def geodesic_retrieve(
151
+ query_fp: torch.Tensor,
152
+ hnsw_index, # EngramIndex instance
153
+ eng_index: dict, # {doc_id: eng_data} for constraint layer
154
+ margin_threshold: float = 0.005,
155
+ correction_weight: float = 0.3,
156
+ top_k: int = 5,
157
+ ) -> GeodesicResult:
158
+ """
159
+ Two-stage geodesic retrieval with automatic confidence scoring.
160
+
161
+ Stage 1: HNSW approximate nearest-neighbor search.
162
+ If margin(top1, top2) >= margin_threshold -> HIGH or MEDIUM.
163
+ Return immediately.
164
+
165
+ Stage 2: Activated when margin < margin_threshold.
166
+ Interpolate query fingerprint toward Stage-1 top-1 result.
167
+ The interpolation weight (correction_weight=0.3) bends the
168
+ geodesic toward the probable destination without assuming
169
+ the Stage-1 answer is correct.
170
+ If Stage-2 margin >= threshold -> MEDIUM confidence.
171
+ If Stage-2 margin still < threshold -> LOW confidence.
172
+
173
+ Stage 3: If confusion_flag is set on Stage-2 top result AND
174
+ eng_index is provided -> activate negative constraint.
175
+ Uses the confusion partner fingerprint as unlike constraint.
176
+
177
+ Args:
178
+ query_fp: [dim] query fingerprint (v2 recommended).
179
+ hnsw_index: Built EngramIndex instance.
180
+ eng_index: Dict {doc_id: eng_data} loaded from .eng files.
181
+ Used for Stage-2 interpolation and Stage-3
182
+ constraint layer. Pass empty dict {} to disable.
183
+ margin_threshold: Minimum margin for MEDIUM confidence.
184
+ Default 0.005 (below S3 mean margin of 0.009).
185
+ correction_weight: Interpolation weight for Stage-2 trajectory
186
+ correction. 0.3 = 30% pull toward top-1.
187
+ Range: 0.1 (gentle) to 0.5 (aggressive).
188
+ top_k: Candidates per search pass.
189
+
190
+ Returns:
191
+ GeodesicResult with doc_id, score, margin, confidence, stages_used.
192
+
193
+ Usage:
194
+ result = geodesic_retrieve(query_fp, idx, eng_index={})
195
+ if result.confidence == RetrievalConfidence.LOW:
196
+ # Flag for human review or return with uncertainty warning
197
+ pass
198
+ """
199
+ # Stage 1: HNSW search
200
+ s1_results = hnsw_index.search(query_fp, top_k=top_k)
201
+ if len(s1_results) < 2:
202
+ return GeodesicResult(
203
+ doc_id=s1_results[0].doc_id if s1_results else "",
204
+ score=s1_results[0].score if s1_results else 0.0,
205
+ margin=0.0,
206
+ confidence=RetrievalConfidence.LOW,
207
+ stages_used=1,
208
+ )
209
+
210
+ s1_margin = s1_results[0].margin
211
+
212
+ # High confidence: single pass sufficient
213
+ if s1_margin >= margin_threshold * 5:
214
+ return GeodesicResult(
215
+ doc_id=s1_results[0].doc_id,
216
+ score=s1_results[0].score,
217
+ margin=s1_margin,
218
+ confidence=RetrievalConfidence.HIGH,
219
+ stages_used=1,
220
+ )
221
+
222
+ # Medium confidence: above threshold but not high
223
+ if s1_margin >= margin_threshold:
224
+ return GeodesicResult(
225
+ doc_id=s1_results[0].doc_id,
226
+ score=s1_results[0].score,
227
+ margin=s1_margin,
228
+ confidence=RetrievalConfidence.MEDIUM,
229
+ stages_used=1,
230
+ )
231
+
232
+ # Stage 2: trajectory correction
233
+ # Retrieve top-1 fingerprint from eng_index for interpolation
234
+ top1_id = s1_results[0].doc_id
235
+ top1_eng = eng_index.get(top1_id, {})
236
+ top1_fp = top1_eng.get("vec_fourier_v2")
237
+ if top1_fp is None:
238
+ top1_fp = top1_eng.get("vec_fourier")
239
+
240
+ if top1_fp is not None:
241
+ # Bend geodesic toward Stage-1 top-1
242
+ refined_fp = F.normalize(
243
+ (1 - correction_weight) * query_fp.float()
244
+ + correction_weight * top1_fp.float(),
245
+ dim=-1,
246
+ )
247
+ s2_results = hnsw_index.search(refined_fp, top_k=top_k)
248
+ s2_margin = s2_results[0].margin if len(s2_results) >= 2 else 0.0
249
+
250
+ # Stage 3: check confusion_flag on Stage-2 top result
251
+ s2_top_id = s2_results[0].doc_id if s2_results else top1_id
252
+ s2_top_eng = eng_index.get(s2_top_id, {})
253
+
254
+ if s2_top_eng.get("confusion_flag") and eng_index:
255
+ # Activate negative constraint: find confusion partner fps
256
+ def _pick_fp(d: dict) -> torch.Tensor | None:
257
+ v = d.get("vec_fourier_v2")
258
+ return v if v is not None else d.get("vec_fourier")
259
+
260
+ confusion_fps = [
261
+ _pick_fp(d)
262
+ for did, d in eng_index.items()
263
+ if d.get("confusion_flag")
264
+ and did != s2_top_id
265
+ and _pick_fp(d) is not None
266
+ ]
267
+ if confusion_fps:
268
+ # Build flat index for constrained_retrieve
269
+ flat_index = {
270
+ did: _pick_fp(d)
271
+ for did, d in eng_index.items()
272
+ if _pick_fp(d) is not None
273
+ }
274
+ q_constrained = EngramQuery(
275
+ like=refined_fp,
276
+ unlike=confusion_fps[:3], # top 3 confusion partners
277
+ min_margin=margin_threshold,
278
+ )
279
+ s3_results = constrained_retrieve(
280
+ q_constrained,
281
+ flat_index,
282
+ )
283
+ if s3_results:
284
+ s3_margin = s3_results[0].margin
285
+ s3_conf = (
286
+ RetrievalConfidence.MEDIUM
287
+ if s3_margin >= margin_threshold
288
+ else RetrievalConfidence.LOW
289
+ )
290
+ return GeodesicResult(
291
+ doc_id=s3_results[0].doc_id,
292
+ score=s3_results[0].score,
293
+ margin=s3_margin,
294
+ confidence=s3_conf,
295
+ stages_used=3,
296
+ constraint_used=True,
297
+ )
298
+
299
+ if s2_margin >= margin_threshold:
300
+ return GeodesicResult(
301
+ doc_id=s2_top_id,
302
+ score=s2_results[0].score,
303
+ margin=s2_margin,
304
+ confidence=RetrievalConfidence.MEDIUM,
305
+ stages_used=2,
306
+ )
307
+ else:
308
+ # Both stages low margin — return LOW confidence
309
+ return GeodesicResult(
310
+ doc_id=s2_top_id,
311
+ score=s2_results[0].score,
312
+ margin=s2_margin,
313
+ confidence=RetrievalConfidence.LOW,
314
+ stages_used=2,
315
+ )
316
+ else:
317
+ # No vector for interpolation — return Stage-1 with LOW confidence
318
+ return GeodesicResult(
319
+ doc_id=top1_id,
320
+ score=s1_results[0].score,
321
+ margin=s1_margin,
322
+ confidence=RetrievalConfidence.LOW,
323
+ stages_used=1,
324
+ )
325
+
326
+
327
+
328
+ def geodesic_retrieve_with_prior(
329
+ query_fp: torch.Tensor,
330
+ hnsw_index,
331
+ eng_index: dict,
332
+ index_c=None,
333
+ query_doc_id: str | None = None,
334
+ margin_threshold: float = 0.005,
335
+ correction_weight: float = 0.3,
336
+ top_k: int = 5,
337
+ ) -> GeodesicResult:
338
+ """
339
+ Prior-aware geodesic retrieval. Uses IndexC history to pre-apply
340
+ constraints on known chronic failures — skipping Stages 1 and 2.
341
+
342
+ When index_c and query_doc_id are provided:
343
+ - If doc is a chronic failure: apply Stage 3 (constraint) immediately.
344
+ This avoids 2 wasted HNSW passes before getting to the constraint.
345
+ - If doc has prior LOW history (not yet chronic): lower threshold.
346
+ - If no prior: standard 3-stage geodesic_retrieve().
347
+
348
+ Args:
349
+ query_fp: [dim] query fingerprint.
350
+ hnsw_index: Built EngramIndex instance.
351
+ eng_index: {doc_id: eng_data} from .eng files.
352
+ index_c: IndexC instance, or None to disable prior mode.
353
+ query_doc_id: doc_id being queried (for prior lookup).
354
+ margin_threshold: Base margin threshold. Lowered if prior is LOW.
355
+ correction_weight: Stage-2 interpolation weight.
356
+ top_k: Candidates per HNSW pass.
357
+
358
+ Returns:
359
+ GeodesicResult. For chronic failures: stages_used=0 (preempted).
360
+ """
361
+ if index_c is None or query_doc_id is None:
362
+ return geodesic_retrieve(
363
+ query_fp, hnsw_index, eng_index,
364
+ margin_threshold=margin_threshold,
365
+ correction_weight=correction_weight,
366
+ top_k=top_k,
367
+ )
368
+
369
+ prior = index_c.prior(query_doc_id)
370
+
371
+ # Preemptive mode: known chronic failure
372
+ if prior.is_chronic_failure and prior.n_total >= 2:
373
+ pairs = index_c.confusion_registry(min_confusions=1)
374
+ partners = [
375
+ p.doc_b for p in pairs if p.doc_a == query_doc_id
376
+ ] + [
377
+ p.doc_a for p in pairs if p.doc_b == query_doc_id
378
+ ]
379
+
380
+ unlike_fps = []
381
+ for partner_id in partners[:3]:
382
+ partner_eng = eng_index.get(partner_id, {})
383
+ fp = partner_eng.get("vec_fourier_v2")
384
+ if fp is None:
385
+ fp = partner_eng.get("vec_fourier")
386
+ if fp is not None:
387
+ unlike_fps.append(fp)
388
+
389
+ if unlike_fps:
390
+ flat_index: dict[str, torch.Tensor] = {}
391
+ for did, d in eng_index.items():
392
+ v = d.get("vec_fourier_v2")
393
+ if v is None:
394
+ v = d.get("vec_fourier")
395
+ if v is not None:
396
+ flat_index[did] = v
397
+
398
+ q_constrained = EngramQuery(
399
+ like=query_fp,
400
+ unlike=unlike_fps,
401
+ min_margin=margin_threshold,
402
+ )
403
+ s3 = constrained_retrieve(q_constrained, flat_index)
404
+ if s3:
405
+ s3_margin = s3[0].margin
406
+ return GeodesicResult(
407
+ doc_id=s3[0].doc_id,
408
+ score=s3[0].score,
409
+ margin=s3_margin,
410
+ confidence=(
411
+ RetrievalConfidence.MEDIUM
412
+ if s3_margin >= margin_threshold
413
+ else RetrievalConfidence.LOW
414
+ ),
415
+ stages_used=0,
416
+ constraint_used=True,
417
+ )
418
+
419
+ # Prior LOW but not yet chronic: tighten threshold
420
+ if prior.n_low > 0 and prior.n_total > 0:
421
+ low_frac = prior.n_low / prior.n_total
422
+ margin_threshold = margin_threshold * (1 + low_frac)
423
+
424
+ return geodesic_retrieve(
425
+ query_fp, hnsw_index, eng_index,
426
+ margin_threshold=margin_threshold,
427
+ correction_weight=correction_weight,
428
+ top_k=top_k,
429
+ )
430
+
431
+
432
+ def geodesic_retrieve_stage4(
433
+ query_fp: torch.Tensor,
434
+ hnsw_index,
435
+ eng_index: dict,
436
+ query_metadata: dict | None = None,
437
+ index_c=None,
438
+ query_doc_id: str | None = None,
439
+ margin_threshold: float = 0.005,
440
+ correction_weight: float = 0.3,
441
+ top_k: int = 5,
442
+ ) -> GeodesicResult:
443
+ """
444
+ Full pipeline: prior-aware geodesic retrieval with Stage 4 fallback.
445
+
446
+ Extends geodesic_retrieve_with_prior() with a Stage 4 metadata
447
+ disambiguation layer. When confidence is LOW and query_metadata
448
+ is provided, calls metadata_disambiguate() on top candidates
449
+ from the last HNSW pass before giving up.
450
+
451
+ Confidence tier:
452
+ HIGH -> fingerprint, 0% error rate target
453
+ MEDIUM -> fingerprint, 0% error rate target
454
+ LOW -> fingerprint failed, Stage 4 unavailable
455
+ low-metadata -> fingerprint failed, Stage 4 used secondary signal
456
+
457
+ Args:
458
+ query_fp: [dim] query fingerprint (vec_fourier_v2).
459
+ hnsw_index: Built EngramIndex.
460
+ eng_index: {doc_id: eng_data} from .eng files.
461
+ query_metadata: dict with task_description, context_len, l2_norm,
462
+ metadata.domain -- from the query doc's .eng data.
463
+ If None, Stage 4 is disabled.
464
+ index_c: IndexC instance for prior lookup.
465
+ query_doc_id: doc_id of the query source (for priors).
466
+ margin_threshold, correction_weight, top_k: as in base function.
467
+ """
468
+ from kvcos.engram.metadata_disambiguate import metadata_disambiguate
469
+
470
+ base = geodesic_retrieve_with_prior(
471
+ query_fp, hnsw_index, eng_index,
472
+ index_c=index_c,
473
+ query_doc_id=query_doc_id,
474
+ margin_threshold=margin_threshold,
475
+ correction_weight=correction_weight,
476
+ top_k=top_k,
477
+ )
478
+
479
+ # Only activate Stage 4 on LOW confidence with metadata available
480
+ if base.confidence != RetrievalConfidence.LOW:
481
+ return base
482
+ if query_metadata is None:
483
+ return base
484
+
485
+ # Get top candidates from HNSW for metadata scoring
486
+ candidates_hnsw = hnsw_index.search(query_fp, top_k=5)
487
+ candidates_meta = [
488
+ eng_index[r.doc_id]
489
+ for r in candidates_hnsw
490
+ if r.doc_id in eng_index
491
+ ]
492
+
493
+ if not candidates_meta:
494
+ return base
495
+
496
+ s4 = metadata_disambiguate(candidates_meta, query_metadata)
497
+ if s4 is None:
498
+ return base
499
+
500
+ return GeodesicResult(
501
+ doc_id = s4.doc_id,
502
+ score = base.score,
503
+ margin = base.margin,
504
+ confidence = RetrievalConfidence.LOW, # still LOW
505
+ stages_used = base.stages_used,
506
+ constraint_used = base.constraint_used,
507
+ stage4_used = True,
508
+ stage4_doc_id = s4.doc_id,
509
+ )
510
+ # Note: confidence stays LOW -- Stage 4 is a tiebreaker, not a promotion.
511
+ # Callers should check stage4_used=True to distinguish "failed silently"
512
+ # from "failed with secondary signal". The confidence tier string
513
+ # 'low-metadata' is available via Stage4Result for logging.
kvcos/engram/session_propagator.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ kvcos/engram/session_propagator.py — Session start/end for ENGRAM.
3
+
4
+ Bridges geodesic_retrieve() results and IndexC persistence.
5
+ Call session_start() at the top of each session to load priors.
6
+ Call session_end() at the bottom to persist results.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ import time
12
+ from dataclasses import dataclass
13
+ from pathlib import Path
14
+
15
+ from kvcos.engram.index_c import IndexC, DocPrior
16
+
17
+
18
+ @dataclass
19
+ class SessionSummary:
20
+ session_id: str
21
+ n_total: int
22
+ n_correct: int
23
+ n_high: int
24
+ n_medium: int
25
+ n_low: int
26
+ n_preempted: int
27
+ new_confusion_pairs: list[tuple[str, str]]
28
+ recall: float
29
+ duration_s: float
30
+
31
+
32
+ class SessionPropagator:
33
+ """
34
+ Manages session-level IndexC writes.
35
+
36
+ Accumulates retrieval results in memory during the session.
37
+ Writes them to IndexC at session_end().
38
+ """
39
+
40
+ def __init__(self, db_path: str, session_id: str):
41
+ self._db_path = str(db_path)
42
+ self._session_id = session_id
43
+ self._ic: IndexC | None = None
44
+ self._records: list[dict] = []
45
+ self._start_ts: float = 0.0
46
+ self._started: bool = False
47
+
48
+ def session_start(self) -> dict[str, DocPrior]:
49
+ """
50
+ Open IndexC, return {doc_id: DocPrior} for all known docs.
51
+ Call at the top of each session.
52
+ """
53
+ self._ic = IndexC.open(self._db_path)
54
+ self._start_ts = time.time()
55
+ self._started = True
56
+
57
+ rmap = self._ic.reliability_map()
58
+ return {
59
+ doc_id: self._ic.prior(doc_id)
60
+ for doc_id in rmap
61
+ }
62
+
63
+ @property
64
+ def index_c(self) -> IndexC:
65
+ """Access the IndexC instance (after session_start)."""
66
+ if self._ic is None:
67
+ raise RuntimeError("Call session_start() first.")
68
+ return self._ic
69
+
70
+ def record(
71
+ self,
72
+ query_doc_id: str,
73
+ result_doc_id: str,
74
+ confidence: str,
75
+ margin: float,
76
+ stages_used: int = 1,
77
+ constraint_used: bool = False,
78
+ correct: bool = True,
79
+ ) -> None:
80
+ """Buffer one retrieval result for this session."""
81
+ self._records.append({
82
+ "query_doc_id": query_doc_id,
83
+ "result_doc_id": result_doc_id,
84
+ "confidence": confidence,
85
+ "margin": float(margin),
86
+ "stages_used": int(stages_used),
87
+ "constraint_used": bool(constraint_used),
88
+ "correct": bool(correct),
89
+ "ts": time.time(),
90
+ })
91
+
92
+ def session_end(self) -> SessionSummary:
93
+ """Write all buffered records to IndexC. Return summary."""
94
+ if not self._started or self._ic is None:
95
+ raise RuntimeError("Call session_start() before session_end().")
96
+
97
+ confusion_before = {
98
+ (p.doc_a, p.doc_b)
99
+ for p in self._ic.confusion_registry(min_confusions=1)
100
+ }
101
+
102
+ for rec in self._records:
103
+ self._ic.record(
104
+ session_id=self._session_id,
105
+ query_doc_id=rec["query_doc_id"],
106
+ result_doc_id=rec["result_doc_id"],
107
+ confidence=rec["confidence"],
108
+ margin=rec["margin"],
109
+ stages_used=rec["stages_used"],
110
+ constraint_used=rec["constraint_used"],
111
+ correct=rec["correct"],
112
+ ts=rec["ts"],
113
+ )
114
+
115
+ confusion_after = {
116
+ (p.doc_a, p.doc_b)
117
+ for p in self._ic.confusion_registry(min_confusions=1)
118
+ }
119
+ new_pairs = list(confusion_after - confusion_before)
120
+
121
+ n_total = len(self._records)
122
+ n_correct = sum(1 for r in self._records if r["correct"])
123
+ counters = {"high": 0, "medium": 0, "low": 0}
124
+ n_preempted = 0
125
+ for r in self._records:
126
+ counters[r["confidence"]] = counters.get(r["confidence"], 0) + 1
127
+ if r["stages_used"] == 0:
128
+ n_preempted += 1
129
+
130
+ summary = SessionSummary(
131
+ session_id=self._session_id,
132
+ n_total=n_total,
133
+ n_correct=n_correct,
134
+ n_high=counters["high"],
135
+ n_medium=counters["medium"],
136
+ n_low=counters["low"],
137
+ n_preempted=n_preempted,
138
+ new_confusion_pairs=new_pairs,
139
+ recall=n_correct / n_total if n_total > 0 else 0.0,
140
+ duration_s=time.time() - self._start_ts,
141
+ )
142
+
143
+ self._ic.close()
144
+ self._ic = None
145
+ self._started = False
146
+ self._records = []
147
+
148
+ return summary
149
+
150
+ def summary_str(self, s: SessionSummary) -> str:
151
+ return (
152
+ f"Session {s.session_id}: "
153
+ f"{s.n_total} retrievals | "
154
+ f"recall={s.recall:.1%} | "
155
+ f"H={s.n_high}/M={s.n_medium}/L={s.n_low} | "
156
+ f"preempted={s.n_preempted} | "
157
+ f"new_pairs={len(s.new_confusion_pairs)} | "
158
+ f"{s.duration_s:.1f}s"
159
+ )
kvcos/engram/writer.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ EIGENGRAM writer: text + model -> .eng file
3
+ """
4
+
5
+ from __future__ import annotations
6
+
7
+ import gc
8
+ import hashlib
9
+ import os
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from llama_cpp import Llama
14
+
15
+ from kvcos.core.blob_parser import parse_state_blob
16
+ from .format import EigramEncoder
17
+
18
+ _encoder = EigramEncoder()
19
+
20
+ _DEFAULT_LR = (8, 24)
21
+ _DEFAULT_GATE = 6
22
+ _DEFAULT_RANK = 116
23
+
24
+
25
+ def _get_model_id(model_path: str) -> str:
26
+ name = os.path.basename(model_path)
27
+ if "3B" in name or "3b" in name:
28
+ return "Llama-3.2-3B"
29
+ if "8B" in name or "8b" in name:
30
+ return "Llama-3.1-8B"
31
+ return name[:15]
32
+
33
+
34
+ def _corpus_hash(basis_path: str) -> str:
35
+ raw = Path(basis_path).read_bytes()
36
+ return hashlib.sha256(raw).hexdigest()[:32]
37
+
38
+
39
+ def write_eigengram(
40
+ model_path: str,
41
+ text: str,
42
+ output_path: str,
43
+ cache_id: str = "",
44
+ task_description: str = "",
45
+ layer_range: tuple[int, int] = _DEFAULT_LR,
46
+ gate: int = _DEFAULT_GATE,
47
+ rank_perdoc: int = _DEFAULT_RANK,
48
+ basis_path: str = "results/corpus_basis_fcdb_v2.pt",
49
+ ) -> dict:
50
+ """Encode a document as an EIGENGRAM file."""
51
+ saved = torch.load(basis_path, weights_only=False)
52
+ P_fcdb = saved["basis"]
53
+ center = saved["joint_center"]
54
+ n_corpus = int(saved["n_docs"])
55
+ basis_rank = P_fcdb.shape[0]
56
+
57
+ llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=-1, verbose=False)
58
+ meta = llm.metadata
59
+ n_kv = int(meta.get("llama.attention.head_count_kv", "8"))
60
+ hd = int(meta.get("llama.embedding_length", "4096")) // int(
61
+ meta.get("llama.attention.head_count", "32")
62
+ )
63
+
64
+ llm.reset()
65
+ llm(text.strip(), max_tokens=1, temperature=0.0)
66
+ state_bytes = bytes(llm.save_state().llama_state)
67
+ del llm
68
+ gc.collect()
69
+
70
+ p = parse_state_blob(state_bytes, n_kv_heads=n_kv, head_dim=hd)
71
+ l0, l1 = layer_range
72
+ k = p.keys[l0:l1].float().reshape(-1, hd)
73
+ mean_v = k.mean(0)
74
+ l2_norm = float(mean_v.norm().item())
75
+
76
+ # Per-doc SVD fingerprint
77
+ if k.shape[0] > 8192:
78
+ gen = torch.Generator()
79
+ gen.manual_seed(42)
80
+ idx = torch.randperm(k.shape[0], generator=gen)[:8192]
81
+ svd_input = k[idx]
82
+ else:
83
+ svd_input = k
84
+ _, _, Vh = torch.linalg.svd(svd_input, full_matrices=False)
85
+ proj = (svd_input @ Vh[gate : gate + rank_perdoc].T).mean(0)
86
+ vec_perdoc = proj / (proj.norm() + 1e-8)
87
+
88
+ # FCDB fingerprint
89
+ delta = mean_v - center
90
+ delta = delta / (delta.norm() + 1e-8)
91
+ vec_fcdb = delta @ P_fcdb.T
92
+ vec_fcdb = vec_fcdb / (vec_fcdb.norm() + 1e-8)
93
+
94
+ # SCS
95
+ scs = float(
96
+ ((delta @ P_fcdb.T @ P_fcdb) ** 2).sum().item()
97
+ / ((delta**2).sum().item() + 1e-12)
98
+ )
99
+
100
+ corpus_h = _corpus_hash(basis_path)
101
+ model_id = _get_model_id(model_path)
102
+
103
+ cert = _encoder.encode(
104
+ vec_perdoc=vec_perdoc,
105
+ vec_fcdb=vec_fcdb,
106
+ joint_center=center,
107
+ corpus_hash=corpus_h,
108
+ model_id=model_id,
109
+ basis_rank=basis_rank,
110
+ n_corpus=n_corpus,
111
+ layer_range=layer_range,
112
+ context_len=int(k.shape[0]),
113
+ l2_norm=l2_norm,
114
+ scs=scs,
115
+ margin_proof=0.0,
116
+ task_description=task_description or text[:100],
117
+ cache_id=cache_id or "",
118
+ )
119
+
120
+ os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
121
+ with open(output_path, "wb") as f:
122
+ f.write(cert)
123
+
124
+ return {
125
+ "output_path": output_path,
126
+ "model_id": model_id,
127
+ "corpus_hash": corpus_h,
128
+ "basis_rank": basis_rank,
129
+ "n_corpus": n_corpus,
130
+ "file_size_bytes": len(cert),
131
+ "scs": round(scs, 4),
132
+ "l2_norm": round(l2_norm, 4),
133
+ "layer_range": layer_range,
134
+ }
kvcos/mar/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ENGRAM Protocol — MAR (Manifold Attention Retrieval)
2
+
3
+ Backward compatibility re-exports. All classes have moved to kvcos.core.
4
+ Import from kvcos.core directly for new code.
5
+ """
6
+
7
+ from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
8
+ from kvcos.core.retriever import EGRRetriever, RetrievalResponse, RetrievalResult
9
+ from kvcos.core.state_extractor import ExtractionResult, MARStateExtractor, SVDProjection
10
+
11
+ __all__ = [
12
+ "IndexEntry",
13
+ "ManifoldIndex",
14
+ "EGRRetriever",
15
+ "RetrievalResponse",
16
+ "RetrievalResult",
17
+ "ExtractionResult",
18
+ "MARStateExtractor",
19
+ "SVDProjection",
20
+ ]
kvcos/storage/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ """ENGRAM Protocol — Storage backends for .eng files."""
2
+
3
+ from kvcos.storage.backends import StorageBackend
4
+ from kvcos.storage.local import LocalStorageBackend
5
+
6
+ __all__ = [
7
+ "StorageBackend",
8
+ "LocalStorageBackend",
9
+ ]
kvcos/storage/backends.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Abstract Storage Backend
3
+
4
+
5
+ All storage backends (local, redis, S3) implement this interface.
6
+ Phase 1 uses local disk only.
7
+ """
8
+
9
+ from __future__ import annotations
10
+
11
+ from abc import ABC, abstractmethod
12
+ from pathlib import Path
13
+
14
+ from kvcos.core.types import CacheStats, EngramMetadata
15
+
16
+
17
+ class StorageBackend(ABC):
18
+ """Abstract interface for engram storage backends.
19
+
20
+ All operations are synchronous in Phase 1.
21
+ """
22
+
23
+ @abstractmethod
24
+ def store(self, cache_id: str, data: bytes, metadata: EngramMetadata) -> str:
25
+ """Store a .eng file. Returns storage path/key."""
26
+ ...
27
+
28
+ @abstractmethod
29
+ def store_file(self, cache_id: str, source_path: Path, metadata: EngramMetadata) -> str:
30
+ """Store a .eng file from a local path (zero-copy when possible)."""
31
+ ...
32
+
33
+ @abstractmethod
34
+ def get(self, cache_id: str) -> bytes | None:
35
+ """Retrieve a .eng file as bytes. None if not found."""
36
+ ...
37
+
38
+ @abstractmethod
39
+ def get_path(self, cache_id: str) -> Path | None:
40
+ """Get local filesystem path for a cache entry. None if not found."""
41
+ ...
42
+
43
+ @abstractmethod
44
+ def get_metadata(self, cache_id: str) -> EngramMetadata | None:
45
+ """Read only metadata (header-only, no tensor data loaded)."""
46
+ ...
47
+
48
+ @abstractmethod
49
+ def delete(self, cache_id: str) -> bool:
50
+ """Delete a cache entry. Returns True if deleted."""
51
+ ...
52
+
53
+ @abstractmethod
54
+ def list_entries(
55
+ self,
56
+ agent_id: str | None = None,
57
+ model_family: str | None = None,
58
+ limit: int = 100,
59
+ ) -> list[EngramMetadata]:
60
+ """List cache entries with optional filters."""
61
+ ...
62
+
63
+ @abstractmethod
64
+ def exists(self, cache_id: str) -> bool:
65
+ """Check if a cache entry exists."""
66
+ ...
67
+
68
+ @abstractmethod
69
+ def stats(self) -> CacheStats:
70
+ """Get aggregate statistics for the store."""
71
+ ...
kvcos/storage/local.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ENGRAM Protocol — Local Disk Storage Backend
3
+
4
+
5
+ Directory layout:
6
+ {data_dir}/{model_family}/{agent_id}/{date}/{cache_id}.eng
7
+
8
+ Phase 1 production backend. Zero infrastructure dependencies.
9
+ Uses safetensors header-only read for metadata operations.
10
+ D7: One safetensors file per 256-token block.
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ import shutil
17
+ from collections import defaultdict
18
+ from datetime import datetime, timezone
19
+ from pathlib import Path
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+ from kvcos.core.serializer import EngramSerializer
24
+ from kvcos.core.types import ENG_FILE_EXTENSION, CacheStats, EngramMetadata
25
+ from kvcos.storage.backends import StorageBackend
26
+
27
+
28
+ class LocalStorageBackend(StorageBackend):
29
+ """Local filesystem storage for .eng files.
30
+
31
+ Files organized by model family, agent ID, and date.
32
+ """
33
+
34
+ def __init__(self, data_dir: Path):
35
+ self.data_dir = data_dir
36
+ self.data_dir.mkdir(parents=True, exist_ok=True)
37
+ self._serializer = EngramSerializer()
38
+ self._index: dict[str, Path] = {} # cache_id → file path
39
+ self._rebuild_index()
40
+
41
+ def _rebuild_index(self) -> None:
42
+ """Scan data directory and rebuild in-memory path index."""
43
+ self._index.clear()
44
+ for eng_file in self.data_dir.rglob(f"*{ENG_FILE_EXTENSION}"):
45
+ cache_id = eng_file.stem
46
+ try:
47
+ meta = self._serializer.read_metadata_only(eng_file)
48
+ if "cache_id" in meta:
49
+ cache_id = meta["cache_id"]
50
+ except Exception as e:
51
+ logger.debug("Skipping metadata for %s: %s", eng_file.name, e)
52
+ self._index[cache_id] = eng_file
53
+
54
+ def _resolve_path(self, metadata: EngramMetadata) -> Path:
55
+ """Determine storage path from metadata."""
56
+ model_family = metadata.get("model_family", "unknown")
57
+ agent_id = metadata.get("agent_id", "default")
58
+ date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
59
+ cache_id = metadata.get("cache_id", "unknown")
60
+ path = self.data_dir / model_family / agent_id / date_str / f"{cache_id}{ENG_FILE_EXTENSION}"
61
+ path.parent.mkdir(parents=True, exist_ok=True)
62
+ return path
63
+
64
+ def store(self, cache_id: str, data: bytes, metadata: EngramMetadata) -> str:
65
+ metadata_copy = dict(metadata)
66
+ metadata_copy["cache_id"] = cache_id
67
+ path = self._resolve_path(metadata_copy) # type: ignore[arg-type]
68
+
69
+ tmp_path = path.with_suffix(f"{ENG_FILE_EXTENSION}.tmp")
70
+ try:
71
+ tmp_path.write_bytes(data)
72
+ tmp_path.rename(path)
73
+ except Exception:
74
+ tmp_path.unlink(missing_ok=True)
75
+ raise
76
+
77
+ self._index[cache_id] = path
78
+ return str(path)
79
+
80
+ def store_file(self, cache_id: str, source_path: Path, metadata: EngramMetadata) -> str:
81
+ metadata_copy = dict(metadata)
82
+ metadata_copy["cache_id"] = cache_id
83
+ dest_path = self._resolve_path(metadata_copy) # type: ignore[arg-type]
84
+
85
+ if source_path == dest_path:
86
+ self._index[cache_id] = dest_path
87
+ return str(dest_path)
88
+
89
+ tmp_path = dest_path.with_suffix(f"{ENG_FILE_EXTENSION}.tmp")
90
+ try:
91
+ shutil.copy2(str(source_path), str(tmp_path))
92
+ tmp_path.rename(dest_path)
93
+ except Exception:
94
+ tmp_path.unlink(missing_ok=True)
95
+ raise
96
+
97
+ self._index[cache_id] = dest_path
98
+ return str(dest_path)
99
+
100
+ def get(self, cache_id: str) -> bytes | None:
101
+ path = self._index.get(cache_id)
102
+ if path is None or not path.exists():
103
+ return None
104
+ return path.read_bytes()
105
+
106
+ def get_path(self, cache_id: str) -> Path | None:
107
+ path = self._index.get(cache_id)
108
+ if path is None or not path.exists():
109
+ return None
110
+ return path
111
+
112
+ def get_metadata(self, cache_id: str) -> EngramMetadata | None:
113
+ path = self._index.get(cache_id)
114
+ if path is None or not path.exists():
115
+ return None
116
+ try:
117
+ return self._serializer.read_metadata_only(path)
118
+ except Exception as e:
119
+ logger.warning("Failed to read metadata for %s: %s", cache_id, e)
120
+ return None
121
+
122
+ def delete(self, cache_id: str) -> bool:
123
+ path = self._index.pop(cache_id, None)
124
+ if path is None or not path.exists():
125
+ return False
126
+
127
+ path.unlink()
128
+
129
+ parent = path.parent
130
+ try:
131
+ while parent != self.data_dir:
132
+ if not any(parent.iterdir()):
133
+ parent.rmdir()
134
+ parent = parent.parent
135
+ else:
136
+ break
137
+ except OSError:
138
+ pass
139
+
140
+ return True
141
+
142
+ def list_entries(
143
+ self,
144
+ agent_id: str | None = None,
145
+ model_family: str | None = None,
146
+ limit: int = 100,
147
+ ) -> list[EngramMetadata]:
148
+ results: list[EngramMetadata] = []
149
+
150
+ for cache_id, path in self._index.items():
151
+ if len(results) >= limit:
152
+ break
153
+ if not path.exists():
154
+ continue
155
+ try:
156
+ meta = self._serializer.read_metadata_only(path)
157
+ except Exception as e:
158
+ logger.debug("Skipping %s in list_entries: %s", cache_id, e)
159
+ continue
160
+ if agent_id and meta.get("agent_id") != agent_id:
161
+ continue
162
+ if model_family and meta.get("model_family") != model_family:
163
+ continue
164
+ results.append(meta)
165
+
166
+ results.sort(key=lambda m: m.get("created_at", ""), reverse=True)
167
+ return results[:limit]
168
+
169
+ def exists(self, cache_id: str) -> bool:
170
+ path = self._index.get(cache_id)
171
+ return path is not None and path.exists()
172
+
173
+ def stats(self) -> CacheStats:
174
+ total_entries = 0
175
+ total_size = 0
176
+ model_counts: dict[str, int] = defaultdict(int)
177
+
178
+ for cache_id, path in self._index.items():
179
+ if not path.exists():
180
+ continue
181
+ total_entries += 1
182
+ total_size += path.stat().st_size
183
+ try:
184
+ meta = self._serializer.read_metadata_only(path)
185
+ model_counts[meta.get("model_family", "unknown")] += 1
186
+ except Exception as e:
187
+ logger.debug("Metadata read failed for %s: %s", cache_id, e)
188
+ model_counts["unknown"] += 1
189
+
190
+ return CacheStats(
191
+ total_entries=total_entries,
192
+ total_size_bytes=total_size,
193
+ avg_compression_ratio=0.0,
194
+ model_breakdown=dict(model_counts),
195
+ )
196
+
197
+ def vacuum(self) -> int:
198
+ """Remove stale index entries for deleted files."""
199
+ stale = [cid for cid, path in self._index.items() if not path.exists()]
200
+ for cid in stale:
201
+ del self._index[cid]
202
+ return len(stale)