File size: 10,110 Bytes
8c51ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51bfe3e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c51ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51bfe3e
8c51ce7
 
 
 
 
6cf7e7c
 
 
 
 
 
8c51ce7
 
51bfe3e
 
 
 
 
 
8c51ce7
 
 
 
 
 
51bfe3e
8c51ce7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
#!/usr/bin/env python3
"""
MiniF2F benchmark for the LangGraph Lean 4 proof agent.

Metrics reported
----------------
pass@k   : fraction of problems solved within k LLM attempts
           (computed for k = 1, 2, ..., max_retries)
avg_attempts_to_solve : mean attempts used on problems that were solved
avg_time_s            : mean wall-clock seconds per problem

Example
-------
# Quick smoke-test (10 problems, gemma3:12b, 3 retries)
python scripts/benchmark.py --subset 10 --model gemma3:12b --retries 3

# Full valid split with Claude (244 problems, 5 retries)
python scripts/benchmark.py --split valid --model claude-3-5-haiku-20241022 --retries 5

# Ablation: no RAG
python scripts/benchmark.py --subset 50 --no-rag --model gemma3:12b
"""

import argparse
import csv
import os
import sys
import tempfile
import time
from pathlib import Path

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))

from langgraph_agent import LangGraphAgent

# ---------------------------------------------------------------------------
# MiniF2F loading
# ---------------------------------------------------------------------------

_DATASET_CANDIDATES = [
    ("cat-searcher/minif2f-lean4",       "formal_statement"),
]

# HuggingFace split name aliases (MiniF2F uses "validation" not "valid")
_SPLIT_ALIASES = {"valid": "validation", "val": "validation"}


def _ensure_import_and_sorry(code: str) -> str:
    if "import Mathlib" not in code:
        code = "import Mathlib\n\n" + code
    # If proof body is missing or is just whitespace after :=, add sorry
    if ":= by" in code and "sorry" not in code:
        code = code.rstrip() + "\n  sorry\n"
    elif ":=" in code and "sorry" not in code and "by" not in code:
        code = code.rstrip() + " by\n  sorry\n"
    return code


def load_minif2f(split: str = "valid", max_problems: int | None = None):
    from datasets import load_dataset

    hf_split = _SPLIT_ALIASES.get(split, split)

    for dataset_name, stmt_field in _DATASET_CANDIDATES:
        try:
            ds = load_dataset(dataset_name, split=hf_split)
            print(f"Loaded '{dataset_name}' ({split} split): {len(ds)} problems")

            # Normalise to list[dict] with keys: name, lean_code
            rows = []
            for i, row in enumerate(ds):
                name = row.get("name") or row.get("id") or row.get("problem_name") or f"problem_{i}"
                code = None
                for f in [stmt_field, "lean_code", "statement", "code", "formal_statement"]:
                    if f in row and row[f]:
                        code = _ensure_import_and_sorry(row[f])
                        break
                if code is None:
                    continue
                rows.append({"name": name, "lean_code": code})

            if max_problems:
                rows = rows[:max_problems]
            print(f"Using {len(rows)} problems after filtering.")
            return rows

        except Exception as e:
            print(f"  Could not load '{dataset_name}': {e}")

    raise RuntimeError(
        "Could not load MiniF2F from any known HuggingFace source.\n"
        "Try: pip install datasets  and check your internet connection."
    )


def load_local_problems(problems_dir: str, max_problems: int | None = None):
    """Load `.lean` files from a directory as a list of {name, lean_code} dicts."""
    root = Path(problems_dir)
    if not root.is_dir():
        raise RuntimeError(f"Problems directory not found: {problems_dir}")

    files = sorted(root.glob("*.lean"))
    if max_problems:
        files = files[:max_problems]

    rows = []
    for path in files:
        code = path.read_text(encoding="utf-8")
        if "sorry" not in code:
            # Skip files that are already complete proofs.
            continue
        rows.append({"name": path.stem, "lean_code": code})

    print(f"Loaded {len(rows)} local problem(s) with sorry placeholders.")
    return rows


# ---------------------------------------------------------------------------
# pass@k estimator
# ---------------------------------------------------------------------------

def pass_at_k(results: list[dict], k: int) -> float:
    """Fraction of problems solved within the first k attempts."""
    if not results:
        return 0.0
    solved = sum(
        1 for r in results
        if r["success"] and r["solved_at_attempt"] <= k
    )
    return solved / len(results)


# ---------------------------------------------------------------------------
# Single-problem runner
# ---------------------------------------------------------------------------

