SuoMoto.AI / utils /vector_store.py
cryogenic22's picture
Update utils/vector_store.py
4339a4c verified
# utils/vector_store.py
import faiss
import numpy as np
from typing import List, Dict, Optional
import pickle
import os
from pathlib import Path
class VectorStore:
def __init__(self):
# Use absolute path for HF Spaces
self.persist_directory = "/data/faiss"
self.index = None
self.documents = []
self.metadata = []
# Ensure directories exist
self._create_data_directories()
# Try to load existing index and data
self._load_or_create_index()
def _create_data_directories(self):
"""Create necessary data directories"""
# Create main data directory
Path("/data").mkdir(parents=True, exist_ok=True)
# Create FAISS specific directory
Path(self.persist_directory).mkdir(parents=True, exist_ok=True)
# Create uploads directory
Path("/data/uploads").mkdir(parents=True, exist_ok=True)
def _load_or_create_index(self):
"""Load existing index or create new one"""
index_path = os.path.join(self.persist_directory, "faiss.index")
data_path = os.path.join(self.persist_directory, "documents.pkl")
try:
if os.path.exists(index_path) and os.path.exists(data_path):
print(f"Loading existing index from {index_path}")
# Load existing index
self.index = faiss.read_index(index_path)
# Load documents and metadata
with open(data_path, 'rb') as f:
data = pickle.load(f)
self.documents = data['documents']
self.metadata = data['metadata']
print(f"Loaded {len(self.documents)} documents from existing index")
else:
print("No existing index found, creating new one")
# Create new index
self.index = None # Will be created when first vectors are added
self.documents = []
self.metadata = []
except Exception as e:
print(f"Error loading index: {e}")
self.index = None
self.documents = []
self.metadata = []
def _save_index(self):
"""Save index and data to disk"""
if self.index is not None:
index_path = os.path.join(self.persist_directory, "faiss.index")
data_path = os.path.join(self.persist_directory, "documents.pkl")
try:
# Save FAISS index
faiss.write_index(self.index, index_path)
# Save documents and metadata
with open(data_path, 'wb') as f:
pickle.dump({
'documents': self.documents,
'metadata': self.metadata
}, f)
except Exception as e:
print(f"Error saving index: {e}")
def add_documents(self, chunks: List[Dict], metadata: Optional[Dict] = None):
"""Add document chunks to vector store"""
if not chunks:
return
# Extract vectors and documents
vectors = np.array([chunk["embeddings"] for chunk in chunks])
# Create index if it doesn't exist
if self.index is None:
dimension = vectors.shape[1]
self.index = faiss.IndexFlatL2(dimension)
# Add vectors to index
self.index.add(vectors.astype(np.float32))
# Store documents and metadata
for chunk in chunks:
chunk_metadata = {
"chunk_id": len(self.documents),
"text_length": len(chunk["text"])
}
if metadata:
chunk_metadata.update(metadata)
self.documents.append(chunk["text"])
self.metadata.append(chunk_metadata)
# Save updated index
self._save_index()
def search(self, query_vector: np.ndarray, n_results: int = 5) -> List[Dict]:
"""Search for similar documents"""
if self.index is None or self.index.ntotal == 0:
return []
# Reshape query vector if needed
if len(query_vector.shape) == 1:
query_vector = query_vector.reshape(1, -1)
# Perform search
distances, indices = self.index.search(query_vector.astype(np.float32), n_results)
# Format results
results = []
for i, (dist, idx) in enumerate(zip(distances[0], indices[0])):
if idx < len(self.documents): # Check if index is valid
results.append({
"text": self.documents[idx],
"metadata": self.metadata[idx],
"distance": float(dist)
})
return results
def get_all_documents(self) -> List[Dict]:
"""Get all stored documents"""
return [
{"text": doc, "metadata": meta}
for doc, meta in zip(self.documents, self.metadata)
]