Spaces:
Sleeping
Sleeping
File size: 4,863 Bytes
6085b61 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | """FAISS-based vector store for code chunk retrieval.
Uses inner-product (cosine similarity on L2-normalised vectors).
Falls back to brute-force numpy search when FAISS is unavailable.
"""
from __future__ import annotations
import pickle
import warnings
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
from indexing.parser import CodeChunk
DEFAULT_INDEX_PATH = str(Path(__file__).parent / "code_index.faiss")
class CodeVectorStore:
"""FAISS index + metadata for code chunk similarity search."""
def __init__(self, index_path: str = DEFAULT_INDEX_PATH):
self.index_path = Path(index_path)
self.metadata_path = self.index_path.with_suffix(".pkl")
self.dimension = 384 # all-MiniLM-L6-v2 output dim
self._index = None
self.metadata: List[CodeChunk] = []
self._use_fallback = False
def build(self, chunks: List[CodeChunk], embeddings: np.ndarray) -> None:
"""Build index from chunks and their embeddings."""
self.metadata = chunks
try:
import faiss
embeddings = embeddings.astype(np.float32)
faiss.normalize_L2(embeddings)
self._index = faiss.IndexFlatIP(self.dimension)
self._index.add(embeddings)
self._use_fallback = False
except ImportError:
warnings.warn("faiss not available — using brute-force numpy search")
self._use_fallback = True
self._fallback_embeddings = embeddings.copy()
self.save()
def search(
self, query_embedding: np.ndarray, k: int = 5
) -> List[Tuple[CodeChunk, float]]:
"""Return top-k (chunk, cosine_similarity) matches."""
if not self.metadata:
return []
if not self._use_fallback and self._index is None:
self.load()
query = query_embedding.astype(np.float32).reshape(1, -1)
if self._use_fallback or self._index is None:
return self._fallback_search(query, k)
import faiss
faiss.normalize_L2(query)
distances, indices = self._index.search(query, k)
results: List[Tuple[CodeChunk, float]] = []
for idx, dist in zip(indices[0], distances[0]):
if 0 <= idx < len(self.metadata):
results.append((self.metadata[idx], float(dist)))
return results
def _fallback_search(
self, query: np.ndarray, k: int
) -> List[Tuple[CodeChunk, float]]:
"""Brute-force cosine similarity when FAISS is unavailable."""
if not hasattr(self, "_fallback_embeddings"):
return []
query_norm = query / (np.linalg.norm(query) + 1e-12)
emb_norm = self._fallback_embeddings / (
np.linalg.norm(self._fallback_embeddings, axis=1, keepdims=True) + 1e-12
)
scores = emb_norm @ query_norm.T
scores = scores.flatten()
top_k = min(k, len(scores))
indices = np.argsort(-scores)[:top_k]
results: List[Tuple[CodeChunk, float]] = []
for idx in indices:
results.append((self.metadata[idx], float(scores[idx])))
return results
def save(self) -> None:
"""Persist index and metadata to disk."""
if not self._use_fallback:
try:
import faiss
faiss.write_index(self._index, str(self.index_path))
except Exception:
pass
# Always save metadata and fallback embeddings
payload = {
"metadata": self.metadata,
"fallback_embeddings": getattr(self, "_fallback_embeddings", None),
}
self.metadata_path.write_bytes(pickle.dumps(payload))
def load(self) -> bool:
"""Load index and metadata from disk. Returns True on success."""
if not self.index_path.exists() and not self.metadata_path.exists():
return False
# Load metadata
if self.metadata_path.exists():
try:
payload = pickle.loads(self.metadata_path.read_bytes())
self.metadata = payload.get("metadata", [])
fb_emb = payload.get("fallback_embeddings")
if fb_emb is not None:
self._fallback_embeddings = fb_emb
self._use_fallback = True
except Exception:
return False
# Load FAISS index
if self.index_path.exists():
try:
import faiss
self._index = faiss.read_index(str(self.index_path))
self._use_fallback = False
return True
except Exception:
pass
return bool(self.metadata)
def index_exists(self) -> bool:
return self.index_path.exists() and self.metadata_path.exists()
|