import faiss import numpy as np import pickle from pathlib import Path from langchain_openai import OpenAIEmbeddings from threading import Lock from typing import List, Dict, Any import logging logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) DocumentChunk = Dict[str, Any] class FAISSVectorStore: def __init__( self, dimension: int = 3072, index_path: str = "faiss_index", embedding_model: str = "text-embedding-3-large", #3072-dim vectors ): if OpenAIEmbeddings is None: raise ImportError( "Could not import OpenAIEmbeddings from langchain. " "Install langchain or adapt the import to your environment." ) self.dimension = dimension self.index_path = Path(index_path) self._lock = Lock() self.index_path.mkdir(parents=True, exist_ok=True) # Instantiate embeddings (may make API calls later when embedding) self.embeddings = OpenAIEmbeddings(model=embedding_model) # in-memory structures self.documents: List[DocumentChunk] = [] # Create a new FAISS index (will be replaced by load if a saved index exists) self.index = faiss.IndexFlatIP(self.dimension) # All vectors must be this length # If there's a saved index, load it (overwrites the index created above). self.load_index() # safe: will return False if nothing to load def _ensure_index_dim(self, d: int): """Ensure FAISS index has dimension d.""" # If current index has no vectors, and d != self.dimension, recreate. # Using getattr for defensive programming if getattr(self.index, "ntotal", 0) == 0 and getattr(self.index, "d", None) != d: logger.info("Recreating an empty index with dimension %d", d) self.dimension = d self.index = faiss.IndexFlatIP(self.dimension) elif getattr(self.index, "d", None) is not None and self.index.d != d: raise ValueError(f"Embedding dimension ({d}) does not match existing index dimension ({self.index.d}).") def add_documents(self, chunks: List[DocumentChunk], save: bool = True): """ Add list of chunks to the FAISS index. Each chunk MUST contain 'text'. If index is empty and embedding dimension differs, the index will be re-created. """ with self._lock: if not chunks: logger.debug("No chunks to add.") return texts = [] for i, chunk in enumerate(chunks): if not isinstance(chunk, dict): raise ValueError(f"Chunk {i} is not a dictionary") if "text" not in chunk: raise ValueError(f"Chunk {i} missing required 'text' field") if not isinstance(chunk["text"], str): raise ValueError(f"Chunk {i} 'text' field must be a string") if not chunk["text"].strip(): logger.warning(f"Chunk {i} has empty text content") continue texts.append(chunk["text"]) # Get embeddings from the embedding provider (call to a model) embeddings = self.embeddings.embed_documents(texts) embeddings_np = np.asarray(embeddings, dtype=np.float32) # Embedding shape checks if embeddings_np.ndim == 1: # single vector returned as 1D array -> reshape to (1, d) embeddings_np = embeddings_np.reshape(1, -1) emb_d = embeddings_np.shape[1] # If needed, recreate the index dimension (only possible if index currently empty) self._ensure_index_dim(emb_d) if emb_d != self.index.d: raise ValueError(f"Embedding dim {emb_d} != index dim {self.index.d}") # L2-normalize rows (in place) so inner product == cosine similarity faiss.normalize_L2(embeddings_np) # Add to index self.index.add(embeddings_np) # The documentation of "add" suggests we have to put the number of vectors, # as a first argument, but Python does it for us. # Append documents (simple positional mapping: index position -> documents list) self.documents.extend(chunks) if save: self.save_index() def search(self, query: str, k: int = 5) -> List[Dict[str, Any]]: """ Search similar documents for `query`. Returns up to k results. Each result: { "content": , "metadata": , "similarity_score": } similarity_score is the inner product of normalized vectors => cosine similarity in [-1,1]. """ with self._lock: # guard: no vectors at all if getattr(self.index, "ntotal", 0) == 0: logger.debug("Search called but index is empty.") return [] # embed query q_emb = self.embeddings.embed_query(query) q_np = np.asarray([q_emb], dtype=np.float32) if q_np.ndim == 1: q_np = q_np.reshape(1, -1) if q_np.shape[1] != self.index.d: # if index is empty we could recreate; but at this point we know index has vectors. raise ValueError(f"Query embedding dim {q_np.shape[1]} does not match index dimension {self.index.d}") faiss.normalize_L2(q_np) # clamp k k = min(k, int(self.index.ntotal)) distances, indices = self.index.search(q_np, k) # distances shape (1,k) ; indices shape (1,k) results = [] for score, idx in zip(distances[0], indices[0]): if idx < 0: # FAISS returns -1 for "empty" slots sometimes; skip continue if idx >= len(self.documents): logger.warning("Index returned idx %d but documents list has length %d", idx, len(self.documents)) continue doc = self.documents[idx] results.append({ "content": doc.get("text"), "metadata": doc.get("metadata", {}), "similarity_score": float(score) # already cosine because of normalization }) return results def save_index(self): """Persist index and documents to disk.""" self.index_path.mkdir(parents=True, exist_ok=True) faiss.write_index(self.index, str(self.index_path / "index.faiss")) with open(self.index_path / "documents.pkl", "wb") as f: pickle.dump(self.documents, f) logger.debug("FAISS index and documents saved to %s", self.index_path) def load_index(self) -> bool: """Load index and documents from disk. Returns True if loaded.""" index_file = self.index_path / "index.faiss" docs_file = self.index_path / "documents.pkl" if index_file.exists() and docs_file.exists(): self.index = faiss.read_index(str(index_file)) with open(docs_file, "rb") as f: self.documents = pickle.load(f) # update dimension to match loaded index if getattr(self.index, "d", None) is not None: self.dimension = int(self.index.d) if self.index.d == 0 or len(self.documents) != self.index.ntotal: logger.error("Corrupted index detected, deleting...") index_file.unlink() docs_file.unlink() return False # warn if counts differ if len(self.documents) != self.index.ntotal: logger.warning( "Loaded documents list length (%d) differs from index.ntotal (%d). " "This can lead to mismatches. Using what's available.", len(self.documents), self.index.ntotal, ) logger.info("Loaded FAISS index from %s (ntotal=%d, dim=%d)", index_file, int(self.index.ntotal), int(self.index.d)) return True return False