File size: 6,091 Bytes
8a1c0d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
"""FAISS vector store for document retrieval."""

import json
import pickle
from pathlib import Path
from typing import Optional

import faiss
import numpy as np
from pydantic import BaseModel

from src.config import settings
from src.document_processor.chunker import DocumentChunk
from src.knowledge.embeddings import EmbeddingModel


class RetrievalResult(BaseModel):
    """Result from vector store retrieval."""

    chunk: DocumentChunk
    score: float
    rank: int

    class Config:
        arbitrary_types_allowed = True


class FAISSVectorStore:
    """FAISS-based vector store for efficient similarity search.

    Stores document chunks with their embeddings and provides
    fast retrieval with source tracking for citations.
    """

    def __init__(
        self,
        embedding_model: Optional[EmbeddingModel] = None,
        index_path: Optional[Path] = None,
    ):
        """Initialize the vector store.

        Args:
            embedding_model: Model for generating embeddings.
            index_path: Path to store/load the FAISS index.
        """
        self.embedding_model = embedding_model or EmbeddingModel()
        self.index_path = Path(index_path or settings.faiss_index_path)

        self._index: Optional[faiss.IndexFlatIP] = None
        self._chunks: list[DocumentChunk] = []
        self._is_loaded = False

    def _ensure_directory(self) -> None:
        """Ensure the index directory exists."""
        self.index_path.parent.mkdir(parents=True, exist_ok=True)

    def _create_index(self, dimension: int) -> faiss.IndexFlatIP:
        """Create a new FAISS index.

        Uses Inner Product (IP) since embeddings are normalized.
        """
        return faiss.IndexFlatIP(dimension)

    def add_chunks(self, chunks: list[DocumentChunk]) -> int:
        """Add document chunks to the vector store.

        Args:
            chunks: List of DocumentChunks to add.

        Returns:
            Number of chunks added.
        """
        if not chunks:
            return 0

        # Generate embeddings
        chunk_embeddings = self.embedding_model.embed_chunks(chunks)

        # Initialize index if needed
        if self._index is None:
            dimension = self.embedding_model.embedding_dimension
            self._index = self._create_index(dimension)

        # Add to index
        embeddings_array = np.vstack([emb for _, emb in chunk_embeddings])
        self._index.add(embeddings_array)

        # Store chunks for retrieval
        for chunk, _ in chunk_embeddings:
            self._chunks.append(chunk)

        return len(chunks)

    def search(
        self,
        query: str,
        top_k: int = None,
        min_score: float = None,
    ) -> list[RetrievalResult]:
        """Search for relevant chunks.

        Args:
            query: Search query.
            top_k: Number of results to return.
            min_score: Minimum similarity score threshold.

        Returns:
            List of RetrievalResults ordered by relevance.
        """
        if self._index is None or self._index.ntotal == 0:
            return []

        top_k = top_k or settings.retrieval_top_k
        min_score = min_score or settings.retrieval_min_score

        # Embed query
        query_embedding = self.embedding_model.embed_query(query)
        query_embedding = query_embedding.reshape(1, -1)

        # Search
        scores, indices = self._index.search(query_embedding, min(top_k, self._index.ntotal))

        # Build results
        results = []
        for rank, (score, idx) in enumerate(zip(scores[0], indices[0])):
            if idx < 0 or score < min_score:
                continue

            chunk = self._chunks[idx]
            results.append(
                RetrievalResult(
                    chunk=chunk,
                    score=float(score),
                    rank=rank + 1,
                )
            )

        return results

    def save(self) -> None:
        """Save the index and chunks to disk."""
        if self._index is None:
            return

        self._ensure_directory()

        # Save FAISS index
        index_file = self.index_path.with_suffix(".faiss")
        faiss.write_index(self._index, str(index_file))

        # Save chunks as JSON
        chunks_file = self.index_path.with_suffix(".chunks.json")
        chunks_data = [chunk.model_dump() for chunk in self._chunks]
        chunks_file.write_text(json.dumps(chunks_data, indent=2), encoding="utf-8")

    def load(self) -> bool:
        """Load the index and chunks from disk.

        Returns:
            True if loaded successfully, False otherwise.
        """
        index_file = self.index_path.with_suffix(".faiss")
        chunks_file = self.index_path.with_suffix(".chunks.json")

        if not index_file.exists() or not chunks_file.exists():
            return False

        try:
            # Load FAISS index
            self._index = faiss.read_index(str(index_file))

            # Load chunks
            chunks_data = json.loads(chunks_file.read_text(encoding="utf-8"))
            self._chunks = [DocumentChunk.model_validate(c) for c in chunks_data]

            self._is_loaded = True
            return True

        except Exception as e:
            print(f"Error loading index: {e}")
            return False

    def clear(self) -> None:
        """Clear the index and all stored chunks."""
        self._index = None
        self._chunks = []
        self._is_loaded = False

        # Remove files if they exist
        index_file = self.index_path.with_suffix(".faiss")
        chunks_file = self.index_path.with_suffix(".chunks.json")

        if index_file.exists():
            index_file.unlink()
        if chunks_file.exists():
            chunks_file.unlink()

    @property
    def size(self) -> int:
        """Get the number of chunks in the store."""
        return len(self._chunks)

    def get_sources(self) -> list[str]:
        """Get list of unique source files in the store."""
        return list(set(chunk.source_file for chunk in self._chunks))