Spaces:
Running
Add LangChain/LangGraph RAG pipeline for retrieval-augmented proof generation
Browse files- src/mathlib_corpus.py: walks Mathlib4 .lean source files and extracts
theorem/lemma/def declarations as LangChain Documents
- src/retriever.py: MathLibRetriever — FAISS + BM25 hybrid retrieval via
EnsembleRetriever (60/40), reranked with CrossEncoder; index persisted to
data/mathlib_index/ for reuse across runs
- src/rag_chain.py: RAGProofChain — LangChain LCEL chain
(ChatPromptTemplate | OllamaLLM | StrOutputParser) that injects retrieved
lemma context into each proof-generation call
- src/langgraph_agent.py: LangGraphAgent — replaces the plain Python retry
loop with a proper state machine (verify → retrieve → generate → verify);
exposes the same solve_file() API
- src/proof_agent.py: thin backward-compatible wrapper around LangGraphAgent
- scripts/build_index.py: one-time offline script to build and save the FAISS
index from Mathlib4 source
- scripts/run_agent.py: updated to use LangGraphAgent directly
- requirements.txt: add faiss-cpu, rank-bm25, sentence-transformers,
langgraph, langchain-ollama
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- problems/simple_add.lean +4 -0
- problems/test_problem.lean +4 -0
- requirements.txt +9 -0
- scripts/build_index.py +40 -0
- scripts/run_agent.py +36 -0
- src/langgraph_agent.py +173 -0
- src/lean_verifier.py +70 -0
- src/lmm_client.py +54 -0
- src/mathlib_corpus.py +117 -0
- src/proof_agent.py +11 -0
- src/rag_chain.py +71 -0
- src/retriever.py +126 -0
- tests/test_lean_verifier.py +57 -0
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import Mathlib
|
| 2 |
+
|
| 3 |
+
theorem add_zero_simple (n : ℕ) : n + 0 = n := by
|
| 4 |
+
sorry
|
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import Mathlib
|
| 2 |
+
|
| 3 |
+
theorem square_root_two_irrational (n m : ℕ) (h : n^2 = 2 * m^2) (hm : m ≠ 0) : False := by
|
| 4 |
+
sorry
|
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
lean-interact
|
| 2 |
+
ollama
|
| 3 |
+
langchain
|
| 4 |
+
langchain-community
|
| 5 |
+
langchain-ollama
|
| 6 |
+
langgraph
|
| 7 |
+
faiss-cpu
|
| 8 |
+
rank-bm25
|
| 9 |
+
sentence-transformers
|
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
One-time script to build the FAISS + BM25 index from Mathlib4 source files.
|
| 4 |
+
Run this before using the LangGraph agent for the first time:
|
| 5 |
+
|
| 6 |
+
python scripts/build_index.py
|
| 7 |
+
|
| 8 |
+
Optional flags:
|
| 9 |
+
--mathlib-root PATH Path to Mathlib4 source (auto-detected if omitted)
|
| 10 |
+
--max-files N Limit to first N .lean files (useful for quick testing)
|
| 11 |
+
--index-dir PATH Where to save the index (default: data/mathlib_index)
|
| 12 |
+
"""
|
| 13 |
+
import sys
|
| 14 |
+
import os
|
| 15 |
+
import argparse
|
| 16 |
+
|
| 17 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
|
| 18 |
+
|
| 19 |
+
from retriever import MathLibRetriever
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def main():
|
| 23 |
+
parser = argparse.ArgumentParser(description="Build Mathlib FAISS index.")
|
| 24 |
+
parser.add_argument("--mathlib-root", default=None, help="Path to Mathlib4 source root")
|
| 25 |
+
parser.add_argument("--max-files", type=int, default=None, help="Limit number of .lean files processed")
|
| 26 |
+
parser.add_argument("--index-dir", default=None, help="Directory to save the index")
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
retriever = MathLibRetriever(index_dir=args.index_dir)
|
| 30 |
+
|
| 31 |
+
if retriever.is_index_built() and not args.max_files:
|
| 32 |
+
print(f"Index already exists at {retriever.index_dir}. Delete it to rebuild.")
|
| 33 |
+
return
|
| 34 |
+
|
| 35 |
+
retriever.build(mathlib_root=args.mathlib_root, max_files=args.max_files)
|
| 36 |
+
print("Done. Index is ready for use.")
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
if __name__ == "__main__":
|
| 40 |
+
main()
|
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
|
| 6 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
|
| 7 |
+
|
| 8 |
+
from langgraph_agent import LangGraphAgent
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main():
|
| 12 |
+
parser = argparse.ArgumentParser(description="Run the LangGraph Lean Proof Agent on a file.")
|
| 13 |
+
parser.add_argument("file", help="Path to the .lean file to solve")
|
| 14 |
+
parser.add_argument("--model", default="qwen3-vl:4b", help="Ollama model name")
|
| 15 |
+
parser.add_argument("--retries", type=int, default=5, help="Max retries")
|
| 16 |
+
parser.add_argument("--index-dir", default=None, help="Path to pre-built FAISS index directory")
|
| 17 |
+
|
| 18 |
+
args = parser.parse_args()
|
| 19 |
+
|
| 20 |
+
print(f"Starting LangGraph Proof Agent with model: {args.model}")
|
| 21 |
+
agent = LangGraphAgent(
|
| 22 |
+
model_name=args.model,
|
| 23 |
+
max_retries=args.retries,
|
| 24 |
+
index_dir=args.index_dir,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
success = agent.solve_file(args.file)
|
| 28 |
+
|
| 29 |
+
if success:
|
| 30 |
+
print("\nSuccess! The proof has been verified.")
|
| 31 |
+
else:
|
| 32 |
+
print("\nFailed to verify the proof.")
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List, TypedDict
|
| 3 |
+
|
| 4 |
+
from langgraph.graph import END, StateGraph
|
| 5 |
+
|
| 6 |
+
from lean_verifier import LeanEnvironment
|
| 7 |
+
from rag_chain import RAGProofChain
|
| 8 |
+
from retriever import MathLibRetriever
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# ---------------------------------------------------------------------------
|
| 12 |
+
# State
|
| 13 |
+
# ---------------------------------------------------------------------------
|
| 14 |
+
|
| 15 |
+
class ProofState(TypedDict):
|
| 16 |
+
file_path: str
|
| 17 |
+
lean_code: str
|
| 18 |
+
goals: List[str]
|
| 19 |
+
errors: List[str]
|
| 20 |
+
attempt: int
|
| 21 |
+
max_retries: int
|
| 22 |
+
status: str # "pending" | "success" | "failed"
|
| 23 |
+
retrieved_lemmas: list
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Nodes
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
def _read_file(path: str) -> str:
|
| 31 |
+
with open(path, "r") as f:
|
| 32 |
+
return f.read()
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _write_file(path: str, code: str) -> None:
|
| 36 |
+
with open(path, "w") as f:
|
| 37 |
+
f.write(code)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def _extract_lean_code(text: str) -> str:
|
| 41 |
+
if "```lean" in text:
|
| 42 |
+
return text.split("```lean")[1].split("```")[0].strip()
|
| 43 |
+
if "```" in text:
|
| 44 |
+
return text.split("```")[1].split("```")[0].strip()
|
| 45 |
+
return text.strip()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def make_verify_node(lean_env: LeanEnvironment):
|
| 49 |
+
def verify_node(state: ProofState) -> ProofState:
|
| 50 |
+
print(f"\n--- Attempt {state['attempt'] + 1} / {state['max_retries']} ---")
|
| 51 |
+
code = _read_file(state["file_path"])
|
| 52 |
+
result = lean_env.verify_proof(code)
|
| 53 |
+
|
| 54 |
+
new_status = "success" if result["status"] == "success" else "pending"
|
| 55 |
+
if new_status == "success":
|
| 56 |
+
print("Proof verified successfully!")
|
| 57 |
+
else:
|
| 58 |
+
print(
|
| 59 |
+
f"Verification failed. "
|
| 60 |
+
f"Errors: {len(result['errors'])}, Goals: {len(result['goals'])}"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return {
|
| 64 |
+
**state,
|
| 65 |
+
"lean_code": code,
|
| 66 |
+
"errors": result["errors"],
|
| 67 |
+
"goals": result["goals"],
|
| 68 |
+
"status": new_status,
|
| 69 |
+
}
|
| 70 |
+
return verify_node
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def make_retrieve_node(retriever: MathLibRetriever):
|
| 74 |
+
def retrieve_node(state: ProofState) -> ProofState:
|
| 75 |
+
query = " ".join(state["goals"] + state["errors"])
|
| 76 |
+
print("Retrieving relevant Mathlib lemmas…")
|
| 77 |
+
lemmas = retriever.retrieve(query)
|
| 78 |
+
print(f" Retrieved {len(lemmas)} lemma(s).")
|
| 79 |
+
return {**state, "retrieved_lemmas": lemmas}
|
| 80 |
+
return retrieve_node
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def make_generate_node(chain: RAGProofChain):
|
| 84 |
+
def generate_node(state: ProofState) -> ProofState:
|
| 85 |
+
print("Generating proof with LLM…")
|
| 86 |
+
raw = chain.generate(
|
| 87 |
+
lean_code=state["lean_code"],
|
| 88 |
+
goals=state["goals"],
|
| 89 |
+
errors=state["errors"],
|
| 90 |
+
retrieved_lemmas=state["retrieved_lemmas"],
|
| 91 |
+
)
|
| 92 |
+
new_code = _extract_lean_code(raw)
|
| 93 |
+
|
| 94 |
+
if not new_code or new_code.strip() == state["lean_code"].strip():
|
| 95 |
+
print("LLM produced no changes.")
|
| 96 |
+
return {**state, "attempt": state["attempt"] + 1, "status": "failed"}
|
| 97 |
+
|
| 98 |
+
_write_file(state["file_path"], new_code)
|
| 99 |
+
print("File updated.")
|
| 100 |
+
return {
|
| 101 |
+
**state,
|
| 102 |
+
"lean_code": new_code,
|
| 103 |
+
"attempt": state["attempt"] + 1,
|
| 104 |
+
}
|
| 105 |
+
return generate_node
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# ---------------------------------------------------------------------------
|
| 109 |
+
# Router
|
| 110 |
+
# ---------------------------------------------------------------------------
|
| 111 |
+
|
| 112 |
+
def should_continue(state: ProofState) -> str:
|
| 113 |
+
if state["status"] == "success":
|
| 114 |
+
return END
|
| 115 |
+
if state["attempt"] >= state["max_retries"]:
|
| 116 |
+
return END
|
| 117 |
+
return "retrieve"
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ---------------------------------------------------------------------------
|
| 121 |
+
# Graph assembly
|
| 122 |
+
# ---------------------------------------------------------------------------
|
| 123 |
+
|
| 124 |
+
def build_graph(lean_env: LeanEnvironment, retriever: MathLibRetriever, chain: RAGProofChain):
|
| 125 |
+
g = StateGraph(ProofState)
|
| 126 |
+
|
| 127 |
+
g.add_node("verify", make_verify_node(lean_env))
|
| 128 |
+
g.add_node("retrieve", make_retrieve_node(retriever))
|
| 129 |
+
g.add_node("generate", make_generate_node(chain))
|
| 130 |
+
|
| 131 |
+
g.set_entry_point("verify")
|
| 132 |
+
g.add_conditional_edges("verify", should_continue, {"retrieve": "retrieve", END: END})
|
| 133 |
+
g.add_edge("retrieve", "generate")
|
| 134 |
+
g.add_edge("generate", "verify")
|
| 135 |
+
|
| 136 |
+
return g.compile()
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
# ---------------------------------------------------------------------------
|
| 140 |
+
# Public entry point
|
| 141 |
+
# ---------------------------------------------------------------------------
|
| 142 |
+
|
| 143 |
+
class LangGraphAgent:
|
| 144 |
+
def __init__(
|
| 145 |
+
self,
|
| 146 |
+
model_name: str = "qwen3-vl:4b",
|
| 147 |
+
max_retries: int = 5,
|
| 148 |
+
index_dir: str | None = None,
|
| 149 |
+
):
|
| 150 |
+
self._lean_env = LeanEnvironment(use_mathlib=True)
|
| 151 |
+
self._retriever = MathLibRetriever(index_dir=index_dir)
|
| 152 |
+
self._chain = RAGProofChain(model_name=model_name)
|
| 153 |
+
self._graph = build_graph(self._lean_env, self._retriever, self._chain)
|
| 154 |
+
self._max_retries = max_retries
|
| 155 |
+
|
| 156 |
+
def solve_file(self, file_path: str) -> bool:
|
| 157 |
+
if not os.path.exists(file_path):
|
| 158 |
+
print(f"Error: {file_path} not found.")
|
| 159 |
+
return False
|
| 160 |
+
|
| 161 |
+
initial: ProofState = {
|
| 162 |
+
"file_path": file_path,
|
| 163 |
+
"lean_code": "",
|
| 164 |
+
"goals": [],
|
| 165 |
+
"errors": [],
|
| 166 |
+
"attempt": 0,
|
| 167 |
+
"max_retries": self._max_retries,
|
| 168 |
+
"status": "pending",
|
| 169 |
+
"retrieved_lemmas": [],
|
| 170 |
+
}
|
| 171 |
+
|
| 172 |
+
final = self._graph.invoke(initial)
|
| 173 |
+
return final["status"] == "success"
|
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Dict, Any, List
|
| 2 |
+
|
| 3 |
+
from lean_interact import LeanREPLConfig, LeanServer, Command, TempRequireProject, LeanRequire
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LeanEnvironment:
|
| 7 |
+
"""
|
| 8 |
+
Manages the Lean REPL environment for verifying Lean 4 proofs.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, use_mathlib: bool = True, lean_version: str = "v4.8.0"):
|
| 12 |
+
"""
|
| 13 |
+
Initializes the Lean environment.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
use_mathlib (bool): If True, configures a TempRequireProject with Mathlib.
|
| 17 |
+
This may take a while to build on the first run.
|
| 18 |
+
lean_version (str): The Lean 4 version to use. Default is v4.8.0.
|
| 19 |
+
"""
|
| 20 |
+
self.lean_version = lean_version
|
| 21 |
+
self.use_mathlib = use_mathlib
|
| 22 |
+
|
| 23 |
+
if self.use_mathlib:
|
| 24 |
+
# We use TempRequireProject with mathlib as specified in lean_interact documentation
|
| 25 |
+
project = TempRequireProject(
|
| 26 |
+
lean_version=self.lean_version,
|
| 27 |
+
require="mathlib"
|
| 28 |
+
)
|
| 29 |
+
self.config = LeanREPLConfig(project=project)
|
| 30 |
+
else:
|
| 31 |
+
self.config = LeanREPLConfig(lean_version=self.lean_version)
|
| 32 |
+
|
| 33 |
+
self.server = LeanServer(self.config)
|
| 34 |
+
|
| 35 |
+
def verify_proof(self, lean_code: str) -> Dict[str, Any]:
|
| 36 |
+
"""
|
| 37 |
+
Executes a block of Lean code and verifies if it is a correct proof.
|
| 38 |
+
|
| 39 |
+
Args:
|
| 40 |
+
lean_code (str): The full Lean 4 code string containing imports, theorem statement, and proof.
|
| 41 |
+
|
| 42 |
+
Returns:
|
| 43 |
+
dict: A dictionary containing the status, errors (if any), and goals (if open sorries remain).
|
| 44 |
+
"""
|
| 45 |
+
response = self.server.run(Command(cmd=lean_code))
|
| 46 |
+
|
| 47 |
+
errors = []
|
| 48 |
+
goals = []
|
| 49 |
+
|
| 50 |
+
# Check for error or warning messages
|
| 51 |
+
if hasattr(response, 'messages') and response.messages:
|
| 52 |
+
for msg in response.messages:
|
| 53 |
+
if msg.severity in ['error', 'warning']:
|
| 54 |
+
# E.g., 'declaration uses 'sorry'' is a warning, but we might want to capture it
|
| 55 |
+
errors.append(msg.data)
|
| 56 |
+
|
| 57 |
+
# Check for open goals (sorries)
|
| 58 |
+
if hasattr(response, 'sorries') and response.sorries:
|
| 59 |
+
for sorry in response.sorries:
|
| 60 |
+
if sorry.goal:
|
| 61 |
+
goals.append(sorry.goal)
|
| 62 |
+
|
| 63 |
+
is_success = len(errors) == 0 and len(goals) == 0
|
| 64 |
+
|
| 65 |
+
return {
|
| 66 |
+
"status": "success" if is_success else "failure",
|
| 67 |
+
"errors": errors,
|
| 68 |
+
"goals": goals,
|
| 69 |
+
"env": getattr(response, "env", None)
|
| 70 |
+
}
|
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import ollama
|
| 2 |
+
from typing import List, Dict, Any, Optional
|
| 3 |
+
|
| 4 |
+
class LMMClient:
|
| 5 |
+
"""
|
| 6 |
+
Client for interacting with local LMMs via Ollama.
|
| 7 |
+
Focuses on Qwen3-VL:4B for high-reasoning tasks.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
def __init__(self, model_name: str = "qwen3-vl:4b"):
|
| 11 |
+
self.model_name = model_name
|
| 12 |
+
|
| 13 |
+
def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Sends a chat request to the model.
|
| 16 |
+
"""
|
| 17 |
+
messages = []
|
| 18 |
+
if system_prompt:
|
| 19 |
+
messages.append({'role': 'system', 'content': system_prompt})
|
| 20 |
+
|
| 21 |
+
messages.append({'role': 'user', 'content': prompt})
|
| 22 |
+
|
| 23 |
+
response = ollama.chat(
|
| 24 |
+
model=self.model_name,
|
| 25 |
+
messages=messages
|
| 26 |
+
)
|
| 27 |
+
return response['message']['content']
|
| 28 |
+
|
| 29 |
+
def generate_proof_steps(self, lean_code: str, goals: List[str], errors: List[str]) -> str:
|
| 30 |
+
"""
|
| 31 |
+
Specific helper to generate proof steps based on current Lean state.
|
| 32 |
+
"""
|
| 33 |
+
system_prompt = (
|
| 34 |
+
"You are an expert Lean 4 proof assistant. "
|
| 35 |
+
"Your goal is to complete the proof by replacing 'sorry' with valid Lean 4 code. "
|
| 36 |
+
"Use Mathlib theorems where appropriate. "
|
| 37 |
+
"Respond ONLY with the corrected Lean code block."
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
prompt = f"""
|
| 41 |
+
Current Lean Code:
|
| 42 |
+
```lean
|
| 43 |
+
{lean_code}
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Current Proof Goals:
|
| 47 |
+
{chr(10).join(goals)}
|
| 48 |
+
|
| 49 |
+
Lean Errors:
|
| 50 |
+
{chr(10).join(errors)}
|
| 51 |
+
|
| 52 |
+
Please provide the corrected Lean code. Focus on solving the current goals and fixing the errors.
|
| 53 |
+
"""
|
| 54 |
+
return self.chat(prompt, system_prompt)
|
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import glob
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import List, Optional
|
| 6 |
+
|
| 7 |
+
from langchain_core.documents import Document
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
# Regex to capture optional docstring + declaration line
|
| 11 |
+
_DECL_PATTERN = re.compile(
|
| 12 |
+
r'(?:/--\s*(.*?)\s*-/)?\s*' # optional /-- docstring -/
|
| 13 |
+
r'(?:@\[.*?\]\s*)*' # optional attributes
|
| 14 |
+
r'(theorem|lemma|def|noncomputable def)\s+'
|
| 15 |
+
r'(\S+)\s*' # declaration name
|
| 16 |
+
r'(.*?)\s*:=', # everything up to :=
|
| 17 |
+
re.DOTALL,
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _find_mathlib_root() -> Optional[str]:
|
| 22 |
+
"""
|
| 23 |
+
Returns the path to the Mathlib4 source directory, searching common locations.
|
| 24 |
+
"""
|
| 25 |
+
candidates = [
|
| 26 |
+
# Lake package cache (used by lean-interact's TempRequireProject)
|
| 27 |
+
os.path.expanduser("~/.elan/toolchains"),
|
| 28 |
+
os.path.expanduser("~/.cache/mathlib"),
|
| 29 |
+
# Nix / Homebrew Lean setups
|
| 30 |
+
"/usr/local/lib/lean",
|
| 31 |
+
"/opt/homebrew/lib/lean",
|
| 32 |
+
]
|
| 33 |
+
|
| 34 |
+
# Also check for a .lake/packages directory next to the current file
|
| 35 |
+
here = Path(__file__).resolve().parent.parent
|
| 36 |
+
lake_pkg = here / ".lake" / "packages" / "mathlib" / "Mathlib"
|
| 37 |
+
if lake_pkg.exists():
|
| 38 |
+
return str(lake_pkg.parent)
|
| 39 |
+
|
| 40 |
+
for root in candidates:
|
| 41 |
+
if not os.path.isdir(root):
|
| 42 |
+
continue
|
| 43 |
+
# Walk up to 4 levels looking for a Mathlib directory
|
| 44 |
+
for dirpath, dirnames, _ in os.walk(root):
|
| 45 |
+
depth = dirpath.replace(root, "").count(os.sep)
|
| 46 |
+
if depth > 4:
|
| 47 |
+
dirnames.clear()
|
| 48 |
+
continue
|
| 49 |
+
if "Mathlib" in dirnames:
|
| 50 |
+
return dirpath
|
| 51 |
+
return None
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def _parse_lean_file(path: str) -> List[Document]:
|
| 55 |
+
"""
|
| 56 |
+
Extracts theorem/lemma/def declarations from a single .lean file.
|
| 57 |
+
Returns a list of Documents with page_content = "<name> : <signature>".
|
| 58 |
+
"""
|
| 59 |
+
try:
|
| 60 |
+
text = Path(path).read_text(encoding="utf-8", errors="ignore")
|
| 61 |
+
except OSError:
|
| 62 |
+
return []
|
| 63 |
+
|
| 64 |
+
docs = []
|
| 65 |
+
for match in _DECL_PATTERN.finditer(text):
|
| 66 |
+
docstring = (match.group(1) or "").strip()
|
| 67 |
+
kind = match.group(2)
|
| 68 |
+
name = match.group(3)
|
| 69 |
+
signature = re.sub(r'\s+', ' ', match.group(4)).strip()
|
| 70 |
+
|
| 71 |
+
content = f"{name} : {signature}"
|
| 72 |
+
if docstring:
|
| 73 |
+
content = f"{docstring}\n{content}"
|
| 74 |
+
|
| 75 |
+
line = text[: match.start()].count("\n") + 1
|
| 76 |
+
docs.append(Document(
|
| 77 |
+
page_content=content,
|
| 78 |
+
metadata={"kind": kind, "name": name, "file": path, "line": line},
|
| 79 |
+
))
|
| 80 |
+
return docs
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class MathLibCorpus:
|
| 84 |
+
"""
|
| 85 |
+
Extracts LangChain Documents from Mathlib4 source files on disk.
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, mathlib_root: Optional[str] = None):
|
| 89 |
+
self.mathlib_root = mathlib_root or _find_mathlib_root()
|
| 90 |
+
|
| 91 |
+
def extract(self, max_files: Optional[int] = None) -> List[Document]:
|
| 92 |
+
"""
|
| 93 |
+
Walks Mathlib source files and extracts declaration Documents.
|
| 94 |
+
|
| 95 |
+
Args:
|
| 96 |
+
max_files: If set, stop after processing this many .lean files
|
| 97 |
+
(useful for quick tests).
|
| 98 |
+
|
| 99 |
+
Returns:
|
| 100 |
+
List of LangChain Documents, one per declaration found.
|
| 101 |
+
"""
|
| 102 |
+
if not self.mathlib_root:
|
| 103 |
+
raise RuntimeError(
|
| 104 |
+
"Could not locate Mathlib4 source. "
|
| 105 |
+
"Pass mathlib_root explicitly or run `lake exe cache get` first."
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
pattern = os.path.join(self.mathlib_root, "**", "*.lean")
|
| 109 |
+
files = glob.glob(pattern, recursive=True)
|
| 110 |
+
if max_files:
|
| 111 |
+
files = files[:max_files]
|
| 112 |
+
|
| 113 |
+
docs: List[Document] = []
|
| 114 |
+
for path in files:
|
| 115 |
+
docs.extend(_parse_lean_file(path))
|
| 116 |
+
|
| 117 |
+
return docs
|
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from langgraph_agent import LangGraphAgent
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class ProofAgent:
|
| 5 |
+
"""Thin compatibility wrapper around LangGraphAgent."""
|
| 6 |
+
|
| 7 |
+
def __init__(self, model_name: str = "qwen3-vl:4b", max_retries: int = 5):
|
| 8 |
+
self._agent = LangGraphAgent(model_name=model_name, max_retries=max_retries)
|
| 9 |
+
|
| 10 |
+
def solve_file(self, file_path: str) -> bool:
|
| 11 |
+
return self._agent.solve_file(file_path)
|
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from langchain_core.documents import Document
|
| 4 |
+
from langchain_core.output_parsers import StrOutputParser
|
| 5 |
+
from langchain_core.prompts import ChatPromptTemplate
|
| 6 |
+
from langchain_ollama import OllamaLLM
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
_SYSTEM = (
|
| 10 |
+
"You are an expert Lean 4 proof assistant with deep knowledge of Mathlib. "
|
| 11 |
+
"Your task is to complete the proof by replacing every `sorry` with valid Lean 4 tactic code. "
|
| 12 |
+
"Use only Mathlib theorems and tactics. "
|
| 13 |
+
"Respond ONLY with the corrected Lean code inside a single ```lean ... ``` block."
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
_HUMAN = """\
|
| 17 |
+
## Current Lean Code
|
| 18 |
+
```lean
|
| 19 |
+
{lean_code}
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Open Proof Goals
|
| 23 |
+
{goals}
|
| 24 |
+
|
| 25 |
+
## Lean Errors
|
| 26 |
+
{errors}
|
| 27 |
+
|
| 28 |
+
## Relevant Mathlib Lemmas
|
| 29 |
+
{retrieved_lemmas}
|
| 30 |
+
|
| 31 |
+
Provide the corrected Lean code that solves all goals and fixes all errors.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _format_docs(docs: List[Document]) -> str:
|
| 36 |
+
if not docs:
|
| 37 |
+
return "(none retrieved)"
|
| 38 |
+
return "\n".join(
|
| 39 |
+
f"- `{d.metadata.get('name', '?')}`: {d.page_content}" for d in docs
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class RAGProofChain:
|
| 44 |
+
"""
|
| 45 |
+
LangChain LCEL chain: retrieved context + proof state → corrected Lean code.
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
def __init__(self, model_name: str = "qwen3-vl:4b"):
|
| 49 |
+
prompt = ChatPromptTemplate.from_messages([
|
| 50 |
+
("system", _SYSTEM),
|
| 51 |
+
("human", _HUMAN),
|
| 52 |
+
])
|
| 53 |
+
llm = OllamaLLM(model=model_name)
|
| 54 |
+
self._chain = prompt | llm | StrOutputParser()
|
| 55 |
+
|
| 56 |
+
def generate(
|
| 57 |
+
self,
|
| 58 |
+
lean_code: str,
|
| 59 |
+
goals: List[str],
|
| 60 |
+
errors: List[str],
|
| 61 |
+
retrieved_lemmas: List[Document],
|
| 62 |
+
) -> str:
|
| 63 |
+
"""
|
| 64 |
+
Generate corrected Lean code given the current proof state and retrieved lemmas.
|
| 65 |
+
"""
|
| 66 |
+
return self._chain.invoke({
|
| 67 |
+
"lean_code": lean_code,
|
| 68 |
+
"goals": "\n".join(goals) or "(none)",
|
| 69 |
+
"errors": "\n".join(errors) or "(none)",
|
| 70 |
+
"retrieved_lemmas": _format_docs(retrieved_lemmas),
|
| 71 |
+
})
|
|
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
|
| 5 |
+
from langchain_community.retrievers import BM25Retriever
|
| 6 |
+
from langchain_community.vectorstores import FAISS
|
| 7 |
+
from langchain_community.cross_encoders import HuggingFaceCrossEncoder
|
| 8 |
+
from langchain.retrievers import EnsembleRetriever
|
| 9 |
+
from langchain.retrievers.document_compressors import CrossEncoderReranker
|
| 10 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
| 11 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 12 |
+
from langchain_core.documents import Document
|
| 13 |
+
|
| 14 |
+
from mathlib_corpus import MathLibCorpus
|
| 15 |
+
|
| 16 |
+
_DEFAULT_INDEX_DIR = Path(__file__).resolve().parent.parent / "data" / "mathlib_index"
|
| 17 |
+
_EMBED_MODEL = "all-MiniLM-L6-v2"
|
| 18 |
+
_RERANK_MODEL = "cross-encoder/ms-marco-MiniLM-L-6-v2"
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class MathLibRetriever:
|
| 22 |
+
"""
|
| 23 |
+
Hybrid FAISS + BM25 retriever with CrossEncoder reranking over Mathlib lemmas.
|
| 24 |
+
|
| 25 |
+
On first use, call build() to create and persist the index.
|
| 26 |
+
Subsequent runs load from disk automatically.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
index_dir: Optional[str] = None,
|
| 32 |
+
top_k: int = 20,
|
| 33 |
+
rerank_top_k: int = 5,
|
| 34 |
+
):
|
| 35 |
+
self.index_dir = Path(index_dir) if index_dir else _DEFAULT_INDEX_DIR
|
| 36 |
+
self.top_k = top_k
|
| 37 |
+
self.rerank_top_k = rerank_top_k
|
| 38 |
+
self._retriever = None
|
| 39 |
+
|
| 40 |
+
# ------------------------------------------------------------------
|
| 41 |
+
# Public API
|
| 42 |
+
# ------------------------------------------------------------------
|
| 43 |
+
|
| 44 |
+
def build(self, mathlib_root: Optional[str] = None, max_files: Optional[int] = None) -> None:
|
| 45 |
+
"""
|
| 46 |
+
Extract Mathlib documents, build FAISS + BM25 indices, and persist to disk.
|
| 47 |
+
Call this once (via scripts/build_index.py) before first use.
|
| 48 |
+
"""
|
| 49 |
+
print("Extracting Mathlib corpus…")
|
| 50 |
+
corpus = MathLibCorpus(mathlib_root=mathlib_root)
|
| 51 |
+
docs = corpus.extract(max_files=max_files)
|
| 52 |
+
print(f" {len(docs)} declarations extracted.")
|
| 53 |
+
|
| 54 |
+
embeddings = self._embeddings()
|
| 55 |
+
|
| 56 |
+
print("Building FAISS index…")
|
| 57 |
+
faiss_store = FAISS.from_documents(docs, embeddings)
|
| 58 |
+
self.index_dir.mkdir(parents=True, exist_ok=True)
|
| 59 |
+
faiss_store.save_local(str(self.index_dir))
|
| 60 |
+
print(f" Index saved to {self.index_dir}")
|
| 61 |
+
|
| 62 |
+
self._retriever = self._build_retriever(faiss_store, docs)
|
| 63 |
+
|
| 64 |
+
def retrieve(self, query: str, k: Optional[int] = None) -> List[Document]:
|
| 65 |
+
"""
|
| 66 |
+
Retrieve and rerank the most relevant Mathlib lemmas for a query.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
query: Natural-language or Lean-syntax query (e.g., proof goals + errors).
|
| 70 |
+
k: Number of results to return after reranking (defaults to self.rerank_top_k).
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
List of Documents ranked by relevance.
|
| 74 |
+
"""
|
| 75 |
+
if self._retriever is None:
|
| 76 |
+
self._load()
|
| 77 |
+
results = self._retriever.invoke(query)
|
| 78 |
+
return results[: k or self.rerank_top_k]
|
| 79 |
+
|
| 80 |
+
def is_index_built(self) -> bool:
|
| 81 |
+
return (self.index_dir / "index.faiss").exists()
|
| 82 |
+
|
| 83 |
+
# ------------------------------------------------------------------
|
| 84 |
+
# Internal helpers
|
| 85 |
+
# ------------------------------------------------------------------
|
| 86 |
+
|
| 87 |
+
def _embeddings(self) -> HuggingFaceEmbeddings:
|
| 88 |
+
return HuggingFaceEmbeddings(model_name=_EMBED_MODEL)
|
| 89 |
+
|
| 90 |
+
def _load(self) -> None:
|
| 91 |
+
if not self.is_index_built():
|
| 92 |
+
raise RuntimeError(
|
| 93 |
+
f"No FAISS index found at {self.index_dir}. "
|
| 94 |
+
"Run `python scripts/build_index.py` first."
|
| 95 |
+
)
|
| 96 |
+
print("Loading FAISS index from disk…")
|
| 97 |
+
embeddings = self._embeddings()
|
| 98 |
+
faiss_store = FAISS.load_local(
|
| 99 |
+
str(self.index_dir),
|
| 100 |
+
embeddings,
|
| 101 |
+
allow_dangerous_deserialization=True,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Re-build BM25 from FAISS docstore
|
| 105 |
+
docs = list(faiss_store.docstore._dict.values())
|
| 106 |
+
self._retriever = self._build_retriever(faiss_store, docs)
|
| 107 |
+
|
| 108 |
+
def _build_retriever(self, faiss_store: FAISS, docs: List[Document]):
|
| 109 |
+
faiss_retriever = faiss_store.as_retriever(
|
| 110 |
+
search_kwargs={"k": self.top_k}
|
| 111 |
+
)
|
| 112 |
+
bm25_retriever = BM25Retriever.from_documents(docs)
|
| 113 |
+
bm25_retriever.k = self.top_k
|
| 114 |
+
|
| 115 |
+
ensemble = EnsembleRetriever(
|
| 116 |
+
retrievers=[faiss_retriever, bm25_retriever],
|
| 117 |
+
weights=[0.6, 0.4],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
cross_encoder = HuggingFaceCrossEncoder(model_name=_RERANK_MODEL)
|
| 121 |
+
reranker = CrossEncoderReranker(model=cross_encoder, top_n=self.rerank_top_k)
|
| 122 |
+
|
| 123 |
+
return ContextualCompressionRetriever(
|
| 124 |
+
base_compressor=reranker,
|
| 125 |
+
base_retriever=ensemble,
|
| 126 |
+
)
|
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import unittest
|
| 2 |
+
import sys
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
# Add src to Python path
|
| 6 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))
|
| 7 |
+
|
| 8 |
+
from lean_verifier import LeanEnvironment
|
| 9 |
+
|
| 10 |
+
class TestLeanVerifier(unittest.TestCase):
|
| 11 |
+
@classmethod
|
| 12 |
+
def setUpClass(cls):
|
| 13 |
+
# We'll use Mathlib here since our goal is to verify it works with the MVP setup
|
| 14 |
+
cls.lean_env = LeanEnvironment(use_mathlib=True)
|
| 15 |
+
|
| 16 |
+
def test_correct_proof(self):
|
| 17 |
+
lean_code = """
|
| 18 |
+
import Mathlib
|
| 19 |
+
|
| 20 |
+
theorem add_comm_test (n m : Nat) : n + m = m + n := by
|
| 21 |
+
exact Nat.add_comm n m
|
| 22 |
+
"""
|
| 23 |
+
result = self.lean_env.verify_proof(lean_code)
|
| 24 |
+
self.assertEqual(result["status"], "success")
|
| 25 |
+
self.assertEqual(len(result["errors"]), 0)
|
| 26 |
+
self.assertEqual(len(result["goals"]), 0)
|
| 27 |
+
|
| 28 |
+
def test_incorrect_proof_type_mismatch(self):
|
| 29 |
+
lean_code = """
|
| 30 |
+
import Mathlib
|
| 31 |
+
|
| 32 |
+
theorem add_comm_test (n m : Nat) : n + m = m + n := by
|
| 33 |
+
exact n
|
| 34 |
+
"""
|
| 35 |
+
result = self.lean_env.verify_proof(lean_code)
|
| 36 |
+
self.assertEqual(result["status"], "failure")
|
| 37 |
+
self.assertTrue(any("type mismatch" in err or "application type mismatch" in err for err in result["errors"]),
|
| 38 |
+
f"Expected type mismatch error, got: {result['errors']}")
|
| 39 |
+
|
| 40 |
+
def test_incomplete_proof_sorry(self):
|
| 41 |
+
lean_code = """
|
| 42 |
+
import Mathlib
|
| 43 |
+
|
| 44 |
+
theorem my_incomplete_thm (n : Nat) : n = 5 → n = 5 := by
|
| 45 |
+
sorry
|
| 46 |
+
"""
|
| 47 |
+
result = self.lean_env.verify_proof(lean_code)
|
| 48 |
+
self.assertEqual(result["status"], "failure")
|
| 49 |
+
# Ensure it has an error indicating sorry
|
| 50 |
+
self.assertTrue(any("uses 'sorry'" in err for err in result["errors"]),
|
| 51 |
+
f"Expected sorry warning/error, got: {result['errors']}")
|
| 52 |
+
# Ensure it outputs the goal
|
| 53 |
+
self.assertEqual(len(result["goals"]), 1)
|
| 54 |
+
self.assertIn("⊢ n = 5 → n = 5", result["goals"][0])
|
| 55 |
+
|
| 56 |
+
if __name__ == '__main__':
|
| 57 |
+
unittest.main()
|