"""Small vector index for nearest-neighbor exploration.""" from __future__ import annotations import math from dataclasses import dataclass from pathlib import Path @dataclass(frozen=True, slots=True) class Neighbor: """A nearest-neighbor result from a vector search.""" item_id: str score: float class VectorIndex: """In-memory cosine-similarity vector index.""" def __init__(self) -> None: """Create an empty vector index.""" self._vectors: dict[str, tuple[float, ...]] = {} def add(self, item_id: str, vector: tuple[float, ...]) -> None: """Add or replace an item vector.""" if not vector: raise ValueError("Vector must contain at least one value.") if not all(math.isfinite(value) for value in vector): raise ValueError("Vector values must be finite.") self._vectors[item_id] = vector def search( self, query_vector: tuple[float, ...], *, limit: int = 5, exclude_id: str | None = None ) -> list[Neighbor]: """Return the closest vectors by cosine similarity.""" if limit < 1: raise ValueError("Limit must be at least 1.") neighbors = [ Neighbor(item_id=item_id, score=_cosine_similarity(query_vector, vector)) for item_id, vector in self._vectors.items() if item_id != exclude_id ] return sorted(neighbors, key=lambda neighbor: neighbor.score, reverse=True)[:limit] def save_tsv(self, path: Path) -> None: """Persist the index as a simple tab-separated text file.""" path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", encoding="utf-8") as index_file: for item_id, vector in sorted(self._vectors.items()): values = "\t".join(str(value) for value in vector) index_file.write(f"{item_id}\t{values}\n") @classmethod def load_tsv(cls, path: Path) -> VectorIndex: """Load an index produced by :meth:`save_tsv`.""" index = cls() with path.open("r", encoding="utf-8") as index_file: for line_number, line in enumerate(index_file, start=1): fields = line.rstrip("\n").split("\t") if len(fields) < 2: msg = f"Line {line_number} must contain an id and vector values." raise ValueError(msg) index.add(fields[0], tuple(float(value) for value in fields[1:])) return index def __len__(self) -> int: """Return the number of indexed vectors.""" return len(self._vectors) def _cosine_similarity(left: tuple[float, ...], right: tuple[float, ...]) -> float: if len(left) != len(right): raise ValueError("Vectors must have the same dimensions.") left_norm = math.sqrt(sum(value * value for value in left)) right_norm = math.sqrt(sum(value * value for value in right)) if left_norm == 0.0 or right_norm == 0.0: return 0.0 dot_product = sum( left_value * right_value for left_value, right_value in zip(left, right, strict=True) ) return dot_product / (left_norm * right_norm) def cosine_similarity(left: tuple[float, ...], right: tuple[float, ...]) -> float: """Public wrapper for cosine similarity (useful in notebooks/examples).""" return _cosine_similarity(left, right)