mikeljl's picture
Switch to real putnam_small benchmark + statement-match guard
f946490
import re
LEAN_OPERATORS = [
":=", "!=", "&&", "-.", "->", "←", "..", "...",
"::", ":>", "<;>", ";;", "==", "||", "=>", "<=",
">=", "⁻¹", "?_",
]
MODIFIERS = ("private", "protected", "noncomputable", "nonrec",
"unsafe", "partial", "scoped", "local")
def _parse_single_attribute(text: str, start: int) -> int | None:
"""Return the index just after the matching ']' for `@[...]` starting at `start`."""
n = len(text)
assert text[start] == "@" and start + 1 < n and text[start + 1] == "["
i = start + 2
depth = 1
while i < n:
c = text[i]
if c == "[":
depth += 1
elif c == "]":
depth -= 1
if depth == 0:
return i + 1
i += 1
return None
def extract_and_remove_attributes(text: str):
"""Strip a leading block of `@[...]` attributes, returning (block, rest)."""
n = len(text)
if n == 0:
return "", text
pos = 0
while pos < n and text[pos].isspace():
pos += 1
if not (pos + 1 < n and text[pos] == "@" and text[pos + 1] == "["):
return "", text
attr_block_end = 0
cur = pos
first_attr_start = pos
while cur < n and text[cur] == "@" and cur + 1 < n and text[cur + 1] == "[":
attr_end = _parse_single_attribute(text, cur)
if attr_end is None:
return "", text
attr_block_end = attr_end
while attr_block_end < n and text[attr_block_end].isspace():
attr_block_end += 1
cur = attr_block_end
if not (cur + 1 < n and text[cur] == "@" and text[cur + 1] == "["):
break
if attr_block_end <= first_attr_start:
return "", text
return text[:attr_block_end], text[attr_block_end:]
def remove_comments(text: str) -> str:
"""Strip `/- ... -/` block and `--` line comments; drop empty lines."""
text = re.sub(r"/-.*?-/", "", text, flags=re.DOTALL)
cleaned_lines = []
for line in text.split("\n"):
cleaned = line.split("--", 1)[0]
if cleaned.strip() == "":
continue
cleaned_lines.append(cleaned)
return "\n".join(cleaned_lines).strip()
def return_theorem_to_prove_mathlib_style(text: str):
"""Find the span of a Mathlib-style theorem/lemma signature ending at top-level `:=`."""
mods_pattern = "|".join(MODIFIERS)
start_pattern = (
r"\s*"
r"(?:(?:" + mods_pattern + r")\s+)*"
r"\s*"
r"(?:theorem|lemma)\b"
)
start_match = re.search(start_pattern, text, re.DOTALL)
if not start_match:
prefix = (
r"\s*"
r"(?:(?:" + mods_pattern + r")\s+)*"
r"\s*"
r"(?:theorem|lemma)"
r".*?"
)
pattern_match = r"(" + prefix + r"\s*\|)"
match = re.search(pattern_match, text, re.DOTALL)
if match:
return match.span()
return None
start_index = start_match.start()
current_index = start_match.end()
bracket_stack: list[str] = []
brackets_map = {")": "(", "]": "[", "}": "{"}
open_brackets = set(brackets_map.values())
close_brackets = set(brackets_map.keys())
text_len = len(text)
while current_index < text_len:
char = text[current_index]
if not bracket_stack and current_index + 1 < text_len:
if text[current_index:current_index + 2] == ":=":
return (start_index, current_index + 2)
if char in open_brackets:
bracket_stack.append(char)
elif char in close_brackets:
if bracket_stack and bracket_stack[-1] == brackets_map[char]:
bracket_stack.pop()
current_index += 1
return None
def _lex(snippet: str) -> str:
"""Tokenize a Lean snippet, then re-collapse multi-char operators that
the per-char split would have separated."""
spaced = [" ".join(op) for op in LEAN_OPERATORS]
op_dict = dict(zip(spaced, LEAN_OPERATORS, strict=False))
out_lines = []
for line in snippet.splitlines():
tokens = []
token = ""
for ch in line:
if ch == " ":
if token:
tokens.append(token)
token = ""
elif ch.isalnum() or ch in "._'":
token += ch
else:
if token:
tokens.append(token)
token = ""
tokens.append(ch)
if token:
tokens.append(token)
line_out = " ".join(tokens)
for conn, original in op_dict.items():
if conn in line_out:
line_out = line_out.replace(conn, original)
out_lines.append(line_out)
return "\n".join(out_lines)
def extract_signature(statement_and_proof: str) -> str | None:
"""Return the theorem signature (everything up to but not including `:=`).
Strips comments and attributes first. Returns None if the text doesn't
look like a `theorem`/`lemma` declaration.
"""
try:
text = remove_comments(statement_and_proof)
_, text = extract_and_remove_attributes(text)
span = return_theorem_to_prove_mathlib_style(text)
if span is None:
return None
sig = text[span[0]:span[1]]
if sig.rstrip().endswith(":="):
sig = sig.rstrip()[:-2]
return sig.strip()
except Exception:
return None
def normalize_signature(sig: str) -> str:
"""Whitespace-insensitive normalization for signature comparison."""
return re.sub(r"\s+", " ", sig or "").strip()
def proof_length(statement_and_proof: str) -> int:
"""Token count of the proof body (everything after the signature's `:=`).
Returns 10**9 on any parse failure so it's clearly distinguishable from a
legitimate short proof.
"""
try:
text = remove_comments(statement_and_proof)
_, text = extract_and_remove_attributes(text)
span = return_theorem_to_prove_mathlib_style(text)
if span is None:
return 10**9
proof = text[span[1]:]
tokenized = _lex(proof)
return sum(len(line.split(" ")) for line in tokenized.splitlines())
except Exception:
return 10**9