from __future__ import annotations import json from dataclasses import dataclass from datetime import datetime from pathlib import Path from typing import List, Tuple import faiss import numpy as np @dataclass class VectorIndexConfig: index_type: str = "IndexFlatIP" embedding_dim: int = 384 class VectorIndex: def __init__(self, embeddings: np.ndarray, config: VectorIndexConfig | None = None): cfg = config or VectorIndexConfig(embedding_dim=embeddings.shape[1]) if cfg.index_type == "IndexFlatIP": self.index = faiss.IndexFlatIP(cfg.embedding_dim) elif cfg.index_type == "IndexHNSWFlat": self.index = faiss.IndexHNSWFlat(cfg.embedding_dim, 32) else: raise ValueError(f"Unsupported index_type {cfg.index_type}") self.index.add(embeddings) self.config = cfg def search(self, query_vector: np.ndarray, k: int = 10) -> Tuple[np.ndarray, np.ndarray]: scores, idx = self.index.search(query_vector.astype(np.float32)[None, :], k) return scores[0], idx[0] def save(self, path: str, metadata: dict) -> None: Path(path).parent.mkdir(parents=True, exist_ok=True) faiss.write_index(self.index, path) meta = { **metadata, "index_type": self.config.index_type, "embedding_dim": self.config.embedding_dim, "saved_at": datetime.utcnow().isoformat(), } with open(Path(path).with_suffix(".json"), "w") as f: json.dump(meta, f, indent=2) @classmethod def load(cls, path: str): index = faiss.read_index(path) obj = cls.__new__(cls) obj.index = index obj.config = VectorIndexConfig(index.d) return obj