Spaces:
Running
Running
File size: 11,168 Bytes
3ac681e 1808386 3ac681e 01cdf65 3ac681e 8c51ce7 3ac681e 1808386 3ac681e 1808386 3ac681e 1808386 f4afb6d 1808386 3ac681e 1808386 3ac681e 1808386 3ac681e ef4afaa 1808386 8c51ce7 1808386 8c51ce7 3ac681e 8c51ce7 3ac681e 8c51ce7 3ac681e 1dc8f6a 3ac681e df04431 3ac681e df04431 3ac681e ec7552d f4afb6d ec7552d 3ac681e 8c51ce7 3ac681e df04431 3ac681e df04431 3ac681e c2ebdf5 3ac681e 1145e92 df04431 3ac681e df04431 1145e92 df04431 3ac681e 8c51ce7 3ac681e 8c51ce7 3ac681e 8c51ce7 3ac681e 8c51ce7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 | import os
import re
from typing import List, TypedDict
from langgraph.graph import END, StateGraph
from lean_verifier import LeanEnvironment
from rag_chain import RAGProofChain
from retriever import MathLibRetriever
# ---------------------------------------------------------------------------
# State
# ---------------------------------------------------------------------------
class ProofState(TypedDict):
file_path: str
lean_code: str
goals: List[str]
errors: List[str]
attempt: int
max_retries: int
status: str # "pending" | "success" | "failed"
retrieved_lemmas: list
solved_at_attempt: int # 0 = unsolved, else the attempt number that succeeded
# ---------------------------------------------------------------------------
# Nodes
# ---------------------------------------------------------------------------
def _read_file(path: str) -> str:
# Lean source uses ∀ ∃ ℕ ↑ etc. Force UTF-8 so non-UTF-8 default locales
# (e.g. C/POSIX inside minimal Docker images) don't corrupt or crash on
# read.
with open(path, "r", encoding="utf-8") as f:
return f.read()
def _write_file(path: str, code: str) -> None:
with open(path, "w", encoding="utf-8") as f:
f.write(code)
# Match an opening fence tagged `lean` (or `lean4`, `Lean`, etc.) followed by a
# newline or whitespace — but NOT something like ```leanish that would
# accidentally consume the "ish" into the code body.
_LEAN_FENCE_RE = re.compile(r"```\s*lean[0-9]*\s*\n", re.IGNORECASE)
# Citation comment the prompt asks the model to emit on its first line
# (`-- used: Nat.add_comm` / `-- used: none`). Logged as feedback on whether
# retrieved premises were actually useful — the seed data for a retrieval
# eval harness.
_USED_CITATION_RE = re.compile(r"^--\s*used:\s*(.+)$", re.MULTILINE)
def _extract_lean_code(text: str) -> str:
"""
Extract the Lean code block from an LLM response.
Handles:
- ```lean\n...\n``` (canonical)
- ```lean4\n...\n``` (some LLMs)
- ```Lean\n...\n``` (case variation)
- ```\n...\n``` (no language tag)
- plain text without fences (returned as-is)
"""
m = _LEAN_FENCE_RE.search(text)
if m:
rest = text[m.end():]
return rest.split("```", 1)[0].strip()
if "```" in text:
return text.split("```", 1)[1].split("```", 1)[0].strip()
return text.strip()
def _sanitize_imports(code: str) -> str:
"""
LLMs often hallucinate Lean import paths. This function strips all `import`
lines from the generated code and replaces them with `import Mathlib`, which
is the correct single import for any Mathlib-based proof.
"""
lines = code.splitlines()
non_import_lines = [l for l in lines if not l.strip().startswith("import ")]
return "import Mathlib\n\n" + "\n".join(non_import_lines).lstrip()
# A declaration keyword must be followed by whitespace or end-of-line so that
# `examplelike` (false positive) doesn't match and `theorem\n` (no trailing
# space) does match. `theorem:` and `theorem(` are not valid Lean syntax
# (the declaration name must come first), so we don't allow those either.
_THEOREM_KEYWORD_RE = re.compile(r"^\s*(?:example|theorem|lemma|def)(?:\s|$)")
def _count_theorem_blocks(code: str) -> int:
return sum(
1 for line in code.splitlines()
if _THEOREM_KEYWORD_RE.match(line)
)
def make_verify_node(lean_env: LeanEnvironment):
def verify_node(state: ProofState) -> ProofState:
print(f"\n--- Attempt {state['attempt'] + 1} / {state['max_retries']} ---")
code = _read_file(state["file_path"])
result = lean_env.verify_proof(code)
new_status = "success" if result["status"] == "success" else "pending"
if new_status == "success":
print("Proof verified successfully!")
else:
print(
f"Verification failed. "
f"Errors: {len(result['errors'])}, Goals: {len(result['goals'])}"
)
solved_at = state["attempt"] + 1 if new_status == "success" else state["solved_at_attempt"]
return {
**state,
"lean_code": code,
"errors": result["errors"],
"goals": result["goals"],
"status": new_status,
"solved_at_attempt": solved_at,
}
return verify_node
def make_retrieve_node(retriever: MathLibRetriever):
def retrieve_node(state: ProofState) -> ProofState:
# Query with goals only, newline-joined: the LeanDojo encoder was
# trained on canonical proof states ("h1 : T1\nh2 : T2\n⊢ goal"), so
# Lean error text is off-distribution noise in the embedding. Errors
# still reach the LLM via the generation prompt — just not retrieval.
# No open goals (e.g. pure syntax error) → empty query → retriever
# returns [] and generation proceeds without premises.
query = "\n\n".join(state["goals"])
print("Retrieving relevant Mathlib lemmas…")
lemmas = retriever.retrieve(query)
print(f" Retrieved {len(lemmas)} lemma(s).")
return {**state, "retrieved_lemmas": lemmas}
return retrieve_node
def make_generate_node(strong_chain: RAGProofChain, fast_chain: RAGProofChain | None = None):
"""
If `fast_chain` is given, attempt 0 uses it (cheaper / faster model);
subsequent attempts escalate to `strong_chain`. This catches the easy
proofs in a few hundred ms before paying for the bigger model.
"""
def generate_node(state: ProofState) -> ProofState:
if fast_chain is not None and state["attempt"] == 0:
chain = fast_chain
print("Generating proof with LLM (fast model, first attempt)…")
else:
chain = strong_chain
print("Generating proof with LLM…")
raw = chain.generate(
lean_code=state["lean_code"],
goals=state["goals"],
errors=state["errors"],
retrieved_lemmas=state["retrieved_lemmas"],
)
extracted = _extract_lean_code(raw)
citation = _USED_CITATION_RE.search(extracted)
if citation:
print(f" [rag] lemmas cited as used: {citation.group(1).strip()}")
# Detect empty/whitespace-only LLM output *before* sanitization, since
# _sanitize_imports unconditionally prepends "import Mathlib" and would
# otherwise mask an empty response as a non-empty payload.
if not extracted or not extracted.strip():
print("LLM produced no usable output.")
return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
new_code = _sanitize_imports(extracted)
if not new_code or new_code.strip() == state["lean_code"].strip():
print("LLM produced no changes.")
return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
original_blocks = _count_theorem_blocks(state["lean_code"])
generated_blocks = _count_theorem_blocks(new_code)
if original_blocks > 0 and generated_blocks < original_blocks:
print(
f"LLM dropped theorem statements "
f"({generated_blocks} of {original_blocks} preserved) — rejecting."
)
return {**state, "attempt": state["attempt"] + 1}
_write_file(state["file_path"], new_code)
print("File updated.")
return {
**state,
"lean_code": new_code,
"attempt": state["attempt"] + 1,
}
return generate_node
# ---------------------------------------------------------------------------
# Router
# ---------------------------------------------------------------------------
def should_continue(state: ProofState) -> str:
if state["status"] == "success":
return END
if state["attempt"] >= state["max_retries"]:
return END
return "retrieve"
# ---------------------------------------------------------------------------
# Graph assembly
# ---------------------------------------------------------------------------
def build_graph(
lean_env: LeanEnvironment,
retriever: MathLibRetriever,
chain: RAGProofChain,
fast_chain: RAGProofChain | None = None,
):
g = StateGraph(ProofState)
g.add_node("verify", make_verify_node(lean_env))
g.add_node("retrieve", make_retrieve_node(retriever))
g.add_node("generate", make_generate_node(chain, fast_chain))
g.set_entry_point("verify")
g.add_conditional_edges("verify", should_continue, {"retrieve": "retrieve", END: END})
g.add_edge("retrieve", "generate")
g.add_edge("generate", "verify")
return g.compile()
# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------
class LangGraphAgent:
def __init__(
self,
model_name: str = "llama-3.3-70b-versatile",
max_retries: int = 5,
index_dir: str | None = None,
api_key: str | None = None,
fast_model: str | None = None,
lean_env: LeanEnvironment | None = None,
retriever: MathLibRetriever | None = None,
):
# Reuse pre-built heavyweight components when given. This lets a hosting
# process (e.g. the Gradio app) cache the Lean REPL and FAISS index
# once at startup instead of rebuilding them on every solve_proof call.
self._lean_env = lean_env if lean_env is not None else LeanEnvironment(use_mathlib=True)
self._retriever = retriever if retriever is not None else MathLibRetriever(index_dir=index_dir)
self._chain = RAGProofChain(model_name=model_name, api_key=api_key)
self._fast_chain = (
RAGProofChain(model_name=fast_model, api_key=api_key)
if fast_model and fast_model != model_name
else None
)
self._graph = build_graph(self._lean_env, self._retriever, self._chain, self._fast_chain)
self._max_retries = max_retries
def solve_file(self, file_path: str) -> bool:
return self.solve_file_detailed(file_path)["success"]
def solve_file_detailed(self, file_path: str) -> dict:
"""Returns {"success": bool, "solved_at_attempt": int, "total_attempts": int}."""
if not os.path.exists(file_path):
print(f"Error: {file_path} not found.")
return {"success": False, "solved_at_attempt": 0, "total_attempts": 0}
initial: ProofState = {
"file_path": file_path,
"lean_code": "",
"goals": [],
"errors": [],
"attempt": 0,
"max_retries": self._max_retries,
"status": "pending",
"retrieved_lemmas": [],
"solved_at_attempt": 0,
}
final = self._graph.invoke(initial)
return {
"success": final["status"] == "success",
"solved_at_attempt": final["solved_at_attempt"],
"total_attempts": final["attempt"],
}
|