File size: 4,863 Bytes
6085b61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""FAISS-based vector store for code chunk retrieval.

Uses inner-product (cosine similarity on L2-normalised vectors).
Falls back to brute-force numpy search when FAISS is unavailable.
"""

from __future__ import annotations

import pickle
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np

from indexing.parser import CodeChunk

DEFAULT_INDEX_PATH = str(Path(__file__).parent / "code_index.faiss")


class CodeVectorStore:
    """FAISS index + metadata for code chunk similarity search."""

    def __init__(self, index_path: str = DEFAULT_INDEX_PATH):
        self.index_path = Path(index_path)
        self.metadata_path = self.index_path.with_suffix(".pkl")
        self.dimension = 384  # all-MiniLM-L6-v2 output dim
        self._index = None
        self.metadata: List[CodeChunk] = []
        self._use_fallback = False

    def build(self, chunks: List[CodeChunk], embeddings: np.ndarray) -> None:
        """Build index from chunks and their embeddings."""
        self.metadata = chunks
        try:
            import faiss

            embeddings = embeddings.astype(np.float32)
            faiss.normalize_L2(embeddings)
            self._index = faiss.IndexFlatIP(self.dimension)
            self._index.add(embeddings)
            self._use_fallback = False
        except ImportError:
            warnings.warn("faiss not available — using brute-force numpy search")
            self._use_fallback = True
            self._fallback_embeddings = embeddings.copy()
        self.save()

    def search(
        self, query_embedding: np.ndarray, k: int = 5
    ) -> List[Tuple[CodeChunk, float]]:
        """Return top-k (chunk, cosine_similarity) matches."""
        if not self.metadata:
            return []

        if not self._use_fallback and self._index is None:
            self.load()

        query = query_embedding.astype(np.float32).reshape(1, -1)

        if self._use_fallback or self._index is None:
            return self._fallback_search(query, k)

        import faiss

        faiss.normalize_L2(query)
        distances, indices = self._index.search(query, k)
        results: List[Tuple[CodeChunk, float]] = []
        for idx, dist in zip(indices[0], distances[0]):
            if 0 <= idx < len(self.metadata):
                results.append((self.metadata[idx], float(dist)))
        return results

    def _fallback_search(
        self, query: np.ndarray, k: int
    ) -> List[Tuple[CodeChunk, float]]:
        """Brute-force cosine similarity when FAISS is unavailable."""
        if not hasattr(self, "_fallback_embeddings"):
            return []
        query_norm = query / (np.linalg.norm(query) + 1e-12)
        emb_norm = self._fallback_embeddings / (
            np.linalg.norm(self._fallback_embeddings, axis=1, keepdims=True) + 1e-12
        )
        scores = emb_norm @ query_norm.T
        scores = scores.flatten()
        top_k = min(k, len(scores))
        indices = np.argsort(-scores)[:top_k]
        results: List[Tuple[CodeChunk, float]] = []
        for idx in indices:
            results.append((self.metadata[idx], float(scores[idx])))
        return results

    def save(self) -> None:
        """Persist index and metadata to disk."""
        if not self._use_fallback:
            try:
                import faiss

                faiss.write_index(self._index, str(self.index_path))
            except Exception:
                pass
        # Always save metadata and fallback embeddings
        payload = {
            "metadata": self.metadata,
            "fallback_embeddings": getattr(self, "_fallback_embeddings", None),
        }
        self.metadata_path.write_bytes(pickle.dumps(payload))

    def load(self) -> bool:
        """Load index and metadata from disk. Returns True on success."""
        if not self.index_path.exists() and not self.metadata_path.exists():
            return False

        # Load metadata
        if self.metadata_path.exists():
            try:
                payload = pickle.loads(self.metadata_path.read_bytes())
                self.metadata = payload.get("metadata", [])
                fb_emb = payload.get("fallback_embeddings")
                if fb_emb is not None:
                    self._fallback_embeddings = fb_emb
                    self._use_fallback = True
            except Exception:
                return False

        # Load FAISS index
        if self.index_path.exists():
            try:
                import faiss

                self._index = faiss.read_index(str(self.index_path))
                self._use_fallback = False
                return True
            except Exception:
                pass

        return bool(self.metadata)

    def index_exists(self) -> bool:
        return self.index_path.exists() and self.metadata_path.exists()