"""Episode state machine and action dispatch. EpisodeState is the core of the environment. One instance lives per episode. apply() is the central dispatch that handles the four action types: - inspect → surface source + Lean spec, small shaping reward - analyze_deps → surface dependency graph, small shaping reward - run_tests → execute candidate_code, score against oracle - submit → build verification IR, run Lean backend, update verified set Verified functions are stored in self.verified (set) and their code snippets in self.verified_snippets (dict). Dependencies are injected automatically via _bundle_with_deps() before testing or submission. """ from __future__ import annotations import functools import re import shutil import subprocess import tempfile import uuid from dataclasses import dataclass, field from pathlib import Path from openenv.core.env_server.types import State from .grader import ( clamp_open_unit, build_breakdown, run_candidate_tests, score_progress, score_step_reward, ) from .models import ( AnalyzeDepsAction, InspectAction, LeanMigrateAction, LeanMigrateObservation, LeanMigrateReward, RunTestsAction, SubmitAction, ) from .submission_parsers import _camel_to_snake from .target_snippets import dependency_closure from .tasks import FunctionSpec, Task from .verification_ir import build_verification_ir from ..lean_backend.interface import LeanBackend from ..lean_backend.kimina_backend import get_backend INSPECT_REWARD = 0.05 ANALYZE_DEPS_REWARD = 0.05 RUN_TESTS_SUCCESS_REWARD = 0.10 @functools.lru_cache(maxsize=128) def _format_rust_snippet(snippet: str) -> str: rustfmt = shutil.which("rustfmt") if rustfmt is None or not snippet.strip(): return snippet # # Rust enums used as simple tags in the verified dependency bundle should # # be Copy so downstream submissions can pass them by value without fighting # # move semantics. # if "pub enum Op" in snippet and "derive(Copy, Clone)" not in snippet: # snippet = snippet.replace( # "pub enum Op", # "#[derive(Copy, Clone)]\npub enum Op", # 1, # ) with tempfile.TemporaryDirectory() as temp_dir: snippet_path = Path(temp_dir) / "snippet.rs" snippet_path.write_text(snippet.rstrip() + "\n") try: completed = subprocess.run( [rustfmt, "--edition", "2021", str(snippet_path)], capture_output=True, text=True, timeout=5, ) except Exception: return snippet if completed.returncode != 0: return snippet try: formatted = snippet_path.read_text().strip() except OSError: return snippet return formatted or snippet def _format_verified_dependency_code(language: str, code: str) -> str: if language == "rust": return _format_rust_snippet(code) return code def _extract_source_symbol(source_fragment: str) -> str | None: for raw_line in source_fragment.splitlines(): line = raw_line.strip() if not line or line.startswith("//") or line.startswith("/*") or line.startswith("*"): continue for pattern in ( r"^(?:pub\s+)?(?:unsafe\s+)?fn\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", r"^(?:def|function)\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(", r"^(?:[A-Za-z_][A-Za-z0-9_:\<\>\*&\s]*?)\b([A-Za-z_][A-Za-z0-9_]*)\s*\(", ): match = re.match(pattern, line) if match is not None: return match.group(1) return None @dataclass class EpisodeState: episode_id: str task: Task backend: LeanBackend step_count: int = 0 verified: set[str] = field(default_factory=set) failing: dict[str, str] = field(default_factory=dict) last_feedback: str | None = None last_action_type: str | None = None last_step_reward: float | None = None last_tests_passed: int | None = None last_tests_total: int | None = None last_reward_details: LeanMigrateReward | None = None cumulative_score: float = 0.0 verified_snippets: dict[str, str] = field(default_factory=dict) consecutive_failures: dict[str, int] = field(default_factory=dict) @classmethod def from_task( cls, task: Task, backend: LeanBackend | None = None, episode_id: str | None = None, ) -> "EpisodeState": return cls( episode_id=episode_id or str(uuid.uuid4()), task=task, backend=backend or get_backend(), ) @property def remaining(self) -> list[str]: return [ function_name for function_name in self.task.function_names if function_name not in self.verified ] @property def ready_to_submit(self) -> list[str]: ready_functions: list[str] = [] for function_name in self.remaining: dependencies = self.task.dependency_graph.get(function_name, []) if all(dependency in self.verified for dependency in dependencies): ready_functions.append(function_name) return ready_functions @property def progress(self) -> float: return score_progress(len(self.verified), len(self.task.functions)) @property def suggested_function(self) -> str | None: ready_functions = self.ready_to_submit return ready_functions[0] if ready_functions else None def to_observation(self) -> LeanMigrateObservation: return LeanMigrateObservation( episode_id=self.episode_id, task_id=self.task.task_id, episode_step=self.step_count, max_steps=self.task.max_steps, source_language=self.task.source_language, target_language=self.task.target_language, source_files=self.task.source_files, verified=sorted(self.verified), failing=dict(self.failing), remaining=self.remaining, progress=self.progress, last_action_type=self.last_action_type, last_action_feedback=self.last_feedback, last_step_reward=self.last_step_reward, done=False, reward=self.last_step_reward, reward_details=self.last_reward_details, ) def apply( self, action: LeanMigrateAction ) -> tuple[LeanMigrateObservation, LeanMigrateReward, bool, dict[str, object]]: self.step_count += 1 self.last_action_type = action.type self.last_tests_passed = None self.last_tests_total = None if action.type == "inspect": reward_value, feedback, proof_compiled, breakdown = self._handle_inspect( action ) elif action.type == "analyze_deps": reward_value, feedback, proof_compiled, breakdown = ( self._handle_analyze_deps(action) ) elif action.type == "run_tests": reward_value, feedback, proof_compiled, breakdown = self._handle_run_tests( action ) elif action.type == "submit": reward_value, feedback, proof_compiled, breakdown = self._handle_submit( action ) else: reward_value, feedback, proof_compiled, breakdown = ( -0.05, "Unknown action type.", None, build_breakdown(0.0), ) self.last_feedback = feedback self.last_step_reward = reward_value self.cumulative_score = clamp_open_unit(self.progress) done = self.progress >= 1.0 or self.step_count >= self.task.max_steps raw_fn_name = getattr(action, "function_name", "") fn_spec_for_error = self._resolve_function(raw_fn_name) canonical_fn_name = fn_spec_for_error.name if fn_spec_for_error else raw_fn_name lean_error = self.failing.get(canonical_fn_name) reward = LeanMigrateReward( score=reward_value, cumulative_score=self.cumulative_score, tests_passed=self.last_tests_passed, tests_total=self.last_tests_total, proof_compiled=proof_compiled, breakdown=breakdown, feedback=feedback, lean_error=lean_error, ) self.last_reward_details = reward observation = self.to_observation() observation.done = done observation.reward = reward_value observation.reward_details = reward return ( observation, reward, done, { "progress": self.progress, "verified": sorted(self.verified), "step": self.step_count, }, ) def _record_failure(self, fn_name: str, feedback: str, reward: float) -> tuple[str, float]: """Increment consecutive failure count; escalate feedback and penalty after thresholds.""" n = self.consecutive_failures.get(fn_name, 0) + 1 self.consecutive_failures[fn_name] = n if n >= 5: extra = ( f"\n\n⚠ REPEATED FAILURE ({n}x on '{fn_name}'): " "Your last several attempts produced the same error. " "You MUST try a fundamentally different approach: " "use 'inspect' to re-read the spec, reconsider the function signature, " "or implement the logic differently." ) reward = min(reward - 0.05, -0.1) elif n >= 3: extra = ( f"\n\nHint: '{fn_name}' has failed {n} times in a row. " "Consider using 'inspect' to re-read the spec and trying a different approach." ) else: extra = "" return feedback + extra, reward def _record_success(self, fn_name: str) -> None: """Reset consecutive failure count on any success.""" self.consecutive_failures.pop(fn_name, None) def _resolve_function(self, name: str) -> FunctionSpec | None: """Exact match first; fall back to snake_case normalisation for Rust tasks. Agents working on Rust targets naturally use snake_case (e.g. ``merge_intervals``) while the spec stores camelCase (``mergeIntervals``). This normalises the lookup so both forms resolve to the same spec. The canonical spec.name is always used for downstream state keys. """ spec = self.task.get_function(name) if spec is not None: return spec for fn in self.task.functions: if _camel_to_snake(fn.name) == name: return fn return None def _bundle_with_deps(self, function_name: str, snippet: str) -> str: """Prepend verified dependency snippets to the given code snippet.""" parts = [ self.verified_snippets[dep] for dep in dependency_closure(self.task, function_name) if dep in self.verified_snippets ] parts.append(snippet) return "\n\n".join(parts).strip() def _handle_inspect( self, action: InspectAction ) -> tuple[float, str, bool | None, dict[str, float]]: function_spec = self._resolve_function(action.function_name) if function_spec is None: valid = [fn.name for fn in self.task.functions] return ( -0.05, f"Function '{action.function_name}' not found. Valid names: {valid}", None, build_breakdown(0.0), ) lang = self.task.source_language target_lang = self.task.target_language parts: list[str] = [] parts.append( f"\n" f"Name: {function_spec.name}\n" f"Source symbol: {_extract_source_symbol(function_spec.source_fragment) or '(unavailable)'}\n" f"Description: {function_spec.description}\n" f"Depends on: {function_spec.depends_on or ['(none)']}\n" f"Proof required: {function_spec.is_proof_required}\n" f"" ) parts.append( f"\n" f"{function_spec.source_fragment}\n" f"" ) parts.append( f"\n" f"{function_spec.lean_fragment}\n" f"" ) # Show verified dependency snippets so the agent knows exactly what code # is auto-prepended to its submission. Without this, agents redefine symbols # already present in the bundle, causing duplicate-definition compile errors # (especially fatal in Rust which forbids duplicate definitions). dep_closure = dependency_closure(self.task, function_spec.name) verified_deps = [ (dep, self.verified_snippets[dep]) for dep in dep_closure if dep in self.verified_snippets ] if verified_deps: dep_blocks: list[str] = [ "The following dependency code is automatically prepended to your " "submission. Do NOT redefine these symbols.\n" ] for dep_name, dep_code in verified_deps: formatted_dep_code = _format_verified_dependency_code(target_lang, dep_code) dep_blocks.append( f"{dep_name} (verified):\n" f"```{target_lang}\n{formatted_dep_code}\n```" ) parts.append( "\n" + "\n\n".join(dep_blocks) + "\n" ) # For proof tasks, append the full Lean spec file so the agent can see # all available public theorems (avoids hallucinating theorem names). if function_spec.is_proof_required and self.task.lean_spec_module: spec_path = ( Path(__file__).parent.parent / "lean" / f"{self.task.lean_spec_module}.lean" ) if spec_path.exists(): lean_content = spec_path.read_text() parts.append( f"\n" f"```lean\n{lean_content}\n```\n" f"" ) feedback = "\n\n".join(parts) return INSPECT_REWARD, feedback, None, build_breakdown(INSPECT_REWARD) def _handle_analyze_deps( self, action: AnalyzeDepsAction ) -> tuple[float, str, bool | None, dict[str, float]]: function_spec = self._resolve_function(action.function_name) if function_spec is None: valid = [fn.name for fn in self.task.functions] return ( -0.05, f"Function '{action.function_name}' not found. Valid names: {valid}", None, build_breakdown(0.0), ) dependencies = function_spec.depends_on graph_lines = [ f" - {function_name}: {self.task.dependency_graph.get(function_name, []) or ['(none)']}" for function_name in self.task.topo_order ] plan_lines = [ f" {index + 1}. {function_name}" for index, function_name in enumerate(self.task.topo_order) ] dependency_lines = [ f" - {dependency}: {'verified' if dependency in self.verified else 'not yet verified'}" for dependency in dependencies ] or [" (none)"] feedback = ( f"Dependencies for '{action.function_name}':\n" + "\n".join(dependency_lines) + "\n\nDependency graph:\n" + "\n".join(graph_lines) + "\n\nMigration plan:\n" + "\n".join(plan_lines) ) return ANALYZE_DEPS_REWARD, feedback, None, build_breakdown( ANALYZE_DEPS_REWARD ) def _handle_run_tests( self, action: RunTestsAction ) -> tuple[float, str, bool | None, dict[str, float]]: function_spec = self._resolve_function(action.function_name) if function_spec is None: valid = [fn.name for fn in self.task.functions] return ( -0.05, f"Function '{action.function_name}' not found. Valid names: {valid}", None, build_breakdown(0.0), ) bundled_code = self._bundle_with_deps(function_spec.name, action.candidate_code) result = run_candidate_tests(self.task, function_spec, bundled_code) feedback = result.feedback self.last_tests_passed = result.tests_passed self.last_tests_total = result.tests_total if result.passed: self._record_success(function_spec.name) reward_value = RUN_TESTS_SUCCESS_REWARD else: # Detect Rust duplicate-definition errors caused by agents re-declaring # symbols that are already injected from verified dependency snippets. dep_closure = dependency_closure(self.task, function_spec.name) has_verified_deps = any(dep in self.verified_snippets for dep in dep_closure) if has_verified_deps and "E0428" in result.stderr: feedback += ( "\nNote: verified dependency code is automatically prepended to your " "submission. Do not redefine types or functions from already-verified " "dependencies. Use 'inspect' to see exactly what code will be injected." ) reward_value = -0.01 * max(1, result.tests_total - result.tests_passed) feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value) breakdown = build_breakdown( 0.0, property_score=result.tests_passed / result.tests_total if result.tests_total else 0.0, ) return reward_value, feedback, None, breakdown def _handle_submit( self, action: SubmitAction ) -> tuple[float, str, bool | None, dict[str, float]]: function_spec = self._resolve_function(action.function_name) if function_spec is None: valid = [fn.name for fn in self.task.functions] return ( -0.05, f"Function '{action.function_name}' not found. Valid names: {valid}", None, build_breakdown(0.0), ) unmet_dependencies = [ dependency for dependency in function_spec.depends_on if dependency not in self.verified ] if unmet_dependencies: feedback = ( f"Cannot submit '{action.function_name}' yet.\n" f"Unmet dependencies: {unmet_dependencies}\n" f"Verify these first: {unmet_dependencies}" ) return -0.05, feedback, None, build_breakdown(0.0) ir_result = None if not function_spec.is_proof_required: target_code = action.target_code if not target_code: return ( -0.05, f"'{action.function_name}' has no code to submit. Run `run_tests` first so the system can record your implementation.", None, build_breakdown(0.0), ) bundled_target = self._bundle_with_deps(function_spec.name, target_code) ir_result = build_verification_ir(self.task, function_spec, bundled_target) if not ir_result.ready or ir_result.lean_code is None: if ir_result.run_result is not None: self.last_tests_passed = ir_result.run_result.tests_passed self.last_tests_total = ir_result.run_result.tests_total property_score = 0.0 if ( ir_result.run_result is not None and ir_result.run_result.tests_total ): property_score = ( ir_result.run_result.tests_passed / ir_result.run_result.tests_total ) return ( -0.05, ( f"REJECTED: '{action.function_name}' failed IR validation.\n" f"{ir_result.feedback}" ), False, build_breakdown(0.0, property_score=property_score), ) if function_spec.is_proof_required: if not action.lean_proof: return ( -0.05, ( f"'{action.function_name}' requires a Lean proof. Provide lean_proof before submitting." ), False, build_breakdown(0.0, proof=0.0), ) result = self.backend.verify_proof( spec_module=self.task.lean_spec_module, proof_code=action.lean_proof, ) proof_compiled = result.passed self.last_tests_passed = None self.last_tests_total = None else: result = self.backend.verify( spec_module=self.task.lean_spec_module, function_name=function_spec.name, code=ir_result.lean_code if ir_result is not None else action.target_code, symbol_name=( f"Candidate.{function_spec.name}" if ir_result is not None and ir_result.lean_code is not None else None ), sample_checks=[], ) proof_compiled = None self.last_tests_total = ( ir_result.run_result.tests_total if ir_result and ir_result.run_result else None ) self.last_tests_passed = ( self.last_tests_total if result.passed and self.last_tests_total else 0 if self.last_tests_total else None ) if result.passed: self._record_success(function_spec.name) self.verified.add(function_spec.name) self.failing.pop(function_spec.name, None) if not function_spec.is_proof_required and action.target_code: self.verified_snippets[function_spec.name] = _format_verified_dependency_code( self.task.target_language, action.target_code ) reward_value = score_step_reward(True, len(self.task.functions)) if function_spec.is_proof_required: breakdown = build_breakdown(1.0, proof=1.0) feedback = ( f"VERIFIED: '{action.function_name}' accepted by LEAN.\n" f"Latency: {result.latency_ms}ms\n" f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified." ) else: breakdown = build_breakdown( 1.0, property_score=1.0 if self.last_tests_total else 0.0, ) feedback = ( f"VERIFIED: '{action.function_name}' accepted by IR + LEAN.\n" f"{ir_result.feedback if ir_result is not None else ''}\n" f"Latency: {result.latency_ms}ms\n" f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified." ) else: self.failing[function_spec.name] = result.error reward_value = score_step_reward(False, len(self.task.functions)) breakdown = build_breakdown(0.0) if function_spec.is_proof_required: feedback = ( f"REJECTED: '{action.function_name}' failed LEAN verification.\n" f"Latency: {result.latency_ms}ms\n\n" f"LEAN error:\n{result.error}\n\n" f"Hint: Check that your code matches the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec." ) else: feedback = ( f"REJECTED: '{action.function_name}' failed LEAN verification.\n" f"{ir_result.feedback if ir_result is not None else ''}\n\n" f"Latency: {result.latency_ms}ms\n\n" f"LEAN error:\n{result.error}\n\n" f"Hint: Check that your code matches the generated Lean mirror and the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec." ) feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value) return reward_value, feedback, proof_compiled, breakdown