File size: 10,773 Bytes
ed428ff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
import modal
import hashlib
from fastapi import FastAPI
from pydantic import BaseModel

app = modal.App("lean-proof-agent")

image = (
    modal.Image.debian_slim()
    .apt_install("curl", "git", "build-essential")
    .pip_install("lean-interact", "requests", "fastapi")
    .run_commands(
        "curl https://elan.lean-lang.org/elan-init.sh -sSf | sh -s -- -y --default-toolchain leanprover/lean4:v4.14.0",
    )
    .env({"PATH": "/root/.elan/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"})

    .run_commands(
        'python -c "from lean_interact import LeanREPLConfig; LeanREPLConfig()"'
    )
)





LLAMA_ENDPOINT = "https://no-name13--llama-server-serve.modal.run/v1/chat/completions"

web_app = FastAPI()

class ProveRequest(BaseModel):
    theorem: str
    max_steps: int = 20
    use_fallbacks: bool = True
    show_reasoning: bool = True  # when True, skip cache read so the full agent loop always runs

class StepLog(BaseModel):
    step: int
    goal: str
    candidates: list[str]
    chosen: str
    status: str
    error: str | None = None

class ProveResponse(BaseModel):
    success: bool
    stuck: bool = False
    tactics: list[str]
    steps: list[StepLog]
    message: str


@app.function(
    image=image,
    timeout=300,
    min_containers=1,
)
@modal.asgi_app()
def fastapi_app():
    from lean_interact import LeanServer, LeanREPLConfig, Command, ProofStep
    from lean_interact.interface import LeanError
    import requests

    config = LeanREPLConfig()
    server = LeanServer(config)

    FALLBACK_TACTICS = ["rfl", "norm_num", "simp", "omega", "contradiction", "assumption"]
    NUM_CANDIDATES = 3
    lemma_cache: dict[str, list[str]] = {}


    def ask_model(goal_state, last_error=None, num_candidates=3):
        error_context = ""
        if last_error:
            error_context = (
                f"\nThe previous tactic failed with this error:\n{last_error}\n"
                f"IMPORTANT: if the error says 'major premise type is not an inductive type', "
                f"it means you must use `intro` to bring variables into context BEFORE using `induction`.\n"
            )
        prompt = (
            f"You are a Lean 4 theorem prover. Given this proof state:\n\n{goal_state}\n"
            f"{error_context}\n"
            f"Suggest the next single tactic. Output ONLY the tactic, no backticks, no explanation.\n"
            f"RULES:\n"
            f"- do NOT use `omega`, `decide`, `tauto`\n"
            f"- do NOT use `apply Nat.add_comm` (using a named library lemma as a shortcut)\n"
            f"- ALLOWED closing tactics — use these freely when they fit:\n"
            f"  `exact h` or `exact ⟨h1, h2⟩` (provide a proof term directly)\n"
            f"  `contradiction` (when context contains P and ¬P)\n"
            f"  `assumption` (when goal matches a hypothesis exactly)\n"
            f"  `absurd h1 h2` (derive False from h1 : P and h2 : ¬P)\n"
            f"- always use fresh, distinct variable names when introducing (e.g. `intro n`, `intro P`, `intro Q`) — never reuse a name already present in the context\n"
            f"- if the goal starts with `∀`, always use `intro` first before anything else\n"
            f"- when using induction, always provide full case syntax:\n"
            f"  induction n with\n  | zero => simp\n  | succ n ih => simp [ih]"
        )

        tactics = []
        for _ in range(num_candidates):
            try:
                payload = {
                    "model": "any",
                    "messages": [{"role": "user", "content": prompt}],
                    "max_tokens": 200,
                    "stream": False,
                    "temperature": 0.8,
                    "chat_template_kwargs": {"enable_thinking": False}
                }
                resp = requests.post(LLAMA_ENDPOINT, json=payload, timeout=30)
                resp.raise_for_status()
                tactic = resp.json()["choices"][0]["message"]["content"].strip().strip("`").strip()
                if tactic and tactic not in tactics:
                    tactics.append(tactic)
            except Exception:
                break
        return tactics

    def try_tactic(tactic, proof_state_id):
        result = server.run(ProofStep(tactic=tactic, proof_state=proof_state_id))
        if isinstance(result, LeanError):
            return None, result.message
        return result, None

    def try_fallbacks(proof_state_id, enabled):
        if not enabled:
            return None, None
        for tactic in FALLBACK_TACTICS:
            result, _ = try_tactic(tactic, proof_state_id)
            if result is not None and "sorry" not in result.proof_status:
                return result, tactic
        return None, None

    @web_app.post("/prove", response_model=ProveResponse)
    def prove(req: ProveRequest):
        steps = []
        response = server.run(Command(cmd=f"{req.theorem} := by sorry"))

        if not response.sorries:
            if any(m.data == "Goals accomplished!" for m in response.messages):
                return ProveResponse(success=True, tactics=[], steps=[], message="Proved trivially!")
            return ProveResponse(success=False, tactics=[], steps=[], message="Could not get initial proof state")

        proof_state_id = response.sorries[0].proof_state
        current_goals = [response.sorries[0].goal]

        goal_hash = hashlib.md5(current_goals[0].encode()).hexdigest()
        if not req.show_reasoning and goal_hash in lemma_cache:
            cached = lemma_cache[goal_hash]
            return ProveResponse(
                success=True, tactics=cached, steps=[],
                message=f"Proved from cache ({len(cached)} tactic(s))!"
            )

        tactics = []
        last_error = None
        visited = set()
        llm_ever_responded = False
        consecutive_failures = 0  # all_failed steps in a row
        goal_seen: dict[str, int] = {}  # goal text → times seen
        STUCK_THRESHOLD = 3

        for step in range(req.max_steps):
            goal_text = "\n".join(current_goals)
            if not goal_text.strip():
                break

            # Stuck-state detection: same goal returning, or consecutive dead ends
            goal_seen[goal_text] = goal_seen.get(goal_text, 0) + 1
            if goal_seen[goal_text] >= STUCK_THRESHOLD or consecutive_failures >= STUCK_THRESHOLD:
                return ProveResponse(
                    success=False, stuck=True, tactics=tactics, steps=steps,
                    message=(
                        "Search stuck — the same goal state recurred with no progress. "
                        "This theorem is likely not provable in the current theory: "
                        "it may require classical logic (Law of Excluded Middle), "
                        "or a tactic the agent is constrained from using."
                    ),
                )

            result, fallback_tactic = try_fallbacks(proof_state_id, req.use_fallbacks)
            if result is not None:
                tactics.append(fallback_tactic)
                steps.append(StepLog(
                    step=step, goal=goal_text,
                    candidates=[fallback_tactic], chosen=fallback_tactic,
                    status=result.proof_status
                ))
                if result.proof_status == "Completed":
                    lemma_cache[goal_hash] = list(tactics)
                    return ProveResponse(success=True, tactics=tactics, steps=steps,
                                        message=f"Proved in {step+1} steps!")
                proof_state_id = result.proof_state
                current_goals = result.goals
                last_error = None
                consecutive_failures = 0
                continue

            candidates = ask_model(goal_text, last_error=last_error, num_candidates=NUM_CANDIDATES)
            if not candidates:
                steps.append(StepLog(
                    step=step, goal=goal_text,
                    candidates=[], chosen="",
                    status="model_unavailable", error="LLM endpoint cold/unavailable — retrying"
                ))
                continue  # don't count toward stuck-state; just burn a step and retry

            llm_ever_responded = True
            best_result = None
            best_tactic = None
            step_error = None

            for tactic in candidates:
                key = (proof_state_id, tactic)
                if key in visited:
                    continue
                visited.add(key)
                result, error = try_tactic(tactic, proof_state_id)
                if result is None:
                    last_error = error
                    step_error = error
                    continue
                if "sorry" in result.proof_status:
                    last_error = "That tactic left sorry holes. Provide the full proof of each case inline."
                    continue
                if result.proof_status == "Completed":
                    tactics.append(tactic)
                    steps.append(StepLog(
                        step=step, goal=goal_text,
                        candidates=candidates, chosen=tactic,
                        status="Completed"
                    ))
                    lemma_cache[goal_hash] = list(tactics)
                    return ProveResponse(success=True, tactics=tactics, steps=steps,
                                        message=f"Proved in {step+1} steps!")
                if best_result is None or len(result.goals) < len(best_result.goals):
                    best_result = result
                    best_tactic = tactic

            if best_result is not None:
                tactics.append(best_tactic)
                steps.append(StepLog(
                    step=step, goal=goal_text,
                    candidates=candidates, chosen=best_tactic,
                    status=best_result.proof_status
                ))
                proof_state_id = best_result.proof_state
                current_goals = best_result.goals
                last_error = None
                consecutive_failures = 0
            else:
                consecutive_failures += 1
                steps.append(StepLog(
                    step=step, goal=goal_text,
                    candidates=candidates, chosen="",
                    status="all_failed", error=step_error
                ))

        fail_msg = (
            "LLM endpoint warming up — fallback tactics only. Failed within max steps."
            if not llm_ever_responded
            else "Failed within max steps"
        )
        return ProveResponse(success=False, tactics=tactics, steps=steps, message=fail_msg)

    return web_app