| """Small vector index for nearest-neighbor exploration.""" |
|
|
| from __future__ import annotations |
|
|
| import math |
| from dataclasses import dataclass |
| from pathlib import Path |
|
|
|
|
| @dataclass(frozen=True, slots=True) |
| class Neighbor: |
| """A nearest-neighbor result from a vector search.""" |
|
|
| item_id: str |
| score: float |
|
|
|
|
| class VectorIndex: |
| """In-memory cosine-similarity vector index.""" |
|
|
| def __init__(self) -> None: |
| """Create an empty vector index.""" |
|
|
| self._vectors: dict[str, tuple[float, ...]] = {} |
|
|
| def add(self, item_id: str, vector: tuple[float, ...]) -> None: |
| """Add or replace an item vector.""" |
|
|
| if not vector: |
| raise ValueError("Vector must contain at least one value.") |
| if not all(math.isfinite(value) for value in vector): |
| raise ValueError("Vector values must be finite.") |
| self._vectors[item_id] = vector |
|
|
| def search( |
| self, query_vector: tuple[float, ...], *, limit: int = 5, exclude_id: str | None = None |
| ) -> list[Neighbor]: |
| """Return the closest vectors by cosine similarity.""" |
|
|
| if limit < 1: |
| raise ValueError("Limit must be at least 1.") |
| neighbors = [ |
| Neighbor(item_id=item_id, score=_cosine_similarity(query_vector, vector)) |
| for item_id, vector in self._vectors.items() |
| if item_id != exclude_id |
| ] |
| return sorted(neighbors, key=lambda neighbor: neighbor.score, reverse=True)[:limit] |
|
|
| def save_tsv(self, path: Path) -> None: |
| """Persist the index as a simple tab-separated text file.""" |
|
|
| path.parent.mkdir(parents=True, exist_ok=True) |
| with path.open("w", encoding="utf-8") as index_file: |
| for item_id, vector in sorted(self._vectors.items()): |
| values = "\t".join(str(value) for value in vector) |
| index_file.write(f"{item_id}\t{values}\n") |
|
|
| @classmethod |
| def load_tsv(cls, path: Path) -> VectorIndex: |
| """Load an index produced by :meth:`save_tsv`.""" |
|
|
| index = cls() |
| with path.open("r", encoding="utf-8") as index_file: |
| for line_number, line in enumerate(index_file, start=1): |
| fields = line.rstrip("\n").split("\t") |
| if len(fields) < 2: |
| msg = f"Line {line_number} must contain an id and vector values." |
| raise ValueError(msg) |
| index.add(fields[0], tuple(float(value) for value in fields[1:])) |
| return index |
|
|
| def __len__(self) -> int: |
| """Return the number of indexed vectors.""" |
|
|
| return len(self._vectors) |
|
|
|
|
| def _cosine_similarity(left: tuple[float, ...], right: tuple[float, ...]) -> float: |
| if len(left) != len(right): |
| raise ValueError("Vectors must have the same dimensions.") |
| left_norm = math.sqrt(sum(value * value for value in left)) |
| right_norm = math.sqrt(sum(value * value for value in right)) |
| if left_norm == 0.0 or right_norm == 0.0: |
| return 0.0 |
| dot_product = sum( |
| left_value * right_value for left_value, right_value in zip(left, right, strict=True) |
| ) |
| return dot_product / (left_norm * right_norm) |
|
|
|
|
| def cosine_similarity(left: tuple[float, ...], right: tuple[float, ...]) -> float: |
| """Public wrapper for cosine similarity (useful in notebooks/examples).""" |
|
|
| return _cosine_similarity(left, right) |
|
|