Hrushi commited on
Commit
26a7647
·
verified ·
1 Parent(s): 8c75600

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -91,7 +91,7 @@ TASK_ID=rbac_auth \
91
  uv run python inference.py
92
  ```
93
 
94
- `inference.py` uses the OpenAI client. It reads credentials from `HF_TOKEN`, `OPENAI_API_KEY`, or `API_KEY` and emits structured stdout logs in the required `[START]`, `[STEP]`, `[END]` format.
95
 
96
  ## Baseline
97
 
 
91
  uv run python inference.py
92
  ```
93
 
94
+ `inference.py` uses the OpenAI client against the Hugging Face router with `Qwen/Qwen2.5-72B-Instruct`. It reads credentials from `HF_TOKEN` and emits structured stdout logs in the required `[START]`, `[STEP]`, `[END]` format.
95
 
96
  ## Baseline
97
 
env/grader.py CHANGED
@@ -309,7 +309,19 @@ def _lean_call(task_id: str, function_name: str, args: tuple[Any, ...]) -> str:
309
  return _lean_call_impl(task_id, function_name, args)
310
 
311
 
312
- def build_lean_sample_checks(task: Task, function_spec: FunctionSpec) -> list[str]:
 
 
 
 
 
 
 
 
 
 
 
 
313
  if function_spec.is_proof_required:
314
  return []
315
 
