lean-migrate / scripts /play_task.py
Hrushi's picture
Upload folder using huggingface_hub
bf9c466 verified
"""play_task β€” LLM-driven CLI for LeanMigrate task playthrough.
Usage:
uv run python scripts/play_task.py list
uv run python scripts/play_task.py play lru_cache expression_eval
uv run python scripts/play_task.py play --all
"""
from __future__ import annotations
import datetime
import json
import os
import re
import sys
import textwrap
from pathlib import Path
from typing import Annotated, Any, Literal, Optional
from dotenv import load_dotenv
import typer
from rich.columns import Columns
from rich.console import Console
from rich.panel import Panel
from rich.rule import Rule
from rich.syntax import Syntax
from rich.table import Table
from rich.text import Text
# Load .env before anything reads os.environ (including typer envvar resolution)
load_dotenv(Path(__file__).resolve().parents[1] / ".env")
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from lean_migrate.env.grader import clamp_open_unit, oracle_result # noqa: E402
from lean_migrate.env.models import ( # noqa: E402
AnalyzeDepsAction,
InspectAction,
LeanMigrateAction,
LeanMigrateObservation,
RunTestsAction,
SubmitAction,
)
from lean_migrate.env.tasks import FunctionSpec, Task, get_task, list_tasks # noqa: E402
from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment # noqa: E402
# ── Pydantic schema for structured LLM output ────────────────────────────────
from pydantic import BaseModel, Field # noqa: E402
class ActionRequest(BaseModel):
"""Structured output schema for LLM action responses."""
type: Literal["inspect", "analyze_deps", "run_tests", "submit"]
function_name: str = Field(description="Name of the function this action targets.")
candidate_code: str | None = Field(
default=None,
description="Full implementation for run_tests. Omit for other action types.",
)
lean_proof: str | None = Field(
default=None,
description="Lean proof text. Only required for proof tasks.",
)
# ── App + console ─────────────────────────────────────────────────────────────
app = typer.Typer(
name="play_task",
help="LLM-driven LeanMigrate task runner with rich step display.",
add_completion=False,
)
console = Console()
# ── Rich helpers ──────────────────────────────────────────────────────────────
def _difficulty_color(d: str) -> str:
return {"easy": "green", "medium": "yellow", "hard": "red"}.get(d, "white")
def _trim(text: str, max_lines: int = 20) -> str:
lines = text.strip().splitlines()
if len(lines) <= max_lines:
return "\n".join(lines)
return (
"\n".join(lines[:max_lines])
+ f"\n[dim]… ({len(lines) - max_lines} more lines)[/dim]"
)
def _split_at_semicolons(s: str) -> list[str]:
"""Depth-aware split at ';' β€” won't split inside brackets."""
parts: list[str] = []
depth = 0
buf: list[str] = []
for ch in s:
if ch in "([{":
depth += 1
buf.append(ch)
elif ch in ")]}":
depth -= 1
buf.append(ch)
elif ch == ";" and depth == 0:
part = "".join(buf).strip()
if part:
parts.append(part)
buf = []
else:
buf.append(ch)
last = "".join(buf).strip()
if last:
parts.append(last)
return parts
# Lean tactic keywords that start a new tactic at the top level of a `by` block.
_LEAN_TACTIC_KWS = re.compile(
r"(?<!\w)("
r"intro\b|intros\b|apply\b|refine\b|exact\b|simp\b|rfl\b|norm_num\b|"
r"decide\b|ring\b|omega\b|constructor\b|cases\b|induction\b|have\b|"
r"obtain\b|rw\s*\[|rewrite\s*\[|calc\b|show\b|suffices\b|use\b|"
r"trivial\b|tauto\b|aesop\b|linarith\b|nlinarith\b|positivity\b|"
r"field_simp\b|push_neg\b|contrapose\b|by_contra\b|by_cases\b|"
r"rcases\b|ext\b|funext\b|congr\b|subst\b|assumption\b|"
r"contradiction\b|scalar_tac\b|progress\b|next\b|all_goals\b|"
r"any_goals\b|first\b|repeat\b|done\b|skip\b"
r")"
)
def _format_lean(code: str) -> str:
"""Best-effort Lean 4 single-line β†’ multi-line formatter."""
# Already multi-line β€” just normalise indentation slightly.
if "\n" in code.strip():
lines = [ln.rstrip() for ln in code.strip().splitlines()]
return "\n".join(lines)
s = code.strip()
# Match: <decl head> := by <tactics>
m = re.match(r"^(.*?)\s*:=\s*by\s+(.+)$", s, re.DOTALL)
if m:
head, tactics_str = m.group(1).strip(), m.group(2).strip()
# First split at semicolons
tactics = _split_at_semicolons(tactics_str)
# Then split each part further at tactic keyword boundaries (depth 0 only)
expanded: list[str] = []
for part in tactics:
# Find positions of tactic-keyword starts (skip inside brackets)
splits: list[int] = [0]
depth = 0
i = 0
while i < len(part):
if part[i] in "([{":
depth += 1
elif part[i] in ")]}":
depth -= 1
elif depth == 0 and i > 0 and part[i - 1] == " ":
hit = _LEAN_TACTIC_KWS.match(part, i)
if hit:
splits.append(i)
i += 1
for j, start in enumerate(splits):
end = splits[j + 1] if j + 1 < len(splits) else len(part)
chunk = part[start:end].strip()
if chunk:
expanded.append(chunk)
if len(expanded) == 1:
return f"{head} := by\n {expanded[0]}"
return head + " := by\n" + "\n".join(f" {t}" for t in expanded)
# Match: <decl head> := <term> (no `by`)
m2 = re.match(r"^(.*?)\s*:=\s*(.+)$", s, re.DOTALL)
if m2:
head, body = m2.group(1).strip(), m2.group(2).strip()
return f"{head} :=\n {body}"
return s
def _format_code(code: str, language: str) -> str:
"""Run a language-specific formatter. Returns original on any failure."""
import subprocess
try:
if language == "lean":
return _format_lean(code)
if language == "typescript":
result = subprocess.run(
["npx", "--yes", "prettier", "--stdin-filepath", "candidate.ts"],
input=code, capture_output=True, text=True, timeout=15,
)
elif language == "python":
result = subprocess.run(
["ruff", "format", "--stdin-filename", "candidate.py", "-"],
input=code, capture_output=True, text=True, timeout=10,
)
elif language == "rust":
result = subprocess.run(
["rustfmt", "--edition", "2021"],
input=code, capture_output=True, text=True, timeout=10,
)
else:
# Unknown language β€” return as-is (word_wrap still applies in panel)
return code
return result.stdout if result.returncode == 0 and result.stdout.strip() else code
except Exception:
return code
def _code_panel(title: str, code: str, language: str, border: str = "blue") -> Panel:
formatted = _format_code(code, language)
try:
body = Syntax(formatted, language, theme="monokai", line_numbers=True, word_wrap=True)
except Exception:
body = Text(formatted, no_wrap=False)
return Panel(body, title=f"[bold {border}]{title}[/]", border_style=border)
def _text_panel(title: str, body: Text | str, border: str = "cyan") -> Panel:
if isinstance(body, str):
body = Text(body)
return Panel(body, title=f"[bold {border}]{title}[/]", border_style=border)
def _state_text(obs: LeanMigrateObservation) -> Text:
t = Text()
t.append("progress : ", style="bold cyan")
t.append(f"{clamp_open_unit(float(obs.progress)):.3f}\n")
t.append("verified : ", style="bold cyan")
t.append(", ".join(obs.verified) if obs.verified else "(none)")
t.append("\n")
t.append("remaining: ", style="bold cyan")
t.append(", ".join(obs.remaining) if obs.remaining else "(none)")
t.append("\n")
t.append("failing : ", style="bold cyan")
t.append(json.dumps(obs.failing) if obs.failing else "(none)")
t.append("\n")
t.append("done : ", style="bold cyan")
t.append(str(obs.done).lower())
return t
def _action_text(ad: dict[str, Any]) -> Text:
t = Text()
t.append("type : ", style="bold magenta")
t.append(ad.get("type", "?"))
t.append("\n")
t.append("function : ", style="bold magenta")
t.append(ad.get("function_name", "?"))
return t
def _obs_text(obs: LeanMigrateObservation) -> Text:
t = Text()
reward_val = clamp_open_unit(float(obs.reward or 0.0))
color = "green" if reward_val > 0 else ("red" if reward_val < 0 else "yellow")
t.append("reward : ", style=f"bold {color}")
t.append(f"{reward_val:.3f}\n", style=color)
t.append("done : ", style="bold cyan")
t.append(str(obs.done).lower())
return t
# ── Trace helpers ─────────────────────────────────────────────────────────────
def _obs_snapshot(obs: LeanMigrateObservation) -> dict[str, Any]:
"""Serialise the parts of an observation useful for offline analysis."""
return {
"progress": float(obs.progress),
"verified": list(obs.verified),
"remaining": list(obs.remaining),
"failing": dict(obs.failing),
"done": bool(obs.done),
"episode_step": int(obs.episode_step),
}
def _reward_snapshot(obs: LeanMigrateObservation) -> dict[str, Any] | None:
rd = obs.reward_details
if rd is None:
return None
return {
"score": float(rd.score),
"cumulative_score": float(rd.cumulative_score),
"tests_passed": rd.tests_passed,
"tests_total": rd.tests_total,
"proof_compiled": rd.proof_compiled,
"breakdown": dict(rd.breakdown),
"feedback": rd.feedback,
"lean_error": rd.lean_error,
}
def _step_analysis(
obs_before: LeanMigrateObservation,
obs_after: LeanMigrateObservation,
action_dict: dict[str, Any],
reward: float,
) -> dict[str, Any]:
"""Derived signals useful for reward-shaping / behaviour analysis."""
progress_delta = float(obs_after.progress) - float(obs_before.progress)
rd = obs_after.reward_details
tests_pass_rate: float | None = None
if rd and rd.tests_total:
tests_pass_rate = (rd.tests_passed or 0) / rd.tests_total
return {
"progress_delta": round(progress_delta, 6),
"tests_pass_rate": tests_pass_rate,
"lean_error_present": bool(rd and rd.lean_error),
"proof_compiled": rd.proof_compiled if rd else None,
"new_function_verified": progress_delta > 0,
"reward_sign": "positive" if reward > 0 else ("negative" if reward < 0 else "zero"),
"action_type": action_dict.get("type"),
"used_inspection": action_dict.get("type") in {"inspect", "analyze_deps"},
}
def _write_partial_trace(
out_path: Path,
task_id: str,
episode_id: str,
model: str,
api_base_url: str,
temperature: float,
task: Any,
steps: list[dict[str, Any]],
rewards: list[float],
system_prompt: str = "",
) -> None:
"""Write an in-progress snapshot so work is preserved on interruption."""
trace = {
"meta": {
"task_id": task_id,
"episode_id": episode_id,
"model": model,
"api_base_url": api_base_url,
"temperature": temperature,
"difficulty": task.difficulty,
"source_language": task.source_language,
"target_language": task.target_language,
"functions": [f.name for f in task.functions],
"proof_functions": [f.name for f in task.functions if f.is_proof_required],
"max_steps": task.max_steps,
"system_prompt": system_prompt,
"partial": True,
},
"steps": steps,
"summary": None,
}
out_path.write_text(json.dumps(trace, indent=2, ensure_ascii=False))
def _write_trace(
trace_dir: Path,
task_id: str,
episode_id: str,
model: str,
api_base_url: str,
temperature: float,
task: Any,
steps: list[dict[str, Any]],
rewards: list[float],
final_score: float,
out_path: Path | None = None,
system_prompt: str = "",
) -> Path:
"""Write the full episode trace to a JSON file and return its path."""
trace_dir.mkdir(parents=True, exist_ok=True)
ts = datetime.datetime.now(datetime.timezone.utc)
ts_str = ts.strftime("%Y%m%dT%H%M%SZ")
# ── Summary analytics ──────────────────────────────────────────────────────
action_counts: dict[str, int] = {}
lean_errors = 0
negative_rewards = 0
max_consecutive_run_tests = 0
current_run_tests_streak = 0
total_code_chars = 0
total_test_runs = 0
total_tests_passed = 0
total_tests_total = 0
for s in steps:
atype = s["action"]["type"]
action_counts[atype] = action_counts.get(atype, 0) + 1
if s["analysis"]["lean_error_present"]:
lean_errors += 1
if s["reward"] < 0:
negative_rewards += 1
if atype == "run_tests":
current_run_tests_streak += 1
max_consecutive_run_tests = max(
max_consecutive_run_tests, current_run_tests_streak
)
code = s["action"].get("candidate_code") or ""
total_code_chars += len(code)
total_test_runs += 1
rd = s["reward_details"] or {}
total_tests_passed += rd.get("tests_passed") or 0
total_tests_total += rd.get("tests_total") or 0
else:
current_run_tests_streak = 0
n_types = len(action_counts)
total_actions = len(steps)
# Shannon entropy of action types as a diversity score
import math
diversity = 0.0
if total_actions > 0 and n_types > 1:
for cnt in action_counts.values():
p = cnt / total_actions
if p > 0:
diversity -= p * math.log2(p)
diversity /= math.log2(len({"inspect", "analyze_deps", "run_tests", "submit"}))
final_obs = steps[-1]["state_after"] if steps else {}
summary = {
"total_steps": len(steps),
"final_score": round(final_score, 6),
"success": final_score >= 0.99,
"rewards": [round(r, 4) for r in rewards],
"reward_sum": round(sum(rewards), 4),
"action_type_counts": action_counts,
"action_diversity_score": round(diversity, 4),
"functions_verified": final_obs.get("verified", []),
"functions_remaining": final_obs.get("remaining", []),
"functions_failing": list((final_obs.get("failing") or {}).keys()),
"max_consecutive_run_tests": max_consecutive_run_tests,
"lean_errors_count": lean_errors,
"negative_rewards_count": negative_rewards,
"avg_tests_pass_rate": (
round(total_tests_passed / total_tests_total, 4)
if total_tests_total > 0
else None
),
"never_used_inspect": action_counts.get("inspect", 0) == 0,
"never_used_analyze_deps": action_counts.get("analyze_deps", 0) == 0,
}
trace = {
"meta": {
"task_id": task_id,
"episode_id": episode_id,
"model": model,
"api_base_url": api_base_url,
"temperature": temperature,
"timestamp": ts.isoformat(),
"difficulty": task.difficulty,
"source_language": task.source_language,
"target_language": task.target_language,
"functions": [f.name for f in task.functions],
"proof_functions": [f.name for f in task.functions if f.is_proof_required],
"max_steps": task.max_steps,
"system_prompt": system_prompt,
},
"steps": steps,
"summary": summary,
}
if out_path is None:
filename = f"{task_id}__{ts_str}__{episode_id[:8]}.json"
out_path = trace_dir / filename
out_path.write_text(json.dumps(trace, indent=2, ensure_ascii=False))
return out_path
def _print_step(
step: int,
obs_before: LeanMigrateObservation,
action_dict: dict[str, Any],
obs_after: LeanMigrateObservation,
target_language: str,
) -> None:
action_type = action_dict.get("type", "?")
fn_name = action_dict.get("function_name", "?")
console.print(
Rule(f"[bold yellow]STEP {step}[/] [white]{action_type}({fn_name})[/]")
)
# ── State (before) + Action summary side-by-side
console.print(
Columns(
[
_text_panel("STATE (before)", _state_text(obs_before), "cyan"),
_text_panel("ACTION", _action_text(action_dict), "magenta"),
],
expand=True,
)
)
# ── Code block (if run_tests)
code = action_dict.get("candidate_code")
if code:
lang = (
"rust"
if target_language == "rust"
else ("typescript" if target_language == "typescript" else "python")
)
console.print(_code_panel(f"CANDIDATE CODE [{lang}]", code, lang, "blue"))
# ── Lean proof (if submit with proof)
proof = action_dict.get("lean_proof")
if proof:
console.print(_code_panel("LEAN PROOF", proof, "lean", "green"))
# ── Observation (after)
rd = obs_after.reward_details
feedback = obs_after.last_action_feedback or ""
lean_err = (rd.lean_error if rd else None) or ""
reward_val = clamp_open_unit(float(obs_after.reward or 0.0))
border = "green" if reward_val > 0 else ("red" if reward_val < 0 else "yellow")
obs_body = Text()
obs_body.append_text(_obs_text(obs_after))
if feedback:
obs_body.append("\n\nfeedback :\n", style="bold cyan")
obs_body.append(_trim(feedback, 15))
if lean_err:
obs_body.append("\n\nlean error:\n", style="bold red")
obs_body.append(_trim(lean_err, 20))
console.print(_text_panel("OBSERVATION", obs_body, border))
# ── LLM helpers (from inference.py) ──────────────────────────────────────────
def _get_simple_sample(task: Task, fn: FunctionSpec) -> str:
cases = task.sample_inputs.get(fn.name, [])
if not cases or fn.is_proof_required:
return ""
case = cases[0]
expected = oracle_result(task.task_id, fn.name, case.args)
args_repr = ", ".join(repr(a) for a in case.args)
if len(args_repr) > 300:
args_repr = args_repr[:297] + "..."
return f"Sample: ({args_repr}) β†’ {expected!r}"
def _get_language_notes(task: Task) -> str:
if task.target_language == "typescript":
return textwrap.dedent("""
## TypeScript notes
- Use top-level `function` declarations (not arrow functions / const assignments).
- Dependency code is injected automatically β€” write only the current function.
""").strip()
if task.target_language == "rust":
return textwrap.dedent("""
## Rust notes
- Use serde_json and serde (already in Cargo.toml).
- Write only the current function. Dependencies are injected as separate pub fn definitions.
- Match the inspected Rust signature exactly, including ownership and mutability. The harness passes owned JSON-decoded values into your function, so do not add `&`, `&mut`, or slice parameters unless the inspected source explicitly shows them.
- If the inspect output shows a `Source symbol`, use that exact Rust function name in your code.
- Return types serializable to serde_json::Value (Vec, Option, tuples all work).
- run_tests compiles via cargo. If run_tests returns "cargo not found", cargo is
unavailable β€” submit your implementation directly without run_tests.
""").strip()
return ""
def _get_proof_notes(task: Task) -> str:
if any(fn.is_proof_required for fn in task.functions):
return textwrap.dedent("""
## Proof task notes
- NEVER use `sorry` β€” the verifier rejects it IMMEDIATELY, gives -0.05 reward,
and does not compile Lean at all. `sorry` has zero partial credit.
- If you cannot write the full proof yet, DO NOT submit. Call `inspect` on the
proof function to read the exact theorem and full Lean spec, then reason step by step.
- Useful first-try tactics: `by decide`, `by native_decide`, `by simp [...]`,
`by omega`, `by rfl`, `by constructor <;> simp`.
- Use qualified ADT constructors from the spec (e.g. `Spec.MyType.mk`).
""").strip()
return ""
def _system_prompt(task: Task) -> str:
fn_sections = []
for fn in task.functions:
proof_tag = " [PROOF REQUIRED β€” submit with lean_proof field]" if fn.is_proof_required else ""
fn_sections.append(
f"### {fn.name}{proof_tag}\n"
f"Description: {fn.description}\n"
f"Depends on: {fn.depends_on or ['(none)']}\n"
)
language_notes = _get_language_notes(task)
proof_notes = _get_proof_notes(task)
extra = "\n\n".join(s for s in [language_notes, proof_notes] if s)
return textwrap.dedent(
f"""
You are an expert software engineer migrating {task.source_language} code to
{task.target_language} with Lean 4 verification.
Task: {task.display_name}
## Workflow
1. Call analyze_deps on any function to see the full dependency graph and migration order.
2. Call inspect on each function to read its source code and Lean spec before implementing.
3. Call run_tests with your implementation and iterate until all cases pass.
4. Call submit when tests pass β€” the system uses the last run_tests code automatically.
5. Write only the current function's code; verified dependencies are injected automatically.
6. For Rust tasks, treat the inspected function signature as authoritative. If a compile error mentions a type mismatch, fix the signature to match the harness before changing the body.
7. If inspect output names a Rust `Source symbol`, prefer that exact symbol over the Lean label when naming the function.
Example sequence for a 2-function task where B depends on A:
analyze_deps(B) β†’ see order: A first, then B
inspect(A) β†’ read source + Lean spec
run_tests(A, <code>) β†’ iterate until passing
submit(A)
inspect(B) β†’ read source + Lean spec
run_tests(B, <code>) β†’ iterate until passing
submit(B)
## Actions (return ONLY JSON)
- inspect: {{"type": "inspect", "function_name": "foo"}}
- analyze_deps: {{"type": "analyze_deps", "function_name": "foo"}}
- run_tests: {{"type": "run_tests", "function_name": "foo", "candidate_code": "..."}}
- submit: {{"type": "submit", "function_name": "foo"}}
Proof only: {{"type": "submit", "function_name": "foo", "lean_proof": "..."}}
{extra}
## Functions to migrate
{"".join(fn_sections)}
"""
).strip()
def _llm_action(
client: Any,
model: str,
temperature: float,
task: Task,
obs: LeanMigrateObservation,
history: list[dict],
) -> tuple[dict[str, Any], str, str]:
"""Return (action_dict, user_prompt, raw_model_output)."""
current_step = len(history) + 1
# ── Dynamic warnings ──────────────────────────────────────────────────────
warnings: list[str] = []
recent_sorry = [
h for h in history[-8:]
if "sorry" in (h.get("feedback") or "").lower() and h.get("reward", 0) < 0
]
if len(recent_sorry) >= 2:
warnings.append(
f"⚠ You have submitted `sorry` {len(recent_sorry)} times. "
"The verifier ALWAYS rejects sorry (-0.05 each time, no exceptions). "
"Write a complete proof: try `by decide`, `by native_decide`, `by simp [...]`, `by omega`."
)
if task.target_language == "rust":
recent_runtime_errors = sum(
1 for h in history[-4:]
if any(k in (h.get("feedback") or "").lower()
for k in ("cargo not found", "cannot find a runtime", "install cargo"))
)
if recent_runtime_errors >= 2:
warnings.append(
"⚠ cargo is unavailable in this environment. "
"Do NOT use run_tests β€” submit your implementation directly."
)
# ── Build prompt sections ─────────────────────────────────────────────────
state_section = (
f"<STATE>\n"
f"Step: {current_step} / {task.max_steps}\n"
f"Progress: {obs.progress:.2f}\n"
f"Verified: {', '.join(obs.verified) if obs.verified else 'None'}\n"
f"Remaining: {', '.join(obs.remaining)}\n"
f"Failing: {json.dumps(obs.failing)}\n"
f"</STATE>"
)
# Full last observation β€” no truncation; inspect output can be large
last_obs_section = ""
if obs.last_action_feedback:
last_obs_section = (
f"<LAST_OBSERVATION>\n"
f"{obs.last_action_feedback}\n"
f"</LAST_OBSERVATION>"
)
# Only the single most recent step (not a history window)
last_step_section = ""
if history:
h = history[-1]
last_step_section = (
f"<LAST_STEP>\n"
f"Step {h['step']}: type={h['action'].get('type')} "
f"fn={h['action'].get('function_name')} "
f"reward={h['reward']:.3f}\n"
f"</LAST_STEP>"
)
warnings_section = (
("<WARNINGS>\n" + "\n".join(warnings) + "\n</WARNINGS>")
if warnings else ""
)
prompt = "\n\n".join(
s for s in [
state_section,
last_obs_section,
last_step_section,
warnings_section,
"Decide the next action. Return ONLY valid JSON.",
]
if s
)
try:
response = client.beta.chat.completions.parse(
model=model,
messages=[
{"role": "system", "content": _system_prompt(task)},
{"role": "user", "content": prompt},
],
temperature=temperature,
response_format=ActionRequest,
)
raw_output = response.choices[0].message.content or ""
parsed = response.choices[0].message.parsed
if parsed is None:
raise ValueError("No structured output returned")
return parsed.model_dump(), prompt, raw_output
except Exception as exc:
console.print(f"[red]LLM error:[/red] {exc}")
return (
{
"type": "run_tests",
"function_name": obs.remaining[0] if obs.remaining else "",
"candidate_code": "# llm error",
},
prompt,
f"<error: {exc}>",
)
def _parse_action(ad: dict[str, Any]) -> LeanMigrateAction:
t = ad.get("type", "run_tests")
fn = ad.get("function_name", "")
if t == "inspect":
return InspectAction(type="inspect", function_name=fn)
if t == "analyze_deps":
return AnalyzeDepsAction(type="analyze_deps", function_name=fn)
if t == "run_tests":
return RunTestsAction(
type="run_tests",
function_name=fn,
candidate_code=ad.get("candidate_code", ""),
)
if t == "submit":
return SubmitAction(
type="submit",
function_name=fn,
target_code=ad.get("target_code"),
lean_proof=ad.get("lean_proof"),
)
return RunTestsAction(type="run_tests", function_name="", candidate_code="")
# ── Commands ──────────────────────────────────────────────────────────────────
@app.command("list")
def cmd_list() -> None:
"""List all available LeanMigrate tasks."""
tasks_info = list_tasks()
from rich import box as rich_box
table = Table(
title="[bold]LeanMigrate Tasks[/bold]",
show_header=True,
header_style="bold cyan",
box=rich_box.SIMPLE_HEAD,
padding=(0, 1),
expand=True,
)
table.add_column("task_id", style="bold yellow", no_wrap=True)
table.add_column("display_name")
table.add_column("diff", no_wrap=True)
table.add_column("migration", no_wrap=True)
table.add_column("functions")
table.add_column("steps", justify="right", no_wrap=True)
for info in tasks_info:
task = get_task(info["task_id"])
diff = info.get("difficulty", "?")
color = _difficulty_color(diff)
non_proof = [f.name for f in task.functions if not f.is_proof_required]
proof = [f.name for f in task.functions if f.is_proof_required]
fn_str = ", ".join(non_proof)
if proof:
fn_str += f" [dim](+{', '.join(proof)})[/dim]"
table.add_row(
info["task_id"],
task.display_name,
f"[{color}]{diff}[/{color}]",
f"{task.source_language} β†’ {task.target_language}",
fn_str,
str(task.max_steps),
)
console.print()
console.print(table)
console.print()
console.print(
"[dim]Play a task:[/dim] uv run python scripts/play_task.py play <task_id>"
)
console.print(
"[dim]Play all: [/dim] uv run python scripts/play_task.py play --all"
)
console.print()
@app.command("play")
def cmd_play(
task_ids: Annotated[
Optional[list[str]],
typer.Argument(help="Task IDs to play (omit to use --all)."),
] = None,
all_tasks: Annotated[bool, typer.Option("--all", help="Play every task.")] = False,
model: Annotated[
Optional[str],
typer.Option(
"--model",
"-m",
envvar="MODEL_NAME",
help="LLM model name. [env: MODEL_NAME]",
),
] = None,
api_base_url: Annotated[
Optional[str],
typer.Option(
"--api-base-url",
envvar="API_BASE_URL",
help="OpenAI-compatible API base URL. [env: API_BASE_URL]",
),
] = None,
api_key: Annotated[
Optional[str],
typer.Option(
"--api-key",
envvar="HF_TOKEN",
help="API key. [env: HF_TOKEN or GEMINI_API_KEY]",
),
] = None,
max_steps: Annotated[
int,
typer.Option("--max-steps", envvar="MAX_STEPS", help="Max LLM steps per task."),
] = 50,
temperature: Annotated[
float,
typer.Option(
"--temperature", envvar="TEMPERATURE", help="Sampling temperature."
),
] = 0.2,
quiet: Annotated[
bool,
typer.Option("--quiet", "-q", help="Suppress per-step rich output."),
] = False,
trace_dir: Annotated[
Optional[Path],
typer.Option(
"--trace-dir",
help="Directory to write per-episode JSON traces. Defaults to <project>/traces/.",
),
] = None,
) -> None:
"""Play LeanMigrate tasks with an LLM, showing rich step-by-step output."""
from openai import OpenAI
all_ids = [t["task_id"] for t in list_tasks()]
if all_tasks:
selected = all_ids
elif task_ids:
bad = [t for t in task_ids if t not in all_ids]
if bad:
console.print(f"[red]Unknown task(s):[/red] {', '.join(bad)}")
console.print(f"Available: {', '.join(all_ids)}")
raise typer.Exit(code=1)
selected = list(task_ids)
else:
console.print("[red]Specify task IDs or --all.[/red]")
console.print(f"Available: {', '.join(all_ids)}")
raise typer.Exit(code=1)
# Resolve LLM config: CLI flag β†’ .env / os.environ β†’ hardcoded fallback
resolved_model = model or os.getenv("MODEL_NAME", "gemini-2.5-flash-preview-04-17")
resolved_url = api_base_url or os.getenv(
"API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/"
)
resolved_key = api_key or os.getenv("HF_TOKEN") or os.getenv("GEMINI_API_KEY")
if not resolved_key:
console.print("[red]Error:[/red] API key required.")
console.print(
"Set [bold]HF_TOKEN[/bold] or [bold]GEMINI_API_KEY[/bold] in your "
".env file or pass [bold]--api-key[/bold]."
)
raise typer.Exit(code=1)
console.print(
Panel(
Text.assemble(
("model : ", "bold cyan"),
resolved_model,
"\n",
("base_url : ", "bold cyan"),
resolved_url,
"\n",
("tasks : ", "bold cyan"),
", ".join(selected),
),
title="[bold white]LLM CONFIG[/]",
border_style="white",
)
)
resolved_trace_dir = trace_dir or (ROOT / "traces")
client = OpenAI(base_url=resolved_url, api_key=resolved_key)
scores: dict[str, float] = {}
for task_id in selected:
scores[task_id] = _play_one_task(
client=client,
task_id=task_id,
model=resolved_model,
api_base_url=resolved_url,
max_steps=max_steps,
temperature=temperature,
quiet=quiet,
trace_dir=resolved_trace_dir,
)
# ── Final summary ──────────────────────────────────────────────────────────
console.print(Rule("[bold white]PLAYTHROUGH SUMMARY[/]"))
summary_table = Table(
show_header=True, header_style="bold cyan", box=None, padding=(0, 2)
)
summary_table.add_column("task_id", style="bold yellow")
summary_table.add_column("score", justify="right")
summary_table.add_column("result")
total = 0.0
for tid, score in scores.items():
color = "green" if score >= 0.99 else ("yellow" if score > 0 else "red")
result = (
"βœ“ complete" if score >= 0.99 else ("partial" if score > 0 else "failed")
)
summary_table.add_row(tid, f"{score:.3f}", f"[{color}]{result}[/{color}]")
total += score
if len(scores) > 1:
avg = total / len(scores)
summary_table.add_section()
summary_table.add_row("[bold]overall[/bold]", f"[bold]{avg:.3f}[/bold]", "")
console.print(summary_table)
def _play_one_task(
client: Any,
task_id: str,
model: str,
api_base_url: str,
max_steps: int,
temperature: float,
quiet: bool,
trace_dir: Path,
) -> float:
task = get_task(task_id)
env = LeanMigrateEnvironment()
obs = env.reset(task_id=task_id)
# ── Task header ────────────────────────────────────────────────────────────
diff = task.difficulty
color = _difficulty_color(diff)
header = Text()
header.append("task : ", style="bold cyan")
header.append(task_id + "\n")
header.append("name : ", style="bold cyan")
header.append(task.display_name + "\n")
header.append("difficulty: ", style="bold cyan")
header.append(diff, style=f"bold {color}")
header.append("\n")
header.append("migration : ", style="bold cyan")
header.append(f"{task.source_language} β†’ {task.target_language}\n")
header.append("episode : ", style="bold cyan")
header.append(obs.episode_id + "\n")
header.append("model : ", style="bold cyan")
header.append(model + "\n")
header.append("max_steps: ", style="bold cyan")
header.append(str(max_steps))
console.print()
console.print(Rule(f"[bold white]{task.display_name}[/] [dim]({task_id})[/dim]"))
console.print(_text_panel("TASK", header, "white"))
rewards: list[float] = []
history: list[dict] = []
trace_steps: list[dict[str, Any]] = []
code_cache: dict[str, str] = {}
steps_taken = 0
episode_system_prompt = _system_prompt(task)
# Pre-determine trace path so partial writes and final write share the same file.
trace_dir.mkdir(parents=True, exist_ok=True)
episode_start_ts = datetime.datetime.now(datetime.timezone.utc)
_partial_trace_path = (
trace_dir
/ f"{task_id}__{episode_start_ts.strftime('%Y%m%dT%H%M%SZ')}__{obs.episode_id[:8]}.json"
)
for step in range(1, max_steps + 1):
if obs.done:
break
obs_before = obs
try:
action_dict, llm_user_prompt, llm_raw_output = _llm_action(
client, model, temperature, task, obs, history
)
# Maintain code cache: record tested code, inject on submit
fn_name = action_dict.get("function_name", "")
if action_dict.get("type") == "run_tests":
code_cache[fn_name] = action_dict.get("candidate_code", "")
elif action_dict.get("type") == "submit":
if fn_name in code_cache:
action_dict["target_code"] = code_cache[fn_name]
elif action_dict.get("candidate_code"):
# Agent skipped run_tests (e.g. Rust without cargo) and attached code directly
action_dict["target_code"] = action_dict["candidate_code"]
code_cache[fn_name] = action_dict["candidate_code"]
action = _parse_action(action_dict)
obs = env.step(action)
raw_reward = float(obs.reward or 0.0)
display_reward = clamp_open_unit(raw_reward)
rewards.append(raw_reward) # raw for accurate analysis
steps_taken = step
feedback = obs.last_action_feedback or ""
rd = obs.reward_details
lean_err = (rd.lean_error if rd else None) or ""
if lean_err:
feedback = (feedback + " | Lean: " + lean_err).strip()
history.append(
{
"step": step,
"action": action_dict,
"reward": raw_reward, # raw so LLM sees genuine negative signal
"feedback": feedback,
}
)
# ── Trace record ───────────────────────────────────────────────────────
trace_steps.append(
{
"step": step,
"timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(),
"state_before": _obs_snapshot(obs_before),
"action": {
"type": action_dict.get("type"),
"function_name": action_dict.get("function_name"),
"candidate_code": action_dict.get("candidate_code"),
"lean_proof": action_dict.get("lean_proof"),
"target_code": action_dict.get("target_code"),
},
"llm": {
"user_prompt": llm_user_prompt,
"raw_output": llm_raw_output,
},
"reward": round(raw_reward, 6),
"reward_details": _reward_snapshot(obs),
"state_after": _obs_snapshot(obs),
"feedback": feedback,
"analysis": _step_analysis(obs_before, obs, action_dict, raw_reward),
}
)
# Persist partial trace after every step so interruption loses nothing.
_write_partial_trace(
out_path=_partial_trace_path,
task_id=task_id,
episode_id=obs.episode_id,
model=model,
api_base_url=api_base_url,
temperature=temperature,
task=task,
steps=trace_steps,
rewards=rewards,
system_prompt=episode_system_prompt,
)
if not quiet:
_print_step(step, obs_before, action_dict, obs, task.target_language)
if obs.done:
break
except KeyboardInterrupt:
console.print(
f"\n[yellow]Interrupted at step {step}. "
f"Partial trace saved β†’ {_partial_trace_path}[/yellow]"
)
raise
# ── Task footer ────────────────────────────────────────────────────────────
final_score = clamp_open_unit(float(obs.progress))
success = final_score >= 0.99
color = "green" if success else ("yellow" if final_score > 0 else "red")
result_label = (
"COMPLETE" if success else ("PARTIAL" if final_score > 0 else "FAILED")
)
footer = Text()
footer.append("result : ", style="bold cyan")
footer.append(result_label, style=f"bold {color}")
footer.append("\n")
footer.append("score : ", style="bold cyan")
footer.append(f"{final_score:.3f}\n", style=color)
footer.append("steps : ", style="bold cyan")
footer.append(f"{steps_taken}\n")
footer.append("verified : ", style="bold cyan")
footer.append(", ".join(obs.verified) if obs.verified else "(none)")
# ── Write trace (overwrites partial with full analytics) ───────────────────
trace_path = _write_trace(
trace_dir=trace_dir,
task_id=task_id,
episode_id=obs.episode_id,
model=model,
api_base_url=api_base_url,
temperature=temperature,
task=task,
steps=trace_steps,
rewards=rewards,
final_score=final_score,
out_path=_partial_trace_path,
system_prompt=episode_system_prompt,
)
footer.append("\ntrace : ", style="bold cyan")
footer.append(str(trace_path), style="dim")
console.print(Rule())
console.print(_text_panel(f"RESULT {task_id}", footer, color))
print(
f"[END] task={task_id} success={str(success).lower()} "
f"steps={steps_taken} score={final_score:.2f} "
f"rewards={','.join(f'{clamp_open_unit(r):.2f}' for r in rewards)} "
f"trace={trace_path}",
flush=True,
)
return final_score
if __name__ == "__main__":
app()