lean-migrate / scripts /play_all_tasks.py
Hrushi's picture
Upload folder using huggingface_hub
361f093 verified
"""Deterministic step-by-step playthrough for all LeanMigrate tasks."""
from __future__ import annotations
import argparse
import sys
from collections.abc import Sequence
from pathlib import Path
from rich.columns import Columns
from rich.console import Console
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.text import Text
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from lean_migrate.env.models import ( # noqa: E402 # type: ignore
AnalyzeDepsAction,
InspectAction,
RunTestsAction,
SubmitAction,
)
from lean_migrate.env.grader import clamp_open_unit # noqa: E402 # type: ignore
from lean_migrate.env.target_snippets import ( # noqa: E402 # type: ignore
TASK_TARGET_SNIPPETS,
build_submission_bundle,
)
from lean_migrate.env.tasks import get_task, list_tasks # noqa: E402 # type: ignore
from lean_migrate.env.verification_ir import ( # noqa: E402 # type: ignore
BehaviorSummary,
VerificationIRResult,
build_verification_ir,
)
from lean_migrate.server.lean_migrate_environment import ( # noqa: E402 # type: ignore
LeanMigrateEnvironment,
)
TASK_IDS = tuple(task_info["task_id"] for task_info in list_tasks())
console = Console()
def _short_feedback(text: str | None, max_lines: int = 12) -> str:
if not text:
return ""
lines = text.strip().splitlines()
if len(lines) <= max_lines:
return "\n".join(lines)
return "\n".join(lines[:max_lines] + ["..."])
def _short_digest(text: str | None, width: int = 16) -> str:
if not text:
return "(n/a)"
if len(text) <= width:
return text
return f"{text[:width]}..."
def _syntax_renderable(code: str, language: str) -> Syntax | Text:
try:
return Syntax(code, language, theme="monokai", line_numbers=False)
except Exception:
return Text(code)
def _print_code_preview(title: str, code: str, language: str) -> None:
console.print(
Panel(
_syntax_renderable(code, language),
title=f"[bold green]{title}[/]",
border_style="green",
)
)
def _summary_text(label: str, summary: BehaviorSummary | None) -> Text:
if summary is None:
return Text("(unavailable)", style="dim")
rendered = Text()
rendered.append(f"{label}\n", style="bold cyan")
for field_name, value in [
("task", summary.task_id),
("function", summary.function_name),
("language", summary.target_language),
("arity", str(summary.arity)),
("samples", str(summary.sample_count)),
("passed", str(summary.passed_count)),
("digest", _short_digest(summary.behavior_digest)),
]:
rendered.append(f" {field_name}: ", style="bold cyan")
rendered.append(f"{value}\n")
return rendered
def _print_ir_preview(
task_id: str,
function_name: str,
target_language: str,
backend_name: str,
ir_result: VerificationIRResult,
) -> None:
status_style = "green" if ir_result.ready else "red"
summary = Text()
summary.append("backend: ", style="bold cyan")
summary.append(f"{backend_name}\n")
summary.append("path: ", style="bold cyan")
summary.append(f"{target_language} -> IR -> Lean\n")
summary.append("ready: ", style="bold cyan")
summary.append(f"{str(ir_result.ready).lower()}\n")
if ir_result.provenance is not None:
summary.append("parser: ", style="bold cyan")
summary.append(f"{ir_result.provenance.parser}\n")
summary.append("signature: ", style="bold cyan")
summary.append(
f"{ir_result.provenance.normalized_signature or '(unavailable)'}\n"
)
summary.append("arity: ", style="bold cyan")
summary.append(
f"{ir_result.provenance.arity if ir_result.provenance.arity is not None else '(n/a)'}\n"
)
summary.append("body hash: ", style="bold cyan")
summary.append(f"{_short_digest(ir_result.provenance.normalized_body_hash)}\n")
summary.append("source digest: ", style="bold cyan")
summary.append(f"{_short_digest(ir_result.provenance.source_digest)}\n")
if ir_result.run_result is not None:
summary.append("samples: ", style="bold cyan")
summary.append(
f"{ir_result.run_result.tests_passed}/{ir_result.run_result.tests_total}\n"
)
if ir_result.run_result.case_results is not None:
summary.append("case traces: ", style="bold cyan")
summary.append(f"{len(ir_result.run_result.case_results)}\n")
if ir_result.submission_summary is not None:
summary.append("behavior digest: ", style="bold cyan")
summary.append(
f"{_short_digest(ir_result.submission_summary.behavior_digest)}\n"
)
console.print(
Panel(
summary,
title=f"[bold magenta]IR PREVIEW[/] {task_id}:{function_name}",
border_style=status_style,
)
)
if ir_result.feedback:
console.print(
Panel(
Text(ir_result.feedback),
title="[bold magenta]IR TRACE[/]",
border_style=status_style,
)
)
if (
ir_result.submission_summary is not None
and ir_result.expected_summary is not None
):
console.print(
Columns(
[
Panel(
_summary_text(
"Submission summary", ir_result.submission_summary
),
border_style="cyan",
title="[bold cyan]SUBMISSION[/]",
),
Panel(
_summary_text("Expected summary", ir_result.expected_summary),
border_style="green",
title="[bold green]EXPECTED[/]",
),
],
expand=True,
)
)
if ir_result.lean_code:
_print_code_preview("GENERATED LEAN", ir_result.lean_code, "lean")
def _print_step_input(action_label: str, **fields: str | None) -> None:
rows = []
for field_name, value in fields.items():
if value is None:
continue
if "\n" in value:
rows.append(
Text.assemble(
(f"{field_name}\n", "bold cyan"),
(value.rstrip(), ""),
)
)
else:
rows.append(Text.assemble((f"{field_name}: ", "bold cyan"), value))
body = Text()
for index, row in enumerate(rows):
body.append_text(row)
if index != len(rows) - 1:
body.append("\n\n")
console.print(
Panel(
body if rows else Text("(no inputs)", style="dim"),
title=f"[bold blue]INPUT[/] {action_label}",
border_style="blue",
)
)
def _print_observation(action_label: str, observation) -> None:
reward = clamp_open_unit(float(observation.reward or 0.0))
status_style = (
"green"
if observation.reward_details
and observation.reward_details.feedback.startswith("VERIFIED")
else "yellow"
)
summary = Text.assemble(
("reward=", "bold cyan"),
f"{reward:.3f} ",
("progress=", "bold cyan"),
f"{clamp_open_unit(float(observation.progress)):.3f} ",
("done=", "bold cyan"),
str(observation.done).lower(),
)
console.print(
Panel(
summary,
title=f"[bold {status_style}]ACTION[/] {action_label}",
border_style=status_style,
)
)
feedback = _short_feedback(getattr(observation.reward_details, "feedback", None))
if feedback:
if "def " in feedback or feedback.lstrip().startswith(
("function ", "from ", "import ")
):
syntax_language = (
"python"
if feedback.lstrip().startswith(("def ", "from ", "import "))
else "javascript"
)
feedback_renderable = Syntax(
feedback, syntax_language, theme="monokai", line_numbers=False
)
else:
feedback_renderable = Text(feedback)
console.print(
Panel(
feedback_renderable,
title="[bold magenta]FEEDBACK[/]",
border_style="magenta",
)
)
def _assert_verified(task_id: str, function_name: str, observation) -> None:
if function_name not in observation.verified:
raise RuntimeError(
f"submit failed for {task_id}:{function_name}; "
f"verified={observation.verified!r}"
)
def _run_task(task_id: str, verbose: bool = True) -> float:
task = get_task(task_id)
env = LeanMigrateEnvironment()
observation = env.reset(task_id=task_id)
backend_name = type(env._state.backend).__name__
verified_target_snippets: dict[str, str] = {}
target_snippets = TASK_TARGET_SNIPPETS[task_id]
if verbose:
console.print(Rule(f"[bold white]{task.display_name}[/] [dim]({task_id})[/]"))
console.print(
Panel(
Text.assemble(
("Episode: ", "bold cyan"),
observation.episode_id,
"\n",
("Order: ", "bold cyan"),
", ".join(task.topo_order),
),
border_style="white",
title="[bold white]TASK[/]",
)
)
for function_name in task.topo_order:
function_spec = task.get_function(function_name)
if function_spec is None:
continue
if verbose:
console.print(Rule(f"[bold yellow]STEP[/] [white]{function_name}[/]"))
_print_step_input(f"inspect({function_name})", function_name=function_name)
observation = env.step(
InspectAction(type="inspect", function_name=function_name)
)
if verbose:
_print_observation(f"inspect({function_name})", observation)
observation = env.step(
AnalyzeDepsAction(type="analyze_deps", function_name=function_name)
)
if verbose:
_print_observation(f"analyze_deps({function_name})", observation)
target_code = None
if not function_spec.is_proof_required:
target_code = build_submission_bundle(
task,
function_name,
verified_target_snippets,
target_snippets[function_name],
)
if verbose:
_print_step_input(
f"run_tests({function_name})",
function_name=function_name,
candidate_code=target_code,
)
observation = env.step(
RunTestsAction(
type="run_tests",
function_name=function_name,
candidate_code=target_code,
)
)
if verbose:
_print_observation(f"run_tests({function_name})", observation)
if verbose:
_print_step_input(
f"submit_proof({function_name})"
if function_spec.is_proof_required
else f"submit_target_code({function_name})",
function_name=function_name,
target_code=target_code,
lean_proof=function_spec.lean_fragment
if function_spec.is_proof_required
else None,
)
if function_spec.is_proof_required:
_print_code_preview(
"LEAN PROOF PREVIEW",
function_spec.lean_fragment or "",
"lean",
)
else:
ir_result = build_verification_ir(
task, function_spec, target_code or ""
)
_print_ir_preview(
task_id,
function_name,
task.target_language,
backend_name,
ir_result,
)
submit_action = SubmitAction(
type="submit",
function_name=function_name,
target_code=target_code,
lean_proof=function_spec.lean_fragment
if function_spec.is_proof_required
else None,
)
observation = env.step(submit_action)
if verbose:
_print_observation(
f"submit_proof({function_name})"
if function_spec.is_proof_required
else f"submit({function_name})",
observation,
)
_assert_verified(task_id, function_name, observation)
if not function_spec.is_proof_required:
verified_target_snippets[function_name] = target_snippets[function_name]
if verbose:
console.print(
Panel(
Text.assemble(
("Completed: ", "bold green"),
task_id,
"\n",
("progress=", "bold cyan"),
f"{clamp_open_unit(float(observation.progress)):.3f}",
),
border_style="green",
)
)
return clamp_open_unit(float(observation.progress))
def main(argv: Sequence[str] | None = None) -> None:
parser = argparse.ArgumentParser(
description="Play all LeanMigrate tasks one step at a time."
)
parser.add_argument(
"--tasks",
nargs="+",
default=TASK_IDS,
help="Task ids to play in order.",
)
parser.add_argument(
"--quiet",
action="store_true",
help="Reduce output to per-task summaries.",
)
args = parser.parse_args(argv)
scores: dict[str, float] = {}
for task_id in args.tasks:
scores[task_id] = _run_task(task_id, verbose=not args.quiet)
console.print(Rule("[bold white]PLAYTHROUGH RESULTS[/]"))
overall = sum(scores.values()) / len(scores) if scores else 0.0
for task_id, score in scores.items():
console.print(f"[bold cyan]{task_id:<20}[/] {score:.3f}")
console.print(f"[bold cyan]{'overall':<20}[/] {overall:.3f}")
if __name__ == "__main__":
main()