Spaces:
Running
Running
| """Basic FAISS RAG retriever for code snippets.""" | |
| from __future__ import annotations | |
| import json | |
| from pathlib import Path | |
| from typing import List | |
| import faiss | |
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| class CodeRAG: | |
| """Loads snippets from data and serves top-k retrieval.""" | |
| def __init__(self, data_path: str = "data/sample_snippets.json"): | |
| project_root = Path(__file__).resolve().parents[1] | |
| self.data_path = project_root / data_path | |
| self.snippets = self._load_data() | |
| self.vectorizer = TfidfVectorizer(ngram_range=(1, 2), min_df=1) | |
| self.snippet_vectors = self._build_vectors(self.snippets) | |
| self.index = self._build_index(self.snippet_vectors) | |
| def _load_data(self) -> List[str]: | |
| if not self.data_path.exists(): | |
| return [] | |
| payload = json.loads(self.data_path.read_text(encoding="utf-8")) | |
| return [item["snippet"] for item in payload] | |
| def _build_vectors(self, snippets: List[str]) -> np.ndarray | None: | |
| if not snippets: | |
| return None | |
| matrix = self.vectorizer.fit_transform(snippets) | |
| vectors = matrix.toarray().astype(np.float32) | |
| # Normalize for cosine-like similarity with IndexFlatIP. | |
| norms = np.linalg.norm(vectors, axis=1, keepdims=True) | |
| norms[norms == 0.0] = 1.0 | |
| return vectors / norms | |
| def _build_index(self, vectors: np.ndarray | None): | |
| if vectors is None: | |
| return None | |
| index = faiss.IndexFlatIP(vectors.shape[1]) | |
| index.add(vectors) | |
| return index | |
| def retrieve(self, query: str, top_k: int = 2) -> str: | |
| if not self.snippets or self.index is None or self.snippet_vectors is None: | |
| return "" | |
| query_vec = self.vectorizer.transform([query]).toarray().astype(np.float32) | |
| qnorm = np.linalg.norm(query_vec, axis=1, keepdims=True) | |
| qnorm[qnorm == 0.0] = 1.0 | |
| query_vec = query_vec / qnorm | |
| _, idx = self.index.search(query_vec, min(top_k, len(self.snippets))) | |
| selected = [self.snippets[i] for i in idx[0] if i >= 0] | |
| return "\n\n".join(selected) | |