DB_Chatbot / rag /vector_store.py
Vanshcc's picture
Upload 34 files
f9ad313 verified
"""
FAISS Vector Store for RAG.
Manages the FAISS index for semantic search over database text content.
"""
import logging
import pickle
import os
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
try:
import faiss
except ImportError:
faiss = None
from .document_processor import Document
from .embeddings import get_embedding_provider, EmbeddingProvider
logger = logging.getLogger(__name__)
class VectorStore:
"""FAISS-based vector store for semantic search."""
def __init__(
self,
embedding_provider: Optional[EmbeddingProvider] = None,
index_path: str = "./faiss_index"
):
if faiss is None:
raise ImportError("faiss-cpu is required. Install with: pip install faiss-cpu")
self.embedding_provider = embedding_provider or get_embedding_provider()
self.index_path = index_path
self.dimension = self.embedding_provider.dimension
self.index: Optional[faiss.IndexFlatIP] = None
self.documents: List[Document] = []
self.id_to_idx: Dict[str, int] = {}
self._initialize_index()
def _initialize_index(self):
"""Initialize or load the FAISS index."""
index_file = os.path.join(self.index_path, "index.faiss")
docs_file = os.path.join(self.index_path, "documents.pkl")
if os.path.exists(index_file) and os.path.exists(docs_file):
try:
# Check file size - if 0 something is wrong
if os.path.getsize(index_file) > 0:
self.index = faiss.read_index(index_file)
with open(docs_file, 'rb') as f:
self.documents = pickle.load(f)
self.id_to_idx = {doc.id: i for i, doc in enumerate(self.documents)}
# Verify index dimension matches expected
if self.index.d != self.dimension:
logger.warning(f"Index dimension mismatch: {self.index.d} != {self.dimension}. Resetting.")
raise ValueError("Dimension mismatch")
logger.info(f"Loaded index with {len(self.documents)} documents")
return
except (Exception, RuntimeError) as e:
logger.warning(f"Failed to load index (might be corrupted or memory error): {e}")
# If loading fails, we should probably backup the broken files or just overwrite
if os.path.exists(index_file):
try:
os.rename(index_file, index_file + ".bak")
os.rename(docs_file, docs_file + ".bak")
except:
pass
# Create new index (Inner Product for cosine similarity with normalized vectors)
self.index = faiss.IndexFlatIP(self.dimension)
self.documents = []
self.id_to_idx = {}
logger.info(f"Created new FAISS index with dimension {self.dimension}")
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""Add documents to the vector store."""
if not documents:
return
new_docs = [doc for doc in documents if doc.id not in self.id_to_idx]
if not new_docs:
logger.info("No new documents to add")
return
logger.info(f"Adding {len(new_docs)} documents to index")
for i in range(0, len(new_docs), batch_size):
batch = new_docs[i:i + batch_size]
texts = [doc.content for doc in batch]
embeddings = self.embedding_provider.embed_texts(texts)
# Normalize for cosine similarity
faiss.normalize_L2(embeddings)
start_idx = len(self.documents)
self.index.add(embeddings)
for j, doc in enumerate(batch):
self.documents.append(doc)
self.id_to_idx[doc.id] = start_idx + j
logger.info(f"Index now contains {len(self.documents)} documents")
def search(
self, query: str, top_k: int = 5, threshold: float = 0.0
) -> List[Tuple[Document, float]]:
"""Search for similar documents."""
if not self.documents:
return []
query_embedding = self.embedding_provider.embed_text(query)
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
faiss.normalize_L2(query_embedding)
k = min(top_k, len(self.documents))
scores, indices = self.index.search(query_embedding, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0 and score >= threshold:
results.append((self.documents[idx], float(score)))
return results
def save(self):
"""Save the index to disk."""
os.makedirs(self.index_path, exist_ok=True)
index_file = os.path.join(self.index_path, "index.faiss")
docs_file = os.path.join(self.index_path, "documents.pkl")
faiss.write_index(self.index, index_file)
with open(docs_file, 'wb') as f:
pickle.dump(self.documents, f)
logger.info(f"Saved index with {len(self.documents)} documents")
def clear(self):
"""Clear the index."""
self.index = faiss.IndexFlatIP(self.dimension)
self.documents = []
self.id_to_idx = {}
# Delete files
index_file = os.path.join(self.index_path, "index.faiss")
docs_file = os.path.join(self.index_path, "documents.pkl")
for f in [index_file, docs_file]:
if os.path.exists(f):
os.remove(f)
logger.info("Index cleared")
def __len__(self) -> int:
return len(self.documents)
_vector_store: Optional[VectorStore] = None
def get_vector_store() -> VectorStore:
global _vector_store
if _vector_store is None:
_vector_store = VectorStore()
return _vector_store