p4r5kpftnp-cmd Claude Sonnet 4.6 commited on
Commit
3ac681e
·
1 Parent(s): 4562b5e

Add LangChain/LangGraph RAG pipeline for retrieval-augmented proof generation

Browse files

- src/mathlib_corpus.py: walks Mathlib4 .lean source files and extracts
theorem/lemma/def declarations as LangChain Documents
- src/retriever.py: MathLibRetriever — FAISS + BM25 hybrid retrieval via
EnsembleRetriever (60/40), reranked with CrossEncoder; index persisted to
data/mathlib_index/ for reuse across runs
- src/rag_chain.py: RAGProofChain — LangChain LCEL chain
(ChatPromptTemplate | OllamaLLM | StrOutputParser) that injects retrieved
lemma context into each proof-generation call
- src/langgraph_agent.py: LangGraphAgent — replaces the plain Python retry
loop with a proper state machine (verify → retrieve → generate → verify);
exposes the same solve_file() API
- src/proof_agent.py: thin backward-compatible wrapper around LangGraphAgent
- scripts/build_index.py: one-time offline script to build and save the FAISS
index from Mathlib4 source
- scripts/run_agent.py: updated to use LangGraphAgent directly
- requirements.txt: add faiss-cpu, rank-bm25, sentence-transformers,
langgraph, langchain-ollama

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

