feat: upload core kvcos library
Browse files- kvcos/.DS_Store +0 -0
- kvcos/__init__.py +5 -0
- kvcos/api/__init__.py +5 -0
- kvcos/api/routes.py +211 -0
- kvcos/api/schemas.py +88 -0
- kvcos/api/server.py +126 -0
- kvcos/client/__init__.py +5 -0
- kvcos/client/python_client.py +158 -0
- kvcos/core/__init__.py +40 -0
- kvcos/core/blob_parser.py +482 -0
- kvcos/core/block_pool.py +167 -0
- kvcos/core/cache_spec.py +215 -0
- kvcos/core/compression.py +395 -0
- kvcos/core/config.py +82 -0
- kvcos/core/fingerprint.py +167 -0
- kvcos/core/manifold_index.py +294 -0
- kvcos/core/retriever.py +263 -0
- kvcos/core/serializer.py +274 -0
- kvcos/core/state_extractor.py +489 -0
- kvcos/core/types.py +201 -0
- kvcos/engram/__init__.py +4 -0
- kvcos/engram/__main__.py +265 -0
- kvcos/engram/chunker.py +327 -0
- kvcos/engram/embedder.py +221 -0
- kvcos/engram/format.py +251 -0
- kvcos/engram/hnsw_index.py +205 -0
- kvcos/engram/index_c.py +426 -0
- kvcos/engram/knowledge_index.py +289 -0
- kvcos/engram/manifest.py +232 -0
- kvcos/engram/metadata_disambiguate.py +119 -0
- kvcos/engram/reader.py +54 -0
- kvcos/engram/retrieval.py +513 -0
- kvcos/engram/session_propagator.py +159 -0
- kvcos/engram/writer.py +134 -0
- kvcos/mar/__init__.py +20 -0
- kvcos/storage/__init__.py +9 -0
- kvcos/storage/backends.py +71 -0
- kvcos/storage/local.py +202 -0
kvcos/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
kvcos/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ENGRAM Protocol — KV cache fingerprinting for persistent semantic retrieval."""
|
| 2 |
+
|
| 3 |
+
from kvcos.core.types import ENGRAM_VERSION
|
| 4 |
+
|
| 5 |
+
__version__ = ENGRAM_VERSION
|
kvcos/api/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ENGRAM Protocol — REST API package."""
|
| 2 |
+
|
| 3 |
+
from kvcos.api.server import create_app
|
| 4 |
+
|
| 5 |
+
__all__ = ["create_app"]
|
kvcos/api/routes.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — API Routes
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
FastAPI route handlers for the ENGRAM REST API.
|
| 6 |
+
All endpoints under /v1/ prefix.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File
|
| 12 |
+
|
| 13 |
+
from kvcos.api.schemas import (
|
| 14 |
+
DeleteResponse,
|
| 15 |
+
HealthResponse,
|
| 16 |
+
SearchRequest,
|
| 17 |
+
SearchResponse,
|
| 18 |
+
SearchResultItem,
|
| 19 |
+
StatsResponse,
|
| 20 |
+
StoreResponse,
|
| 21 |
+
)
|
| 22 |
+
from kvcos.core.types import ENGRAM_VERSION
|
| 23 |
+
|
| 24 |
+
router = APIRouter(prefix="/v1")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ── Dependency stubs ──────────────────────────────────────────────────────────
|
| 28 |
+
# These are replaced by real instances in server.py lifespan.
|
| 29 |
+
# Using module-level state that the server sets during startup.
|
| 30 |
+
|
| 31 |
+
_retriever = None
|
| 32 |
+
_storage = None
|
| 33 |
+
_index = None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def _get_retriever():
|
| 37 |
+
if _retriever is None:
|
| 38 |
+
raise HTTPException(503, "ENGRAM not initialized. Server starting up.")
|
| 39 |
+
return _retriever
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def _get_storage():
|
| 43 |
+
if _storage is None:
|
| 44 |
+
raise HTTPException(503, "ENGRAM not initialized. Server starting up.")
|
| 45 |
+
return _storage
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _get_index():
|
| 49 |
+
if _index is None:
|
| 50 |
+
raise HTTPException(503, "ENGRAM not initialized. Server starting up.")
|
| 51 |
+
return _index
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ── Health ────────────────────────────────────────────────────────────────────
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@router.get("/health", response_model=HealthResponse)
|
| 58 |
+
async def health():
|
| 59 |
+
"""Health check endpoint."""
|
| 60 |
+
index = _get_index()
|
| 61 |
+
storage = _get_storage()
|
| 62 |
+
return HealthResponse(
|
| 63 |
+
status="ok",
|
| 64 |
+
version=ENGRAM_VERSION,
|
| 65 |
+
index_entries=index.n_entries,
|
| 66 |
+
storage_backend="local",
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ── Stats (must come before /cache/{cache_id} to avoid route shadowing) ──────
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
@router.get("/cache/stats", response_model=StatsResponse)
|
| 74 |
+
async def cache_stats():
|
| 75 |
+
"""Get aggregate statistics for the engram store."""
|
| 76 |
+
storage = _get_storage()
|
| 77 |
+
stats = storage.stats()
|
| 78 |
+
return StatsResponse(
|
| 79 |
+
total_entries=stats["total_entries"],
|
| 80 |
+
total_size_bytes=stats["total_size_bytes"],
|
| 81 |
+
total_size_mb=round(stats["total_size_bytes"] / (1024 * 1024), 2),
|
| 82 |
+
avg_compression_ratio=stats["avg_compression_ratio"],
|
| 83 |
+
model_breakdown=stats["model_breakdown"],
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# ── Store ─────────────────────────────────────────────────────────────────────
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@router.post("/cache", response_model=StoreResponse)
|
| 91 |
+
async def store_cache(
|
| 92 |
+
agent_id: str,
|
| 93 |
+
task_description: str,
|
| 94 |
+
model_id: str,
|
| 95 |
+
file: UploadFile = File(...),
|
| 96 |
+
compression: str = "q8_0",
|
| 97 |
+
):
|
| 98 |
+
"""Store a .eng file in the engram store.
|
| 99 |
+
|
| 100 |
+
Accepts a pre-serialized .eng file upload.
|
| 101 |
+
The file is stored and its metadata indexed for EGR retrieval.
|
| 102 |
+
"""
|
| 103 |
+
storage = _get_storage()
|
| 104 |
+
|
| 105 |
+
data = await file.read()
|
| 106 |
+
if len(data) == 0:
|
| 107 |
+
raise HTTPException(400, "Empty file upload")
|
| 108 |
+
|
| 109 |
+
import uuid
|
| 110 |
+
cache_id = str(uuid.uuid4())
|
| 111 |
+
|
| 112 |
+
from kvcos.core.types import EngramMetadata
|
| 113 |
+
from datetime import datetime, timezone
|
| 114 |
+
|
| 115 |
+
metadata: EngramMetadata = {
|
| 116 |
+
"engram_version": ENGRAM_VERSION,
|
| 117 |
+
"cache_id": cache_id,
|
| 118 |
+
"compression": compression,
|
| 119 |
+
"model_id": model_id,
|
| 120 |
+
"model_family": "",
|
| 121 |
+
"n_layers": "0",
|
| 122 |
+
"n_heads": "0",
|
| 123 |
+
"n_kv_heads": "0",
|
| 124 |
+
"head_dim": "0",
|
| 125 |
+
"context_len": "0",
|
| 126 |
+
"agent_id": agent_id,
|
| 127 |
+
"task_description": task_description,
|
| 128 |
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
path = storage.store(cache_id, data, metadata)
|
| 132 |
+
|
| 133 |
+
return StoreResponse(
|
| 134 |
+
cache_id=cache_id,
|
| 135 |
+
size_bytes=len(data),
|
| 136 |
+
compression_ratio=1.0,
|
| 137 |
+
path=path,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
# ── Retrieve by ID ────────────────────────────────────────────────────────────
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
@router.get("/cache/{cache_id}")
|
| 145 |
+
async def get_cache(cache_id: str):
|
| 146 |
+
"""Retrieve a .eng file by cache ID.
|
| 147 |
+
|
| 148 |
+
Returns the raw .eng file bytes (application/octet-stream).
|
| 149 |
+
"""
|
| 150 |
+
storage = _get_storage()
|
| 151 |
+
|
| 152 |
+
data = storage.get(cache_id)
|
| 153 |
+
if data is None:
|
| 154 |
+
raise HTTPException(404, f"Cache entry not found: {cache_id}")
|
| 155 |
+
|
| 156 |
+
from fastapi.responses import Response
|
| 157 |
+
return Response(
|
| 158 |
+
content=data,
|
| 159 |
+
media_type="application/octet-stream",
|
| 160 |
+
headers={"Content-Disposition": f'attachment; filename="{cache_id}.eng"'},
|
| 161 |
+
)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
# ── Search ────────────────────────────────────────────────────────────────────
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
@router.post("/cache/search", response_model=SearchResponse)
|
| 168 |
+
async def search_cache(req: SearchRequest):
|
| 169 |
+
"""Search for similar engram states via EGR manifold search.
|
| 170 |
+
|
| 171 |
+
Uses inner product similarity (MIPS) in the model's pre-RoPE
|
| 172 |
+
key manifold. D2: K→K retrieval only.
|
| 173 |
+
"""
|
| 174 |
+
index = _get_index()
|
| 175 |
+
|
| 176 |
+
# For text-only search without a KV query vector, we need the
|
| 177 |
+
# retriever to extract a state vector first. This endpoint
|
| 178 |
+
# currently returns index entries matching by metadata filter.
|
| 179 |
+
# Full EGR vector search requires a query KV cache (via /egr/retrieve).
|
| 180 |
+
|
| 181 |
+
# Metadata-based listing with optional filters
|
| 182 |
+
storage = _get_storage()
|
| 183 |
+
entries = storage.list_entries(model_family=None, limit=req.top_k)
|
| 184 |
+
|
| 185 |
+
results = [
|
| 186 |
+
SearchResultItem(
|
| 187 |
+
cache_id=e.get("cache_id", ""),
|
| 188 |
+
similarity=0.0,
|
| 189 |
+
task_description=e.get("task_description", ""),
|
| 190 |
+
model_id=e.get("model_id", ""),
|
| 191 |
+
created_at=e.get("created_at", ""),
|
| 192 |
+
context_len=int(e.get("context_len", "0")),
|
| 193 |
+
)
|
| 194 |
+
for e in entries
|
| 195 |
+
if (req.model_id is None or e.get("model_id") == req.model_id)
|
| 196 |
+
]
|
| 197 |
+
|
| 198 |
+
return SearchResponse(results=results[:req.top_k], n_searched=index.n_entries)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ── Delete ────────────────────────────────────────────────────────────────────
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@router.delete("/cache/{cache_id}", response_model=DeleteResponse)
|
| 205 |
+
async def delete_cache(cache_id: str):
|
| 206 |
+
"""Delete an engram from storage and index."""
|
| 207 |
+
retriever = _get_retriever()
|
| 208 |
+
deleted = retriever.delete_engram(cache_id)
|
| 209 |
+
return DeleteResponse(deleted=deleted, cache_id=cache_id)
|
| 210 |
+
|
| 211 |
+
|
kvcos/api/schemas.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — API Schemas
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Pydantic models for all REST API request/response payloads.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from pydantic import BaseModel, Field
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
# ── Store ─────────────────────────────────────────────────────────────────────
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class StoreRequest(BaseModel):
|
| 17 |
+
agent_id: str
|
| 18 |
+
task_description: str
|
| 19 |
+
model_id: str
|
| 20 |
+
compression: str = "q8_0"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class StoreResponse(BaseModel):
|
| 24 |
+
cache_id: str
|
| 25 |
+
size_bytes: int
|
| 26 |
+
compression_ratio: float
|
| 27 |
+
path: str
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ── Retrieve ──────────────────────────────────────────────────────────────────
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class SearchRequest(BaseModel):
|
| 34 |
+
task_description: str
|
| 35 |
+
model_id: str | None = None
|
| 36 |
+
top_k: int = Field(default=5, ge=1, le=100)
|
| 37 |
+
min_similarity: float | None = None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SearchResultItem(BaseModel):
|
| 41 |
+
cache_id: str
|
| 42 |
+
similarity: float
|
| 43 |
+
task_description: str
|
| 44 |
+
model_id: str
|
| 45 |
+
created_at: str
|
| 46 |
+
context_len: int
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class SearchResponse(BaseModel):
|
| 50 |
+
results: list[SearchResultItem]
|
| 51 |
+
n_searched: int
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
# ── Extend ────────────────────────────────────────────────────────────────────
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ExtendResponse(BaseModel):
|
| 58 |
+
cache_id: str
|
| 59 |
+
new_context_len: int
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
# ── Delete ────────────────────────────────────────────────────────────────────
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class DeleteResponse(BaseModel):
|
| 66 |
+
deleted: bool
|
| 67 |
+
cache_id: str
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
# ── Stats ─────────────────────────────────────────────────────────────────────
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class StatsResponse(BaseModel):
|
| 74 |
+
total_entries: int
|
| 75 |
+
total_size_bytes: int
|
| 76 |
+
total_size_mb: float
|
| 77 |
+
avg_compression_ratio: float
|
| 78 |
+
model_breakdown: dict[str, int]
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
# ── Health ────────────────────────────────────────────────────────────────────
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class HealthResponse(BaseModel):
|
| 85 |
+
status: str = "ok"
|
| 86 |
+
version: str
|
| 87 |
+
index_entries: int
|
| 88 |
+
storage_backend: str
|
kvcos/api/server.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — ENGRAM Server
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
FastAPI application factory with lifespan management.
|
| 6 |
+
Initializes storage, index, extractor, and retriever on startup.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from contextlib import asynccontextmanager
|
| 13 |
+
|
| 14 |
+
from fastapi import FastAPI
|
| 15 |
+
|
| 16 |
+
from kvcos.api import routes
|
| 17 |
+
from kvcos.core.config import get_config
|
| 18 |
+
from kvcos.core.serializer import EngramSerializer
|
| 19 |
+
from kvcos.core.types import ENGRAM_VERSION, StateExtractionMode
|
| 20 |
+
from kvcos.core.manifold_index import ManifoldIndex
|
| 21 |
+
from kvcos.core.retriever import EGRRetriever
|
| 22 |
+
from kvcos.core.state_extractor import MARStateExtractor
|
| 23 |
+
from kvcos.storage.local import LocalStorageBackend
|
| 24 |
+
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@asynccontextmanager
|
| 29 |
+
async def lifespan(app: FastAPI):
|
| 30 |
+
"""Initialize ENGRAM components on startup, clean up on shutdown."""
|
| 31 |
+
config = get_config()
|
| 32 |
+
|
| 33 |
+
# Initialize storage backend
|
| 34 |
+
storage = LocalStorageBackend(data_dir=config.data_dir)
|
| 35 |
+
|
| 36 |
+
# Initialize EGR manifold index
|
| 37 |
+
index_path = config.index_dir / "egr.faiss"
|
| 38 |
+
index = ManifoldIndex(dim=config.state_vec_dim, index_path=index_path)
|
| 39 |
+
|
| 40 |
+
# Initialize state extractor
|
| 41 |
+
extractor = MARStateExtractor(
|
| 42 |
+
mode=StateExtractionMode.SVD_PROJECT,
|
| 43 |
+
rank=config.state_vec_dim,
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Initialize retriever
|
| 47 |
+
serializer = EngramSerializer()
|
| 48 |
+
retriever = EGRRetriever(
|
| 49 |
+
extractor=extractor,
|
| 50 |
+
index=index,
|
| 51 |
+
storage=storage,
|
| 52 |
+
serializer=serializer,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Wire into route handlers
|
| 56 |
+
routes._storage = storage
|
| 57 |
+
routes._index = index
|
| 58 |
+
routes._retriever = retriever
|
| 59 |
+
|
| 60 |
+
logger.info("ENGRAM v%s started", ENGRAM_VERSION)
|
| 61 |
+
logger.info(" Storage: %s (%d entries)", config.data_dir, storage.stats()["total_entries"])
|
| 62 |
+
logger.info(" Index: %s (%d vectors, dim=%d)", config.index_dir, index.n_entries, config.state_vec_dim)
|
| 63 |
+
logger.info(" Backend: %s", config.backend.value)
|
| 64 |
+
|
| 65 |
+
yield
|
| 66 |
+
|
| 67 |
+
# Shutdown: persist index
|
| 68 |
+
try:
|
| 69 |
+
index.save(index_path)
|
| 70 |
+
logger.info("Index saved to %s", index_path)
|
| 71 |
+
except Exception as e:
|
| 72 |
+
logger.warning("Failed to save index: %s", e)
|
| 73 |
+
|
| 74 |
+
# Clear route references
|
| 75 |
+
routes._storage = None
|
| 76 |
+
routes._index = None
|
| 77 |
+
routes._retriever = None
|
| 78 |
+
|
| 79 |
+
logger.info("ENGRAM shutdown complete")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def create_app() -> FastAPI:
|
| 83 |
+
"""Create the ENGRAM FastAPI application."""
|
| 84 |
+
app = FastAPI(
|
| 85 |
+
title="ENGRAM Protocol API",
|
| 86 |
+
description="ENGRAM Protocol: Cognitive state, persisted.",
|
| 87 |
+
version=ENGRAM_VERSION,
|
| 88 |
+
lifespan=lifespan,
|
| 89 |
+
docs_url="/docs",
|
| 90 |
+
redoc_url="/redoc",
|
| 91 |
+
)
|
| 92 |
+
app.include_router(routes.router)
|
| 93 |
+
return app
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def main() -> None:
|
| 97 |
+
"""Entry point for `engram-server` console script."""
|
| 98 |
+
import uvicorn
|
| 99 |
+
|
| 100 |
+
config = get_config()
|
| 101 |
+
application = create_app()
|
| 102 |
+
uvicorn.run(
|
| 103 |
+
application,
|
| 104 |
+
host=config.host,
|
| 105 |
+
port=config.port,
|
| 106 |
+
log_level="info",
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def _get_app() -> FastAPI:
|
| 111 |
+
"""Lazy app factory for `uvicorn kvcos.api.server:app`.
|
| 112 |
+
|
| 113 |
+
Defers create_app() until the attribute is actually accessed,
|
| 114 |
+
avoiding side effects on module import.
|
| 115 |
+
"""
|
| 116 |
+
return create_app()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def __getattr__(name: str) -> FastAPI:
|
| 120 |
+
if name == "app":
|
| 121 |
+
return _get_app()
|
| 122 |
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
if __name__ == "__main__":
|
| 126 |
+
main()
|
kvcos/client/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ENGRAM Protocol — Client library."""
|
| 2 |
+
|
| 3 |
+
from kvcos.client.python_client import EngramClient
|
| 4 |
+
|
| 5 |
+
__all__ = ["EngramClient"]
|
kvcos/client/python_client.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — ENGRAM Python Client
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Async HTTP client wrapping all ENGRAM API endpoints.
|
| 6 |
+
This is what agents import to interact with the engram store.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from typing import Any
|
| 13 |
+
|
| 14 |
+
import httpx
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class EngramClient:
|
| 18 |
+
"""Python client for the ENGRAM REST API.
|
| 19 |
+
|
| 20 |
+
Usage:
|
| 21 |
+
client = EngramClient("http://localhost:8080")
|
| 22 |
+
result = client.store_file(path, agent_id="worker", task="analyze code", model_id="llama-3.1-8b")
|
| 23 |
+
matches = client.search(task_description="debug auth error", top_k=3)
|
| 24 |
+
data = client.get(matches[0]["cache_id"])
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, base_url: str = "http://localhost:8080", timeout: float = 30.0):
|
| 28 |
+
self.base_url = base_url.rstrip("/")
|
| 29 |
+
self._client = httpx.Client(base_url=f"{self.base_url}/v1", timeout=timeout)
|
| 30 |
+
|
| 31 |
+
def close(self) -> None:
|
| 32 |
+
self._client.close()
|
| 33 |
+
|
| 34 |
+
def __enter__(self):
|
| 35 |
+
return self
|
| 36 |
+
|
| 37 |
+
def __exit__(self, *args):
|
| 38 |
+
self.close()
|
| 39 |
+
|
| 40 |
+
# ── Health ────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
def health(self) -> dict[str, Any]:
|
| 43 |
+
"""Check ENGRAM server health."""
|
| 44 |
+
resp = self._client.get("/health")
|
| 45 |
+
resp.raise_for_status()
|
| 46 |
+
return resp.json()
|
| 47 |
+
|
| 48 |
+
# ── Store ─────────────────────────────────────────────────
|
| 49 |
+
|
| 50 |
+
def store_file(
|
| 51 |
+
self,
|
| 52 |
+
file_path: Path,
|
| 53 |
+
agent_id: str,
|
| 54 |
+
task_description: str,
|
| 55 |
+
model_id: str,
|
| 56 |
+
compression: str = "q8_0",
|
| 57 |
+
) -> dict[str, Any]:
|
| 58 |
+
"""Upload a .eng file to the engram store.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
file_path: Path to the .eng file
|
| 62 |
+
agent_id: Agent identifier
|
| 63 |
+
task_description: Human-readable description
|
| 64 |
+
model_id: Model identifier
|
| 65 |
+
compression: Compression method used
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
Dict with cache_id, size_bytes, compression_ratio, path
|
| 69 |
+
"""
|
| 70 |
+
with open(file_path, "rb") as f:
|
| 71 |
+
resp = self._client.post(
|
| 72 |
+
"/cache",
|
| 73 |
+
params={
|
| 74 |
+
"agent_id": agent_id,
|
| 75 |
+
"task_description": task_description,
|
| 76 |
+
"model_id": model_id,
|
| 77 |
+
"compression": compression,
|
| 78 |
+
},
|
| 79 |
+
files={"file": (file_path.name, f, "application/octet-stream")},
|
| 80 |
+
)
|
| 81 |
+
resp.raise_for_status()
|
| 82 |
+
return resp.json()
|
| 83 |
+
|
| 84 |
+
def store_bytes(
|
| 85 |
+
self,
|
| 86 |
+
data: bytes,
|
| 87 |
+
agent_id: str,
|
| 88 |
+
task_description: str,
|
| 89 |
+
model_id: str,
|
| 90 |
+
compression: str = "q8_0",
|
| 91 |
+
filename: str = "cache.eng",
|
| 92 |
+
) -> dict[str, Any]:
|
| 93 |
+
"""Upload raw .eng bytes to the engram store."""
|
| 94 |
+
resp = self._client.post(
|
| 95 |
+
"/cache",
|
| 96 |
+
params={
|
| 97 |
+
"agent_id": agent_id,
|
| 98 |
+
"task_description": task_description,
|
| 99 |
+
"model_id": model_id,
|
| 100 |
+
"compression": compression,
|
| 101 |
+
},
|
| 102 |
+
files={"file": (filename, data, "application/octet-stream")},
|
| 103 |
+
)
|
| 104 |
+
resp.raise_for_status()
|
| 105 |
+
return resp.json()
|
| 106 |
+
|
| 107 |
+
# ── Retrieve ──────────────────────────────────────────────
|
| 108 |
+
|
| 109 |
+
def get(self, cache_id: str) -> bytes:
|
| 110 |
+
"""Retrieve a .eng file by cache ID.
|
| 111 |
+
|
| 112 |
+
Returns raw bytes of the .eng file.
|
| 113 |
+
"""
|
| 114 |
+
resp = self._client.get(f"/cache/{cache_id}")
|
| 115 |
+
resp.raise_for_status()
|
| 116 |
+
return resp.content
|
| 117 |
+
|
| 118 |
+
# ── Search ────────────────────────────────────────────────
|
| 119 |
+
|
| 120 |
+
def search(
|
| 121 |
+
self,
|
| 122 |
+
task_description: str,
|
| 123 |
+
model_id: str | None = None,
|
| 124 |
+
top_k: int = 5,
|
| 125 |
+
min_similarity: float | None = None,
|
| 126 |
+
) -> list[dict[str, Any]]:
|
| 127 |
+
"""Search for similar engram states.
|
| 128 |
+
|
| 129 |
+
Returns list of search result dicts with cache_id, similarity, etc.
|
| 130 |
+
"""
|
| 131 |
+
body: dict[str, Any] = {
|
| 132 |
+
"task_description": task_description,
|
| 133 |
+
"top_k": top_k,
|
| 134 |
+
}
|
| 135 |
+
if model_id:
|
| 136 |
+
body["model_id"] = model_id
|
| 137 |
+
if min_similarity is not None:
|
| 138 |
+
body["min_similarity"] = min_similarity
|
| 139 |
+
|
| 140 |
+
resp = self._client.post("/cache/search", json=body)
|
| 141 |
+
resp.raise_for_status()
|
| 142 |
+
return resp.json()["results"]
|
| 143 |
+
|
| 144 |
+
# ── Delete ────────────────────────────────────────────────
|
| 145 |
+
|
| 146 |
+
def delete(self, cache_id: str) -> bool:
|
| 147 |
+
"""Delete an engram from storage and index."""
|
| 148 |
+
resp = self._client.delete(f"/cache/{cache_id}")
|
| 149 |
+
resp.raise_for_status()
|
| 150 |
+
return resp.json()["deleted"]
|
| 151 |
+
|
| 152 |
+
# ── Stats ─────────────────────────────────────────────────
|
| 153 |
+
|
| 154 |
+
def stats(self) -> dict[str, Any]:
|
| 155 |
+
"""Get aggregate engram store statistics."""
|
| 156 |
+
resp = self._client.get("/cache/stats")
|
| 157 |
+
resp.raise_for_status()
|
| 158 |
+
return resp.json()
|
kvcos/core/__init__.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ENGRAM Protocol — Core library: types, parsing, compression, serialization, retrieval."""
|
| 2 |
+
|
| 3 |
+
from kvcos.core.types import (
|
| 4 |
+
ENGRAM_VERSION,
|
| 5 |
+
AttentionType,
|
| 6 |
+
CacheSection,
|
| 7 |
+
CacheSearchResult,
|
| 8 |
+
CacheStats,
|
| 9 |
+
CompressionMethod,
|
| 10 |
+
EngramMetadata,
|
| 11 |
+
ModelCacheSpec,
|
| 12 |
+
StateExtractionMode,
|
| 13 |
+
)
|
| 14 |
+
from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
|
| 15 |
+
from kvcos.core.retriever import EGRRetriever, RetrievalResponse, RetrievalResult
|
| 16 |
+
from kvcos.core.state_extractor import ExtractionResult, MARStateExtractor, SVDProjection
|
| 17 |
+
|
| 18 |
+
__all__ = [
|
| 19 |
+
# Types
|
| 20 |
+
"ENGRAM_VERSION",
|
| 21 |
+
"AttentionType",
|
| 22 |
+
"CacheSection",
|
| 23 |
+
"CacheSearchResult",
|
| 24 |
+
"CacheStats",
|
| 25 |
+
"CompressionMethod",
|
| 26 |
+
"EngramMetadata",
|
| 27 |
+
"ModelCacheSpec",
|
| 28 |
+
"StateExtractionMode",
|
| 29 |
+
# Manifold index
|
| 30 |
+
"IndexEntry",
|
| 31 |
+
"ManifoldIndex",
|
| 32 |
+
# Retriever
|
| 33 |
+
"EGRRetriever",
|
| 34 |
+
"RetrievalResponse",
|
| 35 |
+
"RetrievalResult",
|
| 36 |
+
# State extraction (MAR)
|
| 37 |
+
"ExtractionResult",
|
| 38 |
+
"MARStateExtractor",
|
| 39 |
+
"SVDProjection",
|
| 40 |
+
]
|
kvcos/core/blob_parser.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — llama.cpp State Blob Parser
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Parses the binary state blob from llama_state_get_data() (via save_state())
|
| 6 |
+
into structured PyTorch tensors of shape [n_layers, n_kv_heads, n_cells, head_dim].
|
| 7 |
+
|
| 8 |
+
D1: This is the critical extraction path. The blob format is defined by
|
| 9 |
+
llama.cpp's llama_kv_cache::state_write() and is version-dependent.
|
| 10 |
+
|
| 11 |
+
Validated against llama-cpp-python 0.3.19 (llama.cpp b5000+).
|
| 12 |
+
|
| 13 |
+
Binary format of llama_state_get_data() output:
|
| 14 |
+
1. Architecture string: uint32 str_len + str_len bytes (e.g. "llama")
|
| 15 |
+
2. KV cache section (from memory->state_write()):
|
| 16 |
+
a. uint32 n_stream (always 1 for single-context)
|
| 17 |
+
b. Per stream:
|
| 18 |
+
- uint32 cell_count (= n_used_cells, NOT n_ctx)
|
| 19 |
+
- Per cell: int32 pos, uint32 n_seq_id, int32[] seq_ids
|
| 20 |
+
- uint32 v_trans (1 = values stored transposed)
|
| 21 |
+
- uint32 n_layer
|
| 22 |
+
- Per layer K: int32 type_k, uint64 row_size_k, bytes data[row_size_k * cell_count]
|
| 23 |
+
- Per layer V (non-transposed): int32 type_v, uint64 row_size_v, bytes data[row_size_v * cell_count]
|
| 24 |
+
- Per layer V (transposed): int32 type_v, uint32 el_size, uint32 n_embd_v_gqa,
|
| 25 |
+
bytes data[el_size * n_embd_v_gqa * cell_count]
|
| 26 |
+
|
| 27 |
+
WARNING: This format is not stable across llama.cpp versions.
|
| 28 |
+
Pin llama-cpp-python version in pyproject.toml.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
from __future__ import annotations
|
| 32 |
+
|
| 33 |
+
import struct
|
| 34 |
+
from dataclasses import dataclass
|
| 35 |
+
|
| 36 |
+
import numpy as np
|
| 37 |
+
import torch
|
| 38 |
+
|
| 39 |
+
from kvcos.core.types import CacheSection
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# ── GGML dtype constants ──────────────────────────────────────────────────────
|
| 43 |
+
|
| 44 |
+
GGML_TYPE_F32 = 0
|
| 45 |
+
GGML_TYPE_F16 = 1
|
| 46 |
+
GGML_TYPE_Q8_0 = 8
|
| 47 |
+
GGML_TYPE_Q4_0 = 2
|
| 48 |
+
|
| 49 |
+
GGML_TYPE_SIZE: dict[int, float] = {
|
| 50 |
+
GGML_TYPE_F32: 4.0,
|
| 51 |
+
GGML_TYPE_F16: 2.0,
|
| 52 |
+
GGML_TYPE_Q8_0: 34.0 / 32.0,
|
| 53 |
+
GGML_TYPE_Q4_0: 18.0 / 32.0,
|
| 54 |
+
}
|
| 55 |
+
|
| 56 |
+
GGML_BLOCK_SIZE: dict[int, int] = {
|
| 57 |
+
GGML_TYPE_F32: 1,
|
| 58 |
+
GGML_TYPE_F16: 1,
|
| 59 |
+
GGML_TYPE_Q8_0: 32,
|
| 60 |
+
GGML_TYPE_Q4_0: 32,
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@dataclass
|
| 65 |
+
class CellMeta:
|
| 66 |
+
"""Metadata for a single KV cache cell."""
|
| 67 |
+
|
| 68 |
+
pos: int
|
| 69 |
+
seq_ids: list[int]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class ParsedKVCache:
|
| 74 |
+
"""Result of parsing a llama.cpp state blob into structured engram tensors."""
|
| 75 |
+
|
| 76 |
+
keys: torch.Tensor # [n_layers, n_kv_heads, n_cells, head_dim] float16
|
| 77 |
+
values: torch.Tensor # [n_layers, n_kv_heads, n_cells, head_dim] float16
|
| 78 |
+
cells: list[CellMeta]
|
| 79 |
+
n_cells: int
|
| 80 |
+
n_layers: int
|
| 81 |
+
v_trans: bool
|
| 82 |
+
arch: str
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@dataclass
|
| 86 |
+
class ParsedMultiSectionCache:
|
| 87 |
+
"""Result of parsing an ISWA state blob with multiple KV cache sections.
|
| 88 |
+
|
| 89 |
+
Each section is a ParsedKVCache with its own tensor shapes.
|
| 90 |
+
For Gemma 4: section[0] is Global (5 layers), section[1] is SWA (25 layers).
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
sections: list[ParsedKVCache]
|
| 94 |
+
arch: str
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def n_sections(self) -> int:
|
| 98 |
+
return len(self.sections)
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def total_layers(self) -> int:
|
| 102 |
+
return sum(s.n_layers for s in self.sections)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class BlobParseError(Exception):
|
| 106 |
+
"""Raised when the state blob cannot be parsed."""
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _read_u32(data: bytes, offset: int) -> tuple[int, int]:
|
| 110 |
+
return struct.unpack_from("<I", data, offset)[0], offset + 4
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def _read_i32(data: bytes, offset: int) -> tuple[int, int]:
|
| 114 |
+
return struct.unpack_from("<i", data, offset)[0], offset + 4
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _read_u64(data: bytes, offset: int) -> tuple[int, int]:
|
| 118 |
+
return struct.unpack_from("<Q", data, offset)[0], offset + 8
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _read_f16_block(
|
| 122 |
+
data: bytes, offset: int, n_elements: int,
|
| 123 |
+
) -> tuple[torch.Tensor, int]:
|
| 124 |
+
"""Read n_elements of float16 data from bytes."""
|
| 125 |
+
n_bytes = n_elements * 2
|
| 126 |
+
if offset + n_bytes > len(data):
|
| 127 |
+
raise BlobParseError(
|
| 128 |
+
f"F16 read overflow: need {n_bytes}B at offset {offset}, blob is {len(data)}B"
|
| 129 |
+
)
|
| 130 |
+
arr = np.frombuffer(data, dtype=np.float16, count=n_elements, offset=offset)
|
| 131 |
+
return torch.from_numpy(arr.copy()).to(torch.float16), offset + n_bytes
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def parse_state_blob(
|
| 135 |
+
blob: bytes,
|
| 136 |
+
n_kv_heads: int,
|
| 137 |
+
head_dim: int,
|
| 138 |
+
) -> ParsedKVCache:
|
| 139 |
+
"""Parse a llama.cpp full-context state blob into structured KV tensors.
|
| 140 |
+
|
| 141 |
+
Parses output of llama_state_get_data() (via save_state()):
|
| 142 |
+
1. Architecture string header
|
| 143 |
+
2. KV stream: cell metadata + per-layer K and V tensor data
|
| 144 |
+
|
| 145 |
+
The parser auto-detects n_layers, cell_count, and v_trans from the blob.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
blob: Raw bytes from save_state().llama_state
|
| 149 |
+
n_kv_heads: Number of KV heads (from model spec)
|
| 150 |
+
head_dim: Head dimension (from model spec)
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
ParsedKVCache with [n_layers, n_kv_heads, n_cells, head_dim] tensors.
|
| 154 |
+
"""
|
| 155 |
+
if len(blob) < 20:
|
| 156 |
+
raise BlobParseError(f"Blob too small: {len(blob)} bytes")
|
| 157 |
+
|
| 158 |
+
offset = 0
|
| 159 |
+
n_embd_kv = n_kv_heads * head_dim
|
| 160 |
+
|
| 161 |
+
# ── 1. Architecture string ────────────────────────────────
|
| 162 |
+
str_len, offset = _read_u32(blob, offset)
|
| 163 |
+
if str_len > 100:
|
| 164 |
+
raise BlobParseError(f"Arch string length {str_len} too large — format mismatch")
|
| 165 |
+
arch = blob[offset : offset + str_len].decode("ascii", errors="replace")
|
| 166 |
+
offset += str_len
|
| 167 |
+
|
| 168 |
+
# ── 2. KV stream header ───────────────────────────────────
|
| 169 |
+
n_stream, offset = _read_u32(blob, offset)
|
| 170 |
+
if n_stream != 1:
|
| 171 |
+
raise BlobParseError(f"Expected 1 KV stream, got {n_stream}")
|
| 172 |
+
|
| 173 |
+
cell_count, offset = _read_u32(blob, offset)
|
| 174 |
+
if cell_count == 0:
|
| 175 |
+
raise BlobParseError("State blob has 0 cells")
|
| 176 |
+
if cell_count > 200_000:
|
| 177 |
+
raise BlobParseError(f"Suspiciously large cell_count: {cell_count}")
|
| 178 |
+
|
| 179 |
+
# ── 3. Cell metadata ──────────────────────────────────────
|
| 180 |
+
cells: list[CellMeta] = []
|
| 181 |
+
for _ in range(cell_count):
|
| 182 |
+
pos, offset = _read_i32(blob, offset)
|
| 183 |
+
n_seq, offset = _read_u32(blob, offset)
|
| 184 |
+
seq_ids: list[int] = []
|
| 185 |
+
for _ in range(n_seq):
|
| 186 |
+
sid, offset = _read_i32(blob, offset)
|
| 187 |
+
seq_ids.append(sid)
|
| 188 |
+
cells.append(CellMeta(pos=pos, seq_ids=seq_ids))
|
| 189 |
+
|
| 190 |
+
# ── 4. Data section header ────────────────────────────────
|
| 191 |
+
v_trans_u32, offset = _read_u32(blob, offset)
|
| 192 |
+
v_trans = v_trans_u32 != 0
|
| 193 |
+
|
| 194 |
+
n_layers, offset = _read_u32(blob, offset)
|
| 195 |
+
if n_layers == 0 or n_layers > 200:
|
| 196 |
+
raise BlobParseError(f"Invalid n_layers: {n_layers}")
|
| 197 |
+
|
| 198 |
+
# ── 5. K tensor data (per layer) ──────────────────────────
|
| 199 |
+
k_layers: list[torch.Tensor] = []
|
| 200 |
+
for layer_idx in range(n_layers):
|
| 201 |
+
type_k, offset = _read_i32(blob, offset)
|
| 202 |
+
row_size_k, offset = _read_u64(blob, offset)
|
| 203 |
+
|
| 204 |
+
if type_k != GGML_TYPE_F16:
|
| 205 |
+
raise BlobParseError(
|
| 206 |
+
f"Layer {layer_idx} K: unsupported type {type_k} (expected F16={GGML_TYPE_F16})"
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
data_bytes = row_size_k * cell_count
|
| 210 |
+
n_elements = data_bytes // 2 # fp16
|
| 211 |
+
|
| 212 |
+
if n_elements != n_embd_kv * cell_count:
|
| 213 |
+
raise BlobParseError(
|
| 214 |
+
f"Layer {layer_idx} K: expected {n_embd_kv * cell_count} elements, "
|
| 215 |
+
f"got {n_elements} (row_size={row_size_k}, cells={cell_count})"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
tensor, offset = _read_f16_block(blob, offset, n_elements)
|
| 219 |
+
# Shape: [cell_count, n_kv_heads * head_dim] → [n_kv_heads, cell_count, head_dim]
|
| 220 |
+
tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
|
| 221 |
+
tensor = tensor.permute(1, 0, 2).contiguous()
|
| 222 |
+
k_layers.append(tensor)
|
| 223 |
+
|
| 224 |
+
# ── 6. V tensor data (per layer) ──────────────────────────
|
| 225 |
+
v_layers: list[torch.Tensor] = []
|
| 226 |
+
for layer_idx in range(n_layers):
|
| 227 |
+
type_v, offset = _read_i32(blob, offset)
|
| 228 |
+
|
| 229 |
+
if type_v != GGML_TYPE_F16:
|
| 230 |
+
raise BlobParseError(
|
| 231 |
+
f"Layer {layer_idx} V: unsupported type {type_v} (expected F16={GGML_TYPE_F16})"
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
if v_trans:
|
| 235 |
+
el_size, offset = _read_u32(blob, offset)
|
| 236 |
+
n_embd_v, offset = _read_u32(blob, offset)
|
| 237 |
+
data_bytes = el_size * n_embd_v * cell_count
|
| 238 |
+
n_elements = data_bytes // 2
|
| 239 |
+
|
| 240 |
+
tensor, offset = _read_f16_block(blob, offset, n_elements)
|
| 241 |
+
# V transposed: stored as [n_embd_v, cell_count] per layer
|
| 242 |
+
# n_embd_v = n_kv_heads * head_dim
|
| 243 |
+
tensor = tensor.reshape(n_embd_v // head_dim, head_dim, cell_count)
|
| 244 |
+
# → [n_kv_heads, head_dim, cell_count] → [n_kv_heads, cell_count, head_dim]
|
| 245 |
+
tensor = tensor.permute(0, 2, 1).contiguous()
|
| 246 |
+
else:
|
| 247 |
+
row_size_v, offset = _read_u64(blob, offset)
|
| 248 |
+
data_bytes = row_size_v * cell_count
|
| 249 |
+
n_elements = data_bytes // 2
|
| 250 |
+
|
| 251 |
+
tensor, offset = _read_f16_block(blob, offset, n_elements)
|
| 252 |
+
tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
|
| 253 |
+
tensor = tensor.permute(1, 0, 2).contiguous()
|
| 254 |
+
|
| 255 |
+
v_layers.append(tensor)
|
| 256 |
+
|
| 257 |
+
# ── 7. Stack into [n_layers, n_kv_heads, n_cells, head_dim] ─
|
| 258 |
+
keys = torch.stack(k_layers, dim=0)
|
| 259 |
+
values = torch.stack(v_layers, dim=0)
|
| 260 |
+
|
| 261 |
+
expected_shape = (n_layers, n_kv_heads, cell_count, head_dim)
|
| 262 |
+
if keys.shape != expected_shape:
|
| 263 |
+
raise BlobParseError(f"K shape {keys.shape} != expected {expected_shape}")
|
| 264 |
+
if values.shape != expected_shape:
|
| 265 |
+
raise BlobParseError(f"V shape {values.shape} != expected {expected_shape}")
|
| 266 |
+
|
| 267 |
+
return ParsedKVCache(
|
| 268 |
+
keys=keys,
|
| 269 |
+
values=values,
|
| 270 |
+
cells=cells,
|
| 271 |
+
n_cells=cell_count,
|
| 272 |
+
n_layers=n_layers,
|
| 273 |
+
v_trans=v_trans,
|
| 274 |
+
arch=arch,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def _parse_single_stream(
|
| 279 |
+
blob: bytes,
|
| 280 |
+
offset: int,
|
| 281 |
+
n_kv_heads: int,
|
| 282 |
+
head_dim: int,
|
| 283 |
+
arch: str,
|
| 284 |
+
) -> tuple[ParsedKVCache, int]:
|
| 285 |
+
"""Parse one KV cache stream from blob at given offset.
|
| 286 |
+
|
| 287 |
+
Returns (ParsedKVCache, new_offset) so caller can continue
|
| 288 |
+
parsing subsequent streams for ISWA blobs.
|
| 289 |
+
"""
|
| 290 |
+
n_embd_kv = n_kv_heads * head_dim
|
| 291 |
+
|
| 292 |
+
# Cell count
|
| 293 |
+
cell_count, offset = _read_u32(blob, offset)
|
| 294 |
+
if cell_count == 0:
|
| 295 |
+
raise BlobParseError("Stream has 0 cells")
|
| 296 |
+
if cell_count > 200_000:
|
| 297 |
+
raise BlobParseError(f"Suspiciously large cell_count: {cell_count}")
|
| 298 |
+
|
| 299 |
+
# Cell metadata
|
| 300 |
+
cells: list[CellMeta] = []
|
| 301 |
+
for _ in range(cell_count):
|
| 302 |
+
pos, offset = _read_i32(blob, offset)
|
| 303 |
+
n_seq, offset = _read_u32(blob, offset)
|
| 304 |
+
seq_ids: list[int] = []
|
| 305 |
+
for _ in range(n_seq):
|
| 306 |
+
sid, offset = _read_i32(blob, offset)
|
| 307 |
+
seq_ids.append(sid)
|
| 308 |
+
cells.append(CellMeta(pos=pos, seq_ids=seq_ids))
|
| 309 |
+
|
| 310 |
+
# Data section header
|
| 311 |
+
v_trans_u32, offset = _read_u32(blob, offset)
|
| 312 |
+
v_trans = v_trans_u32 != 0
|
| 313 |
+
|
| 314 |
+
n_layers, offset = _read_u32(blob, offset)
|
| 315 |
+
if n_layers == 0 or n_layers > 200:
|
| 316 |
+
raise BlobParseError(f"Invalid n_layers: {n_layers}")
|
| 317 |
+
|
| 318 |
+
# K layers
|
| 319 |
+
k_layers: list[torch.Tensor] = []
|
| 320 |
+
for layer_idx in range(n_layers):
|
| 321 |
+
type_k, offset = _read_i32(blob, offset)
|
| 322 |
+
row_size_k, offset = _read_u64(blob, offset)
|
| 323 |
+
|
| 324 |
+
if type_k != GGML_TYPE_F16:
|
| 325 |
+
raise BlobParseError(
|
| 326 |
+
f"Layer {layer_idx} K: unsupported type {type_k} (expected F16={GGML_TYPE_F16})"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
data_bytes = row_size_k * cell_count
|
| 330 |
+
n_elements = data_bytes // 2
|
| 331 |
+
|
| 332 |
+
if n_elements != n_embd_kv * cell_count:
|
| 333 |
+
raise BlobParseError(
|
| 334 |
+
f"Layer {layer_idx} K: expected {n_embd_kv * cell_count} elements, "
|
| 335 |
+
f"got {n_elements} (row_size={row_size_k}, cells={cell_count})"
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
tensor, offset = _read_f16_block(blob, offset, n_elements)
|
| 339 |
+
tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
|
| 340 |
+
tensor = tensor.permute(1, 0, 2).contiguous()
|
| 341 |
+
k_layers.append(tensor)
|
| 342 |
+
|
| 343 |
+
# V layers
|
| 344 |
+
v_layers: list[torch.Tensor] = []
|
| 345 |
+
for layer_idx in range(n_layers):
|
| 346 |
+
type_v, offset = _read_i32(blob, offset)
|
| 347 |
+
|
| 348 |
+
if type_v != GGML_TYPE_F16:
|
| 349 |
+
raise BlobParseError(
|
| 350 |
+
f"Layer {layer_idx} V: unsupported type {type_v} (expected F16={GGML_TYPE_F16})"
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
if v_trans:
|
| 354 |
+
el_size, offset = _read_u32(blob, offset)
|
| 355 |
+
n_embd_v, offset = _read_u32(blob, offset)
|
| 356 |
+
data_bytes = el_size * n_embd_v * cell_count
|
| 357 |
+
n_elements = data_bytes // 2
|
| 358 |
+
|
| 359 |
+
tensor, offset = _read_f16_block(blob, offset, n_elements)
|
| 360 |
+
tensor = tensor.reshape(n_embd_v // head_dim, head_dim, cell_count)
|
| 361 |
+
tensor = tensor.permute(0, 2, 1).contiguous()
|
| 362 |
+
else:
|
| 363 |
+
row_size_v, offset = _read_u64(blob, offset)
|
| 364 |
+
data_bytes = row_size_v * cell_count
|
| 365 |
+
n_elements = data_bytes // 2
|
| 366 |
+
|
| 367 |
+
tensor, offset = _read_f16_block(blob, offset, n_elements)
|
| 368 |
+
tensor = tensor.reshape(cell_count, n_kv_heads, head_dim)
|
| 369 |
+
tensor = tensor.permute(1, 0, 2).contiguous()
|
| 370 |
+
|
| 371 |
+
v_layers.append(tensor)
|
| 372 |
+
|
| 373 |
+
keys = torch.stack(k_layers, dim=0)
|
| 374 |
+
values = torch.stack(v_layers, dim=0)
|
| 375 |
+
|
| 376 |
+
expected_shape = (n_layers, n_kv_heads, cell_count, head_dim)
|
| 377 |
+
if keys.shape != expected_shape:
|
| 378 |
+
raise BlobParseError(f"K shape {keys.shape} != expected {expected_shape}")
|
| 379 |
+
if values.shape != expected_shape:
|
| 380 |
+
raise BlobParseError(f"V shape {values.shape} != expected {expected_shape}")
|
| 381 |
+
|
| 382 |
+
parsed = ParsedKVCache(
|
| 383 |
+
keys=keys,
|
| 384 |
+
values=values,
|
| 385 |
+
cells=cells,
|
| 386 |
+
n_cells=cell_count,
|
| 387 |
+
n_layers=n_layers,
|
| 388 |
+
v_trans=v_trans,
|
| 389 |
+
arch=arch,
|
| 390 |
+
)
|
| 391 |
+
return parsed, offset
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
def parse_multi_section_blob(
|
| 395 |
+
blob: bytes,
|
| 396 |
+
sections: tuple[CacheSection, ...],
|
| 397 |
+
) -> ParsedMultiSectionCache:
|
| 398 |
+
"""Parse an ISWA state blob with multiple sequential KV cache sections.
|
| 399 |
+
|
| 400 |
+
ISWA models (e.g., Gemma 4) serialize multiple cache sections in a single
|
| 401 |
+
blob. Each section has its own cell metadata, layer count, and KV dimensions.
|
| 402 |
+
The n_stream field in the blob header equals the number of sections.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
blob: Raw bytes from save_state().llama_state
|
| 406 |
+
sections: Cache section specifications (order must match blob layout)
|
| 407 |
+
|
| 408 |
+
Returns:
|
| 409 |
+
ParsedMultiSectionCache with one ParsedKVCache per section.
|
| 410 |
+
"""
|
| 411 |
+
if len(blob) < 20:
|
| 412 |
+
raise BlobParseError(f"Blob too small: {len(blob)} bytes")
|
| 413 |
+
|
| 414 |
+
offset = 0
|
| 415 |
+
|
| 416 |
+
# Architecture string
|
| 417 |
+
str_len, offset = _read_u32(blob, offset)
|
| 418 |
+
if str_len > 100:
|
| 419 |
+
raise BlobParseError(f"Arch string length {str_len} too large")
|
| 420 |
+
arch = blob[offset : offset + str_len].decode("ascii", errors="replace")
|
| 421 |
+
offset += str_len
|
| 422 |
+
|
| 423 |
+
# Stream count
|
| 424 |
+
n_stream, offset = _read_u32(blob, offset)
|
| 425 |
+
if n_stream != len(sections):
|
| 426 |
+
raise BlobParseError(
|
| 427 |
+
f"Expected {len(sections)} streams, got {n_stream}"
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Parse each stream
|
| 431 |
+
parsed_sections: list[ParsedKVCache] = []
|
| 432 |
+
for section in sections:
|
| 433 |
+
parsed, offset = _parse_single_stream(
|
| 434 |
+
blob, offset,
|
| 435 |
+
n_kv_heads=section.n_kv_heads,
|
| 436 |
+
head_dim=section.head_dim,
|
| 437 |
+
arch=arch,
|
| 438 |
+
)
|
| 439 |
+
parsed_sections.append(parsed)
|
| 440 |
+
|
| 441 |
+
return ParsedMultiSectionCache(sections=parsed_sections, arch=arch)
|
| 442 |
+
|
| 443 |
+
|
| 444 |
+
# ── Legacy compat wrapper ────────────────────────────────────────────────────
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def parse_seq_state_blob(
|
| 448 |
+
blob: bytes,
|
| 449 |
+
spec: dict,
|
| 450 |
+
kv_dtype: int = GGML_TYPE_F16,
|
| 451 |
+
) -> ParsedKVCache:
|
| 452 |
+
"""Legacy wrapper — delegates to parse_state_blob.
|
| 453 |
+
|
| 454 |
+
Kept for backward compatibility with existing tests.
|
| 455 |
+
"""
|
| 456 |
+
return parse_state_blob(
|
| 457 |
+
blob=blob,
|
| 458 |
+
n_kv_heads=spec["n_kv_heads"],
|
| 459 |
+
head_dim=spec["head_dim"],
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
def estimate_blob_size(
|
| 464 |
+
n_kv_heads: int,
|
| 465 |
+
head_dim: int,
|
| 466 |
+
n_layers: int,
|
| 467 |
+
n_cells: int,
|
| 468 |
+
v_trans: bool = True,
|
| 469 |
+
) -> int:
|
| 470 |
+
"""Estimate expected blob size for validation."""
|
| 471 |
+
header = 4 + 5 + 4 + 4 # str_len + "llama" + n_stream + cell_count
|
| 472 |
+
cell_meta = n_cells * 12 # pos(4) + n_seq(4) + seq_id(4) typical
|
| 473 |
+
data_header = 4 + 4 # v_trans + n_layer
|
| 474 |
+
|
| 475 |
+
n_embd_kv = n_kv_heads * head_dim
|
| 476 |
+
k_per_layer = 4 + 8 + (n_embd_kv * 2 * n_cells) # type + row_size + data
|
| 477 |
+
if v_trans:
|
| 478 |
+
v_per_layer = 4 + 4 + 4 + (n_embd_kv * 2 * n_cells) # type + el_size + n_embd + data
|
| 479 |
+
else:
|
| 480 |
+
v_per_layer = 4 + 8 + (n_embd_kv * 2 * n_cells)
|
| 481 |
+
|
| 482 |
+
return header + cell_meta + data_header + n_layers * (k_per_layer + v_per_layer)
|
kvcos/core/block_pool.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — 256-Token Block Pool Manager
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Segments a full KV cache into fixed-size blocks (256 tokens each) that can be:
|
| 6 |
+
- Stored independently (one .eng file per block — D7)
|
| 7 |
+
- Retrieved individually via EGR (fine-grained cache hits)
|
| 8 |
+
- Composed (assemble a context from multiple blocks)
|
| 9 |
+
- Evicted independently (LRU per block, not per session)
|
| 10 |
+
|
| 11 |
+
Design from arXiv:2603.04428 (Persistent Q4 KV Cache, agent-memory paper).
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass, field
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
|
| 20 |
+
from kvcos.core.types import BLOCK_SIZE_TOKENS
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class KVBlock:
|
| 25 |
+
"""A single 256-token block of KV cache data."""
|
| 26 |
+
|
| 27 |
+
block_index: int
|
| 28 |
+
token_start: int
|
| 29 |
+
token_end: int # exclusive
|
| 30 |
+
|
| 31 |
+
keys: torch.Tensor # [n_layers, n_kv_heads, block_len, head_dim]
|
| 32 |
+
values: torch.Tensor # [n_layers, n_kv_heads, block_len, head_dim]
|
| 33 |
+
|
| 34 |
+
@property
|
| 35 |
+
def block_len(self) -> int:
|
| 36 |
+
return self.token_end - self.token_start
|
| 37 |
+
|
| 38 |
+
@property
|
| 39 |
+
def is_full(self) -> bool:
|
| 40 |
+
return self.block_len == BLOCK_SIZE_TOKENS
|
| 41 |
+
|
| 42 |
+
@property
|
| 43 |
+
def n_layers(self) -> int:
|
| 44 |
+
return self.keys.shape[0]
|
| 45 |
+
|
| 46 |
+
@property
|
| 47 |
+
def n_kv_heads(self) -> int:
|
| 48 |
+
return self.keys.shape[1]
|
| 49 |
+
|
| 50 |
+
@property
|
| 51 |
+
def head_dim(self) -> int:
|
| 52 |
+
return self.keys.shape[3]
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class BlockPool:
|
| 57 |
+
"""Manages a collection of KV blocks for an agent session."""
|
| 58 |
+
|
| 59 |
+
agent_id: str
|
| 60 |
+
model_id: str
|
| 61 |
+
blocks: list[KVBlock] = field(default_factory=list)
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def total_tokens(self) -> int:
|
| 65 |
+
return sum(b.block_len for b in self.blocks)
|
| 66 |
+
|
| 67 |
+
@property
|
| 68 |
+
def n_blocks(self) -> int:
|
| 69 |
+
return len(self.blocks)
|
| 70 |
+
|
| 71 |
+
def segment(
|
| 72 |
+
self, keys: torch.Tensor, values: torch.Tensor,
|
| 73 |
+
) -> list[KVBlock]:
|
| 74 |
+
"""Segment a full KV cache into 256-token blocks.
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
keys: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 78 |
+
values: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 79 |
+
"""
|
| 80 |
+
if keys.shape != values.shape:
|
| 81 |
+
raise ValueError(f"Shape mismatch: keys {keys.shape} vs values {values.shape}")
|
| 82 |
+
|
| 83 |
+
ctx_len = keys.shape[2]
|
| 84 |
+
blocks: list[KVBlock] = []
|
| 85 |
+
|
| 86 |
+
for i in range(0, ctx_len, BLOCK_SIZE_TOKENS):
|
| 87 |
+
end = min(i + BLOCK_SIZE_TOKENS, ctx_len)
|
| 88 |
+
block = KVBlock(
|
| 89 |
+
block_index=len(blocks),
|
| 90 |
+
token_start=i,
|
| 91 |
+
token_end=end,
|
| 92 |
+
keys=keys[:, :, i:end, :].contiguous(),
|
| 93 |
+
values=values[:, :, i:end, :].contiguous(),
|
| 94 |
+
)
|
| 95 |
+
blocks.append(block)
|
| 96 |
+
|
| 97 |
+
self.blocks = blocks
|
| 98 |
+
return blocks
|
| 99 |
+
|
| 100 |
+
def assemble(
|
| 101 |
+
self, block_indices: list[int] | None = None,
|
| 102 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
| 103 |
+
"""Assemble KV cache from blocks (concatenate along ctx_len dim)."""
|
| 104 |
+
if not self.blocks:
|
| 105 |
+
raise ValueError("No blocks to assemble")
|
| 106 |
+
|
| 107 |
+
selected = self.blocks if block_indices is None else [self.blocks[i] for i in block_indices]
|
| 108 |
+
if not selected:
|
| 109 |
+
raise ValueError("No blocks selected for assembly")
|
| 110 |
+
|
| 111 |
+
keys = torch.cat([b.keys for b in selected], dim=2)
|
| 112 |
+
values = torch.cat([b.values for b in selected], dim=2)
|
| 113 |
+
return keys, values
|
| 114 |
+
|
| 115 |
+
def append_block(self, block: KVBlock) -> None:
|
| 116 |
+
block.block_index = len(self.blocks)
|
| 117 |
+
self.blocks.append(block)
|
| 118 |
+
|
| 119 |
+
def get_block(self, index: int) -> KVBlock:
|
| 120 |
+
if index < 0 or index >= len(self.blocks):
|
| 121 |
+
raise IndexError(f"Block index {index} out of range [0, {len(self.blocks)})")
|
| 122 |
+
return self.blocks[index]
|
| 123 |
+
|
| 124 |
+
def extend(
|
| 125 |
+
self, new_keys: torch.Tensor, new_values: torch.Tensor,
|
| 126 |
+
) -> list[KVBlock]:
|
| 127 |
+
"""Extend the pool with additional tokens, filling last block first."""
|
| 128 |
+
new_ctx_len = new_keys.shape[2]
|
| 129 |
+
modified_blocks: list[KVBlock] = []
|
| 130 |
+
offset = 0
|
| 131 |
+
|
| 132 |
+
if self.blocks and not self.blocks[-1].is_full:
|
| 133 |
+
last = self.blocks[-1]
|
| 134 |
+
space = BLOCK_SIZE_TOKENS - last.block_len
|
| 135 |
+
fill = min(space, new_ctx_len)
|
| 136 |
+
|
| 137 |
+
merged_k = torch.cat([last.keys, new_keys[:, :, :fill, :]], dim=2).contiguous()
|
| 138 |
+
merged_v = torch.cat([last.values, new_values[:, :, :fill, :]], dim=2).contiguous()
|
| 139 |
+
|
| 140 |
+
self.blocks[-1] = KVBlock(
|
| 141 |
+
block_index=last.block_index,
|
| 142 |
+
token_start=last.token_start,
|
| 143 |
+
token_end=last.token_start + merged_k.shape[2],
|
| 144 |
+
keys=merged_k,
|
| 145 |
+
values=merged_v,
|
| 146 |
+
)
|
| 147 |
+
modified_blocks.append(self.blocks[-1])
|
| 148 |
+
offset = fill
|
| 149 |
+
|
| 150 |
+
remaining = new_ctx_len - offset
|
| 151 |
+
if remaining > 0:
|
| 152 |
+
token_base = self.blocks[-1].token_end if self.blocks else 0
|
| 153 |
+
sub_pool = BlockPool(agent_id=self.agent_id, model_id=self.model_id)
|
| 154 |
+
new_blocks = sub_pool.segment(
|
| 155 |
+
new_keys[:, :, offset:, :], new_values[:, :, offset:, :],
|
| 156 |
+
)
|
| 157 |
+
for b in new_blocks:
|
| 158 |
+
b.block_index = len(self.blocks)
|
| 159 |
+
b.token_start += token_base
|
| 160 |
+
b.token_end += token_base
|
| 161 |
+
self.blocks.append(b)
|
| 162 |
+
modified_blocks.append(b)
|
| 163 |
+
|
| 164 |
+
return modified_blocks
|
| 165 |
+
|
| 166 |
+
def clear(self) -> None:
|
| 167 |
+
self.blocks.clear()
|
kvcos/core/cache_spec.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Model Architecture Registry
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Contains ModelCacheSpec definitions for known models and utilities
|
| 6 |
+
to look up specs by model_id or infer model family from string.
|
| 7 |
+
|
| 8 |
+
D3: extraction_layers set to middle-to-deep (8-31 for 32-layer models)
|
| 9 |
+
per ShadowKV validation. Early layers (0-7) and final layer preserved.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
from kvcos.core.types import AttentionType, CacheSection, ModelCacheSpec
|
| 15 |
+
|
| 16 |
+
# ── Pre-registered Model Specs ────────────────────────────────────────────────
|
| 17 |
+
|
| 18 |
+
# Llama 3.1 8B — Primary Phase 1 target (D1, D6)
|
| 19 |
+
# GQA: 32 query heads, 8 KV heads, head_dim 128
|
| 20 |
+
LLAMA_3_1_8B = ModelCacheSpec(
|
| 21 |
+
model_id="meta-llama/Llama-3.1-8B-Instruct",
|
| 22 |
+
model_family="llama",
|
| 23 |
+
n_layers=32,
|
| 24 |
+
n_heads=32,
|
| 25 |
+
n_kv_heads=8,
|
| 26 |
+
head_dim=128,
|
| 27 |
+
rope_enabled=True,
|
| 28 |
+
extraction_layers=tuple(range(8, 32)), # layers 8-31 (D3)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Llama 3.1 8B base (non-instruct)
|
| 32 |
+
LLAMA_3_1_8B_BASE = ModelCacheSpec(
|
| 33 |
+
model_id="meta-llama/Llama-3.1-8B",
|
| 34 |
+
model_family="llama",
|
| 35 |
+
n_layers=32,
|
| 36 |
+
n_heads=32,
|
| 37 |
+
n_kv_heads=8,
|
| 38 |
+
head_dim=128,
|
| 39 |
+
rope_enabled=True,
|
| 40 |
+
extraction_layers=tuple(range(8, 32)),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
# Phi-3-Mini-128K — Secondary Phase 1 target
|
| 44 |
+
# ShadowKV validated SVD on this model (D3)
|
| 45 |
+
# MHA: 32 query heads, 32 KV heads (no GQA), head_dim 96
|
| 46 |
+
PHI_3_MINI = ModelCacheSpec(
|
| 47 |
+
model_id="microsoft/Phi-3-mini-128k-instruct",
|
| 48 |
+
model_family="phi",
|
| 49 |
+
n_layers=32,
|
| 50 |
+
n_heads=32,
|
| 51 |
+
n_kv_heads=32, # Phi-3-Mini uses MHA, not GQA
|
| 52 |
+
head_dim=96,
|
| 53 |
+
rope_enabled=True,
|
| 54 |
+
extraction_layers=tuple(range(8, 32)),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Gemma 2 2B — NOTE: QK-Norm model, SVD behavior may differ (T3 caveat)
|
| 58 |
+
GEMMA_2_2B = ModelCacheSpec(
|
| 59 |
+
model_id="google/gemma-2-2b-it",
|
| 60 |
+
model_family="gemma",
|
| 61 |
+
n_layers=26,
|
| 62 |
+
n_heads=8,
|
| 63 |
+
n_kv_heads=4,
|
| 64 |
+
head_dim=256,
|
| 65 |
+
rope_enabled=True,
|
| 66 |
+
extraction_layers=tuple(range(6, 26)),
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Qwen 2.5 7B
|
| 70 |
+
QWEN_2_5_7B = ModelCacheSpec(
|
| 71 |
+
model_id="Qwen/Qwen2.5-7B-Instruct",
|
| 72 |
+
model_family="qwen",
|
| 73 |
+
n_layers=28,
|
| 74 |
+
n_heads=28,
|
| 75 |
+
n_kv_heads=4,
|
| 76 |
+
head_dim=128,
|
| 77 |
+
rope_enabled=True,
|
| 78 |
+
extraction_layers=tuple(range(7, 28)),
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Mistral 7B v0.3
|
| 82 |
+
MISTRAL_7B = ModelCacheSpec(
|
| 83 |
+
model_id="mistralai/Mistral-7B-Instruct-v0.3",
|
| 84 |
+
model_family="mistral",
|
| 85 |
+
n_layers=32,
|
| 86 |
+
n_heads=32,
|
| 87 |
+
n_kv_heads=8,
|
| 88 |
+
head_dim=128,
|
| 89 |
+
rope_enabled=True,
|
| 90 |
+
extraction_layers=tuple(range(8, 32)),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# Gemma 4 26B-A4B — ISWA model (Interleaved Sliding Window Attention)
|
| 95 |
+
# Dual KV cache: Global (full context) + SWA (sliding window 1024 tokens)
|
| 96 |
+
# MoE: 128 experts, 8 active — does NOT affect KV cache (FFN-only)
|
| 97 |
+
# Reverse-engineered from llama.cpp b5200+ state blob format.
|
| 98 |
+
GEMMA_4_26B_A4B = ModelCacheSpec(
|
| 99 |
+
model_id="google/gemma-4-26b-a4b-it",
|
| 100 |
+
model_family="gemma",
|
| 101 |
+
n_layers=30, # total: 5 global + 25 SWA
|
| 102 |
+
n_heads=32,
|
| 103 |
+
n_kv_heads=8, # dominant section (SWA)
|
| 104 |
+
head_dim=256, # dominant section (SWA)
|
| 105 |
+
rope_enabled=True,
|
| 106 |
+
extraction_layers=tuple(range(8, 30)),
|
| 107 |
+
cache_sections=(
|
| 108 |
+
CacheSection(
|
| 109 |
+
attention_type=AttentionType.FULL,
|
| 110 |
+
n_layers=5,
|
| 111 |
+
n_kv_heads=2,
|
| 112 |
+
head_dim=512,
|
| 113 |
+
),
|
| 114 |
+
CacheSection(
|
| 115 |
+
attention_type=AttentionType.SLIDING,
|
| 116 |
+
n_layers=25,
|
| 117 |
+
n_kv_heads=8,
|
| 118 |
+
head_dim=256,
|
| 119 |
+
window_size=1024,
|
| 120 |
+
),
|
| 121 |
+
),
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ── Registry ──────────────────────────────────────────────────────────────────
|
| 126 |
+
|
| 127 |
+
_REGISTRY: dict[str, ModelCacheSpec] = {
|
| 128 |
+
spec["model_id"]: spec
|
| 129 |
+
for spec in [
|
| 130 |
+
LLAMA_3_1_8B,
|
| 131 |
+
LLAMA_3_1_8B_BASE,
|
| 132 |
+
PHI_3_MINI,
|
| 133 |
+
GEMMA_2_2B,
|
| 134 |
+
GEMMA_4_26B_A4B,
|
| 135 |
+
QWEN_2_5_7B,
|
| 136 |
+
MISTRAL_7B,
|
| 137 |
+
]
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
_FAMILY_MAP: dict[str, str] = {
|
| 141 |
+
"llama": "llama",
|
| 142 |
+
"meta-llama": "llama",
|
| 143 |
+
"phi": "phi",
|
| 144 |
+
"microsoft/phi": "phi",
|
| 145 |
+
"gemma": "gemma",
|
| 146 |
+
"google/gemma": "gemma",
|
| 147 |
+
"qwen": "qwen",
|
| 148 |
+
"mistral": "mistral",
|
| 149 |
+
"deepseek": "deepseek",
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def get_model_spec(model_id: str) -> ModelCacheSpec | None:
|
| 154 |
+
"""Look up a ModelCacheSpec by exact model_id."""
|
| 155 |
+
return _REGISTRY.get(model_id)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def register_model_spec(spec: ModelCacheSpec) -> None:
|
| 159 |
+
"""Register a new model spec in the runtime registry."""
|
| 160 |
+
_REGISTRY[spec["model_id"]] = spec
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def infer_model_family(model_id: str) -> str:
|
| 164 |
+
"""Infer model family from a model_id string."""
|
| 165 |
+
model_id_lower = model_id.lower()
|
| 166 |
+
for prefix, family in _FAMILY_MAP.items():
|
| 167 |
+
if prefix in model_id_lower:
|
| 168 |
+
return family
|
| 169 |
+
return "unknown"
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def make_spec_from_metadata(
|
| 173 |
+
model_id: str,
|
| 174 |
+
n_layers: int,
|
| 175 |
+
n_heads: int,
|
| 176 |
+
n_kv_heads: int,
|
| 177 |
+
head_dim: int,
|
| 178 |
+
rope_enabled: bool = True,
|
| 179 |
+
) -> ModelCacheSpec:
|
| 180 |
+
"""Create a ModelCacheSpec from raw parameters.
|
| 181 |
+
|
| 182 |
+
Automatically sets extraction_layers to middle-to-deep range (D3).
|
| 183 |
+
"""
|
| 184 |
+
skip_layers = max(1, n_layers // 4)
|
| 185 |
+
extraction_layers = tuple(range(skip_layers, n_layers))
|
| 186 |
+
|
| 187 |
+
return ModelCacheSpec(
|
| 188 |
+
model_id=model_id,
|
| 189 |
+
model_family=infer_model_family(model_id),
|
| 190 |
+
n_layers=n_layers,
|
| 191 |
+
n_heads=n_heads,
|
| 192 |
+
n_kv_heads=n_kv_heads,
|
| 193 |
+
head_dim=head_dim,
|
| 194 |
+
rope_enabled=rope_enabled,
|
| 195 |
+
extraction_layers=extraction_layers,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def is_iswa_spec(spec: ModelCacheSpec) -> bool:
|
| 200 |
+
"""Check if a model spec describes an ISWA (multi-section) cache."""
|
| 201 |
+
return "cache_sections" in spec
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def validate_kv_shape(
|
| 205 |
+
spec: ModelCacheSpec,
|
| 206 |
+
n_layers: int,
|
| 207 |
+
n_kv_heads: int,
|
| 208 |
+
head_dim: int,
|
| 209 |
+
) -> bool:
|
| 210 |
+
"""Validate that KV tensor dimensions match the model spec."""
|
| 211 |
+
return (
|
| 212 |
+
spec["n_layers"] == n_layers
|
| 213 |
+
and spec["n_kv_heads"] == n_kv_heads
|
| 214 |
+
and spec["head_dim"] == head_dim
|
| 215 |
+
)
|
kvcos/core/compression.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — KV Cache Compression Layer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Implements:
|
| 6 |
+
- FP16 passthrough (no compression)
|
| 7 |
+
- Q8_0: group quantization matching llama.cpp GGML_TYPE_Q8_0
|
| 8 |
+
Phase 1 production fallback. ~2x compression, <5% speed hit (D5).
|
| 9 |
+
- PolarQuant: MSE-optimal random rotation + Lloyd-Max codebook at 3 bits.
|
| 10 |
+
QJL REMOVED — confirmed harmful by 6+ independent implementations (D5).
|
| 11 |
+
Softmax amplifies QJL variance, making two-stage worse than MSE-only.
|
| 12 |
+
|
| 13 |
+
Reference: TheTom/turboquant_plus (511+ tests, most mature impl)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from kvcos.core.types import CompressionMethod
|
| 24 |
+
|
| 25 |
+
# ── Q8_0 Constants ────────────────────────────────────────────────────────────
|
| 26 |
+
Q8_GROUP_SIZE = 32
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass(frozen=True)
|
| 30 |
+
class CompressionResult:
|
| 31 |
+
"""Result of compressing a KV cache tensor."""
|
| 32 |
+
|
| 33 |
+
data: torch.Tensor
|
| 34 |
+
method: CompressionMethod
|
| 35 |
+
original_dtype: torch.dtype
|
| 36 |
+
compression_ratio: float
|
| 37 |
+
metadata: dict[str, str]
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# ── FP16 Passthrough ──────────────────────────────────────────────────────────
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def compress_fp16(kv: torch.Tensor) -> CompressionResult:
|
| 44 |
+
"""No-op compression: ensure tensor is FP16."""
|
| 45 |
+
data = kv.to(torch.float16).contiguous()
|
| 46 |
+
return CompressionResult(
|
| 47 |
+
data=data,
|
| 48 |
+
method=CompressionMethod.FP16,
|
| 49 |
+
original_dtype=kv.dtype,
|
| 50 |
+
compression_ratio=1.0,
|
| 51 |
+
metadata={},
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def decompress_fp16(data: torch.Tensor) -> torch.Tensor:
|
| 56 |
+
return data.to(torch.float16)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ── Q8_0 Quantization ────────────────────────────────────────────────────────
|
| 60 |
+
# Matches llama.cpp GGML_TYPE_Q8_0 layout:
|
| 61 |
+
# 32-element groups, 1 float16 scale per group, 32 int8 values
|
| 62 |
+
# Storage: (32*1 + 2) / (32*2) = 34/64 ≈ 1.88x compression
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def compress_q8_0(kv: torch.Tensor) -> CompressionResult:
|
| 66 |
+
"""Quantize KV cache to Q8_0 (int8 with per-group scale).
|
| 67 |
+
|
| 68 |
+
Stores dequantized bfloat16 for safetensors compatibility —
|
| 69 |
+
safetensors doesn't support int8+scale pairs natively.
|
| 70 |
+
"""
|
| 71 |
+
original_dtype = kv.dtype
|
| 72 |
+
original_bytes = kv.numel() * kv.element_size()
|
| 73 |
+
|
| 74 |
+
kv_flat = kv.float().contiguous()
|
| 75 |
+
orig_shape = kv_flat.shape
|
| 76 |
+
|
| 77 |
+
last_dim = orig_shape[-1]
|
| 78 |
+
pad_amount = (Q8_GROUP_SIZE - last_dim % Q8_GROUP_SIZE) % Q8_GROUP_SIZE
|
| 79 |
+
if pad_amount > 0:
|
| 80 |
+
kv_flat = torch.nn.functional.pad(kv_flat, (0, pad_amount))
|
| 81 |
+
|
| 82 |
+
new_shape = kv_flat.shape[:-1] + (-1, Q8_GROUP_SIZE)
|
| 83 |
+
grouped = kv_flat.reshape(new_shape)
|
| 84 |
+
|
| 85 |
+
scales = grouped.abs().amax(dim=-1, keepdim=True) / 127.0
|
| 86 |
+
scales = scales.clamp(min=1e-10)
|
| 87 |
+
|
| 88 |
+
quantized = torch.clamp(torch.round(grouped / scales), -127, 127)
|
| 89 |
+
dequantized = (quantized * scales).reshape(kv_flat.shape)
|
| 90 |
+
|
| 91 |
+
if pad_amount > 0:
|
| 92 |
+
dequantized = dequantized[..., :last_dim]
|
| 93 |
+
|
| 94 |
+
dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
|
| 95 |
+
compressed_bytes = dequantized.numel() * 2
|
| 96 |
+
|
| 97 |
+
return CompressionResult(
|
| 98 |
+
data=dequantized,
|
| 99 |
+
method=CompressionMethod.Q8_0,
|
| 100 |
+
original_dtype=original_dtype,
|
| 101 |
+
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
|
| 102 |
+
metadata={"q8_group_size": str(Q8_GROUP_SIZE)},
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def decompress_q8_0(data: torch.Tensor) -> torch.Tensor:
|
| 107 |
+
return data.to(torch.float16)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# ── PolarQuant (Phase 2 — TurboQuant without QJL) ────────────────────────────
|
| 111 |
+
# QJL is INTENTIONALLY ABSENT per D5.
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class PolarQuantConfig:
|
| 115 |
+
"""Configuration for PolarQuant compression."""
|
| 116 |
+
|
| 117 |
+
def __init__(self, bits: int = 3, seed: int = 42):
|
| 118 |
+
self.bits = bits
|
| 119 |
+
self.n_centroids = 2**bits
|
| 120 |
+
self.seed = seed
|
| 121 |
+
self._rotation_cache: dict[int, torch.Tensor] = {}
|
| 122 |
+
self._codebook_cache: dict[int, torch.Tensor] = {}
|
| 123 |
+
|
| 124 |
+
def get_rotation_matrix(self, dim: int, device: torch.device) -> torch.Tensor:
|
| 125 |
+
"""Get fixed random orthogonal rotation matrix R ∈ R^(d×d)."""
|
| 126 |
+
if dim not in self._rotation_cache:
|
| 127 |
+
rng = np.random.RandomState(self.seed)
|
| 128 |
+
gaussian = rng.randn(dim, dim).astype(np.float32)
|
| 129 |
+
q, r = np.linalg.qr(gaussian)
|
| 130 |
+
d = np.diag(r)
|
| 131 |
+
ph = np.sign(d)
|
| 132 |
+
q *= ph[np.newaxis, :]
|
| 133 |
+
self._rotation_cache[dim] = torch.from_numpy(q)
|
| 134 |
+
return self._rotation_cache[dim].to(device)
|
| 135 |
+
|
| 136 |
+
def get_lloyd_max_codebook(self, dim: int) -> torch.Tensor:
|
| 137 |
+
"""Lloyd-Max optimal centroids for N(0,1), 3-bit (8 levels)."""
|
| 138 |
+
if dim not in self._codebook_cache:
|
| 139 |
+
codebook = torch.tensor(
|
| 140 |
+
[-1.748, -1.050, -0.501, -0.000, 0.000, 0.501, 1.050, 1.748],
|
| 141 |
+
dtype=torch.float32,
|
| 142 |
+
)
|
| 143 |
+
self._codebook_cache[dim] = codebook
|
| 144 |
+
return self._codebook_cache[dim]
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
_POLAR_CONFIG = PolarQuantConfig()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def compress_polarquant(kv: torch.Tensor) -> CompressionResult:
|
| 151 |
+
"""Compress using PolarQuant (3-bit Lloyd-Max after random rotation).
|
| 152 |
+
|
| 153 |
+
Phase 2 implementation. Currently stores dequantized bfloat16.
|
| 154 |
+
True 3-bit packed storage is Phase 2+.
|
| 155 |
+
"""
|
| 156 |
+
original_dtype = kv.dtype
|
| 157 |
+
original_bytes = kv.numel() * kv.element_size()
|
| 158 |
+
device = kv.device
|
| 159 |
+
|
| 160 |
+
kv_float = kv.float().contiguous()
|
| 161 |
+
orig_shape = kv_float.shape
|
| 162 |
+
|
| 163 |
+
head_dim = orig_shape[-1]
|
| 164 |
+
flat = kv_float.reshape(-1, head_dim)
|
| 165 |
+
|
| 166 |
+
R = _POLAR_CONFIG.get_rotation_matrix(head_dim, device)
|
| 167 |
+
rotated = flat @ R
|
| 168 |
+
|
| 169 |
+
dim_std = rotated.std(dim=0, keepdim=True).clamp(min=1e-10)
|
| 170 |
+
normalized = rotated / dim_std
|
| 171 |
+
|
| 172 |
+
codebook = _POLAR_CONFIG.get_lloyd_max_codebook(head_dim).to(device)
|
| 173 |
+
distances = (normalized.unsqueeze(-1) - codebook.unsqueeze(0).unsqueeze(0)) ** 2
|
| 174 |
+
indices = distances.argmin(dim=-1)
|
| 175 |
+
|
| 176 |
+
dequantized = codebook[indices]
|
| 177 |
+
dequantized = dequantized * dim_std
|
| 178 |
+
R_inv = R.T
|
| 179 |
+
dequantized = dequantized @ R_inv
|
| 180 |
+
|
| 181 |
+
dequantized = dequantized.reshape(orig_shape).to(torch.bfloat16)
|
| 182 |
+
compressed_bytes = dequantized.numel() * 2
|
| 183 |
+
|
| 184 |
+
return CompressionResult(
|
| 185 |
+
data=dequantized,
|
| 186 |
+
method=CompressionMethod.POLARQUANT,
|
| 187 |
+
original_dtype=original_dtype,
|
| 188 |
+
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
|
| 189 |
+
metadata={
|
| 190 |
+
"polarquant_bits": "3",
|
| 191 |
+
"polarquant_seed": str(_POLAR_CONFIG.seed),
|
| 192 |
+
"qjl_enabled": "false", # D5: QJL permanently disabled
|
| 193 |
+
},
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def decompress_polarquant(data: torch.Tensor) -> torch.Tensor:
|
| 198 |
+
return data.to(torch.float16)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ── INT8 Quantization (Phase 2 — true on-disk compression) ───────────────────
|
| 202 |
+
# Stores actual int8 tensors in safetensors (1 byte/element vs 2 for fp16).
|
| 203 |
+
# Per-row symmetric quantization: scale = max(abs(row)) / 127.
|
| 204 |
+
# Separate scale tensor stored alongside quantized data.
|
| 205 |
+
# 2x on-disk compression with cos_sim > 0.999.
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
@dataclass(frozen=True)
|
| 209 |
+
class Int8CompressedPair:
|
| 210 |
+
"""INT8 quantized tensor + per-row scales."""
|
| 211 |
+
|
| 212 |
+
quantized: torch.Tensor # int8 [same shape as input]
|
| 213 |
+
scales: torch.Tensor # float16 [shape[:-1]] — one scale per row
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def compress_int8_tensor(kv: torch.Tensor) -> Int8CompressedPair:
|
| 217 |
+
"""Quantize a KV tensor to int8 with per-row scales.
|
| 218 |
+
|
| 219 |
+
Args:
|
| 220 |
+
kv: [..., head_dim] tensor (any dtype)
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
Int8CompressedPair with int8 data and float16 scales
|
| 224 |
+
"""
|
| 225 |
+
orig_shape = kv.shape
|
| 226 |
+
flat = kv.float().reshape(-1, orig_shape[-1])
|
| 227 |
+
|
| 228 |
+
row_max = flat.abs().amax(dim=1, keepdim=True).clamp(min=1e-8)
|
| 229 |
+
scales = row_max / 127.0
|
| 230 |
+
|
| 231 |
+
quantized = (flat / scales).round().clamp(-127, 127).to(torch.int8)
|
| 232 |
+
scales_f16 = scales.squeeze(1).to(torch.float16)
|
| 233 |
+
|
| 234 |
+
return Int8CompressedPair(
|
| 235 |
+
quantized=quantized.reshape(orig_shape),
|
| 236 |
+
scales=scales_f16.reshape(orig_shape[:-1]),
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
def decompress_int8_tensor(quantized: torch.Tensor, scales: torch.Tensor) -> torch.Tensor:
|
| 241 |
+
"""Dequantize int8 tensor using per-row scales.
|
| 242 |
+
|
| 243 |
+
Returns float16 tensor of the original shape.
|
| 244 |
+
"""
|
| 245 |
+
return (quantized.float() * scales.float().unsqueeze(-1)).to(torch.float16)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def compress_int8(kv: torch.Tensor) -> CompressionResult:
|
| 249 |
+
"""INT8 compression — returns dequantized float16 for CompressionResult compat.
|
| 250 |
+
|
| 251 |
+
The actual int8 storage is handled by the serializer which calls
|
| 252 |
+
compress_int8_tensor() directly for true on-disk compression.
|
| 253 |
+
This wrapper exists for the dispatcher API.
|
| 254 |
+
"""
|
| 255 |
+
pair = compress_int8_tensor(kv)
|
| 256 |
+
dequantized = decompress_int8_tensor(pair.quantized, pair.scales)
|
| 257 |
+
|
| 258 |
+
original_bytes = kv.numel() * kv.element_size()
|
| 259 |
+
# True on-disk: int8 data + float16 scales
|
| 260 |
+
compressed_bytes = pair.quantized.numel() * 1 + pair.scales.numel() * 2
|
| 261 |
+
|
| 262 |
+
return CompressionResult(
|
| 263 |
+
data=dequantized,
|
| 264 |
+
method=CompressionMethod.INT8,
|
| 265 |
+
original_dtype=kv.dtype,
|
| 266 |
+
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
|
| 267 |
+
metadata={"int8_scale_dtype": "float16"},
|
| 268 |
+
)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# ── LAYER_DELTA Compression ──────────────────────────────────────────────────
|
| 272 |
+
# Stores layer 0 as fp16 baseline, layers 1..N as int8 deltas from previous.
|
| 273 |
+
# Inter-layer residuals are typically small (adjacent layers are correlated),
|
| 274 |
+
# so int8 quantization of deltas achieves better fidelity than direct int8.
|
| 275 |
+
# On-disk: ~(1/N) fp16 + ((N-1)/N) int8 ≈ slightly better than straight INT8.
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
@dataclass(frozen=True)
|
| 279 |
+
class LayerDeltaCompressed:
|
| 280 |
+
"""Layer-delta compressed: fp16 baseline + int8 deltas."""
|
| 281 |
+
|
| 282 |
+
baseline: torch.Tensor # [n_kv_heads, n_cells, head_dim] fp16
|
| 283 |
+
delta_quantized: list[torch.Tensor] # each int8 [n_kv_heads, n_cells, head_dim]
|
| 284 |
+
delta_scales: list[torch.Tensor] # each fp16 [n_kv_heads, n_cells]
|
| 285 |
+
n_layers: int
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
def compress_layer_delta(kv: torch.Tensor) -> LayerDeltaCompressed:
|
| 289 |
+
"""Compress KV tensor using inter-layer delta encoding.
|
| 290 |
+
|
| 291 |
+
Args:
|
| 292 |
+
kv: [n_layers, n_kv_heads, n_cells, head_dim]
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
LayerDeltaCompressed with fp16 baseline + int8 deltas
|
| 296 |
+
"""
|
| 297 |
+
n_layers = kv.shape[0]
|
| 298 |
+
baseline = kv[0].to(torch.float16)
|
| 299 |
+
|
| 300 |
+
deltas: list[torch.Tensor] = []
|
| 301 |
+
scales: list[torch.Tensor] = []
|
| 302 |
+
|
| 303 |
+
for i in range(1, n_layers):
|
| 304 |
+
delta = (kv[i].float() - kv[i - 1].float())
|
| 305 |
+
flat = delta.reshape(-1, delta.shape[-1])
|
| 306 |
+
row_max = flat.abs().amax(dim=1).clamp(min=1e-8) / 127.0
|
| 307 |
+
q = (flat / row_max.unsqueeze(1)).round().clamp(-127, 127).to(torch.int8)
|
| 308 |
+
deltas.append(q.reshape(delta.shape))
|
| 309 |
+
scales.append(row_max.to(torch.float16).reshape(delta.shape[:-1]))
|
| 310 |
+
|
| 311 |
+
return LayerDeltaCompressed(
|
| 312 |
+
baseline=baseline, delta_quantized=deltas,
|
| 313 |
+
delta_scales=scales, n_layers=n_layers,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def decompress_layer_delta(data: LayerDeltaCompressed) -> torch.Tensor:
|
| 318 |
+
"""Decompress layer-delta encoded KV tensor."""
|
| 319 |
+
layers = [data.baseline.float()]
|
| 320 |
+
for dq, ds in zip(data.delta_quantized, data.delta_scales):
|
| 321 |
+
flat = dq.float().reshape(-1, dq.shape[-1])
|
| 322 |
+
delta = (flat * ds.float().reshape(-1).unsqueeze(1)).reshape(dq.shape)
|
| 323 |
+
layers.append(layers[-1] + delta)
|
| 324 |
+
return torch.stack(layers).to(torch.float16)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
def compress_layer_delta_result(kv: torch.Tensor) -> CompressionResult:
|
| 328 |
+
"""Layer-delta wrapper for CompressionResult API."""
|
| 329 |
+
compressed = compress_layer_delta(kv)
|
| 330 |
+
decompressed = decompress_layer_delta(compressed)
|
| 331 |
+
|
| 332 |
+
original_bytes = kv.numel() * kv.element_size()
|
| 333 |
+
# On-disk: baseline fp16 + (N-1) int8 deltas + (N-1) fp16 scales
|
| 334 |
+
n = compressed.n_layers
|
| 335 |
+
per_layer_elements = kv[0].numel()
|
| 336 |
+
scale_elements = kv.shape[1] * kv.shape[2] # n_kv_heads * n_cells
|
| 337 |
+
compressed_bytes = (
|
| 338 |
+
per_layer_elements * 2 # baseline fp16
|
| 339 |
+
+ (n - 1) * per_layer_elements * 1 # int8 deltas
|
| 340 |
+
+ (n - 1) * scale_elements * 2 # fp16 scales
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
return CompressionResult(
|
| 344 |
+
data=decompressed,
|
| 345 |
+
method=CompressionMethod.LAYER_DELTA,
|
| 346 |
+
original_dtype=kv.dtype,
|
| 347 |
+
compression_ratio=original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
|
| 348 |
+
metadata={"delta_n_layers": str(n)},
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
|
| 352 |
+
# ── Dispatcher ────────────────────────────────────────────────────────────────
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def compress(kv: torch.Tensor, method: CompressionMethod) -> CompressionResult:
|
| 356 |
+
"""Compress a KV cache tensor using the specified method."""
|
| 357 |
+
match method:
|
| 358 |
+
case CompressionMethod.FP16:
|
| 359 |
+
return compress_fp16(kv)
|
| 360 |
+
case CompressionMethod.Q8_0:
|
| 361 |
+
return compress_q8_0(kv)
|
| 362 |
+
case CompressionMethod.POLARQUANT:
|
| 363 |
+
return compress_polarquant(kv)
|
| 364 |
+
case CompressionMethod.INT8:
|
| 365 |
+
return compress_int8(kv)
|
| 366 |
+
case CompressionMethod.LAYER_DELTA:
|
| 367 |
+
return compress_layer_delta_result(kv)
|
| 368 |
+
case CompressionMethod.Q4_0:
|
| 369 |
+
import warnings
|
| 370 |
+
|
| 371 |
+
warnings.warn(
|
| 372 |
+
"Q4_0 has 92% dequantization slowdown at 64K+ context. "
|
| 373 |
+
"Using Q8_0 instead. See D5.",
|
| 374 |
+
UserWarning,
|
| 375 |
+
stacklevel=2,
|
| 376 |
+
)
|
| 377 |
+
return compress_q8_0(kv)
|
| 378 |
+
case _:
|
| 379 |
+
raise ValueError(f"Unknown compression method: {method}")
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def decompress(data: torch.Tensor, method: CompressionMethod) -> torch.Tensor:
|
| 383 |
+
"""Decompress a KV cache tensor."""
|
| 384 |
+
match method:
|
| 385 |
+
case CompressionMethod.FP16:
|
| 386 |
+
return decompress_fp16(data)
|
| 387 |
+
case CompressionMethod.Q8_0 | CompressionMethod.Q4_0:
|
| 388 |
+
return decompress_q8_0(data)
|
| 389 |
+
case CompressionMethod.POLARQUANT:
|
| 390 |
+
return decompress_polarquant(data)
|
| 391 |
+
case CompressionMethod.INT8 | CompressionMethod.LAYER_DELTA:
|
| 392 |
+
# Already dequantized float16 in CompressionResult
|
| 393 |
+
return data.to(torch.float16)
|
| 394 |
+
case _:
|
| 395 |
+
raise ValueError(f"Unknown compression method: {method}")
|
kvcos/core/config.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Centralized Configuration
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Single source of truth for all runtime configuration.
|
| 6 |
+
Uses pydantic-settings for validation and type coercion.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from functools import lru_cache
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
| 15 |
+
|
| 16 |
+
from kvcos.core.types import CompressionMethod, IndexBackend, StorageBackend
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EngramConfig(BaseSettings):
|
| 20 |
+
"""ENGRAM runtime configuration.
|
| 21 |
+
|
| 22 |
+
Loaded from environment variables with ENGRAM_ prefix,
|
| 23 |
+
or from a .env file in the project root.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
model_config = SettingsConfigDict(
|
| 27 |
+
env_prefix="ENGRAM_",
|
| 28 |
+
env_file=".env",
|
| 29 |
+
env_file_encoding="utf-8",
|
| 30 |
+
extra="ignore",
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
# ── Server ────────────────────────────────────────────────
|
| 34 |
+
port: int = 8080
|
| 35 |
+
host: str = "0.0.0.0"
|
| 36 |
+
|
| 37 |
+
# ── Storage ───────────────────────────────────────────────
|
| 38 |
+
data_dir: Path = Path.home() / ".engram" / "data"
|
| 39 |
+
backend: StorageBackend = StorageBackend.LOCAL
|
| 40 |
+
default_compression: CompressionMethod = CompressionMethod.Q8_0
|
| 41 |
+
|
| 42 |
+
# ── FAISS Index (D2) ──────────────────────────────────────
|
| 43 |
+
index_backend: IndexBackend = IndexBackend.FAISS_FLAT_IP
|
| 44 |
+
index_dir: Path = Path.home() / ".engram" / "index"
|
| 45 |
+
# State vector dimension — must match extraction output
|
| 46 |
+
# 128 for mean_pool (head_dim), 160 for svd_project (rank-160)
|
| 47 |
+
state_vec_dim: int = 160
|
| 48 |
+
|
| 49 |
+
# ── LLM Runtime (D1) ──────────────────────────────────────
|
| 50 |
+
model_path: str = "" # Path to GGUF model file
|
| 51 |
+
n_gpu_layers: int = 0 # D1: CPU-only Phase 1 (avoids Issue #743)
|
| 52 |
+
n_ctx: int = 16384 # D6: 16K context for Phase 1 demo target
|
| 53 |
+
|
| 54 |
+
# ── Phase 2: Remote backends ──────────────────────────────
|
| 55 |
+
redis_url: str = "redis://localhost:6379"
|
| 56 |
+
redis_max_memory_gb: float = 2.0
|
| 57 |
+
s3_bucket: str = "engram-cache"
|
| 58 |
+
s3_region: str = "eu-central-1"
|
| 59 |
+
cloudflare_r2_endpoint: str = ""
|
| 60 |
+
|
| 61 |
+
# ── Phase 2: Semantic index ───────────────────────────────
|
| 62 |
+
qdrant_url: str = "http://localhost:6333"
|
| 63 |
+
qdrant_api_key: str = ""
|
| 64 |
+
qdrant_collection: str = "engram_states"
|
| 65 |
+
cohere_api_key: str = ""
|
| 66 |
+
|
| 67 |
+
# ── Phase 4: Cross-model transfer ─────────────────────────
|
| 68 |
+
adapter_enabled: bool = False
|
| 69 |
+
adapter_checkpoint_dir: Path = Path.home() / ".engram" / "adapters"
|
| 70 |
+
|
| 71 |
+
def ensure_dirs(self) -> None:
|
| 72 |
+
"""Create required directories if they don't exist."""
|
| 73 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 74 |
+
self.index_dir.mkdir(parents=True, exist_ok=True)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@lru_cache(maxsize=1)
|
| 78 |
+
def get_config() -> EngramConfig:
|
| 79 |
+
"""Get the singleton config instance. Cached after first call."""
|
| 80 |
+
config = EngramConfig()
|
| 81 |
+
config.ensure_dirs()
|
| 82 |
+
return config
|
kvcos/core/fingerprint.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Standalone State Extraction Functions
|
| 3 |
+
|
| 4 |
+
Contains the Engram Absolute fingerprint: compute_fourier_fingerprint().
|
| 5 |
+
This is the primary cross-model retrieval fingerprint, validated at
|
| 6 |
+
98% recall@1 at N=1000 with power-law decay N^-0.207.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from typing import TYPE_CHECKING
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from kvcos.core.blob_parser import ParsedMultiSectionCache
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def compute_fourier_fingerprint(
|
| 21 |
+
layer_keys: torch.Tensor,
|
| 22 |
+
freqs: list[int] | None = None,
|
| 23 |
+
) -> torch.Tensor:
|
| 24 |
+
"""Compute the Engram Absolute fingerprint (f0+f1) from per-layer mean keys.
|
| 25 |
+
|
| 26 |
+
Takes the real DFT over the layer dimension, extracts amplitude at
|
| 27 |
+
the specified frequencies, normalizes each, and concatenates.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
layer_keys: [n_layers, dim] where dim = n_kv_heads * head_dim.
|
| 31 |
+
Must be per-layer MEAN across token positions.
|
| 32 |
+
If shape is [n_layers, n_kv, hd], reshapes automatically.
|
| 33 |
+
freqs: DFT frequency indices to concatenate.
|
| 34 |
+
Default [0, 1] = DC (f0) + first harmonic (f1).
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Fingerprint tensor [dim * len(freqs)], L2-normalized, float32.
|
| 38 |
+
|
| 39 |
+
Properties:
|
| 40 |
+
- Cross-model invariant within Llama-3.x family (cos ~0.89)
|
| 41 |
+
- Zero corpus dependency: no centroid, no basis, no training data
|
| 42 |
+
- Recall@1 N=200: 98% N=1000: 98% decay: N^-0.207
|
| 43 |
+
"""
|
| 44 |
+
if freqs is None:
|
| 45 |
+
freqs = [0, 1]
|
| 46 |
+
|
| 47 |
+
if layer_keys.dim() == 3:
|
| 48 |
+
n_layers = layer_keys.shape[0]
|
| 49 |
+
layer_keys = layer_keys.reshape(n_layers, -1)
|
| 50 |
+
|
| 51 |
+
layer_keys = layer_keys.float()
|
| 52 |
+
|
| 53 |
+
F_complex = torch.fft.rfft(layer_keys, dim=0)
|
| 54 |
+
F_amp = F_complex.abs()
|
| 55 |
+
|
| 56 |
+
components = []
|
| 57 |
+
for f in freqs:
|
| 58 |
+
if f >= F_amp.shape[0]:
|
| 59 |
+
raise ValueError(
|
| 60 |
+
f"Requested freq={f} but rfft produced only "
|
| 61 |
+
f"{F_amp.shape[0]} components for {layer_keys.shape[0]} layers."
|
| 62 |
+
)
|
| 63 |
+
components.append(F.normalize(F_amp[f], dim=-1))
|
| 64 |
+
|
| 65 |
+
return torch.cat(components, dim=-1)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def compute_eigenform_score(
|
| 69 |
+
layer_keys: torch.Tensor,
|
| 70 |
+
noise_sigma: float = 0.001,
|
| 71 |
+
n_trials: int = 3,
|
| 72 |
+
freqs: list | None = None,
|
| 73 |
+
) -> float:
|
| 74 |
+
"""Compute eigenform stability score via noise perturbation.
|
| 75 |
+
|
| 76 |
+
Measures how stable the Fourier fingerprint is under small noise.
|
| 77 |
+
Score near 1.0 = stable. Below 0.95 = fragile fingerprint.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
layer_keys: [n_layers, dim] per-layer mean key vectors.
|
| 81 |
+
noise_sigma: Gaussian noise standard deviation.
|
| 82 |
+
n_trials: Number of perturbed copies to compare.
|
| 83 |
+
freqs: DFT frequencies. Default [0, 1].
|
| 84 |
+
|
| 85 |
+
Returns:
|
| 86 |
+
float in [0, 1]. Mean pairwise cosine across noise trials.
|
| 87 |
+
"""
|
| 88 |
+
if freqs is None:
|
| 89 |
+
freqs = [0, 1]
|
| 90 |
+
fps = []
|
| 91 |
+
for t in range(n_trials):
|
| 92 |
+
noisy = layer_keys if t == 0 else layer_keys + torch.randn_like(layer_keys) * noise_sigma
|
| 93 |
+
fps.append(compute_fourier_fingerprint(noisy.float(), freqs=freqs))
|
| 94 |
+
pairs = [(i, j) for i in range(n_trials) for j in range(i+1, n_trials)]
|
| 95 |
+
if not pairs:
|
| 96 |
+
return 1.0
|
| 97 |
+
return float(sum(F.cosine_similarity(fps[a].unsqueeze(0), fps[b].unsqueeze(0)).item() for a, b in pairs) / len(pairs))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def compute_iswa_fingerprint(
|
| 101 |
+
parsed: "ParsedMultiSectionCache",
|
| 102 |
+
freqs: list | None = None,
|
| 103 |
+
normalize_layers: bool = True,
|
| 104 |
+
) -> torch.Tensor:
|
| 105 |
+
"""Compute concatenated Fourier fingerprint for ISWA multi-section caches.
|
| 106 |
+
|
| 107 |
+
Strategy A (per-section concatenation):
|
| 108 |
+
For each cache section, compute mean over tokens, then Fourier FP.
|
| 109 |
+
Concatenate section FPs into one vector.
|
| 110 |
+
|
| 111 |
+
For Gemma 4 with freqs=[0, 1]:
|
| 112 |
+
Global (5 layers, 2 heads, 512 dim) → 1024 * 2 = 2048
|
| 113 |
+
SWA (25 layers, 8 heads, 256 dim) → 2048 * 2 = 4096
|
| 114 |
+
Total: 6144-dim fingerprint
|
| 115 |
+
|
| 116 |
+
Each section's sub-fingerprint is independently L2-normalized,
|
| 117 |
+
preserving the relative geometry within each attention type.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
parsed: ParsedMultiSectionCache from parse_multi_section_blob()
|
| 121 |
+
freqs: DFT frequency indices. Default [0, 1].
|
| 122 |
+
normalize_layers: L2-normalize each layer before DFT (v2 behavior).
|
| 123 |
+
|
| 124 |
+
Returns:
|
| 125 |
+
Concatenated fingerprint tensor, float32.
|
| 126 |
+
"""
|
| 127 |
+
if freqs is None:
|
| 128 |
+
freqs = [0, 1]
|
| 129 |
+
|
| 130 |
+
section_fps: list[torch.Tensor] = []
|
| 131 |
+
for section in parsed.sections:
|
| 132 |
+
# Mean over tokens: [n_layers, n_kv_heads, n_cells, head_dim] → [n_layers, n_kv_heads * head_dim]
|
| 133 |
+
layer_keys = section.keys.float().mean(dim=2)
|
| 134 |
+
fp = compute_fourier_fingerprint_v2(layer_keys, freqs=freqs, normalize_layers=normalize_layers)
|
| 135 |
+
section_fps.append(fp)
|
| 136 |
+
|
| 137 |
+
return torch.cat(section_fps, dim=-1)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def compute_fourier_fingerprint_v2(
|
| 141 |
+
layer_keys: torch.Tensor,
|
| 142 |
+
freqs: list | None = None,
|
| 143 |
+
normalize_layers: bool = True,
|
| 144 |
+
) -> torch.Tensor:
|
| 145 |
+
"""Fourier fingerprint v2: L2-normalize each layer before DFT.
|
| 146 |
+
|
| 147 |
+
Removes absolute magnitude scale (which differs by KV head count
|
| 148 |
+
across model families), preserves layer-progression shape.
|
| 149 |
+
|
| 150 |
+
Within-family: same recall as v1 (98%).
|
| 151 |
+
Cross-family: f0+f1 cross-sim expected >>0.26 (v1 baseline).
|
| 152 |
+
"""
|
| 153 |
+
if freqs is None:
|
| 154 |
+
freqs = [0, 1]
|
| 155 |
+
if layer_keys.dim() == 3:
|
| 156 |
+
layer_keys = layer_keys.reshape(layer_keys.shape[0], -1)
|
| 157 |
+
layer_keys = layer_keys.float()
|
| 158 |
+
if normalize_layers:
|
| 159 |
+
layer_keys = F.normalize(layer_keys, dim=-1)
|
| 160 |
+
F_complex = torch.fft.rfft(layer_keys, dim=0)
|
| 161 |
+
F_amp = F_complex.abs()
|
| 162 |
+
components = []
|
| 163 |
+
for f in freqs:
|
| 164 |
+
if f >= F_amp.shape[0]:
|
| 165 |
+
raise ValueError(f"freq={f} out of range for {layer_keys.shape[0]} layers")
|
| 166 |
+
components.append(F.normalize(F_amp[f], dim=-1))
|
| 167 |
+
return torch.cat(components, dim=-1)
|
kvcos/core/manifold_index.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engrammatic Geometry Retrieval — Manifold Index
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
FAISS-backed MIPS (Maximum Inner Product Search) index for EGR retrieval.
|
| 6 |
+
Indexes state vectors extracted from .eng files by MARStateExtractor.
|
| 7 |
+
|
| 8 |
+
D2: FAISS IndexFlatIP for K→K retrieval only. Never Q→K.
|
| 9 |
+
faiss.serialize_index() for persistence (not write_index — avoids
|
| 10 |
+
platform incompatibility Issue #3888). Atomic write via temp + rename.
|
| 11 |
+
MKL build enforced at import time.
|
| 12 |
+
|
| 13 |
+
D4: No L2 normalization. True MIPS. Raw inner product scores.
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass, field
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import faiss
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
|
| 25 |
+
from kvcos.core.types import CacheSearchResult
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class IndexEntry:
|
| 30 |
+
"""Metadata associated with an indexed state vector."""
|
| 31 |
+
|
| 32 |
+
cache_id: str
|
| 33 |
+
task_description: str
|
| 34 |
+
model_id: str
|
| 35 |
+
created_at: str
|
| 36 |
+
context_len: int
|
| 37 |
+
l2_norm: float # D4: stored for optional downstream use
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class ManifoldIndex:
|
| 41 |
+
"""FAISS-backed inner product index for EGR state vectors.
|
| 42 |
+
|
| 43 |
+
Stores state vectors and associated metadata for MIPS retrieval.
|
| 44 |
+
Persistence via faiss.serialize_index() with atomic file writes.
|
| 45 |
+
|
| 46 |
+
Usage:
|
| 47 |
+
index = ManifoldIndex(dim=160)
|
| 48 |
+
index.add(state_vec, entry)
|
| 49 |
+
results = index.search(query_vec, top_k=5)
|
| 50 |
+
index.save(Path("~/.engram/index/egr.faiss"))
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
def __init__(self, dim: int, index_path: Path | None = None):
|
| 54 |
+
"""Initialize the manifold index.
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
dim: Dimension of state vectors (must match MARStateExtractor output).
|
| 58 |
+
index_path: Optional path to load an existing index from disk.
|
| 59 |
+
"""
|
| 60 |
+
self.dim = dim
|
| 61 |
+
self._entries: list[IndexEntry] = []
|
| 62 |
+
self._id_to_position: dict[str, int] = {} # cache_id → FAISS row position
|
| 63 |
+
|
| 64 |
+
if index_path and index_path.exists():
|
| 65 |
+
self._index = self._load_index(index_path)
|
| 66 |
+
else:
|
| 67 |
+
# D2: IndexFlatIP — exact MIPS, correct for Phase 1 corpus sizes (<100K)
|
| 68 |
+
self._index = faiss.IndexFlatIP(dim)
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def n_entries(self) -> int:
|
| 72 |
+
"""Number of indexed state vectors."""
|
| 73 |
+
return self._index.ntotal
|
| 74 |
+
|
| 75 |
+
def add(
|
| 76 |
+
self,
|
| 77 |
+
state_vec: torch.Tensor | np.ndarray,
|
| 78 |
+
entry: IndexEntry,
|
| 79 |
+
) -> None:
|
| 80 |
+
"""Add a state vector and its metadata to the index.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
state_vec: [dim] state vector (D4: NOT normalized)
|
| 84 |
+
entry: Associated metadata for this engram
|
| 85 |
+
"""
|
| 86 |
+
vec = self._to_numpy(state_vec)
|
| 87 |
+
|
| 88 |
+
if vec.shape != (self.dim,):
|
| 89 |
+
raise ValueError(
|
| 90 |
+
f"State vector dim {vec.shape} != index dim ({self.dim},)"
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Check for duplicate cache_id
|
| 94 |
+
if entry.cache_id in self._id_to_position:
|
| 95 |
+
# Update: remove old entry position tracking, add at new position
|
| 96 |
+
# FAISS IndexFlat doesn't support in-place update, so we just
|
| 97 |
+
# track the latest position. Old vector remains but is shadowed.
|
| 98 |
+
pass
|
| 99 |
+
|
| 100 |
+
position = self._index.ntotal
|
| 101 |
+
self._index.add(vec.reshape(1, -1).astype(np.float32))
|
| 102 |
+
self._entries.append(entry)
|
| 103 |
+
self._id_to_position[entry.cache_id] = position
|
| 104 |
+
|
| 105 |
+
def search(
|
| 106 |
+
self,
|
| 107 |
+
query_vec: torch.Tensor | np.ndarray,
|
| 108 |
+
top_k: int = 5,
|
| 109 |
+
min_similarity: float | None = None,
|
| 110 |
+
model_id: str | None = None,
|
| 111 |
+
) -> list[CacheSearchResult]:
|
| 112 |
+
"""Search for the most similar engram states via MIPS.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
query_vec: [dim] query state vector
|
| 116 |
+
top_k: Number of results to return
|
| 117 |
+
min_similarity: Minimum inner product score threshold
|
| 118 |
+
model_id: Optional filter by model ID
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
List of CacheSearchResult sorted by similarity (descending)
|
| 122 |
+
"""
|
| 123 |
+
if self._index.ntotal == 0:
|
| 124 |
+
return []
|
| 125 |
+
|
| 126 |
+
vec = self._to_numpy(query_vec)
|
| 127 |
+
if vec.shape != (self.dim,):
|
| 128 |
+
raise ValueError(
|
| 129 |
+
f"Query vector dim {vec.shape} != index dim ({self.dim},)"
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Search more than top_k to account for filtering
|
| 133 |
+
search_k = min(top_k * 3, self._index.ntotal) if model_id else min(top_k, self._index.ntotal)
|
| 134 |
+
scores, indices = self._index.search(
|
| 135 |
+
vec.reshape(1, -1).astype(np.float32), search_k
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
results: list[CacheSearchResult] = []
|
| 139 |
+
for score, idx in zip(scores[0], indices[0]):
|
| 140 |
+
if idx < 0 or idx >= len(self._entries):
|
| 141 |
+
continue
|
| 142 |
+
|
| 143 |
+
entry = self._entries[idx]
|
| 144 |
+
|
| 145 |
+
# Skip if this cache_id has been superseded by a later add
|
| 146 |
+
if self._id_to_position.get(entry.cache_id) != idx:
|
| 147 |
+
continue
|
| 148 |
+
|
| 149 |
+
# Apply filters
|
| 150 |
+
if model_id and entry.model_id != model_id:
|
| 151 |
+
continue
|
| 152 |
+
if min_similarity is not None and score < min_similarity:
|
| 153 |
+
continue
|
| 154 |
+
|
| 155 |
+
results.append(CacheSearchResult(
|
| 156 |
+
cache_id=entry.cache_id,
|
| 157 |
+
similarity=float(score),
|
| 158 |
+
task_description=entry.task_description,
|
| 159 |
+
model_id=entry.model_id,
|
| 160 |
+
created_at=entry.created_at,
|
| 161 |
+
context_len=entry.context_len,
|
| 162 |
+
))
|
| 163 |
+
|
| 164 |
+
if len(results) >= top_k:
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
return results
|
| 168 |
+
|
| 169 |
+
def remove(self, cache_id: str) -> bool:
|
| 170 |
+
"""Mark a cache entry as removed from the index.
|
| 171 |
+
|
| 172 |
+
FAISS IndexFlat doesn't support deletion. We remove from the
|
| 173 |
+
metadata tracking so the entry is filtered out of search results.
|
| 174 |
+
The vector remains in FAISS until the next rebuild.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
cache_id: ID to remove
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
True if the entry was found and removed from tracking
|
| 181 |
+
"""
|
| 182 |
+
if cache_id in self._id_to_position:
|
| 183 |
+
del self._id_to_position[cache_id]
|
| 184 |
+
return True
|
| 185 |
+
return False
|
| 186 |
+
|
| 187 |
+
def rebuild(self) -> int:
|
| 188 |
+
"""Rebuild the index from only active entries.
|
| 189 |
+
|
| 190 |
+
Removes gaps left by remove() calls. Returns count of active entries.
|
| 191 |
+
"""
|
| 192 |
+
active_positions = set(self._id_to_position.values())
|
| 193 |
+
if len(active_positions) == len(self._entries):
|
| 194 |
+
return len(active_positions) # No gaps
|
| 195 |
+
|
| 196 |
+
# Collect active vectors and entries
|
| 197 |
+
new_entries: list[IndexEntry] = []
|
| 198 |
+
vectors: list[np.ndarray] = []
|
| 199 |
+
|
| 200 |
+
for pos, entry in enumerate(self._entries):
|
| 201 |
+
if pos in active_positions and entry.cache_id in self._id_to_position:
|
| 202 |
+
if self._id_to_position[entry.cache_id] == pos:
|
| 203 |
+
vec = faiss.rev_swig_ptr(
|
| 204 |
+
self._index.get_xb(), self._index.ntotal * self.dim
|
| 205 |
+
).reshape(-1, self.dim)[pos]
|
| 206 |
+
vectors.append(vec.copy())
|
| 207 |
+
new_entries.append(entry)
|
| 208 |
+
|
| 209 |
+
# Rebuild
|
| 210 |
+
self._index = faiss.IndexFlatIP(self.dim)
|
| 211 |
+
self._entries = []
|
| 212 |
+
self._id_to_position = {}
|
| 213 |
+
|
| 214 |
+
for vec, entry in zip(vectors, new_entries):
|
| 215 |
+
self.add(torch.from_numpy(vec), entry)
|
| 216 |
+
|
| 217 |
+
return self.n_entries
|
| 218 |
+
|
| 219 |
+
def save(self, path: Path) -> None:
|
| 220 |
+
"""Persist the index to disk.
|
| 221 |
+
|
| 222 |
+
D2: Uses faiss.serialize_index() (not write_index) to avoid
|
| 223 |
+
platform incompatibility. Atomic write via temp file + rename.
|
| 224 |
+
Metadata saved as a sidecar .json file.
|
| 225 |
+
"""
|
| 226 |
+
import json
|
| 227 |
+
|
| 228 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 229 |
+
|
| 230 |
+
# D2: serialize_index returns numpy uint8 array — write raw bytes
|
| 231 |
+
index_bytes: np.ndarray = faiss.serialize_index(self._index)
|
| 232 |
+
|
| 233 |
+
# Atomic write for FAISS index
|
| 234 |
+
tmp_path = path.with_suffix(".faiss.tmp")
|
| 235 |
+
try:
|
| 236 |
+
tmp_path.write_bytes(index_bytes.tobytes())
|
| 237 |
+
tmp_path.rename(path)
|
| 238 |
+
except Exception:
|
| 239 |
+
tmp_path.unlink(missing_ok=True)
|
| 240 |
+
raise
|
| 241 |
+
|
| 242 |
+
# Save metadata sidecar
|
| 243 |
+
meta_path = path.with_suffix(".meta.json")
|
| 244 |
+
meta_tmp = meta_path.with_suffix(".json.tmp")
|
| 245 |
+
try:
|
| 246 |
+
sidecar = {
|
| 247 |
+
"dim": self.dim,
|
| 248 |
+
"entries": [
|
| 249 |
+
{
|
| 250 |
+
"cache_id": e.cache_id,
|
| 251 |
+
"task_description": e.task_description,
|
| 252 |
+
"model_id": e.model_id,
|
| 253 |
+
"created_at": e.created_at,
|
| 254 |
+
"context_len": e.context_len,
|
| 255 |
+
"l2_norm": e.l2_norm,
|
| 256 |
+
}
|
| 257 |
+
for e in self._entries
|
| 258 |
+
],
|
| 259 |
+
"id_to_position": self._id_to_position,
|
| 260 |
+
}
|
| 261 |
+
meta_tmp.write_text(json.dumps(sidecar, indent=2))
|
| 262 |
+
meta_tmp.rename(meta_path)
|
| 263 |
+
except Exception:
|
| 264 |
+
meta_tmp.unlink(missing_ok=True)
|
| 265 |
+
raise
|
| 266 |
+
|
| 267 |
+
def _load_index(self, path: Path) -> faiss.IndexFlatIP:
|
| 268 |
+
"""Load a FAISS index and its metadata sidecar from disk.
|
| 269 |
+
|
| 270 |
+
D2: Uses faiss.deserialize_index() from raw bytes (not read_index).
|
| 271 |
+
"""
|
| 272 |
+
import json
|
| 273 |
+
|
| 274 |
+
raw = np.frombuffer(path.read_bytes(), dtype=np.uint8)
|
| 275 |
+
index = faiss.deserialize_index(raw)
|
| 276 |
+
|
| 277 |
+
meta_path = path.with_suffix(".meta.json")
|
| 278 |
+
if meta_path.exists():
|
| 279 |
+
sidecar = json.loads(meta_path.read_text())
|
| 280 |
+
self._entries = [
|
| 281 |
+
IndexEntry(**e) for e in sidecar.get("entries", [])
|
| 282 |
+
]
|
| 283 |
+
self._id_to_position = {
|
| 284 |
+
k: int(v) for k, v in sidecar.get("id_to_position", {}).items()
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
return index
|
| 288 |
+
|
| 289 |
+
@staticmethod
|
| 290 |
+
def _to_numpy(vec: torch.Tensor | np.ndarray) -> np.ndarray:
|
| 291 |
+
"""Convert a vector to numpy float32."""
|
| 292 |
+
if isinstance(vec, torch.Tensor):
|
| 293 |
+
return vec.detach().cpu().float().numpy()
|
| 294 |
+
return vec.astype(np.float32)
|
kvcos/core/retriever.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engrammatic Geometry Retrieval — Retriever
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Orchestrates the full EGR retrieval pipeline:
|
| 6 |
+
1. Extract state vector from query KV cache (MARStateExtractor)
|
| 7 |
+
2. Search manifold index for similar engram states (ManifoldIndex)
|
| 8 |
+
3. Load matched .eng files from storage (StorageBackend)
|
| 9 |
+
4. Return ranked results with KV tensors ready for injection
|
| 10 |
+
|
| 11 |
+
This is the primary interface agents use for retrieval.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
from kvcos.core.serializer import EngramSerializer
|
| 22 |
+
from kvcos.core.types import (
|
| 23 |
+
CacheSearchResult,
|
| 24 |
+
CompressionMethod,
|
| 25 |
+
EngramMetadata,
|
| 26 |
+
ModelCacheSpec,
|
| 27 |
+
StateExtractionMode,
|
| 28 |
+
)
|
| 29 |
+
from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
|
| 30 |
+
from kvcos.core.state_extractor import ExtractionResult, MARStateExtractor
|
| 31 |
+
from kvcos.storage.backends import StorageBackend
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
@dataclass
|
| 35 |
+
class RetrievalResult:
|
| 36 |
+
"""A single retrieval result with loaded KV tensors."""
|
| 37 |
+
|
| 38 |
+
cache_id: str
|
| 39 |
+
similarity: float
|
| 40 |
+
task_description: str
|
| 41 |
+
model_id: str
|
| 42 |
+
keys: torch.Tensor # [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 43 |
+
values: torch.Tensor # [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 44 |
+
metadata: EngramMetadata
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class RetrievalResponse:
|
| 49 |
+
"""Full response from a retrieval query."""
|
| 50 |
+
|
| 51 |
+
query_extraction: ExtractionResult
|
| 52 |
+
results: list[RetrievalResult]
|
| 53 |
+
n_searched: int # total entries in the index
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class EGRRetriever:
|
| 57 |
+
"""Engrammatic Geometry Retrieval — full pipeline.
|
| 58 |
+
|
| 59 |
+
Connects MARStateExtractor → ManifoldIndex → StorageBackend
|
| 60 |
+
into a single retrieval call.
|
| 61 |
+
|
| 62 |
+
Usage:
|
| 63 |
+
retriever = EGRRetriever(extractor, index, storage)
|
| 64 |
+
|
| 65 |
+
# Store an engram
|
| 66 |
+
retriever.index_engram(keys, values, spec, agent_id, task_desc, model_id)
|
| 67 |
+
|
| 68 |
+
# Retrieve similar engrams
|
| 69 |
+
response = retriever.retrieve(query_keys, spec, top_k=3)
|
| 70 |
+
for result in response.results:
|
| 71 |
+
print(result.similarity, result.task_description)
|
| 72 |
+
# result.keys / result.values ready for injection
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
def __init__(
|
| 76 |
+
self,
|
| 77 |
+
extractor: MARStateExtractor,
|
| 78 |
+
index: ManifoldIndex,
|
| 79 |
+
storage: StorageBackend,
|
| 80 |
+
serializer: EngramSerializer | None = None,
|
| 81 |
+
):
|
| 82 |
+
self.extractor = extractor
|
| 83 |
+
self.index = index
|
| 84 |
+
self.storage = storage
|
| 85 |
+
self._serializer = serializer or EngramSerializer()
|
| 86 |
+
|
| 87 |
+
def index_engram(
|
| 88 |
+
self,
|
| 89 |
+
keys: torch.Tensor,
|
| 90 |
+
values: torch.Tensor,
|
| 91 |
+
spec: ModelCacheSpec,
|
| 92 |
+
agent_id: str,
|
| 93 |
+
task_description: str,
|
| 94 |
+
model_id: str,
|
| 95 |
+
cache_id: str | None = None,
|
| 96 |
+
compression: CompressionMethod = CompressionMethod.Q8_0,
|
| 97 |
+
output_dir: Path | None = None,
|
| 98 |
+
extra_metadata: dict[str, str] | None = None,
|
| 99 |
+
) -> str:
|
| 100 |
+
"""Extract state vector, store .eng file, and add to index.
|
| 101 |
+
|
| 102 |
+
This is the "write" path: compute once → store → index → reuse forever.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
keys: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 106 |
+
values: same shape as keys
|
| 107 |
+
spec: Model architecture spec
|
| 108 |
+
agent_id: Agent identifier
|
| 109 |
+
task_description: Human-readable task description (searchable)
|
| 110 |
+
model_id: Full model identifier
|
| 111 |
+
cache_id: Explicit ID (auto-generated if None)
|
| 112 |
+
compression: Compression method for storage
|
| 113 |
+
output_dir: Directory for .eng file (uses storage backend default if None)
|
| 114 |
+
extra_metadata: Additional metadata key-value pairs
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
cache_id of the stored engram
|
| 118 |
+
"""
|
| 119 |
+
import uuid
|
| 120 |
+
from datetime import datetime, timezone
|
| 121 |
+
|
| 122 |
+
from kvcos.core.types import ENG_FILE_EXTENSION
|
| 123 |
+
|
| 124 |
+
cid = cache_id or str(uuid.uuid4())
|
| 125 |
+
|
| 126 |
+
# 1. Extract state vector
|
| 127 |
+
extraction = self.extractor.extract(keys, spec)
|
| 128 |
+
|
| 129 |
+
# 2. Serialize to .eng file
|
| 130 |
+
if output_dir:
|
| 131 |
+
output_path = output_dir / f"{cid}{ENG_FILE_EXTENSION}"
|
| 132 |
+
else:
|
| 133 |
+
# Use a temp path; storage backend will move it
|
| 134 |
+
import tempfile
|
| 135 |
+
output_path = Path(tempfile.mkdtemp()) / f"{cid}{ENG_FILE_EXTENSION}"
|
| 136 |
+
|
| 137 |
+
merge_meta = {
|
| 138 |
+
"state_vec_norm": str(extraction.l2_norm),
|
| 139 |
+
"extraction_mode": extraction.mode.value,
|
| 140 |
+
}
|
| 141 |
+
if extra_metadata:
|
| 142 |
+
merge_meta.update(extra_metadata)
|
| 143 |
+
|
| 144 |
+
result = self._serializer.serialize(
|
| 145 |
+
keys=keys,
|
| 146 |
+
values=values,
|
| 147 |
+
agent_id=agent_id,
|
| 148 |
+
task_description=task_description,
|
| 149 |
+
model_id=model_id,
|
| 150 |
+
output_path=output_path,
|
| 151 |
+
compression=compression,
|
| 152 |
+
cache_id=cid,
|
| 153 |
+
extra_metadata=merge_meta,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# 3. Store in backend
|
| 157 |
+
metadata = self._serializer.read_metadata_only(output_path)
|
| 158 |
+
self.storage.store_file(cid, output_path, metadata)
|
| 159 |
+
|
| 160 |
+
# 4. Add to manifold index
|
| 161 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 162 |
+
entry = IndexEntry(
|
| 163 |
+
cache_id=cid,
|
| 164 |
+
task_description=task_description,
|
| 165 |
+
model_id=model_id,
|
| 166 |
+
created_at=now,
|
| 167 |
+
context_len=keys.shape[2],
|
| 168 |
+
l2_norm=extraction.l2_norm,
|
| 169 |
+
)
|
| 170 |
+
self.index.add(extraction.state_vec, entry)
|
| 171 |
+
|
| 172 |
+
return cid
|
| 173 |
+
|
| 174 |
+
def retrieve(
|
| 175 |
+
self,
|
| 176 |
+
query_keys: torch.Tensor,
|
| 177 |
+
spec: ModelCacheSpec,
|
| 178 |
+
top_k: int = 5,
|
| 179 |
+
min_similarity: float | None = None,
|
| 180 |
+
model_id: str | None = None,
|
| 181 |
+
load_tensors: bool = True,
|
| 182 |
+
) -> RetrievalResponse:
|
| 183 |
+
"""Retrieve similar engram states for a query KV cache.
|
| 184 |
+
|
| 185 |
+
This is the "read" path: extract query vector → search index →
|
| 186 |
+
load matching .eng files.
|
| 187 |
+
|
| 188 |
+
Args:
|
| 189 |
+
query_keys: [n_layers, n_kv_heads, ctx_len, head_dim] query K cache
|
| 190 |
+
spec: Model architecture spec
|
| 191 |
+
top_k: Number of results to return
|
| 192 |
+
min_similarity: Minimum MIPS score threshold
|
| 193 |
+
model_id: Filter by model ID
|
| 194 |
+
load_tensors: If True, load full KV tensors from storage.
|
| 195 |
+
If False, return metadata only (faster for previewing).
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
RetrievalResponse with ranked results
|
| 199 |
+
"""
|
| 200 |
+
# 1. Extract query state vector
|
| 201 |
+
query_extraction = self.extractor.extract(query_keys, spec)
|
| 202 |
+
|
| 203 |
+
# 2. Search manifold index
|
| 204 |
+
search_results = self.index.search(
|
| 205 |
+
query_vec=query_extraction.state_vec,
|
| 206 |
+
top_k=top_k,
|
| 207 |
+
min_similarity=min_similarity,
|
| 208 |
+
model_id=model_id,
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# 3. Load matching engrams from storage
|
| 212 |
+
results: list[RetrievalResult] = []
|
| 213 |
+
for sr in search_results:
|
| 214 |
+
if load_tensors:
|
| 215 |
+
path = self.storage.get_path(sr["cache_id"])
|
| 216 |
+
if path is None:
|
| 217 |
+
continue
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
keys, values, metadata = self._serializer.deserialize(path)
|
| 221 |
+
except Exception:
|
| 222 |
+
continue
|
| 223 |
+
|
| 224 |
+
results.append(RetrievalResult(
|
| 225 |
+
cache_id=sr["cache_id"],
|
| 226 |
+
similarity=sr["similarity"],
|
| 227 |
+
task_description=sr["task_description"],
|
| 228 |
+
model_id=sr["model_id"],
|
| 229 |
+
keys=keys,
|
| 230 |
+
values=values,
|
| 231 |
+
metadata=metadata,
|
| 232 |
+
))
|
| 233 |
+
else:
|
| 234 |
+
# Metadata-only mode
|
| 235 |
+
metadata = self.storage.get_metadata(sr["cache_id"])
|
| 236 |
+
if metadata is None:
|
| 237 |
+
continue
|
| 238 |
+
|
| 239 |
+
results.append(RetrievalResult(
|
| 240 |
+
cache_id=sr["cache_id"],
|
| 241 |
+
similarity=sr["similarity"],
|
| 242 |
+
task_description=sr["task_description"],
|
| 243 |
+
model_id=sr["model_id"],
|
| 244 |
+
keys=torch.empty(0),
|
| 245 |
+
values=torch.empty(0),
|
| 246 |
+
metadata=metadata,
|
| 247 |
+
))
|
| 248 |
+
|
| 249 |
+
return RetrievalResponse(
|
| 250 |
+
query_extraction=query_extraction,
|
| 251 |
+
results=results,
|
| 252 |
+
n_searched=self.index.n_entries,
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
def delete_engram(self, cache_id: str) -> bool:
|
| 256 |
+
"""Remove an engram from both index and storage."""
|
| 257 |
+
idx_removed = self.index.remove(cache_id)
|
| 258 |
+
store_removed = self.storage.delete(cache_id)
|
| 259 |
+
return idx_removed or store_removed
|
| 260 |
+
|
| 261 |
+
def save_index(self, path: Path) -> None:
|
| 262 |
+
"""Persist the manifold index to disk."""
|
| 263 |
+
self.index.save(path)
|
kvcos/core/serializer.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — .eng File Serializer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
.eng = safetensors container with:
|
| 6 |
+
- __metadata__: JSON-stringified EngramMetadata (all string values per D7)
|
| 7 |
+
- Tensor keys: layer_{i}_keys, layer_{i}_values
|
| 8 |
+
- Each tensor: [n_kv_heads, ctx_len, head_dim] at compressed dtype
|
| 9 |
+
|
| 10 |
+
D7: safetensors confirmed. GGUF rejected. String-only metadata values.
|
| 11 |
+
Reference: arXiv:2603.04428 uses identical safetensors approach.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import hashlib
|
| 17 |
+
import uuid
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
from typing import Any
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from safetensors.torch import load_file, save_file
|
| 24 |
+
|
| 25 |
+
from kvcos.core.cache_spec import infer_model_family
|
| 26 |
+
from kvcos.core.compression import CompressionResult, compress, decompress
|
| 27 |
+
from kvcos.core.types import (
|
| 28 |
+
ENGRAM_VERSION,
|
| 29 |
+
ENG_FILE_EXTENSION,
|
| 30 |
+
CompressionMethod,
|
| 31 |
+
EngramMetadata,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class SerializationError(Exception):
|
| 36 |
+
"""Raised when serialization or deserialization fails."""
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class EngramSerializer:
|
| 40 |
+
"""Serializes/deserializes KV cache tensors to/from .eng files.
|
| 41 |
+
|
| 42 |
+
Canonical shape for KV tensors in ENGRAM:
|
| 43 |
+
keys: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 44 |
+
values: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
def serialize(
|
| 48 |
+
self,
|
| 49 |
+
keys: torch.Tensor,
|
| 50 |
+
values: torch.Tensor,
|
| 51 |
+
agent_id: str,
|
| 52 |
+
task_description: str,
|
| 53 |
+
model_id: str,
|
| 54 |
+
output_path: Path,
|
| 55 |
+
compression: CompressionMethod = CompressionMethod.Q8_0,
|
| 56 |
+
cache_id: str | None = None,
|
| 57 |
+
parent_cache_id: str | None = None,
|
| 58 |
+
input_tokens: list[int] | None = None,
|
| 59 |
+
extra_metadata: dict[str, str] | None = None,
|
| 60 |
+
) -> dict[str, Any]:
|
| 61 |
+
"""Serialize KV cache tensors to a .eng file.
|
| 62 |
+
|
| 63 |
+
Args:
|
| 64 |
+
keys: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 65 |
+
values: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 66 |
+
agent_id: Identifier for the agent that produced this state
|
| 67 |
+
task_description: Human-readable description (used for EGR search)
|
| 68 |
+
model_id: Full model identifier
|
| 69 |
+
output_path: Path to write .eng file
|
| 70 |
+
compression: Compression method to apply
|
| 71 |
+
cache_id: Explicit cache ID (auto-generated if None)
|
| 72 |
+
parent_cache_id: ID of parent for delta chains
|
| 73 |
+
input_tokens: Token IDs that generated this state (for hash)
|
| 74 |
+
extra_metadata: Additional string key-value pairs
|
| 75 |
+
|
| 76 |
+
Returns:
|
| 77 |
+
Dict with cache_id, size_bytes, compression_ratio, path
|
| 78 |
+
"""
|
| 79 |
+
if keys.shape != values.shape:
|
| 80 |
+
raise SerializationError(
|
| 81 |
+
f"Keys/values shape mismatch: {keys.shape} vs {values.shape}"
|
| 82 |
+
)
|
| 83 |
+
if keys.dim() != 4:
|
| 84 |
+
raise SerializationError(
|
| 85 |
+
f"Expected 4D [n_layers, n_kv_heads, ctx_len, head_dim], "
|
| 86 |
+
f"got {keys.dim()}D: {keys.shape}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
|
| 90 |
+
|
| 91 |
+
tensors: dict[str, torch.Tensor] = {}
|
| 92 |
+
|
| 93 |
+
if compression == CompressionMethod.INT8:
|
| 94 |
+
from kvcos.core.compression import compress_int8_tensor
|
| 95 |
+
|
| 96 |
+
k_pair = compress_int8_tensor(keys)
|
| 97 |
+
v_pair = compress_int8_tensor(values)
|
| 98 |
+
for i in range(n_layers):
|
| 99 |
+
tensors[f"layer_{i}_keys"] = k_pair.quantized[i].contiguous()
|
| 100 |
+
tensors[f"layer_{i}_keys_scale"] = k_pair.scales[i].contiguous()
|
| 101 |
+
tensors[f"layer_{i}_values"] = v_pair.quantized[i].contiguous()
|
| 102 |
+
tensors[f"layer_{i}_values_scale"] = v_pair.scales[i].contiguous()
|
| 103 |
+
# Reuse k_compressed for metadata only — actual INT8 data is
|
| 104 |
+
# already written per-layer above via k_pair/v_pair.
|
| 105 |
+
k_compressed = compress(keys, compression)
|
| 106 |
+
v_compressed = k_compressed
|
| 107 |
+
elif compression == CompressionMethod.LAYER_DELTA:
|
| 108 |
+
from kvcos.core.compression import compress_layer_delta
|
| 109 |
+
|
| 110 |
+
k_ld = compress_layer_delta(keys)
|
| 111 |
+
v_ld = compress_layer_delta(values)
|
| 112 |
+
# Layer 0: fp16 baseline
|
| 113 |
+
tensors["layer_0_keys"] = k_ld.baseline.contiguous()
|
| 114 |
+
tensors["layer_0_values"] = v_ld.baseline.contiguous()
|
| 115 |
+
# Layers 1..N: int8 deltas + fp16 scales
|
| 116 |
+
for i in range(n_layers - 1):
|
| 117 |
+
tensors[f"layer_{i+1}_keys"] = k_ld.delta_quantized[i].contiguous()
|
| 118 |
+
tensors[f"layer_{i+1}_keys_scale"] = k_ld.delta_scales[i].contiguous()
|
| 119 |
+
tensors[f"layer_{i+1}_values"] = v_ld.delta_quantized[i].contiguous()
|
| 120 |
+
tensors[f"layer_{i+1}_values_scale"] = v_ld.delta_scales[i].contiguous()
|
| 121 |
+
# Reuse k_compressed for metadata only — actual layer-delta data
|
| 122 |
+
# is already written above via k_ld/v_ld.
|
| 123 |
+
k_compressed = compress(keys, compression)
|
| 124 |
+
v_compressed = k_compressed
|
| 125 |
+
else:
|
| 126 |
+
k_compressed = compress(keys, compression)
|
| 127 |
+
v_compressed = compress(values, compression)
|
| 128 |
+
for i in range(n_layers):
|
| 129 |
+
tensors[f"layer_{i}_keys"] = k_compressed.data[i].contiguous()
|
| 130 |
+
tensors[f"layer_{i}_values"] = v_compressed.data[i].contiguous()
|
| 131 |
+
|
| 132 |
+
cid = cache_id or str(uuid.uuid4())
|
| 133 |
+
now = datetime.now(timezone.utc).isoformat()
|
| 134 |
+
|
| 135 |
+
token_hash = ""
|
| 136 |
+
if input_tokens:
|
| 137 |
+
token_bytes = b"".join(t.to_bytes(4, "little") for t in input_tokens)
|
| 138 |
+
token_hash = f"sha256:{hashlib.sha256(token_bytes).hexdigest()}"
|
| 139 |
+
|
| 140 |
+
metadata: EngramMetadata = {
|
| 141 |
+
"engram_version": ENGRAM_VERSION,
|
| 142 |
+
"cache_id": cid,
|
| 143 |
+
"compression": compression.value,
|
| 144 |
+
"model_id": model_id,
|
| 145 |
+
"model_family": infer_model_family(model_id),
|
| 146 |
+
"n_layers": str(n_layers),
|
| 147 |
+
"n_heads": str(n_kv_heads),
|
| 148 |
+
"n_kv_heads": str(n_kv_heads),
|
| 149 |
+
"head_dim": str(head_dim),
|
| 150 |
+
"context_len": str(ctx_len),
|
| 151 |
+
"agent_id": agent_id,
|
| 152 |
+
"task_description": task_description,
|
| 153 |
+
"created_at": now,
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
if parent_cache_id:
|
| 157 |
+
metadata["parent_cache_id"] = parent_cache_id
|
| 158 |
+
if token_hash:
|
| 159 |
+
metadata["token_hash"] = token_hash
|
| 160 |
+
for key, val in k_compressed.metadata.items():
|
| 161 |
+
metadata[f"compression_{key}"] = val # type: ignore[literal-required]
|
| 162 |
+
if extra_metadata:
|
| 163 |
+
for key, val in extra_metadata.items():
|
| 164 |
+
metadata[key] = val # type: ignore[literal-required]
|
| 165 |
+
|
| 166 |
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
|
| 168 |
+
str_metadata: dict[str, str] = {k: str(v) for k, v in metadata.items()}
|
| 169 |
+
save_file(tensors, str(output_path), metadata=str_metadata)
|
| 170 |
+
|
| 171 |
+
original_bytes = (keys.numel() + values.numel()) * keys.element_size()
|
| 172 |
+
compressed_bytes = output_path.stat().st_size
|
| 173 |
+
|
| 174 |
+
return {
|
| 175 |
+
"cache_id": cid,
|
| 176 |
+
"size_bytes": compressed_bytes,
|
| 177 |
+
"compression_ratio": original_bytes / compressed_bytes if compressed_bytes > 0 else 1.0,
|
| 178 |
+
"path": str(output_path),
|
| 179 |
+
"n_layers": n_layers,
|
| 180 |
+
"context_len": ctx_len,
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
def deserialize(
|
| 184 |
+
self,
|
| 185 |
+
path: Path,
|
| 186 |
+
target_compression: CompressionMethod | None = None,
|
| 187 |
+
) -> tuple[torch.Tensor, torch.Tensor, EngramMetadata]:
|
| 188 |
+
"""Deserialize a .eng file into KV cache tensors.
|
| 189 |
+
|
| 190 |
+
Returns (keys, values, metadata) where tensors are
|
| 191 |
+
[n_layers, n_kv_heads, ctx_len, head_dim].
|
| 192 |
+
"""
|
| 193 |
+
if not path.exists():
|
| 194 |
+
raise SerializationError(f"Engram file not found: {path}")
|
| 195 |
+
|
| 196 |
+
tensors = load_file(str(path))
|
| 197 |
+
metadata = self._read_metadata(path)
|
| 198 |
+
|
| 199 |
+
n_layers = int(metadata.get("n_layers", "0"))
|
| 200 |
+
if n_layers == 0:
|
| 201 |
+
n_layers = (
|
| 202 |
+
max(int(k.split("_")[1]) for k in tensors if k.startswith("layer_")) + 1
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
stored_compression = metadata.get("compression", "fp16")
|
| 206 |
+
is_int8 = stored_compression == CompressionMethod.INT8.value
|
| 207 |
+
is_layer_delta = stored_compression == CompressionMethod.LAYER_DELTA.value
|
| 208 |
+
|
| 209 |
+
k_layers: list[torch.Tensor] = []
|
| 210 |
+
v_layers: list[torch.Tensor] = []
|
| 211 |
+
|
| 212 |
+
if is_layer_delta:
|
| 213 |
+
from kvcos.core.compression import decompress_int8_tensor
|
| 214 |
+
|
| 215 |
+
# Layer 0: fp16 baseline
|
| 216 |
+
k_layers.append(tensors["layer_0_keys"].float())
|
| 217 |
+
v_layers.append(tensors["layer_0_values"].float())
|
| 218 |
+
# Layers 1..N: accumulate int8 deltas
|
| 219 |
+
for i in range(1, n_layers):
|
| 220 |
+
k_delta = decompress_int8_tensor(
|
| 221 |
+
tensors[f"layer_{i}_keys"], tensors[f"layer_{i}_keys_scale"]
|
| 222 |
+
)
|
| 223 |
+
v_delta = decompress_int8_tensor(
|
| 224 |
+
tensors[f"layer_{i}_values"], tensors[f"layer_{i}_values_scale"]
|
| 225 |
+
)
|
| 226 |
+
k_layers.append(k_layers[-1] + k_delta.float())
|
| 227 |
+
v_layers.append(v_layers[-1] + v_delta.float())
|
| 228 |
+
keys = torch.stack([l.to(torch.float16) for l in k_layers], dim=0)
|
| 229 |
+
values = torch.stack([l.to(torch.float16) for l in v_layers], dim=0)
|
| 230 |
+
else:
|
| 231 |
+
for i in range(n_layers):
|
| 232 |
+
k_key = f"layer_{i}_keys"
|
| 233 |
+
v_key = f"layer_{i}_values"
|
| 234 |
+
if k_key not in tensors or v_key not in tensors:
|
| 235 |
+
raise SerializationError(f"Missing tensor for layer {i}")
|
| 236 |
+
|
| 237 |
+
if is_int8:
|
| 238 |
+
from kvcos.core.compression import decompress_int8_tensor
|
| 239 |
+
|
| 240 |
+
k_scale_key = f"layer_{i}_keys_scale"
|
| 241 |
+
v_scale_key = f"layer_{i}_values_scale"
|
| 242 |
+
if k_scale_key not in tensors or v_scale_key not in tensors:
|
| 243 |
+
raise SerializationError(f"Missing INT8 scale for layer {i}")
|
| 244 |
+
k_layers.append(decompress_int8_tensor(tensors[k_key], tensors[k_scale_key]))
|
| 245 |
+
v_layers.append(decompress_int8_tensor(tensors[v_key], tensors[v_scale_key]))
|
| 246 |
+
else:
|
| 247 |
+
k_layers.append(tensors[k_key])
|
| 248 |
+
v_layers.append(tensors[v_key])
|
| 249 |
+
|
| 250 |
+
keys = torch.stack(k_layers, dim=0)
|
| 251 |
+
values = torch.stack(v_layers, dim=0)
|
| 252 |
+
|
| 253 |
+
if target_compression is not None:
|
| 254 |
+
stored = CompressionMethod(metadata.get("compression", "fp16"))
|
| 255 |
+
keys = decompress(keys, stored)
|
| 256 |
+
values = decompress(values, stored)
|
| 257 |
+
|
| 258 |
+
return keys, values, metadata # type: ignore[return-value]
|
| 259 |
+
|
| 260 |
+
def _read_metadata(self, path: Path) -> dict[str, str]:
|
| 261 |
+
"""Read only the metadata header (no tensor data loaded)."""
|
| 262 |
+
from safetensors import safe_open
|
| 263 |
+
|
| 264 |
+
metadata: dict[str, str] = {}
|
| 265 |
+
with safe_open(str(path), framework="pt") as f:
|
| 266 |
+
raw_meta = f.metadata()
|
| 267 |
+
if raw_meta:
|
| 268 |
+
metadata = dict(raw_meta)
|
| 269 |
+
return metadata
|
| 270 |
+
|
| 271 |
+
def read_metadata_only(self, path: Path) -> EngramMetadata:
|
| 272 |
+
"""Read just the metadata from a .eng file. Efficient for indexing."""
|
| 273 |
+
raw = self._read_metadata(path)
|
| 274 |
+
return raw # type: ignore[return-value]
|
kvcos/core/state_extractor.py
ADDED
|
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Engrammatic Geometry Retrieval — State Extraction Layer
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Extracts a retrieval state vector from a KV cache tensor for MIPS-based
|
| 6 |
+
retrieval in EGR (Engrammatic Geometry Retrieval). The state vector is
|
| 7 |
+
a compact geometric fingerprint of a cognitive state — positioned in the
|
| 8 |
+
model's own pre-RoPE key manifold for geometrically consistent retrieval.
|
| 9 |
+
|
| 10 |
+
Three extraction modes:
|
| 11 |
+
|
| 12 |
+
mean_pool: Fast baseline. Mean over heads + context of key matrices
|
| 13 |
+
across extraction layers. Output: [head_dim]. No learned
|
| 14 |
+
parameters. Use for bootstrapping and smoke tests.
|
| 15 |
+
|
| 16 |
+
svd_project: Truncated SVD on pre-RoPE keys, extraction layers (D3: 8-31),
|
| 17 |
+
rank-160 for 8B models. Validated by ShadowKV (ICML 2025,
|
| 18 |
+
ByteDance) on Llama-3.1-8B and Phi-3-Mini-128K.
|
| 19 |
+
Output: [rank]. Projection is prompt-dependent — W computed
|
| 20 |
+
per cache via online SVD, not precomputed globally.
|
| 21 |
+
Reference: github.com/ByteDance-Seed/ShadowKV
|
| 22 |
+
|
| 23 |
+
xkv_project: Grouped cross-layer SVD. Groups 4 adjacent extraction layers,
|
| 24 |
+
extracts shared basis vectors across the group. Achieves
|
| 25 |
+
6.8x compression vs 2.5x single-layer SVD. K:V rank ratio
|
| 26 |
+
1:1.5 is optimal per xKV paper.
|
| 27 |
+
Reference: github.com/abdelfattah-lab/xKV
|
| 28 |
+
arXiv:2503.18893
|
| 29 |
+
|
| 30 |
+
REMOVED: sals_project — last-layer-only extraction invalidated by
|
| 31 |
+
Layer-Condensed KV Cache (ACL 2024). See D3.
|
| 32 |
+
|
| 33 |
+
D4: No L2 normalization. True MIPS. L2 norm stored as metadata for
|
| 34 |
+
optional downstream use.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
from __future__ import annotations
|
| 38 |
+
|
| 39 |
+
from dataclasses import dataclass, field
|
| 40 |
+
|
| 41 |
+
import torch
|
| 42 |
+
from einops import rearrange
|
| 43 |
+
|
| 44 |
+
from kvcos.core.types import (
|
| 45 |
+
DEFAULT_SVD_RANK,
|
| 46 |
+
ModelCacheSpec,
|
| 47 |
+
StateExtractionMode,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class ExtractionResult:
|
| 53 |
+
"""Result of state vector extraction from a KV cache."""
|
| 54 |
+
|
| 55 |
+
state_vec: torch.Tensor # [d_out] — the retrieval vector
|
| 56 |
+
l2_norm: float # stored as metadata per D4
|
| 57 |
+
mode: StateExtractionMode
|
| 58 |
+
n_layers_used: int
|
| 59 |
+
n_tokens: int
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@dataclass
|
| 63 |
+
class SVDProjection:
|
| 64 |
+
"""Learned SVD projection matrix for a specific cache.
|
| 65 |
+
|
| 66 |
+
ShadowKV finding: pre-RoPE keys share low-rank subspaces WITHIN
|
| 67 |
+
sequences but differ ACROSS sequences. Projection must be computed
|
| 68 |
+
online per cache, not precomputed globally.
|
| 69 |
+
"""
|
| 70 |
+
|
| 71 |
+
W: torch.Tensor # [head_dim, rank] — right singular vectors
|
| 72 |
+
singular_values: torch.Tensor # [rank] — for diagnostics
|
| 73 |
+
explained_variance_ratio: float # fraction of variance captured
|
| 74 |
+
source_shape: tuple[int, ...] # shape of the keys used to compute this
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class MARStateExtractor:
|
| 78 |
+
"""Extracts retrieval state vectors from KV cache tensors for EGR.
|
| 79 |
+
|
| 80 |
+
Usage:
|
| 81 |
+
extractor = MARStateExtractor(mode="svd_project", rank=160)
|
| 82 |
+
result = extractor.extract(keys, spec)
|
| 83 |
+
# result.state_vec is the retrieval vector for FAISS IndexFlatIP
|
| 84 |
+
# result.l2_norm goes into .eng metadata (D4)
|
| 85 |
+
"""
|
| 86 |
+
|
| 87 |
+
# Max rows fed to SVD. 8192 rows on a 128-dim matrix runs in ~15ms
|
| 88 |
+
# vs ~2000ms for the full 786K-row matrix. Subspace quality is
|
| 89 |
+
# preserved because SVD only needs O(head_dim²) samples to recover
|
| 90 |
+
# the top singular vectors of a low-rank matrix.
|
| 91 |
+
MAX_SVD_ROWS: int = 8192
|
| 92 |
+
|
| 93 |
+
def __init__(
|
| 94 |
+
self,
|
| 95 |
+
mode: StateExtractionMode = StateExtractionMode.SVD_PROJECT,
|
| 96 |
+
rank: int = DEFAULT_SVD_RANK,
|
| 97 |
+
xkv_group_size: int = 4,
|
| 98 |
+
xkv_kv_rank_ratio: float = 1.5,
|
| 99 |
+
max_svd_rows: int | None = None,
|
| 100 |
+
layer_range: tuple[int, int] | None = None,
|
| 101 |
+
gate_start: int = 0,
|
| 102 |
+
):
|
| 103 |
+
self.mode = mode
|
| 104 |
+
self.rank = rank
|
| 105 |
+
self.xkv_group_size = xkv_group_size
|
| 106 |
+
self.xkv_kv_rank_ratio = xkv_kv_rank_ratio
|
| 107 |
+
self.max_svd_rows = max_svd_rows or self.MAX_SVD_ROWS
|
| 108 |
+
# Override spec extraction_layers when set. (8, 24) uses middle
|
| 109 |
+
# layers which encode semantic content (Tenney 2019, Huh 2024).
|
| 110 |
+
self.layer_range = layer_range
|
| 111 |
+
# Skip top gate_start singular values in SVD projection.
|
| 112 |
+
# Top SVs encode shared positional/syntactic structure;
|
| 113 |
+
# skipping them isolates semantic content (gate_start=6 optimal).
|
| 114 |
+
self.gate_start = gate_start
|
| 115 |
+
|
| 116 |
+
# Cached projection from last extract call (for inspection/reuse)
|
| 117 |
+
self._last_projection: SVDProjection | None = None
|
| 118 |
+
|
| 119 |
+
def extract(
|
| 120 |
+
self,
|
| 121 |
+
keys: torch.Tensor,
|
| 122 |
+
spec: ModelCacheSpec,
|
| 123 |
+
) -> ExtractionResult:
|
| 124 |
+
"""Extract a state vector from KV cache key tensors.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
keys: [n_layers, n_kv_heads, ctx_len, head_dim] — the K cache.
|
| 128 |
+
Must be pre-RoPE if available. Post-RoPE works but with
|
| 129 |
+
reduced retrieval quality due to position-dependent distortion.
|
| 130 |
+
spec: Model architecture spec (provides extraction_layers).
|
| 131 |
+
|
| 132 |
+
Returns:
|
| 133 |
+
ExtractionResult with state vector and metadata.
|
| 134 |
+
"""
|
| 135 |
+
n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
|
| 136 |
+
|
| 137 |
+
# Layer selection: layer_range overrides spec extraction_layers
|
| 138 |
+
if self.layer_range is not None:
|
| 139 |
+
start, end = self.layer_range
|
| 140 |
+
start = max(0, min(start, n_layers))
|
| 141 |
+
end = max(start, min(end, n_layers))
|
| 142 |
+
layer_indices = list(range(start, end))
|
| 143 |
+
else:
|
| 144 |
+
extraction_layers = spec["extraction_layers"]
|
| 145 |
+
layer_indices = [l for l in extraction_layers if l < n_layers]
|
| 146 |
+
|
| 147 |
+
if not layer_indices:
|
| 148 |
+
layer_indices = list(range(n_layers))
|
| 149 |
+
|
| 150 |
+
selected_keys = keys[layer_indices] # [n_selected, n_kv_heads, ctx_len, head_dim]
|
| 151 |
+
|
| 152 |
+
match self.mode:
|
| 153 |
+
case StateExtractionMode.MEAN_POOL:
|
| 154 |
+
state_vec = self._mean_pool(selected_keys)
|
| 155 |
+
case StateExtractionMode.SVD_PROJECT:
|
| 156 |
+
state_vec = self._svd_project(selected_keys)
|
| 157 |
+
case StateExtractionMode.XKV_PROJECT:
|
| 158 |
+
state_vec = self._xkv_project(selected_keys)
|
| 159 |
+
case _:
|
| 160 |
+
raise ValueError(f"Unknown extraction mode: {self.mode}")
|
| 161 |
+
|
| 162 |
+
# D4: No normalization. True MIPS. Store norm as metadata.
|
| 163 |
+
l2_norm = float(torch.linalg.vector_norm(state_vec).item())
|
| 164 |
+
|
| 165 |
+
return ExtractionResult(
|
| 166 |
+
state_vec=state_vec,
|
| 167 |
+
l2_norm=l2_norm,
|
| 168 |
+
mode=self.mode,
|
| 169 |
+
n_layers_used=len(layer_indices),
|
| 170 |
+
n_tokens=ctx_len,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _mean_pool(self, keys: torch.Tensor) -> torch.Tensor:
|
| 174 |
+
"""Fast baseline: mean over layers, heads, and context positions.
|
| 175 |
+
|
| 176 |
+
Input: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 177 |
+
Output: [head_dim]
|
| 178 |
+
"""
|
| 179 |
+
return keys.float().mean(dim=(0, 1, 2))
|
| 180 |
+
|
| 181 |
+
def _svd_project(self, keys: torch.Tensor) -> torch.Tensor:
|
| 182 |
+
"""Truncated SVD projection on pre-RoPE keys.
|
| 183 |
+
|
| 184 |
+
ShadowKV approach: flatten all extraction layers' keys into a 2D matrix
|
| 185 |
+
[N, head_dim], compute truncated SVD, project onto top-rank singular vectors,
|
| 186 |
+
then mean-pool the projected vectors.
|
| 187 |
+
|
| 188 |
+
For large contexts (N > max_svd_rows), we subsample rows before SVD.
|
| 189 |
+
SVD only needs O(head_dim²) samples to recover the top singular vectors
|
| 190 |
+
of a low-rank matrix, so subsampling to 8K rows preserves subspace quality
|
| 191 |
+
while reducing SVD from ~2000ms to ~15ms at 4K context.
|
| 192 |
+
|
| 193 |
+
Input: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 194 |
+
Output: [rank]
|
| 195 |
+
"""
|
| 196 |
+
n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
|
| 197 |
+
|
| 198 |
+
# Total rows in the flattened matrix
|
| 199 |
+
n_rows = n_layers * n_kv_heads * ctx_len
|
| 200 |
+
|
| 201 |
+
if n_rows > self.max_svd_rows:
|
| 202 |
+
# Subsample BEFORE flatten+cast to avoid allocating the full
|
| 203 |
+
# float32 matrix (saves ~30ms rearrange + 100MB at 4K context).
|
| 204 |
+
gen = torch.Generator()
|
| 205 |
+
gen.manual_seed(42)
|
| 206 |
+
indices = torch.randperm(n_rows, generator=gen)[:self.max_svd_rows]
|
| 207 |
+
flat_keys = keys.reshape(n_rows, head_dim)[indices].float()
|
| 208 |
+
svd_input = flat_keys
|
| 209 |
+
else:
|
| 210 |
+
flat_keys = rearrange(keys.float(), 'l h t d -> (l h t) d')
|
| 211 |
+
svd_input = flat_keys
|
| 212 |
+
|
| 213 |
+
# Clamp rank to not exceed matrix dimensions
|
| 214 |
+
max_rank = min(head_dim, svd_input.shape[0])
|
| 215 |
+
effective_rank = min(self.gate_start + self.rank, max_rank)
|
| 216 |
+
|
| 217 |
+
# Truncated SVD on (subsampled) matrix
|
| 218 |
+
U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
|
| 219 |
+
|
| 220 |
+
# W = right singular vectors with gating: skip top gate_start SVs
|
| 221 |
+
# to remove shared positional/syntactic structure
|
| 222 |
+
W = Vh[self.gate_start:effective_rank, :].T
|
| 223 |
+
|
| 224 |
+
# Store projection for inspection
|
| 225 |
+
total_var = (S ** 2).sum()
|
| 226 |
+
explained_var = (S[:effective_rank] ** 2).sum()
|
| 227 |
+
self._last_projection = SVDProjection(
|
| 228 |
+
W=W,
|
| 229 |
+
singular_values=S[:effective_rank],
|
| 230 |
+
explained_variance_ratio=float((explained_var / total_var).item()) if total_var > 0 else 0.0,
|
| 231 |
+
source_shape=tuple(keys.shape),
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
# Project subsampled rows and mean-pool → [rank]
|
| 235 |
+
# Using the subsample for projection too avoids the expensive
|
| 236 |
+
# 786K × 128 matmul + mean that dominates at large contexts.
|
| 237 |
+
projected = svd_input @ W
|
| 238 |
+
state_vec = projected.mean(dim=0)
|
| 239 |
+
|
| 240 |
+
return state_vec
|
| 241 |
+
|
| 242 |
+
def _xkv_project(self, keys: torch.Tensor) -> torch.Tensor:
|
| 243 |
+
"""Grouped cross-layer SVD (xKV approach).
|
| 244 |
+
|
| 245 |
+
Groups adjacent layers (default 4), computes shared SVD basis
|
| 246 |
+
per group, projects keys onto that basis, then concatenates
|
| 247 |
+
group state vectors.
|
| 248 |
+
|
| 249 |
+
This captures cross-layer structure that single-layer SVD misses.
|
| 250 |
+
Achieves 6.8x vs 2.5x for single-layer SVD on Llama-3.1-8B.
|
| 251 |
+
|
| 252 |
+
K:V rank ratio 1:1.5 is optimal per xKV paper, but since we
|
| 253 |
+
only index keys (D2: K→K retrieval), we use the K rank only.
|
| 254 |
+
|
| 255 |
+
Input: [n_layers, n_kv_heads, ctx_len, head_dim]
|
| 256 |
+
Output: [n_groups * rank_per_group]
|
| 257 |
+
"""
|
| 258 |
+
n_layers, n_kv_heads, ctx_len, head_dim = keys.shape
|
| 259 |
+
|
| 260 |
+
# Compute rank per group
|
| 261 |
+
# xKV finding: K rank is lower than V rank by factor 1:1.5
|
| 262 |
+
# For 160 total rank budget across groups, allocate per group
|
| 263 |
+
n_groups = max(1, n_layers // self.xkv_group_size)
|
| 264 |
+
rank_per_group = max(1, self.rank // n_groups)
|
| 265 |
+
rank_per_group = min(rank_per_group, head_dim)
|
| 266 |
+
|
| 267 |
+
group_vecs: list[torch.Tensor] = []
|
| 268 |
+
|
| 269 |
+
for g in range(n_groups):
|
| 270 |
+
start = g * self.xkv_group_size
|
| 271 |
+
end = min(start + self.xkv_group_size, n_layers)
|
| 272 |
+
group_keys = keys[start:end] # [group_size, n_kv_heads, ctx_len, head_dim]
|
| 273 |
+
|
| 274 |
+
# Flatten group
|
| 275 |
+
n_group_rows = group_keys.shape[0] * n_kv_heads * ctx_len
|
| 276 |
+
|
| 277 |
+
if n_group_rows > self.max_svd_rows:
|
| 278 |
+
gen = torch.Generator()
|
| 279 |
+
gen.manual_seed(42 + g)
|
| 280 |
+
indices = torch.randperm(n_group_rows, generator=gen)[:self.max_svd_rows]
|
| 281 |
+
svd_input = group_keys.reshape(n_group_rows, head_dim)[indices].float()
|
| 282 |
+
else:
|
| 283 |
+
svd_input = rearrange(group_keys.float(), 'l h t d -> (l h t) d')
|
| 284 |
+
|
| 285 |
+
effective_rank = min(rank_per_group, svd_input.shape[0], head_dim)
|
| 286 |
+
|
| 287 |
+
# Truncated SVD for this group (on subsampled data)
|
| 288 |
+
U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
|
| 289 |
+
W_group = Vh[:effective_rank, :].T # [head_dim, rank_per_group]
|
| 290 |
+
|
| 291 |
+
# Project subsampled rows and mean-pool → [rank_per_group]
|
| 292 |
+
projected = svd_input @ W_group
|
| 293 |
+
group_vec = projected.mean(dim=0)
|
| 294 |
+
group_vecs.append(group_vec)
|
| 295 |
+
|
| 296 |
+
# Handle remainder layers (if n_layers not divisible by group_size)
|
| 297 |
+
remainder_start = n_groups * self.xkv_group_size
|
| 298 |
+
if remainder_start < n_layers:
|
| 299 |
+
remainder_keys = keys[remainder_start:]
|
| 300 |
+
n_rem_rows = remainder_keys.shape[0] * n_kv_heads * ctx_len
|
| 301 |
+
|
| 302 |
+
if n_rem_rows > self.max_svd_rows:
|
| 303 |
+
gen = torch.Generator()
|
| 304 |
+
gen.manual_seed(42 + n_groups)
|
| 305 |
+
indices = torch.randperm(n_rem_rows, generator=gen)[:self.max_svd_rows]
|
| 306 |
+
svd_input = remainder_keys.reshape(n_rem_rows, head_dim)[indices].float()
|
| 307 |
+
else:
|
| 308 |
+
svd_input = rearrange(remainder_keys.float(), 'l h t d -> (l h t) d')
|
| 309 |
+
|
| 310 |
+
effective_rank = min(rank_per_group, svd_input.shape[0], head_dim)
|
| 311 |
+
U, S, Vh = torch.linalg.svd(svd_input, full_matrices=False)
|
| 312 |
+
W_rem = Vh[:effective_rank, :].T
|
| 313 |
+
projected = svd_input @ W_rem
|
| 314 |
+
group_vecs.append(projected.mean(dim=0))
|
| 315 |
+
|
| 316 |
+
# Concatenate all group vectors → [n_groups * rank_per_group + remainder]
|
| 317 |
+
state_vec = torch.cat(group_vecs, dim=0)
|
| 318 |
+
|
| 319 |
+
return state_vec
|
| 320 |
+
|
| 321 |
+
# ── Fixed Corpus Basis (FCB) ────────────────────────────────────────────
|
| 322 |
+
|
| 323 |
+
@classmethod
|
| 324 |
+
def compute_corpus_basis(
|
| 325 |
+
cls,
|
| 326 |
+
key_tensors: list[torch.Tensor],
|
| 327 |
+
layer_range: tuple[int, int],
|
| 328 |
+
gate_start: int,
|
| 329 |
+
rank: int,
|
| 330 |
+
max_rows: int = 32768,
|
| 331 |
+
seed: int = 42,
|
| 332 |
+
) -> torch.Tensor:
|
| 333 |
+
"""Compute a fixed projection matrix from a corpus of key tensors.
|
| 334 |
+
|
| 335 |
+
Returns P: [rank, head_dim] — the global semantic basis.
|
| 336 |
+
Unlike per-document SVD, this basis is document-independent.
|
| 337 |
+
All documents projected with P exist in the same coordinate system,
|
| 338 |
+
enabling stable cross-document and cross-model comparison.
|
| 339 |
+
"""
|
| 340 |
+
l_start, l_end = layer_range
|
| 341 |
+
gen = torch.Generator()
|
| 342 |
+
gen.manual_seed(seed)
|
| 343 |
+
|
| 344 |
+
all_rows: list[torch.Tensor] = []
|
| 345 |
+
per_doc_max = max(1, max_rows // len(key_tensors))
|
| 346 |
+
|
| 347 |
+
for keys in key_tensors:
|
| 348 |
+
k = keys[l_start:l_end].float()
|
| 349 |
+
n_rows = k.shape[0] * k.shape[1] * k.shape[2]
|
| 350 |
+
flat = k.reshape(n_rows, k.shape[3])
|
| 351 |
+
if flat.shape[0] > per_doc_max:
|
| 352 |
+
idx = torch.randperm(flat.shape[0], generator=gen)[:per_doc_max]
|
| 353 |
+
flat = flat[idx]
|
| 354 |
+
all_rows.append(flat)
|
| 355 |
+
|
| 356 |
+
corpus = torch.cat(all_rows, dim=0)
|
| 357 |
+
if corpus.shape[0] > max_rows:
|
| 358 |
+
idx = torch.randperm(corpus.shape[0], generator=gen)[:max_rows]
|
| 359 |
+
corpus = corpus[idx]
|
| 360 |
+
|
| 361 |
+
_, S, Vh = torch.linalg.svd(corpus, full_matrices=False)
|
| 362 |
+
P = Vh[gate_start : gate_start + rank] # [rank, head_dim]
|
| 363 |
+
return P
|
| 364 |
+
|
| 365 |
+
def extract_with_basis(
|
| 366 |
+
self,
|
| 367 |
+
keys: torch.Tensor,
|
| 368 |
+
spec: ModelCacheSpec,
|
| 369 |
+
basis: torch.Tensor,
|
| 370 |
+
) -> ExtractionResult:
|
| 371 |
+
"""Extract state vector using a pre-computed fixed corpus basis.
|
| 372 |
+
|
| 373 |
+
All vectors computed with the same basis share a coordinate system,
|
| 374 |
+
which is required for cross-model transfer via adapter.
|
| 375 |
+
|
| 376 |
+
Args:
|
| 377 |
+
keys: [n_layers, n_kv_heads, n_cells, head_dim]
|
| 378 |
+
spec: Model spec (used for layer_range fallback)
|
| 379 |
+
basis: [rank, head_dim] from compute_corpus_basis()
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
ExtractionResult with L2-normalized state vector
|
| 383 |
+
"""
|
| 384 |
+
if self.layer_range is not None:
|
| 385 |
+
l_start, l_end = self.layer_range
|
| 386 |
+
else:
|
| 387 |
+
l_start, l_end = 0, keys.shape[0]
|
| 388 |
+
l_start = max(0, min(l_start, keys.shape[0]))
|
| 389 |
+
l_end = max(l_start, min(l_end, keys.shape[0]))
|
| 390 |
+
|
| 391 |
+
k = keys[l_start:l_end].float()
|
| 392 |
+
n_rows = k.shape[0] * k.shape[1] * k.shape[2]
|
| 393 |
+
flat = k.reshape(n_rows, k.shape[3])
|
| 394 |
+
|
| 395 |
+
proj = flat @ basis.T # [N_rows, rank]
|
| 396 |
+
vec = proj.mean(dim=0) # [rank]
|
| 397 |
+
|
| 398 |
+
norm = float(torch.linalg.vector_norm(vec).item())
|
| 399 |
+
vec_normed = vec / (norm + 1e-8)
|
| 400 |
+
|
| 401 |
+
return ExtractionResult(
|
| 402 |
+
state_vec=vec_normed.to(torch.float32),
|
| 403 |
+
l2_norm=norm,
|
| 404 |
+
mode=self.mode,
|
| 405 |
+
n_layers_used=l_end - l_start,
|
| 406 |
+
n_tokens=k.shape[2],
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# ── Fourier Fingerprint (Engram Absolute) ────────────────────────
|
| 410 |
+
|
| 411 |
+
@staticmethod
|
| 412 |
+
def compute_fourier_fingerprint(
|
| 413 |
+
keys: torch.Tensor,
|
| 414 |
+
freqs: tuple[int, ...] = (0, 1),
|
| 415 |
+
) -> torch.Tensor:
|
| 416 |
+
"""Compute the Fourier Absolute fingerprint from KV cache keys.
|
| 417 |
+
|
| 418 |
+
Takes the real DFT over the layer dimension, extracts the
|
| 419 |
+
amplitude at the specified frequencies, normalizes each, and
|
| 420 |
+
concatenates them into a single fingerprint vector.
|
| 421 |
+
|
| 422 |
+
This fingerprint is:
|
| 423 |
+
- Cross-model invariant (cos ~0.90 between 3B and 8B)
|
| 424 |
+
- Corpus-independent (no basis, no center, no training)
|
| 425 |
+
- Scale-stable (98% recall@1 at N=1000, decay N^-0.207)
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
keys: [n_layers, n_kv_heads, n_cells, head_dim] — full KV keys.
|
| 429 |
+
All layers are used (not sliced by layer_range).
|
| 430 |
+
freqs: Frequency indices to extract. Default (0, 1) = DC + 1st harmonic.
|
| 431 |
+
f=0 captures overall key magnitude profile.
|
| 432 |
+
f=1 captures dominant oscillation across depth.
|
| 433 |
+
|
| 434 |
+
Returns:
|
| 435 |
+
Fingerprint vector [dim * len(freqs)], L2-normalized.
|
| 436 |
+
"""
|
| 437 |
+
# Mean over cells (tokens) per layer: [n_layers, n_kv_heads * head_dim]
|
| 438 |
+
n_layers = keys.shape[0]
|
| 439 |
+
layer_means = keys.float().mean(dim=2).reshape(n_layers, -1)
|
| 440 |
+
|
| 441 |
+
# DFT over layer dimension
|
| 442 |
+
F_complex = torch.fft.rfft(layer_means, dim=0) # [n_freq, dim]
|
| 443 |
+
F_amp = F_complex.abs() # amplitude spectrum
|
| 444 |
+
|
| 445 |
+
# Extract and normalize each frequency component
|
| 446 |
+
parts = []
|
| 447 |
+
for f in freqs:
|
| 448 |
+
if f >= F_amp.shape[0]:
|
| 449 |
+
# Frequency out of range — use zeros
|
| 450 |
+
parts.append(torch.zeros(F_amp.shape[1]))
|
| 451 |
+
else:
|
| 452 |
+
v = F_amp[f]
|
| 453 |
+
parts.append(v / (v.norm() + 1e-8))
|
| 454 |
+
|
| 455 |
+
fingerprint = torch.cat(parts, dim=0)
|
| 456 |
+
return fingerprint / (fingerprint.norm() + 1e-8)
|
| 457 |
+
|
| 458 |
+
@property
|
| 459 |
+
def last_projection(self) -> SVDProjection | None:
|
| 460 |
+
"""Access the SVD projection from the last svd_project call.
|
| 461 |
+
|
| 462 |
+
Useful for diagnostics: check explained_variance_ratio to validate
|
| 463 |
+
that the rank is sufficient for this particular cache.
|
| 464 |
+
"""
|
| 465 |
+
return self._last_projection
|
| 466 |
+
|
| 467 |
+
def output_dim(self, spec: ModelCacheSpec) -> int:
|
| 468 |
+
"""Compute the output dimension of the state vector for a given spec.
|
| 469 |
+
|
| 470 |
+
This is needed to initialize the FAISS index with the correct dimension.
|
| 471 |
+
"""
|
| 472 |
+
match self.mode:
|
| 473 |
+
case StateExtractionMode.MEAN_POOL:
|
| 474 |
+
return spec["head_dim"]
|
| 475 |
+
case StateExtractionMode.SVD_PROJECT:
|
| 476 |
+
max_rank = min(self.gate_start + self.rank, spec["head_dim"])
|
| 477 |
+
return max_rank - self.gate_start
|
| 478 |
+
case StateExtractionMode.XKV_PROJECT:
|
| 479 |
+
extraction_layers = spec["extraction_layers"]
|
| 480 |
+
n_layers = len(extraction_layers)
|
| 481 |
+
n_groups = max(1, n_layers // self.xkv_group_size)
|
| 482 |
+
rank_per_group = max(1, self.rank // n_groups)
|
| 483 |
+
rank_per_group = min(rank_per_group, spec["head_dim"])
|
| 484 |
+
# Groups + possible remainder group
|
| 485 |
+
has_remainder = (n_layers % self.xkv_group_size) != 0
|
| 486 |
+
total_groups = n_groups + (1 if has_remainder else 0)
|
| 487 |
+
return total_groups * rank_per_group
|
| 488 |
+
case _:
|
| 489 |
+
raise ValueError(f"Unknown mode: {self.mode}")
|
kvcos/core/types.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Core Type Definitions
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
All enums, TypedDicts, constants, and type aliases live here.
|
| 6 |
+
Every downstream module imports from this file. No circular dependencies.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from dataclasses import dataclass
|
| 12 |
+
from enum import StrEnum
|
| 13 |
+
from typing import TypedDict
|
| 14 |
+
|
| 15 |
+
# ── Constants ─────────────────────────────────────────────────────────────────
|
| 16 |
+
|
| 17 |
+
ENGRAM_VERSION = "0.1.0"
|
| 18 |
+
ENG_FILE_EXTENSION = ".eng" # ENGRAM file format extension
|
| 19 |
+
BLOCK_SIZE_TOKENS = 256 # 256-token blocks per arXiv:2603.04428
|
| 20 |
+
DEFAULT_SVD_RANK = 160 # ShadowKV default for 8B models
|
| 21 |
+
DEFAULT_LATENT_DIM = 512 # MLA: full KV info recoverable from 512-dim
|
| 22 |
+
MAX_CONTEXT_TOKENS = 131072 # 128K max supported context
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
# ── Enums ─────────────────────────────────────────────────────────────────────
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class CompressionMethod(StrEnum):
|
| 29 |
+
"""Supported KV cache compression methods.
|
| 30 |
+
|
| 31 |
+
Phase 1: Q8_0, FP16
|
| 32 |
+
Phase 2: POLARQUANT (TurboQuant without QJL — QJL removed per D5)
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
FP16 = "fp16"
|
| 36 |
+
Q8_0 = "q8_0" # llama.cpp GGML_TYPE_Q8_0: ~2x compression, <5% speed hit
|
| 37 |
+
Q4_0 = "q4_0" # NOT recommended at 64K+ (92% dequant slowdown)
|
| 38 |
+
POLARQUANT = "polarquant_3bit" # PolarQuant only, no QJL
|
| 39 |
+
INT8 = "int8" # Phase 2: true int8 + per-row scale, 2x on-disk compression
|
| 40 |
+
LAYER_DELTA = "layer_delta" # Phase 2: fp16 baseline + int8 inter-layer deltas
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class StorageBackend(StrEnum):
|
| 44 |
+
"""Supported storage backends."""
|
| 45 |
+
|
| 46 |
+
LOCAL = "local"
|
| 47 |
+
REDIS = "redis" # Phase 2
|
| 48 |
+
S3 = "s3" # Phase 2
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class StateExtractionMode(StrEnum):
|
| 52 |
+
"""EGR (Engrammatic Geometry Retrieval) state vector extraction modes.
|
| 53 |
+
|
| 54 |
+
mean_pool: Fast baseline. Mean over heads + context of key matrices.
|
| 55 |
+
svd_project: Truncated SVD on pre-RoPE keys, layers 8-31, rank-160.
|
| 56 |
+
Validated by ShadowKV (ICML 2025) on Llama-3.1-8B.
|
| 57 |
+
xkv_project: Grouped cross-layer SVD, 4-layer groups, K:V rank 1:1.5.
|
| 58 |
+
From xKV (arXiv:2503.18893). 6.8x compression.
|
| 59 |
+
|
| 60 |
+
REMOVED: sals_project — last-layer-only extraction invalidated by
|
| 61 |
+
Layer-Condensed KV Cache (ACL 2024). See D3.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
MEAN_POOL = "mean_pool"
|
| 65 |
+
SVD_PROJECT = "svd_project"
|
| 66 |
+
XKV_PROJECT = "xkv_project"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class IndexBackend(StrEnum):
|
| 70 |
+
"""EGR manifold index backends."""
|
| 71 |
+
|
| 72 |
+
FAISS_FLAT_IP = "faiss_flat_ip" # Phase 1: exact MIPS
|
| 73 |
+
FAISS_IVF_IP = "faiss_ivf_ip" # Phase 2: approximate MIPS for >100K vectors
|
| 74 |
+
QDRANT_DOT = "qdrant_dot" # Phase 2: production persistent index
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class AttentionType(StrEnum):
|
| 78 |
+
"""KV cache attention mechanism per layer group.
|
| 79 |
+
|
| 80 |
+
Standard models use FULL for all layers.
|
| 81 |
+
ISWA models (Gemma 4) interleave FULL (global) and SLIDING (SWA) sections.
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
FULL = "full" # Full-context attention (standard)
|
| 85 |
+
SLIDING = "sliding" # Sliding window attention (SWA)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ── Data Classes ─────────────────────────────────────────────────────────────
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
@dataclass(frozen=True)
|
| 92 |
+
class CacheSection:
|
| 93 |
+
"""One section of a multi-section KV cache.
|
| 94 |
+
|
| 95 |
+
Standard models have a single implicit section covering all layers.
|
| 96 |
+
ISWA models serialize multiple sections sequentially in the state blob,
|
| 97 |
+
each with its own n_layers, n_kv_heads, and head_dim.
|
| 98 |
+
|
| 99 |
+
Reverse-engineered from Gemma 4 26B-A4B (llama.cpp b5200+):
|
| 100 |
+
Section 0 (Global): 5 layers, 2 KV heads, head_dim=512
|
| 101 |
+
Section 1 (SWA): 25 layers, 8 KV heads, head_dim=256
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
attention_type: AttentionType
|
| 105 |
+
n_layers: int
|
| 106 |
+
n_kv_heads: int
|
| 107 |
+
head_dim: int
|
| 108 |
+
window_size: int | None = None # SWA window size in tokens (None for full)
|
| 109 |
+
|
| 110 |
+
@property
|
| 111 |
+
def n_embd_kv(self) -> int:
|
| 112 |
+
"""Total KV embedding dimension for this section."""
|
| 113 |
+
return self.n_kv_heads * self.head_dim
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
# ── TypedDicts ────────────────────────────────────────────────────────────────
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class _ModelCacheSpecRequired(TypedDict):
|
| 120 |
+
"""Required fields for ModelCacheSpec (internal base)."""
|
| 121 |
+
|
| 122 |
+
model_id: str # e.g. "meta-llama/Llama-3.1-8B-Instruct"
|
| 123 |
+
model_family: str # e.g. "llama"
|
| 124 |
+
n_layers: int # total transformer layers
|
| 125 |
+
n_heads: int # query heads (may differ from KV heads in GQA)
|
| 126 |
+
n_kv_heads: int # key/value heads (GQA-aware)
|
| 127 |
+
head_dim: int # dimension per head
|
| 128 |
+
rope_enabled: bool # whether model uses RoPE
|
| 129 |
+
extraction_layers: tuple[int, ...] # layers for EGR state extraction (D3)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class ModelCacheSpec(_ModelCacheSpecRequired, total=False):
|
| 133 |
+
"""Architecture-agnostic specification of a model's KV cache layout.
|
| 134 |
+
|
| 135 |
+
Used to validate .eng files and ensure correct tensor shapes.
|
| 136 |
+
|
| 137 |
+
For standard models (Llama, Phi, Qwen, Mistral):
|
| 138 |
+
n_kv_heads and head_dim describe the single uniform KV cache.
|
| 139 |
+
cache_sections is absent.
|
| 140 |
+
|
| 141 |
+
For ISWA models (Gemma 4):
|
| 142 |
+
cache_sections lists per-section dimensions. Each section has its
|
| 143 |
+
own n_layers, n_kv_heads, and head_dim. The top-level n_kv_heads
|
| 144 |
+
and head_dim reflect the dominant (largest) section.
|
| 145 |
+
The state blob contains multiple sequential KV streams.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
cache_sections: tuple[CacheSection, ...]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class EngramMetadata(TypedDict, total=False):
|
| 152 |
+
"""Metadata stored in .eng file header (safetensors __metadata__).
|
| 153 |
+
|
| 154 |
+
All values are strings per safetensors spec (D7).
|
| 155 |
+
Optional fields use total=False.
|
| 156 |
+
"""
|
| 157 |
+
|
| 158 |
+
# Required
|
| 159 |
+
engram_version: str
|
| 160 |
+
cache_id: str
|
| 161 |
+
compression: str # CompressionMethod value
|
| 162 |
+
model_id: str
|
| 163 |
+
model_family: str
|
| 164 |
+
n_layers: str # stringified int
|
| 165 |
+
n_heads: str
|
| 166 |
+
n_kv_heads: str
|
| 167 |
+
head_dim: str
|
| 168 |
+
context_len: str
|
| 169 |
+
agent_id: str
|
| 170 |
+
task_description: str
|
| 171 |
+
created_at: str # ISO 8601
|
| 172 |
+
|
| 173 |
+
# Optional
|
| 174 |
+
expires_at: str
|
| 175 |
+
parent_cache_id: str
|
| 176 |
+
delta_from: str
|
| 177 |
+
token_hash: str # SHA-256 of input tokens
|
| 178 |
+
state_vec_norm: str # L2 norm of state vector (D4: stored as metadata)
|
| 179 |
+
extraction_mode: str # StateExtractionMode value
|
| 180 |
+
block_index: str # block position within a multi-block cache
|
| 181 |
+
total_blocks: str
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class CacheSearchResult(TypedDict):
|
| 185 |
+
"""Result from EGR manifold search over cached engram states."""
|
| 186 |
+
|
| 187 |
+
cache_id: str
|
| 188 |
+
similarity: float # raw inner product score (not normalized — D4)
|
| 189 |
+
task_description: str
|
| 190 |
+
model_id: str
|
| 191 |
+
created_at: str
|
| 192 |
+
context_len: int
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
class CacheStats(TypedDict):
|
| 196 |
+
"""Aggregate statistics for the engram store."""
|
| 197 |
+
|
| 198 |
+
total_entries: int
|
| 199 |
+
total_size_bytes: int
|
| 200 |
+
avg_compression_ratio: float
|
| 201 |
+
model_breakdown: dict[str, int] # model_family → count
|
kvcos/engram/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# EIGENGRAM format package
|
| 2 |
+
from .format import EigramEncoder, EigramDecoder, EIGENGRAM_MAGIC, EIGENGRAM_VERSION
|
| 3 |
+
from .writer import write_eigengram
|
| 4 |
+
from .reader import read_eigengram, load_eigengram_index
|
kvcos/engram/__main__.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EIGENGRAM command-line interface.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
python -m kvcos.engram encode --model <gguf> --text "..." --out doc.eng
|
| 6 |
+
python -m kvcos.engram search --model <gguf> --query "..." index/*.eng
|
| 7 |
+
python -m kvcos.engram inspect doc.eng
|
| 8 |
+
python -m kvcos.engram list index/*.eng
|
| 9 |
+
|
| 10 |
+
Commands:
|
| 11 |
+
encode Run a document through a GGUF model and write a .eng file.
|
| 12 |
+
search Query .eng files using a text query and a model.
|
| 13 |
+
inspect Print all fields from .eng files (no model needed).
|
| 14 |
+
list Print a summary table of .eng files (no model needed).
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import argparse
|
| 20 |
+
import gc
|
| 21 |
+
import glob
|
| 22 |
+
import os
|
| 23 |
+
import sys
|
| 24 |
+
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _resolve_paths(patterns: list[str]) -> list[str]:
|
| 29 |
+
"""Expand glob patterns, return sorted list of .eng paths."""
|
| 30 |
+
paths = []
|
| 31 |
+
for p in patterns:
|
| 32 |
+
expanded = glob.glob(p)
|
| 33 |
+
if expanded:
|
| 34 |
+
paths.extend(expanded)
|
| 35 |
+
elif os.path.exists(p):
|
| 36 |
+
paths.append(p)
|
| 37 |
+
else:
|
| 38 |
+
print(f"Warning: no files matched '{p}'", file=sys.stderr)
|
| 39 |
+
return sorted(set(paths))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def cmd_encode(args: argparse.Namespace) -> None:
|
| 43 |
+
"""Encode a document as a .eng EIGENGRAM file."""
|
| 44 |
+
from kvcos.engram.writer import write_eigengram
|
| 45 |
+
|
| 46 |
+
if args.text:
|
| 47 |
+
text = args.text
|
| 48 |
+
elif args.file:
|
| 49 |
+
if not os.path.exists(args.file):
|
| 50 |
+
print(f"Error: input file not found: {args.file}", file=sys.stderr)
|
| 51 |
+
sys.exit(1)
|
| 52 |
+
text = open(args.file).read().strip()
|
| 53 |
+
else:
|
| 54 |
+
print("Error: provide --text or --file", file=sys.stderr)
|
| 55 |
+
sys.exit(1)
|
| 56 |
+
|
| 57 |
+
output_path = args.out or (
|
| 58 |
+
os.path.splitext(args.file)[0] + ".eng" if args.file else "output.eng"
|
| 59 |
+
)
|
| 60 |
+
task_desc = args.description or text[:80]
|
| 61 |
+
cache_id_val = args.id or text[:64]
|
| 62 |
+
|
| 63 |
+
print(f"Encoding document...")
|
| 64 |
+
print(f" Model: {args.model}")
|
| 65 |
+
print(f" Text: {text[:60]}{'...' if len(text) > 60 else ''}")
|
| 66 |
+
print(f" Output: {output_path}")
|
| 67 |
+
print()
|
| 68 |
+
|
| 69 |
+
result = write_eigengram(
|
| 70 |
+
model_path=args.model,
|
| 71 |
+
text=text,
|
| 72 |
+
output_path=output_path,
|
| 73 |
+
cache_id=cache_id_val,
|
| 74 |
+
task_description=task_desc,
|
| 75 |
+
basis_path=args.basis,
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
print(f"Done.")
|
| 79 |
+
print(f" File size : {result['file_size_bytes']} bytes")
|
| 80 |
+
print(f" Model ID : {result['model_id']}")
|
| 81 |
+
print(f" SCS : {result['scs']:.4f}")
|
| 82 |
+
print(f" Basis rank: {result['basis_rank']}")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def cmd_search(args: argparse.Namespace) -> None:
|
| 86 |
+
"""Search .eng files using a text query."""
|
| 87 |
+
from llama_cpp import Llama
|
| 88 |
+
|
| 89 |
+
from kvcos.core.blob_parser import parse_state_blob
|
| 90 |
+
from kvcos.engram.reader import load_eigengram_index
|
| 91 |
+
from kvcos.core.manifold_index import ManifoldIndex
|
| 92 |
+
|
| 93 |
+
paths = _resolve_paths(args.eng_files)
|
| 94 |
+
if not paths:
|
| 95 |
+
print("No .eng files found.", file=sys.stderr)
|
| 96 |
+
sys.exit(1)
|
| 97 |
+
|
| 98 |
+
fingerprint = args.fingerprint
|
| 99 |
+
saved = torch.load(args.basis, weights_only=False)
|
| 100 |
+
P = saved["basis"]
|
| 101 |
+
center = saved["joint_center"]
|
| 102 |
+
LR = (8, 24)
|
| 103 |
+
GATE = 6
|
| 104 |
+
RANK = P.shape[0]
|
| 105 |
+
|
| 106 |
+
print(f"Query: {args.query}")
|
| 107 |
+
print(f"Index: {len(paths)} files")
|
| 108 |
+
print(f"Fingerprint: {fingerprint}")
|
| 109 |
+
print()
|
| 110 |
+
|
| 111 |
+
llm = Llama(model_path=args.model, n_ctx=2048, n_gpu_layers=-1, verbose=False)
|
| 112 |
+
meta = llm.metadata
|
| 113 |
+
n_kv = int(meta.get("llama.attention.head_count_kv", "8"))
|
| 114 |
+
hd = int(meta.get("llama.embedding_length", "4096")) // int(
|
| 115 |
+
meta.get("llama.attention.head_count", "32")
|
| 116 |
+
)
|
| 117 |
+
llm.reset()
|
| 118 |
+
llm(args.query.strip(), max_tokens=1, temperature=0.0)
|
| 119 |
+
p_q = parse_state_blob(
|
| 120 |
+
bytes(llm.save_state().llama_state), n_kv_heads=n_kv, head_dim=hd
|
| 121 |
+
)
|
| 122 |
+
del llm
|
| 123 |
+
gc.collect()
|
| 124 |
+
|
| 125 |
+
if fingerprint == "fourier":
|
| 126 |
+
from kvcos.core.fingerprint import compute_fourier_fingerprint
|
| 127 |
+
|
| 128 |
+
# Use ALL layers for Fourier (not sliced)
|
| 129 |
+
layer_means = p_q.keys.float().mean(dim=2).reshape(p_q.keys.shape[0], -1)
|
| 130 |
+
query_vec = compute_fourier_fingerprint(layer_means, freqs=[0, 1])
|
| 131 |
+
dim = query_vec.shape[0]
|
| 132 |
+
elif fingerprint == "perdoc":
|
| 133 |
+
k_q = p_q.keys[LR[0] : LR[1]].float().reshape(-1, hd)
|
| 134 |
+
_, _, Vh = torch.linalg.svd(k_q, full_matrices=False)
|
| 135 |
+
proj_q = (k_q @ Vh[GATE : GATE + RANK].T).mean(0)
|
| 136 |
+
query_vec = proj_q / (proj_q.norm() + 1e-8)
|
| 137 |
+
dim = RANK
|
| 138 |
+
else: # fcdb
|
| 139 |
+
k_q = p_q.keys[LR[0] : LR[1]].float().reshape(-1, hd)
|
| 140 |
+
mean_q = k_q.mean(0)
|
| 141 |
+
delta_q = mean_q - center
|
| 142 |
+
delta_q = delta_q / (delta_q.norm() + 1e-8)
|
| 143 |
+
query_vec = delta_q @ P.T
|
| 144 |
+
query_vec = query_vec / (query_vec.norm() + 1e-8)
|
| 145 |
+
dim = RANK
|
| 146 |
+
|
| 147 |
+
vecs, entries = load_eigengram_index(paths, fingerprint=fingerprint)
|
| 148 |
+
idx = ManifoldIndex(dim=dim)
|
| 149 |
+
for v, e in zip(vecs, entries):
|
| 150 |
+
idx.add(v, e)
|
| 151 |
+
|
| 152 |
+
top_k = min(args.top_k, len(paths))
|
| 153 |
+
results = idx.search(query_vec, top_k=top_k)
|
| 154 |
+
|
| 155 |
+
print(f"Results (top {top_k}):")
|
| 156 |
+
print(f" {'#':<3} {'sim':>7} {'cache_id':<20} description")
|
| 157 |
+
print(f" {'---'} {'-------'} {'--------------------'} {'----------------------------------------'}")
|
| 158 |
+
for i, r in enumerate(results):
|
| 159 |
+
desc = r.get("task_description", "")[:40]
|
| 160 |
+
cid = r.get("cache_id", "")[:20]
|
| 161 |
+
print(f" {i + 1:<3} {r['similarity']:>+.4f} {cid:<20} {desc}")
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def cmd_inspect(args: argparse.Namespace) -> None:
|
| 165 |
+
"""Print all fields of .eng files in readable format."""
|
| 166 |
+
from kvcos.engram.reader import read_eigengram
|
| 167 |
+
|
| 168 |
+
paths = _resolve_paths(args.eng_files)
|
| 169 |
+
if not paths:
|
| 170 |
+
print("No .eng files found.", file=sys.stderr)
|
| 171 |
+
sys.exit(1)
|
| 172 |
+
|
| 173 |
+
for path in paths:
|
| 174 |
+
try:
|
| 175 |
+
rec = read_eigengram(path)
|
| 176 |
+
except Exception as e:
|
| 177 |
+
print(f" {path}: ERROR - {e}")
|
| 178 |
+
continue
|
| 179 |
+
|
| 180 |
+
size = os.path.getsize(path)
|
| 181 |
+
print(f"{'=' * 55}")
|
| 182 |
+
print(f" File: {path} ({size} bytes)")
|
| 183 |
+
print(f" Format: EGR1 v{rec['version']}")
|
| 184 |
+
print(f" Created: {rec['created_at']} UTC")
|
| 185 |
+
print(f" Model: {rec['model_id']}")
|
| 186 |
+
print(f" cache_id: {rec['cache_id']}")
|
| 187 |
+
print(f" Description: {rec['task_description']}")
|
| 188 |
+
print()
|
| 189 |
+
print(f" Basis rank: {rec['basis_rank']}")
|
| 190 |
+
print(f" N corpus: {rec['n_corpus']}")
|
| 191 |
+
print(f" Layer range: {rec['layer_range']}")
|
| 192 |
+
print(f" Context len: {rec['context_len']} KV rows")
|
| 193 |
+
print(f" L2 norm: {rec['l2_norm']:.4f}")
|
| 194 |
+
print(f" SCS: {rec['scs']:.4f}")
|
| 195 |
+
print(f" Margin proof: {rec['margin_proof']:.4f}")
|
| 196 |
+
print(f" Corpus hash: {rec['corpus_hash']}")
|
| 197 |
+
print(f" vec_perdoc: [{rec['vec_perdoc'].shape[0]}] norm={rec['vec_perdoc'].norm():.4f}")
|
| 198 |
+
print(f" vec_fcdb: [{rec['vec_fcdb'].shape[0]}] norm={rec['vec_fcdb'].norm():.4f}")
|
| 199 |
+
print()
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def cmd_list(args: argparse.Namespace) -> None:
|
| 203 |
+
"""Print a one-line summary table of .eng files."""
|
| 204 |
+
from kvcos.engram.reader import read_eigengram
|
| 205 |
+
|
| 206 |
+
paths = _resolve_paths(args.eng_files)
|
| 207 |
+
if not paths:
|
| 208 |
+
print("No .eng files found.", file=sys.stderr)
|
| 209 |
+
sys.exit(1)
|
| 210 |
+
|
| 211 |
+
hdr = f"{'filename':<30} {'model':<14} {'scs':>6} {'bytes':>5} description"
|
| 212 |
+
print(hdr)
|
| 213 |
+
print("-" * len(hdr))
|
| 214 |
+
|
| 215 |
+
for path in paths:
|
| 216 |
+
fname = os.path.basename(path)[:29]
|
| 217 |
+
try:
|
| 218 |
+
rec = read_eigengram(path)
|
| 219 |
+
size = os.path.getsize(path)
|
| 220 |
+
print(
|
| 221 |
+
f"{fname:<30} {rec['model_id'][:14]:<14} "
|
| 222 |
+
f"{rec['scs']:>6.3f} {size:>5} "
|
| 223 |
+
f"{rec['task_description'][:40]}"
|
| 224 |
+
)
|
| 225 |
+
except Exception as e:
|
| 226 |
+
print(f"{fname:<30} ERROR: {e}")
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def main() -> None:
|
| 230 |
+
"""EIGENGRAM CLI entry point."""
|
| 231 |
+
parser = argparse.ArgumentParser(
|
| 232 |
+
prog="python -m kvcos.engram",
|
| 233 |
+
description="EIGENGRAM CLI - encode and search KV-cache semantic certificates.",
|
| 234 |
+
)
|
| 235 |
+
sub = parser.add_subparsers(dest="command", required=True)
|
| 236 |
+
|
| 237 |
+
enc = sub.add_parser("encode", help="Encode a document as a .eng file.")
|
| 238 |
+
enc.add_argument("--model", required=True, help="Path to GGUF model file.")
|
| 239 |
+
enc.add_argument("--text", help="Document text.")
|
| 240 |
+
enc.add_argument("--file", help="Path to a text file to encode.")
|
| 241 |
+
enc.add_argument("--out", help="Output .eng file path.")
|
| 242 |
+
enc.add_argument("--id", help="Unique cache_id.")
|
| 243 |
+
enc.add_argument("--description", help="Human-readable description.")
|
| 244 |
+
enc.add_argument("--basis", default="results/corpus_basis_fcdb_v2.pt", help="FCDB v2 basis path.")
|
| 245 |
+
|
| 246 |
+
srch = sub.add_parser("search", help="Search .eng files with a query.")
|
| 247 |
+
srch.add_argument("--model", required=True, help="GGUF model for query encoding.")
|
| 248 |
+
srch.add_argument("--query", required=True, help="Query text.")
|
| 249 |
+
srch.add_argument("--fingerprint", default="fourier", choices=["perdoc", "fcdb", "fourier"])
|
| 250 |
+
srch.add_argument("--top-k", type=int, default=5, dest="top_k")
|
| 251 |
+
srch.add_argument("--basis", default="results/corpus_basis_fcdb_v2.pt")
|
| 252 |
+
srch.add_argument("eng_files", nargs="+", help=".eng file paths or globs.")
|
| 253 |
+
|
| 254 |
+
ins = sub.add_parser("inspect", help="Print all fields of .eng files.")
|
| 255 |
+
ins.add_argument("eng_files", nargs="+")
|
| 256 |
+
|
| 257 |
+
lst = sub.add_parser("list", help="Summary table of .eng files.")
|
| 258 |
+
lst.add_argument("eng_files", nargs="+")
|
| 259 |
+
|
| 260 |
+
args = parser.parse_args()
|
| 261 |
+
{"encode": cmd_encode, "search": cmd_search, "inspect": cmd_inspect, "list": cmd_list}[args.command](args)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
main()
|
kvcos/engram/chunker.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/chunker.py — Markdown-aware semantic chunker.
|
| 3 |
+
|
| 4 |
+
Splits markdown files into chunks suitable for .eng indexing.
|
| 5 |
+
Each chunk gets its own fingerprint and becomes independently
|
| 6 |
+
retrievable via HNSW.
|
| 7 |
+
|
| 8 |
+
Strategy:
|
| 9 |
+
1. Split on H1/H2 headers first (natural semantic boundaries)
|
| 10 |
+
2. If a section exceeds max_chars, split on H3/H4
|
| 11 |
+
3. If still too large, split on paragraph boundaries
|
| 12 |
+
4. Never break mid-paragraph (preserve semantic coherence)
|
| 13 |
+
|
| 14 |
+
Each chunk carries context: the file's title + parent headers
|
| 15 |
+
are prepended so the fingerprint captures the full meaning.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
import re
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
from typing import Sequence
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass(frozen=True)
|
| 26 |
+
class Chunk:
|
| 27 |
+
"""One semantic chunk from a markdown file."""
|
| 28 |
+
text: str # Chunk content (with context header prepended)
|
| 29 |
+
raw_text: str # Original content without context header
|
| 30 |
+
char_start: int # Start offset in original file
|
| 31 |
+
char_end: int # End offset in original file
|
| 32 |
+
index: int # 0-based chunk index
|
| 33 |
+
headers: tuple[str, ...] # Header hierarchy (e.g., ("# Title", "## Section"))
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def char_count(self) -> int:
|
| 37 |
+
return len(self.text)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
# Regex for markdown headers (ATX style: # through ######)
|
| 41 |
+
_HEADER_RE = re.compile(r"^(#{1,6})\s+(.+)$", re.MULTILINE)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _header_level(line: str) -> int:
|
| 45 |
+
"""Return header level (1-6) or 0 if not a header."""
|
| 46 |
+
m = re.match(r"^(#{1,6})\s+", line)
|
| 47 |
+
return len(m.group(1)) if m else 0
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _split_by_headers(
|
| 51 |
+
content: str,
|
| 52 |
+
max_level: int = 2,
|
| 53 |
+
) -> list[tuple[int, int, list[str]]]:
|
| 54 |
+
"""
|
| 55 |
+
Split content into sections by header level.
|
| 56 |
+
|
| 57 |
+
Returns list of (start_offset, end_offset, header_stack) tuples.
|
| 58 |
+
max_level: split on headers of this level or lower (1=H1, 2=H2, etc.)
|
| 59 |
+
"""
|
| 60 |
+
sections: list[tuple[int, int, list[str]]] = []
|
| 61 |
+
header_stack: list[str] = []
|
| 62 |
+
current_start = 0
|
| 63 |
+
|
| 64 |
+
for m in _HEADER_RE.finditer(content):
|
| 65 |
+
level = len(m.group(1))
|
| 66 |
+
header_text = m.group(0).strip()
|
| 67 |
+
|
| 68 |
+
if level <= max_level and m.start() > current_start:
|
| 69 |
+
# Close previous section
|
| 70 |
+
section_text = content[current_start:m.start()].strip()
|
| 71 |
+
if section_text:
|
| 72 |
+
sections.append((
|
| 73 |
+
current_start,
|
| 74 |
+
m.start(),
|
| 75 |
+
list(header_stack),
|
| 76 |
+
))
|
| 77 |
+
current_start = m.start()
|
| 78 |
+
|
| 79 |
+
# Update header stack
|
| 80 |
+
if level <= max_level:
|
| 81 |
+
# Trim stack to parent level and push current
|
| 82 |
+
header_stack = [
|
| 83 |
+
h for h in header_stack
|
| 84 |
+
if _header_level(h) < level
|
| 85 |
+
]
|
| 86 |
+
header_stack.append(header_text)
|
| 87 |
+
|
| 88 |
+
# Final section
|
| 89 |
+
if current_start < len(content):
|
| 90 |
+
final_text = content[current_start:].strip()
|
| 91 |
+
if final_text:
|
| 92 |
+
sections.append((
|
| 93 |
+
current_start,
|
| 94 |
+
len(content),
|
| 95 |
+
list(header_stack),
|
| 96 |
+
))
|
| 97 |
+
|
| 98 |
+
return sections
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _split_paragraphs(
|
| 102 |
+
text: str,
|
| 103 |
+
max_chars: int,
|
| 104 |
+
base_offset: int = 0,
|
| 105 |
+
) -> list[tuple[int, int]]:
|
| 106 |
+
"""
|
| 107 |
+
Split text into chunks at paragraph boundaries.
|
| 108 |
+
|
| 109 |
+
Returns list of (start_offset, end_offset) tuples relative
|
| 110 |
+
to the original file (offset by base_offset).
|
| 111 |
+
"""
|
| 112 |
+
paragraphs = re.split(r"\n\n+", text)
|
| 113 |
+
chunks: list[tuple[int, int]] = []
|
| 114 |
+
current_start = 0
|
| 115 |
+
current_len = 0
|
| 116 |
+
|
| 117 |
+
for para in paragraphs:
|
| 118 |
+
para_len = len(para) + 2 # +2 for the \n\n separator
|
| 119 |
+
|
| 120 |
+
if current_len + para_len > max_chars and current_len > 0:
|
| 121 |
+
# Close current chunk
|
| 122 |
+
chunks.append((
|
| 123 |
+
base_offset + current_start,
|
| 124 |
+
base_offset + current_start + current_len,
|
| 125 |
+
))
|
| 126 |
+
current_start = current_start + current_len
|
| 127 |
+
current_len = 0
|
| 128 |
+
|
| 129 |
+
current_len += para_len
|
| 130 |
+
|
| 131 |
+
# Final chunk
|
| 132 |
+
if current_len > 0:
|
| 133 |
+
chunks.append((
|
| 134 |
+
base_offset + current_start,
|
| 135 |
+
base_offset + current_start + current_len,
|
| 136 |
+
))
|
| 137 |
+
|
| 138 |
+
return chunks
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def chunk_markdown(
|
| 142 |
+
content: str,
|
| 143 |
+
max_chars: int = 2000,
|
| 144 |
+
min_chars: int = 100,
|
| 145 |
+
context_prefix: str = "",
|
| 146 |
+
) -> list[Chunk]:
|
| 147 |
+
"""
|
| 148 |
+
Split a markdown file into semantic chunks.
|
| 149 |
+
|
| 150 |
+
Args:
|
| 151 |
+
content: Full markdown file content.
|
| 152 |
+
max_chars: Target maximum chars per chunk (soft limit).
|
| 153 |
+
min_chars: Minimum chars — smaller sections merge with next.
|
| 154 |
+
context_prefix: Prepended to each chunk for context
|
| 155 |
+
(e.g., "Source: geodesic3.md | Project: engram").
|
| 156 |
+
|
| 157 |
+
Returns:
|
| 158 |
+
List of Chunk objects, ordered by position in file.
|
| 159 |
+
"""
|
| 160 |
+
if not content.strip():
|
| 161 |
+
return []
|
| 162 |
+
|
| 163 |
+
# If the whole file fits in one chunk, return it directly
|
| 164 |
+
if len(content) <= max_chars:
|
| 165 |
+
full_text = f"{context_prefix}\n\n{content}" if context_prefix else content
|
| 166 |
+
return [Chunk(
|
| 167 |
+
text=full_text,
|
| 168 |
+
raw_text=content,
|
| 169 |
+
char_start=0,
|
| 170 |
+
char_end=len(content),
|
| 171 |
+
index=0,
|
| 172 |
+
headers=(),
|
| 173 |
+
)]
|
| 174 |
+
|
| 175 |
+
# Phase 1: Split on H1/H2 boundaries
|
| 176 |
+
sections = _split_by_headers(content, max_level=2)
|
| 177 |
+
|
| 178 |
+
# If no headers found, treat as single block
|
| 179 |
+
if not sections:
|
| 180 |
+
sections = [(0, len(content), [])]
|
| 181 |
+
|
| 182 |
+
# Phase 2: Sub-split large sections on H3/H4
|
| 183 |
+
refined: list[tuple[int, int, list[str]]] = []
|
| 184 |
+
for start, end, headers in sections:
|
| 185 |
+
section_text = content[start:end]
|
| 186 |
+
if len(section_text) > max_chars:
|
| 187 |
+
subsections = _split_by_headers(section_text, max_level=4)
|
| 188 |
+
if len(subsections) > 1:
|
| 189 |
+
for sub_start, sub_end, sub_headers in subsections:
|
| 190 |
+
refined.append((
|
| 191 |
+
start + sub_start,
|
| 192 |
+
start + sub_end,
|
| 193 |
+
headers + sub_headers,
|
| 194 |
+
))
|
| 195 |
+
else:
|
| 196 |
+
refined.append((start, end, headers))
|
| 197 |
+
else:
|
| 198 |
+
refined.append((start, end, headers))
|
| 199 |
+
|
| 200 |
+
# Phase 3: Paragraph-split anything still over max_chars
|
| 201 |
+
final_ranges: list[tuple[int, int, list[str]]] = []
|
| 202 |
+
for start, end, headers in refined:
|
| 203 |
+
section_text = content[start:end]
|
| 204 |
+
if len(section_text) > max_chars:
|
| 205 |
+
para_ranges = _split_paragraphs(section_text, max_chars, start)
|
| 206 |
+
for p_start, p_end in para_ranges:
|
| 207 |
+
final_ranges.append((p_start, p_end, headers))
|
| 208 |
+
else:
|
| 209 |
+
final_ranges.append((start, end, headers))
|
| 210 |
+
|
| 211 |
+
# Phase 4: Greedily pack sections into chunks up to max_chars.
|
| 212 |
+
# Keep merging consecutive sections while their combined size
|
| 213 |
+
# stays under max_chars. This prevents over-fragmentation of
|
| 214 |
+
# files with many small header sections.
|
| 215 |
+
merged: list[tuple[int, int, list[str]]] = []
|
| 216 |
+
for start, end, headers in final_ranges:
|
| 217 |
+
chunk_text = content[start:end].strip()
|
| 218 |
+
if not chunk_text:
|
| 219 |
+
continue
|
| 220 |
+
|
| 221 |
+
if merged:
|
| 222 |
+
prev_start, prev_end, prev_headers = merged[-1]
|
| 223 |
+
prev_len = prev_end - prev_start
|
| 224 |
+
curr_len = end - start
|
| 225 |
+
|
| 226 |
+
# Merge if combined chunk stays under max_chars
|
| 227 |
+
if (prev_len + curr_len) <= max_chars:
|
| 228 |
+
merged[-1] = (prev_start, end, prev_headers)
|
| 229 |
+
continue
|
| 230 |
+
|
| 231 |
+
merged.append((start, end, headers))
|
| 232 |
+
|
| 233 |
+
# Phase 5: Build Chunk objects with context
|
| 234 |
+
chunks: list[Chunk] = []
|
| 235 |
+
for idx, (start, end, headers) in enumerate(merged):
|
| 236 |
+
raw = content[start:end].strip()
|
| 237 |
+
if not raw:
|
| 238 |
+
continue
|
| 239 |
+
|
| 240 |
+
# Build context header
|
| 241 |
+
parts = []
|
| 242 |
+
if context_prefix:
|
| 243 |
+
parts.append(context_prefix)
|
| 244 |
+
if headers:
|
| 245 |
+
parts.append(" > ".join(headers))
|
| 246 |
+
|
| 247 |
+
prefix = "\n".join(parts)
|
| 248 |
+
text = f"{prefix}\n\n{raw}" if prefix else raw
|
| 249 |
+
|
| 250 |
+
chunks.append(Chunk(
|
| 251 |
+
text=text,
|
| 252 |
+
raw_text=raw,
|
| 253 |
+
char_start=start,
|
| 254 |
+
char_end=end,
|
| 255 |
+
index=idx,
|
| 256 |
+
headers=tuple(headers),
|
| 257 |
+
))
|
| 258 |
+
|
| 259 |
+
# Re-index after merging
|
| 260 |
+
return [
|
| 261 |
+
Chunk(
|
| 262 |
+
text=c.text,
|
| 263 |
+
raw_text=c.raw_text,
|
| 264 |
+
char_start=c.char_start,
|
| 265 |
+
char_end=c.char_end,
|
| 266 |
+
index=i,
|
| 267 |
+
headers=c.headers,
|
| 268 |
+
)
|
| 269 |
+
for i, c in enumerate(chunks)
|
| 270 |
+
]
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
def slug_from_path(path: str) -> str:
|
| 274 |
+
"""
|
| 275 |
+
Generate a kebab-case slug from a file path.
|
| 276 |
+
|
| 277 |
+
Examples:
|
| 278 |
+
"geodesic3.md" → "geodesic3"
|
| 279 |
+
"EIGENGRAM_SPEC.md" → "eigengram-spec"
|
| 280 |
+
"coding-style.md" → "coding-style"
|
| 281 |
+
"""
|
| 282 |
+
name = path.rsplit("/", 1)[-1] # filename only
|
| 283 |
+
name = name.rsplit(".", 1)[0] # strip extension
|
| 284 |
+
# Convert underscores and spaces to hyphens, lowercase
|
| 285 |
+
slug = re.sub(r"[_\s]+", "-", name).lower()
|
| 286 |
+
# Strip non-alphanumeric except hyphens
|
| 287 |
+
slug = re.sub(r"[^a-z0-9-]", "", slug)
|
| 288 |
+
# Collapse multiple hyphens
|
| 289 |
+
slug = re.sub(r"-+", "-", slug).strip("-")
|
| 290 |
+
return slug
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def eng_filename(
|
| 294 |
+
project: str,
|
| 295 |
+
slug: str,
|
| 296 |
+
date: str,
|
| 297 |
+
chunk_index: int | None = None,
|
| 298 |
+
chunk_total: int | None = None,
|
| 299 |
+
time_str: str = "",
|
| 300 |
+
) -> str:
|
| 301 |
+
"""
|
| 302 |
+
Generate .eng filename following the naming convention.
|
| 303 |
+
|
| 304 |
+
Format: <slug>[_<chunk>]_<date>[_<time>].eng
|
| 305 |
+
|
| 306 |
+
Args:
|
| 307 |
+
project: Project namespace (used for directory, not filename)
|
| 308 |
+
slug: Kebab-case file identifier
|
| 309 |
+
date: ISO date string (YYYY-MM-DD)
|
| 310 |
+
chunk_index: 0-based chunk index (None if single chunk)
|
| 311 |
+
chunk_total: Total chunks (None if single chunk)
|
| 312 |
+
time_str: Optional HHmm time string
|
| 313 |
+
|
| 314 |
+
Returns:
|
| 315 |
+
Filename (not full path) like "geodesic3_001_2026-04-02.eng"
|
| 316 |
+
"""
|
| 317 |
+
parts = [slug]
|
| 318 |
+
|
| 319 |
+
if chunk_index is not None and chunk_total is not None and chunk_total > 1:
|
| 320 |
+
parts.append(f"{chunk_index + 1:03d}")
|
| 321 |
+
|
| 322 |
+
parts.append(date)
|
| 323 |
+
|
| 324 |
+
if time_str:
|
| 325 |
+
parts.append(time_str)
|
| 326 |
+
|
| 327 |
+
return "_".join(parts) + ".eng"
|
kvcos/engram/embedder.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/embedder.py — Unified text-to-fingerprint embedding.
|
| 3 |
+
|
| 4 |
+
Three strategies, tried in priority order:
|
| 5 |
+
1. llama_cpp: Native ENGRAM KV-cache Fourier pipeline (2048-dim)
|
| 6 |
+
2. sbert: Sentence-transformers all-MiniLM-L6-v2 (384-dim)
|
| 7 |
+
3. hash: Deterministic SHA256-seeded pseudo-fingerprint (2048-dim)
|
| 8 |
+
|
| 9 |
+
The chosen strategy is cached after first call. The fingerprint
|
| 10 |
+
source tag travels with every .eng file so retrieval knows what
|
| 11 |
+
comparison is valid.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
from kvcos.engram.embedder import get_fingerprint
|
| 15 |
+
fp, source = get_fingerprint("some text")
|
| 16 |
+
# fp: torch.Tensor, source: "llama_cpp"|"sbert"|"hash-fallback"
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
|
| 21 |
+
import hashlib
|
| 22 |
+
import os
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Protocol
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class Embedder(Protocol):
|
| 31 |
+
"""Protocol for text → fingerprint embedding."""
|
| 32 |
+
def embed(self, text: str) -> torch.Tensor: ...
|
| 33 |
+
@property
|
| 34 |
+
def source(self) -> str: ...
|
| 35 |
+
@property
|
| 36 |
+
def dim(self) -> int: ...
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# ── Strategy 1: Native ENGRAM (llama_cpp) ────────────────────────────
|
| 40 |
+
|
| 41 |
+
class LlamaCppEmbedder:
|
| 42 |
+
"""KV-cache Fourier fingerprint via local GGUF model.
|
| 43 |
+
|
| 44 |
+
Uses the full ENGRAM pipeline:
|
| 45 |
+
text → LlamaCppBridge (generate → KV cache) → Fourier DFT → fingerprint
|
| 46 |
+
|
| 47 |
+
Supports both standard and ISWA models:
|
| 48 |
+
Standard (Llama): 2048-dim (8 × 128 × 2)
|
| 49 |
+
ISWA (Gemma 4): 6144-dim (1024×2 + 2048×2)
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, model_path: str) -> None:
|
| 53 |
+
from integrations.llama_cpp_bridge import LlamaCppBridge
|
| 54 |
+
from kvcos.core.cache_spec import is_iswa_spec
|
| 55 |
+
from kvcos.core.fingerprint import compute_fourier_fingerprint_v2, compute_iswa_fingerprint
|
| 56 |
+
|
| 57 |
+
self._bridge = LlamaCppBridge(
|
| 58 |
+
model_path,
|
| 59 |
+
n_ctx=2048,
|
| 60 |
+
n_gpu_layers=0,
|
| 61 |
+
verbose=False,
|
| 62 |
+
)
|
| 63 |
+
self._spec = self._bridge.load_model()
|
| 64 |
+
self._is_iswa = is_iswa_spec(self._spec)
|
| 65 |
+
self._compute_standard = compute_fourier_fingerprint_v2
|
| 66 |
+
self._compute_iswa = compute_iswa_fingerprint
|
| 67 |
+
|
| 68 |
+
if self._is_iswa:
|
| 69 |
+
sections = self._spec["cache_sections"]
|
| 70 |
+
self._dim = sum(s.n_kv_heads * s.head_dim * 2 for s in sections)
|
| 71 |
+
else:
|
| 72 |
+
self._dim = self._spec["n_kv_heads"] * self._spec["head_dim"] * 2
|
| 73 |
+
|
| 74 |
+
def embed(self, text: str) -> torch.Tensor:
|
| 75 |
+
"""Generate text through model, extract KV keys, compute Fourier fp."""
|
| 76 |
+
self._bridge.llm.reset()
|
| 77 |
+
self._bridge.generate(text, max_tokens=1)
|
| 78 |
+
|
| 79 |
+
if self._is_iswa:
|
| 80 |
+
parsed = self._bridge.extract_kv_cache_iswa()
|
| 81 |
+
return self._compute_iswa(parsed, freqs=[0, 1])
|
| 82 |
+
|
| 83 |
+
parsed = self._bridge.extract_kv_cache()
|
| 84 |
+
layer_keys = parsed.keys.float().mean(dim=2)
|
| 85 |
+
return self._compute_standard(layer_keys, freqs=[0, 1])
|
| 86 |
+
|
| 87 |
+
@property
|
| 88 |
+
def source(self) -> str:
|
| 89 |
+
return "llama_cpp"
|
| 90 |
+
|
| 91 |
+
@property
|
| 92 |
+
def dim(self) -> int:
|
| 93 |
+
return self._dim
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# ── Strategy 2: Sentence-transformers ────────────────────────────────
|
| 97 |
+
|
| 98 |
+
class SBertEmbedder:
|
| 99 |
+
"""Semantic fingerprint via sentence-transformers.
|
| 100 |
+
|
| 101 |
+
Uses all-MiniLM-L6-v2 (80MB, 384-dim). Downloads on first use.
|
| 102 |
+
Subsequent calls use the cached model (~50ms per text on CPU).
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
MODEL_NAME = "all-MiniLM-L6-v2"
|
| 106 |
+
|
| 107 |
+
def __init__(self) -> None:
|
| 108 |
+
import logging
|
| 109 |
+
import warnings
|
| 110 |
+
# Suppress noisy HF/tokenizer/sbert/safetensors warnings on load
|
| 111 |
+
for name in (
|
| 112 |
+
"sentence_transformers",
|
| 113 |
+
"transformers",
|
| 114 |
+
"transformers.modeling_utils",
|
| 115 |
+
"huggingface_hub",
|
| 116 |
+
"huggingface_hub.utils",
|
| 117 |
+
"safetensors",
|
| 118 |
+
):
|
| 119 |
+
logging.getLogger(name).setLevel(logging.CRITICAL)
|
| 120 |
+
# Suppress the HF_TOKEN and load-report warnings
|
| 121 |
+
os.environ.setdefault("HF_HUB_DISABLE_PROGRESS_BARS", "1")
|
| 122 |
+
os.environ.setdefault("TRANSFORMERS_VERBOSITY", "error")
|
| 123 |
+
os.environ.setdefault("HF_HUB_VERBOSITY", "error")
|
| 124 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 125 |
+
from sentence_transformers import SentenceTransformer
|
| 126 |
+
self._model = SentenceTransformer(self.MODEL_NAME)
|
| 127 |
+
self._dim = self._model.get_sentence_embedding_dimension()
|
| 128 |
+
|
| 129 |
+
def embed(self, text: str) -> torch.Tensor:
|
| 130 |
+
# encode returns numpy array
|
| 131 |
+
vec = self._model.encode(text, normalize_embeddings=True)
|
| 132 |
+
return torch.from_numpy(vec.astype(np.float32))
|
| 133 |
+
|
| 134 |
+
@property
|
| 135 |
+
def source(self) -> str:
|
| 136 |
+
return "sbert"
|
| 137 |
+
|
| 138 |
+
@property
|
| 139 |
+
def dim(self) -> int:
|
| 140 |
+
return self._dim
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
# ── Strategy 3: Hash fallback ────────────────────────────────────────
|
| 144 |
+
|
| 145 |
+
class HashEmbedder:
|
| 146 |
+
"""Deterministic pseudo-fingerprint from SHA256 hash.
|
| 147 |
+
|
| 148 |
+
No semantic meaning — same text always maps to same vector,
|
| 149 |
+
but unrelated texts have random cosine similarity (~0).
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(self, dim: int = 2048) -> None:
|
| 153 |
+
self._dim = dim
|
| 154 |
+
|
| 155 |
+
def embed(self, text: str) -> torch.Tensor:
|
| 156 |
+
seed = int(hashlib.sha256(text.encode()).hexdigest()[:8], 16)
|
| 157 |
+
rng = np.random.RandomState(seed)
|
| 158 |
+
fp = rng.randn(self._dim).astype("float32")
|
| 159 |
+
fp /= np.linalg.norm(fp) + 1e-8
|
| 160 |
+
return torch.from_numpy(fp)
|
| 161 |
+
|
| 162 |
+
@property
|
| 163 |
+
def source(self) -> str:
|
| 164 |
+
return "hash-fallback"
|
| 165 |
+
|
| 166 |
+
@property
|
| 167 |
+
def dim(self) -> int:
|
| 168 |
+
return self._dim
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
# ── Singleton factory ────────────────────────────────────────────────
|
| 172 |
+
|
| 173 |
+
_cached_embedder: Embedder | None = None
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _create_embedder() -> Embedder:
|
| 177 |
+
"""Try strategies in priority order, return first that works."""
|
| 178 |
+
|
| 179 |
+
# Strategy 1: llama_cpp
|
| 180 |
+
model_path = os.environ.get("ENGRAM_MODEL_PATH", "")
|
| 181 |
+
if model_path and Path(model_path).exists():
|
| 182 |
+
try:
|
| 183 |
+
return LlamaCppEmbedder(model_path)
|
| 184 |
+
except Exception:
|
| 185 |
+
pass
|
| 186 |
+
|
| 187 |
+
# Strategy 2: sentence-transformers
|
| 188 |
+
try:
|
| 189 |
+
embedder = SBertEmbedder()
|
| 190 |
+
return embedder
|
| 191 |
+
except Exception:
|
| 192 |
+
pass
|
| 193 |
+
|
| 194 |
+
# Strategy 3: hash fallback (always works)
|
| 195 |
+
return HashEmbedder()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def get_embedder() -> Embedder:
|
| 199 |
+
"""Get the cached embedder singleton."""
|
| 200 |
+
global _cached_embedder
|
| 201 |
+
if _cached_embedder is None:
|
| 202 |
+
_cached_embedder = _create_embedder()
|
| 203 |
+
return _cached_embedder
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def get_fingerprint(text: str) -> tuple[torch.Tensor, str]:
|
| 207 |
+
"""
|
| 208 |
+
Compute fingerprint for text using best available strategy.
|
| 209 |
+
|
| 210 |
+
Returns:
|
| 211 |
+
(fingerprint_tensor, source_tag)
|
| 212 |
+
"""
|
| 213 |
+
embedder = get_embedder()
|
| 214 |
+
fp = embedder.embed(text)
|
| 215 |
+
return fp, embedder.source
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def reset_embedder() -> None:
|
| 219 |
+
"""Reset the cached embedder (for testing or strategy switching)."""
|
| 220 |
+
global _cached_embedder
|
| 221 |
+
_cached_embedder = None
|
kvcos/engram/format.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EIGENGRAM binary format codec (EGR1 v1.0).
|
| 3 |
+
|
| 4 |
+
An EIGENGRAM (.eng) file is a self-contained semantic certificate
|
| 5 |
+
for a KV-cache document. It encodes two fingerprint vectors, the
|
| 6 |
+
shared coordinate system they live in, and enough metadata to
|
| 7 |
+
reproduce the query fold-in without access to the original text
|
| 8 |
+
or model.
|
| 9 |
+
|
| 10 |
+
Design goals:
|
| 11 |
+
- Portable: pure binary, no JSON, no pickle, no protobuf.
|
| 12 |
+
- Versioned: magic bytes + version field.
|
| 13 |
+
- Self-contained: joint_center embedded for query fold-in.
|
| 14 |
+
- Compact: float16 vectors, ~800 bytes total per document.
|
| 15 |
+
|
| 16 |
+
Dual-fingerprint architecture:
|
| 17 |
+
vec_perdoc - per-document SVD projection (same-model, margin ~0.37)
|
| 18 |
+
vec_fcdb - FCDB projection (cross-model, margin ~0.013)
|
| 19 |
+
|
| 20 |
+
Binary layout (little-endian, 99-byte fixed header):
|
| 21 |
+
Offset Size Type Field
|
| 22 |
+
0 4 bytes magic = "EGR1"
|
| 23 |
+
4 1 uint8 version (currently 1)
|
| 24 |
+
5 32 ascii corpus_hash
|
| 25 |
+
37 20 ascii created_at
|
| 26 |
+
57 16 ascii model_id (null-padded)
|
| 27 |
+
73 2 uint16 basis_rank R
|
| 28 |
+
75 2 uint16 n_corpus
|
| 29 |
+
77 2 int8x2 layer_range
|
| 30 |
+
79 4 uint32 context_len
|
| 31 |
+
83 4 float32 l2_norm
|
| 32 |
+
87 4 float32 scs
|
| 33 |
+
91 4 float32 margin_proof
|
| 34 |
+
95 2 uint16 task_desc_len
|
| 35 |
+
97 2 uint16 cache_id_len
|
| 36 |
+
Variable:
|
| 37 |
+
99 R*2 float16 vec_perdoc
|
| 38 |
+
+R*2 R*2 float16 vec_fcdb
|
| 39 |
+
+2R*2 256 float16 joint_center (128 x float16)
|
| 40 |
+
+256 var utf-8 task_description
|
| 41 |
+
+var var utf-8 cache_id
|
| 42 |
+
|
| 43 |
+
Total for R=116: ~800 bytes.
|
| 44 |
+
|
| 45 |
+
Compatibility: readers MUST reject magic != "EGR1" or version mismatch.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
from __future__ import annotations
|
| 49 |
+
|
| 50 |
+
import struct
|
| 51 |
+
|
| 52 |
+
import numpy as np
|
| 53 |
+
import torch
|
| 54 |
+
|
| 55 |
+
EIGENGRAM_MAGIC = b"EGR1"
|
| 56 |
+
EIGENGRAM_VERSION = 1
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
class EigramEncoder:
|
| 60 |
+
"""Encode and decode EIGENGRAM binary certificates.
|
| 61 |
+
|
| 62 |
+
A single instance handles both directions. EigramDecoder is an alias.
|
| 63 |
+
Float16 storage preserves cosine similarity to > 0.999.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
def encode(
|
| 67 |
+
self,
|
| 68 |
+
vec_perdoc: torch.Tensor,
|
| 69 |
+
vec_fcdb: torch.Tensor,
|
| 70 |
+
joint_center: torch.Tensor,
|
| 71 |
+
corpus_hash: str,
|
| 72 |
+
model_id: str,
|
| 73 |
+
basis_rank: int,
|
| 74 |
+
n_corpus: int,
|
| 75 |
+
layer_range: tuple[int, int],
|
| 76 |
+
context_len: int,
|
| 77 |
+
l2_norm: float,
|
| 78 |
+
scs: float,
|
| 79 |
+
margin_proof: float,
|
| 80 |
+
task_description: str,
|
| 81 |
+
cache_id: str,
|
| 82 |
+
vec_fourier: torch.Tensor | None = None,
|
| 83 |
+
local_density: int = 0,
|
| 84 |
+
eigenform_score: float = 1.0,
|
| 85 |
+
confusion_flag: bool = False,
|
| 86 |
+
vec_fourier_v2: torch.Tensor | None = None,
|
| 87 |
+
) -> bytes:
|
| 88 |
+
"""Serialise all fields into an EIGENGRAM binary blob."""
|
| 89 |
+
from datetime import datetime, timezone
|
| 90 |
+
|
| 91 |
+
td_b = task_description.encode("utf-8")[:256]
|
| 92 |
+
ci_b = cache_id.encode("utf-8")[:64]
|
| 93 |
+
now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S")
|
| 94 |
+
|
| 95 |
+
buf = bytearray()
|
| 96 |
+
buf += EIGENGRAM_MAGIC
|
| 97 |
+
buf += struct.pack("<B", EIGENGRAM_VERSION)
|
| 98 |
+
buf += corpus_hash.encode("ascii")[:32].ljust(32, b"\x00")
|
| 99 |
+
buf += now.encode("ascii")[:20].ljust(20, b"\x00")
|
| 100 |
+
buf += model_id.encode("ascii")[:16].ljust(16, b"\x00")
|
| 101 |
+
buf += struct.pack("<H", basis_rank)
|
| 102 |
+
buf += struct.pack("<H", n_corpus)
|
| 103 |
+
buf += struct.pack("<bb", layer_range[0], layer_range[1])
|
| 104 |
+
buf += struct.pack("<I", context_len)
|
| 105 |
+
buf += struct.pack("<f", l2_norm)
|
| 106 |
+
buf += struct.pack("<f", scs)
|
| 107 |
+
buf += struct.pack("<f", margin_proof)
|
| 108 |
+
buf += struct.pack("<H", len(td_b))
|
| 109 |
+
buf += struct.pack("<H", len(ci_b))
|
| 110 |
+
# fourier_dim: 0 if no vec_fourier, else len(vec_fourier)
|
| 111 |
+
fourier_dim = len(vec_fourier) if vec_fourier is not None else 0
|
| 112 |
+
buf += struct.pack("<H", fourier_dim)
|
| 113 |
+
buf += struct.pack("<H", local_density)
|
| 114 |
+
buf += struct.pack("<f", eigenform_score)
|
| 115 |
+
|
| 116 |
+
buf += vec_perdoc.to(torch.float16).numpy().tobytes()
|
| 117 |
+
buf += vec_fcdb.to(torch.float16).numpy().tobytes()
|
| 118 |
+
buf += joint_center[:128].to(torch.float16).numpy().tobytes()
|
| 119 |
+
|
| 120 |
+
buf += td_b
|
| 121 |
+
buf += ci_b
|
| 122 |
+
|
| 123 |
+
# Append vec_fourier if present (backward-compatible extension)
|
| 124 |
+
if vec_fourier is not None:
|
| 125 |
+
buf += vec_fourier.to(torch.float16).numpy().tobytes()
|
| 126 |
+
|
| 127 |
+
# v1.2 extension: confusion_flag + vec_fourier_v2
|
| 128 |
+
# Written only when at least one is non-default, preserving
|
| 129 |
+
# backward compat with readers that stop after vec_fourier.
|
| 130 |
+
if confusion_flag or vec_fourier_v2 is not None:
|
| 131 |
+
buf += struct.pack("<B", 1 if confusion_flag else 0)
|
| 132 |
+
v2_dim = len(vec_fourier_v2) if vec_fourier_v2 is not None else 0
|
| 133 |
+
buf += struct.pack("<H", v2_dim)
|
| 134 |
+
if vec_fourier_v2 is not None:
|
| 135 |
+
buf += vec_fourier_v2.to(torch.float16).numpy().tobytes()
|
| 136 |
+
|
| 137 |
+
return bytes(buf)
|
| 138 |
+
|
| 139 |
+
def decode(self, data: bytes) -> dict:
|
| 140 |
+
"""Deserialise an EIGENGRAM binary blob into a dict.
|
| 141 |
+
|
| 142 |
+
Returns dict with all fields. Vectors upcast to float32.
|
| 143 |
+
Raises ValueError on magic/version mismatch.
|
| 144 |
+
"""
|
| 145 |
+
if len(data) < 4 or data[:4] != EIGENGRAM_MAGIC:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
f"Invalid EIGENGRAM magic: {data[:4]!r} (expected {EIGENGRAM_MAGIC!r})"
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
off = 4
|
| 151 |
+
version = struct.unpack_from("<B", data, off)[0]; off += 1
|
| 152 |
+
if version != EIGENGRAM_VERSION:
|
| 153 |
+
raise ValueError(
|
| 154 |
+
f"Unsupported EIGENGRAM version {version} "
|
| 155 |
+
f"(this reader supports v{EIGENGRAM_VERSION})"
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
corpus_hash = data[off : off + 32].rstrip(b"\x00").decode("ascii"); off += 32
|
| 159 |
+
created_at = data[off : off + 20].rstrip(b"\x00").decode("ascii"); off += 20
|
| 160 |
+
model_id = data[off : off + 16].rstrip(b"\x00").decode("ascii"); off += 16
|
| 161 |
+
|
| 162 |
+
basis_rank = struct.unpack_from("<H", data, off)[0]; off += 2
|
| 163 |
+
n_corpus = struct.unpack_from("<H", data, off)[0]; off += 2
|
| 164 |
+
lr0, lr1 = struct.unpack_from("<bb", data, off); off += 2
|
| 165 |
+
context_len = struct.unpack_from("<I", data, off)[0]; off += 4
|
| 166 |
+
l2_norm = struct.unpack_from("<f", data, off)[0]; off += 4
|
| 167 |
+
scs = struct.unpack_from("<f", data, off)[0]; off += 4
|
| 168 |
+
margin_proof = struct.unpack_from("<f", data, off)[0]; off += 4
|
| 169 |
+
td_len = struct.unpack_from("<H", data, off)[0]; off += 2
|
| 170 |
+
ci_len = struct.unpack_from("<H", data, off)[0]; off += 2
|
| 171 |
+
|
| 172 |
+
# v1.1 extension fields: fourier_dim + local_density
|
| 173 |
+
# Detect by checking if file has extra bytes beyond v1.0 layout
|
| 174 |
+
fourier_dim = 0
|
| 175 |
+
local_density = 0
|
| 176 |
+
expected_old_size = off + basis_rank * 4 + 256 + td_len + ci_len
|
| 177 |
+
eigenform_score = 1.0
|
| 178 |
+
if len(data) > expected_old_size + 4:
|
| 179 |
+
fourier_dim = struct.unpack_from("<H", data, off)[0]; off += 2
|
| 180 |
+
local_density = struct.unpack_from("<H", data, off)[0]; off += 2
|
| 181 |
+
eigenform_score = struct.unpack_from("<f", data, off)[0]; off += 4
|
| 182 |
+
# If the file was written with fourier_dim field but is old format,
|
| 183 |
+
# we already consumed 2 bytes. This is safe because old files
|
| 184 |
+
# won't have extra bytes.
|
| 185 |
+
|
| 186 |
+
R = basis_rank
|
| 187 |
+
vec_perdoc = torch.from_numpy(
|
| 188 |
+
np.frombuffer(data, dtype=np.float16, count=R, offset=off).copy()
|
| 189 |
+
).float(); off += R * 2
|
| 190 |
+
|
| 191 |
+
vec_fcdb = torch.from_numpy(
|
| 192 |
+
np.frombuffer(data, dtype=np.float16, count=R, offset=off).copy()
|
| 193 |
+
).float(); off += R * 2
|
| 194 |
+
|
| 195 |
+
joint_center = torch.from_numpy(
|
| 196 |
+
np.frombuffer(data, dtype=np.float16, count=128, offset=off).copy()
|
| 197 |
+
).float(); off += 128 * 2
|
| 198 |
+
|
| 199 |
+
task_description = data[off : off + td_len].decode("utf-8", errors="replace"); off += td_len
|
| 200 |
+
cache_id = data[off : off + ci_len].decode("utf-8", errors="replace"); off += ci_len
|
| 201 |
+
|
| 202 |
+
# Read vec_fourier if present
|
| 203 |
+
vec_fourier = None
|
| 204 |
+
if fourier_dim > 0 and off + fourier_dim * 2 <= len(data):
|
| 205 |
+
vec_fourier = torch.from_numpy(
|
| 206 |
+
np.frombuffer(data, dtype=np.float16, count=fourier_dim, offset=off).copy()
|
| 207 |
+
).float()
|
| 208 |
+
off += fourier_dim * 2
|
| 209 |
+
|
| 210 |
+
# v1.2 extension: confusion_flag + vec_fourier_v2
|
| 211 |
+
confusion_flag = False
|
| 212 |
+
vec_fourier_v2 = None
|
| 213 |
+
if off + 3 <= len(data): # 1 byte flag + 2 byte dim minimum
|
| 214 |
+
confusion_flag = bool(struct.unpack_from("<B", data, off)[0])
|
| 215 |
+
off += 1
|
| 216 |
+
v2_dim = struct.unpack_from("<H", data, off)[0]
|
| 217 |
+
off += 2
|
| 218 |
+
if v2_dim > 0 and off + v2_dim * 2 <= len(data):
|
| 219 |
+
vec_fourier_v2 = torch.from_numpy(
|
| 220 |
+
np.frombuffer(data, dtype=np.float16, count=v2_dim, offset=off).copy()
|
| 221 |
+
).float()
|
| 222 |
+
|
| 223 |
+
result = {
|
| 224 |
+
"version": version,
|
| 225 |
+
"corpus_hash": corpus_hash,
|
| 226 |
+
"created_at": created_at,
|
| 227 |
+
"model_id": model_id,
|
| 228 |
+
"basis_rank": basis_rank,
|
| 229 |
+
"n_corpus": n_corpus,
|
| 230 |
+
"layer_range": (lr0, lr1),
|
| 231 |
+
"context_len": context_len,
|
| 232 |
+
"l2_norm": l2_norm,
|
| 233 |
+
"scs": scs,
|
| 234 |
+
"margin_proof": margin_proof,
|
| 235 |
+
"vec_perdoc": vec_perdoc,
|
| 236 |
+
"vec_fcdb": vec_fcdb,
|
| 237 |
+
"joint_center": joint_center,
|
| 238 |
+
"task_description": task_description,
|
| 239 |
+
"cache_id": cache_id,
|
| 240 |
+
}
|
| 241 |
+
if vec_fourier is not None:
|
| 242 |
+
result["vec_fourier"] = vec_fourier
|
| 243 |
+
if vec_fourier_v2 is not None:
|
| 244 |
+
result["vec_fourier_v2"] = vec_fourier_v2
|
| 245 |
+
result["local_density"] = local_density
|
| 246 |
+
result["eigenform_score"] = eigenform_score
|
| 247 |
+
result["confusion_flag"] = confusion_flag
|
| 248 |
+
return result
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
EigramDecoder = EigramEncoder
|
kvcos/engram/hnsw_index.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM HNSW Index — O(log N) approximate nearest neighbor retrieval.
|
| 3 |
+
|
| 4 |
+
Wraps faiss.IndexHNSWFlat for production-scale ENGRAM search.
|
| 5 |
+
Primary fingerprint: v2 layer-normalized Fourier f0+f1.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
idx = EngramIndex(dim=2048)
|
| 9 |
+
idx.add_batch(doc_ids, vectors)
|
| 10 |
+
results = idx.search(query_fp, top_k=5)
|
| 11 |
+
idx.save('index/hnsw')
|
| 12 |
+
idx2 = EngramIndex.load('index/hnsw')
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import json
|
| 18 |
+
import logging
|
| 19 |
+
import os
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
|
| 22 |
+
import faiss
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
|
| 27 |
+
logger = logging.getLogger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class HNSWResult:
|
| 32 |
+
"""Single HNSW search result."""
|
| 33 |
+
|
| 34 |
+
doc_id: str
|
| 35 |
+
score: float
|
| 36 |
+
rank: int
|
| 37 |
+
margin: float = 0.0
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class EngramIndex:
|
| 41 |
+
"""HNSW-backed ENGRAM retrieval index.
|
| 42 |
+
|
| 43 |
+
HNSW parameters:
|
| 44 |
+
M=32: graph degree (higher = better recall, more memory)
|
| 45 |
+
efConstruction=200: build-time search width
|
| 46 |
+
efSearch=64: query-time search width
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
M = 32
|
| 50 |
+
EF_CONSTRUCTION = 200
|
| 51 |
+
EF_SEARCH = 64
|
| 52 |
+
|
| 53 |
+
def __init__(self, dim: int = 2048):
|
| 54 |
+
self._dim = dim
|
| 55 |
+
self._index: faiss.IndexHNSWFlat | None = None
|
| 56 |
+
self._ids: list[str] = []
|
| 57 |
+
self._id_to_pos: dict[str, int] = {}
|
| 58 |
+
self._n_docs: int = 0
|
| 59 |
+
|
| 60 |
+
def add_batch(
|
| 61 |
+
self,
|
| 62 |
+
doc_ids: list[str],
|
| 63 |
+
vectors: torch.Tensor,
|
| 64 |
+
) -> None:
|
| 65 |
+
"""Build HNSW index from vectors.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
doc_ids: list of document identifiers
|
| 69 |
+
vectors: [N, dim] tensor of fingerprints
|
| 70 |
+
"""
|
| 71 |
+
matrix = F.normalize(vectors.float(), dim=-1).numpy().astype("float32")
|
| 72 |
+
self._dim = matrix.shape[1]
|
| 73 |
+
self._ids = list(doc_ids)
|
| 74 |
+
self._id_to_pos = {cid: i for i, cid in enumerate(doc_ids)}
|
| 75 |
+
self._n_docs = len(doc_ids)
|
| 76 |
+
|
| 77 |
+
self._index = faiss.IndexHNSWFlat(self._dim, self.M)
|
| 78 |
+
self._index.hnsw.efConstruction = self.EF_CONSTRUCTION
|
| 79 |
+
self._index.hnsw.efSearch = self.EF_SEARCH
|
| 80 |
+
self._index.add(matrix)
|
| 81 |
+
|
| 82 |
+
def build(
|
| 83 |
+
self,
|
| 84 |
+
eng_files: list[str],
|
| 85 |
+
fp_key: str = "vec_fourier_v2",
|
| 86 |
+
verbose: bool = True,
|
| 87 |
+
) -> None:
|
| 88 |
+
"""Build HNSW index from list of .eng file paths.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
eng_files: List of paths to .eng encoded files.
|
| 92 |
+
fp_key: Fingerprint field to index.
|
| 93 |
+
Default 'vec_fourier_v2' (S3 validated, 99.5% recall).
|
| 94 |
+
Falls back to 'vec_fourier' if v2 not present.
|
| 95 |
+
"""
|
| 96 |
+
from kvcos.engram.reader import read_eigengram
|
| 97 |
+
|
| 98 |
+
doc_ids = []
|
| 99 |
+
vecs = []
|
| 100 |
+
missing_v2 = 0
|
| 101 |
+
|
| 102 |
+
for fp in eng_files:
|
| 103 |
+
data = read_eigengram(fp)
|
| 104 |
+
cid = data.get("cache_id")
|
| 105 |
+
if not cid:
|
| 106 |
+
continue
|
| 107 |
+
|
| 108 |
+
vec = data.get(fp_key)
|
| 109 |
+
if vec is None:
|
| 110 |
+
vec = data.get("vec_fourier")
|
| 111 |
+
missing_v2 += 1
|
| 112 |
+
if vec is None:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
doc_ids.append(cid)
|
| 116 |
+
vecs.append(vec.float())
|
| 117 |
+
|
| 118 |
+
if not vecs:
|
| 119 |
+
raise ValueError(
|
| 120 |
+
f"No valid fingerprints found in {len(eng_files)} files"
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
if missing_v2 > 0 and verbose:
|
| 124 |
+
logger.warning(
|
| 125 |
+
"%d docs missing %s, used vec_fourier fallback",
|
| 126 |
+
missing_v2, fp_key,
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
self.add_batch(doc_ids, torch.stack(vecs))
|
| 130 |
+
|
| 131 |
+
if verbose:
|
| 132 |
+
logger.info("HNSW index built: %d docs, dim=%d", self._n_docs, self._dim)
|
| 133 |
+
logger.info("M=%d, efC=%d, efS=%d", self.M, self.EF_CONSTRUCTION, self.EF_SEARCH)
|
| 134 |
+
|
| 135 |
+
def search(
|
| 136 |
+
self,
|
| 137 |
+
query_fp: torch.Tensor,
|
| 138 |
+
top_k: int = 5,
|
| 139 |
+
) -> list[HNSWResult]:
|
| 140 |
+
"""Search the HNSW index.
|
| 141 |
+
|
| 142 |
+
Returns list of HNSWResult sorted by score descending.
|
| 143 |
+
HNSW uses L2 on normalized vectors: cosine = 1 - L2^2/2.
|
| 144 |
+
"""
|
| 145 |
+
if self._index is None:
|
| 146 |
+
raise RuntimeError("Index not built. Call add_batch() or load() first.")
|
| 147 |
+
|
| 148 |
+
qn = F.normalize(query_fp.float().unsqueeze(0), dim=-1).numpy().astype("float32")
|
| 149 |
+
D, I = self._index.search(qn, min(top_k + 1, self._n_docs))
|
| 150 |
+
|
| 151 |
+
results = []
|
| 152 |
+
for rank, (dist, idx) in enumerate(zip(D[0], I[0])):
|
| 153 |
+
if idx < 0:
|
| 154 |
+
continue
|
| 155 |
+
cosine_sim = float(1.0 - dist / 2.0)
|
| 156 |
+
results.append(HNSWResult(
|
| 157 |
+
doc_id=self._ids[idx], score=cosine_sim, rank=rank,
|
| 158 |
+
))
|
| 159 |
+
|
| 160 |
+
if len(results) >= 2:
|
| 161 |
+
results[0].margin = results[0].score - results[1].score
|
| 162 |
+
return results[:top_k]
|
| 163 |
+
|
| 164 |
+
def save(self, path: str) -> None:
|
| 165 |
+
"""Save index to disk (faiss + JSON metadata)."""
|
| 166 |
+
os.makedirs(os.path.dirname(path) or ".", exist_ok=True)
|
| 167 |
+
faiss.write_index(self._index, path + ".faiss")
|
| 168 |
+
meta_path = path + ".meta.json"
|
| 169 |
+
with open(meta_path, "w") as f:
|
| 170 |
+
json.dump({
|
| 171 |
+
"ids": self._ids,
|
| 172 |
+
"id_to_pos": self._id_to_pos,
|
| 173 |
+
"dim": self._dim,
|
| 174 |
+
"n_docs": self._n_docs,
|
| 175 |
+
}, f, indent=2)
|
| 176 |
+
|
| 177 |
+
@classmethod
|
| 178 |
+
def load(cls, path: str) -> EngramIndex:
|
| 179 |
+
"""Load index from disk."""
|
| 180 |
+
obj = cls()
|
| 181 |
+
obj._index = faiss.read_index(path + ".faiss")
|
| 182 |
+
meta_path = path + ".meta.json"
|
| 183 |
+
with open(meta_path, "r") as f:
|
| 184 |
+
meta = json.load(f)
|
| 185 |
+
obj._ids = meta["ids"]
|
| 186 |
+
obj._id_to_pos = meta["id_to_pos"]
|
| 187 |
+
obj._dim = meta["dim"]
|
| 188 |
+
obj._n_docs = meta["n_docs"]
|
| 189 |
+
return obj
|
| 190 |
+
|
| 191 |
+
def __len__(self) -> int:
|
| 192 |
+
return self._n_docs
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_vector(self, doc_id: str) -> torch.Tensor | None:
|
| 196 |
+
"""Return stored vector for doc_id, or None if not found."""
|
| 197 |
+
pos = self._id_to_pos.get(doc_id)
|
| 198 |
+
if pos is None:
|
| 199 |
+
return None
|
| 200 |
+
vec_np = np.zeros(self._dim, dtype="float32")
|
| 201 |
+
self._index.reconstruct(pos, vec_np)
|
| 202 |
+
return torch.from_numpy(vec_np)
|
| 203 |
+
|
| 204 |
+
def __repr__(self) -> str:
|
| 205 |
+
return f"EngramIndex(n={self._n_docs}, dim={self._dim}, M={self.M})"
|
kvcos/engram/index_c.py
ADDED
|
@@ -0,0 +1,426 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/index_c.py — Confidence history index for ENGRAM.
|
| 3 |
+
|
| 4 |
+
Stores retrieval confidence records across sessions.
|
| 5 |
+
Makes the system self-aware: chronic failures are known before retrieval.
|
| 6 |
+
|
| 7 |
+
Schema:
|
| 8 |
+
retrievals: one row per geodesic_retrieve() call
|
| 9 |
+
confusion_pairs: doc pairs that confuse each other (confidence<threshold)
|
| 10 |
+
doc_stats: per-doc aggregate reliability scores
|
| 11 |
+
|
| 12 |
+
Usage:
|
| 13 |
+
ic = IndexC.open("results/index_c.db")
|
| 14 |
+
ic.record(session_id="s1", query_doc_id="doc_146", result=geodesic_result)
|
| 15 |
+
prior = ic.prior("doc_146")
|
| 16 |
+
pairs = ic.confusion_registry()
|
| 17 |
+
rmap = ic.reliability_map()
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import sqlite3
|
| 24 |
+
import time
|
| 25 |
+
from dataclasses import dataclass
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@dataclass
|
| 30 |
+
class ConfidenceRecord:
|
| 31 |
+
session_id: str
|
| 32 |
+
query_doc_id: str
|
| 33 |
+
result_doc_id: str
|
| 34 |
+
confidence: str
|
| 35 |
+
margin: float
|
| 36 |
+
stages_used: int
|
| 37 |
+
constraint_used: bool
|
| 38 |
+
correct: bool
|
| 39 |
+
ts: float
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class DocPrior:
|
| 44 |
+
"""Prior confidence distribution for a doc_id."""
|
| 45 |
+
|
| 46 |
+
doc_id: str
|
| 47 |
+
n_high: int
|
| 48 |
+
n_medium: int
|
| 49 |
+
n_low: int
|
| 50 |
+
n_total: int
|
| 51 |
+
reliability: float
|
| 52 |
+
is_chronic_failure: bool
|
| 53 |
+
|
| 54 |
+
@property
|
| 55 |
+
def dominant_confidence(self) -> str:
|
| 56 |
+
if self.n_total == 0:
|
| 57 |
+
return "unknown"
|
| 58 |
+
counts = {
|
| 59 |
+
"high": self.n_high,
|
| 60 |
+
"medium": self.n_medium,
|
| 61 |
+
"low": self.n_low,
|
| 62 |
+
}
|
| 63 |
+
return max(counts, key=counts.get)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
@dataclass
|
| 67 |
+
class ConfusionPair:
|
| 68 |
+
doc_a: str
|
| 69 |
+
doc_b: str
|
| 70 |
+
n_confusions: int
|
| 71 |
+
first_seen: float
|
| 72 |
+
last_seen: float
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
SCHEMA = """
|
| 76 |
+
CREATE TABLE IF NOT EXISTS retrievals (
|
| 77 |
+
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
| 78 |
+
session_id TEXT NOT NULL,
|
| 79 |
+
query_doc_id TEXT NOT NULL,
|
| 80 |
+
result_doc_id TEXT NOT NULL,
|
| 81 |
+
confidence TEXT NOT NULL,
|
| 82 |
+
margin REAL NOT NULL,
|
| 83 |
+
stages_used INTEGER NOT NULL,
|
| 84 |
+
constraint_used INTEGER NOT NULL,
|
| 85 |
+
correct INTEGER NOT NULL,
|
| 86 |
+
ts REAL NOT NULL
|
| 87 |
+
);
|
| 88 |
+
|
| 89 |
+
CREATE INDEX IF NOT EXISTS idx_ret_query ON retrievals(query_doc_id);
|
| 90 |
+
CREATE INDEX IF NOT EXISTS idx_ret_result ON retrievals(result_doc_id);
|
| 91 |
+
CREATE INDEX IF NOT EXISTS idx_ret_conf ON retrievals(confidence);
|
| 92 |
+
CREATE INDEX IF NOT EXISTS idx_ret_sess ON retrievals(session_id);
|
| 93 |
+
|
| 94 |
+
CREATE TABLE IF NOT EXISTS confusion_pairs (
|
| 95 |
+
doc_a TEXT NOT NULL,
|
| 96 |
+
doc_b TEXT NOT NULL,
|
| 97 |
+
n_confusions INTEGER NOT NULL DEFAULT 1,
|
| 98 |
+
first_seen REAL NOT NULL,
|
| 99 |
+
last_seen REAL NOT NULL,
|
| 100 |
+
PRIMARY KEY (doc_a, doc_b)
|
| 101 |
+
);
|
| 102 |
+
|
| 103 |
+
CREATE TABLE IF NOT EXISTS doc_stats (
|
| 104 |
+
doc_id TEXT PRIMARY KEY,
|
| 105 |
+
n_high INTEGER NOT NULL DEFAULT 0,
|
| 106 |
+
n_medium INTEGER NOT NULL DEFAULT 0,
|
| 107 |
+
n_low INTEGER NOT NULL DEFAULT 0,
|
| 108 |
+
reliability REAL NOT NULL DEFAULT 1.0,
|
| 109 |
+
last_updated REAL NOT NULL
|
| 110 |
+
);
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class IndexC:
|
| 115 |
+
"""
|
| 116 |
+
Confidence history index.
|
| 117 |
+
|
| 118 |
+
Backed by SQLite — append-only, persistent across sessions.
|
| 119 |
+
Provides priors for geodesic_retrieve() to pre-apply constraints
|
| 120 |
+
on docs that are known chronic failures.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
CHRONIC_FAILURE_THRESHOLD = 0.5
|
| 124 |
+
|
| 125 |
+
def __init__(self, db_path: str):
|
| 126 |
+
self._db_path = str(db_path)
|
| 127 |
+
self._conn = sqlite3.connect(
|
| 128 |
+
db_path, check_same_thread=False, isolation_level=None
|
| 129 |
+
)
|
| 130 |
+
self._conn.execute("PRAGMA journal_mode=WAL")
|
| 131 |
+
self._conn.executescript(SCHEMA)
|
| 132 |
+
|
| 133 |
+
@classmethod
|
| 134 |
+
def open(cls, db_path: str | Path) -> "IndexC":
|
| 135 |
+
"""Open (or create) the Index-C database at db_path."""
|
| 136 |
+
os.makedirs(Path(db_path).parent, exist_ok=True)
|
| 137 |
+
return cls(str(db_path))
|
| 138 |
+
|
| 139 |
+
# ── WRITE ────────────────────────────────────────────────────────
|
| 140 |
+
|
| 141 |
+
def record(
|
| 142 |
+
self,
|
| 143 |
+
session_id: str,
|
| 144 |
+
query_doc_id: str,
|
| 145 |
+
result_doc_id: str,
|
| 146 |
+
confidence: str,
|
| 147 |
+
margin: float,
|
| 148 |
+
stages_used: int = 1,
|
| 149 |
+
constraint_used: bool = False,
|
| 150 |
+
correct: bool = True,
|
| 151 |
+
ts: float | None = None,
|
| 152 |
+
) -> None:
|
| 153 |
+
"""Log one retrieval result to the index."""
|
| 154 |
+
ts = ts or time.time()
|
| 155 |
+
with self._conn:
|
| 156 |
+
self._conn.execute(
|
| 157 |
+
"""INSERT INTO retrievals
|
| 158 |
+
(session_id, query_doc_id, result_doc_id, confidence,
|
| 159 |
+
margin, stages_used, constraint_used, correct, ts)
|
| 160 |
+
VALUES (?,?,?,?,?,?,?,?,?)""",
|
| 161 |
+
(session_id, query_doc_id, result_doc_id, confidence,
|
| 162 |
+
float(margin), int(stages_used), int(constraint_used),
|
| 163 |
+
int(correct), float(ts)),
|
| 164 |
+
)
|
| 165 |
+
if not correct:
|
| 166 |
+
self._register_confusion(query_doc_id, result_doc_id, ts)
|
| 167 |
+
self._update_doc_stats(query_doc_id, confidence, ts)
|
| 168 |
+
|
| 169 |
+
def _register_confusion(
|
| 170 |
+
self, doc_a: str, doc_b: str, ts: float
|
| 171 |
+
) -> None:
|
| 172 |
+
"""Insert or increment confusion pair."""
|
| 173 |
+
existing = self._conn.execute(
|
| 174 |
+
"SELECT n_confusions FROM confusion_pairs "
|
| 175 |
+
"WHERE doc_a=? AND doc_b=?",
|
| 176 |
+
(doc_a, doc_b),
|
| 177 |
+
).fetchone()
|
| 178 |
+
if existing:
|
| 179 |
+
self._conn.execute(
|
| 180 |
+
"UPDATE confusion_pairs SET n_confusions=n_confusions+1, "
|
| 181 |
+
"last_seen=? WHERE doc_a=? AND doc_b=?",
|
| 182 |
+
(ts, doc_a, doc_b),
|
| 183 |
+
)
|
| 184 |
+
else:
|
| 185 |
+
self._conn.execute(
|
| 186 |
+
"INSERT INTO confusion_pairs "
|
| 187 |
+
"(doc_a, doc_b, n_confusions, first_seen, last_seen) "
|
| 188 |
+
"VALUES (?,?,1,?,?)",
|
| 189 |
+
(doc_a, doc_b, ts, ts),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def _update_doc_stats(
|
| 193 |
+
self, doc_id: str, confidence: str, ts: float
|
| 194 |
+
) -> None:
|
| 195 |
+
"""Upsert doc_stats row for doc_id."""
|
| 196 |
+
col_map = {"high": "n_high", "medium": "n_medium", "low": "n_low"}
|
| 197 |
+
col = col_map.get(confidence, "n_medium")
|
| 198 |
+
|
| 199 |
+
existing = self._conn.execute(
|
| 200 |
+
"SELECT n_high, n_medium, n_low FROM doc_stats WHERE doc_id=?",
|
| 201 |
+
(doc_id,),
|
| 202 |
+
).fetchone()
|
| 203 |
+
|
| 204 |
+
if existing:
|
| 205 |
+
n_high, n_medium, n_low = existing
|
| 206 |
+
if col == "n_high":
|
| 207 |
+
n_high += 1
|
| 208 |
+
elif col == "n_medium":
|
| 209 |
+
n_medium += 1
|
| 210 |
+
else:
|
| 211 |
+
n_low += 1
|
| 212 |
+
n_total = n_high + n_medium + n_low
|
| 213 |
+
reliability = (n_high + n_medium) / n_total if n_total > 0 else 1.0
|
| 214 |
+
self._conn.execute(
|
| 215 |
+
"UPDATE doc_stats SET n_high=?, n_medium=?, n_low=?, "
|
| 216 |
+
"reliability=?, last_updated=? WHERE doc_id=?",
|
| 217 |
+
(n_high, n_medium, n_low, reliability, ts, doc_id),
|
| 218 |
+
)
|
| 219 |
+
else:
|
| 220 |
+
vals = {"n_high": 0, "n_medium": 0, "n_low": 0}
|
| 221 |
+
vals[col] = 1
|
| 222 |
+
reliability = (vals["n_high"] + vals["n_medium"]) / 1
|
| 223 |
+
self._conn.execute(
|
| 224 |
+
"INSERT INTO doc_stats "
|
| 225 |
+
"(doc_id, n_high, n_medium, n_low, reliability, last_updated) "
|
| 226 |
+
"VALUES (?,?,?,?,?,?)",
|
| 227 |
+
(doc_id, vals["n_high"], vals["n_medium"],
|
| 228 |
+
vals["n_low"], reliability, ts),
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
# ── READ ─────────────────────────────────────────────────────────
|
| 232 |
+
|
| 233 |
+
def prior(self, doc_id: str) -> DocPrior:
|
| 234 |
+
"""Return prior confidence distribution for doc_id."""
|
| 235 |
+
row = self._conn.execute(
|
| 236 |
+
"SELECT n_high, n_medium, n_low, reliability "
|
| 237 |
+
"FROM doc_stats WHERE doc_id=?",
|
| 238 |
+
(doc_id,),
|
| 239 |
+
).fetchone()
|
| 240 |
+
if not row:
|
| 241 |
+
return DocPrior(
|
| 242 |
+
doc_id=doc_id, n_high=0, n_medium=0, n_low=0,
|
| 243 |
+
n_total=0, reliability=1.0, is_chronic_failure=False,
|
| 244 |
+
)
|
| 245 |
+
n_high, n_medium, n_low, reliability = row
|
| 246 |
+
n_total = n_high + n_medium + n_low
|
| 247 |
+
return DocPrior(
|
| 248 |
+
doc_id=doc_id,
|
| 249 |
+
n_high=n_high,
|
| 250 |
+
n_medium=n_medium,
|
| 251 |
+
n_low=n_low,
|
| 252 |
+
n_total=n_total,
|
| 253 |
+
reliability=reliability,
|
| 254 |
+
is_chronic_failure=(
|
| 255 |
+
n_low / n_total > self.CHRONIC_FAILURE_THRESHOLD
|
| 256 |
+
if n_total > 0
|
| 257 |
+
else False
|
| 258 |
+
),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
def confusion_registry(
|
| 262 |
+
self, min_confusions: int = 1
|
| 263 |
+
) -> list[ConfusionPair]:
|
| 264 |
+
"""Return known confusion pairs with >= min_confusions."""
|
| 265 |
+
rows = self._conn.execute(
|
| 266 |
+
"SELECT doc_a, doc_b, n_confusions, first_seen, last_seen "
|
| 267 |
+
"FROM confusion_pairs WHERE n_confusions >= ? "
|
| 268 |
+
"ORDER BY n_confusions DESC",
|
| 269 |
+
(min_confusions,),
|
| 270 |
+
).fetchall()
|
| 271 |
+
return [
|
| 272 |
+
ConfusionPair(
|
| 273 |
+
doc_a=r[0], doc_b=r[1], n_confusions=r[2],
|
| 274 |
+
first_seen=r[3], last_seen=r[4],
|
| 275 |
+
)
|
| 276 |
+
for r in rows
|
| 277 |
+
]
|
| 278 |
+
|
| 279 |
+
def reliability_map(self) -> dict[str, float]:
|
| 280 |
+
"""Return {doc_id: reliability_score} for all tracked docs."""
|
| 281 |
+
rows = self._conn.execute(
|
| 282 |
+
"SELECT doc_id, reliability FROM doc_stats ORDER BY reliability"
|
| 283 |
+
).fetchall()
|
| 284 |
+
return {r[0]: float(r[1]) for r in rows}
|
| 285 |
+
|
| 286 |
+
def session_history(self, session_id: str) -> list[ConfidenceRecord]:
|
| 287 |
+
"""Return all records for a session_id."""
|
| 288 |
+
rows = self._conn.execute(
|
| 289 |
+
"SELECT session_id, query_doc_id, result_doc_id, confidence, "
|
| 290 |
+
"margin, stages_used, constraint_used, correct, ts "
|
| 291 |
+
"FROM retrievals WHERE session_id=? ORDER BY ts",
|
| 292 |
+
(session_id,),
|
| 293 |
+
).fetchall()
|
| 294 |
+
return [
|
| 295 |
+
ConfidenceRecord(
|
| 296 |
+
session_id=r[0], query_doc_id=r[1], result_doc_id=r[2],
|
| 297 |
+
confidence=r[3], margin=float(r[4]), stages_used=int(r[5]),
|
| 298 |
+
constraint_used=bool(r[6]), correct=bool(r[7]), ts=float(r[8]),
|
| 299 |
+
)
|
| 300 |
+
for r in rows
|
| 301 |
+
]
|
| 302 |
+
|
| 303 |
+
def n_sessions(self) -> int:
|
| 304 |
+
"""Number of distinct sessions recorded."""
|
| 305 |
+
row = self._conn.execute(
|
| 306 |
+
"SELECT COUNT(DISTINCT session_id) FROM retrievals"
|
| 307 |
+
).fetchone()
|
| 308 |
+
return row[0] if row else 0
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# ── RECENCY-WEIGHTED RELIABILITY ────────────────────────────────
|
| 313 |
+
|
| 314 |
+
def weighted_reliability(
|
| 315 |
+
self,
|
| 316 |
+
doc_id: str,
|
| 317 |
+
decay: float = 0.85,
|
| 318 |
+
) -> float:
|
| 319 |
+
"""
|
| 320 |
+
Exponentially weighted reliability score.
|
| 321 |
+
Newer retrievals have higher weight.
|
| 322 |
+
decay=0.85: a failure 5 sessions ago counts 0.85^5 = 0.44
|
| 323 |
+
of a failure last session.
|
| 324 |
+
|
| 325 |
+
Returns float in [0, 1]. Returns 1.0 if no history.
|
| 326 |
+
"""
|
| 327 |
+
rows = self._conn.execute(
|
| 328 |
+
"""SELECT correct FROM retrievals
|
| 329 |
+
WHERE query_doc_id=?
|
| 330 |
+
ORDER BY ts ASC""",
|
| 331 |
+
(doc_id,),
|
| 332 |
+
).fetchall()
|
| 333 |
+
if not rows:
|
| 334 |
+
return 1.0
|
| 335 |
+
history = [bool(r[0]) for r in rows]
|
| 336 |
+
n = len(history)
|
| 337 |
+
weights = [decay ** (n - 1 - i) for i in range(n)]
|
| 338 |
+
total_w = sum(weights)
|
| 339 |
+
score = sum(w * int(h) for w, h in zip(weights, history))
|
| 340 |
+
return round(score / total_w, 6) if total_w > 0 else 1.0
|
| 341 |
+
|
| 342 |
+
# ── INDEX GROWTH REVALIDATION ────────────────────────────────────
|
| 343 |
+
|
| 344 |
+
def on_document_added(
|
| 345 |
+
self,
|
| 346 |
+
new_doc_id: str,
|
| 347 |
+
new_vec, # torch.Tensor [dim]
|
| 348 |
+
hnsw_index, # EngramIndex instance
|
| 349 |
+
revalidation_radius: float = 0.85,
|
| 350 |
+
density_threshold: int = 3,
|
| 351 |
+
) -> list[str]:
|
| 352 |
+
"""
|
| 353 |
+
Call after adding a document to the HNSW index.
|
| 354 |
+
Recomputes local_density for neighbors of new_doc_id
|
| 355 |
+
whose similarity > revalidation_radius.
|
| 356 |
+
|
| 357 |
+
Returns list of doc_ids whose density was updated.
|
| 358 |
+
|
| 359 |
+
Why this matters:
|
| 360 |
+
At N=200, doc_042 is in sparse space (density=1).
|
| 361 |
+
At N=500, doc_201 lands near doc_042 (cosine=0.88).
|
| 362 |
+
doc_042 is now in a denser region — its confidence tier
|
| 363 |
+
has degraded. This method detects and records that.
|
| 364 |
+
"""
|
| 365 |
+
import torch
|
| 366 |
+
import torch.nn.functional as F
|
| 367 |
+
|
| 368 |
+
updated = []
|
| 369 |
+
try:
|
| 370 |
+
results = hnsw_index.search(new_vec, top_k=20)
|
| 371 |
+
except Exception:
|
| 372 |
+
return updated
|
| 373 |
+
|
| 374 |
+
for r in results:
|
| 375 |
+
if r.doc_id == new_doc_id:
|
| 376 |
+
continue
|
| 377 |
+
if r.score < revalidation_radius:
|
| 378 |
+
continue
|
| 379 |
+
|
| 380 |
+
# Recompute density for this neighbor
|
| 381 |
+
neighbor_vec = hnsw_index.get_vector(r.doc_id)
|
| 382 |
+
if neighbor_vec is None:
|
| 383 |
+
continue
|
| 384 |
+
|
| 385 |
+
all_results = hnsw_index.search(neighbor_vec, top_k=50)
|
| 386 |
+
new_density = sum(
|
| 387 |
+
1 for x in all_results
|
| 388 |
+
if x.doc_id != r.doc_id and x.score > revalidation_radius
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
ts = __import__("time").time()
|
| 392 |
+
existing = self._conn.execute(
|
| 393 |
+
"SELECT n_high FROM doc_stats WHERE doc_id=?",
|
| 394 |
+
(r.doc_id,),
|
| 395 |
+
).fetchone()
|
| 396 |
+
|
| 397 |
+
if existing:
|
| 398 |
+
self._conn.execute(
|
| 399 |
+
"UPDATE doc_stats SET last_updated=? WHERE doc_id=?",
|
| 400 |
+
(ts, r.doc_id),
|
| 401 |
+
)
|
| 402 |
+
else:
|
| 403 |
+
self._conn.execute(
|
| 404 |
+
"INSERT OR IGNORE INTO doc_stats "
|
| 405 |
+
"(doc_id, n_high, n_medium, n_low, reliability, last_updated) "
|
| 406 |
+
"VALUES (?,0,0,0,1.0,?)",
|
| 407 |
+
(r.doc_id, ts),
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
# If density crossed threshold, register as needing constraint activation
|
| 411 |
+
if new_density > density_threshold:
|
| 412 |
+
self._register_confusion(r.doc_id, new_doc_id, ts)
|
| 413 |
+
|
| 414 |
+
updated.append(r.doc_id)
|
| 415 |
+
self._conn.commit()
|
| 416 |
+
|
| 417 |
+
return updated
|
| 418 |
+
|
| 419 |
+
def close(self) -> None:
|
| 420 |
+
self._conn.close()
|
| 421 |
+
|
| 422 |
+
def __repr__(self) -> str:
|
| 423 |
+
n = self._conn.execute(
|
| 424 |
+
"SELECT COUNT(*) FROM retrievals"
|
| 425 |
+
).fetchone()[0]
|
| 426 |
+
return f"IndexC(db={self._db_path!r}, n_records={n})"
|
kvcos/engram/knowledge_index.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/knowledge_index.py — HNSW index over the knowledge store.
|
| 3 |
+
|
| 4 |
+
Builds and maintains a faiss HNSW index over all .eng files in
|
| 5 |
+
~/.engram/knowledge/. Supports dynamic dimension (384 for sbert,
|
| 6 |
+
2048 for llama_cpp/hash) — determined at build time from the first
|
| 7 |
+
.eng file.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# Build from all knowledge .eng files
|
| 11 |
+
kidx = KnowledgeIndex.build_from_knowledge_dir()
|
| 12 |
+
results = kidx.search("HNSW recall benchmark", k=5)
|
| 13 |
+
kidx.save()
|
| 14 |
+
|
| 15 |
+
# Load pre-built index
|
| 16 |
+
kidx = KnowledgeIndex.load()
|
| 17 |
+
results = kidx.search("testing patterns", k=3)
|
| 18 |
+
|
| 19 |
+
Index files:
|
| 20 |
+
~/.engram/index/knowledge.faiss
|
| 21 |
+
~/.engram/index/knowledge.meta
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import json
|
| 27 |
+
import logging
|
| 28 |
+
import os
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
|
| 32 |
+
import faiss
|
| 33 |
+
import numpy as np
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
|
| 37 |
+
from kvcos.engram.embedder import get_fingerprint
|
| 38 |
+
from kvcos.engram.format import EigramEncoder
|
| 39 |
+
|
| 40 |
+
logger = logging.getLogger(__name__)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
INDEX_DIR = Path(
|
| 44 |
+
os.environ.get("ENGRAM_INDEX_DIR", "~/.engram/index")
|
| 45 |
+
).expanduser()
|
| 46 |
+
|
| 47 |
+
KNOWLEDGE_DIR = Path(
|
| 48 |
+
os.environ.get("ENGRAM_KNOWLEDGE_DIR", "~/.engram/knowledge")
|
| 49 |
+
).expanduser()
|
| 50 |
+
|
| 51 |
+
INDEX_NAME = "knowledge"
|
| 52 |
+
|
| 53 |
+
_encoder = EigramEncoder()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@dataclass(frozen=True)
|
| 57 |
+
class KnowledgeResult:
|
| 58 |
+
"""Single search result from the knowledge index."""
|
| 59 |
+
doc_id: str
|
| 60 |
+
score: float
|
| 61 |
+
rank: int
|
| 62 |
+
source_path: str
|
| 63 |
+
project: str
|
| 64 |
+
content: str
|
| 65 |
+
chunk_info: str # "2/5" format
|
| 66 |
+
headers: list[str]
|
| 67 |
+
margin: float = 0.0
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class KnowledgeIndex:
|
| 71 |
+
"""HNSW index over the ENGRAM knowledge store.
|
| 72 |
+
|
| 73 |
+
Parameters match EngramIndex for consistency:
|
| 74 |
+
M=32, efConstruction=200, efSearch=64
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
M = 32
|
| 78 |
+
EF_CONSTRUCTION = 200
|
| 79 |
+
EF_SEARCH = 64
|
| 80 |
+
|
| 81 |
+
def __init__(self, dim: int = 384) -> None:
|
| 82 |
+
self._dim = dim
|
| 83 |
+
self._index: faiss.IndexHNSWFlat | None = None
|
| 84 |
+
self._meta: list[dict] = [] # per-vector metadata
|
| 85 |
+
self._n_docs: int = 0
|
| 86 |
+
|
| 87 |
+
@classmethod
|
| 88 |
+
def build_from_knowledge_dir(
|
| 89 |
+
cls,
|
| 90 |
+
knowledge_dir: Path | None = None,
|
| 91 |
+
verbose: bool = True,
|
| 92 |
+
) -> KnowledgeIndex:
|
| 93 |
+
"""Build HNSW index from all .eng files in the knowledge directory."""
|
| 94 |
+
if knowledge_dir is None:
|
| 95 |
+
knowledge_dir = KNOWLEDGE_DIR
|
| 96 |
+
|
| 97 |
+
eng_files = sorted(knowledge_dir.rglob("*.eng"), key=os.path.getmtime)
|
| 98 |
+
eng_files = [p for p in eng_files if p.suffix == ".eng"]
|
| 99 |
+
|
| 100 |
+
if not eng_files:
|
| 101 |
+
raise ValueError(f"No .eng files found in {knowledge_dir}")
|
| 102 |
+
|
| 103 |
+
vectors: list[torch.Tensor] = []
|
| 104 |
+
metas: list[dict] = []
|
| 105 |
+
skipped = 0
|
| 106 |
+
|
| 107 |
+
for p in eng_files:
|
| 108 |
+
try:
|
| 109 |
+
data = _encoder.decode(p.read_bytes())
|
| 110 |
+
|
| 111 |
+
fp = data.get("vec_fourier_v2")
|
| 112 |
+
if fp is None:
|
| 113 |
+
fp = data.get("vec_fourier")
|
| 114 |
+
if fp is None:
|
| 115 |
+
skipped += 1
|
| 116 |
+
continue
|
| 117 |
+
|
| 118 |
+
# Load sidecar metadata
|
| 119 |
+
meta_path = Path(str(p) + ".meta.json")
|
| 120 |
+
meta = {}
|
| 121 |
+
if meta_path.exists():
|
| 122 |
+
meta = json.loads(meta_path.read_text())
|
| 123 |
+
|
| 124 |
+
# Use sidecar description if longer than binary
|
| 125 |
+
description = meta.get("task_description", "") or \
|
| 126 |
+
data.get("task_description", "")
|
| 127 |
+
|
| 128 |
+
vectors.append(fp.float())
|
| 129 |
+
metas.append({
|
| 130 |
+
"doc_id": data.get("cache_id", p.stem),
|
| 131 |
+
"source_path": meta.get("source_path", ""),
|
| 132 |
+
"project": meta.get("project", ""),
|
| 133 |
+
"content": description,
|
| 134 |
+
"chunk_index": meta.get("chunk_index", 0),
|
| 135 |
+
"chunk_total": meta.get("chunk_total", 1),
|
| 136 |
+
"headers": meta.get("headers", []),
|
| 137 |
+
"fp_source": meta.get("fp_source", "unknown"),
|
| 138 |
+
})
|
| 139 |
+
except Exception as exc:
|
| 140 |
+
logger.debug("Skipping %s: %s", p, exc)
|
| 141 |
+
skipped += 1
|
| 142 |
+
|
| 143 |
+
if not vectors:
|
| 144 |
+
raise ValueError(
|
| 145 |
+
f"No valid fingerprints in {len(eng_files)} .eng files"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Stack and determine dimension from actual data
|
| 149 |
+
matrix = torch.stack(vectors)
|
| 150 |
+
dim = matrix.shape[1]
|
| 151 |
+
|
| 152 |
+
# Normalize for cosine similarity via L2
|
| 153 |
+
matrix = F.normalize(matrix, dim=-1).numpy().astype("float32")
|
| 154 |
+
|
| 155 |
+
# Build HNSW
|
| 156 |
+
obj = cls(dim=dim)
|
| 157 |
+
obj._index = faiss.IndexHNSWFlat(dim, cls.M)
|
| 158 |
+
obj._index.hnsw.efConstruction = cls.EF_CONSTRUCTION
|
| 159 |
+
obj._index.hnsw.efSearch = cls.EF_SEARCH
|
| 160 |
+
obj._index.add(matrix)
|
| 161 |
+
obj._meta = metas
|
| 162 |
+
obj._n_docs = len(metas)
|
| 163 |
+
|
| 164 |
+
if verbose:
|
| 165 |
+
projects = {m["project"] for m in metas}
|
| 166 |
+
logger.info("Knowledge HNSW: %d vectors, dim=%d", obj._n_docs, dim)
|
| 167 |
+
logger.info("Projects: %s", sorted(projects))
|
| 168 |
+
if skipped:
|
| 169 |
+
logger.warning("Skipped: %d files (no fingerprint)", skipped)
|
| 170 |
+
|
| 171 |
+
return obj
|
| 172 |
+
|
| 173 |
+
def search(
|
| 174 |
+
self,
|
| 175 |
+
query: str | torch.Tensor,
|
| 176 |
+
k: int = 5,
|
| 177 |
+
) -> list[KnowledgeResult]:
|
| 178 |
+
"""
|
| 179 |
+
Search the knowledge index.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
query: Search text (will be fingerprinted) or pre-computed tensor.
|
| 183 |
+
k: Number of results to return.
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
List of KnowledgeResult sorted by score descending.
|
| 187 |
+
"""
|
| 188 |
+
if self._index is None:
|
| 189 |
+
raise RuntimeError("Index not built. Call build_from_knowledge_dir() first.")
|
| 190 |
+
|
| 191 |
+
if isinstance(query, str):
|
| 192 |
+
query_fp, _ = get_fingerprint(query)
|
| 193 |
+
else:
|
| 194 |
+
query_fp = query
|
| 195 |
+
|
| 196 |
+
qn = F.normalize(
|
| 197 |
+
query_fp.float().unsqueeze(0), dim=-1
|
| 198 |
+
).numpy().astype("float32")
|
| 199 |
+
|
| 200 |
+
top = min(k + 1, self._n_docs)
|
| 201 |
+
D, I = self._index.search(qn, top)
|
| 202 |
+
|
| 203 |
+
results: list[KnowledgeResult] = []
|
| 204 |
+
for rank, (dist, idx) in enumerate(zip(D[0], I[0])):
|
| 205 |
+
if idx < 0 or idx >= len(self._meta):
|
| 206 |
+
continue
|
| 207 |
+
meta = self._meta[idx]
|
| 208 |
+
cosine = float(1.0 - dist / 2.0)
|
| 209 |
+
ci = meta.get("chunk_index", 0)
|
| 210 |
+
ct = meta.get("chunk_total", 1)
|
| 211 |
+
|
| 212 |
+
results.append(KnowledgeResult(
|
| 213 |
+
doc_id=meta["doc_id"],
|
| 214 |
+
score=cosine,
|
| 215 |
+
rank=rank,
|
| 216 |
+
source_path=meta.get("source_path", ""),
|
| 217 |
+
project=meta.get("project", ""),
|
| 218 |
+
content=meta.get("content", ""),
|
| 219 |
+
chunk_info=f"{ci + 1}/{ct}",
|
| 220 |
+
headers=meta.get("headers", []),
|
| 221 |
+
))
|
| 222 |
+
|
| 223 |
+
# Set margin on top result
|
| 224 |
+
if len(results) >= 2:
|
| 225 |
+
results[0] = KnowledgeResult(
|
| 226 |
+
doc_id=results[0].doc_id,
|
| 227 |
+
score=results[0].score,
|
| 228 |
+
rank=results[0].rank,
|
| 229 |
+
source_path=results[0].source_path,
|
| 230 |
+
project=results[0].project,
|
| 231 |
+
content=results[0].content,
|
| 232 |
+
chunk_info=results[0].chunk_info,
|
| 233 |
+
headers=results[0].headers,
|
| 234 |
+
margin=results[0].score - results[1].score,
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
return results[:k]
|
| 238 |
+
|
| 239 |
+
def save(self, index_dir: Path | None = None) -> Path:
|
| 240 |
+
"""Save index to disk."""
|
| 241 |
+
if index_dir is None:
|
| 242 |
+
index_dir = INDEX_DIR
|
| 243 |
+
index_dir.mkdir(parents=True, exist_ok=True)
|
| 244 |
+
|
| 245 |
+
faiss_path = index_dir / f"{INDEX_NAME}.faiss"
|
| 246 |
+
meta_path = index_dir / f"{INDEX_NAME}.meta.json"
|
| 247 |
+
|
| 248 |
+
faiss.write_index(self._index, str(faiss_path))
|
| 249 |
+
with open(meta_path, "w") as f:
|
| 250 |
+
json.dump({
|
| 251 |
+
"meta": self._meta,
|
| 252 |
+
"dim": self._dim,
|
| 253 |
+
"n_docs": self._n_docs,
|
| 254 |
+
}, f, indent=2)
|
| 255 |
+
|
| 256 |
+
return faiss_path
|
| 257 |
+
|
| 258 |
+
@classmethod
|
| 259 |
+
def load(cls, index_dir: Path | None = None) -> KnowledgeIndex:
|
| 260 |
+
"""Load pre-built index from disk."""
|
| 261 |
+
if index_dir is None:
|
| 262 |
+
index_dir = INDEX_DIR
|
| 263 |
+
|
| 264 |
+
faiss_path = index_dir / f"{INDEX_NAME}.faiss"
|
| 265 |
+
meta_path = index_dir / f"{INDEX_NAME}.meta.json"
|
| 266 |
+
|
| 267 |
+
if not faiss_path.exists():
|
| 268 |
+
raise FileNotFoundError(
|
| 269 |
+
f"No knowledge index at {faiss_path}. "
|
| 270 |
+
"Build with KnowledgeIndex.build_from_knowledge_dir()"
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
obj = cls()
|
| 274 |
+
obj._index = faiss.read_index(str(faiss_path))
|
| 275 |
+
with open(meta_path, "r") as f:
|
| 276 |
+
data = json.load(f)
|
| 277 |
+
obj._meta = data["meta"]
|
| 278 |
+
obj._dim = data["dim"]
|
| 279 |
+
obj._n_docs = data["n_docs"]
|
| 280 |
+
return obj
|
| 281 |
+
|
| 282 |
+
def __len__(self) -> int:
|
| 283 |
+
return self._n_docs
|
| 284 |
+
|
| 285 |
+
def __repr__(self) -> str:
|
| 286 |
+
return (
|
| 287 |
+
f"KnowledgeIndex(n={self._n_docs}, dim={self._dim}, "
|
| 288 |
+
f"M={self.M})"
|
| 289 |
+
)
|
kvcos/engram/manifest.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/manifest.py — Knowledge index manifest registry.
|
| 3 |
+
|
| 4 |
+
Tracks which source files have been indexed into .eng files,
|
| 5 |
+
their content hashes for incremental re-indexing, and chunk
|
| 6 |
+
metadata for multi-chunk files.
|
| 7 |
+
|
| 8 |
+
Storage: JSON file at ~/.engram/manifest.json (human-readable,
|
| 9 |
+
git-friendly, easily inspectable).
|
| 10 |
+
|
| 11 |
+
Thread safety: reads are lock-free, writes use atomic rename.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from __future__ import annotations
|
| 15 |
+
|
| 16 |
+
import hashlib
|
| 17 |
+
import json
|
| 18 |
+
import os
|
| 19 |
+
import tempfile
|
| 20 |
+
import time
|
| 21 |
+
from dataclasses import asdict, dataclass, field
|
| 22 |
+
from pathlib import Path
|
| 23 |
+
from typing import Iterator
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
@dataclass(frozen=True)
|
| 27 |
+
class ChunkRecord:
|
| 28 |
+
"""One indexed chunk from a source file."""
|
| 29 |
+
eng_path: str # Absolute path to .eng file
|
| 30 |
+
chunk_index: int # 0-based chunk index within source
|
| 31 |
+
chunk_total: int # Total chunks for this source
|
| 32 |
+
char_start: int # Start offset in source content
|
| 33 |
+
char_end: int # End offset in source content
|
| 34 |
+
indexed_at: float # Unix timestamp of indexing
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass(frozen=True)
|
| 38 |
+
class SourceRecord:
|
| 39 |
+
"""Registry entry for one indexed source file."""
|
| 40 |
+
source_path: str # Absolute path to original .md file
|
| 41 |
+
content_hash: str # SHA-256 of file content at index time
|
| 42 |
+
project: str # Project namespace (e.g., "engram", "_global")
|
| 43 |
+
file_size: int # Bytes at index time
|
| 44 |
+
chunks: tuple[ChunkRecord, ...] = ()
|
| 45 |
+
indexed_at: float = 0.0
|
| 46 |
+
last_verified: float = 0.0
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def eng_paths(self) -> list[str]:
|
| 50 |
+
"""All .eng file paths for this source."""
|
| 51 |
+
return [c.eng_path for c in self.chunks]
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _content_hash(content: str) -> str:
|
| 55 |
+
"""SHA-256 hex digest of string content."""
|
| 56 |
+
return hashlib.sha256(content.encode("utf-8")).hexdigest()
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _file_hash(path: Path) -> str:
|
| 60 |
+
"""SHA-256 hex digest of file on disk."""
|
| 61 |
+
return hashlib.sha256(path.read_bytes()).hexdigest()
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Manifest:
|
| 65 |
+
"""
|
| 66 |
+
Knowledge index manifest — tracks source-to-.eng mappings.
|
| 67 |
+
|
| 68 |
+
Immutable-style operations: all mutations return new state
|
| 69 |
+
and write atomically to disk.
|
| 70 |
+
|
| 71 |
+
Usage:
|
| 72 |
+
m = Manifest.load()
|
| 73 |
+
m = m.register(source_path, content_hash, project, chunks)
|
| 74 |
+
# m is now updated and persisted to disk
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
records: dict[str, SourceRecord],
|
| 80 |
+
manifest_path: Path,
|
| 81 |
+
) -> None:
|
| 82 |
+
self._records = dict(records) # defensive copy
|
| 83 |
+
self._path = manifest_path
|
| 84 |
+
|
| 85 |
+
@classmethod
|
| 86 |
+
def load(cls, manifest_path: Path | None = None) -> Manifest:
|
| 87 |
+
"""Load manifest from disk, or create empty if not found."""
|
| 88 |
+
if manifest_path is None:
|
| 89 |
+
manifest_path = Path(
|
| 90 |
+
os.environ.get("ENGRAM_MANIFEST_PATH",
|
| 91 |
+
"~/.engram/manifest.json")
|
| 92 |
+
).expanduser()
|
| 93 |
+
|
| 94 |
+
if manifest_path.exists():
|
| 95 |
+
data = json.loads(manifest_path.read_text())
|
| 96 |
+
records = {}
|
| 97 |
+
for key, rec_data in data.get("sources", {}).items():
|
| 98 |
+
chunks = tuple(
|
| 99 |
+
ChunkRecord(**c) for c in rec_data.pop("chunks", [])
|
| 100 |
+
)
|
| 101 |
+
records[key] = SourceRecord(**rec_data, chunks=chunks)
|
| 102 |
+
return cls(records, manifest_path)
|
| 103 |
+
|
| 104 |
+
return cls({}, manifest_path)
|
| 105 |
+
|
| 106 |
+
def register(
|
| 107 |
+
self,
|
| 108 |
+
source_path: str,
|
| 109 |
+
content_hash: str,
|
| 110 |
+
project: str,
|
| 111 |
+
file_size: int,
|
| 112 |
+
chunks: list[ChunkRecord],
|
| 113 |
+
) -> Manifest:
|
| 114 |
+
"""
|
| 115 |
+
Register a newly indexed source file. Returns updated Manifest.
|
| 116 |
+
|
| 117 |
+
Overwrites any existing record for the same source_path
|
| 118 |
+
(re-index scenario).
|
| 119 |
+
"""
|
| 120 |
+
now = time.time()
|
| 121 |
+
record = SourceRecord(
|
| 122 |
+
source_path=source_path,
|
| 123 |
+
content_hash=content_hash,
|
| 124 |
+
project=project,
|
| 125 |
+
file_size=file_size,
|
| 126 |
+
chunks=tuple(chunks),
|
| 127 |
+
indexed_at=now,
|
| 128 |
+
last_verified=now,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
new_records = dict(self._records)
|
| 132 |
+
new_records[source_path] = record
|
| 133 |
+
|
| 134 |
+
new_manifest = Manifest(new_records, self._path)
|
| 135 |
+
new_manifest._persist()
|
| 136 |
+
return new_manifest
|
| 137 |
+
|
| 138 |
+
def unregister(self, source_path: str) -> Manifest:
|
| 139 |
+
"""Remove a source from the manifest. Returns updated Manifest."""
|
| 140 |
+
new_records = {
|
| 141 |
+
k: v for k, v in self._records.items()
|
| 142 |
+
if k != source_path
|
| 143 |
+
}
|
| 144 |
+
new_manifest = Manifest(new_records, self._path)
|
| 145 |
+
new_manifest._persist()
|
| 146 |
+
return new_manifest
|
| 147 |
+
|
| 148 |
+
def needs_reindex(self, source_path: str, current_hash: str) -> bool:
|
| 149 |
+
"""Check if a source file needs re-indexing (content changed)."""
|
| 150 |
+
record = self._records.get(source_path)
|
| 151 |
+
if record is None:
|
| 152 |
+
return True
|
| 153 |
+
return record.content_hash != current_hash
|
| 154 |
+
|
| 155 |
+
def get_record(self, source_path: str) -> SourceRecord | None:
|
| 156 |
+
"""Look up a source record by path."""
|
| 157 |
+
return self._records.get(source_path)
|
| 158 |
+
|
| 159 |
+
def get_project_records(self, project: str) -> list[SourceRecord]:
|
| 160 |
+
"""All records for a given project namespace."""
|
| 161 |
+
return [
|
| 162 |
+
r for r in self._records.values()
|
| 163 |
+
if r.project == project
|
| 164 |
+
]
|
| 165 |
+
|
| 166 |
+
def all_records(self) -> Iterator[SourceRecord]:
|
| 167 |
+
"""Iterate over all registered source records."""
|
| 168 |
+
yield from self._records.values()
|
| 169 |
+
|
| 170 |
+
@property
|
| 171 |
+
def total_sources(self) -> int:
|
| 172 |
+
return len(self._records)
|
| 173 |
+
|
| 174 |
+
@property
|
| 175 |
+
def total_chunks(self) -> int:
|
| 176 |
+
return sum(len(r.chunks) for r in self._records.values())
|
| 177 |
+
|
| 178 |
+
@property
|
| 179 |
+
def projects(self) -> set[str]:
|
| 180 |
+
return {r.project for r in self._records.values()}
|
| 181 |
+
|
| 182 |
+
def summary(self) -> dict:
|
| 183 |
+
"""Quick stats for display."""
|
| 184 |
+
return {
|
| 185 |
+
"total_sources": self.total_sources,
|
| 186 |
+
"total_chunks": self.total_chunks,
|
| 187 |
+
"projects": sorted(self.projects),
|
| 188 |
+
"manifest_path": str(self._path),
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
def _persist(self) -> None:
|
| 192 |
+
"""Atomic write to disk via tempfile + rename."""
|
| 193 |
+
self._path.parent.mkdir(parents=True, exist_ok=True)
|
| 194 |
+
|
| 195 |
+
serializable = {
|
| 196 |
+
"version": 1,
|
| 197 |
+
"updated_at": time.time(),
|
| 198 |
+
"sources": {},
|
| 199 |
+
}
|
| 200 |
+
for key, rec in self._records.items():
|
| 201 |
+
rec_dict = asdict(rec)
|
| 202 |
+
serializable["sources"][key] = rec_dict
|
| 203 |
+
|
| 204 |
+
# Atomic write: write to temp, then rename
|
| 205 |
+
fd, tmp_path = tempfile.mkstemp(
|
| 206 |
+
dir=str(self._path.parent),
|
| 207 |
+
suffix=".tmp",
|
| 208 |
+
)
|
| 209 |
+
try:
|
| 210 |
+
with os.fdopen(fd, "w") as f:
|
| 211 |
+
json.dump(serializable, f, indent=2)
|
| 212 |
+
os.replace(tmp_path, str(self._path))
|
| 213 |
+
except Exception:
|
| 214 |
+
# Clean up temp file on failure
|
| 215 |
+
try:
|
| 216 |
+
os.unlink(tmp_path)
|
| 217 |
+
except OSError:
|
| 218 |
+
pass
|
| 219 |
+
raise
|
| 220 |
+
|
| 221 |
+
def __len__(self) -> int:
|
| 222 |
+
return self.total_sources
|
| 223 |
+
|
| 224 |
+
def __contains__(self, source_path: str) -> bool:
|
| 225 |
+
return source_path in self._records
|
| 226 |
+
|
| 227 |
+
def __repr__(self) -> str:
|
| 228 |
+
return (
|
| 229 |
+
f"Manifest({self.total_sources} sources, "
|
| 230 |
+
f"{self.total_chunks} chunks, "
|
| 231 |
+
f"projects={sorted(self.projects)})"
|
| 232 |
+
)
|
kvcos/engram/metadata_disambiguate.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/metadata_disambiguate.py
|
| 3 |
+
|
| 4 |
+
Stage 4 retrieval: activates when the fingerprint pipeline returns LOW.
|
| 5 |
+
Uses .eng metadata fields (domain, context_len, l2_norm, task_description)
|
| 6 |
+
to break ties that the Fourier fingerprint cannot resolve.
|
| 7 |
+
|
| 8 |
+
Returns Stage4Result with confidence='low-metadata' and metadata_used=True.
|
| 9 |
+
|
| 10 |
+
Design note (VRCM source):
|
| 11 |
+
When constraint satisfaction fails on the fingerprint axis, switch to
|
| 12 |
+
orthogonal axes — metadata fields that are independent of spectral structure.
|
| 13 |
+
The medicine/biology failure exists because their f0+f1 profiles are
|
| 14 |
+
spectrally identical. Their metadata is NOT identical: context_len differs,
|
| 15 |
+
l2_norm differs, task_description keywords differ. That orthogonal signal
|
| 16 |
+
is what Stage 4 exploits.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from __future__ import annotations
|
| 20 |
+
import re
|
| 21 |
+
from dataclasses import dataclass
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class Stage4Result:
|
| 26 |
+
doc_id: str
|
| 27 |
+
meta_score: float
|
| 28 |
+
confidence: str = 'low-metadata'
|
| 29 |
+
metadata_used: bool = True
|
| 30 |
+
domain_matched: bool = False
|
| 31 |
+
score_breakdown: dict = None
|
| 32 |
+
|
| 33 |
+
def __post_init__(self):
|
| 34 |
+
if self.score_breakdown is None:
|
| 35 |
+
self.score_breakdown = {}
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _keyword_overlap(text_a: str, text_b: str) -> float:
|
| 39 |
+
"""Jaccard overlap on lowercase word sets, excluding stopwords."""
|
| 40 |
+
STOPWORDS = {
|
| 41 |
+
'the', 'a', 'an', 'is', 'are', 'was', 'were', 'in', 'of',
|
| 42 |
+
'to', 'and', 'or', 'that', 'this', 'it', 'with', 'for',
|
| 43 |
+
'on', 'at', 'by', 'from', 'be', 'has', 'have', 'had',
|
| 44 |
+
}
|
| 45 |
+
def words(t):
|
| 46 |
+
return set(w for w in re.sub(r'[^a-z0-9 ]', ' ',
|
| 47 |
+
(t or '').lower()).split()
|
| 48 |
+
if w not in STOPWORDS and len(w) > 2)
|
| 49 |
+
a, b = words(text_a), words(text_b)
|
| 50 |
+
if not a or not b:
|
| 51 |
+
return 0.0
|
| 52 |
+
return len(a & b) / len(a | b)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def metadata_disambiguate(
|
| 56 |
+
candidates: list[dict],
|
| 57 |
+
query_metadata: dict,
|
| 58 |
+
domain_bonus: float = 0.3,
|
| 59 |
+
max_len: int = 8192,
|
| 60 |
+
max_norm: float = 10.0,
|
| 61 |
+
) -> Stage4Result | None:
|
| 62 |
+
"""
|
| 63 |
+
Stage 4 disambiguation using .eng metadata fields.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
candidates: list of dicts from eng_index values.
|
| 67 |
+
Each must have: cache_id, task_description,
|
| 68 |
+
context_len (optional), l2_norm (optional),
|
| 69 |
+
metadata dict with domain (optional).
|
| 70 |
+
query_metadata: dict with same structure as one candidate.
|
| 71 |
+
domain_bonus: score added for exact domain match (default 0.3).
|
| 72 |
+
max_len: normalisation constant for context_len diff.
|
| 73 |
+
max_norm: normalisation constant for l2_norm diff.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Stage4Result for the highest meta-scoring candidate, or None
|
| 77 |
+
if candidates list is empty.
|
| 78 |
+
"""
|
| 79 |
+
if not candidates:
|
| 80 |
+
return None
|
| 81 |
+
|
| 82 |
+
best: Stage4Result | None = None
|
| 83 |
+
|
| 84 |
+
q_domain = (query_metadata.get('metadata') or {}).get('domain', '')
|
| 85 |
+
q_len = float(query_metadata.get('context_len') or 512)
|
| 86 |
+
q_norm = float(query_metadata.get('l2_norm') or 1.0)
|
| 87 |
+
q_desc = (query_metadata.get('task_description') or '')[:80]
|
| 88 |
+
|
| 89 |
+
for cand in candidates:
|
| 90 |
+
c_domain = (cand.get('metadata') or {}).get('domain', '')
|
| 91 |
+
c_len = float(cand.get('context_len') or 512)
|
| 92 |
+
c_norm = float(cand.get('l2_norm') or 1.0)
|
| 93 |
+
c_desc = (cand.get('task_description') or '')[:80]
|
| 94 |
+
c_id = cand.get('cache_id', '')
|
| 95 |
+
|
| 96 |
+
domain_match = (q_domain and c_domain and q_domain == c_domain)
|
| 97 |
+
domain_score = domain_bonus if domain_match else 0.0
|
| 98 |
+
len_score = 1.0 - min(abs(q_len - c_len) / max(max_len, 1), 1.0)
|
| 99 |
+
norm_score = 1.0 - min(abs(q_norm - c_norm) / max(max_norm, 1), 1.0)
|
| 100 |
+
kw_score = _keyword_overlap(q_desc, c_desc)
|
| 101 |
+
meta_score = domain_score + len_score + norm_score + kw_score
|
| 102 |
+
|
| 103 |
+
r = Stage4Result(
|
| 104 |
+
doc_id = c_id,
|
| 105 |
+
meta_score = meta_score,
|
| 106 |
+
confidence = 'low-metadata',
|
| 107 |
+
metadata_used = True,
|
| 108 |
+
domain_matched = domain_match,
|
| 109 |
+
score_breakdown = {
|
| 110 |
+
'domain': round(domain_score, 3),
|
| 111 |
+
'len': round(len_score, 3),
|
| 112 |
+
'norm': round(norm_score, 3),
|
| 113 |
+
'kw': round(kw_score, 3),
|
| 114 |
+
},
|
| 115 |
+
)
|
| 116 |
+
if best is None or meta_score > best.meta_score:
|
| 117 |
+
best = r
|
| 118 |
+
|
| 119 |
+
return best
|
kvcos/engram/reader.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EIGENGRAM reader: .eng file -> IndexEntry + fingerprint vectors
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from .format import EigramDecoder
|
| 10 |
+
from kvcos.core.manifold_index import IndexEntry
|
| 11 |
+
|
| 12 |
+
_decoder = EigramDecoder()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def read_eigengram(path: str) -> dict:
|
| 16 |
+
"""Read a .eng file and return decoded fields."""
|
| 17 |
+
if not Path(path).exists():
|
| 18 |
+
raise FileNotFoundError(f"EIGENGRAM not found: {path}")
|
| 19 |
+
data = Path(path).read_bytes()
|
| 20 |
+
return _decoder.decode(data)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def load_eigengram_index(
|
| 24 |
+
paths: list[str],
|
| 25 |
+
fingerprint: str = "perdoc",
|
| 26 |
+
) -> tuple[list, list]:
|
| 27 |
+
"""Load multiple .eng files for ManifoldIndex.
|
| 28 |
+
|
| 29 |
+
fingerprint: 'perdoc' (same-model) | 'fcdb' (cross-model)
|
| 30 |
+
|
| 31 |
+
Returns (vecs, entries) ready for ManifoldIndex.add().
|
| 32 |
+
"""
|
| 33 |
+
if fingerprint not in ("perdoc", "fcdb", "fourier"):
|
| 34 |
+
raise ValueError(f"fingerprint must be 'perdoc', 'fcdb', or 'fourier', got '{fingerprint}'")
|
| 35 |
+
|
| 36 |
+
vecs = []
|
| 37 |
+
entries = []
|
| 38 |
+
key = f"vec_{fingerprint}"
|
| 39 |
+
|
| 40 |
+
for path in paths:
|
| 41 |
+
rec = read_eigengram(path)
|
| 42 |
+
vecs.append(rec[key])
|
| 43 |
+
entries.append(
|
| 44 |
+
IndexEntry(
|
| 45 |
+
cache_id=rec["cache_id"],
|
| 46 |
+
task_description=rec["task_description"],
|
| 47 |
+
model_id=rec["model_id"],
|
| 48 |
+
created_at=rec["created_at"],
|
| 49 |
+
context_len=rec["context_len"],
|
| 50 |
+
l2_norm=rec["l2_norm"],
|
| 51 |
+
)
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return vecs, entries
|
kvcos/engram/retrieval.py
ADDED
|
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM constrained retrieval — apophatic negative constraint layer.
|
| 3 |
+
|
| 4 |
+
Implements constrained_retrieve() which penalizes candidates too
|
| 5 |
+
similar to known confusion partners, resolving dense-region failures.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from __future__ import annotations
|
| 9 |
+
|
| 10 |
+
from dataclasses import dataclass, field
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class CosineResult:
|
| 18 |
+
"""Single retrieval result."""
|
| 19 |
+
|
| 20 |
+
doc_id: str
|
| 21 |
+
score: float
|
| 22 |
+
cos_score: float
|
| 23 |
+
margin: float = 0.0
|
| 24 |
+
constrained: bool = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@dataclass
|
| 28 |
+
class EngramQuery:
|
| 29 |
+
"""Query with optional negative constraints.
|
| 30 |
+
|
| 31 |
+
like: Fingerprint to match (positive constraint).
|
| 32 |
+
unlike: Fingerprints to avoid (negative constraints).
|
| 33 |
+
min_margin: Minimum acceptable score gap.
|
| 34 |
+
fingerprint: Which fingerprint field to use ('fourier', 'fcdb', 'perdoc').
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
like: torch.Tensor
|
| 38 |
+
unlike: list[torch.Tensor] = field(default_factory=list)
|
| 39 |
+
min_margin: float = 0.001
|
| 40 |
+
domain_hint: str | None = None
|
| 41 |
+
fingerprint: str = "fourier"
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def cosine_search(
|
| 45 |
+
query_fp: torch.Tensor,
|
| 46 |
+
index: dict[str, torch.Tensor],
|
| 47 |
+
top_k: int = 5,
|
| 48 |
+
) -> list[CosineResult]:
|
| 49 |
+
"""Standard unconstrained cosine similarity search."""
|
| 50 |
+
if not index:
|
| 51 |
+
return []
|
| 52 |
+
|
| 53 |
+
doc_ids = list(index.keys())
|
| 54 |
+
matrix = torch.stack([index[d] for d in doc_ids])
|
| 55 |
+
qn = F.normalize(query_fp.unsqueeze(0).float(), dim=-1)
|
| 56 |
+
mn = F.normalize(matrix.float(), dim=-1)
|
| 57 |
+
sims = (qn @ mn.T).squeeze(0)
|
| 58 |
+
|
| 59 |
+
top_indices = sims.topk(min(top_k, len(doc_ids))).indices.tolist()
|
| 60 |
+
results = [
|
| 61 |
+
CosineResult(
|
| 62 |
+
doc_id=doc_ids[i],
|
| 63 |
+
score=float(sims[i].item()),
|
| 64 |
+
cos_score=float(sims[i].item()),
|
| 65 |
+
)
|
| 66 |
+
for i in top_indices
|
| 67 |
+
]
|
| 68 |
+
if len(results) >= 2:
|
| 69 |
+
results[0].margin = results[0].score - results[1].score
|
| 70 |
+
return results
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def constrained_retrieve(
|
| 74 |
+
query: EngramQuery,
|
| 75 |
+
index: dict[str, torch.Tensor],
|
| 76 |
+
top_k: int = 5,
|
| 77 |
+
neg_weight: float = 0.5,
|
| 78 |
+
neg_threshold: float = 0.85,
|
| 79 |
+
) -> list[CosineResult]:
|
| 80 |
+
"""Retrieval with negative (apophatic) constraint layer.
|
| 81 |
+
|
| 82 |
+
Penalizes candidates too similar to `unlike` fingerprints.
|
| 83 |
+
In dense regions, this discriminates between docs that would
|
| 84 |
+
otherwise have identical cosine scores.
|
| 85 |
+
|
| 86 |
+
Algorithm:
|
| 87 |
+
1. Compute cosine similarity to query (positive score)
|
| 88 |
+
2. For each unlike fingerprint, compute sim to each candidate
|
| 89 |
+
3. Subtract penalty: neg_weight * max(0, sim_to_unlike - threshold)
|
| 90 |
+
4. Sort by adjusted score
|
| 91 |
+
"""
|
| 92 |
+
if not index:
|
| 93 |
+
return []
|
| 94 |
+
|
| 95 |
+
doc_ids = list(index.keys())
|
| 96 |
+
matrix = torch.stack([index[d] for d in doc_ids])
|
| 97 |
+
qn = F.normalize(query.like.unsqueeze(0).float(), dim=-1)
|
| 98 |
+
mn = F.normalize(matrix.float(), dim=-1)
|
| 99 |
+
cos_scores = (qn @ mn.T).squeeze(0)
|
| 100 |
+
|
| 101 |
+
adjusted = cos_scores.clone()
|
| 102 |
+
|
| 103 |
+
if query.unlike:
|
| 104 |
+
for unlike_fp in query.unlike:
|
| 105 |
+
un = F.normalize(unlike_fp.unsqueeze(0).float(), dim=-1)
|
| 106 |
+
neg_sims = (un @ mn.T).squeeze(0)
|
| 107 |
+
penalty = neg_weight * torch.clamp(neg_sims - neg_threshold, min=0)
|
| 108 |
+
adjusted = adjusted - penalty
|
| 109 |
+
|
| 110 |
+
top_indices = adjusted.topk(min(top_k, len(doc_ids))).indices.tolist()
|
| 111 |
+
results = [
|
| 112 |
+
CosineResult(
|
| 113 |
+
doc_id=doc_ids[i],
|
| 114 |
+
score=float(adjusted[i].item()),
|
| 115 |
+
cos_score=float(cos_scores[i].item()),
|
| 116 |
+
constrained=bool(query.unlike),
|
| 117 |
+
)
|
| 118 |
+
for i in top_indices
|
| 119 |
+
]
|
| 120 |
+
if len(results) >= 2:
|
| 121 |
+
results[0].margin = results[0].score - results[1].score
|
| 122 |
+
return results
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
# ── TWO-STAGE GEODESIC RETRIEVAL ──────────────────────────────────────
|
| 126 |
+
|
| 127 |
+
from enum import Enum
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class RetrievalConfidence(Enum):
|
| 131 |
+
HIGH = "high" # margin > 5x threshold, single pass sufficient
|
| 132 |
+
MEDIUM = "medium" # margin > threshold, or resolved by stage-2
|
| 133 |
+
LOW = "low" # margin < threshold after stage-2 — uncertain
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@dataclass
|
| 137 |
+
class GeodesicResult:
|
| 138 |
+
"""Result from geodesic_retrieve()."""
|
| 139 |
+
|
| 140 |
+
doc_id: str
|
| 141 |
+
score: float
|
| 142 |
+
margin: float
|
| 143 |
+
confidence: RetrievalConfidence
|
| 144 |
+
stages_used: int = 1 # 1 = single pass, 2 = two-stage, 3 = constrained
|
| 145 |
+
constraint_used: bool = False
|
| 146 |
+
stage4_used: bool = False
|
| 147 |
+
stage4_doc_id: str = ""
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def geodesic_retrieve(
|
| 151 |
+
query_fp: torch.Tensor,
|
| 152 |
+
hnsw_index, # EngramIndex instance
|
| 153 |
+
eng_index: dict, # {doc_id: eng_data} for constraint layer
|
| 154 |
+
margin_threshold: float = 0.005,
|
| 155 |
+
correction_weight: float = 0.3,
|
| 156 |
+
top_k: int = 5,
|
| 157 |
+
) -> GeodesicResult:
|
| 158 |
+
"""
|
| 159 |
+
Two-stage geodesic retrieval with automatic confidence scoring.
|
| 160 |
+
|
| 161 |
+
Stage 1: HNSW approximate nearest-neighbor search.
|
| 162 |
+
If margin(top1, top2) >= margin_threshold -> HIGH or MEDIUM.
|
| 163 |
+
Return immediately.
|
| 164 |
+
|
| 165 |
+
Stage 2: Activated when margin < margin_threshold.
|
| 166 |
+
Interpolate query fingerprint toward Stage-1 top-1 result.
|
| 167 |
+
The interpolation weight (correction_weight=0.3) bends the
|
| 168 |
+
geodesic toward the probable destination without assuming
|
| 169 |
+
the Stage-1 answer is correct.
|
| 170 |
+
If Stage-2 margin >= threshold -> MEDIUM confidence.
|
| 171 |
+
If Stage-2 margin still < threshold -> LOW confidence.
|
| 172 |
+
|
| 173 |
+
Stage 3: If confusion_flag is set on Stage-2 top result AND
|
| 174 |
+
eng_index is provided -> activate negative constraint.
|
| 175 |
+
Uses the confusion partner fingerprint as unlike constraint.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
query_fp: [dim] query fingerprint (v2 recommended).
|
| 179 |
+
hnsw_index: Built EngramIndex instance.
|
| 180 |
+
eng_index: Dict {doc_id: eng_data} loaded from .eng files.
|
| 181 |
+
Used for Stage-2 interpolation and Stage-3
|
| 182 |
+
constraint layer. Pass empty dict {} to disable.
|
| 183 |
+
margin_threshold: Minimum margin for MEDIUM confidence.
|
| 184 |
+
Default 0.005 (below S3 mean margin of 0.009).
|
| 185 |
+
correction_weight: Interpolation weight for Stage-2 trajectory
|
| 186 |
+
correction. 0.3 = 30% pull toward top-1.
|
| 187 |
+
Range: 0.1 (gentle) to 0.5 (aggressive).
|
| 188 |
+
top_k: Candidates per search pass.
|
| 189 |
+
|
| 190 |
+
Returns:
|
| 191 |
+
GeodesicResult with doc_id, score, margin, confidence, stages_used.
|
| 192 |
+
|
| 193 |
+
Usage:
|
| 194 |
+
result = geodesic_retrieve(query_fp, idx, eng_index={})
|
| 195 |
+
if result.confidence == RetrievalConfidence.LOW:
|
| 196 |
+
# Flag for human review or return with uncertainty warning
|
| 197 |
+
pass
|
| 198 |
+
"""
|
| 199 |
+
# Stage 1: HNSW search
|
| 200 |
+
s1_results = hnsw_index.search(query_fp, top_k=top_k)
|
| 201 |
+
if len(s1_results) < 2:
|
| 202 |
+
return GeodesicResult(
|
| 203 |
+
doc_id=s1_results[0].doc_id if s1_results else "",
|
| 204 |
+
score=s1_results[0].score if s1_results else 0.0,
|
| 205 |
+
margin=0.0,
|
| 206 |
+
confidence=RetrievalConfidence.LOW,
|
| 207 |
+
stages_used=1,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
s1_margin = s1_results[0].margin
|
| 211 |
+
|
| 212 |
+
# High confidence: single pass sufficient
|
| 213 |
+
if s1_margin >= margin_threshold * 5:
|
| 214 |
+
return GeodesicResult(
|
| 215 |
+
doc_id=s1_results[0].doc_id,
|
| 216 |
+
score=s1_results[0].score,
|
| 217 |
+
margin=s1_margin,
|
| 218 |
+
confidence=RetrievalConfidence.HIGH,
|
| 219 |
+
stages_used=1,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Medium confidence: above threshold but not high
|
| 223 |
+
if s1_margin >= margin_threshold:
|
| 224 |
+
return GeodesicResult(
|
| 225 |
+
doc_id=s1_results[0].doc_id,
|
| 226 |
+
score=s1_results[0].score,
|
| 227 |
+
margin=s1_margin,
|
| 228 |
+
confidence=RetrievalConfidence.MEDIUM,
|
| 229 |
+
stages_used=1,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Stage 2: trajectory correction
|
| 233 |
+
# Retrieve top-1 fingerprint from eng_index for interpolation
|
| 234 |
+
top1_id = s1_results[0].doc_id
|
| 235 |
+
top1_eng = eng_index.get(top1_id, {})
|
| 236 |
+
top1_fp = top1_eng.get("vec_fourier_v2")
|
| 237 |
+
if top1_fp is None:
|
| 238 |
+
top1_fp = top1_eng.get("vec_fourier")
|
| 239 |
+
|
| 240 |
+
if top1_fp is not None:
|
| 241 |
+
# Bend geodesic toward Stage-1 top-1
|
| 242 |
+
refined_fp = F.normalize(
|
| 243 |
+
(1 - correction_weight) * query_fp.float()
|
| 244 |
+
+ correction_weight * top1_fp.float(),
|
| 245 |
+
dim=-1,
|
| 246 |
+
)
|
| 247 |
+
s2_results = hnsw_index.search(refined_fp, top_k=top_k)
|
| 248 |
+
s2_margin = s2_results[0].margin if len(s2_results) >= 2 else 0.0
|
| 249 |
+
|
| 250 |
+
# Stage 3: check confusion_flag on Stage-2 top result
|
| 251 |
+
s2_top_id = s2_results[0].doc_id if s2_results else top1_id
|
| 252 |
+
s2_top_eng = eng_index.get(s2_top_id, {})
|
| 253 |
+
|
| 254 |
+
if s2_top_eng.get("confusion_flag") and eng_index:
|
| 255 |
+
# Activate negative constraint: find confusion partner fps
|
| 256 |
+
def _pick_fp(d: dict) -> torch.Tensor | None:
|
| 257 |
+
v = d.get("vec_fourier_v2")
|
| 258 |
+
return v if v is not None else d.get("vec_fourier")
|
| 259 |
+
|
| 260 |
+
confusion_fps = [
|
| 261 |
+
_pick_fp(d)
|
| 262 |
+
for did, d in eng_index.items()
|
| 263 |
+
if d.get("confusion_flag")
|
| 264 |
+
and did != s2_top_id
|
| 265 |
+
and _pick_fp(d) is not None
|
| 266 |
+
]
|
| 267 |
+
if confusion_fps:
|
| 268 |
+
# Build flat index for constrained_retrieve
|
| 269 |
+
flat_index = {
|
| 270 |
+
did: _pick_fp(d)
|
| 271 |
+
for did, d in eng_index.items()
|
| 272 |
+
if _pick_fp(d) is not None
|
| 273 |
+
}
|
| 274 |
+
q_constrained = EngramQuery(
|
| 275 |
+
like=refined_fp,
|
| 276 |
+
unlike=confusion_fps[:3], # top 3 confusion partners
|
| 277 |
+
min_margin=margin_threshold,
|
| 278 |
+
)
|
| 279 |
+
s3_results = constrained_retrieve(
|
| 280 |
+
q_constrained,
|
| 281 |
+
flat_index,
|
| 282 |
+
)
|
| 283 |
+
if s3_results:
|
| 284 |
+
s3_margin = s3_results[0].margin
|
| 285 |
+
s3_conf = (
|
| 286 |
+
RetrievalConfidence.MEDIUM
|
| 287 |
+
if s3_margin >= margin_threshold
|
| 288 |
+
else RetrievalConfidence.LOW
|
| 289 |
+
)
|
| 290 |
+
return GeodesicResult(
|
| 291 |
+
doc_id=s3_results[0].doc_id,
|
| 292 |
+
score=s3_results[0].score,
|
| 293 |
+
margin=s3_margin,
|
| 294 |
+
confidence=s3_conf,
|
| 295 |
+
stages_used=3,
|
| 296 |
+
constraint_used=True,
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
if s2_margin >= margin_threshold:
|
| 300 |
+
return GeodesicResult(
|
| 301 |
+
doc_id=s2_top_id,
|
| 302 |
+
score=s2_results[0].score,
|
| 303 |
+
margin=s2_margin,
|
| 304 |
+
confidence=RetrievalConfidence.MEDIUM,
|
| 305 |
+
stages_used=2,
|
| 306 |
+
)
|
| 307 |
+
else:
|
| 308 |
+
# Both stages low margin — return LOW confidence
|
| 309 |
+
return GeodesicResult(
|
| 310 |
+
doc_id=s2_top_id,
|
| 311 |
+
score=s2_results[0].score,
|
| 312 |
+
margin=s2_margin,
|
| 313 |
+
confidence=RetrievalConfidence.LOW,
|
| 314 |
+
stages_used=2,
|
| 315 |
+
)
|
| 316 |
+
else:
|
| 317 |
+
# No vector for interpolation — return Stage-1 with LOW confidence
|
| 318 |
+
return GeodesicResult(
|
| 319 |
+
doc_id=top1_id,
|
| 320 |
+
score=s1_results[0].score,
|
| 321 |
+
margin=s1_margin,
|
| 322 |
+
confidence=RetrievalConfidence.LOW,
|
| 323 |
+
stages_used=1,
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def geodesic_retrieve_with_prior(
|
| 329 |
+
query_fp: torch.Tensor,
|
| 330 |
+
hnsw_index,
|
| 331 |
+
eng_index: dict,
|
| 332 |
+
index_c=None,
|
| 333 |
+
query_doc_id: str | None = None,
|
| 334 |
+
margin_threshold: float = 0.005,
|
| 335 |
+
correction_weight: float = 0.3,
|
| 336 |
+
top_k: int = 5,
|
| 337 |
+
) -> GeodesicResult:
|
| 338 |
+
"""
|
| 339 |
+
Prior-aware geodesic retrieval. Uses IndexC history to pre-apply
|
| 340 |
+
constraints on known chronic failures — skipping Stages 1 and 2.
|
| 341 |
+
|
| 342 |
+
When index_c and query_doc_id are provided:
|
| 343 |
+
- If doc is a chronic failure: apply Stage 3 (constraint) immediately.
|
| 344 |
+
This avoids 2 wasted HNSW passes before getting to the constraint.
|
| 345 |
+
- If doc has prior LOW history (not yet chronic): lower threshold.
|
| 346 |
+
- If no prior: standard 3-stage geodesic_retrieve().
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
query_fp: [dim] query fingerprint.
|
| 350 |
+
hnsw_index: Built EngramIndex instance.
|
| 351 |
+
eng_index: {doc_id: eng_data} from .eng files.
|
| 352 |
+
index_c: IndexC instance, or None to disable prior mode.
|
| 353 |
+
query_doc_id: doc_id being queried (for prior lookup).
|
| 354 |
+
margin_threshold: Base margin threshold. Lowered if prior is LOW.
|
| 355 |
+
correction_weight: Stage-2 interpolation weight.
|
| 356 |
+
top_k: Candidates per HNSW pass.
|
| 357 |
+
|
| 358 |
+
Returns:
|
| 359 |
+
GeodesicResult. For chronic failures: stages_used=0 (preempted).
|
| 360 |
+
"""
|
| 361 |
+
if index_c is None or query_doc_id is None:
|
| 362 |
+
return geodesic_retrieve(
|
| 363 |
+
query_fp, hnsw_index, eng_index,
|
| 364 |
+
margin_threshold=margin_threshold,
|
| 365 |
+
correction_weight=correction_weight,
|
| 366 |
+
top_k=top_k,
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
prior = index_c.prior(query_doc_id)
|
| 370 |
+
|
| 371 |
+
# Preemptive mode: known chronic failure
|
| 372 |
+
if prior.is_chronic_failure and prior.n_total >= 2:
|
| 373 |
+
pairs = index_c.confusion_registry(min_confusions=1)
|
| 374 |
+
partners = [
|
| 375 |
+
p.doc_b for p in pairs if p.doc_a == query_doc_id
|
| 376 |
+
] + [
|
| 377 |
+
p.doc_a for p in pairs if p.doc_b == query_doc_id
|
| 378 |
+
]
|
| 379 |
+
|
| 380 |
+
unlike_fps = []
|
| 381 |
+
for partner_id in partners[:3]:
|
| 382 |
+
partner_eng = eng_index.get(partner_id, {})
|
| 383 |
+
fp = partner_eng.get("vec_fourier_v2")
|
| 384 |
+
if fp is None:
|
| 385 |
+
fp = partner_eng.get("vec_fourier")
|
| 386 |
+
if fp is not None:
|
| 387 |
+
unlike_fps.append(fp)
|
| 388 |
+
|
| 389 |
+
if unlike_fps:
|
| 390 |
+
flat_index: dict[str, torch.Tensor] = {}
|
| 391 |
+
for did, d in eng_index.items():
|
| 392 |
+
v = d.get("vec_fourier_v2")
|
| 393 |
+
if v is None:
|
| 394 |
+
v = d.get("vec_fourier")
|
| 395 |
+
if v is not None:
|
| 396 |
+
flat_index[did] = v
|
| 397 |
+
|
| 398 |
+
q_constrained = EngramQuery(
|
| 399 |
+
like=query_fp,
|
| 400 |
+
unlike=unlike_fps,
|
| 401 |
+
min_margin=margin_threshold,
|
| 402 |
+
)
|
| 403 |
+
s3 = constrained_retrieve(q_constrained, flat_index)
|
| 404 |
+
if s3:
|
| 405 |
+
s3_margin = s3[0].margin
|
| 406 |
+
return GeodesicResult(
|
| 407 |
+
doc_id=s3[0].doc_id,
|
| 408 |
+
score=s3[0].score,
|
| 409 |
+
margin=s3_margin,
|
| 410 |
+
confidence=(
|
| 411 |
+
RetrievalConfidence.MEDIUM
|
| 412 |
+
if s3_margin >= margin_threshold
|
| 413 |
+
else RetrievalConfidence.LOW
|
| 414 |
+
),
|
| 415 |
+
stages_used=0,
|
| 416 |
+
constraint_used=True,
|
| 417 |
+
)
|
| 418 |
+
|
| 419 |
+
# Prior LOW but not yet chronic: tighten threshold
|
| 420 |
+
if prior.n_low > 0 and prior.n_total > 0:
|
| 421 |
+
low_frac = prior.n_low / prior.n_total
|
| 422 |
+
margin_threshold = margin_threshold * (1 + low_frac)
|
| 423 |
+
|
| 424 |
+
return geodesic_retrieve(
|
| 425 |
+
query_fp, hnsw_index, eng_index,
|
| 426 |
+
margin_threshold=margin_threshold,
|
| 427 |
+
correction_weight=correction_weight,
|
| 428 |
+
top_k=top_k,
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def geodesic_retrieve_stage4(
|
| 433 |
+
query_fp: torch.Tensor,
|
| 434 |
+
hnsw_index,
|
| 435 |
+
eng_index: dict,
|
| 436 |
+
query_metadata: dict | None = None,
|
| 437 |
+
index_c=None,
|
| 438 |
+
query_doc_id: str | None = None,
|
| 439 |
+
margin_threshold: float = 0.005,
|
| 440 |
+
correction_weight: float = 0.3,
|
| 441 |
+
top_k: int = 5,
|
| 442 |
+
) -> GeodesicResult:
|
| 443 |
+
"""
|
| 444 |
+
Full pipeline: prior-aware geodesic retrieval with Stage 4 fallback.
|
| 445 |
+
|
| 446 |
+
Extends geodesic_retrieve_with_prior() with a Stage 4 metadata
|
| 447 |
+
disambiguation layer. When confidence is LOW and query_metadata
|
| 448 |
+
is provided, calls metadata_disambiguate() on top candidates
|
| 449 |
+
from the last HNSW pass before giving up.
|
| 450 |
+
|
| 451 |
+
Confidence tier:
|
| 452 |
+
HIGH -> fingerprint, 0% error rate target
|
| 453 |
+
MEDIUM -> fingerprint, 0% error rate target
|
| 454 |
+
LOW -> fingerprint failed, Stage 4 unavailable
|
| 455 |
+
low-metadata -> fingerprint failed, Stage 4 used secondary signal
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
query_fp: [dim] query fingerprint (vec_fourier_v2).
|
| 459 |
+
hnsw_index: Built EngramIndex.
|
| 460 |
+
eng_index: {doc_id: eng_data} from .eng files.
|
| 461 |
+
query_metadata: dict with task_description, context_len, l2_norm,
|
| 462 |
+
metadata.domain -- from the query doc's .eng data.
|
| 463 |
+
If None, Stage 4 is disabled.
|
| 464 |
+
index_c: IndexC instance for prior lookup.
|
| 465 |
+
query_doc_id: doc_id of the query source (for priors).
|
| 466 |
+
margin_threshold, correction_weight, top_k: as in base function.
|
| 467 |
+
"""
|
| 468 |
+
from kvcos.engram.metadata_disambiguate import metadata_disambiguate
|
| 469 |
+
|
| 470 |
+
base = geodesic_retrieve_with_prior(
|
| 471 |
+
query_fp, hnsw_index, eng_index,
|
| 472 |
+
index_c=index_c,
|
| 473 |
+
query_doc_id=query_doc_id,
|
| 474 |
+
margin_threshold=margin_threshold,
|
| 475 |
+
correction_weight=correction_weight,
|
| 476 |
+
top_k=top_k,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
# Only activate Stage 4 on LOW confidence with metadata available
|
| 480 |
+
if base.confidence != RetrievalConfidence.LOW:
|
| 481 |
+
return base
|
| 482 |
+
if query_metadata is None:
|
| 483 |
+
return base
|
| 484 |
+
|
| 485 |
+
# Get top candidates from HNSW for metadata scoring
|
| 486 |
+
candidates_hnsw = hnsw_index.search(query_fp, top_k=5)
|
| 487 |
+
candidates_meta = [
|
| 488 |
+
eng_index[r.doc_id]
|
| 489 |
+
for r in candidates_hnsw
|
| 490 |
+
if r.doc_id in eng_index
|
| 491 |
+
]
|
| 492 |
+
|
| 493 |
+
if not candidates_meta:
|
| 494 |
+
return base
|
| 495 |
+
|
| 496 |
+
s4 = metadata_disambiguate(candidates_meta, query_metadata)
|
| 497 |
+
if s4 is None:
|
| 498 |
+
return base
|
| 499 |
+
|
| 500 |
+
return GeodesicResult(
|
| 501 |
+
doc_id = s4.doc_id,
|
| 502 |
+
score = base.score,
|
| 503 |
+
margin = base.margin,
|
| 504 |
+
confidence = RetrievalConfidence.LOW, # still LOW
|
| 505 |
+
stages_used = base.stages_used,
|
| 506 |
+
constraint_used = base.constraint_used,
|
| 507 |
+
stage4_used = True,
|
| 508 |
+
stage4_doc_id = s4.doc_id,
|
| 509 |
+
)
|
| 510 |
+
# Note: confidence stays LOW -- Stage 4 is a tiebreaker, not a promotion.
|
| 511 |
+
# Callers should check stage4_used=True to distinguish "failed silently"
|
| 512 |
+
# from "failed with secondary signal". The confidence tier string
|
| 513 |
+
# 'low-metadata' is available via Stage4Result for logging.
|
kvcos/engram/session_propagator.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
kvcos/engram/session_propagator.py — Session start/end for ENGRAM.
|
| 3 |
+
|
| 4 |
+
Bridges geodesic_retrieve() results and IndexC persistence.
|
| 5 |
+
Call session_start() at the top of each session to load priors.
|
| 6 |
+
Call session_end() at the bottom to persist results.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
import time
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
|
| 15 |
+
from kvcos.engram.index_c import IndexC, DocPrior
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@dataclass
|
| 19 |
+
class SessionSummary:
|
| 20 |
+
session_id: str
|
| 21 |
+
n_total: int
|
| 22 |
+
n_correct: int
|
| 23 |
+
n_high: int
|
| 24 |
+
n_medium: int
|
| 25 |
+
n_low: int
|
| 26 |
+
n_preempted: int
|
| 27 |
+
new_confusion_pairs: list[tuple[str, str]]
|
| 28 |
+
recall: float
|
| 29 |
+
duration_s: float
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SessionPropagator:
|
| 33 |
+
"""
|
| 34 |
+
Manages session-level IndexC writes.
|
| 35 |
+
|
| 36 |
+
Accumulates retrieval results in memory during the session.
|
| 37 |
+
Writes them to IndexC at session_end().
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, db_path: str, session_id: str):
|
| 41 |
+
self._db_path = str(db_path)
|
| 42 |
+
self._session_id = session_id
|
| 43 |
+
self._ic: IndexC | None = None
|
| 44 |
+
self._records: list[dict] = []
|
| 45 |
+
self._start_ts: float = 0.0
|
| 46 |
+
self._started: bool = False
|
| 47 |
+
|
| 48 |
+
def session_start(self) -> dict[str, DocPrior]:
|
| 49 |
+
"""
|
| 50 |
+
Open IndexC, return {doc_id: DocPrior} for all known docs.
|
| 51 |
+
Call at the top of each session.
|
| 52 |
+
"""
|
| 53 |
+
self._ic = IndexC.open(self._db_path)
|
| 54 |
+
self._start_ts = time.time()
|
| 55 |
+
self._started = True
|
| 56 |
+
|
| 57 |
+
rmap = self._ic.reliability_map()
|
| 58 |
+
return {
|
| 59 |
+
doc_id: self._ic.prior(doc_id)
|
| 60 |
+
for doc_id in rmap
|
| 61 |
+
}
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def index_c(self) -> IndexC:
|
| 65 |
+
"""Access the IndexC instance (after session_start)."""
|
| 66 |
+
if self._ic is None:
|
| 67 |
+
raise RuntimeError("Call session_start() first.")
|
| 68 |
+
return self._ic
|
| 69 |
+
|
| 70 |
+
def record(
|
| 71 |
+
self,
|
| 72 |
+
query_doc_id: str,
|
| 73 |
+
result_doc_id: str,
|
| 74 |
+
confidence: str,
|
| 75 |
+
margin: float,
|
| 76 |
+
stages_used: int = 1,
|
| 77 |
+
constraint_used: bool = False,
|
| 78 |
+
correct: bool = True,
|
| 79 |
+
) -> None:
|
| 80 |
+
"""Buffer one retrieval result for this session."""
|
| 81 |
+
self._records.append({
|
| 82 |
+
"query_doc_id": query_doc_id,
|
| 83 |
+
"result_doc_id": result_doc_id,
|
| 84 |
+
"confidence": confidence,
|
| 85 |
+
"margin": float(margin),
|
| 86 |
+
"stages_used": int(stages_used),
|
| 87 |
+
"constraint_used": bool(constraint_used),
|
| 88 |
+
"correct": bool(correct),
|
| 89 |
+
"ts": time.time(),
|
| 90 |
+
})
|
| 91 |
+
|
| 92 |
+
def session_end(self) -> SessionSummary:
|
| 93 |
+
"""Write all buffered records to IndexC. Return summary."""
|
| 94 |
+
if not self._started or self._ic is None:
|
| 95 |
+
raise RuntimeError("Call session_start() before session_end().")
|
| 96 |
+
|
| 97 |
+
confusion_before = {
|
| 98 |
+
(p.doc_a, p.doc_b)
|
| 99 |
+
for p in self._ic.confusion_registry(min_confusions=1)
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
for rec in self._records:
|
| 103 |
+
self._ic.record(
|
| 104 |
+
session_id=self._session_id,
|
| 105 |
+
query_doc_id=rec["query_doc_id"],
|
| 106 |
+
result_doc_id=rec["result_doc_id"],
|
| 107 |
+
confidence=rec["confidence"],
|
| 108 |
+
margin=rec["margin"],
|
| 109 |
+
stages_used=rec["stages_used"],
|
| 110 |
+
constraint_used=rec["constraint_used"],
|
| 111 |
+
correct=rec["correct"],
|
| 112 |
+
ts=rec["ts"],
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
confusion_after = {
|
| 116 |
+
(p.doc_a, p.doc_b)
|
| 117 |
+
for p in self._ic.confusion_registry(min_confusions=1)
|
| 118 |
+
}
|
| 119 |
+
new_pairs = list(confusion_after - confusion_before)
|
| 120 |
+
|
| 121 |
+
n_total = len(self._records)
|
| 122 |
+
n_correct = sum(1 for r in self._records if r["correct"])
|
| 123 |
+
counters = {"high": 0, "medium": 0, "low": 0}
|
| 124 |
+
n_preempted = 0
|
| 125 |
+
for r in self._records:
|
| 126 |
+
counters[r["confidence"]] = counters.get(r["confidence"], 0) + 1
|
| 127 |
+
if r["stages_used"] == 0:
|
| 128 |
+
n_preempted += 1
|
| 129 |
+
|
| 130 |
+
summary = SessionSummary(
|
| 131 |
+
session_id=self._session_id,
|
| 132 |
+
n_total=n_total,
|
| 133 |
+
n_correct=n_correct,
|
| 134 |
+
n_high=counters["high"],
|
| 135 |
+
n_medium=counters["medium"],
|
| 136 |
+
n_low=counters["low"],
|
| 137 |
+
n_preempted=n_preempted,
|
| 138 |
+
new_confusion_pairs=new_pairs,
|
| 139 |
+
recall=n_correct / n_total if n_total > 0 else 0.0,
|
| 140 |
+
duration_s=time.time() - self._start_ts,
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
self._ic.close()
|
| 144 |
+
self._ic = None
|
| 145 |
+
self._started = False
|
| 146 |
+
self._records = []
|
| 147 |
+
|
| 148 |
+
return summary
|
| 149 |
+
|
| 150 |
+
def summary_str(self, s: SessionSummary) -> str:
|
| 151 |
+
return (
|
| 152 |
+
f"Session {s.session_id}: "
|
| 153 |
+
f"{s.n_total} retrievals | "
|
| 154 |
+
f"recall={s.recall:.1%} | "
|
| 155 |
+
f"H={s.n_high}/M={s.n_medium}/L={s.n_low} | "
|
| 156 |
+
f"preempted={s.n_preempted} | "
|
| 157 |
+
f"new_pairs={len(s.new_confusion_pairs)} | "
|
| 158 |
+
f"{s.duration_s:.1f}s"
|
| 159 |
+
)
|
kvcos/engram/writer.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
EIGENGRAM writer: text + model -> .eng file
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import gc
|
| 8 |
+
import hashlib
|
| 9 |
+
import os
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
from llama_cpp import Llama
|
| 14 |
+
|
| 15 |
+
from kvcos.core.blob_parser import parse_state_blob
|
| 16 |
+
from .format import EigramEncoder
|
| 17 |
+
|
| 18 |
+
_encoder = EigramEncoder()
|
| 19 |
+
|
| 20 |
+
_DEFAULT_LR = (8, 24)
|
| 21 |
+
_DEFAULT_GATE = 6
|
| 22 |
+
_DEFAULT_RANK = 116
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def _get_model_id(model_path: str) -> str:
|
| 26 |
+
name = os.path.basename(model_path)
|
| 27 |
+
if "3B" in name or "3b" in name:
|
| 28 |
+
return "Llama-3.2-3B"
|
| 29 |
+
if "8B" in name or "8b" in name:
|
| 30 |
+
return "Llama-3.1-8B"
|
| 31 |
+
return name[:15]
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _corpus_hash(basis_path: str) -> str:
|
| 35 |
+
raw = Path(basis_path).read_bytes()
|
| 36 |
+
return hashlib.sha256(raw).hexdigest()[:32]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def write_eigengram(
|
| 40 |
+
model_path: str,
|
| 41 |
+
text: str,
|
| 42 |
+
output_path: str,
|
| 43 |
+
cache_id: str = "",
|
| 44 |
+
task_description: str = "",
|
| 45 |
+
layer_range: tuple[int, int] = _DEFAULT_LR,
|
| 46 |
+
gate: int = _DEFAULT_GATE,
|
| 47 |
+
rank_perdoc: int = _DEFAULT_RANK,
|
| 48 |
+
basis_path: str = "results/corpus_basis_fcdb_v2.pt",
|
| 49 |
+
) -> dict:
|
| 50 |
+
"""Encode a document as an EIGENGRAM file."""
|
| 51 |
+
saved = torch.load(basis_path, weights_only=False)
|
| 52 |
+
P_fcdb = saved["basis"]
|
| 53 |
+
center = saved["joint_center"]
|
| 54 |
+
n_corpus = int(saved["n_docs"])
|
| 55 |
+
basis_rank = P_fcdb.shape[0]
|
| 56 |
+
|
| 57 |
+
llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=-1, verbose=False)
|
| 58 |
+
meta = llm.metadata
|
| 59 |
+
n_kv = int(meta.get("llama.attention.head_count_kv", "8"))
|
| 60 |
+
hd = int(meta.get("llama.embedding_length", "4096")) // int(
|
| 61 |
+
meta.get("llama.attention.head_count", "32")
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
llm.reset()
|
| 65 |
+
llm(text.strip(), max_tokens=1, temperature=0.0)
|
| 66 |
+
state_bytes = bytes(llm.save_state().llama_state)
|
| 67 |
+
del llm
|
| 68 |
+
gc.collect()
|
| 69 |
+
|
| 70 |
+
p = parse_state_blob(state_bytes, n_kv_heads=n_kv, head_dim=hd)
|
| 71 |
+
l0, l1 = layer_range
|
| 72 |
+
k = p.keys[l0:l1].float().reshape(-1, hd)
|
| 73 |
+
mean_v = k.mean(0)
|
| 74 |
+
l2_norm = float(mean_v.norm().item())
|
| 75 |
+
|
| 76 |
+
# Per-doc SVD fingerprint
|
| 77 |
+
if k.shape[0] > 8192:
|
| 78 |
+
gen = torch.Generator()
|
| 79 |
+
gen.manual_seed(42)
|
| 80 |
+
idx = torch.randperm(k.shape[0], generator=gen)[:8192]
|
| 81 |
+
svd_input = k[idx]
|
| 82 |
+
else:
|
| 83 |
+
svd_input = k
|
| 84 |
+
_, _, Vh = torch.linalg.svd(svd_input, full_matrices=False)
|
| 85 |
+
proj = (svd_input @ Vh[gate : gate + rank_perdoc].T).mean(0)
|
| 86 |
+
vec_perdoc = proj / (proj.norm() + 1e-8)
|
| 87 |
+
|
| 88 |
+
# FCDB fingerprint
|
| 89 |
+
delta = mean_v - center
|
| 90 |
+
delta = delta / (delta.norm() + 1e-8)
|
| 91 |
+
vec_fcdb = delta @ P_fcdb.T
|
| 92 |
+
vec_fcdb = vec_fcdb / (vec_fcdb.norm() + 1e-8)
|
| 93 |
+
|
| 94 |
+
# SCS
|
| 95 |
+
scs = float(
|
| 96 |
+
((delta @ P_fcdb.T @ P_fcdb) ** 2).sum().item()
|
| 97 |
+
/ ((delta**2).sum().item() + 1e-12)
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
corpus_h = _corpus_hash(basis_path)
|
| 101 |
+
model_id = _get_model_id(model_path)
|
| 102 |
+
|
| 103 |
+
cert = _encoder.encode(
|
| 104 |
+
vec_perdoc=vec_perdoc,
|
| 105 |
+
vec_fcdb=vec_fcdb,
|
| 106 |
+
joint_center=center,
|
| 107 |
+
corpus_hash=corpus_h,
|
| 108 |
+
model_id=model_id,
|
| 109 |
+
basis_rank=basis_rank,
|
| 110 |
+
n_corpus=n_corpus,
|
| 111 |
+
layer_range=layer_range,
|
| 112 |
+
context_len=int(k.shape[0]),
|
| 113 |
+
l2_norm=l2_norm,
|
| 114 |
+
scs=scs,
|
| 115 |
+
margin_proof=0.0,
|
| 116 |
+
task_description=task_description or text[:100],
|
| 117 |
+
cache_id=cache_id or "",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
|
| 121 |
+
with open(output_path, "wb") as f:
|
| 122 |
+
f.write(cert)
|
| 123 |
+
|
| 124 |
+
return {
|
| 125 |
+
"output_path": output_path,
|
| 126 |
+
"model_id": model_id,
|
| 127 |
+
"corpus_hash": corpus_h,
|
| 128 |
+
"basis_rank": basis_rank,
|
| 129 |
+
"n_corpus": n_corpus,
|
| 130 |
+
"file_size_bytes": len(cert),
|
| 131 |
+
"scs": round(scs, 4),
|
| 132 |
+
"l2_norm": round(l2_norm, 4),
|
| 133 |
+
"layer_range": layer_range,
|
| 134 |
+
}
|
kvcos/mar/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ENGRAM Protocol — MAR (Manifold Attention Retrieval)
|
| 2 |
+
|
| 3 |
+
Backward compatibility re-exports. All classes have moved to kvcos.core.
|
| 4 |
+
Import from kvcos.core directly for new code.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
from kvcos.core.manifold_index import IndexEntry, ManifoldIndex
|
| 8 |
+
from kvcos.core.retriever import EGRRetriever, RetrievalResponse, RetrievalResult
|
| 9 |
+
from kvcos.core.state_extractor import ExtractionResult, MARStateExtractor, SVDProjection
|
| 10 |
+
|
| 11 |
+
__all__ = [
|
| 12 |
+
"IndexEntry",
|
| 13 |
+
"ManifoldIndex",
|
| 14 |
+
"EGRRetriever",
|
| 15 |
+
"RetrievalResponse",
|
| 16 |
+
"RetrievalResult",
|
| 17 |
+
"ExtractionResult",
|
| 18 |
+
"MARStateExtractor",
|
| 19 |
+
"SVDProjection",
|
| 20 |
+
]
|
kvcos/storage/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""ENGRAM Protocol — Storage backends for .eng files."""
|
| 2 |
+
|
| 3 |
+
from kvcos.storage.backends import StorageBackend
|
| 4 |
+
from kvcos.storage.local import LocalStorageBackend
|
| 5 |
+
|
| 6 |
+
__all__ = [
|
| 7 |
+
"StorageBackend",
|
| 8 |
+
"LocalStorageBackend",
|
| 9 |
+
]
|
kvcos/storage/backends.py
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Abstract Storage Backend
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
All storage backends (local, redis, S3) implement this interface.
|
| 6 |
+
Phase 1 uses local disk only.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
from __future__ import annotations
|
| 10 |
+
|
| 11 |
+
from abc import ABC, abstractmethod
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
from kvcos.core.types import CacheStats, EngramMetadata
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class StorageBackend(ABC):
|
| 18 |
+
"""Abstract interface for engram storage backends.
|
| 19 |
+
|
| 20 |
+
All operations are synchronous in Phase 1.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def store(self, cache_id: str, data: bytes, metadata: EngramMetadata) -> str:
|
| 25 |
+
"""Store a .eng file. Returns storage path/key."""
|
| 26 |
+
...
|
| 27 |
+
|
| 28 |
+
@abstractmethod
|
| 29 |
+
def store_file(self, cache_id: str, source_path: Path, metadata: EngramMetadata) -> str:
|
| 30 |
+
"""Store a .eng file from a local path (zero-copy when possible)."""
|
| 31 |
+
...
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def get(self, cache_id: str) -> bytes | None:
|
| 35 |
+
"""Retrieve a .eng file as bytes. None if not found."""
|
| 36 |
+
...
|
| 37 |
+
|
| 38 |
+
@abstractmethod
|
| 39 |
+
def get_path(self, cache_id: str) -> Path | None:
|
| 40 |
+
"""Get local filesystem path for a cache entry. None if not found."""
|
| 41 |
+
...
|
| 42 |
+
|
| 43 |
+
@abstractmethod
|
| 44 |
+
def get_metadata(self, cache_id: str) -> EngramMetadata | None:
|
| 45 |
+
"""Read only metadata (header-only, no tensor data loaded)."""
|
| 46 |
+
...
|
| 47 |
+
|
| 48 |
+
@abstractmethod
|
| 49 |
+
def delete(self, cache_id: str) -> bool:
|
| 50 |
+
"""Delete a cache entry. Returns True if deleted."""
|
| 51 |
+
...
|
| 52 |
+
|
| 53 |
+
@abstractmethod
|
| 54 |
+
def list_entries(
|
| 55 |
+
self,
|
| 56 |
+
agent_id: str | None = None,
|
| 57 |
+
model_family: str | None = None,
|
| 58 |
+
limit: int = 100,
|
| 59 |
+
) -> list[EngramMetadata]:
|
| 60 |
+
"""List cache entries with optional filters."""
|
| 61 |
+
...
|
| 62 |
+
|
| 63 |
+
@abstractmethod
|
| 64 |
+
def exists(self, cache_id: str) -> bool:
|
| 65 |
+
"""Check if a cache entry exists."""
|
| 66 |
+
...
|
| 67 |
+
|
| 68 |
+
@abstractmethod
|
| 69 |
+
def stats(self) -> CacheStats:
|
| 70 |
+
"""Get aggregate statistics for the store."""
|
| 71 |
+
...
|
kvcos/storage/local.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ENGRAM Protocol — Local Disk Storage Backend
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
Directory layout:
|
| 6 |
+
{data_dir}/{model_family}/{agent_id}/{date}/{cache_id}.eng
|
| 7 |
+
|
| 8 |
+
Phase 1 production backend. Zero infrastructure dependencies.
|
| 9 |
+
Uses safetensors header-only read for metadata operations.
|
| 10 |
+
D7: One safetensors file per 256-token block.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
from __future__ import annotations
|
| 14 |
+
|
| 15 |
+
import logging
|
| 16 |
+
import shutil
|
| 17 |
+
from collections import defaultdict
|
| 18 |
+
from datetime import datetime, timezone
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
from kvcos.core.serializer import EngramSerializer
|
| 24 |
+
from kvcos.core.types import ENG_FILE_EXTENSION, CacheStats, EngramMetadata
|
| 25 |
+
from kvcos.storage.backends import StorageBackend
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class LocalStorageBackend(StorageBackend):
|
| 29 |
+
"""Local filesystem storage for .eng files.
|
| 30 |
+
|
| 31 |
+
Files organized by model family, agent ID, and date.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, data_dir: Path):
|
| 35 |
+
self.data_dir = data_dir
|
| 36 |
+
self.data_dir.mkdir(parents=True, exist_ok=True)
|
| 37 |
+
self._serializer = EngramSerializer()
|
| 38 |
+
self._index: dict[str, Path] = {} # cache_id → file path
|
| 39 |
+
self._rebuild_index()
|
| 40 |
+
|
| 41 |
+
def _rebuild_index(self) -> None:
|
| 42 |
+
"""Scan data directory and rebuild in-memory path index."""
|
| 43 |
+
self._index.clear()
|
| 44 |
+
for eng_file in self.data_dir.rglob(f"*{ENG_FILE_EXTENSION}"):
|
| 45 |
+
cache_id = eng_file.stem
|
| 46 |
+
try:
|
| 47 |
+
meta = self._serializer.read_metadata_only(eng_file)
|
| 48 |
+
if "cache_id" in meta:
|
| 49 |
+
cache_id = meta["cache_id"]
|
| 50 |
+
except Exception as e:
|
| 51 |
+
logger.debug("Skipping metadata for %s: %s", eng_file.name, e)
|
| 52 |
+
self._index[cache_id] = eng_file
|
| 53 |
+
|
| 54 |
+
def _resolve_path(self, metadata: EngramMetadata) -> Path:
|
| 55 |
+
"""Determine storage path from metadata."""
|
| 56 |
+
model_family = metadata.get("model_family", "unknown")
|
| 57 |
+
agent_id = metadata.get("agent_id", "default")
|
| 58 |
+
date_str = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
| 59 |
+
cache_id = metadata.get("cache_id", "unknown")
|
| 60 |
+
path = self.data_dir / model_family / agent_id / date_str / f"{cache_id}{ENG_FILE_EXTENSION}"
|
| 61 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 62 |
+
return path
|
| 63 |
+
|
| 64 |
+
def store(self, cache_id: str, data: bytes, metadata: EngramMetadata) -> str:
|
| 65 |
+
metadata_copy = dict(metadata)
|
| 66 |
+
metadata_copy["cache_id"] = cache_id
|
| 67 |
+
path = self._resolve_path(metadata_copy) # type: ignore[arg-type]
|
| 68 |
+
|
| 69 |
+
tmp_path = path.with_suffix(f"{ENG_FILE_EXTENSION}.tmp")
|
| 70 |
+
try:
|
| 71 |
+
tmp_path.write_bytes(data)
|
| 72 |
+
tmp_path.rename(path)
|
| 73 |
+
except Exception:
|
| 74 |
+
tmp_path.unlink(missing_ok=True)
|
| 75 |
+
raise
|
| 76 |
+
|
| 77 |
+
self._index[cache_id] = path
|
| 78 |
+
return str(path)
|
| 79 |
+
|
| 80 |
+
def store_file(self, cache_id: str, source_path: Path, metadata: EngramMetadata) -> str:
|
| 81 |
+
metadata_copy = dict(metadata)
|
| 82 |
+
metadata_copy["cache_id"] = cache_id
|
| 83 |
+
dest_path = self._resolve_path(metadata_copy) # type: ignore[arg-type]
|
| 84 |
+
|
| 85 |
+
if source_path == dest_path:
|
| 86 |
+
self._index[cache_id] = dest_path
|
| 87 |
+
return str(dest_path)
|
| 88 |
+
|
| 89 |
+
tmp_path = dest_path.with_suffix(f"{ENG_FILE_EXTENSION}.tmp")
|
| 90 |
+
try:
|
| 91 |
+
shutil.copy2(str(source_path), str(tmp_path))
|
| 92 |
+
tmp_path.rename(dest_path)
|
| 93 |
+
except Exception:
|
| 94 |
+
tmp_path.unlink(missing_ok=True)
|
| 95 |
+
raise
|
| 96 |
+
|
| 97 |
+
self._index[cache_id] = dest_path
|
| 98 |
+
return str(dest_path)
|
| 99 |
+
|
| 100 |
+
def get(self, cache_id: str) -> bytes | None:
|
| 101 |
+
path = self._index.get(cache_id)
|
| 102 |
+
if path is None or not path.exists():
|
| 103 |
+
return None
|
| 104 |
+
return path.read_bytes()
|
| 105 |
+
|
| 106 |
+
def get_path(self, cache_id: str) -> Path | None:
|
| 107 |
+
path = self._index.get(cache_id)
|
| 108 |
+
if path is None or not path.exists():
|
| 109 |
+
return None
|
| 110 |
+
return path
|
| 111 |
+
|
| 112 |
+
def get_metadata(self, cache_id: str) -> EngramMetadata | None:
|
| 113 |
+
path = self._index.get(cache_id)
|
| 114 |
+
if path is None or not path.exists():
|
| 115 |
+
return None
|
| 116 |
+
try:
|
| 117 |
+
return self._serializer.read_metadata_only(path)
|
| 118 |
+
except Exception as e:
|
| 119 |
+
logger.warning("Failed to read metadata for %s: %s", cache_id, e)
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
def delete(self, cache_id: str) -> bool:
|
| 123 |
+
path = self._index.pop(cache_id, None)
|
| 124 |
+
if path is None or not path.exists():
|
| 125 |
+
return False
|
| 126 |
+
|
| 127 |
+
path.unlink()
|
| 128 |
+
|
| 129 |
+
parent = path.parent
|
| 130 |
+
try:
|
| 131 |
+
while parent != self.data_dir:
|
| 132 |
+
if not any(parent.iterdir()):
|
| 133 |
+
parent.rmdir()
|
| 134 |
+
parent = parent.parent
|
| 135 |
+
else:
|
| 136 |
+
break
|
| 137 |
+
except OSError:
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
return True
|
| 141 |
+
|
| 142 |
+
def list_entries(
|
| 143 |
+
self,
|
| 144 |
+
agent_id: str | None = None,
|
| 145 |
+
model_family: str | None = None,
|
| 146 |
+
limit: int = 100,
|
| 147 |
+
) -> list[EngramMetadata]:
|
| 148 |
+
results: list[EngramMetadata] = []
|
| 149 |
+
|
| 150 |
+
for cache_id, path in self._index.items():
|
| 151 |
+
if len(results) >= limit:
|
| 152 |
+
break
|
| 153 |
+
if not path.exists():
|
| 154 |
+
continue
|
| 155 |
+
try:
|
| 156 |
+
meta = self._serializer.read_metadata_only(path)
|
| 157 |
+
except Exception as e:
|
| 158 |
+
logger.debug("Skipping %s in list_entries: %s", cache_id, e)
|
| 159 |
+
continue
|
| 160 |
+
if agent_id and meta.get("agent_id") != agent_id:
|
| 161 |
+
continue
|
| 162 |
+
if model_family and meta.get("model_family") != model_family:
|
| 163 |
+
continue
|
| 164 |
+
results.append(meta)
|
| 165 |
+
|
| 166 |
+
results.sort(key=lambda m: m.get("created_at", ""), reverse=True)
|
| 167 |
+
return results[:limit]
|
| 168 |
+
|
| 169 |
+
def exists(self, cache_id: str) -> bool:
|
| 170 |
+
path = self._index.get(cache_id)
|
| 171 |
+
return path is not None and path.exists()
|
| 172 |
+
|
| 173 |
+
def stats(self) -> CacheStats:
|
| 174 |
+
total_entries = 0
|
| 175 |
+
total_size = 0
|
| 176 |
+
model_counts: dict[str, int] = defaultdict(int)
|
| 177 |
+
|
| 178 |
+
for cache_id, path in self._index.items():
|
| 179 |
+
if not path.exists():
|
| 180 |
+
continue
|
| 181 |
+
total_entries += 1
|
| 182 |
+
total_size += path.stat().st_size
|
| 183 |
+
try:
|
| 184 |
+
meta = self._serializer.read_metadata_only(path)
|
| 185 |
+
model_counts[meta.get("model_family", "unknown")] += 1
|
| 186 |
+
except Exception as e:
|
| 187 |
+
logger.debug("Metadata read failed for %s: %s", cache_id, e)
|
| 188 |
+
model_counts["unknown"] += 1
|
| 189 |
+
|
| 190 |
+
return CacheStats(
|
| 191 |
+
total_entries=total_entries,
|
| 192 |
+
total_size_bytes=total_size,
|
| 193 |
+
avg_compression_ratio=0.0,
|
| 194 |
+
model_breakdown=dict(model_counts),
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
def vacuum(self) -> int:
|
| 198 |
+
"""Remove stale index entries for deleted files."""
|
| 199 |
+
stale = [cid for cid, path in self._index.items() if not path.exists()]
|
| 200 |
+
for cid in stale:
|
| 201 |
+
del self._index[cid]
|
| 202 |
+
return len(stale)
|