@@ -317,7 +329,10 @@ def build_lean_sample_checks(task: Task, function_spec: FunctionSpec) -> list[st
317
  checks: list[str] = []
318
  for case in cases:
319
  expected_value = oracle_result(task.task_id, function_spec.name, case.args)
320
- call_expr = _lean_call(task.task_id, function_spec.name, case.args)
 
 
 
321
  expected_expr = _lean_value(task.task_id, function_spec.name, expected_value)
322
  checks.append(
323
  textwrap.dedent(
 
309
  return _lean_call_impl(task_id, function_name, args)
310
 
311
 
312
+ def _with_call_namespace(call_expr: str, call_namespace: str) -> str:
313
+ if call_namespace == "_root_":
314
+ return call_expr
315
+
316
+ root_prefix = "_root_."
317
+ if call_expr.startswith(root_prefix):
318
+ return f"{call_namespace}.{call_expr[len(root_prefix):]}"
319
+ return call_expr
320
+
321
+
322
+ def build_lean_sample_checks(
323
+ task: Task, function_spec: FunctionSpec, call_namespace: str = "_root_"
324
+ ) -> list[str]:
325
  if function_spec.is_proof_required:
326
  return []
327
 
 
329
  checks: list[str] = []
330
  for case in cases:
331
  expected_value = oracle_result(task.task_id, function_spec.name, case.args)
332
+ call_expr = _with_call_namespace(
333
+ _lean_call(task.task_id, function_spec.name, case.args),
334
+ call_namespace,
335
+ )
336
  expected_expr = _lean_value(task.task_id, function_spec.name, expected_value)
337
  checks.append(
338
  textwrap.dedent(
env/state.py CHANGED
@@ -337,6 +337,11 @@ class EpisodeState:
337
  code=ir_result.lean_code
338
  if ir_result is not None
339
  else action.target_code,
 
 
 
 
 
340
  sample_checks=[],
341
  )
342
  proof_compiled = None
 
337
  code=ir_result.lean_code
338
  if ir_result is not None
339
  else action.target_code,
340
+ symbol_name=(
341
+ f"Candidate.{action.function_name}"
342
+ if ir_result is not None and ir_result.lean_code is not None
343
+ else None
344
+ ),
345
  sample_checks=[],
346
  )
347
  proof_compiled = None
env/verification_ir.py CHANGED
@@ -541,6 +541,27 @@ def _render_lean_definition(task: Task, function_name: str) -> str:
541
  return textwrap.dedent(function_spec.lean_fragment).strip()
542
 
543
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  def _render_lean_mirror(
545
  task: Task, function_spec: FunctionSpec, provenance: CodeProvenanceIR
546
  ) -> str:
@@ -558,18 +579,27 @@ def _render_lean_mirror(
558
  """
559
  ).strip()
560
 
 
561
  definition_names = dependency_closure(task, function_spec.name) + [
562
  function_spec.name
563
  ]
 
 
 
 
 
 
 
 
564
  definition_block = "\n\n".join(
565
- [f"open {task.lean_spec_module}"]
566
- + [
567
- _render_lean_definition(task, definition_name)
568
- for definition_name in definition_names
569
- ]
570
  )
571
 
572
- sample_checks = build_lean_sample_checks(task, function_spec)
 
 
573
  checks_block = "\n\n".join(sample_checks)
574
 
575
  return "\n\n".join(
 
541
  return textwrap.dedent(function_spec.lean_fragment).strip()
542
 
543
 
544
+ def _qualify_definition_references(
545
+ definition_text: str, definition_names: list[str], namespace: str
546
+ ) -> str:
547
+ if len(definition_names) == 0:
548
+ return definition_text
549
+
550
+ lines = definition_text.splitlines()
551
+ if len(lines) <= 1:
552
+ return definition_text
553
+
554
+ name_pattern = re.compile(
555
+ r"\b(" + "|".join(re.escape(name) for name in definition_names) + r")\b"
556
+ )
557
+ qualified_lines = [lines[0]]
558
+ qualified_lines.extend(
559
+ name_pattern.sub(lambda match: f"{namespace}.{match.group(1)}", line)
560
+ for line in lines[1:]
561
+ )
562
+ return "\n".join(qualified_lines)
563
+
564
+
565
  def _render_lean_mirror(
566
  task: Task, function_spec: FunctionSpec, provenance: CodeProvenanceIR
567
  ) -> str:
 
579
  """
580
  ).strip()
581
 
582
+ candidate_namespace = "Candidate"
583
  definition_names = dependency_closure(task, function_spec.name) + [
584
  function_spec.name
585
  ]
586
+ qualified_definitions = [
587
+ _qualify_definition_references(
588
+ _render_lean_definition(task, definition_name),
589
+ definition_names,
590
+ candidate_namespace,
591
+ )
592
+ for definition_name in definition_names
593
+ ]
594
  definition_block = "\n\n".join(
595
+ [f"open {task.lean_spec_module}", f"namespace {candidate_namespace}"]
596
+ + qualified_definitions
597
+ + [f"end {candidate_namespace}"]
 
 
598
  )
599
 
600
+ sample_checks = build_lean_sample_checks(
601
+ task, function_spec, call_namespace=candidate_namespace
602
+ )
603
  checks_block = "\n\n".join(sample_checks)
604
 
605
  return "\n\n".join(
inference.py CHANGED
@@ -9,12 +9,15 @@ import textwrap
9
  from pathlib import Path
10
  from typing import Optional
11
 
 
 
 
 
 
12
  ROOT = Path(__file__).resolve().parents[1]
13
  if str(ROOT) not in sys.path:
14
  sys.path.insert(0, str(ROOT))
15
 
16
- from openai import OpenAI
17
-
18
  from lean_migrate.env.models import SubmitAction
19
  from lean_migrate.env.target_snippets import (
20
  TASK_TARGET_SNIPPETS,
@@ -23,9 +26,11 @@ from lean_migrate.env.target_snippets import (
23
  from lean_migrate.env.tasks import get_task, list_tasks
24
  from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment
25
 
 
 
26
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
27
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
28
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
29
  TASK_ID = os.getenv("TASK_ID")
30
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
31
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
@@ -39,7 +44,10 @@ def log_start(task: str, env: str, model: str) -> None:
39
  def log_step(
40
  step: int, action: str, reward: float, done: bool, error: Optional[str]
41
  ) -> None:
42
- error_value = error if error else "null"
 
 
 
43
  print(
44
  f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_value}",
45
  flush=True,
@@ -200,9 +208,7 @@ async def _run_task(client: OpenAI, task_id: str) -> None:
200
 
201
  async def main() -> None:
202
  if not API_KEY:
203
- raise RuntimeError(
204
- "Set HF_TOKEN, OPENAI_API_KEY, or API_KEY before running inference.py"
205
- )
206
 
207
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
208
  task_ids = [TASK_ID] if TASK_ID else [task["task_id"] for task in list_tasks()]
 
9
  from pathlib import Path
10
  from typing import Optional
11
 
12
+ from dotenv import load_dotenv
13
+
14
+
15
+ from openai import OpenAI
16
+
17
  ROOT = Path(__file__).resolve().parents[1]
18
  if str(ROOT) not in sys.path:
19
  sys.path.insert(0, str(ROOT))
20
 
 
 
21
  from lean_migrate.env.models import SubmitAction
22
  from lean_migrate.env.target_snippets import (
23
  TASK_TARGET_SNIPPETS,
 
26
  from lean_migrate.env.tasks import get_task, list_tasks
27
  from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment
28
 
29
+ load_dotenv() # Load environment variables from .env file
30
+
31
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
32
  MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
33
+ API_KEY = os.getenv("HF_TOKEN")
34
  TASK_ID = os.getenv("TASK_ID")
35
  MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
36
  TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
 
44
  def log_step(
45
  step: int, action: str, reward: float, done: bool, error: Optional[str]
46
  ) -> None:
47
+ if error:
48
+ error_value = " ".join(error.split())
49
+ else:
50
+ error_value = "null"
51
  print(
52
  f"[STEP] step={step} action={action} reward={reward:.2f} done={str(done).lower()} error={error_value}",
53
  flush=True,
 
208
 
209
  async def main() -> None:
210
  if not API_KEY:
211
+ raise RuntimeError("Set HF_TOKEN before running inference.py")
 
 
212
 
213
  client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
214
  task_ids = [TASK_ID] if TASK_ID else [task["task_id"] for task in list_tasks()]
lean_backend/interface.py CHANGED
@@ -19,6 +19,7 @@ class LeanBackend(ABC):
19
  spec_module: str,
20
  function_name: str,
21
  code: str,
 
22
  extra_imports: list[str] | None = None,
23
  sample_checks: list[str] | None = None,
24
  ) -> LeanResult:
 
19
  spec_module: str,
20
  function_name: str,
21
  code: str,
22
+ symbol_name: str | None = None,
23
  extra_imports: list[str] | None = None,
24
  sample_checks: list[str] | None = None,
25
  ) -> LeanResult:
lean_backend/kimina_backend.py CHANGED
@@ -52,6 +52,7 @@ class KiminaBackend(LeanBackend):
52
  spec_module: str,
53
  function_name: str,
54
  code: str,
 
55
  extra_imports: list[str] | None = None,
56
  sample_checks: list[str] | None = None,
57
  ) -> LeanResult:
@@ -71,7 +72,7 @@ class KiminaBackend(LeanBackend):
71
  ]
72
  if sample_checks:
73
  sections.extend(sample_checks)
74
- sections.append(f"#check _root_.{function_name}")
75
  lean_code = "\n\n".join(section for section in sections if section.strip())
76
  return self._call_kimina(lean_code)
77
 
 
52
  spec_module: str,
53
  function_name: str,
54
  code: str,
55
+ symbol_name: str | None = None,
56
  extra_imports: list[str] | None = None,
57
  sample_checks: list[str] | None = None,
58
  ) -> LeanResult:
 
72
  ]
73
  if sample_checks:
74
  sections.extend(sample_checks)
75
+ sections.append(f"#check {symbol_name or f'_root_.{function_name}'}")
76
  lean_code = "\n\n".join(section for section in sections if section.strip())
77
  return self._call_kimina(lean_code)
78
 
lean_backend/stdin_backend.py CHANGED
@@ -100,6 +100,7 @@ class StdinBackend(LeanBackend):
100
  spec_module: str,
101
  function_name: str,
102
  code: str,
 
103
  extra_imports: list[str] | None = None,
104
  sample_checks: list[str] | None = None,
105
  ) -> LeanResult:
@@ -119,7 +120,7 @@ class StdinBackend(LeanBackend):
119
  ]
