File size: 11,168 Bytes
3ac681e
1808386
3ac681e
 
 
01cdf65
3ac681e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c51ce7
3ac681e
 
 
 
 
 
 
1808386
 
 
 
3ac681e
 
 
 
1808386
3ac681e
 
 
1808386
 
 
 
 
f4afb6d
 
 
 
 
 
1808386
3ac681e
1808386
 
 
 
 
 
 
 
 
 
 
 
 
 
3ac681e
1808386
3ac681e
 
 
ef4afaa
 
 
 
 
 
 
 
 
 
 
1808386
 
 
 
 
8c51ce7
 
 
 
 
1808386
8c51ce7
 
 
3ac681e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c51ce7
3ac681e
 
 
 
 
 
8c51ce7
3ac681e
 
 
 
 
 
1dc8f6a
 
 
 
 
 
 
3ac681e
 
 
 
 
 
 
df04431
 
 
 
 
 
3ac681e
df04431
 
 
 
 
 
3ac681e
 
 
 
 
 
ec7552d
f4afb6d
 
 
ec7552d
 
 
 
 
 
 
3ac681e
 
 
 
 
8c51ce7
 
 
 
 
 
 
 
 
3ac681e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
df04431
 
 
 
 
 
3ac681e
 
 
 
df04431
3ac681e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2ebdf5
3ac681e
 
1145e92
df04431
 
 
3ac681e
df04431
 
 
 
 
1145e92
df04431
 
 
 
 
 
3ac681e
 
 
8c51ce7
 
 
 
3ac681e
 
8c51ce7
3ac681e
 
 
 
 
 
 
 
 
 
8c51ce7
3ac681e
 
 
8c51ce7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
import os
import re
from typing import List, TypedDict

from langgraph.graph import END, StateGraph

from lean_verifier import LeanEnvironment
from rag_chain import RAGProofChain
from retriever import MathLibRetriever

# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------

class ProofState(TypedDict):
    file_path: str
    lean_code: str
    goals: List[str]
    errors: List[str]
    attempt: int
    max_retries: int
    status: str          # "pending" | "success" | "failed"
    retrieved_lemmas: list
    solved_at_attempt: int  # 0 = unsolved, else the attempt number that succeeded


# ---------------------------------------------------------------------------
# Nodes
# ---------------------------------------------------------------------------

def _read_file(path: str) -> str:
    # Lean source uses ∀ ∃ ℕ ↑ etc. Force UTF-8 so non-UTF-8 default locales
    # (e.g. C/POSIX inside minimal Docker images) don't corrupt or crash on
    # read.
    with open(path, "r", encoding="utf-8") as f:
        return f.read()


def _write_file(path: str, code: str) -> None:
    with open(path, "w", encoding="utf-8") as f:
        f.write(code)


# Match an opening fence tagged `lean` (or `lean4`, `Lean`, etc.) followed by a
# newline or whitespace — but NOT something like ```leanish that would
# accidentally consume the "ish" into the code body.
_LEAN_FENCE_RE = re.compile(r"```\s*lean[0-9]*\s*\n", re.IGNORECASE)

# Citation comment the prompt asks the model to emit on its first line
# (`-- used: Nat.add_comm` / `-- used: none`). Logged as feedback on whether
# retrieved premises were actually useful — the seed data for a retrieval
# eval harness.
_USED_CITATION_RE = re.compile(r"^--\s*used:\s*(.+)$", re.MULTILINE)


def _extract_lean_code(text: str) -> str:
    """
    Extract the Lean code block from an LLM response.

    Handles:
      - ```lean\n...\n```         (canonical)
      - ```lean4\n...\n```        (some LLMs)
      - ```Lean\n...\n```         (case variation)
      - ```\n...\n```             (no language tag)
      - plain text without fences (returned as-is)
    """
    m = _LEAN_FENCE_RE.search(text)
    if m:
        rest = text[m.end():]
        return rest.split("```", 1)[0].strip()
    if "```" in text:
        return text.split("```", 1)[1].split("```", 1)[0].strip()
    return text.strip()


