Lean4-helper / src /langgraph_agent.py
p4r5kpftnp-cmd
RAFT-style prompting: distractor-aware framing + usage citations
f4afb6d
Raw
History Blame Contribute Delete
11.2 kB
import os
import re
from typing import List, TypedDict
from langgraph.graph import END, StateGraph
from lean_verifier import LeanEnvironment
from rag_chain import RAGProofChain
from retriever import MathLibRetriever
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
class ProofState(TypedDict):
file_path: str
lean_code: str
goals: List[str]
errors: List[str]
attempt: int
max_retries: int
status: str # "pending" | "success" | "failed"
retrieved_lemmas: list
solved_at_attempt: int # 0 = unsolved, else the attempt number that succeeded
# ---------------------------------------------------------------------------
# Nodes
# ---------------------------------------------------------------------------
def _read_file(path: str) -> str:
# Lean source uses ∀ ∃ ℕ ↑ etc. Force UTF-8 so non-UTF-8 default locales
# (e.g. C/POSIX inside minimal Docker images) don't corrupt or crash on
# read.
with open(path, "r", encoding="utf-8") as f:
return f.read()
def _write_file(path: str, code: str) -> None:
with open(path, "w", encoding="utf-8") as f:
f.write(code)
# Match an opening fence tagged `lean` (or `lean4`, `Lean`, etc.) followed by a
# newline or whitespace — but NOT something like ```leanish that would
# accidentally consume the "ish" into the code body.
_LEAN_FENCE_RE = re.compile(r"```\s*lean[0-9]*\s*\n", re.IGNORECASE)
# Citation comment the prompt asks the model to emit on its first line
# (`-- used: Nat.add_comm` / `-- used: none`). Logged as feedback on whether
# retrieved premises were actually useful — the seed data for a retrieval
# eval harness.
_USED_CITATION_RE = re.compile(r"^--\s*used:\s*(.+)$", re.MULTILINE)
def _extract_lean_code(text: str) -> str:
"""
Extract the Lean code block from an LLM response.
Handles:
- ```lean\n...\n``` (canonical)
- ```lean4\n...\n``` (some LLMs)
- ```Lean\n...\n``` (case variation)
- ```\n...\n``` (no language tag)
- plain text without fences (returned as-is)
"""
m = _LEAN_FENCE_RE.search(text)
if m:
rest = text[m.end():]
return rest.split("```", 1)[0].strip()
if "```" in text:
return text.split("```", 1)[1].split("```", 1)[0].strip()
return text.strip()
def _sanitize_imports(code: str) -> str:
"""
LLMs often hallucinate Lean import paths. This function strips all `import`
lines from the generated code and replaces them with `import Mathlib`, which
is the correct single import for any Mathlib-based proof.
"""
lines = code.splitlines()
non_import_lines = [l for l in lines if not l.strip().startswith("import ")]
return "import Mathlib\n\n" + "\n".join(non_import_lines).lstrip()
# A declaration keyword must be followed by whitespace or end-of-line so that
# `examplelike` (false positive) doesn't match and `theorem\n` (no trailing
# space) does match. `theorem:` and `theorem(` are not valid Lean syntax
# (the declaration name must come first), so we don't allow those either.
_THEOREM_KEYWORD_RE = re.compile(r"^\s*(?:example|theorem|lemma|def)(?:\s|$)")
def _count_theorem_blocks(code: str) -> int:
return sum(
1 for line in code.splitlines()
if _THEOREM_KEYWORD_RE.match(line)
)
def make_verify_node(lean_env: LeanEnvironment):
def verify_node(state: ProofState) -> ProofState:
print(f"\n--- Attempt {state['attempt'] + 1} / {state['max_retries']} ---")
code = _read_file(state["file_path"])
result = lean_env.verify_proof(code)
new_status = "success" if result["status"] == "success" else "pending"
if new_status == "success":
print("Proof verified successfully!")
else:
print(
f"Verification failed. "
f"Errors: {len(result['errors'])}, Goals: {len(result['goals'])}"
)
solved_at = state["attempt"] + 1 if new_status == "success" else state["solved_at_attempt"]
return {
**state,
"lean_code": code,
"errors": result["errors"],
"goals": result["goals"],
"status": new_status,
"solved_at_attempt": solved_at,
}
return verify_node
def make_retrieve_node(retriever: MathLibRetriever):
def retrieve_node(state: ProofState) -> ProofState:
# Query with goals only, newline-joined: the LeanDojo encoder was
# trained on canonical proof states ("h1 : T1\nh2 : T2\n⊢ goal"), so
# Lean error text is off-distribution noise in the embedding. Errors
# still reach the LLM via the generation prompt — just not retrieval.
# No open goals (e.g. pure syntax error) → empty query → retriever
# returns [] and generation proceeds without premises.
query = "\n\n".join(state["goals"])
print("Retrieving relevant Mathlib lemmas…")
lemmas = retriever.retrieve(query)
print(f" Retrieved {len(lemmas)} lemma(s).")
return {**state, "retrieved_lemmas": lemmas}
return retrieve_node
def make_generate_node(strong_chain: RAGProofChain, fast_chain: RAGProofChain | None = None):
"""
If `fast_chain` is given, attempt 0 uses it (cheaper / faster model);
subsequent attempts escalate to `strong_chain`. This catches the easy
proofs in a few hundred ms before paying for the bigger model.
"""
def generate_node(state: ProofState) -> ProofState:
if fast_chain is not None and state["attempt"] == 0:
chain = fast_chain
print("Generating proof with LLM (fast model, first attempt)…")
else:
chain = strong_chain
print("Generating proof with LLM…")
raw = chain.generate(
lean_code=state["lean_code"],
goals=state["goals"],
errors=state["errors"],
retrieved_lemmas=state["retrieved_lemmas"],
)
extracted = _extract_lean_code(raw)
citation = _USED_CITATION_RE.search(extracted)
if citation:
print(f" [rag] lemmas cited as used: {citation.group(1).strip()}")
# Detect empty/whitespace-only LLM output *before* sanitization, since
# _sanitize_imports unconditionally prepends "import Mathlib" and would
# otherwise mask an empty response as a non-empty payload.
if not extracted or not extracted.strip():
print("LLM produced no usable output.")
return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
new_code = _sanitize_imports(extracted)
if not new_code or new_code.strip() == state["lean_code"].strip():
print("LLM produced no changes.")
return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
original_blocks = _count_theorem_blocks(state["lean_code"])
generated_blocks = _count_theorem_blocks(new_code)
if original_blocks > 0 and generated_blocks < original_blocks:
print(
f"LLM dropped theorem statements "
f"({generated_blocks} of {original_blocks} preserved) — rejecting."
)
return {**state, "attempt": state["attempt"] + 1}
_write_file(state["file_path"], new_code)
print("File updated.")
return {
**state,
"lean_code": new_code,
"attempt": state["attempt"] + 1,
}
return generate_node
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
def should_continue(state: ProofState) -> str:
if state["status"] == "success":
return END
if state["attempt"] >= state["max_retries"]:
return END
return "retrieve"
# ---------------------------------------------------------------------------
# Graph assembly
# ---------------------------------------------------------------------------
def build_graph(
lean_env: LeanEnvironment,
retriever: MathLibRetriever,
chain: RAGProofChain,
fast_chain: RAGProofChain | None = None,
):
g = StateGraph(ProofState)
g.add_node("verify", make_verify_node(lean_env))
g.add_node("retrieve", make_retrieve_node(retriever))
g.add_node("generate", make_generate_node(chain, fast_chain))
g.set_entry_point("verify")
g.add_conditional_edges("verify", should_continue, {"retrieve": "retrieve", END: END})
g.add_edge("retrieve", "generate")
g.add_edge("generate", "verify")
return g.compile()
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
class LangGraphAgent:
def __init__(
self,
model_name: str = "llama-3.3-70b-versatile",
max_retries: int = 5,
index_dir: str | None = None,
api_key: str | None = None,
fast_model: str | None = None,
lean_env: LeanEnvironment | None = None,
retriever: MathLibRetriever | None = None,
):
# Reuse pre-built heavyweight components when given. This lets a hosting
# process (e.g. the Gradio app) cache the Lean REPL and FAISS index
# once at startup instead of rebuilding them on every solve_proof call.
self._lean_env = lean_env if lean_env is not None else LeanEnvironment(use_mathlib=True)
self._retriever = retriever if retriever is not None else MathLibRetriever(index_dir=index_dir)
self._chain = RAGProofChain(model_name=model_name, api_key=api_key)
self._fast_chain = (
RAGProofChain(model_name=fast_model, api_key=api_key)
if fast_model and fast_model != model_name
else None
)
self._graph = build_graph(self._lean_env, self._retriever, self._chain, self._fast_chain)
self._max_retries = max_retries
def solve_file(self, file_path: str) -> bool:
return self.solve_file_detailed(file_path)["success"]
def solve_file_detailed(self, file_path: str) -> dict:
"""Returns {"success": bool, "solved_at_attempt": int, "total_attempts": int}."""
if not os.path.exists(file_path):
print(f"Error: {file_path} not found.")
return {"success": False, "solved_at_attempt": 0, "total_attempts": 0}
initial: ProofState = {
"file_path": file_path,
"lean_code": "",
"goals": [],
"errors": [],
"attempt": 0,
"max_retries": self._max_retries,
"status": "pending",
"retrieved_lemmas": [],
"solved_at_attempt": 0,
}
final = self._graph.invoke(initial)
return {
"success": final["status"] == "success",
"solved_at_attempt": final["solved_at_attempt"],
"total_attempts": final["attempt"],
}