120
  if sample_checks:
121
  sections.extend(sample_checks)
122
- sections.append(f"#check _root_.{function_name}")
123
  lean_code = "\n\n".join(section for section in sections if section.strip())
124
  return self._run_lean(lean_code)
125
 
 
100
  spec_module: str,
101
  function_name: str,
102
  code: str,
103
+ symbol_name: str | None = None,
104
  extra_imports: list[str] | None = None,
105
  sample_checks: list[str] | None = None,
106
  ) -> LeanResult:
 
120
  ]
121
  if sample_checks:
122
  sections.extend(sample_checks)
123
+ sections.append(f"#check {symbol_name or f'_root_.{function_name}'}")
124
  lean_code = "\n\n".join(section for section in sections if section.strip())
125
  return self._run_lean(lean_code)
126
 
tests/test_env_episode.py CHANGED
@@ -8,7 +8,7 @@ from lean_migrate.env.models import (
8
  RunTestsAction,
9
  SubmitAction,
10
  )
11
- from lean_migrate.env.target_snippets import TASK_TARGET_SNIPPETS
12
  from lean_migrate.env.state import EpisodeState
13
  from lean_migrate.env.tasks import get_task
14
  from lean_migrate.env.verification_ir import build_verification_ir
@@ -172,8 +172,45 @@ def test_verification_ir_builds_summary() -> None:
172
  assert result.provenance is not None
