Spaces:
Sleeping
Sleeping
File size: 6,091 Bytes
8a1c0d1 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 | """FAISS vector store for document retrieval."""
import json
import pickle
from pathlib import Path
from typing import Optional
import faiss
import numpy as np
from pydantic import BaseModel
from src.config import settings
from src.document_processor.chunker import DocumentChunk
from src.knowledge.embeddings import EmbeddingModel
class RetrievalResult(BaseModel):
"""Result from vector store retrieval."""
chunk: DocumentChunk
score: float
rank: int
class Config:
arbitrary_types_allowed = True
class FAISSVectorStore:
"""FAISS-based vector store for efficient similarity search.
Stores document chunks with their embeddings and provides
fast retrieval with source tracking for citations.
"""
def __init__(
self,
embedding_model: Optional[EmbeddingModel] = None,
index_path: Optional[Path] = None,
):
"""Initialize the vector store.
Args:
embedding_model: Model for generating embeddings.
index_path: Path to store/load the FAISS index.
"""
self.embedding_model = embedding_model or EmbeddingModel()
self.index_path = Path(index_path or settings.faiss_index_path)
self._index: Optional[faiss.IndexFlatIP] = None
self._chunks: list[DocumentChunk] = []
self._is_loaded = False
def _ensure_directory(self) -> None:
"""Ensure the index directory exists."""
self.index_path.parent.mkdir(parents=True, exist_ok=True)
def _create_index(self, dimension: int) -> faiss.IndexFlatIP:
"""Create a new FAISS index.
Uses Inner Product (IP) since embeddings are normalized.
"""
return faiss.IndexFlatIP(dimension)
def add_chunks(self, chunks: list[DocumentChunk]) -> int:
"""Add document chunks to the vector store.
Args:
chunks: List of DocumentChunks to add.
Returns:
Number of chunks added.
"""
if not chunks:
return 0
# Generate embeddings
chunk_embeddings = self.embedding_model.embed_chunks(chunks)
# Initialize index if needed
if self._index is None:
dimension = self.embedding_model.embedding_dimension
self._index = self._create_index(dimension)
# Add to index
embeddings_array = np.vstack([emb for _, emb in chunk_embeddings])
self._index.add(embeddings_array)
# Store chunks for retrieval
for chunk, _ in chunk_embeddings:
self._chunks.append(chunk)
return len(chunks)
def search(
self,
query: str,
top_k: int = None,
min_score: float = None,
) -> list[RetrievalResult]:
"""Search for relevant chunks.
Args:
query: Search query.
top_k: Number of results to return.
min_score: Minimum similarity score threshold.
Returns:
List of RetrievalResults ordered by relevance.
"""
if self._index is None or self._index.ntotal == 0:
return []
top_k = top_k or settings.retrieval_top_k
min_score = min_score or settings.retrieval_min_score
# Embed query
query_embedding = self.embedding_model.embed_query(query)
query_embedding = query_embedding.reshape(1, -1)
# Search
scores, indices = self._index.search(query_embedding, min(top_k, self._index.ntotal))
# Build results
results = []
for rank, (score, idx) in enumerate(zip(scores[0], indices[0])):
if idx < 0 or score < min_score:
continue
chunk = self._chunks[idx]
results.append(
RetrievalResult(
chunk=chunk,
score=float(score),
rank=rank + 1,
)
)
return results
def save(self) -> None:
"""Save the index and chunks to disk."""
if self._index is None:
return
self._ensure_directory()
# Save FAISS index
index_file = self.index_path.with_suffix(".faiss")
faiss.write_index(self._index, str(index_file))
# Save chunks as JSON
chunks_file = self.index_path.with_suffix(".chunks.json")
chunks_data = [chunk.model_dump() for chunk in self._chunks]
chunks_file.write_text(json.dumps(chunks_data, indent=2), encoding="utf-8")
def load(self) -> bool:
"""Load the index and chunks from disk.
Returns:
True if loaded successfully, False otherwise.
"""
index_file = self.index_path.with_suffix(".faiss")
chunks_file = self.index_path.with_suffix(".chunks.json")
if not index_file.exists() or not chunks_file.exists():
return False
try:
# Load FAISS index
self._index = faiss.read_index(str(index_file))
# Load chunks
chunks_data = json.loads(chunks_file.read_text(encoding="utf-8"))
self._chunks = [DocumentChunk.model_validate(c) for c in chunks_data]
self._is_loaded = True
return True
except Exception as e:
print(f"Error loading index: {e}")
return False
def clear(self) -> None:
"""Clear the index and all stored chunks."""
self._index = None
self._chunks = []
self._is_loaded = False
# Remove files if they exist
index_file = self.index_path.with_suffix(".faiss")
chunks_file = self.index_path.with_suffix(".chunks.json")
if index_file.exists():
index_file.unlink()
if chunks_file.exists():
chunks_file.unlink()
@property
def size(self) -> int:
"""Get the number of chunks in the store."""
return len(self._chunks)
def get_sources(self) -> list[str]:
"""Get list of unique source files in the store."""
return list(set(chunk.source_file for chunk in self._chunks))
|