Spaces:
Sleeping
Sleeping
| """Episode state machine and action dispatch. | |
| EpisodeState is the core of the environment. One instance lives per episode. | |
| apply() is the central dispatch that handles the four action types: | |
| - inspect → surface source + Lean spec, small shaping reward | |
| - analyze_deps → surface dependency graph, small shaping reward | |
| - run_tests → execute candidate_code, score against oracle | |
| - submit → build verification IR, run Lean backend, update verified set | |
| Verified functions are stored in self.verified (set) and their code snippets | |
| in self.verified_snippets (dict). Dependencies are injected automatically | |
| via _bundle_with_deps() before testing or submission. | |
| """ | |
| from __future__ import annotations | |
| import functools | |
| import re | |
| import shutil | |
| import subprocess | |
| import tempfile | |
| import uuid | |
| from dataclasses import dataclass, field | |
| from pathlib import Path | |
| from openenv.core.env_server.types import State | |
| from .grader import ( | |
| clamp_open_unit, | |
| build_breakdown, | |
| run_candidate_tests, | |
| score_progress, | |
| score_step_reward, | |
| ) | |
| from .models import ( | |
| AnalyzeDepsAction, | |
| InspectAction, | |
| LeanMigrateAction, | |
| LeanMigrateObservation, | |
| LeanMigrateReward, | |
| RunTestsAction, | |
| SubmitAction, | |
| ) | |
| from .submission_parsers import _camel_to_snake | |
| from .target_snippets import dependency_closure | |
| from .tasks import FunctionSpec, Task | |
| from .verification_ir import build_verification_ir | |
| from ..lean_backend.interface import LeanBackend | |
| from ..lean_backend.kimina_backend import get_backend | |
| INSPECT_REWARD = 0.05 | |
| ANALYZE_DEPS_REWARD = 0.05 | |
| RUN_TESTS_SUCCESS_REWARD = 0.10 | |
| def _format_rust_snippet(snippet: str) -> str: | |
| rustfmt = shutil.which("rustfmt") | |
| if rustfmt is None or not snippet.strip(): | |
| return snippet | |
| # # Rust enums used as simple tags in the verified dependency bundle should | |
| # # be Copy so downstream submissions can pass them by value without fighting | |
| # # move semantics. | |
| # if "pub enum Op" in snippet and "derive(Copy, Clone)" not in snippet: | |
| # snippet = snippet.replace( | |
| # "pub enum Op", | |
| # "#[derive(Copy, Clone)]\npub enum Op", | |
| # 1, | |
| # ) | |
| with tempfile.TemporaryDirectory() as temp_dir: | |
| snippet_path = Path(temp_dir) / "snippet.rs" | |
| snippet_path.write_text(snippet.rstrip() + "\n") | |
| try: | |
| completed = subprocess.run( | |
| [rustfmt, "--edition", "2021", str(snippet_path)], | |
| capture_output=True, | |
| text=True, | |
| timeout=5, | |
| ) | |
| except Exception: | |
| return snippet | |
| if completed.returncode != 0: | |
| return snippet | |
| try: | |
| formatted = snippet_path.read_text().strip() | |
| except OSError: | |
| return snippet | |
| return formatted or snippet | |
| def _format_verified_dependency_code(language: str, code: str) -> str: | |
| if language == "rust": | |
| return _format_rust_snippet(code) | |
| return code | |
| def _extract_source_symbol(source_fragment: str) -> str | None: | |
| for raw_line in source_fragment.splitlines(): | |
| line = raw_line.strip() | |
| if not line or line.startswith("//") or line.startswith("/*") or line.startswith("*"): | |
| continue | |
| for pattern in ( | |
| r"^(?:pub\s+)?(?:unsafe\s+)?fn\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", | |
| r"^(?:def|function)\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", | |
| r"^(?:[A-Za-z_][A-Za-z0-9_:\<\>\*&\s]*?)\b([A-Za-z_][A-Za-z0-9_]*)\s*\(", | |
| ): | |
| match = re.match(pattern, line) | |
| if match is not None: | |
| return match.group(1) | |
| return None | |
| class EpisodeState: | |
| episode_id: str | |
| task: Task | |
| backend: LeanBackend | |
| step_count: int = 0 | |
| verified: set[str] = field(default_factory=set) | |
| failing: dict[str, str] = field(default_factory=dict) | |
| last_feedback: str | None = None | |
| last_action_type: str | None = None | |
| last_step_reward: float | None = None | |
| last_tests_passed: int | None = None | |
| last_tests_total: int | None = None | |
| last_reward_details: LeanMigrateReward | None = None | |
| cumulative_score: float = 0.0 | |
| verified_snippets: dict[str, str] = field(default_factory=dict) | |
| consecutive_failures: dict[str, int] = field(default_factory=dict) | |
| def from_task( | |
| cls, | |
| task: Task, | |
| backend: LeanBackend | None = None, | |
| episode_id: str | None = None, | |
| ) -> "EpisodeState": | |
| return cls( | |
| episode_id=episode_id or str(uuid.uuid4()), | |
| task=task, | |
| backend=backend or get_backend(), | |
| ) | |
| def remaining(self) -> list[str]: | |
| return [ | |
| function_name | |
| for function_name in self.task.function_names | |
| if function_name not in self.verified | |
| ] | |
| def ready_to_submit(self) -> list[str]: | |
| ready_functions: list[str] = [] | |
| for function_name in self.remaining: | |
| dependencies = self.task.dependency_graph.get(function_name, []) | |
| if all(dependency in self.verified for dependency in dependencies): | |
| ready_functions.append(function_name) | |
| return ready_functions | |
| def progress(self) -> float: | |
| return score_progress(len(self.verified), len(self.task.functions)) | |
| def suggested_function(self) -> str | None: | |
| ready_functions = self.ready_to_submit | |
| return ready_functions[0] if ready_functions else None | |
| def to_observation(self) -> LeanMigrateObservation: | |
| return LeanMigrateObservation( | |
| episode_id=self.episode_id, | |
| task_id=self.task.task_id, | |
| episode_step=self.step_count, | |
| max_steps=self.task.max_steps, | |
| source_language=self.task.source_language, | |
| target_language=self.task.target_language, | |
| source_files=self.task.source_files, | |
| verified=sorted(self.verified), | |
| failing=dict(self.failing), | |
| remaining=self.remaining, | |
| progress=self.progress, | |
| last_action_type=self.last_action_type, | |
| last_action_feedback=self.last_feedback, | |
| last_step_reward=self.last_step_reward, | |
| done=False, | |
| reward=self.last_step_reward, | |
| reward_details=self.last_reward_details, | |
| ) | |
| def apply( | |
| self, action: LeanMigrateAction | |
| ) -> tuple[LeanMigrateObservation, LeanMigrateReward, bool, dict[str, object]]: | |
| self.step_count += 1 | |
| self.last_action_type = action.type | |
| self.last_tests_passed = None | |
| self.last_tests_total = None | |
| if action.type == "inspect": | |
| reward_value, feedback, proof_compiled, breakdown = self._handle_inspect( | |
| action | |
| ) | |
| elif action.type == "analyze_deps": | |
| reward_value, feedback, proof_compiled, breakdown = ( | |
| self._handle_analyze_deps(action) | |
| ) | |
| elif action.type == "run_tests": | |
| reward_value, feedback, proof_compiled, breakdown = self._handle_run_tests( | |
| action | |
| ) | |
| elif action.type == "submit": | |
| reward_value, feedback, proof_compiled, breakdown = self._handle_submit( | |
| action | |
| ) | |
| else: | |
| reward_value, feedback, proof_compiled, breakdown = ( | |
| -0.05, | |
| "Unknown action type.", | |
| None, | |
| build_breakdown(0.0), | |
| ) | |
| self.last_feedback = feedback | |
| self.last_step_reward = reward_value | |
| self.cumulative_score = clamp_open_unit(self.progress) | |
| done = self.progress >= 1.0 or self.step_count >= self.task.max_steps | |
| raw_fn_name = getattr(action, "function_name", "") | |
| fn_spec_for_error = self._resolve_function(raw_fn_name) | |
| canonical_fn_name = fn_spec_for_error.name if fn_spec_for_error else raw_fn_name | |
| lean_error = self.failing.get(canonical_fn_name) | |
| reward = LeanMigrateReward( | |
| score=reward_value, | |
| cumulative_score=self.cumulative_score, | |
| tests_passed=self.last_tests_passed, | |
| tests_total=self.last_tests_total, | |
| proof_compiled=proof_compiled, | |
| breakdown=breakdown, | |
| feedback=feedback, | |
| lean_error=lean_error, | |
| ) | |
| self.last_reward_details = reward | |
| observation = self.to_observation() | |
| observation.done = done | |
| observation.reward = reward_value | |
| observation.reward_details = reward | |
| return ( | |
| observation, | |
| reward, | |
| done, | |
| { | |
| "progress": self.progress, | |
| "verified": sorted(self.verified), | |
| "step": self.step_count, | |
| }, | |
| ) | |
| def _record_failure(self, fn_name: str, feedback: str, reward: float) -> tuple[str, float]: | |
| """Increment consecutive failure count; escalate feedback and penalty after thresholds.""" | |
| n = self.consecutive_failures.get(fn_name, 0) + 1 | |
| self.consecutive_failures[fn_name] = n | |
| if n >= 5: | |
| extra = ( | |
| f"\n\n⚠ REPEATED FAILURE ({n}x on '{fn_name}'): " | |
| "Your last several attempts produced the same error. " | |
| "You MUST try a fundamentally different approach: " | |
| "use 'inspect' to re-read the spec, reconsider the function signature, " | |
| "or implement the logic differently." | |
| ) | |
| reward = min(reward - 0.05, -0.1) | |
| elif n >= 3: | |
| extra = ( | |
| f"\n\nHint: '{fn_name}' has failed {n} times in a row. " | |
| "Consider using 'inspect' to re-read the spec and trying a different approach." | |
| ) | |
| else: | |
| extra = "" | |
| return feedback + extra, reward | |
| def _record_success(self, fn_name: str) -> None: | |
| """Reset consecutive failure count on any success.""" | |
| self.consecutive_failures.pop(fn_name, None) | |
| def _resolve_function(self, name: str) -> FunctionSpec | None: | |
| """Exact match first; fall back to snake_case normalisation for Rust tasks. | |
| Agents working on Rust targets naturally use snake_case (e.g. | |
| ``merge_intervals``) while the spec stores camelCase (``mergeIntervals``). | |
| This normalises the lookup so both forms resolve to the same spec. | |
| The canonical spec.name is always used for downstream state keys. | |
| """ | |
| spec = self.task.get_function(name) | |
| if spec is not None: | |
| return spec | |
| for fn in self.task.functions: | |
| if _camel_to_snake(fn.name) == name: | |
| return fn | |
| return None | |
| def _bundle_with_deps(self, function_name: str, snippet: str) -> str: | |
| """Prepend verified dependency snippets to the given code snippet.""" | |
| parts = [ | |
| self.verified_snippets[dep] | |
| for dep in dependency_closure(self.task, function_name) | |
| if dep in self.verified_snippets | |
| ] | |
| parts.append(snippet) | |
| return "\n\n".join(parts).strip() | |
| def _handle_inspect( | |
| self, action: InspectAction | |
| ) -> tuple[float, str, bool | None, dict[str, float]]: | |
| function_spec = self._resolve_function(action.function_name) | |
| if function_spec is None: | |
| valid = [fn.name for fn in self.task.functions] | |
| return ( | |
| -0.05, | |
| f"Function '{action.function_name}' not found. Valid names: {valid}", | |
| None, | |
| build_breakdown(0.0), | |
| ) | |
| lang = self.task.source_language | |
| target_lang = self.task.target_language | |
| parts: list[str] = [] | |
| parts.append( | |
| f"<FUNCTION_INFO>\n" | |
| f"Name: {function_spec.name}\n" | |
| f"Source symbol: {_extract_source_symbol(function_spec.source_fragment) or '(unavailable)'}\n" | |
| f"Description: {function_spec.description}\n" | |
| f"Depends on: {function_spec.depends_on or ['(none)']}\n" | |
| f"Proof required: {function_spec.is_proof_required}\n" | |
| f"</FUNCTION_INFO>" | |
| ) | |
| parts.append( | |
| f"<SOURCE language=\"{lang}\">\n" | |
| f"{function_spec.source_fragment}\n" | |
| f"</SOURCE>" | |
| ) | |
| parts.append( | |
| f"<LEAN_SPEC>\n" | |
| f"{function_spec.lean_fragment}\n" | |
| f"</LEAN_SPEC>" | |
| ) | |
| # Show verified dependency snippets so the agent knows exactly what code | |
| # is auto-prepended to its submission. Without this, agents redefine symbols | |
| # already present in the bundle, causing duplicate-definition compile errors | |
| # (especially fatal in Rust which forbids duplicate definitions). | |
| dep_closure = dependency_closure(self.task, function_spec.name) | |
| verified_deps = [ | |
| (dep, self.verified_snippets[dep]) | |
| for dep in dep_closure | |
| if dep in self.verified_snippets | |
| ] | |
| if verified_deps: | |
| dep_blocks: list[str] = [ | |
| "The following dependency code is automatically prepended to your " | |
| "submission. Do NOT redefine these symbols.\n" | |
| ] | |
| for dep_name, dep_code in verified_deps: | |
| formatted_dep_code = _format_verified_dependency_code(target_lang, dep_code) | |
| dep_blocks.append( | |
| f"{dep_name} (verified):\n" | |
| f"```{target_lang}\n{formatted_dep_code}\n```" | |
| ) | |
| parts.append( | |
| "<VERIFIED_DEPS>\n" | |
| + "\n\n".join(dep_blocks) | |
| + "\n</VERIFIED_DEPS>" | |
| ) | |
| # For proof tasks, append the full Lean spec file so the agent can see | |
| # all available public theorems (avoids hallucinating theorem names). | |
| if function_spec.is_proof_required and self.task.lean_spec_module: | |
| spec_path = ( | |
| Path(__file__).parent.parent | |
| / "lean" | |
| / f"{self.task.lean_spec_module}.lean" | |
| ) | |
| if spec_path.exists(): | |
| lean_content = spec_path.read_text() | |
| parts.append( | |
| f"<FULL_LEAN_SPEC file=\"{self.task.lean_spec_module}.lean\">\n" | |
| f"```lean\n{lean_content}\n```\n" | |
| f"</FULL_LEAN_SPEC>" | |
| ) | |
| feedback = "\n\n".join(parts) | |
| return INSPECT_REWARD, feedback, None, build_breakdown(INSPECT_REWARD) | |
| def _handle_analyze_deps( | |
| self, action: AnalyzeDepsAction | |
| ) -> tuple[float, str, bool | None, dict[str, float]]: | |
| function_spec = self._resolve_function(action.function_name) | |
| if function_spec is None: | |
| valid = [fn.name for fn in self.task.functions] | |
| return ( | |
| -0.05, | |
| f"Function '{action.function_name}' not found. Valid names: {valid}", | |
| None, | |
| build_breakdown(0.0), | |
| ) | |
| dependencies = function_spec.depends_on | |
| graph_lines = [ | |
| f" - {function_name}: {self.task.dependency_graph.get(function_name, []) or ['(none)']}" | |
| for function_name in self.task.topo_order | |
| ] | |
| plan_lines = [ | |
| f" {index + 1}. {function_name}" | |
| for index, function_name in enumerate(self.task.topo_order) | |
| ] | |
| dependency_lines = [ | |
| f" - {dependency}: {'verified' if dependency in self.verified else 'not yet verified'}" | |
| for dependency in dependencies | |
| ] or [" (none)"] | |
| feedback = ( | |
| f"Dependencies for '{action.function_name}':\n" | |
| + "\n".join(dependency_lines) | |
| + "\n\nDependency graph:\n" | |
| + "\n".join(graph_lines) | |
| + "\n\nMigration plan:\n" | |
| + "\n".join(plan_lines) | |
| ) | |
| return ANALYZE_DEPS_REWARD, feedback, None, build_breakdown( | |
| ANALYZE_DEPS_REWARD | |
| ) | |
| def _handle_run_tests( | |
| self, action: RunTestsAction | |
| ) -> tuple[float, str, bool | None, dict[str, float]]: | |
| function_spec = self._resolve_function(action.function_name) | |
| if function_spec is None: | |
| valid = [fn.name for fn in self.task.functions] | |
| return ( | |
| -0.05, | |
| f"Function '{action.function_name}' not found. Valid names: {valid}", | |
| None, | |
| build_breakdown(0.0), | |
| ) | |
| bundled_code = self._bundle_with_deps(function_spec.name, action.candidate_code) | |
| result = run_candidate_tests(self.task, function_spec, bundled_code) | |
| feedback = result.feedback | |
| self.last_tests_passed = result.tests_passed | |
| self.last_tests_total = result.tests_total | |
| if result.passed: | |
| self._record_success(function_spec.name) | |
| reward_value = RUN_TESTS_SUCCESS_REWARD | |
| else: | |
| # Detect Rust duplicate-definition errors caused by agents re-declaring | |
| # symbols that are already injected from verified dependency snippets. | |
| dep_closure = dependency_closure(self.task, function_spec.name) | |
| has_verified_deps = any(dep in self.verified_snippets for dep in dep_closure) | |
| if has_verified_deps and "E0428" in result.stderr: | |
| feedback += ( | |
| "\nNote: verified dependency code is automatically prepended to your " | |
| "submission. Do not redefine types or functions from already-verified " | |
| "dependencies. Use 'inspect' to see exactly what code will be injected." | |
| ) | |
| reward_value = -0.01 * max(1, result.tests_total - result.tests_passed) | |
| feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value) | |
| breakdown = build_breakdown( | |
| 0.0, | |
| property_score=result.tests_passed / result.tests_total | |
| if result.tests_total | |
| else 0.0, | |
| ) | |
| return reward_value, feedback, None, breakdown | |
| def _handle_submit( | |
| self, action: SubmitAction | |
| ) -> tuple[float, str, bool | None, dict[str, float]]: | |
| function_spec = self._resolve_function(action.function_name) | |
| if function_spec is None: | |
| valid = [fn.name for fn in self.task.functions] | |
| return ( | |
| -0.05, | |
| f"Function '{action.function_name}' not found. Valid names: {valid}", | |
| None, | |
| build_breakdown(0.0), | |
| ) | |
| unmet_dependencies = [ | |
| dependency | |
| for dependency in function_spec.depends_on | |
| if dependency not in self.verified | |
| ] | |
| if unmet_dependencies: | |
| feedback = ( | |
| f"Cannot submit '{action.function_name}' yet.\n" | |
| f"Unmet dependencies: {unmet_dependencies}\n" | |
| f"Verify these first: {unmet_dependencies}" | |
| ) | |
| return -0.05, feedback, None, build_breakdown(0.0) | |
| ir_result = None | |
| if not function_spec.is_proof_required: | |
| target_code = action.target_code | |
| if not target_code: | |
| return ( | |
| -0.05, | |
| f"'{action.function_name}' has no code to submit. Run `run_tests` first so the system can record your implementation.", | |
| None, | |
| build_breakdown(0.0), | |
| ) | |
| bundled_target = self._bundle_with_deps(function_spec.name, target_code) | |
| ir_result = build_verification_ir(self.task, function_spec, bundled_target) | |
| if not ir_result.ready or ir_result.lean_code is None: | |
| if ir_result.run_result is not None: | |
| self.last_tests_passed = ir_result.run_result.tests_passed | |
| self.last_tests_total = ir_result.run_result.tests_total | |
| property_score = 0.0 | |
| if ( | |
| ir_result.run_result is not None | |
| and ir_result.run_result.tests_total | |
| ): | |
| property_score = ( | |
| ir_result.run_result.tests_passed | |
| / ir_result.run_result.tests_total | |
| ) | |
| return ( | |
| -0.05, | |
| ( | |
| f"REJECTED: '{action.function_name}' failed IR validation.\n" | |
| f"{ir_result.feedback}" | |
| ), | |
| False, | |
| build_breakdown(0.0, property_score=property_score), | |
| ) | |
| if function_spec.is_proof_required: | |
| if not action.lean_proof: | |
| return ( | |
| -0.05, | |
| ( | |
| f"'{action.function_name}' requires a Lean proof. Provide lean_proof before submitting." | |
| ), | |
| False, | |
| build_breakdown(0.0, proof=0.0), | |
| ) | |
| result = self.backend.verify_proof( | |
| spec_module=self.task.lean_spec_module, | |
| proof_code=action.lean_proof, | |
| ) | |
| proof_compiled = result.passed | |
| self.last_tests_passed = None | |
| self.last_tests_total = None | |
| else: | |
| result = self.backend.verify( | |
| spec_module=self.task.lean_spec_module, | |
| function_name=function_spec.name, | |
| code=ir_result.lean_code | |
| if ir_result is not None | |
| else action.target_code, | |
| symbol_name=( | |
| f"Candidate.{function_spec.name}" | |
| if ir_result is not None and ir_result.lean_code is not None | |
| else None | |
| ), | |
| sample_checks=[], | |
| ) | |
| proof_compiled = None | |
| self.last_tests_total = ( | |
| ir_result.run_result.tests_total | |
| if ir_result and ir_result.run_result | |
| else None | |
| ) | |
| self.last_tests_passed = ( | |
| self.last_tests_total | |
| if result.passed and self.last_tests_total | |
| else 0 | |
| if self.last_tests_total | |
| else None | |
| ) | |
| if result.passed: | |
| self._record_success(function_spec.name) | |
| self.verified.add(function_spec.name) | |
| self.failing.pop(function_spec.name, None) | |
| if not function_spec.is_proof_required and action.target_code: | |
| self.verified_snippets[function_spec.name] = _format_verified_dependency_code( | |
| self.task.target_language, action.target_code | |
| ) | |
| reward_value = score_step_reward(True, len(self.task.functions)) | |
| if function_spec.is_proof_required: | |
| breakdown = build_breakdown(1.0, proof=1.0) | |
| feedback = ( | |
| f"VERIFIED: '{action.function_name}' accepted by LEAN.\n" | |
| f"Latency: {result.latency_ms}ms\n" | |
| f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified." | |
| ) | |
| else: | |
| breakdown = build_breakdown( | |
| 1.0, | |
| property_score=1.0 if self.last_tests_total else 0.0, | |
| ) | |
| feedback = ( | |
| f"VERIFIED: '{action.function_name}' accepted by IR + LEAN.\n" | |
| f"{ir_result.feedback if ir_result is not None else ''}\n" | |
| f"Latency: {result.latency_ms}ms\n" | |
| f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified." | |
| ) | |
| else: | |
| self.failing[function_spec.name] = result.error | |
| reward_value = score_step_reward(False, len(self.task.functions)) | |
| breakdown = build_breakdown(0.0) | |
| if function_spec.is_proof_required: | |
| feedback = ( | |
| f"REJECTED: '{action.function_name}' failed LEAN verification.\n" | |
| f"Latency: {result.latency_ms}ms\n\n" | |
| f"LEAN error:\n{result.error}\n\n" | |
| f"Hint: Check that your code matches the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec." | |
| ) | |
| else: | |
| feedback = ( | |
| f"REJECTED: '{action.function_name}' failed LEAN verification.\n" | |
| f"{ir_result.feedback if ir_result is not None else ''}\n\n" | |
| f"Latency: {result.latency_ms}ms\n\n" | |
| f"LEAN error:\n{result.error}\n\n" | |
| f"Hint: Check that your code matches the generated Lean mirror and the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec." | |
| ) | |
| feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value) | |
| return reward_value, feedback, proof_compiled, breakdown | |