lean-migrate / env /state.py
Hrushi's picture
Upload folder using huggingface_hub
bf9c466 verified
"""Episode state machine and action dispatch.
EpisodeState is the core of the environment. One instance lives per episode.
apply() is the central dispatch that handles the four action types:
- inspect → surface source + Lean spec, small shaping reward
- analyze_deps → surface dependency graph, small shaping reward
- run_tests → execute candidate_code, score against oracle
- submit → build verification IR, run Lean backend, update verified set
Verified functions are stored in self.verified (set) and their code snippets
in self.verified_snippets (dict). Dependencies are injected automatically
via _bundle_with_deps() before testing or submission.
"""
from __future__ import annotations
import functools
import re
import shutil
import subprocess
import tempfile
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from openenv.core.env_server.types import State
from .grader import (
clamp_open_unit,
build_breakdown,
run_candidate_tests,
score_progress,
score_step_reward,
)
from .models import (
AnalyzeDepsAction,
InspectAction,
LeanMigrateAction,
LeanMigrateObservation,
LeanMigrateReward,
RunTestsAction,
SubmitAction,
)
from .submission_parsers import _camel_to_snake
from .target_snippets import dependency_closure
from .tasks import FunctionSpec, Task
from .verification_ir import build_verification_ir
from ..lean_backend.interface import LeanBackend
from ..lean_backend.kimina_backend import get_backend
INSPECT_REWARD = 0.05
ANALYZE_DEPS_REWARD = 0.05
RUN_TESTS_SUCCESS_REWARD = 0.10
@functools.lru_cache(maxsize=128)
def _format_rust_snippet(snippet: str) -> str:
rustfmt = shutil.which("rustfmt")
if rustfmt is None or not snippet.strip():
return snippet
# # Rust enums used as simple tags in the verified dependency bundle should
# # be Copy so downstream submissions can pass them by value without fighting
# # move semantics.
# if "pub enum Op" in snippet and "derive(Copy, Clone)" not in snippet:
# snippet = snippet.replace(
# "pub enum Op",
# "#[derive(Copy, Clone)]\npub enum Op",
# 1,
# )
with tempfile.TemporaryDirectory() as temp_dir:
snippet_path = Path(temp_dir) / "snippet.rs"
snippet_path.write_text(snippet.rstrip() + "\n")
try:
completed = subprocess.run(
[rustfmt, "--edition", "2021", str(snippet_path)],
capture_output=True,
text=True,
timeout=5,
)
except Exception:
return snippet
if completed.returncode != 0:
return snippet
try:
formatted = snippet_path.read_text().strip()
except OSError:
return snippet
return formatted or snippet
def _format_verified_dependency_code(language: str, code: str) -> str:
if language == "rust":
return _format_rust_snippet(code)
return code
def _extract_source_symbol(source_fragment: str) -> str | None:
for raw_line in source_fragment.splitlines():
line = raw_line.strip()
if not line or line.startswith("//") or line.startswith("/*") or line.startswith("*"):
continue
for pattern in (
r"^(?:pub\s+)?(?:unsafe\s+)?fn\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
r"^(?:def|function)\s+([A-Za-z_][A-Za-z0-9_]*)\s*\(",
r"^(?:[A-Za-z_][A-Za-z0-9_:\<\>\*&\s]*?)\b([A-Za-z_][A-Za-z0-9_]*)\s*\(",
):
match = re.match(pattern, line)
if match is not None:
return match.group(1)
return None
@dataclass
class EpisodeState:
episode_id: str
task: Task
backend: LeanBackend
step_count: int = 0
verified: set[str] = field(default_factory=set)
failing: dict[str, str] = field(default_factory=dict)
last_feedback: str | None = None
last_action_type: str | None = None
last_step_reward: float | None = None
last_tests_passed: int | None = None
last_tests_total: int | None = None
last_reward_details: LeanMigrateReward | None = None
cumulative_score: float = 0.0
verified_snippets: dict[str, str] = field(default_factory=dict)
consecutive_failures: dict[str, int] = field(default_factory=dict)
@classmethod
def from_task(
cls,
task: Task,
backend: LeanBackend | None = None,
episode_id: str | None = None,
) -> "EpisodeState":
return cls(
episode_id=episode_id or str(uuid.uuid4()),
task=task,
backend=backend or get_backend(),
)
@property
def remaining(self) -> list[str]:
return [
function_name
for function_name in self.task.function_names
if function_name not in self.verified
]
@property
def ready_to_submit(self) -> list[str]:
ready_functions: list[str] = []
for function_name in self.remaining:
dependencies = self.task.dependency_graph.get(function_name, [])
if all(dependency in self.verified for dependency in dependencies):
ready_functions.append(function_name)
return ready_functions
@property
def progress(self) -> float:
return score_progress(len(self.verified), len(self.task.functions))
@property
def suggested_function(self) -> str | None:
ready_functions = self.ready_to_submit
return ready_functions[0] if ready_functions else None
def to_observation(self) -> LeanMigrateObservation:
return LeanMigrateObservation(
episode_id=self.episode_id,
task_id=self.task.task_id,
episode_step=self.step_count,
max_steps=self.task.max_steps,
source_language=self.task.source_language,
target_language=self.task.target_language,
source_files=self.task.source_files,
verified=sorted(self.verified),
failing=dict(self.failing),
remaining=self.remaining,
progress=self.progress,
last_action_type=self.last_action_type,
last_action_feedback=self.last_feedback,
last_step_reward=self.last_step_reward,
done=False,
reward=self.last_step_reward,
reward_details=self.last_reward_details,
)
def apply(
self, action: LeanMigrateAction
) -> tuple[LeanMigrateObservation, LeanMigrateReward, bool, dict[str, object]]:
self.step_count += 1
self.last_action_type = action.type
self.last_tests_passed = None
self.last_tests_total = None
if action.type == "inspect":
reward_value, feedback, proof_compiled, breakdown = self._handle_inspect(
action
)
elif action.type == "analyze_deps":
reward_value, feedback, proof_compiled, breakdown = (
self._handle_analyze_deps(action)
)
elif action.type == "run_tests":
reward_value, feedback, proof_compiled, breakdown = self._handle_run_tests(
action
)
elif action.type == "submit":
reward_value, feedback, proof_compiled, breakdown = self._handle_submit(
action
)
else:
reward_value, feedback, proof_compiled, breakdown = (
-0.05,
"Unknown action type.",
None,
build_breakdown(0.0),
)
self.last_feedback = feedback
self.last_step_reward = reward_value
self.cumulative_score = clamp_open_unit(self.progress)
done = self.progress >= 1.0 or self.step_count >= self.task.max_steps
raw_fn_name = getattr(action, "function_name", "")
fn_spec_for_error = self._resolve_function(raw_fn_name)
canonical_fn_name = fn_spec_for_error.name if fn_spec_for_error else raw_fn_name
lean_error = self.failing.get(canonical_fn_name)
reward = LeanMigrateReward(
score=reward_value,
cumulative_score=self.cumulative_score,
tests_passed=self.last_tests_passed,
tests_total=self.last_tests_total,
proof_compiled=proof_compiled,
breakdown=breakdown,
feedback=feedback,
lean_error=lean_error,
)
self.last_reward_details = reward
observation = self.to_observation()
observation.done = done
observation.reward = reward_value
observation.reward_details = reward
return (
observation,
reward,
done,
{
"progress": self.progress,
"verified": sorted(self.verified),
"step": self.step_count,
},
)
def _record_failure(self, fn_name: str, feedback: str, reward: float) -> tuple[str, float]:
"""Increment consecutive failure count; escalate feedback and penalty after thresholds."""
n = self.consecutive_failures.get(fn_name, 0) + 1
self.consecutive_failures[fn_name] = n
if n >= 5:
extra = (
f"\n\n⚠ REPEATED FAILURE ({n}x on '{fn_name}'): "
"Your last several attempts produced the same error. "
"You MUST try a fundamentally different approach: "
"use 'inspect' to re-read the spec, reconsider the function signature, "
"or implement the logic differently."
)
reward = min(reward - 0.05, -0.1)
elif n >= 3:
extra = (
f"\n\nHint: '{fn_name}' has failed {n} times in a row. "
"Consider using 'inspect' to re-read the spec and trying a different approach."
)
else:
extra = ""
return feedback + extra, reward
def _record_success(self, fn_name: str) -> None:
"""Reset consecutive failure count on any success."""
self.consecutive_failures.pop(fn_name, None)
def _resolve_function(self, name: str) -> FunctionSpec | None:
"""Exact match first; fall back to snake_case normalisation for Rust tasks.
Agents working on Rust targets naturally use snake_case (e.g.
``merge_intervals``) while the spec stores camelCase (``mergeIntervals``).
This normalises the lookup so both forms resolve to the same spec.
The canonical spec.name is always used for downstream state keys.
"""
spec = self.task.get_function(name)
if spec is not None:
return spec
for fn in self.task.functions:
if _camel_to_snake(fn.name) == name:
return fn
return None
def _bundle_with_deps(self, function_name: str, snippet: str) -> str:
"""Prepend verified dependency snippets to the given code snippet."""
parts = [
self.verified_snippets[dep]
for dep in dependency_closure(self.task, function_name)
if dep in self.verified_snippets
]
parts.append(snippet)
return "\n\n".join(parts).strip()
def _handle_inspect(
self, action: InspectAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
lang = self.task.source_language
target_lang = self.task.target_language
parts: list[str] = []
parts.append(
f"<FUNCTION_INFO>\n"
f"Name: {function_spec.name}\n"
f"Source symbol: {_extract_source_symbol(function_spec.source_fragment) or '(unavailable)'}\n"
f"Description: {function_spec.description}\n"
f"Depends on: {function_spec.depends_on or ['(none)']}\n"
f"Proof required: {function_spec.is_proof_required}\n"
f"</FUNCTION_INFO>"
)
parts.append(
f"<SOURCE language=\"{lang}\">\n"
f"{function_spec.source_fragment}\n"
f"</SOURCE>"
)
parts.append(
f"<LEAN_SPEC>\n"
f"{function_spec.lean_fragment}\n"
f"</LEAN_SPEC>"
)
# Show verified dependency snippets so the agent knows exactly what code
# is auto-prepended to its submission. Without this, agents redefine symbols
# already present in the bundle, causing duplicate-definition compile errors
# (especially fatal in Rust which forbids duplicate definitions).
dep_closure = dependency_closure(self.task, function_spec.name)
verified_deps = [
(dep, self.verified_snippets[dep])
for dep in dep_closure
if dep in self.verified_snippets
]
if verified_deps:
dep_blocks: list[str] = [
"The following dependency code is automatically prepended to your "
"submission. Do NOT redefine these symbols.\n"
]
for dep_name, dep_code in verified_deps:
formatted_dep_code = _format_verified_dependency_code(target_lang, dep_code)
dep_blocks.append(
f"{dep_name} (verified):\n"
f"```{target_lang}\n{formatted_dep_code}\n```"
)
parts.append(
"<VERIFIED_DEPS>\n"
+ "\n\n".join(dep_blocks)
+ "\n</VERIFIED_DEPS>"
)
# For proof tasks, append the full Lean spec file so the agent can see
# all available public theorems (avoids hallucinating theorem names).
if function_spec.is_proof_required and self.task.lean_spec_module:
spec_path = (
Path(__file__).parent.parent
/ "lean"
/ f"{self.task.lean_spec_module}.lean"
)
if spec_path.exists():
lean_content = spec_path.read_text()
parts.append(
f"<FULL_LEAN_SPEC file=\"{self.task.lean_spec_module}.lean\">\n"
f"```lean\n{lean_content}\n```\n"
f"</FULL_LEAN_SPEC>"
)
feedback = "\n\n".join(parts)
return INSPECT_REWARD, feedback, None, build_breakdown(INSPECT_REWARD)
def _handle_analyze_deps(
self, action: AnalyzeDepsAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
dependencies = function_spec.depends_on
graph_lines = [
f" - {function_name}: {self.task.dependency_graph.get(function_name, []) or ['(none)']}"
for function_name in self.task.topo_order
]
plan_lines = [
f" {index + 1}. {function_name}"
for index, function_name in enumerate(self.task.topo_order)
]
dependency_lines = [
f" - {dependency}: {'verified' if dependency in self.verified else 'not yet verified'}"
for dependency in dependencies
] or [" (none)"]
feedback = (
f"Dependencies for '{action.function_name}':\n"
+ "\n".join(dependency_lines)
+ "\n\nDependency graph:\n"
+ "\n".join(graph_lines)
+ "\n\nMigration plan:\n"
+ "\n".join(plan_lines)
)
return ANALYZE_DEPS_REWARD, feedback, None, build_breakdown(
ANALYZE_DEPS_REWARD
)
def _handle_run_tests(
self, action: RunTestsAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
bundled_code = self._bundle_with_deps(function_spec.name, action.candidate_code)
result = run_candidate_tests(self.task, function_spec, bundled_code)
feedback = result.feedback
self.last_tests_passed = result.tests_passed
self.last_tests_total = result.tests_total
if result.passed:
self._record_success(function_spec.name)
reward_value = RUN_TESTS_SUCCESS_REWARD
else:
# Detect Rust duplicate-definition errors caused by agents re-declaring
# symbols that are already injected from verified dependency snippets.
dep_closure = dependency_closure(self.task, function_spec.name)
has_verified_deps = any(dep in self.verified_snippets for dep in dep_closure)
if has_verified_deps and "E0428" in result.stderr:
feedback += (
"\nNote: verified dependency code is automatically prepended to your "
"submission. Do not redefine types or functions from already-verified "
"dependencies. Use 'inspect' to see exactly what code will be injected."
)
reward_value = -0.01 * max(1, result.tests_total - result.tests_passed)
feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value)
breakdown = build_breakdown(
0.0,
property_score=result.tests_passed / result.tests_total
if result.tests_total
else 0.0,
)
return reward_value, feedback, None, breakdown
def _handle_submit(
self, action: SubmitAction
) -> tuple[float, str, bool | None, dict[str, float]]:
function_spec = self._resolve_function(action.function_name)
if function_spec is None:
valid = [fn.name for fn in self.task.functions]
return (
-0.05,
f"Function '{action.function_name}' not found. Valid names: {valid}",
None,
build_breakdown(0.0),
)
unmet_dependencies = [
dependency
for dependency in function_spec.depends_on
if dependency not in self.verified
]
if unmet_dependencies:
feedback = (
f"Cannot submit '{action.function_name}' yet.\n"
f"Unmet dependencies: {unmet_dependencies}\n"
f"Verify these first: {unmet_dependencies}"
)
return -0.05, feedback, None, build_breakdown(0.0)
ir_result = None
if not function_spec.is_proof_required:
target_code = action.target_code
if not target_code:
return (
-0.05,
f"'{action.function_name}' has no code to submit. Run `run_tests` first so the system can record your implementation.",
None,
build_breakdown(0.0),
)
bundled_target = self._bundle_with_deps(function_spec.name, target_code)
ir_result = build_verification_ir(self.task, function_spec, bundled_target)
if not ir_result.ready or ir_result.lean_code is None:
if ir_result.run_result is not None:
self.last_tests_passed = ir_result.run_result.tests_passed
self.last_tests_total = ir_result.run_result.tests_total
property_score = 0.0
if (
ir_result.run_result is not None
and ir_result.run_result.tests_total
):
property_score = (
ir_result.run_result.tests_passed
/ ir_result.run_result.tests_total
)
return (
-0.05,
(
f"REJECTED: '{action.function_name}' failed IR validation.\n"
f"{ir_result.feedback}"
),
False,
build_breakdown(0.0, property_score=property_score),
)
if function_spec.is_proof_required:
if not action.lean_proof:
return (
-0.05,
(
f"'{action.function_name}' requires a Lean proof. Provide lean_proof before submitting."
),
False,
build_breakdown(0.0, proof=0.0),
)
result = self.backend.verify_proof(
spec_module=self.task.lean_spec_module,
proof_code=action.lean_proof,
)
proof_compiled = result.passed
self.last_tests_passed = None
self.last_tests_total = None
else:
result = self.backend.verify(
spec_module=self.task.lean_spec_module,
function_name=function_spec.name,
code=ir_result.lean_code
if ir_result is not None
else action.target_code,
symbol_name=(
f"Candidate.{function_spec.name}"
if ir_result is not None and ir_result.lean_code is not None
else None
),
sample_checks=[],
)
proof_compiled = None
self.last_tests_total = (
ir_result.run_result.tests_total
if ir_result and ir_result.run_result
else None
)
self.last_tests_passed = (
self.last_tests_total
if result.passed and self.last_tests_total
else 0
if self.last_tests_total
else None
)
if result.passed:
self._record_success(function_spec.name)
self.verified.add(function_spec.name)
self.failing.pop(function_spec.name, None)
if not function_spec.is_proof_required and action.target_code:
self.verified_snippets[function_spec.name] = _format_verified_dependency_code(
self.task.target_language, action.target_code
)
reward_value = score_step_reward(True, len(self.task.functions))
if function_spec.is_proof_required:
breakdown = build_breakdown(1.0, proof=1.0)
feedback = (
f"VERIFIED: '{action.function_name}' accepted by LEAN.\n"
f"Latency: {result.latency_ms}ms\n"
f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified."
)
else:
breakdown = build_breakdown(
1.0,
property_score=1.0 if self.last_tests_total else 0.0,
)
feedback = (
f"VERIFIED: '{action.function_name}' accepted by IR + LEAN.\n"
f"{ir_result.feedback if ir_result is not None else ''}\n"
f"Latency: {result.latency_ms}ms\n"
f"Progress: {len(self.verified)}/{len(self.task.functions)} functions verified."
)
else:
self.failing[function_spec.name] = result.error
reward_value = score_step_reward(False, len(self.task.functions))
breakdown = build_breakdown(0.0)
if function_spec.is_proof_required:
feedback = (
f"REJECTED: '{action.function_name}' failed LEAN verification.\n"
f"Latency: {result.latency_ms}ms\n\n"
f"LEAN error:\n{result.error}\n\n"
f"Hint: Check that your code matches the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec."
)
else:
feedback = (
f"REJECTED: '{action.function_name}' failed LEAN verification.\n"
f"{ir_result.feedback if ir_result is not None else ''}\n\n"
f"Latency: {result.latency_ms}ms\n\n"
f"LEAN error:\n{result.error}\n\n"
f"Hint: Check that your code matches the generated Lean mirror and the LEAN spec exactly. Use 'inspect {action.function_name}' to review the spec."
)
feedback, reward_value = self._record_failure(function_spec.name, feedback, reward_value)
return reward_value, feedback, proof_compiled, breakdown