| import re |
|
|
|
|
| LEAN_OPERATORS = [ |
| ":=", "!=", "&&", "-.", "->", "←", "..", "...", |
| "::", ":>", "<;>", ";;", "==", "||", "=>", "<=", |
| ">=", "⁻¹", "?_", |
| ] |
|
|
|
|
| MODIFIERS = ("private", "protected", "noncomputable", "nonrec", |
| "unsafe", "partial", "scoped", "local") |
|
|
|
|
| def _parse_single_attribute(text: str, start: int) -> int | None: |
| """Return the index just after the matching ']' for `@[...]` starting at `start`.""" |
| n = len(text) |
| assert text[start] == "@" and start + 1 < n and text[start + 1] == "[" |
| i = start + 2 |
| depth = 1 |
| while i < n: |
| c = text[i] |
| if c == "[": |
| depth += 1 |
| elif c == "]": |
| depth -= 1 |
| if depth == 0: |
| return i + 1 |
| i += 1 |
| return None |
|
|
|
|
| def extract_and_remove_attributes(text: str): |
| """Strip a leading block of `@[...]` attributes, returning (block, rest).""" |
| n = len(text) |
| if n == 0: |
| return "", text |
| pos = 0 |
| while pos < n and text[pos].isspace(): |
| pos += 1 |
| if not (pos + 1 < n and text[pos] == "@" and text[pos + 1] == "["): |
| return "", text |
|
|
| attr_block_end = 0 |
| cur = pos |
| first_attr_start = pos |
| while cur < n and text[cur] == "@" and cur + 1 < n and text[cur + 1] == "[": |
| attr_end = _parse_single_attribute(text, cur) |
| if attr_end is None: |
| return "", text |
| attr_block_end = attr_end |
| while attr_block_end < n and text[attr_block_end].isspace(): |
| attr_block_end += 1 |
| cur = attr_block_end |
| if not (cur + 1 < n and text[cur] == "@" and text[cur + 1] == "["): |
| break |
|
|
| if attr_block_end <= first_attr_start: |
| return "", text |
| return text[:attr_block_end], text[attr_block_end:] |
|
|
|
|
| def remove_comments(text: str) -> str: |
| """Strip `/- ... -/` block and `--` line comments; drop empty lines.""" |
| text = re.sub(r"/-.*?-/", "", text, flags=re.DOTALL) |
| cleaned_lines = [] |
| for line in text.split("\n"): |
| cleaned = line.split("--", 1)[0] |
| if cleaned.strip() == "": |
| continue |
| cleaned_lines.append(cleaned) |
| return "\n".join(cleaned_lines).strip() |
|
|
|
|
| def return_theorem_to_prove_mathlib_style(text: str): |
| """Find the span of a Mathlib-style theorem/lemma signature ending at top-level `:=`.""" |
| mods_pattern = "|".join(MODIFIERS) |
| start_pattern = ( |
| r"\s*" |
| r"(?:(?:" + mods_pattern + r")\s+)*" |
| r"\s*" |
| r"(?:theorem|lemma)\b" |
| ) |
| start_match = re.search(start_pattern, text, re.DOTALL) |
| if not start_match: |
| prefix = ( |
| r"\s*" |
| r"(?:(?:" + mods_pattern + r")\s+)*" |
| r"\s*" |
| r"(?:theorem|lemma)" |
| r".*?" |
| ) |
| pattern_match = r"(" + prefix + r"\s*\|)" |
| match = re.search(pattern_match, text, re.DOTALL) |
| if match: |
| return match.span() |
| return None |
|
|
| start_index = start_match.start() |
| current_index = start_match.end() |
| bracket_stack: list[str] = [] |
| brackets_map = {")": "(", "]": "[", "}": "{"} |
| open_brackets = set(brackets_map.values()) |
| close_brackets = set(brackets_map.keys()) |
| text_len = len(text) |
|
|
| while current_index < text_len: |
| char = text[current_index] |
| if not bracket_stack and current_index + 1 < text_len: |
| if text[current_index:current_index + 2] == ":=": |
| return (start_index, current_index + 2) |
| if char in open_brackets: |
| bracket_stack.append(char) |
| elif char in close_brackets: |
| if bracket_stack and bracket_stack[-1] == brackets_map[char]: |
| bracket_stack.pop() |
| current_index += 1 |
|
|
| return None |
|
|
|
|
| def _lex(snippet: str) -> str: |
| """Tokenize a Lean snippet, then re-collapse multi-char operators that |
| the per-char split would have separated.""" |
| spaced = [" ".join(op) for op in LEAN_OPERATORS] |
| op_dict = dict(zip(spaced, LEAN_OPERATORS, strict=False)) |
| out_lines = [] |
| for line in snippet.splitlines(): |
| tokens = [] |
| token = "" |
| for ch in line: |
| if ch == " ": |
| if token: |
| tokens.append(token) |
| token = "" |
| elif ch.isalnum() or ch in "._'": |
| token += ch |
| else: |
| if token: |
| tokens.append(token) |
| token = "" |
| tokens.append(ch) |
| if token: |
| tokens.append(token) |
| line_out = " ".join(tokens) |
| for conn, original in op_dict.items(): |
| if conn in line_out: |
| line_out = line_out.replace(conn, original) |
| out_lines.append(line_out) |
| return "\n".join(out_lines) |
|
|
|
|
| def extract_signature(statement_and_proof: str) -> str | None: |
| """Return the theorem signature (everything up to but not including `:=`). |
| |
| Strips comments and attributes first. Returns None if the text doesn't |
| look like a `theorem`/`lemma` declaration. |
| """ |
| try: |
| text = remove_comments(statement_and_proof) |
| _, text = extract_and_remove_attributes(text) |
| span = return_theorem_to_prove_mathlib_style(text) |
| if span is None: |
| return None |
| sig = text[span[0]:span[1]] |
| if sig.rstrip().endswith(":="): |
| sig = sig.rstrip()[:-2] |
| return sig.strip() |
| except Exception: |
| return None |
|
|
|
|
| def normalize_signature(sig: str) -> str: |
| """Whitespace-insensitive normalization for signature comparison.""" |
| return re.sub(r"\s+", " ", sig or "").strip() |
|
|
|
|
| def proof_length(statement_and_proof: str) -> int: |
| """Token count of the proof body (everything after the signature's `:=`). |
| |
| Returns 10**9 on any parse failure so it's clearly distinguishable from a |
| legitimate short proof. |
| """ |
| try: |
| text = remove_comments(statement_and_proof) |
| _, text = extract_and_remove_attributes(text) |
| span = return_theorem_to_prove_mathlib_style(text) |
| if span is None: |
| return 10**9 |
| proof = text[span[1]:] |
| tokenized = _lex(proof) |
| return sum(len(line.split(" ")) for line in tokenized.splitlines()) |
| except Exception: |
| return 10**9 |
|
|