173
  assert result.provenance.parse_ok
174
  assert result.lean_code is not None
 
175
  assert "def findRole" in result.lean_code
176
- assert "example : _root_.findRole" in result.lean_code
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
 
179
  def test_verification_ir_reports_sample_mismatches() -> None:
 
8
  RunTestsAction,
9
  SubmitAction,
10
  )
11
+ from lean_migrate.env.target_snippets import TASK_TARGET_SNIPPETS, build_submission_bundle
12
  from lean_migrate.env.state import EpisodeState
13
  from lean_migrate.env.tasks import get_task
14
  from lean_migrate.env.verification_ir import build_verification_ir
 
172
  assert result.provenance is not None
173
  assert result.provenance.parse_ok
174
  assert result.lean_code is not None
175
+ assert "namespace Candidate" in result.lean_code
176
  assert "def findRole" in result.lean_code
177
+ assert "example : Candidate.findRole" in result.lean_code
178
+
179
+
180
+ def test_submit_accepts_rbac_dependency_bundle() -> None:
181
+ task = get_task("rbac_auth")
182
+ state = EpisodeState.from_task(task)
183
+ target_snippets = TASK_TARGET_SNIPPETS["rbac_auth"]
184
+ verified_target_snippets: dict[str, str] = {}
185
+
186
+ for function_name in ["findRole", "hasDirectPermission"]:
187
+ action = SubmitAction(
188
+ type="submit",
189
+ function_name=function_name,
190
+ target_code=target_snippets[function_name],
191
+ )
192
+ _, reward, done, _ = state.apply(action)
193
+ assert reward.score > 0.0
194
+ assert not done
195
+ verified_target_snippets[function_name] = target_snippets[function_name]
196
+
197
+ action = SubmitAction(
198
+ type="submit",
199
+ function_name="canAccess",
200
+ target_code=build_submission_bundle(
201
+ task,
202
+ "canAccess",
203
+ verified_target_snippets,
204
+ target_snippets["canAccess"],
205
+ ),
206
+ )
207
+
208
+ observation, reward, done, _ = state.apply(action)
209
+
210
+ assert reward.score > 0.0
211
+ assert reward.feedback.startswith("VERIFIED")
212
+ assert observation.progress == 1.0
213
+ assert done
214
 
215
 
216
  def test_verification_ir_reports_sample_mismatches() -> None: