Spaces:
Sleeping
Sleeping
| """Phase 3 eval: cross-modal retrieval quality (recall@k, MRR). | |
| Design: N held-out (docstring, code) pairs form a closed candidate pool. | |
| Each query docstring is ranked against all N code candidates; the paired | |
| code is the positive and the other N-1 are distractors. This tests the | |
| embedder's ability to bridge natural-language intent → code, without the | |
| confound of looking up exact code that is already in the FAISS index. | |
| ⚠️ Leakage caveat: CodeSearchNet's func_code_string includes the Python | |
| docstring verbatim inside the function body (the triple-quoted string right | |
| after `def`). Embedding the raw code therefore lets the embedder trivially | |
| find the match via lexical overlap — recall@1 ≈ 0.96 is an artefact, NOT a | |
| measure of true code understanding. | |
| Call with strip_code_docstrings=True to remove triple-quoted strings and | |
| # comments from candidate code before embedding. That number (~0.3-0.5 | |
| recall@1) reflects the embedder's actual semantic matching ability. | |
| Usage (standalone): | |
| python scripts/retrieval_only_eval.py | |
| """ | |
| from __future__ import annotations | |
| import re | |
| import sys | |
| from pathlib import Path | |
| import numpy as np | |
| import pandas as pd | |
| sys.path.append(str(Path(__file__).resolve().parents[2])) | |
| # Matches the first triple-quoted string in a Python function body. | |
| _TRIPLE_QUOTE_RE = re.compile(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'') | |
| _COMMENT_RE = re.compile(r'#[^\n]*') | |
| def _strip_code_docstring(code: str) -> str: | |
| """Remove the first triple-quoted docstring and all # comments from Python code.""" | |
| code = _TRIPLE_QUOTE_RE.sub('', code, count=1) | |
| code = _COMMENT_RE.sub('', code) | |
| return code | |
| def evaluate_cross_modal( | |
| embedder, | |
| pairs: pd.DataFrame, | |
| k_values: tuple[int, ...] = (1, 5, 10), | |
| batch_size: int = 64, | |
| strip_code_docstrings: bool = False, | |
| ) -> dict: | |
| """Cross-modal retrieval eval: docstring queries → code candidates. | |
| Args: | |
| embedder: SentenceTransformer (or anything with .encode()). | |
| pairs: DataFrame with 'docstring' and 'code' columns (N rows). | |
| k_values: Recall cut-offs to report. | |
| batch_size: Encoding batch size. | |
| strip_code_docstrings: If True, remove triple-quoted docstrings and # | |
| comments from candidate code before embedding. | |
| Use this for a leakage-free signal; see module | |
| docstring for why the raw number is inflated. | |
| Returns dict with keys mrr, recall@k (for each k), n_pairs, stripped. | |
| """ | |
| n = len(pairs) | |
| candidates = pairs["code"].tolist() | |
| if strip_code_docstrings: | |
| candidates = [_strip_code_docstring(c) for c in candidates] | |
| print(f"[eval] encoding {n} docstrings as queries ...") | |
| q_emb = embedder.encode( | |
| pairs["docstring"].tolist(), | |
| batch_size=batch_size, show_progress_bar=True, | |
| convert_to_numpy=True, normalize_embeddings=True, | |
| ).astype("float32") | |
| print(f"[eval] encoding {n} code candidates" | |
| f"{' (docstrings stripped)' if strip_code_docstrings else ''} ...") | |
| c_emb = embedder.encode( | |
| candidates, | |
| batch_size=batch_size, show_progress_bar=True, | |
| convert_to_numpy=True, normalize_embeddings=True, | |
| ).astype("float32") | |
| # Cosine similarity matrix (N × N); both sides are already L2-normalised, | |
| # so inner product == cosine similarity. | |
| sim = q_emb @ c_emb.T # shape (N, N) | |
| reciprocal_ranks: list[float] = [] | |
| hits: dict[int, int] = {k: 0 for k in k_values} | |
| for i in range(n): | |
| order = sim[i].argsort()[::-1] | |
| rank = int(np.where(order == i)[0][0]) + 1 # 1-indexed | |
| reciprocal_ranks.append(1.0 / rank) | |
| for k in k_values: | |
| if rank <= k: | |
| hits[k] += 1 | |
| result: dict = { | |
| "mrr": round(float(np.mean(reciprocal_ranks)), 4), | |
| "n_pairs": n, | |
| "stripped": strip_code_docstrings, | |
| } | |
| for k in k_values: | |
| result[f"recall@{k}"] = round(hits[k] / n, 4) | |
| return result | |