Spaces:
Sleeping
Sleeping
| """Deterministic step-by-step playthrough for all LeanMigrate tasks.""" | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| from collections.abc import Sequence | |
| from pathlib import Path | |
| from rich.columns import Columns | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.rule import Rule | |
| from rich.syntax import Syntax | |
| from rich.text import Text | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from lean_migrate.env.models import ( # noqa: E402 # type: ignore | |
| AnalyzeDepsAction, | |
| InspectAction, | |
| RunTestsAction, | |
| SubmitAction, | |
| ) | |
| from lean_migrate.env.grader import clamp_open_unit # noqa: E402 # type: ignore | |
| from lean_migrate.env.target_snippets import ( # noqa: E402 # type: ignore | |
| TASK_TARGET_SNIPPETS, | |
| build_submission_bundle, | |
| ) | |
| from lean_migrate.env.tasks import get_task, list_tasks # noqa: E402 # type: ignore | |
| from lean_migrate.env.verification_ir import ( # noqa: E402 # type: ignore | |
| BehaviorSummary, | |
| VerificationIRResult, | |
| build_verification_ir, | |
| ) | |
| from lean_migrate.server.lean_migrate_environment import ( # noqa: E402 # type: ignore | |
| LeanMigrateEnvironment, | |
| ) | |
| TASK_IDS = tuple(task_info["task_id"] for task_info in list_tasks()) | |
| console = Console() | |
| def _short_feedback(text: str | None, max_lines: int = 12) -> str: | |
| if not text: | |
| return "" | |
| lines = text.strip().splitlines() | |
| if len(lines) <= max_lines: | |
| return "\n".join(lines) | |
| return "\n".join(lines[:max_lines] + ["..."]) | |
| def _short_digest(text: str | None, width: int = 16) -> str: | |
| if not text: | |
| return "(n/a)" | |
| if len(text) <= width: | |
| return text | |
| return f"{text[:width]}..." | |
| def _syntax_renderable(code: str, language: str) -> Syntax | Text: | |
| try: | |
| return Syntax(code, language, theme="monokai", line_numbers=False) | |
| except Exception: | |
| return Text(code) | |
| def _print_code_preview(title: str, code: str, language: str) -> None: | |
| console.print( | |
| Panel( | |
| _syntax_renderable(code, language), | |
| title=f"[bold green]{title}[/]", | |
| border_style="green", | |
| ) | |
| ) | |
| def _summary_text(label: str, summary: BehaviorSummary | None) -> Text: | |
| if summary is None: | |
| return Text("(unavailable)", style="dim") | |
| rendered = Text() | |
| rendered.append(f"{label}\n", style="bold cyan") | |
| for field_name, value in [ | |
| ("task", summary.task_id), | |
| ("function", summary.function_name), | |
| ("language", summary.target_language), | |
| ("arity", str(summary.arity)), | |
| ("samples", str(summary.sample_count)), | |
| ("passed", str(summary.passed_count)), | |
| ("digest", _short_digest(summary.behavior_digest)), | |
| ]: | |
| rendered.append(f" {field_name}: ", style="bold cyan") | |
| rendered.append(f"{value}\n") | |
| return rendered | |
| def _print_ir_preview( | |
| task_id: str, | |
| function_name: str, | |
| target_language: str, | |
| backend_name: str, | |
| ir_result: VerificationIRResult, | |
| ) -> None: | |
| status_style = "green" if ir_result.ready else "red" | |
| summary = Text() | |
| summary.append("backend: ", style="bold cyan") | |
| summary.append(f"{backend_name}\n") | |
| summary.append("path: ", style="bold cyan") | |
| summary.append(f"{target_language} -> IR -> Lean\n") | |
| summary.append("ready: ", style="bold cyan") | |
| summary.append(f"{str(ir_result.ready).lower()}\n") | |
| if ir_result.provenance is not None: | |
| summary.append("parser: ", style="bold cyan") | |
| summary.append(f"{ir_result.provenance.parser}\n") | |
| summary.append("signature: ", style="bold cyan") | |
| summary.append( | |
| f"{ir_result.provenance.normalized_signature or '(unavailable)'}\n" | |
| ) | |
| summary.append("arity: ", style="bold cyan") | |
| summary.append( | |
| f"{ir_result.provenance.arity if ir_result.provenance.arity is not None else '(n/a)'}\n" | |
| ) | |
| summary.append("body hash: ", style="bold cyan") | |
| summary.append(f"{_short_digest(ir_result.provenance.normalized_body_hash)}\n") | |
| summary.append("source digest: ", style="bold cyan") | |
| summary.append(f"{_short_digest(ir_result.provenance.source_digest)}\n") | |
| if ir_result.run_result is not None: | |
| summary.append("samples: ", style="bold cyan") | |
| summary.append( | |
| f"{ir_result.run_result.tests_passed}/{ir_result.run_result.tests_total}\n" | |
| ) | |
| if ir_result.run_result.case_results is not None: | |
| summary.append("case traces: ", style="bold cyan") | |
| summary.append(f"{len(ir_result.run_result.case_results)}\n") | |
| if ir_result.submission_summary is not None: | |
| summary.append("behavior digest: ", style="bold cyan") | |
| summary.append( | |
| f"{_short_digest(ir_result.submission_summary.behavior_digest)}\n" | |
| ) | |
| console.print( | |
| Panel( | |
| summary, | |
| title=f"[bold magenta]IR PREVIEW[/] {task_id}:{function_name}", | |
| border_style=status_style, | |
| ) | |
| ) | |
| if ir_result.feedback: | |
| console.print( | |
| Panel( | |
| Text(ir_result.feedback), | |
| title="[bold magenta]IR TRACE[/]", | |
| border_style=status_style, | |
| ) | |
| ) | |
| if ( | |
| ir_result.submission_summary is not None | |
| and ir_result.expected_summary is not None | |
| ): | |
| console.print( | |
| Columns( | |
| [ | |
| Panel( | |
| _summary_text( | |
| "Submission summary", ir_result.submission_summary | |
| ), | |
| border_style="cyan", | |
| title="[bold cyan]SUBMISSION[/]", | |
| ), | |
| Panel( | |
| _summary_text("Expected summary", ir_result.expected_summary), | |
| border_style="green", | |
| title="[bold green]EXPECTED[/]", | |
| ), | |
| ], | |
| expand=True, | |
| ) | |
| ) | |
| if ir_result.lean_code: | |
| _print_code_preview("GENERATED LEAN", ir_result.lean_code, "lean") | |
| def _print_step_input(action_label: str, **fields: str | None) -> None: | |
| rows = [] | |
| for field_name, value in fields.items(): | |
| if value is None: | |
| continue | |
| if "\n" in value: | |
| rows.append( | |
| Text.assemble( | |
| (f"{field_name}\n", "bold cyan"), | |
| (value.rstrip(), ""), | |
| ) | |
| ) | |
| else: | |
| rows.append(Text.assemble((f"{field_name}: ", "bold cyan"), value)) | |
| body = Text() | |
| for index, row in enumerate(rows): | |
| body.append_text(row) | |
| if index != len(rows) - 1: | |
| body.append("\n\n") | |
| console.print( | |
| Panel( | |
| body if rows else Text("(no inputs)", style="dim"), | |
| title=f"[bold blue]INPUT[/] {action_label}", | |
| border_style="blue", | |
| ) | |
| ) | |
| def _print_observation(action_label: str, observation) -> None: | |
| reward = clamp_open_unit(float(observation.reward or 0.0)) | |
| status_style = ( | |
| "green" | |
| if observation.reward_details | |
| and observation.reward_details.feedback.startswith("VERIFIED") | |
| else "yellow" | |
| ) | |
| summary = Text.assemble( | |
| ("reward=", "bold cyan"), | |
| f"{reward:.3f} ", | |
| ("progress=", "bold cyan"), | |
| f"{clamp_open_unit(float(observation.progress)):.3f} ", | |
| ("done=", "bold cyan"), | |
| str(observation.done).lower(), | |
| ) | |
| console.print( | |
| Panel( | |
| summary, | |
| title=f"[bold {status_style}]ACTION[/] {action_label}", | |
| border_style=status_style, | |
| ) | |
| ) | |
| feedback = _short_feedback(getattr(observation.reward_details, "feedback", None)) | |
| if feedback: | |
| if "def " in feedback or feedback.lstrip().startswith( | |
| ("function ", "from ", "import ") | |
| ): | |
| syntax_language = ( | |
| "python" | |
| if feedback.lstrip().startswith(("def ", "from ", "import ")) | |
| else "javascript" | |
| ) | |
| feedback_renderable = Syntax( | |
| feedback, syntax_language, theme="monokai", line_numbers=False | |
| ) | |
| else: | |
| feedback_renderable = Text(feedback) | |
| console.print( | |
| Panel( | |
| feedback_renderable, | |
| title="[bold magenta]FEEDBACK[/]", | |
| border_style="magenta", | |
| ) | |
| ) | |
| def _assert_verified(task_id: str, function_name: str, observation) -> None: | |
| if function_name not in observation.verified: | |
| raise RuntimeError( | |
| f"submit failed for {task_id}:{function_name}; " | |
| f"verified={observation.verified!r}" | |
| ) | |
| def _run_task(task_id: str, verbose: bool = True) -> float: | |
| task = get_task(task_id) | |
| env = LeanMigrateEnvironment() | |
| observation = env.reset(task_id=task_id) | |
| backend_name = type(env._state.backend).__name__ | |
| verified_target_snippets: dict[str, str] = {} | |
| target_snippets = TASK_TARGET_SNIPPETS[task_id] | |
| if verbose: | |
| console.print(Rule(f"[bold white]{task.display_name}[/] [dim]({task_id})[/]")) | |
| console.print( | |
| Panel( | |
| Text.assemble( | |
| ("Episode: ", "bold cyan"), | |
| observation.episode_id, | |
| "\n", | |
| ("Order: ", "bold cyan"), | |
| ", ".join(task.topo_order), | |
| ), | |
| border_style="white", | |
| title="[bold white]TASK[/]", | |
| ) | |
| ) | |
| for function_name in task.topo_order: | |
| function_spec = task.get_function(function_name) | |
| if function_spec is None: | |
| continue | |
| if verbose: | |
| console.print(Rule(f"[bold yellow]STEP[/] [white]{function_name}[/]")) | |
| _print_step_input(f"inspect({function_name})", function_name=function_name) | |
| observation = env.step( | |
| InspectAction(type="inspect", function_name=function_name) | |
| ) | |
| if verbose: | |
| _print_observation(f"inspect({function_name})", observation) | |
| observation = env.step( | |
| AnalyzeDepsAction(type="analyze_deps", function_name=function_name) | |
| ) | |
| if verbose: | |
| _print_observation(f"analyze_deps({function_name})", observation) | |
| target_code = None | |
| if not function_spec.is_proof_required: | |
| target_code = build_submission_bundle( | |
| task, | |
| function_name, | |
| verified_target_snippets, | |
| target_snippets[function_name], | |
| ) | |
| if verbose: | |
| _print_step_input( | |
| f"run_tests({function_name})", | |
| function_name=function_name, | |
| candidate_code=target_code, | |
| ) | |
| observation = env.step( | |
| RunTestsAction( | |
| type="run_tests", | |
| function_name=function_name, | |
| candidate_code=target_code, | |
| ) | |
| ) | |
| if verbose: | |
| _print_observation(f"run_tests({function_name})", observation) | |
| if verbose: | |
| _print_step_input( | |
| f"submit_proof({function_name})" | |
| if function_spec.is_proof_required | |
| else f"submit_target_code({function_name})", | |
| function_name=function_name, | |
| target_code=target_code, | |
| lean_proof=function_spec.lean_fragment | |
| if function_spec.is_proof_required | |
| else None, | |
| ) | |
| if function_spec.is_proof_required: | |
| _print_code_preview( | |
| "LEAN PROOF PREVIEW", | |
| function_spec.lean_fragment or "", | |
| "lean", | |
| ) | |
| else: | |
| ir_result = build_verification_ir( | |
| task, function_spec, target_code or "" | |
| ) | |
| _print_ir_preview( | |
| task_id, | |
| function_name, | |
| task.target_language, | |
| backend_name, | |
| ir_result, | |
| ) | |
| submit_action = SubmitAction( | |
| type="submit", | |
| function_name=function_name, | |
| target_code=target_code, | |
| lean_proof=function_spec.lean_fragment | |
| if function_spec.is_proof_required | |
| else None, | |
| ) | |
| observation = env.step(submit_action) | |
| if verbose: | |
| _print_observation( | |
| f"submit_proof({function_name})" | |
| if function_spec.is_proof_required | |
| else f"submit({function_name})", | |
| observation, | |
| ) | |
| _assert_verified(task_id, function_name, observation) | |
| if not function_spec.is_proof_required: | |
| verified_target_snippets[function_name] = target_snippets[function_name] | |
| if verbose: | |
| console.print( | |
| Panel( | |
| Text.assemble( | |
| ("Completed: ", "bold green"), | |
| task_id, | |
| "\n", | |
| ("progress=", "bold cyan"), | |
| f"{clamp_open_unit(float(observation.progress)):.3f}", | |
| ), | |
| border_style="green", | |
| ) | |
| ) | |
| return clamp_open_unit(float(observation.progress)) | |
| def main(argv: Sequence[str] | None = None) -> None: | |
| parser = argparse.ArgumentParser( | |
| description="Play all LeanMigrate tasks one step at a time." | |
| ) | |
| parser.add_argument( | |
| "--tasks", | |
| nargs="+", | |
| default=TASK_IDS, | |
| help="Task ids to play in order.", | |
| ) | |
| parser.add_argument( | |
| "--quiet", | |
| action="store_true", | |
| help="Reduce output to per-task summaries.", | |
| ) | |
| args = parser.parse_args(argv) | |
| scores: dict[str, float] = {} | |
| for task_id in args.tasks: | |
| scores[task_id] = _run_task(task_id, verbose=not args.quiet) | |
| console.print(Rule("[bold white]PLAYTHROUGH RESULTS[/]")) | |
| overall = sum(scores.values()) / len(scores) if scores else 0.0 | |
| for task_id, score in scores.items(): | |
| console.print(f"[bold cyan]{task_id:<20}[/] {score:.3f}") | |
| console.print(f"[bold cyan]{'overall':<20}[/] {overall:.3f}") | |
| if __name__ == "__main__": | |
| main() | |