Spaces:
Sleeping
Sleeping
File size: 6,396 Bytes
f9ad313 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
"""
FAISS Vector Store for RAG.
Manages the FAISS index for semantic search over database text content.
"""
import logging
import pickle
import os
from typing import List, Dict, Any, Optional, Tuple
import numpy as np
try:
import faiss
except ImportError:
faiss = None
from .document_processor import Document
from .embeddings import get_embedding_provider, EmbeddingProvider
logger = logging.getLogger(__name__)
class VectorStore:
"""FAISS-based vector store for semantic search."""
def __init__(
self,
embedding_provider: Optional[EmbeddingProvider] = None,
index_path: str = "./faiss_index"
):
if faiss is None:
raise ImportError("faiss-cpu is required. Install with: pip install faiss-cpu")
self.embedding_provider = embedding_provider or get_embedding_provider()
self.index_path = index_path
self.dimension = self.embedding_provider.dimension
self.index: Optional[faiss.IndexFlatIP] = None
self.documents: List[Document] = []
self.id_to_idx: Dict[str, int] = {}
self._initialize_index()
def _initialize_index(self):
"""Initialize or load the FAISS index."""
index_file = os.path.join(self.index_path, "index.faiss")
docs_file = os.path.join(self.index_path, "documents.pkl")
if os.path.exists(index_file) and os.path.exists(docs_file):
try:
# Check file size - if 0 something is wrong
if os.path.getsize(index_file) > 0:
self.index = faiss.read_index(index_file)
with open(docs_file, 'rb') as f:
self.documents = pickle.load(f)
self.id_to_idx = {doc.id: i for i, doc in enumerate(self.documents)}
# Verify index dimension matches expected
if self.index.d != self.dimension:
logger.warning(f"Index dimension mismatch: {self.index.d} != {self.dimension}. Resetting.")
raise ValueError("Dimension mismatch")
logger.info(f"Loaded index with {len(self.documents)} documents")
return
except (Exception, RuntimeError) as e:
logger.warning(f"Failed to load index (might be corrupted or memory error): {e}")
# If loading fails, we should probably backup the broken files or just overwrite
if os.path.exists(index_file):
try:
os.rename(index_file, index_file + ".bak")
os.rename(docs_file, docs_file + ".bak")
except:
pass
# Create new index (Inner Product for cosine similarity with normalized vectors)
self.index = faiss.IndexFlatIP(self.dimension)
self.documents = []
self.id_to_idx = {}
logger.info(f"Created new FAISS index with dimension {self.dimension}")
def add_documents(self, documents: List[Document], batch_size: int = 100):
"""Add documents to the vector store."""
if not documents:
return
new_docs = [doc for doc in documents if doc.id not in self.id_to_idx]
if not new_docs:
logger.info("No new documents to add")
return
logger.info(f"Adding {len(new_docs)} documents to index")
for i in range(0, len(new_docs), batch_size):
batch = new_docs[i:i + batch_size]
texts = [doc.content for doc in batch]
embeddings = self.embedding_provider.embed_texts(texts)
# Normalize for cosine similarity
faiss.normalize_L2(embeddings)
start_idx = len(self.documents)
self.index.add(embeddings)
for j, doc in enumerate(batch):
self.documents.append(doc)
self.id_to_idx[doc.id] = start_idx + j
logger.info(f"Index now contains {len(self.documents)} documents")
def search(
self, query: str, top_k: int = 5, threshold: float = 0.0
) -> List[Tuple[Document, float]]:
"""Search for similar documents."""
if not self.documents:
return []
query_embedding = self.embedding_provider.embed_text(query)
query_embedding = query_embedding.reshape(1, -1).astype(np.float32)
faiss.normalize_L2(query_embedding)
k = min(top_k, len(self.documents))
scores, indices = self.index.search(query_embedding, k)
results = []
for score, idx in zip(scores[0], indices[0]):
if idx >= 0 and score >= threshold:
results.append((self.documents[idx], float(score)))
return results
def save(self):
"""Save the index to disk."""
os.makedirs(self.index_path, exist_ok=True)
index_file = os.path.join(self.index_path, "index.faiss")
docs_file = os.path.join(self.index_path, "documents.pkl")
faiss.write_index(self.index, index_file)
with open(docs_file, 'wb') as f:
pickle.dump(self.documents, f)
logger.info(f"Saved index with {len(self.documents)} documents")
def clear(self):
"""Clear the index."""
self.index = faiss.IndexFlatIP(self.dimension)
self.documents = []
self.id_to_idx = {}
# Delete files
index_file = os.path.join(self.index_path, "index.faiss")
docs_file = os.path.join(self.index_path, "documents.pkl")
for f in [index_file, docs_file]:
if os.path.exists(f):
os.remove(f)
logger.info("Index cleared")
def __len__(self) -> int:
return len(self.documents)
_vector_store: Optional[VectorStore] = None
def get_vector_store() -> VectorStore:
global _vector_store
if _vector_store is None:
_vector_store = VectorStore()
return _vector_store
|