problems/simple_add.lean ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import Mathlib
2
+
3
+ theorem add_zero_simple (n : ℕ) : n + 0 = n := by
4
+ sorry
problems/test_problem.lean ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ import Mathlib
2
+
3
+ theorem square_root_two_irrational (n m : ℕ) (h : n^2 = 2 * m^2) (hm : m ≠ 0) : False := by
4
+ sorry
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ lean-interact
2
+ ollama
3
+ langchain
4
+ langchain-community
5
+ langchain-ollama
6
+ langgraph
7
+ faiss-cpu
8
+ rank-bm25
9
+ sentence-transformers
scripts/build_index.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ One-time script to build the FAISS + BM25 index from Mathlib4 source files.
4
+ Run this before using the LangGraph agent for the first time:
5
+
6
+ python scripts/build_index.py
7
+
8
+ Optional flags:
9
+ --mathlib-root PATH Path to Mathlib4 source (auto-detected if omitted)
10
+ --max-files N Limit to first N .lean files (useful for quick testing)
11
+ --index-dir PATH Where to save the index (default: data/mathlib_index)
12
+ """
13
+ import sys
14
+ import os
15
+ import argparse
16
+
17
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
18
+
19
+ from retriever import MathLibRetriever
20
+
21
+
22
+ def main():
23
+ parser = argparse.ArgumentParser(description="Build Mathlib FAISS index.")
24
+ parser.add_argument("--mathlib-root", default=None, help="Path to Mathlib4 source root")
25
+ parser.add_argument("--max-files", type=int, default=None, help="Limit number of .lean files processed")
26
+ parser.add_argument("--index-dir", default=None, help="Directory to save the index")
27
+ args = parser.parse_args()
28
+
29
+ retriever = MathLibRetriever(index_dir=args.index_dir)
30
+
31
+ if retriever.is_index_built() and not args.max_files:
32
+ print(f"Index already exists at {retriever.index_dir}. Delete it to rebuild.")
33
+ return
34
+
35
+ retriever.build(mathlib_root=args.mathlib_root, max_files=args.max_files)
36
+ print("Done. Index is ready for use.")
37
+
38
+
39
+ if __name__ == "__main__":
40
+ main()
scripts/run_agent.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import sys
3
+ import os
4
+ import argparse
5
+
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
7
+
8
+ from langgraph_agent import LangGraphAgent
9
+
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Run the LangGraph Lean Proof Agent on a file.")
13
+ parser.add_argument("file", help="Path to the .lean file to solve")
14
+ parser.add_argument("--model", default="qwen3-vl:4b", help="Ollama model name")
15
+ parser.add_argument("--retries", type=int, default=5, help="Max retries")
16
+ parser.add_argument("--index-dir", default=None, help="Path to pre-built FAISS index directory")
17
+
18
+ args = parser.parse_args()
19
+
20
+ print(f"Starting LangGraph Proof Agent with model: {args.model}")
21
+ agent = LangGraphAgent(
22
+ model_name=args.model,
23
+ max_retries=args.retries,
24
+ index_dir=args.index_dir,
25
+ )
26
+
27
+ success = agent.solve_file(args.file)
28
+
29
+ if success:
30
+ print("\nSuccess! The proof has been verified.")
31
+ else:
32
+ print("\nFailed to verify the proof.")
33
+
34
+
35
+ if __name__ == "__main__":
36
+ main()
src/langgraph_agent.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, TypedDict
3
+
4
+ from langgraph.graph import END, StateGraph
5
+
6
+ from lean_verifier import LeanEnvironment
7
+ from rag_chain import RAGProofChain
8
+ from retriever import MathLibRetriever
9
+
10
+
11
+ # ---------------------------------------------------------------------------
12
+ # State
13
+ # ---------------------------------------------------------------------------
14
+
15
+ class ProofState(TypedDict):
16
+ file_path: str
17
+ lean_code: str
18
+ goals: List[str]
19
+ errors: List[str]
20
+ attempt: int
21
+ max_retries: int
22
+ status: str # "pending" | "success" | "failed"
23
+ retrieved_lemmas: list
24
+
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Nodes
28
+ # ---------------------------------------------------------------------------
29
+
30
+ def _read_file(path: str) -> str:
31
+ with open(path, "r") as f:
32
+ return f.read()
33
+
34
+
35
+ def _write_file(path: str, code: str) -> None:
36
+ with open(path, "w") as f:
37
+ f.write(code)
38
+
39
+
40
+ def _extract_lean_code(text: str) -> str:
41
+ if "```lean" in text:
42
+ return text.split("```lean")[1].split("```")[0].strip()
43
+ if "```" in text:
44
+ return text.split("```")[1].split("```")[0].strip()
45
+ return text.strip()
46
+
47
+
48
+ def make_verify_node(lean_env: LeanEnvironment):
49
+ def verify_node(state: ProofState) -> ProofState:
50
+ print(f"\n--- Attempt {state['attempt'] + 1} / {state['max_retries']} ---")
51
+ code = _read_file(state["file_path"])
52
+ result = lean_env.verify_proof(code)
53
+
54
+ new_status = "success" if result["status"] == "success" else "pending"
55
+ if new_status == "success":
56
+ print("Proof verified successfully!")
57
+ else:
58
+ print(
59
+ f"Verification failed. "
60
+ f"Errors: {len(result['errors'])}, Goals: {len(result['goals'])}"
61
+ )
62
+
63
+ return {
64
+ **state,
65
+ "lean_code": code,
66
+ "errors": result["errors"],
67
+ "goals": result["goals"],
68
+ "status": new_status,
69
+ }
70
+ return verify_node
71
+
72
+
73
+ def make_retrieve_node(retriever: MathLibRetriever):
74
+ def retrieve_node(state: ProofState) -> ProofState:
75
+ query = " ".join(state["goals"] + state["errors"])
76
+ print("Retrieving relevant Mathlib lemmas…")
77
+ lemmas = retriever.retrieve(query)
78
+ print(f" Retrieved {len(lemmas)} lemma(s).")
79
+ return {**state, "retrieved_lemmas": lemmas}
80
+ return retrieve_node
81
+
82
+
83
+ def make_generate_node(chain: RAGProofChain):
84
+ def generate_node(state: ProofState) -> ProofState:
85
+ print("Generating proof with LLM…")
86
+ raw = chain.generate(
87
+ lean_code=state["lean_code"],
88
+ goals=state["goals"],
89
+ errors=state["errors"],
90
+ retrieved_lemmas=state["retrieved_lemmas"],
91
+ )
92
+ new_code = _extract_lean_code(raw)
93
+
94
+ if not new_code or new_code.strip() == state["lean_code"].strip():
95
+ print("LLM produced no changes.")
96
+ return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
97
+
98
+ _write_file(state["file_path"], new_code)
99
+ print("File updated.")
100
+ return {
101
+ **state,
102
+ "lean_code": new_code,
103
+ "attempt": state["attempt"] + 1,
104
+ }
105
+ return generate_node
106
+
107
+
108
+ # ---------------------------------------------------------------------------
109
+ # Router
110
+ # ---------------------------------------------------------------------------
111
+
112
+ def should_continue(state: ProofState) -> str:
113
+ if state["status"] == "success":
114
+ return END
115
+ if state["attempt"] >= state["max_retries"]:
116
+ return END
117
+ return "retrieve"
118
+
119
+
120
+ # ---------------------------------------------------------------------------
121
+ # Graph assembly
122
+ # ---------------------------------------------------------------------------
123
+
124
+ def build_graph(lean_env: LeanEnvironment, retriever: MathLibRetriever, chain: RAGProofChain):
125
+ g = StateGraph(ProofState)
126
+
127
+ g.add_node("verify", make_verify_node(lean_env))
128
+ g.add_node("retrieve", make_retrieve_node(retriever))
129
+ g.add_node("generate", make_generate_node(chain))
130
+
131
+ g.set_entry_point("verify")
132
+ g.add_conditional_edges("verify", should_continue, {"retrieve": "retrieve", END: END})
133
+ g.add_edge("retrieve", "generate")
134
+ g.add_edge("generate", "verify")
135
+
136
+ return g.compile()
137
+
138
+
139
+ # ---------------------------------------------------------------------------
140
+ # Public entry point
141
+ # ---------------------------------------------------------------------------
142
+
143
+ class LangGraphAgent:
144
+ def __init__(
145
+ self,
146
+ model_name: str = "qwen3-vl:4b",
147
+ max_retries: int = 5,
148
+ index_dir: str | None = None,
149
+ ):
150
+ self._lean_env = LeanEnvironment(use_mathlib=True)
151
+ self._retriever = MathLibRetriever(index_dir=index_dir)
152
+ self._chain = RAGProofChain(model_name=model_name)
153
+ self._graph = build_graph(self._lean_env, self._retriever, self._chain)
154
+ self._max_retries = max_retries
155
+
156
+ def solve_file(self, file_path: str) -> bool:
157
+ if not os.path.exists(file_path):
158
+ print(f"Error: {file_path} not found.")
159
+ return False
160
+
161
+ initial: ProofState = {
162
+ "file_path": file_path,
163
+ "lean_code": "",
164
+ "goals": [],
165
+ "errors": [],
166
+ "attempt": 0,
167
+ "max_retries": self._max_retries,
168
+ "status": "pending",
169
+ "retrieved_lemmas": [],
170
+ }
171
+
172
+ final = self._graph.invoke(initial)
173
+ return final["status"] == "success"
src/lean_verifier.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Any, List
2
+
3
+ from lean_interact import LeanREPLConfig, LeanServer, Command, TempRequireProject, LeanRequire
4
+
5
+
6
+ class LeanEnvironment:
7
+ """
8
+ Manages the Lean REPL environment for verifying Lean 4 proofs.
9
+ """
10
+
11
+ def __init__(self, use_mathlib: bool = True, lean_version: str = "v4.8.0"):
12
+ """
13
+ Initializes the Lean environment.
14
+
15
+ Args:
16
+ use_mathlib (bool): If True, configures a TempRequireProject with Mathlib.
17
+ This may take a while to build on the first run.
18
+ lean_version (str): The Lean 4 version to use. Default is v4.8.0.
19
+ """
20
+ self.lean_version = lean_version
21
+ self.use_mathlib = use_mathlib
22
+
23
+ if self.use_mathlib:
24
+ # We use TempRequireProject with mathlib as specified in lean_interact documentation
25
+ project = TempRequireProject(
26
+ lean_version=self.lean_version,
27
+ require="mathlib"
28
+ )
29
+ self.config = LeanREPLConfig(project=project)
30
+ else:
31
+ self.config = LeanREPLConfig(lean_version=self.lean_version)
32
+
33
+ self.server = LeanServer(self.config)
34
+
35
+ def verify_proof(self, lean_code: str) -> Dict[str, Any]:
36
+ """
37
+ Executes a block of Lean code and verifies if it is a correct proof.
38
+
39
+ Args:
40
+ lean_code (str): The full Lean 4 code string containing imports, theorem statement, and proof.
41
+
42
+ Returns:
43
+ dict: A dictionary containing the status, errors (if any), and goals (if open sorries remain).
44
+ """
45
+ response = self.server.run(Command(cmd=lean_code))
46
+
47
+ errors = []
48
+ goals = []
49
+
50
+ # Check for error or warning messages
51
+ if hasattr(response, 'messages') and response.messages:
52
+ for msg in response.messages:
53
+ if msg.severity in ['error', 'warning']:
54
+ # E.g., 'declaration uses 'sorry'' is a warning, but we might want to capture it
55
+ errors.append(msg.data)
56
+
57
+ # Check for open goals (sorries)
58
+ if hasattr(response, 'sorries') and response.sorries:
59
+ for sorry in response.sorries:
60
+ if sorry.goal:
61
+ goals.append(sorry.goal)
62
+
63
+ is_success = len(errors) == 0 and len(goals) == 0
64
+
65
+ return {
66
+ "status": "success" if is_success else "failure",
67
+ "errors": errors,
68
+ "goals": goals,
69
+ "env": getattr(response, "env", None)
70
+ }
src/lmm_client.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ollama
2
+ from typing import List, Dict, Any, Optional
3
+
4
+ class LMMClient:
5
+ """
6
+ Client for interacting with local LMMs via Ollama.
7
+ Focuses on Qwen3-VL:4B for high-reasoning tasks.
8
+ """
9
+
10
+ def __init__(self, model_name: str = "qwen3-vl:4b"):
11
+ self.model_name = model_name
12
+
13
+ def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str:
14
+ """
15
+ Sends a chat request to the model.
16
+ """
17
+ messages = []
18
+ if system_prompt:
19
+ messages.append({'role': 'system', 'content': system_prompt})
20
+
21
+ messages.append({'role': 'user', 'content': prompt})
22
+
23
+ response = ollama.chat(
24
+ model=self.model_name,
25
+ messages=messages
26
+ )
27
+ return response['message']['content']
28
+
29
+ def generate_proof_steps(self, lean_code: str, goals: List[str], errors: List[str]) -> str:
30
+ """
31
+ Specific helper to generate proof steps based on current Lean state.
32
+ """
33
+ system_prompt = (
34
+ "You are an expert Lean 4 proof assistant. "
35
+ "Your goal is to complete the proof by replacing 'sorry' with valid Lean 4 code. "
36
+ "Use Mathlib theorems where appropriate. "
37
+ "Respond ONLY with the corrected Lean code block."
38
+ )
39
+
40
+ prompt = f"""
41
+ Current Lean Code:
42
+ ```lean
43
+ {lean_code}
44
+ ```
45
+
46
+ Current Proof Goals:
47
+ {chr(10).join(goals)}
48
+
49
+ Lean Errors:
50
+ {chr(10).join(errors)}
51
+
52
+ Please provide the corrected Lean code. Focus on solving the current goals and fixing the errors.
53
+ """
54
+ return self.chat(prompt, system_prompt)
src/mathlib_corpus.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import glob
4
+ from pathlib import Path
5
+ from typing import List, Optional
6
+
7
+ from langchain_core.documents import Document
8
+
9
+
10
+ # Regex to capture optional docstring + declaration line
11
+ _DECL_PATTERN = re.compile(
12
+ r'(?:/--\s*(.*?)\s*-/)?\s*' # optional /-- docstring -/
13
+ r'(?:@\[.*?\]\s*)*' # optional attributes
14
+ r'(theorem|lemma|def|noncomputable def)\s+'
15
+ r'(\S+)\s*' # declaration name
16
+ r'(.*?)\s*:=', # everything up to :=
17
+ re.DOTALL,
18
+ )
19
+
20
+
21
+ def _find_mathlib_root() -> Optional[str]:
22
+ """
23
+ Returns the path to the Mathlib4 source directory, searching common locations.
24
+ """
25
+ candidates = [
26
+ # Lake package cache (used by lean-interact's TempRequireProject)
27
+ os.path.expanduser("~/.elan/toolchains"),
28
+ os.path.expanduser("~/.cache/mathlib"),
29
+ # Nix / Homebrew Lean setups
30
+ "/usr/local/lib/lean",
31
+ "/opt/homebrew/lib/lean",
32
+ ]
33
+
34
+ # Also check for a .lake/packages directory next to the current file
35
+ here = Path(__file__).resolve().parent.parent
36
+ lake_pkg = here / ".lake" / "packages" / "mathlib" / "Mathlib"
37
+ if lake_pkg.exists():
38
+ return str(lake_pkg.parent)
39
+
40
+ for root in candidates:
41
+ if not os.path.isdir(root):
42
+ continue
43
+ # Walk up to 4 levels looking for a Mathlib directory
44
+ for dirpath, dirnames, _ in os.walk(root):
45
+ depth = dirpath.replace(root, "").count(os.sep)
46
+ if depth > 4:
47
+ dirnames.clear()
48
+ continue
49
+ if "Mathlib" in dirnames:
50
+ return dirpath
51
+ return None
52
+
53
+
54
+ def _parse_lean_file(path: str) -> List[Document]:
55
+ """
56
+ Extracts theorem/lemma/def declarations from a single .lean file.
57
+ Returns a list of Documents with page_content = "<name> : <signature>".
58
+ """
59
+ try:
60
+ text = Path(path).read_text(encoding="utf-8", errors="ignore")
61
+ except OSError:
62
+ return []
63
+
64
+ docs = []
65
+ for match in _DECL_PATTERN.finditer(text):
66
+ docstring = (match.group(1) or "").strip()
67
+ kind = match.group(2)
68
+ name = match.group(3)
69
+ signature = re.sub(r'\s+', ' ', match.group(4)).strip()
70
+
71
+ content = f"{name} : {signature}"
72
+ if docstring:
73
+ content = f"{docstring}\n{content}"
74
+
75
+ line = text[: match.start()].count("\n") + 1
76
+ docs.append(Document(
77
+ page_content=content,
78
+ metadata={"kind": kind, "name": name, "file": path, "line": line},
79
+ ))
80
+ return docs
81
+
82
+
83
+ class MathLibCorpus:
84
+ """
85
+ Extracts LangChain Documents from Mathlib4 source files on disk.
86
+ """
87
+
88
+ def __init__(self, mathlib_root: Optional[str] = None):
89
+ self.mathlib_root = mathlib_root or _find_mathlib_root()
90
+
91
+ def extract(self, max_files: Optional[int] = None) -> List[Document]:
92
+ """
93
+ Walks Mathlib source files and extracts declaration Documents.
94
+
95
+ Args:
96
+ max_files: If set, stop after processing this many .lean files
97
+ (useful for quick tests).
98
+
99
+ Returns:
100
+ List of LangChain Documents, one per declaration found.
101
+ """
102
+ if not self.mathlib_root:
103
+ raise RuntimeError(
104
+ "Could not locate Mathlib4 source. "
105
+ "Pass mathlib_root explicitly or run `lake exe cache get` first."
106
+ )
107
+
108
+ pattern = os.path.join(self.mathlib_root, "**", "*.lean")
109
+ files = glob.glob(pattern, recursive=True)
110
+ if max_files:
111
+ files = files[:max_files]
112
+
113
+ docs: List[Document] = []
114
+ for path in files:
115
+ docs.extend(_parse_lean_file(path))
116
+
117
+ return docs
src/proof_agent.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langgraph_agent import LangGraphAgent
2
+
3
+
4
+ class ProofAgent:
5
+ """Thin compatibility wrapper around LangGraphAgent."""
6
+
7
+ def __init__(self, model_name: str = "qwen3-vl:4b", max_retries: int = 5):
8
+ self._agent = LangGraphAgent(model_name=model_name, max_retries=max_retries)
9
+
10
+ def solve_file(self, file_path: str) -> bool:
11
+ return self._agent.solve_file(file_path)
src/rag_chain.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from langchain_core.documents import Document
4
+ from langchain_core.output_parsers import StrOutputParser
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_ollama import OllamaLLM
7
+
8
+
9
+ _SYSTEM = (
10
+ "You are an expert Lean 4 proof assistant with deep knowledge of Mathlib. "
11
+ "Your task is to complete the proof by replacing every `sorry` with valid Lean 4 tactic code. "
12
+ "Use only Mathlib theorems and tactics. "
13
+ "Respond ONLY with the corrected Lean code inside a single ```lean ... ``` block."
14
+ )
15
+
16
+ _HUMAN = """\
17
+ ## Current Lean Code
18
+ ```lean
19
+ {lean_code}
20
+ ```
21
+
22
+ ## Open Proof Goals
23
+ {goals}
24
+
25
+ ## Lean Errors
26
+ {errors}
27
+
28
+ ## Relevant Mathlib Lemmas
29
+ {retrieved_lemmas}
30
+
31
+ Provide the corrected Lean code that solves all goals and fixes all errors.
32
+ """
33
+
34
+
35
+ def _format_docs(docs: List[Document]) -> str:
36
+ if not docs:
37
+ return "(none retrieved)"
38
+ return "\n".join(
39
+ f"- `{d.metadata.get('name', '?')}`: {d.page_content}" for d in docs
40
+ )
41
+
42
+
43
+ class RAGProofChain:
44
+ """
45
+ LangChain LCEL chain: retrieved context + proof state → corrected Lean code.
46
+ """
47
+
48
+ def __init__(self, model_name: str = "qwen3-vl:4b"):
49
+ prompt = ChatPromptTemplate.from_messages([
50
+ ("system", _SYSTEM),
51
+ ("human", _HUMAN),
52
+ ])
53
+ llm = OllamaLLM(model=model_name)
54
+ self._chain = prompt | llm | StrOutputParser()
55
+
56
+ def generate(
57
+ self,
58
+ lean_code: str,
59
+ goals: List[str],
60
+ errors: List[str],
61
+ retrieved_lemmas: List[Document],
62
+ ) -> str:
63
+ """
64
+ Generate corrected Lean code given the current proof state and retrieved lemmas.
65
+ """
66
+ return self._chain.invoke({
67
+ "lean_code": lean_code,
68
+ "goals": "\n".join(goals) or "(none)",
69
+ "errors": "\n".join(errors) or "(none)",
70
+ "retrieved_lemmas": _format_docs(retrieved_lemmas),
71
+ })
src/retriever.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pathlib import Path
3
+ from typing import List, Optional
4
+
5
+ from langchain_community.retrievers import BM25Retriever
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain_community.cross_encoders import HuggingFaceCrossEncoder
8
+ from langchain.retrievers import EnsembleRetriever
9
+ from langchain.retrievers.document_compressors import CrossEncoderReranker
10
+ from langchain.retrievers import ContextualCompressionRetriever
11
+ from langchain_huggingface import HuggingFaceEmbeddings
12
+ from langchain_core.documents import Document
13
+
14
+ from mathlib_corpus import MathLibCorpus
15
+
16
+ _DEFAULT_INDEX_DIR = Path(__file__).resolve().parent.parent / "data" / "mathlib_index"
17
+ _EMBED_MODEL = "all-MiniLM-L6-v2"
18
+ _RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
19
+
20
+
21
+ class MathLibRetriever:
22
+ """
23
+ Hybrid FAISS + BM25 retriever with CrossEncoder reranking over Mathlib lemmas.
24
+
25
+ On first use, call build() to create and persist the index.
26
+ Subsequent runs load from disk automatically.
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ index_dir: Optional[str] = None,
32
+ top_k: int = 20,
33
+ rerank_top_k: int = 5,
34
+ ):
35
+ self.index_dir = Path(index_dir) if index_dir else _DEFAULT_INDEX_DIR
36
+ self.top_k = top_k
37
+ self.rerank_top_k = rerank_top_k
38
+ self._retriever = None
39
+
40
+ # ------------------------------------------------------------------
41
+ # Public API
42
+ # ------------------------------------------------------------------
43
+
44
+ def build(self, mathlib_root: Optional[str] = None, max_files: Optional[int] = None) -> None:
45
+ """
46
+ Extract Mathlib documents, build FAISS + BM25 indices, and persist to disk.
47
+ Call this once (via scripts/build_index.py) before first use.
48
+ """
49
+ print("Extracting Mathlib corpus…")
50
+ corpus = MathLibCorpus(mathlib_root=mathlib_root)
51
+ docs = corpus.extract(max_files=max_files)
52
+ print(f" {len(docs)} declarations extracted.")
53
+
54
+ embeddings = self._embeddings()
55
+
56
+ print("Building FAISS index…")
57
+ faiss_store = FAISS.from_documents(docs, embeddings)
58
+ self.index_dir.mkdir(parents=True, exist_ok=True)
59
+ faiss_store.save_local(str(self.index_dir))
60
+ print(f" Index saved to {self.index_dir}")
61
+
62
+ self._retriever = self._build_retriever(faiss_store, docs)
63
+
64
+ def retrieve(self, query: str, k: Optional[int] = None) -> List[Document]:
65
+ """
66
+ Retrieve and rerank the most relevant Mathlib lemmas for a query.
67
+
68
+ Args:
69
+ query: Natural-language or Lean-syntax query (e.g., proof goals + errors).
70
+ k: Number of results to return after reranking (defaults to self.rerank_top_k).
71
+
72
+ Returns:
73
+ List of Documents ranked by relevance.
74
+ """
75
+ if self._retriever is None:
76
+ self._load()
77
+ results = self._retriever.invoke(query)
78
+ return results[: k or self.rerank_top_k]
79
+
80
+ def is_index_built(self) -> bool:
81
+ return (self.index_dir / "index.faiss").exists()
82
+
83
+ # ------------------------------------------------------------------
84
+ # Internal helpers
85
+ # ------------------------------------------------------------------
86
+
87
+ def _embeddings(self) -> HuggingFaceEmbeddings:
88
+ return HuggingFaceEmbeddings(model_name=_EMBED_MODEL)
89
+
90
+ def _load(self) -> None:
91
+ if not self.is_index_built():
92
+ raise RuntimeError(
93
+ f"No FAISS index found at {self.index_dir}. "
94
+ "Run `python scripts/build_index.py` first."
95
+ )
96
+ print("Loading FAISS index from disk…")
97
+ embeddings = self._embeddings()
98
+ faiss_store = FAISS.load_local(
99
+ str(self.index_dir),
100
+ embeddings,
101
+ allow_dangerous_deserialization=True,
102
+ )
103
+
104
+ # Re-build BM25 from FAISS docstore
105
+ docs = list(faiss_store.docstore._dict.values())
106
+ self._retriever = self._build_retriever(faiss_store, docs)
107
+
108
+ def _build_retriever(self, faiss_store: FAISS, docs: List[Document]):
109
+ faiss_retriever = faiss_store.as_retriever(
110
+ search_kwargs={"k": self.top_k}
111
+ )
112
+ bm25_retriever = BM25Retriever.from_documents(docs)
113
+ bm25_retriever.k = self.top_k
114
+
115
+ ensemble = EnsembleRetriever(
116
+ retrievers=[faiss_retriever, bm25_retriever],
117
+ weights=[0.6, 0.4],
118
+ )
119
+
120
+ cross_encoder = HuggingFaceCrossEncoder(model_name=_RERANK_MODEL)
121
+ reranker = CrossEncoderReranker(model=cross_encoder, top_n=self.rerank_top_k)
122
+
123
+ return ContextualCompressionRetriever(
124
+ base_compressor=reranker,
125
+ base_retriever=ensemble,
126
+ )
tests/test_lean_verifier.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import sys
3
+ import os
4
+
5
+ # Add src to Python path
6
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
7
+
8
+ from lean_verifier import LeanEnvironment
9
+
10
+ class TestLeanVerifier(unittest.TestCase):
11
+ @classmethod
12
+ def setUpClass(cls):
13
+ # We'll use Mathlib here since our goal is to verify it works with the MVP setup
14
+ cls.lean_env = LeanEnvironment(use_mathlib=True)
15
+
16
+ def test_correct_proof(self):
17
+ lean_code = """
18
+ import Mathlib
19
+
20
+ theorem add_comm_test (n m : Nat) : n + m = m + n := by
21
+ exact Nat.add_comm n m
22
+ """
23
+ result = self.lean_env.verify_proof(lean_code)
24
+ self.assertEqual(result["status"], "success")
25
+ self.assertEqual(len(result["errors"]), 0)
26
+ self.assertEqual(len(result["goals"]), 0)
27
+
28
+ def test_incorrect_proof_type_mismatch(self):
29
+ lean_code = """
30
+ import Mathlib
31
+
32
+ theorem add_comm_test (n m : Nat) : n + m = m + n := by
33
+ exact n
34
+ """
35
+ result = self.lean_env.verify_proof(lean_code)
36
+ self.assertEqual(result["status"], "failure")
37
+ self.assertTrue(any("type mismatch" in err or "application type mismatch" in err for err in result["errors"]),
38
+ f"Expected type mismatch error, got: {result['errors']}")
39
+
40
+ def test_incomplete_proof_sorry(self):
41
+ lean_code = """
42
+ import Mathlib
43
+
44
+ theorem my_incomplete_thm (n : Nat) : n = 5 → n = 5 := by
45
+ sorry
46
+ """
47
+ result = self.lean_env.verify_proof(lean_code)
48
+ self.assertEqual(result["status"], "failure")
49
+ # Ensure it has an error indicating sorry
50
+ self.assertTrue(any("uses 'sorry'" in err for err in result["errors"]),
51
+ f"Expected sorry warning/error, got: {result['errors']}")
52
+ # Ensure it outputs the goal
53
+ self.assertEqual(len(result["goals"]), 1)
54
+ self.assertIn("⊢ n = 5 → n = 5", result["goals"][0])
55
+
56
+ if __name__ == '__main__':
57
+ unittest.main()