File size: 5,319 Bytes
bf9c466
 
 
 
 
 
 
 
16f1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf9c466
 
 
 
 
16f1328
 
 
 
 
bf9c466
 
 
16f1328
 
 
 
bf9c466
16f1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26a7647
16f1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26a7647
16f1328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""StdinBackend: local Lean verification via `lean --stdin`.

Requires a prior `cd lean && lake build` so the .olean files for all spec
modules exist.  _ensure_lake_build() auto-triggers the build on first use
if the artifacts are absent — useful in Docker where the build runs at image
build time and this guard is a no-op.
"""

from __future__ import annotations

import os
import shutil
import subprocess
import textwrap
import time
from pathlib import Path

from .interface import LeanBackend, LeanResult

REPO_ROOT = Path(__file__).resolve().parents[1]
DEFAULT_LEAN_BIN = Path(shutil.which("lean") or Path.home() / ".elan/bin/lean")
DEFAULT_LEAN_CWD = REPO_ROOT / "lean"
DEFAULT_LEAN_PATH = DEFAULT_LEAN_CWD / ".lake" / "build" / "lib"
DEFAULT_LAKE_BIN = Path(shutil.which("lake") or Path.home() / ".elan/bin/lake")

LEAN_BIN = Path(os.environ.get("LEAN_BIN", str(DEFAULT_LEAN_BIN)))
LEAN_CWD = Path(os.environ.get("LEAN_CWD", str(DEFAULT_LEAN_CWD)))
LEAN_PATH = os.environ.get("LEAN_PATH", str(DEFAULT_LEAN_PATH))
LAKE_BIN = Path(os.environ.get("LAKE_BIN", str(DEFAULT_LAKE_BIN)))

if not LEAN_CWD.exists():
    LEAN_CWD = DEFAULT_LEAN_CWD
if not Path(LEAN_PATH).exists():
    LEAN_PATH = str(DEFAULT_LEAN_PATH)

_LEAN_BUILD_READY = False


def _get_spec_modules() -> list[str]:
    from lean_migrate.env.tasks import _TASKS
    return list(dict.fromkeys(t.lean_spec_module for t in _TASKS.values()))


def _ensure_lake_build() -> None:
    global _LEAN_BUILD_READY
    if _LEAN_BUILD_READY:
        return

    modules = _get_spec_modules()
    lean_path = Path(LEAN_PATH)
    if all((lean_path / f"{m}.olean").exists() for m in modules):
        _LEAN_BUILD_READY = True
        return

    process = subprocess.run(
        [str(LAKE_BIN), "build"] + modules,
        capture_output=True,
        text=True,
        cwd=str(LEAN_CWD),
        env={**os.environ, "LEAN_PATH": LEAN_PATH},
    )
    if process.returncode != 0:
        raise RuntimeError(
            "Lean build failed before verification.\n"
            f"stdout:\n{process.stdout}\n"
            f"stderr:\n{process.stderr}"
        )

    _LEAN_BUILD_READY = True


class StdinBackend(LeanBackend):
    def _run_lean(self, code: str, timeout: int = 15) -> LeanResult:
        _ensure_lake_build()
        start_time = time.monotonic()
        try:
            process = subprocess.run(
                [str(LEAN_BIN), "--stdin"],
                input=code,
                capture_output=True,
                text=True,
                timeout=timeout,
                cwd=str(LEAN_CWD),
                env={**os.environ, "LEAN_PATH": LEAN_PATH},
            )
            elapsed_ms = int((time.monotonic() - start_time) * 1000)
            passed = process.returncode == 0 and "error:" not in process.stderr.lower()
            error_output = process.stderr.strip() or process.stdout.strip()
            return LeanResult(
                passed=passed,
                error="" if passed else error_output,
                latency_ms=elapsed_ms,
                raw_stderr=process.stderr,
            )
        except subprocess.TimeoutExpired:
            elapsed_ms = int((time.monotonic() - start_time) * 1000)
            return LeanResult(
                passed=False,
                error=f"LEAN verification timed out after {timeout} seconds.",
                latency_ms=elapsed_ms,
                raw_stderr="",
            )
        except Exception as error:
            elapsed_ms = int((time.monotonic() - start_time) * 1000)
            return LeanResult(
                passed=False,
                error=f"Backend error: {error}",
                latency_ms=elapsed_ms,
                raw_stderr="",
            )

    def verify(
        self,
        spec_module: str,
        function_name: str,
        code: str,
        symbol_name: str | None = None,
        extra_imports: list[str] | None = None,
        sample_checks: list[str] | None = None,
    ) -> LeanResult:
        module_name = f"lean.{spec_module}" if not spec_module.startswith("lean.") else spec_module
        import_lines = [f"import {module_name}"]
        for module_name in extra_imports or []:
            prefixed_name = (
                module_name
                if module_name.startswith("lean.")
                else f"lean.{module_name}"
            )
            import_lines.append(f"import {prefixed_name}")

        sections = [
            "\n".join(import_lines),
            code,
        ]
        if sample_checks:
            sections.extend(sample_checks)
        sections.append(f"#check {symbol_name or f'_root_.{function_name}'}")
        lean_code = "\n\n".join(section for section in sections if section.strip())
        return self._run_lean(lean_code)

    def verify_proof(self, spec_module: str, proof_code: str) -> LeanResult:
        if "sorry" in proof_code:
            return LeanResult(
                passed=False,
                error="Proof contains 'sorry' and cannot be accepted.",
                latency_ms=0,
                raw_stderr="",
            )

        lean_code = textwrap.dedent(
            f"""
            import lean.{spec_module}
            open {spec_module}

            {proof_code}
            """
        ).strip()
        return self._run_lean(lean_code, timeout=30)