"""Deterministic step-by-step playthrough for all LeanMigrate tasks.""" from __future__ import annotations import argparse import sys from collections.abc import Sequence from pathlib import Path from rich.columns import Columns from rich.console import Console from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax from rich.text import Text ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from lean_migrate.env.models import ( # noqa: E402 # type: ignore AnalyzeDepsAction, InspectAction, RunTestsAction, SubmitAction, ) from lean_migrate.env.grader import clamp_open_unit # noqa: E402 # type: ignore from lean_migrate.env.target_snippets import ( # noqa: E402 # type: ignore TASK_TARGET_SNIPPETS, build_submission_bundle, ) from lean_migrate.env.tasks import get_task, list_tasks # noqa: E402 # type: ignore from lean_migrate.env.verification_ir import ( # noqa: E402 # type: ignore BehaviorSummary, VerificationIRResult, build_verification_ir, ) from lean_migrate.server.lean_migrate_environment import ( # noqa: E402 # type: ignore LeanMigrateEnvironment, ) TASK_IDS = tuple(task_info["task_id"] for task_info in list_tasks()) console = Console() def _short_feedback(text: str | None, max_lines: int = 12) -> str: if not text: return "" lines = text.strip().splitlines() if len(lines) <= max_lines: return "\n".join(lines) return "\n".join(lines[:max_lines] + ["..."]) def _short_digest(text: str | None, width: int = 16) -> str: if not text: return "(n/a)" if len(text) <= width: return text return f"{text[:width]}..." def _syntax_renderable(code: str, language: str) -> Syntax | Text: try: return Syntax(code, language, theme="monokai", line_numbers=False) except Exception: return Text(code) def _print_code_preview(title: str, code: str, language: str) -> None: console.print( Panel( _syntax_renderable(code, language), title=f"[bold green]{title}[/]", border_style="green", ) ) def _summary_text(label: str, summary: BehaviorSummary | None) -> Text: if summary is None: return Text("(unavailable)", style="dim") rendered = Text() rendered.append(f"{label}\n", style="bold cyan") for field_name, value in [ ("task", summary.task_id), ("function", summary.function_name), ("language", summary.target_language), ("arity", str(summary.arity)), ("samples", str(summary.sample_count)), ("passed", str(summary.passed_count)), ("digest", _short_digest(summary.behavior_digest)), ]: rendered.append(f" {field_name}: ", style="bold cyan") rendered.append(f"{value}\n") return rendered def _print_ir_preview( task_id: str, function_name: str, target_language: str, backend_name: str, ir_result: VerificationIRResult, ) -> None: status_style = "green" if ir_result.ready else "red" summary = Text() summary.append("backend: ", style="bold cyan") summary.append(f"{backend_name}\n") summary.append("path: ", style="bold cyan") summary.append(f"{target_language} -> IR -> Lean\n") summary.append("ready: ", style="bold cyan") summary.append(f"{str(ir_result.ready).lower()}\n") if ir_result.provenance is not None: summary.append("parser: ", style="bold cyan") summary.append(f"{ir_result.provenance.parser}\n") summary.append("signature: ", style="bold cyan") summary.append( f"{ir_result.provenance.normalized_signature or '(unavailable)'}\n" ) summary.append("arity: ", style="bold cyan") summary.append( f"{ir_result.provenance.arity if ir_result.provenance.arity is not None else '(n/a)'}\n" ) summary.append("body hash: ", style="bold cyan") summary.append(f"{_short_digest(ir_result.provenance.normalized_body_hash)}\n") summary.append("source digest: ", style="bold cyan") summary.append(f"{_short_digest(ir_result.provenance.source_digest)}\n") if ir_result.run_result is not None: summary.append("samples: ", style="bold cyan") summary.append( f"{ir_result.run_result.tests_passed}/{ir_result.run_result.tests_total}\n" ) if ir_result.run_result.case_results is not None: summary.append("case traces: ", style="bold cyan") summary.append(f"{len(ir_result.run_result.case_results)}\n") if ir_result.submission_summary is not None: summary.append("behavior digest: ", style="bold cyan") summary.append( f"{_short_digest(ir_result.submission_summary.behavior_digest)}\n" ) console.print( Panel( summary, title=f"[bold magenta]IR PREVIEW[/] {task_id}:{function_name}", border_style=status_style, ) ) if ir_result.feedback: console.print( Panel( Text(ir_result.feedback), title="[bold magenta]IR TRACE[/]", border_style=status_style, ) ) if ( ir_result.submission_summary is not None and ir_result.expected_summary is not None ): console.print( Columns( [ Panel( _summary_text( "Submission summary", ir_result.submission_summary ), border_style="cyan", title="[bold cyan]SUBMISSION[/]", ), Panel( _summary_text("Expected summary", ir_result.expected_summary), border_style="green", title="[bold green]EXPECTED[/]", ), ], expand=True, ) ) if ir_result.lean_code: _print_code_preview("GENERATED LEAN", ir_result.lean_code, "lean") def _print_step_input(action_label: str, **fields: str | None) -> None: rows = [] for field_name, value in fields.items(): if value is None: continue if "\n" in value: rows.append( Text.assemble( (f"{field_name}\n", "bold cyan"), (value.rstrip(), ""), ) ) else: rows.append(Text.assemble((f"{field_name}: ", "bold cyan"), value)) body = Text() for index, row in enumerate(rows): body.append_text(row) if index != len(rows) - 1: body.append("\n\n") console.print( Panel( body if rows else Text("(no inputs)", style="dim"), title=f"[bold blue]INPUT[/] {action_label}", border_style="blue", ) ) def _print_observation(action_label: str, observation) -> None: reward = clamp_open_unit(float(observation.reward or 0.0)) status_style = ( "green" if observation.reward_details and observation.reward_details.feedback.startswith("VERIFIED") else "yellow" ) summary = Text.assemble( ("reward=", "bold cyan"), f"{reward:.3f} ", ("progress=", "bold cyan"), f"{clamp_open_unit(float(observation.progress)):.3f} ", ("done=", "bold cyan"), str(observation.done).lower(), ) console.print( Panel( summary, title=f"[bold {status_style}]ACTION[/] {action_label}", border_style=status_style, ) ) feedback = _short_feedback(getattr(observation.reward_details, "feedback", None)) if feedback: if "def " in feedback or feedback.lstrip().startswith( ("function ", "from ", "import ") ): syntax_language = ( "python" if feedback.lstrip().startswith(("def ", "from ", "import ")) else "javascript" ) feedback_renderable = Syntax( feedback, syntax_language, theme="monokai", line_numbers=False ) else: feedback_renderable = Text(feedback) console.print( Panel( feedback_renderable, title="[bold magenta]FEEDBACK[/]", border_style="magenta", ) ) def _assert_verified(task_id: str, function_name: str, observation) -> None: if function_name not in observation.verified: raise RuntimeError( f"submit failed for {task_id}:{function_name}; " f"verified={observation.verified!r}" ) def _run_task(task_id: str, verbose: bool = True) -> float: task = get_task(task_id) env = LeanMigrateEnvironment() observation = env.reset(task_id=task_id) backend_name = type(env._state.backend).__name__ verified_target_snippets: dict[str, str] = {} target_snippets = TASK_TARGET_SNIPPETS[task_id] if verbose: console.print(Rule(f"[bold white]{task.display_name}[/] [dim]({task_id})[/]")) console.print( Panel( Text.assemble( ("Episode: ", "bold cyan"), observation.episode_id, "\n", ("Order: ", "bold cyan"), ", ".join(task.topo_order), ), border_style="white", title="[bold white]TASK[/]", ) ) for function_name in task.topo_order: function_spec = task.get_function(function_name) if function_spec is None: continue if verbose: console.print(Rule(f"[bold yellow]STEP[/] [white]{function_name}[/]")) _print_step_input(f"inspect({function_name})", function_name=function_name) observation = env.step( InspectAction(type="inspect", function_name=function_name) ) if verbose: _print_observation(f"inspect({function_name})", observation) observation = env.step( AnalyzeDepsAction(type="analyze_deps", function_name=function_name) ) if verbose: _print_observation(f"analyze_deps({function_name})", observation) target_code = None if not function_spec.is_proof_required: target_code = build_submission_bundle( task, function_name, verified_target_snippets, target_snippets[function_name], ) if verbose: _print_step_input( f"run_tests({function_name})", function_name=function_name, candidate_code=target_code, ) observation = env.step( RunTestsAction( type="run_tests", function_name=function_name, candidate_code=target_code, ) ) if verbose: _print_observation(f"run_tests({function_name})", observation) if verbose: _print_step_input( f"submit_proof({function_name})" if function_spec.is_proof_required else f"submit_target_code({function_name})", function_name=function_name, target_code=target_code, lean_proof=function_spec.lean_fragment if function_spec.is_proof_required else None, ) if function_spec.is_proof_required: _print_code_preview( "LEAN PROOF PREVIEW", function_spec.lean_fragment or "", "lean", ) else: ir_result = build_verification_ir( task, function_spec, target_code or "" ) _print_ir_preview( task_id, function_name, task.target_language, backend_name, ir_result, ) submit_action = SubmitAction( type="submit", function_name=function_name, target_code=target_code, lean_proof=function_spec.lean_fragment if function_spec.is_proof_required else None, ) observation = env.step(submit_action) if verbose: _print_observation( f"submit_proof({function_name})" if function_spec.is_proof_required else f"submit({function_name})", observation, ) _assert_verified(task_id, function_name, observation) if not function_spec.is_proof_required: verified_target_snippets[function_name] = target_snippets[function_name] if verbose: console.print( Panel( Text.assemble( ("Completed: ", "bold green"), task_id, "\n", ("progress=", "bold cyan"), f"{clamp_open_unit(float(observation.progress)):.3f}", ), border_style="green", ) ) return clamp_open_unit(float(observation.progress)) def main(argv: Sequence[str] | None = None) -> None: parser = argparse.ArgumentParser( description="Play all LeanMigrate tasks one step at a time." ) parser.add_argument( "--tasks", nargs="+", default=TASK_IDS, help="Task ids to play in order.", ) parser.add_argument( "--quiet", action="store_true", help="Reduce output to per-task summaries.", ) args = parser.parse_args(argv) scores: dict[str, float] = {} for task_id in args.tasks: scores[task_id] = _run_task(task_id, verbose=not args.quiet) console.print(Rule("[bold white]PLAYTHROUGH RESULTS[/]")) overall = sum(scores.values()) / len(scores) if scores else 0.0 for task_id, score in scores.items(): console.print(f"[bold cyan]{task_id:<20}[/] {score:.3f}") console.print(f"[bold cyan]{'overall':<20}[/] {overall:.3f}") if __name__ == "__main__": main()