from pathlib import Path from typing import List, Optional from langchain_community.vectorstores import FAISS from langchain_core.documents import Document from byt5_embedder import ByT5PremiseEmbedder from mathlib_corpus import MathLibCorpus _DEFAULT_INDEX_DIR = Path(__file__).resolve().parent.parent / "data" / "mathlib_index" # Query-time embedder. Lean-aware ByT5 encoder from LeanDojo, trained on # (proof_state, used_premise) pairs from Mathlib. The Mathlib FAISS index is # built from LeanDojo's pre-computed embeddings via # `scripts/build_leandojo_index.py`. # # History: we previously hybrid-ensembled this with BM25 and reranked with a # generic `ms-marco-MiniLM-L-6-v2` cross-encoder. That made sense for the old # general-English `all-MiniLM-L6-v2` dense path. With the LeanDojo encoder the # semantic results are already strong and domain-tuned — the generic reranker # was actively reordering correct premises (Nat.add_comm, Nat.add_assoc, …) # below irrelevant matches (Ackermann, ZMod, choose). We dropped both layers. class MathLibRetriever: """ Lean-aware FAISS retriever over Mathlib premises using LeanDojo's ByT5 encoder. Loads an IVFPQ-compressed FAISS index built from LeanDojo's pre-computed premise embeddings (see scripts/build_leandojo_index.py). """ def __init__( self, index_dir: Optional[str] = None, top_k: int = 20, rerank_top_k: int = 5, nprobe: int = 32, ): self.index_dir = Path(index_dir) if index_dir else _DEFAULT_INDEX_DIR self.top_k = top_k # `rerank_top_k` kept for back-compat; it just caps the number returned # from the underlying FAISS search (we no longer rerank). self.rerank_top_k = rerank_top_k self.nprobe = nprobe self._retriever = None self._faiss_store: Optional[FAISS] = None self._missing_index_warned = False # ------------------------------------------------------------------ # Public API # ------------------------------------------------------------------ def build(self, mathlib_root: Optional[str] = None, max_files: Optional[int] = None) -> None: """ Extract Mathlib documents, build FAISS + BM25 indices, and persist to disk. Call this once (via scripts/build_index.py) before first use. """ print("Extracting Mathlib corpus…") corpus = MathLibCorpus(mathlib_root=mathlib_root) docs = corpus.extract(max_files=max_files) print(f" {len(docs)} declarations extracted.") embeddings = self._embeddings() print("Building FAISS index…") faiss_store = FAISS.from_documents(docs, embeddings) self.index_dir.mkdir(parents=True, exist_ok=True) faiss_store.save_local(str(self.index_dir)) print(f" Index saved to {self.index_dir}") self._retriever = self._build_retriever(faiss_store, docs) # Hard cap on query length passed to downstream retrievers. Long queries # waste embedding work and can blow past tokenizer limits; truncating here # gives callers a predictable upper bound. _MAX_QUERY_CHARS = 2000 def retrieve(self, query: str, k: Optional[int] = None) -> List[Document]: """ Retrieve and rerank the most relevant Mathlib lemmas for a query. Args: query: Natural-language or Lean-syntax query (e.g., proof goals + errors). k: Number of results to return after reranking (defaults to self.rerank_top_k). Returns: List of Documents ranked by relevance. Raises: TypeError: If `query` is not a string. """ if not isinstance(query, str): raise TypeError( f"query must be a str, got {type(query).__name__}" ) # Empty/whitespace-only queries are valid input but degenerate; bail # out early with no results rather than asking the embedding model to # vectorise an empty string (which produces noisy nearest neighbours). if not query.strip(): return [] # Truncate absurdly long inputs so a runaway caller can't pin the # embedding model. if len(query) > self._MAX_QUERY_CHARS: query = query[: self._MAX_QUERY_CHARS] if self._faiss_store is None: if not self.is_index_built(): if not self._missing_index_warned: print( f" [retriever] No FAISS index at {self.index_dir} — " "skipping Mathlib RAG. The LLM will solve from its training " "knowledge of Mathlib only. Run `python scripts/build_leandojo_index.py` " "to enable retrieval-augmented generation." ) self._missing_index_warned = True return [] self._load() return self._faiss_store.similarity_search(query, k=k or self.rerank_top_k) def is_index_built(self) -> bool: return (self.index_dir / "index.faiss").exists() # ------------------------------------------------------------------ # Internal helpers # ------------------------------------------------------------------ def _embeddings(self) -> ByT5PremiseEmbedder: return ByT5PremiseEmbedder() def _load(self) -> None: if not self.is_index_built(): raise RuntimeError( f"No FAISS index found at {self.index_dir}. " "Run `python scripts/build_leandojo_index.py` first." ) print("Loading FAISS index from disk…") embeddings = self._embeddings() self._faiss_store = FAISS.load_local( str(self.index_dir), embeddings, allow_dangerous_deserialization=True, ) # Tune IVFPQ search breadth. nprobe=32 / nlist=512 = 6% of clusters # searched — good recall/speed tradeoff for this index size. # extract_index_ivf reaches the IVF layer even if the index is later # wrapped (e.g. IndexIDMap); setting nprobe on a wrapper is silently # ignored by FAISS, which this guards against. import faiss try: faiss.extract_index_ivf(self._faiss_store.index).nprobe = self.nprobe except (RuntimeError, TypeError): # RuntimeError: no IVF layer (e.g. flat index) — nothing to tune. # TypeError: not a real faiss index (e.g. a test double). pass