"""Episode state machine and action dispatch.
EpisodeState is the core of the environment. One instance lives per episode.
apply() is the central dispatch that handles the four action types:
- inspect → surface source + Lean spec, small shaping reward
- analyze_deps → surface dependency graph, small shaping reward
- run_tests → execute candidate_code, score against oracle
- submit → build verification IR, run Lean backend, update verified set
Verified functions are stored in self.verified (set) and their code snippets
in self.verified_snippets (dict). Dependencies are injected automatically
via _bundle_with_deps() before testing or submission.
"""
from __future__ import annotations
import functools
import re
import shutil
import subprocess
import tempfile
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from openenv.core.env_server.types import State
from .grader import (
clamp_open_unit,
build_breakdown,
run_candidate_tests,
score_progress,
score_step_reward,
)
from .models import (
AnalyzeDepsAction,
InspectAction,
LeanMigrateAction,
LeanMigrateObservation,
LeanMigrateReward,
RunTestsAction,
SubmitAction,
)
from .submission_parsers import _camel_to_snake
from .target_snippets import dependency_closure
from .tasks import FunctionSpec, Task
from .verification_ir import build_verification_ir
from ..lean_backend.interface import LeanBackend
from ..lean_backend.kimina_backend import get_backend
INSPECT_REWARD = 0.05
ANALYZE_DEPS_REWARD = 0.05
RUN_TESTS_SUCCESS_REWARD = 0.10
@functools.lru_cache(maxsize=128)
def _format_rust_snippet(snippet: str) -> str:
rustfmt = shutil.which("rustfmt")
if rustfmt is None or not snippet.strip():
return snippet
# # Rust enums used as simple tags in the verified dependency bundle should
# # be Copy so downstream submissions can pass them by value without fighting
# # move semantics.
# if "pub enum Op" in snippet and "derive(Copy, Clone)" not in snippet:
# snippet = snippet.replace(
# "pub enum Op",
# "#[derive(Copy, Clone)]\npub enum Op",
# 1,
# )
with tempfile.TemporaryDirectory() as temp_dir:
snippet_path = Path(temp_dir) / "snippet.rs"
snippet_path.write_text(snippet.rstrip() + "\n")
try:
completed = subprocess.run(
[rustfmt, "--edition", "2021", str(snippet_path)],
capture_output=True,
text=True,
timeout=5,
)
except Exception:
return snippet
if completed.returncode != 0:
return snippet
try:
formatted = snippet_path.read_text().strip()
except OSError:
return snippet
return formatted or snippet
def _format_verified_dependency_code(language: str, code: str) -> str:
if language == "rust":
return _format_rust_snippet(code)
return code
def _extract_source_symbol(source_fragment: str) -> str | None:
for raw_line in source_fragment.splitlines():
line = raw_line.strip()
if not line or line.startswith("//") or line.startswith("/*") or line.startswith("*"):
continue
for pattern in (
r"^(?:pub\s+)?(?:unsafe\s+)?fn\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
r"^(?:def|function)\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
r"^(?:[A-Za-z_][A-Za-z0-9_:\<\>\*&\s]*?)\b([A-Za-z_][A-Za-z0-9_]*)\s*\(",
):
match = re.match(pattern, line)
if match is not None:
return match.group(1)
return None
@dataclass
class EpisodeState:
episode_id: str
task: Task
backend: LeanBackend
step_count: int = 0
verified: set[str] = field(default_factory=set)
failing: dict[str, str] = field(default_factory=dict)
last_feedback: str | None = None
last_action_type: str | None = None
last_step_reward: float | None = None
last_tests_passed: int | None = None
last_tests_total: int | None = None
last_reward_details: LeanMigrateReward | None = None
cumulative_score: float = 0.0
verified_snippets: dict[str, str] = field(default_factory=dict)
consecutive_failures: dict[str, int] = field(default_factory=dict)
@classmethod
def from_task(
cls,
task: Task,
backend: LeanBackend | None = None,
episode_id: str | None = None,
) -> "EpisodeState":
return cls(
episode_id=episode_id or str(uuid.uuid4()),
task=task,
backend=backend or get_backend(),
)
@property
def remaining(self) -> list[str]:
return [
function_name
for function_name in self.task.function_names
if function_name not in self.verified
]
@property
def ready_to_submit(self) -> list[str]:
ready_functions: list[str] = []
for function_name in self.remaining:
dependencies = self.task.dependency_graph.get(function_name, [])
if all(dependency in self.verified for dependency in dependencies):
ready_functions.append(function_name)
return ready_functions
@property
def progress(self) -> float:
return score_progress(len(self.verified), len(self.task.functions))
@property
def suggested_function(self) -> str | None:
ready_functions = self.ready_to_submit
return ready_functions[0] if ready_functions else None
def to_observation(self) -> LeanMigrateObservation:
return LeanMigrateObservation(
episode_id=self.episode_id,
task_id=self.task.task_id,
episode_step=self.step_count,
max_steps=self.task.max_steps,
source_language=self.task.source_language,
target_language=self.task.target_language,
source_files=self.task.source_files,
verified=sorted(self.verified),
failing=dict(self.failing),
remaining=self.remaining,
progress=self.progress,
last_action_type=self.last_action_type,
last_action_feedback=self.last_feedback,
last_step_reward=self.last_step_reward,
done=False,
reward=self.last_step_reward,
reward_details=self.last_reward_details,
)
def apply(
self, action: LeanMigrateAction
) -> tuple[LeanMigrateObservation, LeanMigrateReward, bool, dict[str, object]]:
self.step_count += 1
self.last_action_type = action.type
self.last_tests_passed = None
self.last_tests_total = None
if action.type == "inspect":
reward_value, feedback, proof_compiled, breakdown = self._handle_inspect(
action
)
elif action.type == "analyze_deps":
reward_value, feedback, proof_compiled, breakdown = (
self._handle_analyze_deps(action)
)
elif action.type == "run_tests":
reward_value, feedback, proof_compiled, breakdown = self._handle_run_tests(
action
)
elif action.type == "submit":
reward_value, feedback, proof_compiled, breakdown = self._handle_submit(
action
)
else:
reward_value, feedback, proof_compiled, breakdown = (
-0.05,
"Unknown action type.",
None,
build_breakdown(0.0),
)
self.last_feedback = feedback
self.last_step_reward = reward_value
self.cumulative_score = clamp_open_unit(self.progress)
done = self.progress >= 1.0 or self.step_count >= self.task.max_steps
raw_fn_name = getattr(action, "function_name", "")
fn_spec_for_error = self._resolve_function(raw_fn_name)
canonical_fn_name = fn_spec_for_error.name if fn_spec_for_error else raw_fn_name
lean_error = self.failing.get(canonical_fn_name)
reward = LeanMigrateReward(
score=reward_value,
cumulative_score=self.cumulative_score,
tests_passed=self.last_tests_passed,
tests_total=self.last_tests_total,
proof_compiled=proof_compiled,
breakdown=breakdown,
feedback=feedback,
lean_error=lean_error,
)
self.last_reward_details = reward
observation = self.to_observation()
observation.done = done
observation.reward = reward_value
observation.reward_details = reward
return (
observation,
reward,
done,
{
"progress": self.progress,
"verified": sorted(self.verified),
"step": self.step_count,
},
)
def _record_failure(self, fn_name: str, feedback: str, reward: float) -> tuple[str, float]:
"""Increment consecutive failure count; escalate feedback and penalty after thresholds."""
n = self.consecutive_failures.get(fn_name, 0) + 1
self.consecutive_failures[fn_name] = n
if n >= 5:
extra = (
f"\n\n⚠ REPEATED FAILURE ({n}x on '{fn_name}'): "
"Your last several attempts produced the same error. "
"You MUST try a fundamentally different approach: "
"use 'inspect' to re-read the spec, reconsider the function signature, "
"or implement the logic differently."
)
reward = min(reward - 0.05, -0.1)
elif n >= 3:
extra = (
f"\n\nHint: '{fn_name}' has failed {n} times in a row. "
"Consider using 'inspect' to re-read the spec and trying a different approach."
)
else:
extra = ""
return feedback + extra, reward
def _record_success(self, fn_name: str) -> None:
"""Reset consecutive failure count on any success."""
self.consecutive_failures.pop(fn_name, None)
def _resolve_function(self, name: str) -> FunctionSpec | None:
"""Exact match first; fall back to snake_case normalisation for Rust tasks.
Agents working on Rust targets naturally use snake_case (e.g.
``merge_intervals``) while the spec stores camelCase (``mergeIntervals``).
This normalises the lookup so both forms resolve to the same spec.
The canonical spec.name is always used for downstream state keys.
"""
spec = self.task.get_function(name)
if spec is not None:
return spec
for fn in self.task.functions:
if _camel_to_snake(fn.name) == name:
return fn
return None
def _bundle_with_deps(self, function_name: str, snippet: str) -> str:
"""Prepend verified dependency snippets to the given code snippet."""
parts = [
self.verified_snippets[dep]
for dep in dependency_closure(self.task, function_name)
if dep in self.verified_snippets
]
parts.append(snippet)
return "\n\n".join(parts).strip()
def _handle_inspect(
self, action: InspectAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
lang = self.task.source_language
target_lang = self.task.target_language
parts: list[str] = []
parts.append(
f"\n"
f"Name: {function_spec.name}\n"
f"Source symbol: {_extract_source_symbol(function_spec.source_fragment) or '(unavailable)'}\n"
f"Description: {function_spec.description}\n"
f"Depends on: {function_spec.depends_on or ['(none)']}\n"
f"Proof required: {function_spec.is_proof_required}\n"
f""
)
parts.append(
f"\n"
f"{function_spec.source_fragment}\n"
f""
)
parts.append(
f"\n"
f"{function_spec.lean_fragment}\n"
f""
)
# Show verified dependency snippets so the agent knows exactly what code
# is auto-prepended to its submission. Without this, agents redefine symbols
# already present in the bundle, causing duplicate-definition compile errors
# (especially fatal in Rust which forbids duplicate definitions).
dep_closure = dependency_closure(self.task, function_spec.name)
verified_deps = [
(dep, self.verified_snippets[dep])
for dep in dep_closure
if dep in self.verified_snippets
]
if verified_deps:
dep_blocks: list[str] = [
"The following dependency code is automatically prepended to your "
"submission. Do NOT redefine these symbols.\n"
]
for dep_name, dep_code in verified_deps:
formatted_dep_code = _format_verified_dependency_code(target_lang, dep_code)
dep_blocks.append(
f"{dep_name} (verified):\n"
f"```{target_lang}\n{formatted_dep_code}\n```"
)
parts.append(
"\n"
+ "\n\n".join(dep_blocks)
+ "\n"
)
# For proof tasks, append the full Lean spec file so the agent can see
# all available public theorems (avoids hallucinating theorem names).
if function_spec.is_proof_required and self.task.lean_spec_module:
spec_path = (
Path(__file__).parent.parent
/ "lean"
/ f"{self.task.lean_spec_module}.lean"
)
if spec_path.exists():
lean_content = spec_path.read_text()
parts.append(
f"\n"
f"```lean\n{lean_content}\n```\n"
f""
)
feedback = "\n\n".join(parts)
return INSPECT_REWARD, feedback, None, build_breakdown(INSPECT_REWARD)
def _handle_analyze_deps(
self, action: AnalyzeDepsAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
dependencies = function_spec.depends_on
graph_lines = [
f" - {function_name}: {self.task.dependency_graph.get(function_name, []) or ['(none)']}"
for function_name in self.task.topo_order
]
plan_lines = [
f" {index + 1}. {function_name}"
for index, function_name in enumerate(self.task.topo_order)
]
dependency_lines = [
f" - {dependency}: {'verified' if dependency in self.verified else 'not yet verified'}"
for dependency in dependencies
] or [" (none)"]
feedback = (
f"Dependencies for '{action.function_name}':\n"
+ "\n".join(dependency_lines)
+ "\n\nDependency graph:\n"
+ "\n".join(graph_lines)
+ "\n\nMigration plan:\n"
+ "\n".join(plan_lines)
)
return ANALYZE_DEPS_REWARD, feedback, None, build_breakdown(
ANALYZE_DEPS_REWARD
)
def _handle_run_tests(
self, action: RunTestsAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
bundled_code = self._bundle_with_deps(function_spec.name, action.candidate_code)
result = run_candidate_tests(self.task, function_spec, bundled_code)
feedback = result.feedback
self.last_tests_passed = result.tests_passed
self.last_tests_total = result.tests_total
if result.passed:
self._record_success(function_spec.name)
reward_value = RUN_TESTS_SUCCESS_REWARD
else:
# Detect Rust duplicate-definition errors caused by agents re-declaring
# symbols that are already injected from verified dependency snippets.
dep_closure = dependency_closure(self.task, function_spec.name)
has_verified_deps = any(dep in self.verified_snippets for dep in dep_closure)
if has_verified_deps and "E0428" in result.stderr:
feedback += (
"\nNote: verified dependency code is automatically prepended to your "
"submission. Do not redefine types or functions from already-verified "
"dependencies. Use 'inspect' to see exactly what code will be injected."
)
reward_value = -0.01 * max(1, result.tests_total - result.tests_passed)
feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value)
breakdown = build_breakdown(
0.0,
property_score=result.tests_passed / result.tests_total
if result.tests_total
else 0.0,
)
return reward_value, feedback, None, breakdown
def _handle_submit(
self, action: SubmitAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
unmet_dependencies = [
dependency
for dependency in function_spec.depends_on
if dependency not in self.verified
]
if unmet_dependencies:
feedback = (
f"Cannot submit '{action.function_name}' yet.\n"
f"Unmet dependencies: {unmet_dependencies}\n"
f"Verify these first: {unmet_dependencies}"
)
return -0.05, feedback, None, build_breakdown(0.0)
ir_result = None
if not function_spec.is_proof_required:
target_code = action.target_code
if not target_code:
return (
-0.05,
f"'{action.function_name}' has no code to submit. Run `run_tests` first so the system can record your implementation.",
None,
build_breakdown(0.0),
)
bundled_target = self._bundle_with_deps(function_spec.name, target_code)
ir_result = build_verification_ir(self.task, function_spec, bundled_target)
if not ir_result.ready or ir_result.lean_code is None:
if ir_result.run_result is not None:
self.last_tests_passed = ir_result.run_result.tests_passed
self.last_tests_total = ir_result.run_result.tests_total
property_score = 0.0
if (
ir_result.run_result is not None
and ir_result.run_result.tests_total
):
property_score = (
ir_result.run_result.tests_passed
/ ir_result.run_result.tests_total
)
return (
-0.05,
(
f"REJECTED: '{action.function_name}' failed IR validation.\n"
f"{ir_result.feedback}"
),
False,
build_breakdown(0.0, property_score=property_score),
)
if function_spec.is_proof_required:
if not action.lean_proof:
return (
-0.05,
(
f"'{action.function_name}' requires a Lean proof. Provide lean_proof before submitting."
),
False,
build_breakdown(0.0, proof=0.0),
)
result = self.backend.verify_proof(
spec_module=self.task.lean_spec_module,
proof_code=action.lean_proof,
)
proof_compiled = result.passed
self.last_tests_passed = None
self.last_tests_total = None
else:
result = self.backend.verify(
spec_module=self.task.lean_spec_module,
function_name=function_spec.name,
code=ir_result.lean_code
if ir_result is not None
else action.target_code,
symbol_name=(
f"Candidate.{function_spec.name}"
if ir_result is not None and ir_result.lean_code is not None
else None
),
sample_checks=[],
)
proof_compiled = None
self.last_tests_total = (
ir_result.run_result.tests_total
if ir_result and ir_result.run_result
else None
)
self.last_tests_passed = (
self.last_tests_total
if result.passed and self.last_tests_total
else 0
if self.last_tests_total
else None
)
if result.passed:
self._record_success(function_spec.name)
self.verified.add(function_spec.name)
self.failing.pop(function_spec.name, None)
if not function_spec.is_proof_required and action.target_code:
self.verified_snippets[function_spec.name] = _format_verified_dependency_code(
self.task.target_language, action.target_code
)
reward_value = score_step_reward(True, len(self.task.functions))
if function_spec.is_proof_required:
breakdown = build_breakdown(1.0, proof=1.0)
feedback = (
f"VERIFIED: '{action.function_name}' accepted by LEAN.\n"
f"Latency: {result.latency_ms}ms\n"
f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified."
)
else:
breakdown = build_breakdown(
1.0,
property_score=1.0 if self.last_tests_total else 0.0,
)
feedback = (
f"VERIFIED: '{action.function_name}' accepted by IR + LEAN.\n"
f"{ir_result.feedback if ir_result is not None else ''}\n"
f"Latency: {result.latency_ms}ms\n"
f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified."
)
else:
self.failing[function_spec.name] = result.error
reward_value = score_step_reward(False, len(self.task.functions))
breakdown = build_breakdown(0.0)
if function_spec.is_proof_required:
feedback = (
f"REJECTED: '{action.function_name}' failed LEAN verification.\n"
f"Latency: {result.latency_ms}ms\n\n"
f"LEAN error:\n{result.error}\n\n"
f"Hint: Check that your code matches the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec."
)
else:
feedback = (
f"REJECTED: '{action.function_name}' failed LEAN verification.\n"
f"{ir_result.feedback if ir_result is not None else ''}\n\n"
f"Latency: {result.latency_ms}ms\n\n"
f"LEAN error:\n{result.error}\n\n"
f"Hint: Check that your code matches the generated Lean mirror and the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec."
)
feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value)
return reward_value, feedback, proof_compiled, breakdown