code-gen-assistant / src /rag /embedder.py
Rushabh147's picture
Initial deploy to HF Spaces (clean history, LFS for all binaries)
b89e6d6
Raw
History Blame Contribute Delete
3.05 kB
"""Phase 3: build / save / load the FAISS retrieval index.
The index plus the corpus DataFrame are persisted so deployment doesn't rebuild
embeddings on every start (rebuilding is the slow part).
"""
from __future__ import annotations
import sys
from pathlib import Path
import numpy as np
import pandas as pd
sys.path.append(str(Path(__file__).resolve().parents[2]))
from src.config import load_config # noqa: E402
class CodeIndex:
"""Wraps a sentence-transformer embedder + a FAISS cosine index."""
def __init__(self, embed_model: str):
from sentence_transformers import SentenceTransformer
self.embed_model = embed_model
self.embedder = SentenceTransformer(embed_model)
self.index = None
self.corpus: pd.DataFrame | None = None
def build(self, corpus: pd.DataFrame, text_col: str = "docstring", batch_size: int = 64):
import faiss
self.corpus = corpus.reset_index(drop=True)
emb = self.embedder.encode(
self.corpus[text_col].tolist(),
batch_size=batch_size, show_progress_bar=True,
convert_to_numpy=True, normalize_embeddings=True,
).astype("float32")
self.index = faiss.IndexFlatIP(emb.shape[1])
self.index.add(emb)
return self
def retrieve(self, query: str, k: int = 3) -> pd.DataFrame:
if self.index is None or self.corpus is None:
raise RuntimeError("Index not built/loaded. Call build() or load().")
q = self.embedder.encode(
[query], convert_to_numpy=True, normalize_embeddings=True
).astype("float32")
scores, idx = self.index.search(q, k)
out = self.corpus.iloc[idx[0]].copy()
out["score"] = scores[0]
return out
def save(self, out_dir: str):
import faiss
out = Path(out_dir)
out.mkdir(parents=True, exist_ok=True)
faiss.write_index(self.index, str(out / "code.index"))
self.corpus.to_parquet(out / "corpus.parquet", index=False)
(out / "embed_model.txt").write_text(self.embed_model)
print(f"[index] saved to {out}")
@classmethod
def load(cls, in_dir: str) -> "CodeIndex":
import faiss
in_dir = Path(in_dir)
embed_model = (in_dir / "embed_model.txt").read_text().strip()
obj = cls(embed_model)
obj.index = faiss.read_index(str(in_dir / "code.index"))
obj.corpus = pd.read_parquet(in_dir / "corpus.parquet")
print(f"[index] loaded {obj.index.ntotal} vectors from {in_dir}")
return obj
def build_index_from_processed(cfg=None) -> CodeIndex:
"""Build the index from data/processed/train.parquet."""
cfg = cfg or load_config()
train_path = Path(cfg.paths.processed_dir) / "train.parquet"
if not train_path.exists():
sys.exit("train.parquet missing. Run scripts/01_prepare_data.py first.")
corpus = pd.read_parquet(train_path)
idx = CodeIndex(cfg.models.embed_model).build(corpus)
idx.save(cfg.paths.index_dir)
return idx