def _sanitize_imports(code: str) -> str:
    """
    LLMs often hallucinate Lean import paths. This function strips all `import`
    lines from the generated code and replaces them with `import Mathlib`, which
    is the correct single import for any Mathlib-based proof.
    """
    lines = code.splitlines()
    non_import_lines = [l for l in lines if not l.strip().startswith("import ")]
    return "import Mathlib\n\n" + "\n".join(non_import_lines).lstrip()


# A declaration keyword must be followed by whitespace or end-of-line so that
# `examplelike` (false positive) doesn't match and `theorem\n` (no trailing
# space) does match. `theorem:` and `theorem(` are not valid Lean syntax
# (the declaration name must come first), so we don't allow those either.
_THEOREM_KEYWORD_RE = re.compile(r"^\s*(?:example|theorem|lemma|def)(?:\s|$)")


def _count_theorem_blocks(code: str) -> int:
    return sum(
        1 for line in code.splitlines()
        if _THEOREM_KEYWORD_RE.match(line)
    )


def make_verify_node(lean_env: LeanEnvironment):
    def verify_node(state: ProofState) -> ProofState:
        print(f"\n--- Attempt {state['attempt'] + 1} / {state['max_retries']} ---")
        code = _read_file(state["file_path"])
        result = lean_env.verify_proof(code)

        new_status = "success" if result["status"] == "success" else "pending"
        if new_status == "success":
            print("Proof verified successfully!")
        else:
            print(
                f"Verification failed. "
                f"Errors: {len(result['errors'])}, Goals: {len(result['goals'])}"
            )

        solved_at = state["attempt"] + 1 if new_status == "success" else state["solved_at_attempt"]
        return {
            **state,
            "lean_code": code,
            "errors": result["errors"],
            "goals": result["goals"],
            "status": new_status,
            "solved_at_attempt": solved_at,
        }
    return verify_node


def make_retrieve_node(retriever: MathLibRetriever):
    def retrieve_node(state: ProofState) -> ProofState:
        # Query with goals only, newline-joined: the LeanDojo encoder was
        # trained on canonical proof states ("h1 : T1\nh2 : T2\n⊢ goal"), so
        # Lean error text is off-distribution noise in the embedding. Errors
        # still reach the LLM via the generation prompt — just not retrieval.
        # No open goals (e.g. pure syntax error) → empty query → retriever
        # returns [] and generation proceeds without premises.
        query = "\n\n".join(state["goals"])
        print("Retrieving relevant Mathlib lemmas…")
        lemmas = retriever.retrieve(query)
        print(f"  Retrieved {len(lemmas)} lemma(s).")
        return {**state, "retrieved_lemmas": lemmas}
    return retrieve_node


def make_generate_node(strong_chain: RAGProofChain, fast_chain: RAGProofChain | None = None):
    """
    If `fast_chain` is given, attempt 0 uses it (cheaper / faster model);
    subsequent attempts escalate to `strong_chain`. This catches the easy
    proofs in a few hundred ms before paying for the bigger model.
    """
    def generate_node(state: ProofState) -> ProofState:
        if fast_chain is not None and state["attempt"] == 0:
            chain = fast_chain
            print("Generating proof with LLM (fast model, first attempt)…")
        else:
            chain = strong_chain
            print("Generating proof with LLM…")
        raw = chain.generate(
            lean_code=state["lean_code"],
            goals=state["goals"],
            errors=state["errors"],
            retrieved_lemmas=state["retrieved_lemmas"],
        )
        extracted = _extract_lean_code(raw)
        citation = _USED_CITATION_RE.search(extracted)
        if citation:
            print(f"  [rag] lemmas cited as used: {citation.group(1).strip()}")
        # Detect empty/whitespace-only LLM output *before* sanitization, since
        # _sanitize_imports unconditionally prepends "import Mathlib" and would
        # otherwise mask an empty response as a non-empty payload.
        if not extracted or not extracted.strip():
            print("LLM produced no usable output.")
            return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
        new_code = _sanitize_imports(extracted)

        if not new_code or new_code.strip() == state["lean_code"].strip():
            print("LLM produced no changes.")
            return {**state, "attempt": state["attempt"] + 1, "status": "failed"}

        original_blocks = _count_theorem_blocks(state["lean_code"])
        generated_blocks = _count_theorem_blocks(new_code)
        if original_blocks > 0 and generated_blocks < original_blocks:
            print(
                f"LLM dropped theorem statements "
                f"({generated_blocks} of {original_blocks} preserved) — rejecting."
            )
            return {**state, "attempt": state["attempt"] + 1}

        _write_file(state["file_path"], new_code)
        print("File updated.")
        return {
            **state,
            "lean_code": new_code,
            "attempt": state["attempt"] + 1,
        }
    return generate_node


# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------

def should_continue(state: ProofState) -> str:
    if state["status"] == "success":
        return END
    if state["attempt"] >= state["max_retries"]:
        return END
    return "retrieve"


# ---------------------------------------------------------------------------
# Graph assembly
# ---------------------------------------------------------------------------

def build_graph(
    lean_env: LeanEnvironment,
    retriever: MathLibRetriever,
    chain: RAGProofChain,
    fast_chain: RAGProofChain | None = None,
):
    g = StateGraph(ProofState)

    g.add_node("verify", make_verify_node(lean_env))
    g.add_node("retrieve", make_retrieve_node(retriever))
    g.add_node("generate", make_generate_node(chain, fast_chain))

    g.set_entry_point("verify")
    g.add_conditional_edges("verify", should_continue, {"retrieve": "retrieve", END: END})
    g.add_edge("retrieve", "generate")
    g.add_edge("generate", "verify")

    return g.compile()


# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------

class LangGraphAgent:
    def __init__(
        self,
        model_name: str = "llama-3.3-70b-versatile",
        max_retries: int = 5,
        index_dir: str | None = None,
        api_key: str | None = None,
        fast_model: str | None = None,
        lean_env: LeanEnvironment | None = None,
        retriever: MathLibRetriever | None = None,
    ):
        # Reuse pre-built heavyweight components when given. This lets a hosting
        # process (e.g. the Gradio app) cache the Lean REPL and FAISS index
        # once at startup instead of rebuilding them on every solve_proof call.
        self._lean_env = lean_env if lean_env is not None else LeanEnvironment(use_mathlib=True)
        self._retriever = retriever if retriever is not None else MathLibRetriever(index_dir=index_dir)
        self._chain = RAGProofChain(model_name=model_name, api_key=api_key)
        self._fast_chain = (
            RAGProofChain(model_name=fast_model, api_key=api_key)
            if fast_model and fast_model != model_name
            else None
        )
        self._graph = build_graph(self._lean_env, self._retriever, self._chain, self._fast_chain)
        self._max_retries = max_retries

    def solve_file(self, file_path: str) -> bool:
        return self.solve_file_detailed(file_path)["success"]

    def solve_file_detailed(self, file_path: str) -> dict:
        """Returns {"success": bool, "solved_at_attempt": int, "total_attempts": int}."""
        if not os.path.exists(file_path):
            print(f"Error: {file_path} not found.")
            return {"success": False, "solved_at_attempt": 0, "total_attempts": 0}

        initial: ProofState = {
            "file_path": file_path,
            "lean_code": "",
            "goals": [],
            "errors": [],
            "attempt": 0,
            "max_retries": self._max_retries,
            "status": "pending",
            "retrieved_lemmas": [],
            "solved_at_attempt": 0,
        }

        final = self._graph.invoke(initial)
        return {
            "success": final["status"] == "success",
            "solved_at_attempt": final["solved_at_attempt"],
            "total_attempts": final["attempt"],
        }