"""FAISS vector store for document retrieval.""" import json import pickle from pathlib import Path from typing import Optional import faiss import numpy as np from pydantic import BaseModel from src.config import settings from src.document_processor.chunker import DocumentChunk from src.knowledge.embeddings import EmbeddingModel class RetrievalResult(BaseModel): """Result from vector store retrieval.""" chunk: DocumentChunk score: float rank: int class Config: arbitrary_types_allowed = True class FAISSVectorStore: """FAISS-based vector store for efficient similarity search. Stores document chunks with their embeddings and provides fast retrieval with source tracking for citations. """ def __init__( self, embedding_model: Optional[EmbeddingModel] = None, index_path: Optional[Path] = None, ): """Initialize the vector store. Args: embedding_model: Model for generating embeddings. index_path: Path to store/load the FAISS index. """ self.embedding_model = embedding_model or EmbeddingModel() self.index_path = Path(index_path or settings.faiss_index_path) self._index: Optional[faiss.IndexFlatIP] = None self._chunks: list[DocumentChunk] = [] self._is_loaded = False def _ensure_directory(self) -> None: """Ensure the index directory exists.""" self.index_path.parent.mkdir(parents=True, exist_ok=True) def _create_index(self, dimension: int) -> faiss.IndexFlatIP: """Create a new FAISS index. Uses Inner Product (IP) since embeddings are normalized. """ return faiss.IndexFlatIP(dimension) def add_chunks(self, chunks: list[DocumentChunk]) -> int: """Add document chunks to the vector store. Args: chunks: List of DocumentChunks to add. Returns: Number of chunks added. """ if not chunks: return 0 # Generate embeddings chunk_embeddings = self.embedding_model.embed_chunks(chunks) # Initialize index if needed if self._index is None: dimension = self.embedding_model.embedding_dimension self._index = self._create_index(dimension) # Add to index embeddings_array = np.vstack([emb for _, emb in chunk_embeddings]) self._index.add(embeddings_array) # Store chunks for retrieval for chunk, _ in chunk_embeddings: self._chunks.append(chunk) return len(chunks) def search( self, query: str, top_k: int = None, min_score: float = None, ) -> list[RetrievalResult]: """Search for relevant chunks. Args: query: Search query. top_k: Number of results to return. min_score: Minimum similarity score threshold. Returns: List of RetrievalResults ordered by relevance. """ if self._index is None or self._index.ntotal == 0: return [] top_k = top_k or settings.retrieval_top_k min_score = min_score or settings.retrieval_min_score # Embed query query_embedding = self.embedding_model.embed_query(query) query_embedding = query_embedding.reshape(1, -1) # Search scores, indices = self._index.search(query_embedding, min(top_k, self._index.ntotal)) # Build results results = [] for rank, (score, idx) in enumerate(zip(scores[0], indices[0])): if idx < 0 or score < min_score: continue chunk = self._chunks[idx] results.append( RetrievalResult( chunk=chunk, score=float(score), rank=rank + 1, ) ) return results def save(self) -> None: """Save the index and chunks to disk.""" if self._index is None: return self._ensure_directory() # Save FAISS index index_file = self.index_path.with_suffix(".faiss") faiss.write_index(self._index, str(index_file)) # Save chunks as JSON chunks_file = self.index_path.with_suffix(".chunks.json") chunks_data = [chunk.model_dump() for chunk in self._chunks] chunks_file.write_text(json.dumps(chunks_data, indent=2), encoding="utf-8") def load(self) -> bool: """Load the index and chunks from disk. Returns: True if loaded successfully, False otherwise. """ index_file = self.index_path.with_suffix(".faiss") chunks_file = self.index_path.with_suffix(".chunks.json") if not index_file.exists() or not chunks_file.exists(): return False try: # Load FAISS index self._index = faiss.read_index(str(index_file)) # Load chunks chunks_data = json.loads(chunks_file.read_text(encoding="utf-8")) self._chunks = [DocumentChunk.model_validate(c) for c in chunks_data] self._is_loaded = True return True except Exception as e: print(f"Error loading index: {e}") return False def clear(self) -> None: """Clear the index and all stored chunks.""" self._index = None self._chunks = [] self._is_loaded = False # Remove files if they exist index_file = self.index_path.with_suffix(".faiss") chunks_file = self.index_path.with_suffix(".chunks.json") if index_file.exists(): index_file.unlink() if chunks_file.exists(): chunks_file.unlink() @property def size(self) -> int: """Get the number of chunks in the store.""" return len(self._chunks) def get_sources(self) -> list[str]: """Get list of unique source files in the store.""" return list(set(chunk.source_file for chunk in self._chunks))