metropolis-chess / app /memory /vector_store.py
Forkei's picture
feat(block13): agent vs character match flow (commits 3–8)
b0d952a
"""Interface over memory embeddings.
Currently a thin wrapper around the embedding column on Memory: cosine
similarity is computed in-Python with numpy. The corpus is small
(~40-50 memories per character, <5 characters preset), so in-memory
scoring is sub-millisecond.
Phase 3 can swap the implementation for sqlite-vec without touching
callers — the interface (`upsert`, `search`, `get_embedding`) is stable.
"""
from __future__ import annotations
import logging
import math
from dataclasses import dataclass
import numpy as np
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.models.memory import Memory, MemoryScope
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class VectorHit:
memory_id: str
score: float
def cosine_similarity(a: list[float], b: list[float]) -> float:
if not a or not b:
return 0.0
va = np.asarray(a, dtype=np.float32)
vb = np.asarray(b, dtype=np.float32)
na = float(np.linalg.norm(va))
nb = float(np.linalg.norm(vb))
if na == 0.0 or nb == 0.0:
return 0.0
return float(np.dot(va, vb) / (na * nb))
def upsert(session: Session, *, memory_id: str, embedding: list[float]) -> None:
"""Write `embedding` onto the memory row. Caller commits."""
row = session.get(Memory, memory_id)
if row is None:
raise ValueError(f"Memory {memory_id} not found")
row.embedding = list(embedding)
session.flush()
def search(
session: Session,
*,
query_embedding: list[float],
k: int,
character_id: str | None = None,
agent_id: str | None = None,
scope: MemoryScope | None = None,
player_id: str | None = None,
include_null_player: bool = True,
) -> list[VectorHit]:
"""Return top-k memories ranked by cosine similarity.
Pass either `character_id` (for character-scoped memories) or `agent_id`
(for agent-scoped memories, Block 13+). Exactly one must be provided.
Filters:
- `scope`: restrict to a specific MemoryScope if given
- `player_id`: restrict to memories attached to this player; when
`include_null_player=True` (default) memories with `player_id` NULL
are also kept (character_lore, cross_player-style general memories)
Memories without an embedding are skipped silently — the caller should
run the backfill script to cover them.
"""
if agent_id is not None:
stmt = select(Memory).where(Memory.agent_id == agent_id)
elif character_id is not None:
stmt = select(Memory).where(Memory.character_id == character_id)
else:
raise ValueError("search() requires either character_id or agent_id")
if scope is not None:
stmt = stmt.where(Memory.scope == scope)
if player_id is not None:
if include_null_player:
stmt = stmt.where((Memory.player_id == player_id) | (Memory.player_id.is_(None)))
else:
stmt = stmt.where(Memory.player_id == player_id)
rows = list(session.execute(stmt).scalars())
if not rows:
return []
# Keep only rows that have been embedded.
candidates: list[tuple[str, list[float]]] = []
for m in rows:
if m.embedding:
candidates.append((m.id, list(m.embedding)))
if not candidates:
return []
# Vectorized cosine over the candidate set.
q = np.asarray(query_embedding, dtype=np.float32)
q_norm = float(np.linalg.norm(q))
if q_norm == 0.0:
return []
mat = np.asarray([c[1] for c in candidates], dtype=np.float32)
norms = np.linalg.norm(mat, axis=1)
# Avoid divide-by-zero.
safe_norms = np.where(norms == 0.0, 1.0, norms)
sims = (mat @ q) / (safe_norms * q_norm)
ranked = sorted(
(VectorHit(memory_id=cid, score=float(sims[i])) for i, (cid, _) in enumerate(candidates)),
key=lambda h: h.score,
reverse=True,
)
return ranked[:k]
def get_embedding(session: Session, memory_id: str) -> list[float] | None:
row = session.get(Memory, memory_id)
if row is None:
return None
return list(row.embedding) if row.embedding else None
def ensure_embedding_column(bind) -> None:
"""Idempotently add the `embedding` column to `memories` on existing DBs.
`Base.metadata.create_all` only creates missing tables, not missing
columns. Phase 2b adds a column to a pre-existing table; this helper
bridges the gap without requiring a proper migration tool in 2b.
"""
with bind.begin() as conn:
dialect = conn.dialect.name
if dialect != "sqlite":
# Other backends will need proper migrations; bail loudly.
logger.warning(
"ensure_embedding_column: dialect=%s not SQLite — skipping; "
"ensure the `memories.embedding` column exists via your migration tool.",
dialect,
)
return
cols = conn.exec_driver_sql("PRAGMA table_info(memories)").fetchall()
names = {row[1] for row in cols}
if "embedding" not in names:
logger.info("Adding memories.embedding column (Phase 2b migration).")
conn.exec_driver_sql("ALTER TABLE memories ADD COLUMN embedding JSON")