Spaces:
Sleeping
Sleeping
Upload folder using huggingface_hub
Browse files- README.md +1 -1
- env/grader.py +17 -2
- env/state.py +5 -0
- env/verification_ir.py +36 -6
- inference.py +13 -7
- lean_backend/interface.py +1 -0
- lean_backend/kimina_backend.py +2 -1
- lean_backend/stdin_backend.py +2 -1
- tests/test_env_episode.py +39 -2
README.md
CHANGED
|
@@ -91,7 +91,7 @@ TASK_ID=rbac_auth \
|
|
| 91 |
uv run python inference.py
|
| 92 |
```
|
| 93 |
|
| 94 |
-
`inference.py` uses the OpenAI client. It reads credentials from `HF_TOKEN`
|
| 95 |
|
| 96 |
## Baseline
|
| 97 |
|
|
|
|
| 91 |
uv run python inference.py
|
| 92 |
```
|
| 93 |
|
| 94 |
+
`inference.py` uses the OpenAI client against the Hugging Face router with `Qwen/Qwen2.5-72B-Instruct`. It reads credentials from `HF_TOKEN` and emits structured stdout logs in the required `[START]`, `[STEP]`, `[END]` format.
|
| 95 |
|
| 96 |
## Baseline
|
| 97 |
|
env/grader.py
CHANGED
|
@@ -309,7 +309,19 @@ def _lean_call(task_id: str, function_name: str, args: tuple[Any, ...]) -> str:
|
|
| 309 |
return _lean_call_impl(task_id, function_name, args)
|
| 310 |
|
| 311 |
|
| 312 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 313 |
if function_spec.is_proof_required:
|
| 314 |
return []
|
| 315 |
|
|
@@ -317,7 +329,10 @@ def build_lean_sample_checks(task: Task, function_spec: FunctionSpec) -> list[st
|
|
| 317 |
checks: list[str] = []
|
| 318 |
for case in cases:
|
| 319 |
expected_value = oracle_result(task.task_id, function_spec.name, case.args)
|
| 320 |
-
call_expr =
|
|
|
|
|
|
|
|
|
|
| 321 |
expected_expr = _lean_value(task.task_id, function_spec.name, expected_value)
|
| 322 |
checks.append(
|
| 323 |
textwrap.dedent(
|
|
|
|
| 309 |
return _lean_call_impl(task_id, function_name, args)
|
| 310 |
|
| 311 |
|
| 312 |
+
def _with_call_namespace(call_expr: str, call_namespace: str) -> str:
|
| 313 |
+
if call_namespace == "_root_":
|
| 314 |
+
return call_expr
|
| 315 |
+
|
| 316 |
+
root_prefix = "_root_."
|
| 317 |
+
if call_expr.startswith(root_prefix):
|
| 318 |
+
return f"{call_namespace}.{call_expr[len(root_prefix):]}"
|
| 319 |
+
return call_expr
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
def build_lean_sample_checks(
|
| 323 |
+
task: Task, function_spec: FunctionSpec, call_namespace: str = "_root_"
|
| 324 |
+
) -> list[str]:
|
| 325 |
if function_spec.is_proof_required:
|
| 326 |
return []
|
| 327 |
|
|
|
|
| 329 |
checks: list[str] = []
|
| 330 |
for case in cases:
|
| 331 |
expected_value = oracle_result(task.task_id, function_spec.name, case.args)
|
| 332 |
+
call_expr = _with_call_namespace(
|
| 333 |
+
_lean_call(task.task_id, function_spec.name, case.args),
|
| 334 |
+
call_namespace,
|
| 335 |
+
)
|
| 336 |
expected_expr = _lean_value(task.task_id, function_spec.name, expected_value)
|
| 337 |
checks.append(
|
| 338 |
textwrap.dedent(
|
env/state.py
CHANGED
|
@@ -337,6 +337,11 @@ class EpisodeState:
|
|
| 337 |
code=ir_result.lean_code
|
| 338 |
if ir_result is not None
|
| 339 |
else action.target_code,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 340 |
sample_checks=[],
|
| 341 |
)
|
| 342 |
proof_compiled = None
|
|
|
|
| 337 |
code=ir_result.lean_code
|
| 338 |
if ir_result is not None
|
| 339 |
else action.target_code,
|
| 340 |
+
symbol_name=(
|
| 341 |
+
f"Candidate.{action.function_name}"
|
| 342 |
+
if ir_result is not None and ir_result.lean_code is not None
|
| 343 |
+
else None
|
| 344 |
+
),
|
| 345 |
sample_checks=[],
|
| 346 |
)
|
| 347 |
proof_compiled = None
|
env/verification_ir.py
CHANGED
|
@@ -541,6 +541,27 @@ def _render_lean_definition(task: Task, function_name: str) -> str:
|
|
| 541 |
return textwrap.dedent(function_spec.lean_fragment).strip()
|
| 542 |
|
| 543 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 544 |
def _render_lean_mirror(
|
| 545 |
task: Task, function_spec: FunctionSpec, provenance: CodeProvenanceIR
|
| 546 |
) -> str:
|
|
@@ -558,18 +579,27 @@ def _render_lean_mirror(
|
|
| 558 |
"""
|
| 559 |
).strip()
|
| 560 |
|
|
|
|
| 561 |
definition_names = dependency_closure(task, function_spec.name) + [
|
| 562 |
function_spec.name
|
| 563 |
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 564 |
definition_block = "\n\n".join(
|
| 565 |
-
[f"open {task.lean_spec_module}"]
|
| 566 |
-
+
|
| 567 |
-
|
| 568 |
-
for definition_name in definition_names
|
| 569 |
-
]
|
| 570 |
)
|
| 571 |
|
| 572 |
-
sample_checks = build_lean_sample_checks(
|
|
|
|
|
|
|
| 573 |
checks_block = "\n\n".join(sample_checks)
|
| 574 |
|
| 575 |
return "\n\n".join(
|
|
|
|
| 541 |
return textwrap.dedent(function_spec.lean_fragment).strip()
|
| 542 |
|
| 543 |
|
| 544 |
+
def _qualify_definition_references(
|
| 545 |
+
definition_text: str, definition_names: list[str], namespace: str
|
| 546 |
+
) -> str:
|
| 547 |
+
if len(definition_names) == 0:
|
| 548 |
+
return definition_text
|
| 549 |
+
|
| 550 |
+
lines = definition_text.splitlines()
|
| 551 |
+
if len(lines) <= 1:
|
| 552 |
+
return definition_text
|
| 553 |
+
|
| 554 |
+
name_pattern = re.compile(
|
| 555 |
+
r"\b(" + "|".join(re.escape(name) for name in definition_names) + r")\b"
|
| 556 |
+
)
|
| 557 |
+
qualified_lines = [lines[0]]
|
| 558 |
+
qualified_lines.extend(
|
| 559 |
+
name_pattern.sub(lambda match: f"{namespace}.{match.group(1)}", line)
|
| 560 |
+
for line in lines[1:]
|
| 561 |
+
)
|
| 562 |
+
return "\n".join(qualified_lines)
|
| 563 |
+
|
| 564 |
+
|
| 565 |
def _render_lean_mirror(
|
| 566 |
task: Task, function_spec: FunctionSpec, provenance: CodeProvenanceIR
|
| 567 |
) -> str:
|
|
|
|
| 579 |
"""
|
| 580 |
).strip()
|
| 581 |
|
| 582 |
+
candidate_namespace = "Candidate"
|
| 583 |
definition_names = dependency_closure(task, function_spec.name) + [
|
| 584 |
function_spec.name
|
| 585 |
]
|
| 586 |
+
qualified_definitions = [
|
| 587 |
+
_qualify_definition_references(
|
| 588 |
+
_render_lean_definition(task, definition_name),
|
| 589 |
+
definition_names,
|
| 590 |
+
candidate_namespace,
|
| 591 |
+
)
|
| 592 |
+
for definition_name in definition_names
|
| 593 |
+
]
|
| 594 |
definition_block = "\n\n".join(
|
| 595 |
+
[f"open {task.lean_spec_module}", f"namespace {candidate_namespace}"]
|
| 596 |
+
+ qualified_definitions
|
| 597 |
+
+ [f"end {candidate_namespace}"]
|
|
|
|
|
|
|
| 598 |
)
|
| 599 |
|
| 600 |
+
sample_checks = build_lean_sample_checks(
|
| 601 |
+
task, function_spec, call_namespace=candidate_namespace
|
| 602 |
+
)
|
| 603 |
checks_block = "\n\n".join(sample_checks)
|
| 604 |
|
| 605 |
return "\n\n".join(
|
inference.py
CHANGED
|
@@ -9,12 +9,15 @@ import textwrap
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
ROOT = Path(__file__).resolve().parents[1]
|
| 13 |
if str(ROOT) not in sys.path:
|
| 14 |
sys.path.insert(0, str(ROOT))
|
| 15 |
|
| 16 |
-
from openai import OpenAI
|
| 17 |
-
|
| 18 |
from lean_migrate.env.models import SubmitAction
|
| 19 |
from lean_migrate.env.target_snippets import (
|
| 20 |
TASK_TARGET_SNIPPETS,
|
|
@@ -23,9 +26,11 @@ from lean_migrate.env.target_snippets import (
|
|
| 23 |
from lean_migrate.env.tasks import get_task, list_tasks
|
| 24 |
from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment
|
| 25 |
|
|
|
|
|
|
|
| 26 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 27 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 28 |
-
API_KEY = os.getenv("HF_TOKEN")
|
| 29 |
TASK_ID = os.getenv("TASK_ID")
|
| 30 |
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
| 31 |
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
|
|
@@ -39,7 +44,10 @@ def log_start(task: str, env: str, model: str) -> None:
|
|
| 39 |
def log_step(
|
| 40 |
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 41 |
) -> None:
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
| 43 |
print(
|
| 44 |
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_value}",
|
| 45 |
flush=True,
|
|
@@ -200,9 +208,7 @@ async def _run_task(client: OpenAI, task_id: str) -> None:
|
|
| 200 |
|
| 201 |
async def main() -> None:
|
| 202 |
if not API_KEY:
|
| 203 |
-
raise RuntimeError(
|
| 204 |
-
"Set HF_TOKEN, OPENAI_API_KEY, or API_KEY before running inference.py"
|
| 205 |
-
)
|
| 206 |
|
| 207 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 208 |
task_ids = [TASK_ID] if TASK_ID else [task["task_id"] for task in list_tasks()]
|
|
|
|
| 9 |
from pathlib import Path
|
| 10 |
from typing import Optional
|
| 11 |
|
| 12 |
+
from dotenv import load_dotenv
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
from openai import OpenAI
|
| 16 |
+
|
| 17 |
ROOT = Path(__file__).resolve().parents[1]
|
| 18 |
if str(ROOT) not in sys.path:
|
| 19 |
sys.path.insert(0, str(ROOT))
|
| 20 |
|
|
|
|
|
|
|
| 21 |
from lean_migrate.env.models import SubmitAction
|
| 22 |
from lean_migrate.env.target_snippets import (
|
| 23 |
TASK_TARGET_SNIPPETS,
|
|
|
|
| 26 |
from lean_migrate.env.tasks import get_task, list_tasks
|
| 27 |
from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment
|
| 28 |
|
| 29 |
+
load_dotenv() # Load environment variables from .env file
|
| 30 |
+
|
| 31 |
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
|
| 32 |
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
|
| 33 |
+
API_KEY = os.getenv("HF_TOKEN")
|
| 34 |
TASK_ID = os.getenv("TASK_ID")
|
| 35 |
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
|
| 36 |
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
|
|
|
|
| 44 |
def log_step(
|
| 45 |
step: int, action: str, reward: float, done: bool, error: Optional[str]
|
| 46 |
) -> None:
|
| 47 |
+
if error:
|
| 48 |
+
error_value = " ".join(error.split())
|
| 49 |
+
else:
|
| 50 |
+
error_value = "null"
|
| 51 |
print(
|
| 52 |
f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_value}",
|
| 53 |
flush=True,
|
|
|
|
| 208 |
|
| 209 |
async def main() -> None:
|
| 210 |
if not API_KEY:
|
| 211 |
+
raise RuntimeError("Set HF_TOKEN before running inference.py")
|
|
|
|
|
|
|
| 212 |
|
| 213 |
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
|
| 214 |
task_ids = [TASK_ID] if TASK_ID else [task["task_id"] for task in list_tasks()]
|
lean_backend/interface.py
CHANGED
|
@@ -19,6 +19,7 @@ class LeanBackend(ABC):
|
|
| 19 |
spec_module: str,
|
| 20 |
function_name: str,
|
| 21 |
code: str,
|
|
|
|
| 22 |
extra_imports: list[str] | None = None,
|
| 23 |
sample_checks: list[str] | None = None,
|
| 24 |
) -> LeanResult:
|
|
|
|
| 19 |
spec_module: str,
|
| 20 |
function_name: str,
|
| 21 |
code: str,
|
| 22 |
+
symbol_name: str | None = None,
|
| 23 |
extra_imports: list[str] | None = None,
|
| 24 |
sample_checks: list[str] | None = None,
|
| 25 |
) -> LeanResult:
|
lean_backend/kimina_backend.py
CHANGED
|
@@ -52,6 +52,7 @@ class KiminaBackend(LeanBackend):
|
|
| 52 |
spec_module: str,
|
| 53 |
function_name: str,
|
| 54 |
code: str,
|
|
|
|
| 55 |
extra_imports: list[str] | None = None,
|
| 56 |
sample_checks: list[str] | None = None,
|
| 57 |
) -> LeanResult:
|
|
@@ -71,7 +72,7 @@ class KiminaBackend(LeanBackend):
|
|
| 71 |
]
|
| 72 |
if sample_checks:
|
| 73 |
sections.extend(sample_checks)
|
| 74 |
-
sections.append(f"#check _root_.{function_name}")
|
| 75 |
lean_code = "\n\n".join(section for section in sections if section.strip())
|
| 76 |
return self._call_kimina(lean_code)
|
| 77 |
|
|
|
|
| 52 |
spec_module: str,
|
| 53 |
function_name: str,
|
| 54 |
code: str,
|
| 55 |
+
symbol_name: str | None = None,
|
| 56 |
extra_imports: list[str] | None = None,
|
| 57 |
sample_checks: list[str] | None = None,
|
| 58 |
) -> LeanResult:
|
|
|
|
| 72 |
]
|
| 73 |
if sample_checks:
|
| 74 |
sections.extend(sample_checks)
|
| 75 |
+
sections.append(f"#check {symbol_name or f'_root_.{function_name}'}")
|
| 76 |
lean_code = "\n\n".join(section for section in sections if section.strip())
|
| 77 |
return self._call_kimina(lean_code)
|
| 78 |
|
lean_backend/stdin_backend.py
CHANGED
|
@@ -100,6 +100,7 @@ class StdinBackend(LeanBackend):
|
|
| 100 |
spec_module: str,
|
| 101 |
function_name: str,
|
| 102 |
code: str,
|
|
|
|
| 103 |
extra_imports: list[str] | None = None,
|
| 104 |
sample_checks: list[str] | None = None,
|
| 105 |
) -> LeanResult:
|
|
@@ -119,7 +120,7 @@ class StdinBackend(LeanBackend):
|
|
| 119 |
]
|
| 120 |
if sample_checks:
|
| 121 |
sections.extend(sample_checks)
|
| 122 |
-
sections.append(f"#check _root_.{function_name}")
|
| 123 |
lean_code = "\n\n".join(section for section in sections if section.strip())
|
| 124 |
return self._run_lean(lean_code)
|
| 125 |
|
|
|
|
| 100 |
spec_module: str,
|
| 101 |
function_name: str,
|
| 102 |
code: str,
|
| 103 |
+
symbol_name: str | None = None,
|
| 104 |
extra_imports: list[str] | None = None,
|
| 105 |
sample_checks: list[str] | None = None,
|
| 106 |
) -> LeanResult:
|
|
|
|
| 120 |
]
|
| 121 |
if sample_checks:
|
| 122 |
sections.extend(sample_checks)
|
| 123 |
+
sections.append(f"#check {symbol_name or f'_root_.{function_name}'}")
|
| 124 |
lean_code = "\n\n".join(section for section in sections if section.strip())
|
| 125 |
return self._run_lean(lean_code)
|
| 126 |
|
tests/test_env_episode.py
CHANGED
|
@@ -8,7 +8,7 @@ from lean_migrate.env.models import (
|
|
| 8 |
RunTestsAction,
|
| 9 |
SubmitAction,
|
| 10 |
)
|
| 11 |
-
from lean_migrate.env.target_snippets import TASK_TARGET_SNIPPETS
|
| 12 |
from lean_migrate.env.state import EpisodeState
|
| 13 |
from lean_migrate.env.tasks import get_task
|
| 14 |
from lean_migrate.env.verification_ir import build_verification_ir
|
|
@@ -172,8 +172,45 @@ def test_verification_ir_builds_summary() -> None:
|
|
| 172 |
assert result.provenance is not None
|
| 173 |
assert result.provenance.parse_ok
|
| 174 |
assert result.lean_code is not None
|
|
|
|
| 175 |
assert "def findRole" in result.lean_code
|
| 176 |
-
assert "example :
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
|
| 179 |
def test_verification_ir_reports_sample_mismatches() -> None:
|
|
|
|
| 8 |
RunTestsAction,
|
| 9 |
SubmitAction,
|
| 10 |
)
|
| 11 |
+
from lean_migrate.env.target_snippets import TASK_TARGET_SNIPPETS, build_submission_bundle
|
| 12 |
from lean_migrate.env.state import EpisodeState
|
| 13 |
from lean_migrate.env.tasks import get_task
|
| 14 |
from lean_migrate.env.verification_ir import build_verification_ir
|
|
|
|
| 172 |
assert result.provenance is not None
|
| 173 |
assert result.provenance.parse_ok
|
| 174 |
assert result.lean_code is not None
|
| 175 |
+
assert "namespace Candidate" in result.lean_code
|
| 176 |
assert "def findRole" in result.lean_code
|
| 177 |
+
assert "example : Candidate.findRole" in result.lean_code
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def test_submit_accepts_rbac_dependency_bundle() -> None:
|
| 181 |
+
task = get_task("rbac_auth")
|
| 182 |
+
state = EpisodeState.from_task(task)
|
| 183 |
+
target_snippets = TASK_TARGET_SNIPPETS["rbac_auth"]
|
| 184 |
+
verified_target_snippets: dict[str, str] = {}
|
| 185 |
+
|
| 186 |
+
for function_name in ["findRole", "hasDirectPermission"]:
|
| 187 |
+
action = SubmitAction(
|
| 188 |
+
type="submit",
|
| 189 |
+
function_name=function_name,
|
| 190 |
+
target_code=target_snippets[function_name],
|
| 191 |
+
)
|
| 192 |
+
_, reward, done, _ = state.apply(action)
|
| 193 |
+
assert reward.score > 0.0
|
| 194 |
+
assert not done
|
| 195 |
+
verified_target_snippets[function_name] = target_snippets[function_name]
|
| 196 |
+
|
| 197 |
+
action = SubmitAction(
|
| 198 |
+
type="submit",
|
| 199 |
+
function_name="canAccess",
|
| 200 |
+
target_code=build_submission_bundle(
|
| 201 |
+
task,
|
| 202 |
+
"canAccess",
|
| 203 |
+
verified_target_snippets,
|
| 204 |
+
target_snippets["canAccess"],
|
| 205 |
+
),
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
observation, reward, done, _ = state.apply(action)
|
| 209 |
+
|
| 210 |
+
assert reward.score > 0.0
|
| 211 |
+
assert reward.feedback.startswith("VERIFIED")
|
| 212 |
+
assert observation.progress == 1.0
|
| 213 |
+
assert done
|
| 214 |
|
| 215 |
|
| 216 |
def test_verification_ir_reports_sample_mismatches() -> None:
|