hanwenzhu commited on
Commit
2fc99ca
·
verified ·
1 Parent(s): efaf8f9

Upload tactic_benchmark.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tactic_benchmark.py +171 -0
tactic_benchmark.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Optional, Literal
2
+ import json
3
+ import subprocess
4
+ import os
5
+ import shutil
6
+ import re
7
+ import concurrent.futures
8
+ import uuid
9
+
10
+ import tqdm
11
+ import argparse
12
+
13
+ SCRATCH_DIR = "/data/user_data/thomaszh/tmp-premises"
14
+
15
+ parser = argparse.ArgumentParser(description="Run tactic benchmark with premises.")
16
+ parser.add_argument("--ntp_toolkit_path", type=str, default="/home/thomaszh/ntp-toolkit", help="Path to the ntp-toolkit repository containing a tactic_benchmark script.")
17
+ parser.add_argument("--decl_names_file", type=str, default="eval_decls_nov6_v412.json", help="File containing declaration names for benchmark.")
18
+ parser.add_argument("--premises_file", type=str, default=None, help="File containing retrieved premises (default: use ground truth premises).")
19
+ parser.add_argument("--out_dir", type=str, default="results_retrieved_eval_premises_nov8", help="Output directory for results.")
20
+ parser.add_argument("--timeout", type=int, default=300, help="Timeout for each benchmark run in seconds.")
21
+ parser.add_argument("--benchmark_type", type=str, default="simp_all_with_premises", help="Type of benchmark to run.")
22
+ parser.add_argument("--k", type=int, default=8, help="Number of top premises to use.")
23
+ parser.add_argument("--max_workers", type=int, default=8, help="Number of workers for running the benchmark.")
24
+ parser.add_argument("--rerank", action="store_true", help="Enable reranking of premises.")
25
+ parser.add_argument("--pred_simp_all_hint", action="store_true", help="Enable prediction of simp_all_hint (replacing with notInSimpAll).")
26
+ parser.add_argument("--temp_premises_dir", type=str, default=None, help="A temporary directory that simulates Examples/Mathlib/TrainingDataWithPremises, if premises_file is given")
27
+ parser.add_argument("--tag_suffix", type=str, default=None, help="Suffix to the output json file name")
28
+
29
+ args = parser.parse_args()
30
+
31
+ decl_names_for_benchmark_file: str = args.decl_names_file
32
+ premises_file: str = args.premises_file
33
+ out_dir: str = args.out_dir
34
+ timeout: int = args.timeout
35
+ k: int = args.k
36
+ benchmark_type: str = args.benchmark_type
37
+ max_workers: int = args.max_workers
38
+ ntp_toolkit_path: str = args.ntp_toolkit_path
39
+ rerank: bool = args.rerank
40
+ pred_simp_all_hint: bool = args.pred_simp_all_hint
41
+ temp_premises_dir: str = args.temp_premises_dir or os.path.join(SCRATCH_DIR, f"premises-{uuid.uuid4()}")
42
+
43
+ os.makedirs(out_dir, exist_ok=True)
44
+ out_file = os.path.join(out_dir, benchmark_type)
45
+ if premises_file is None:
46
+ out_file += "-gt"
47
+ elif "hammer" in benchmark_type or "premise" in benchmark_type:
48
+ out_file += f"-k{k}"
49
+ if rerank:
50
+ out_file += "-rr"
51
+ if pred_simp_all_hint:
52
+ out_file += "-psah"
53
+ if args.tag_suffix is not None:
54
+ out_file += f"-{args.tag_suffix}"
55
+ out_file += ".json"
56
+
57
+ with open(decl_names_for_benchmark_file) as f:
58
+ decl_names_for_benchmark = json.load(f)
59
+
60
+ results = {d["decl_name"]: {} for d in decl_names_for_benchmark}
61
+
62
+ # Build `results` mapping declaration name to premises and hints
63
+ if premises_file is not None:
64
+ with open(premises_file) as f:
65
+ premises_raw = json.load(f)
66
+ if isinstance(premises_raw, dict):
67
+ premises_raw = premises_raw["dot"]
68
+ # (before nov 20) for each decl, there are multiple states corresponding to the decl (now only one)
69
+ # we assume the first state encountered in the file is the "root" initial state
70
+ for ps_entry in premises_raw:
71
+ decl_name = ps_entry["decl_name"]
72
+ if "premises" not in results[decl_name]:
73
+ premises = ps_entry["premises"]
74
+ # take names of top k premises
75
+ rank_key = "rerank_score" if rerank else "score"
76
+ topk_premises = [p for p in sorted(premises, key=lambda p: p[rank_key], reverse=True)[:k]]
77
+ results[decl_name]["premises"] = [p["corpus_id"] for p in topk_premises]
78
+ results[decl_name]["hints"] = [p.get("simp_all_hint", "notInSimpAll") for p in topk_premises]
79
+ else:
80
+ # Use ground truth premises
81
+ for entry in decl_names_for_benchmark:
82
+ decl_name = entry["decl_name"]
83
+ premises = results[decl_name]["premises"] = entry["gt_premises"]
84
+ results[decl_name]["hints"] = [entry["gt_hints"][p] for p in premises]
85
+
86
+ not_found_decl_names = []
87
+ for decl_name in results:
88
+ if "premises" not in results[decl_name]:
89
+ print(f"warning: premises for {decl_name} not found")
90
+ not_found_decl_names.append(decl_name)
91
+ decl_names_for_benchmark = [e for e in decl_names_for_benchmark if e["decl_name"] not in not_found_decl_names]
92
+
93
+ # build Examples/Mathlib/TrainingDataWithPremises-like directory but with retrieved premises
94
+ shutil.rmtree(temp_premises_dir, ignore_errors=True)
95
+ os.makedirs(temp_premises_dir, exist_ok=True)
96
+ for entry in decl_names_for_benchmark:
97
+ decl_name = entry["decl_name"]
98
+ module = entry["module"]
99
+ serialized_premises = []
100
+ for premise, hint in zip(results[decl_name]["premises"], results[decl_name]["hints"]):
101
+ if not pred_simp_all_hint:
102
+ hint = "notInSimpAll"
103
+ serialized_premises.append(f"({premise}, {hint})")
104
+ with open(os.path.join(temp_premises_dir, f"{module}.jsonl"), "a") as f:
105
+ json.dump({"declName": decl_name, "declHammerRecommendation": serialized_premises}, f) # NOTE: upstream might change name for simp_all
106
+ f.write("\n")
107
+
108
+ def run_benchmark(entry: dict, print_emoji: bool = False) -> dict[str, str]:
109
+ decl_name = entry["decl_name"]
110
+ module = entry["module"]
111
+ result_data = {"decl_name": decl_name, "module": module}
112
+ command: list[str] = [
113
+ "lake", "exe", "tactic_benchmark",
114
+ module, decl_name,
115
+ os.path.abspath(temp_premises_dir),
116
+ benchmark_type,
117
+ "--k", f"{k}"
118
+ ]
119
+ result_data["command"] = " ".join(command)
120
+ try:
121
+ result = subprocess.run(
122
+ command,
123
+ cwd=ntp_toolkit_path,
124
+ check=False,
125
+ text=True,
126
+ stdout=subprocess.PIPE,
127
+ stderr=subprocess.STDOUT,
128
+ timeout=timeout
129
+ )
130
+ result_output = "\n".join(
131
+ line for line in result.stdout.splitlines()
132
+ if not any(line.startswith(prefix) for prefix in ["note:", "warning:", "error:", "⚠", "✖", "✔"])
133
+ )
134
+ match = re.search(r"^([❌️💥️✅️]+) ", result.stdout, flags=re.MULTILINE)
135
+ result_emoji = match.group(1) if match else None
136
+ except subprocess.TimeoutExpired as e:
137
+ result_emoji = "⏰"
138
+ result_output = str(e)
139
+
140
+ result_data["result_emoji"] = result_emoji
141
+ result_data["result_output"] = result_output
142
+ # print(result_output)
143
+
144
+ if print_emoji:
145
+ print(result_emoji)
146
+ return result_data
147
+
148
+
149
+ subprocess.run(
150
+ ["lake", "build", "tactic_benchmark"],
151
+ cwd=ntp_toolkit_path,
152
+ check=True,
153
+ stdout=subprocess.DEVNULL,
154
+ stderr=subprocess.DEVNULL,
155
+ )
156
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
157
+ futures = {executor.submit(run_benchmark, entry): entry for entry in decl_names_for_benchmark}
158
+ for future in tqdm.tqdm(concurrent.futures.as_completed(futures), total=len(decl_names_for_benchmark)):
159
+ result_data = future.result()
160
+ results[result_data["decl_name"]].update(result_data)
161
+
162
+ with open(out_file, "w") as f:
163
+ json.dump(results, f, indent=4)
164
+ print(f"Results saved to {out_file}")
165
+
166
+ # Sometimes timeout tactics leave zombie threads (TODO)
167
+ subprocess.run(
168
+ ["killall", "tactic_benchmark"],
169
+ check=False,
170
+ )
171
+ shutil.rmtree(temp_premises_dir, ignore_errors=True)