File size: 4,100 Bytes
b89e6d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Phase 3 eval: cross-modal retrieval quality (recall@k, MRR).

Design: N held-out (docstring, code) pairs form a closed candidate pool.
Each query docstring is ranked against all N code candidates; the paired
code is the positive and the other N-1 are distractors.  This tests the
embedder's ability to bridge natural-language intent → code, without the
confound of looking up exact code that is already in the FAISS index.

⚠️  Leakage caveat: CodeSearchNet's func_code_string includes the Python
docstring verbatim inside the function body (the triple-quoted string right
after `def`).  Embedding the raw code therefore lets the embedder trivially
find the match via lexical overlap — recall@1 ≈ 0.96 is an artefact, NOT a
measure of true code understanding.

Call with strip_code_docstrings=True to remove triple-quoted strings and
# comments from candidate code before embedding.  That number (~0.3-0.5
recall@1) reflects the embedder's actual semantic matching ability.

Usage (standalone):
    python scripts/retrieval_only_eval.py
"""
from __future__ import annotations

import re
import sys
from pathlib import Path

import numpy as np
import pandas as pd

sys.path.append(str(Path(__file__).resolve().parents[2]))

# Matches the first triple-quoted string in a Python function body.
_TRIPLE_QUOTE_RE = re.compile(r'"""[\s\S]*?"""|\'\'\'[\s\S]*?\'\'\'')
_COMMENT_RE = re.compile(r'#[^\n]*')


def _strip_code_docstring(code: str) -> str:
    """Remove the first triple-quoted docstring and all # comments from Python code."""
    code = _TRIPLE_QUOTE_RE.sub('', code, count=1)
    code = _COMMENT_RE.sub('', code)
    return code


def evaluate_cross_modal(
    embedder,
    pairs: pd.DataFrame,
    k_values: tuple[int, ...] = (1, 5, 10),
    batch_size: int = 64,
    strip_code_docstrings: bool = False,
) -> dict:
    """Cross-modal retrieval eval: docstring queries → code candidates.

    Args:
        embedder:               SentenceTransformer (or anything with .encode()).
        pairs:                  DataFrame with 'docstring' and 'code' columns (N rows).
        k_values:               Recall cut-offs to report.
        batch_size:             Encoding batch size.
        strip_code_docstrings:  If True, remove triple-quoted docstrings and #
                                comments from candidate code before embedding.
                                Use this for a leakage-free signal; see module
                                docstring for why the raw number is inflated.

    Returns dict with keys  mrr, recall@k (for each k), n_pairs, stripped.
    """
    n = len(pairs)

    candidates = pairs["code"].tolist()
    if strip_code_docstrings:
        candidates = [_strip_code_docstring(c) for c in candidates]

    print(f"[eval] encoding {n} docstrings as queries ...")
    q_emb = embedder.encode(
        pairs["docstring"].tolist(),
        batch_size=batch_size, show_progress_bar=True,
        convert_to_numpy=True, normalize_embeddings=True,
    ).astype("float32")

    print(f"[eval] encoding {n} code candidates"
          f"{' (docstrings stripped)' if strip_code_docstrings else ''} ...")
    c_emb = embedder.encode(
        candidates,
        batch_size=batch_size, show_progress_bar=True,
        convert_to_numpy=True, normalize_embeddings=True,
    ).astype("float32")

    # Cosine similarity matrix (N × N); both sides are already L2-normalised,
    # so inner product == cosine similarity.
    sim = q_emb @ c_emb.T  # shape (N, N)

    reciprocal_ranks: list[float] = []
    hits: dict[int, int] = {k: 0 for k in k_values}

    for i in range(n):
        order = sim[i].argsort()[::-1]
        rank = int(np.where(order == i)[0][0]) + 1  # 1-indexed
        reciprocal_ranks.append(1.0 / rank)
        for k in k_values:
            if rank <= k:
                hits[k] += 1

    result: dict = {
        "mrr": round(float(np.mean(reciprocal_ranks)), 4),
        "n_pairs": n,
        "stripped": strip_code_docstrings,
    }
    for k in k_values:
        result[f"recall@{k}"] = round(hits[k] / n, 4)
    return result