hrbot / src /knowledge /vector_store.py
Sonu Prasad
updated
8a1c0d1
"""FAISS vector store for document retrieval."""
import json
import pickle
from pathlib import Path
from typing import Optional
import faiss
import numpy as np
from pydantic import BaseModel
from src.config import settings
from src.document_processor.chunker import DocumentChunk
from src.knowledge.embeddings import EmbeddingModel
class RetrievalResult(BaseModel):
"""Result from vector store retrieval."""
chunk: DocumentChunk
score: float
rank: int
class Config:
arbitrary_types_allowed = True
class FAISSVectorStore:
"""FAISS-based vector store for efficient similarity search.
Stores document chunks with their embeddings and provides
fast retrieval with source tracking for citations.
"""
def __init__(
self,
embedding_model: Optional[EmbeddingModel] = None,
index_path: Optional[Path] = None,
):
"""Initialize the vector store.
Args:
embedding_model: Model for generating embeddings.
index_path: Path to store/load the FAISS index.
"""
self.embedding_model = embedding_model or EmbeddingModel()
self.index_path = Path(index_path or settings.faiss_index_path)
self._index: Optional[faiss.IndexFlatIP] = None
self._chunks: list[DocumentChunk] = []
self._is_loaded = False
def _ensure_directory(self) -> None:
"""Ensure the index directory exists."""
self.index_path.parent.mkdir(parents=True, exist_ok=True)
def _create_index(self, dimension: int) -> faiss.IndexFlatIP:
"""Create a new FAISS index.
Uses Inner Product (IP) since embeddings are normalized.
"""
return faiss.IndexFlatIP(dimension)
def add_chunks(self, chunks: list[DocumentChunk]) -> int:
"""Add document chunks to the vector store.
Args:
chunks: List of DocumentChunks to add.
Returns:
Number of chunks added.
"""
if not chunks:
return 0
# Generate embeddings
chunk_embeddings = self.embedding_model.embed_chunks(chunks)
# Initialize index if needed
if self._index is None:
dimension = self.embedding_model.embedding_dimension
self._index = self._create_index(dimension)
# Add to index
embeddings_array = np.vstack([emb for _, emb in chunk_embeddings])
self._index.add(embeddings_array)
# Store chunks for retrieval
for chunk, _ in chunk_embeddings:
self._chunks.append(chunk)
return len(chunks)
def search(
self,
query: str,
top_k: int = None,
min_score: float = None,
) -> list[RetrievalResult]:
"""Search for relevant chunks.
Args:
query: Search query.
top_k: Number of results to return.
min_score: Minimum similarity score threshold.
Returns:
List of RetrievalResults ordered by relevance.
"""
if self._index is None or self._index.ntotal == 0:
return []
top_k = top_k or settings.retrieval_top_k
min_score = min_score or settings.retrieval_min_score
# Embed query
query_embedding = self.embedding_model.embed_query(query)
query_embedding = query_embedding.reshape(1, -1)
# Search
scores, indices = self._index.search(query_embedding, min(top_k, self._index.ntotal))
# Build results
results = []
for rank, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < 0 or score < min_score:
continue
chunk = self._chunks[idx]
results.append(
RetrievalResult(
chunk=chunk,
score=float(score),
rank=rank + 1,
)
)
return results
def save(self) -> None:
"""Save the index and chunks to disk."""
if self._index is None:
return
self._ensure_directory()
# Save FAISS index
index_file = self.index_path.with_suffix(".faiss")
faiss.write_index(self._index, str(index_file))
# Save chunks as JSON
chunks_file = self.index_path.with_suffix(".chunks.json")
chunks_data = [chunk.model_dump() for chunk in self._chunks]
chunks_file.write_text(json.dumps(chunks_data, indent=2), encoding="utf-8")
def load(self) -> bool:
"""Load the index and chunks from disk.
Returns:
True if loaded successfully, False otherwise.
"""
index_file = self.index_path.with_suffix(".faiss")
chunks_file = self.index_path.with_suffix(".chunks.json")
if not index_file.exists() or not chunks_file.exists():
return False
try:
# Load FAISS index
self._index = faiss.read_index(str(index_file))
# Load chunks
chunks_data = json.loads(chunks_file.read_text(encoding="utf-8"))
self._chunks = [DocumentChunk.model_validate(c) for c in chunks_data]
self._is_loaded = True
return True
except Exception as e:
print(f"Error loading index: {e}")
return False
def clear(self) -> None:
"""Clear the index and all stored chunks."""
self._index = None
self._chunks = []
self._is_loaded = False
# Remove files if they exist
index_file = self.index_path.with_suffix(".faiss")
chunks_file = self.index_path.with_suffix(".chunks.json")
if index_file.exists():
index_file.unlink()
if chunks_file.exists():
chunks_file.unlink()
@property
def size(self) -> int:
"""Get the number of chunks in the store."""
return len(self._chunks)
def get_sources(self) -> list[str]:
"""Get list of unique source files in the store."""
return list(set(chunk.source_file for chunk in self._chunks))