def run_one(agent: LangGraphAgent, name: str, lean_code: str, verbose: bool) -> dict:
    with tempfile.NamedTemporaryFile(
        mode="w", suffix=".lean", prefix=f"bench_{name[:20]}_", delete=False
    ) as f:
        f.write(lean_code)
        tmp = f.name

    try:
        t0 = time.time()
        detail = agent.solve_file_detailed(tmp)
        elapsed = round(time.time() - t0, 2)
    finally:
        # Restore original sorry so the temp file doesn't leak a partial proof
        try:
            os.unlink(tmp)
        except OSError:
            pass

    result = {
        "name": name,
        "success": detail["success"],
        "solved_at_attempt": detail["solved_at_attempt"],
        "total_attempts": detail["total_attempts"],
        "time_s": elapsed,
    }

    if verbose:
        status = "PASS" if result["success"] else "FAIL"
        print(
            f"  [{status}] {name:<50} "
            f"attempt={result['solved_at_attempt'] or '-':>2}  "
            f"time={elapsed:>6.1f}s"
        )
    return result


# ---------------------------------------------------------------------------
# Summary
# ---------------------------------------------------------------------------

def print_summary(results: list[dict], max_retries: int, model: str, no_rag: bool):
    n = len(results)
    solved = [r for r in results if r["success"]]

    print("\n" + "=" * 60)
    print("BENCHMARK SUMMARY")
    print("=" * 60)
    print(f"  Model       : {model}")
    print(f"  RAG         : {'disabled' if no_rag else 'enabled'}")
    print(f"  Problems    : {n}")
    print(f"  Max retries : {max_retries}")
    print()
    print(f"  {'Metric':<25} {'Value':>10}")
    print(f"  {'-'*25} {'-'*10}")
    for k in range(1, max_retries + 1):
        pct = pass_at_k(results, k) * 100
        print(f"  {'pass@' + str(k):<25} {pct:>9.1f}%")
    print()
    if solved:
        avg_att = sum(r["solved_at_attempt"] for r in solved) / len(solved)
        avg_t   = sum(r["time_s"] for r in results) / n
        print(f"  {'avg attempts (solved)':<25} {avg_att:>10.2f}")
        print(f"  {'avg time/problem (s)':<25} {avg_t:>10.1f}")
    print("=" * 60)


def write_csv(results: list[dict], path: str):
    fieldnames = ["name", "success", "solved_at_attempt", "total_attempts", "time_s"]
    with open(path, "w", newline="") as f:
        w = csv.DictWriter(f, fieldnames=fieldnames)
        w.writeheader()
        w.writerows(results)
    print(f"\nResults written to: {path}")


# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------

def main():
    parser = argparse.ArgumentParser(
        description="Run the Lean proof agent on MiniF2F and report pass@k metrics."
    )
    parser.add_argument("--split",   default="valid",        help="Dataset split: valid (=validation) | test")
    parser.add_argument("--subset",  type=int, default=None, help="Use only first N problems")
    parser.add_argument("--model",   default="llama-3.3-70b-versatile", help="Groq / Claude model ID")
    parser.add_argument("--retries", type=int, default=5,    help="Max LLM attempts per problem")
    parser.add_argument("--no-rag",  action="store_true",    help="Disable RAG retrieval (ablation)")
    parser.add_argument("--index-dir", default=None,         help="Path to pre-built FAISS index")
    parser.add_argument("--output",  default="benchmark_results.csv", help="CSV output path")
    parser.add_argument("--verbose", action="store_true",    help="Print per-problem results")
    parser.add_argument("--api-key", default=None,
                        help="API key for the chosen provider (Anthropic for Claude models). "
                             "Falls back to ANTHROPIC_API_KEY / GROQ_API_KEY env.")
    parser.add_argument("--problems-dir", default=None,
                        help="Use local .lean files in this directory instead of MiniF2F. "
                             "Each file is one problem.")
    args = parser.parse_args()

    if args.problems_dir:
        print(f"Loading local problems from {args.problems_dir}…")
        problems = load_local_problems(args.problems_dir, max_problems=args.subset)
    else:
        print(f"Loading MiniF2F ({args.split} split)…")
        problems = load_minif2f(split=args.split, max_problems=args.subset)

    print(f"Initialising agent (model={args.model}, retries={args.retries})…")
    agent = LangGraphAgent(
        model_name=args.model,
        max_retries=args.retries,
        index_dir=args.index_dir,
        api_key=args.api_key,
    )

    if args.no_rag:
        # Monkey-patch retriever to return empty results
        agent._retriever.retrieve = lambda query: []

    results = []
    print(f"\nRunning {len(problems)} problems…\n")
    for i, prob in enumerate(problems, 1):
        print(f"[{i:>3}/{len(problems)}] {prob['name'][:60]}")
        r = run_one(agent, prob["name"], prob["lean_code"], verbose=args.verbose)
        results.append(r)

        # Rolling summary every 10 problems
        if i % 10 == 0:
            p1 = pass_at_k(results, 1) * 100
            pk = pass_at_k(results, args.retries) * 100
            print(f"  → Rolling pass@1={p1:.1f}%  pass@{args.retries}={pk:.1f}%  ({i}/{len(problems)} done)\n")

    print_summary(results, args.retries, args.model, args.no_rag)
    write_csv(results, args.output)


if __name__ == "__main__":
    main()