Rag-based-api-task / src /vector_store.py
sairika's picture
Create vector_store.py
c71d352 verified
import os
import pickle
from typing import List, Dict, Any, Optional
import numpy as np
# Vector store
import faiss
from sentence_transformers import SentenceTransformer
from config import Config
class VectorStore:
"""FAISS-based vector store for document embeddings"""
def __init__(self, embedding_model: SentenceTransformer, config: Config = None):
self.config = config or Config()
self.embedding_model = embedding_model
# Get embedding dimension
self.dimension = embedding_model.get_sentence_embedding_dimension()
# Initialize FAISS index
self.index = faiss.IndexFlatIP(self.dimension) # Inner product for cosine similarity
# Storage for chunks and metadata
self.chunks = []
self.metadata = []
self.file_map = {} # Map file_id to chunk indices
print(f"βœ… Vector store initialized with dimension: {self.dimension}")
def add_documents(self, chunks: List[str], file_id: str, filename: str):
"""Add documents to vector store"""
if not chunks:
print("Warning: No chunks to add")
return
print(f"πŸ“ Adding {len(chunks)} chunks from {filename}")
try:
# Generate embeddings
embeddings = self.embedding_model.encode(
chunks,
convert_to_numpy=True,
show_progress_bar=len(chunks) > 10
)
# Ensure embeddings are float32
embeddings = embeddings.astype(np.float32)
# Normalize embeddings for cosine similarity with inner product
faiss.normalize_L2(embeddings)
# Add to FAISS index
start_idx = len(self.chunks)
self.index.add(embeddings)
# Store chunks and metadata
chunk_indices = []
for i, chunk in enumerate(chunks):
chunk_idx = start_idx + i
chunk_indices.append(chunk_idx)
self.chunks.append(chunk)
self.metadata.append({
'file_id': file_id,
'filename': filename,
'chunk_index': i,
'global_index': chunk_idx,
'text': chunk,
'embedding_added': True
})
# Update file mapping
if file_id not in self.file_map:
self.file_map[file_id] = []
self.file_map[file_id].extend(chunk_indices)
print(f"βœ… Successfully added {len(chunks)} chunks. Total chunks: {len(self.chunks)}")
except Exception as e:
print(f"❌ Error adding documents: {e}")
raise
def search(self, query: str, k: int = 5, file_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Search for similar documents"""
if len(self.chunks) == 0:
print("Warning: No documents in vector store")
return []
try:
# Generate query embedding
query_embedding = self.embedding_model.encode([query], convert_to_numpy=True)
query_embedding = query_embedding.astype(np.float32)
# Normalize for cosine similarity
faiss.normalize_L2(query_embedding)
# Search
search_k = min(k, len(self.chunks)) # Don't search for more than available
scores, indices = self.index.search(query_embedding, search_k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx != -1 and idx < len(self.chunks): # Valid index
# Filter by file_id if specified
if file_id and self.metadata[idx]['file_id'] != file_id:
continue
result = {
'text': self.chunks[idx],
'metadata': self.metadata[idx].copy(),
'score': float(score),
'similarity': float(score) # Alias for compatibility
}
results.append(result)
# Sort by score (highest first)
results.sort(key=lambda x: x['score'], reverse=True)
print(f"πŸ” Found {len(results)} results for query: '{query[:50]}...'")
return results[:k] # Return top k results
except Exception as e:
print(f"❌ Search error: {e}")
return []
def get_document_stats(self) -> Dict[str, Any]:
"""Get statistics about stored documents"""
stats = {
'total_chunks': len(self.chunks),
'total_files': len(self.file_map),
'index_size': self.index.ntotal,
'dimension': self.dimension
}
# File-wise statistics
file_stats = {}
for file_id, chunk_indices in self.file_map.items():
filename = self.metadata[chunk_indices[0]]['filename'] if chunk_indices else 'unknown'
file_stats[file_id] = {
'filename': filename,
'chunk_count': len(chunk_indices),
'chunk_indices': chunk_indices
}
stats['files'] = file_stats
return stats
def remove_file(self, file_id: str) -> bool:
"""Remove all chunks for a specific file"""
if file_id not in self.file_map:
print(f"Warning: File {file_id} not found in vector store")
return False
try:
# Get chunk indices for this file
chunk_indices = self.file_map[file_id]
# Remove from file map
del self.file_map[file_id]
# Mark chunks as removed (we can't actually remove from FAISS index)
for idx in chunk_indices:
if idx < len(self.metadata):
self.metadata[idx]['removed'] = True
print(f"βœ… Marked {len(chunk_indices)} chunks as removed for file {file_id}")
return True
except Exception as e:
print(f"❌ Error removing file {file_id}: {e}")
return False
def save(self, path: str):
"""Save vector store to disk"""
try:
os.makedirs(path, exist_ok=True)
# Save FAISS index
faiss.write_index(self.index, os.path.join(path, "index.faiss"))
# Save chunks and metadata
data = {
'chunks': self.chunks,
'metadata': self.metadata,
'file_map': self.file_map,
'dimension': self.dimension
}
with open(os.path.join(path, "data.pkl"), 'wb') as f:
pickle.dump(data, f)
print(f"βœ… Vector store saved to {path}")
except Exception as e:
print(f"❌ Error saving vector store: {e}")
raise
def load(self, path: str) -> bool:
"""Load vector store from disk"""
try:
index_path = os.path.join(path, "index.faiss")
data_path = os.path.join(path, "data.pkl")
if not (os.path.exists(index_path) and os.path.exists(data_path)):
print(f"Vector store files not found in {path}")
return False
# Load FAISS index
self.index = faiss.read_index(index_path)
# Load chunks and metadata
with open(data_path, 'rb') as f:
data = pickle.load(f)
self.chunks = data.get('chunks', [])
self.metadata = data.get('metadata', [])
self.file_map = data.get('file_map', {})
# Verify dimension consistency
saved_dimension = data.get('dimension', self.dimension)
if saved_dimension != self.dimension:
print(f"Warning: Dimension mismatch. Expected: {self.dimension}, Got: {saved_dimension}")
print(f"βœ… Vector store loaded from {path}. {len(self.chunks)} chunks, {len(self.file_map)} files")
return True
except Exception as e:
print(f"❌ Error loading vector store: {e}")
return False
def reset(self):
"""Reset vector store (clear all data)"""
try:
# Reinitialize FAISS index
self.index = faiss.IndexFlatIP(self.dimension)
# Clear data
self.chunks = []
self.metadata = []
self.file_map = {}
print("βœ… Vector store reset successfully")
except Exception as e:
print(f"❌ Error resetting vector store: {e}")
raise
def get_chunk_by_index(self, index: int) -> Optional[Dict[str, Any]]:
"""Get chunk by global index"""
if 0 <= index < len(self.chunks):
return {
'text': self.chunks[index],
'metadata': self.metadata[index]
}
return None
def search_by_file(self, file_id: str, query: str = "", k: int = 10) -> List[Dict[str, Any]]:
"""Get all chunks for a specific file, optionally filtered by query"""
if file_id not in self.file_map:
return []
chunk_indices = self.file_map[file_id]
results = []
for idx in chunk_indices:
if idx < len(self.chunks):
# Skip removed chunks
if self.metadata[idx].get('removed', False):
continue
result = {
'text': self.chunks[idx],
'metadata': self.metadata[idx].copy(),
'score': 1.0, # No scoring for file-based retrieval
'global_index': idx
}
results.append(result)
# If query provided, filter results
if query:
# Simple text matching (can be enhanced with embedding similarity)
query_lower = query.lower()
filtered_results = []
for result in results:
if query_lower in result['text'].lower():
filtered_results.append(result)
results = filtered_results
return results[:k]
def optimize_index(self):
"""Optimize FAISS index (placeholder for future enhancements)"""
# For now, just print stats
stats = self.get_document_stats()
print(f"πŸ“Š Index stats: {stats['total_chunks']} chunks, {stats['total_files']} files")
# In the future, we could:
# - Remove deleted chunks and rebuild index
# - Switch to more efficient index types (IVF, HNSW)
# - Compress embeddings
pass