code-gen-assistant / src /eval /retrieval_eval.py
Rushabh147's picture
Initial deploy to HF Spaces (clean history, LFS for all binaries)
b89e6d6
Raw
History Blame Contribute Delete
4.1 kB
"""Phase 3 eval: cross-modal retrieval quality (recall@k, MRR).
Design: N held-out (docstring, code) pairs form a closed candidate pool.
Each query docstring is ranked against all N code candidates; the paired
code is the positive and the other N-1 are distractors. This tests the
embedder's ability to bridge natural-language intent → code, without the
confound of looking up exact code that is already in the FAISS index.
⚠️ Leakage caveat: CodeSearchNet's func_code_string includes the Python
docstring verbatim inside the function body (the triple-quoted string right
after `def`). Embedding the raw code therefore lets the embedder trivially
find the match via lexical overlap — recall@1 ≈ 0.96 is an artefact, NOT a
measure of true code understanding.
Call with strip_code_docstrings=True to remove triple-quoted strings and
# comments from candidate code before embedding. That number (~0.3-0.5
recall@1) reflects the embedder's actual semantic matching ability.
Usage (standalone):
python scripts/retrieval_only_eval.py
"""
from __future__ import annotations
import re
import sys
from pathlib import Path
import numpy as np
import pandas as pd
sys.path.append(str(Path(__file__).resolve().parents[2]))
# Matches the first triple-quoted string in a Python function body.
_TRIPLE_QUOTE_RE = re.compile(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'')
_COMMENT_RE = re.compile(r'#[^\n]*')
def _strip_code_docstring(code: str) -> str:
"""Remove the first triple-quoted docstring and all # comments from Python code."""
code = _TRIPLE_QUOTE_RE.sub('', code, count=1)
code = _COMMENT_RE.sub('', code)
return code
def evaluate_cross_modal(
embedder,
pairs: pd.DataFrame,
k_values: tuple[int, ...] = (1, 5, 10),
batch_size: int = 64,
strip_code_docstrings: bool = False,
) -> dict:
"""Cross-modal retrieval eval: docstring queries → code candidates.
Args:
embedder: SentenceTransformer (or anything with .encode()).
pairs: DataFrame with 'docstring' and 'code' columns (N rows).
k_values: Recall cut-offs to report.
batch_size: Encoding batch size.
strip_code_docstrings: If True, remove triple-quoted docstrings and #
comments from candidate code before embedding.
Use this for a leakage-free signal; see module
docstring for why the raw number is inflated.
Returns dict with keys mrr, recall@k (for each k), n_pairs, stripped.
"""
n = len(pairs)
candidates = pairs["code"].tolist()
if strip_code_docstrings:
candidates = [_strip_code_docstring(c) for c in candidates]
print(f"[eval] encoding {n} docstrings as queries ...")
q_emb = embedder.encode(
pairs["docstring"].tolist(),
batch_size=batch_size, show_progress_bar=True,
convert_to_numpy=True, normalize_embeddings=True,
).astype("float32")
print(f"[eval] encoding {n} code candidates"
f"{' (docstrings stripped)' if strip_code_docstrings else ''} ...")
c_emb = embedder.encode(
candidates,
batch_size=batch_size, show_progress_bar=True,
convert_to_numpy=True, normalize_embeddings=True,
).astype("float32")
# Cosine similarity matrix (N × N); both sides are already L2-normalised,
# so inner product == cosine similarity.
sim = q_emb @ c_emb.T # shape (N, N)
reciprocal_ranks: list[float] = []
hits: dict[int, int] = {k: 0 for k in k_values}
for i in range(n):
order = sim[i].argsort()[::-1]
rank = int(np.where(order == i)[0][0]) + 1 # 1-indexed
reciprocal_ranks.append(1.0 / rank)
for k in k_values:
if rank <= k:
hits[k] += 1
result: dict = {
"mrr": round(float(np.mean(reciprocal_ranks)), 4),
"n_pairs": n,
"stripped": strip_code_docstrings,
}
for k in k_values:
result[f"recall@{k}"] = round(hits[k] / n, 4)
return result