File size: 3,047 Bytes
b89e6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Phase 3: build / save / load the FAISS retrieval index.

The index plus the corpus DataFrame are persisted so deployment doesn't rebuild
embeddings on every start (rebuilding is the slow part).
"""
from __future__ import annotations

import sys
from pathlib import Path

import numpy as np
import pandas as pd

sys.path.append(str(Path(__file__).resolve().parents[2]))
from src.config import load_config  # noqa: E402


class CodeIndex:
    """Wraps a sentence-transformer embedder + a FAISS cosine index."""

    def __init__(self, embed_model: str):
        from sentence_transformers import SentenceTransformer

        self.embed_model = embed_model
        self.embedder = SentenceTransformer(embed_model)
        self.index = None
        self.corpus: pd.DataFrame | None = None

    def build(self, corpus: pd.DataFrame, text_col: str = "docstring", batch_size: int = 64):
        import faiss

        self.corpus = corpus.reset_index(drop=True)
        emb = self.embedder.encode(
            self.corpus[text_col].tolist(),
            batch_size=batch_size, show_progress_bar=True,
            convert_to_numpy=True, normalize_embeddings=True,
        ).astype("float32")
        self.index = faiss.IndexFlatIP(emb.shape[1])
        self.index.add(emb)
        return self

    def retrieve(self, query: str, k: int = 3) -> pd.DataFrame:
        if self.index is None or self.corpus is None:
            raise RuntimeError("Index not built/loaded. Call build() or load().")
        q = self.embedder.encode(
            [query], convert_to_numpy=True, normalize_embeddings=True
        ).astype("float32")
        scores, idx = self.index.search(q, k)
        out = self.corpus.iloc[idx[0]].copy()
        out["score"] = scores[0]
        return out

    def save(self, out_dir: str):
        import faiss

        out = Path(out_dir)
        out.mkdir(parents=True, exist_ok=True)
        faiss.write_index(self.index, str(out / "code.index"))
        self.corpus.to_parquet(out / "corpus.parquet", index=False)
        (out / "embed_model.txt").write_text(self.embed_model)
        print(f"[index] saved to {out}")

    @classmethod
    def load(cls, in_dir: str) -> "CodeIndex":
        import faiss

        in_dir = Path(in_dir)
        embed_model = (in_dir / "embed_model.txt").read_text().strip()
        obj = cls(embed_model)
        obj.index = faiss.read_index(str(in_dir / "code.index"))
        obj.corpus = pd.read_parquet(in_dir / "corpus.parquet")
        print(f"[index] loaded {obj.index.ntotal} vectors from {in_dir}")
        return obj


def build_index_from_processed(cfg=None) -> CodeIndex:
    """Build the index from data/processed/train.parquet."""
    cfg = cfg or load_config()
    train_path = Path(cfg.paths.processed_dir) / "train.parquet"
    if not train_path.exists():
        sys.exit("train.parquet missing. Run scripts/01_prepare_data.py first.")
    corpus = pd.read_parquet(train_path)
    idx = CodeIndex(cfg.models.embed_model).build(corpus)
    idx.save(cfg.paths.index_dir)
    return idx