"""play_task — LLM-driven CLI for LeanMigrate task playthrough. Usage: uv run python scripts/play_task.py list uv run python scripts/play_task.py play lru_cache expression_eval uv run python scripts/play_task.py play --all """ from __future__ import annotations import datetime import json import os import re import sys import textwrap from pathlib import Path from typing import Annotated, Any, Literal, Optional from dotenv import load_dotenv import typer from rich.columns import Columns from rich.console import Console from rich.panel import Panel from rich.rule import Rule from rich.syntax import Syntax from rich.table import Table from rich.text import Text # Load .env before anything reads os.environ (including typer envvar resolution) load_dotenv(Path(__file__).resolve().parents[1] / ".env") ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from lean_migrate.env.grader import clamp_open_unit, oracle_result # noqa: E402 from lean_migrate.env.models import ( # noqa: E402 AnalyzeDepsAction, InspectAction, LeanMigrateAction, LeanMigrateObservation, RunTestsAction, SubmitAction, ) from lean_migrate.env.tasks import FunctionSpec, Task, get_task, list_tasks # noqa: E402 from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment # noqa: E402 # ── Pydantic schema for structured LLM output ──────────────────────────────── from pydantic import BaseModel, Field # noqa: E402 class ActionRequest(BaseModel): """Structured output schema for LLM action responses.""" type: Literal["inspect", "analyze_deps", "run_tests", "submit"] function_name: str = Field(description="Name of the function this action targets.") candidate_code: str | None = Field( default=None, description="Full implementation for run_tests. Omit for other action types.", ) lean_proof: str | None = Field( default=None, description="Lean proof text. Only required for proof tasks.", ) # ── App + console ───────────────────────────────────────────────────────────── app = typer.Typer( name="play_task", help="LLM-driven LeanMigrate task runner with rich step display.", add_completion=False, ) console = Console() # ── Rich helpers ────────────────────────────────────────────────────────────── def _difficulty_color(d: str) -> str: return {"easy": "green", "medium": "yellow", "hard": "red"}.get(d, "white") def _trim(text: str, max_lines: int = 20) -> str: lines = text.strip().splitlines() if len(lines) <= max_lines: return "\n".join(lines) return ( "\n".join(lines[:max_lines]) + f"\n[dim]… ({len(lines) - max_lines} more lines)[/dim]" ) def _split_at_semicolons(s: str) -> list[str]: """Depth-aware split at ';' — won't split inside brackets.""" parts: list[str] = [] depth = 0 buf: list[str] = [] for ch in s: if ch in "([{": depth += 1 buf.append(ch) elif ch in ")]}": depth -= 1 buf.append(ch) elif ch == ";" and depth == 0: part = "".join(buf).strip() if part: parts.append(part) buf = [] else: buf.append(ch) last = "".join(buf).strip() if last: parts.append(last) return parts # Lean tactic keywords that start a new tactic at the top level of a `by` block. _LEAN_TACTIC_KWS = re.compile( r"(? str: """Best-effort Lean 4 single-line → multi-line formatter.""" # Already multi-line — just normalise indentation slightly. if "\n" in code.strip(): lines = [ln.rstrip() for ln in code.strip().splitlines()] return "\n".join(lines) s = code.strip() # Match: := by m = re.match(r"^(.*?)\s*:=\s*by\s+(.+)$", s, re.DOTALL) if m: head, tactics_str = m.group(1).strip(), m.group(2).strip() # First split at semicolons tactics = _split_at_semicolons(tactics_str) # Then split each part further at tactic keyword boundaries (depth 0 only) expanded: list[str] = [] for part in tactics: # Find positions of tactic-keyword starts (skip inside brackets) splits: list[int] = [0] depth = 0 i = 0 while i < len(part): if part[i] in "([{": depth += 1 elif part[i] in ")]}": depth -= 1 elif depth == 0 and i > 0 and part[i - 1] == " ": hit = _LEAN_TACTIC_KWS.match(part, i) if hit: splits.append(i) i += 1 for j, start in enumerate(splits): end = splits[j + 1] if j + 1 < len(splits) else len(part) chunk = part[start:end].strip() if chunk: expanded.append(chunk) if len(expanded) == 1: return f"{head} := by\n {expanded[0]}" return head + " := by\n" + "\n".join(f" {t}" for t in expanded) # Match: := (no `by`) m2 = re.match(r"^(.*?)\s*:=\s*(.+)$", s, re.DOTALL) if m2: head, body = m2.group(1).strip(), m2.group(2).strip() return f"{head} :=\n {body}" return s def _format_code(code: str, language: str) -> str: """Run a language-specific formatter. Returns original on any failure.""" import subprocess try: if language == "lean": return _format_lean(code) if language == "typescript": result = subprocess.run( ["npx", "--yes", "prettier", "--stdin-filepath", "candidate.ts"], input=code, capture_output=True, text=True, timeout=15, ) elif language == "python": result = subprocess.run( ["ruff", "format", "--stdin-filename", "candidate.py", "-"], input=code, capture_output=True, text=True, timeout=10, ) elif language == "rust": result = subprocess.run( ["rustfmt", "--edition", "2021"], input=code, capture_output=True, text=True, timeout=10, ) else: # Unknown language — return as-is (word_wrap still applies in panel) return code return result.stdout if result.returncode == 0 and result.stdout.strip() else code except Exception: return code def _code_panel(title: str, code: str, language: str, border: str = "blue") -> Panel: formatted = _format_code(code, language) try: body = Syntax(formatted, language, theme="monokai", line_numbers=True, word_wrap=True) except Exception: body = Text(formatted, no_wrap=False) return Panel(body, title=f"[bold {border}]{title}[/]", border_style=border) def _text_panel(title: str, body: Text | str, border: str = "cyan") -> Panel: if isinstance(body, str): body = Text(body) return Panel(body, title=f"[bold {border}]{title}[/]", border_style=border) def _state_text(obs: LeanMigrateObservation) -> Text: t = Text() t.append("progress : ", style="bold cyan") t.append(f"{clamp_open_unit(float(obs.progress)):.3f}\n") t.append("verified : ", style="bold cyan") t.append(", ".join(obs.verified) if obs.verified else "(none)") t.append("\n") t.append("remaining: ", style="bold cyan") t.append(", ".join(obs.remaining) if obs.remaining else "(none)") t.append("\n") t.append("failing : ", style="bold cyan") t.append(json.dumps(obs.failing) if obs.failing else "(none)") t.append("\n") t.append("done : ", style="bold cyan") t.append(str(obs.done).lower()) return t def _action_text(ad: dict[str, Any]) -> Text: t = Text() t.append("type : ", style="bold magenta") t.append(ad.get("type", "?")) t.append("\n") t.append("function : ", style="bold magenta") t.append(ad.get("function_name", "?")) return t def _obs_text(obs: LeanMigrateObservation) -> Text: t = Text() reward_val = clamp_open_unit(float(obs.reward or 0.0)) color = "green" if reward_val > 0 else ("red" if reward_val < 0 else "yellow") t.append("reward : ", style=f"bold {color}") t.append(f"{reward_val:.3f}\n", style=color) t.append("done : ", style="bold cyan") t.append(str(obs.done).lower()) return t # ── Trace helpers ───────────────────────────────────────────────────────────── def _obs_snapshot(obs: LeanMigrateObservation) -> dict[str, Any]: """Serialise the parts of an observation useful for offline analysis.""" return { "progress": float(obs.progress), "verified": list(obs.verified), "remaining": list(obs.remaining), "failing": dict(obs.failing), "done": bool(obs.done), "episode_step": int(obs.episode_step), } def _reward_snapshot(obs: LeanMigrateObservation) -> dict[str, Any] | None: rd = obs.reward_details if rd is None: return None return { "score": float(rd.score), "cumulative_score": float(rd.cumulative_score), "tests_passed": rd.tests_passed, "tests_total": rd.tests_total, "proof_compiled": rd.proof_compiled, "breakdown": dict(rd.breakdown), "feedback": rd.feedback, "lean_error": rd.lean_error, } def _step_analysis( obs_before: LeanMigrateObservation, obs_after: LeanMigrateObservation, action_dict: dict[str, Any], reward: float, ) -> dict[str, Any]: """Derived signals useful for reward-shaping / behaviour analysis.""" progress_delta = float(obs_after.progress) - float(obs_before.progress) rd = obs_after.reward_details tests_pass_rate: float | None = None if rd and rd.tests_total: tests_pass_rate = (rd.tests_passed or 0) / rd.tests_total return { "progress_delta": round(progress_delta, 6), "tests_pass_rate": tests_pass_rate, "lean_error_present": bool(rd and rd.lean_error), "proof_compiled": rd.proof_compiled if rd else None, "new_function_verified": progress_delta > 0, "reward_sign": "positive" if reward > 0 else ("negative" if reward < 0 else "zero"), "action_type": action_dict.get("type"), "used_inspection": action_dict.get("type") in {"inspect", "analyze_deps"}, } def _write_partial_trace( out_path: Path, task_id: str, episode_id: str, model: str, api_base_url: str, temperature: float, task: Any, steps: list[dict[str, Any]], rewards: list[float], system_prompt: str = "", ) -> None: """Write an in-progress snapshot so work is preserved on interruption.""" trace = { "meta": { "task_id": task_id, "episode_id": episode_id, "model": model, "api_base_url": api_base_url, "temperature": temperature, "difficulty": task.difficulty, "source_language": task.source_language, "target_language": task.target_language, "functions": [f.name for f in task.functions], "proof_functions": [f.name for f in task.functions if f.is_proof_required], "max_steps": task.max_steps, "system_prompt": system_prompt, "partial": True, }, "steps": steps, "summary": None, } out_path.write_text(json.dumps(trace, indent=2, ensure_ascii=False)) def _write_trace( trace_dir: Path, task_id: str, episode_id: str, model: str, api_base_url: str, temperature: float, task: Any, steps: list[dict[str, Any]], rewards: list[float], final_score: float, out_path: Path | None = None, system_prompt: str = "", ) -> Path: """Write the full episode trace to a JSON file and return its path.""" trace_dir.mkdir(parents=True, exist_ok=True) ts = datetime.datetime.now(datetime.timezone.utc) ts_str = ts.strftime("%Y%m%dT%H%M%SZ") # ── Summary analytics ────────────────────────────────────────────────────── action_counts: dict[str, int] = {} lean_errors = 0 negative_rewards = 0 max_consecutive_run_tests = 0 current_run_tests_streak = 0 total_code_chars = 0 total_test_runs = 0 total_tests_passed = 0 total_tests_total = 0 for s in steps: atype = s["action"]["type"] action_counts[atype] = action_counts.get(atype, 0) + 1 if s["analysis"]["lean_error_present"]: lean_errors += 1 if s["reward"] < 0: negative_rewards += 1 if atype == "run_tests": current_run_tests_streak += 1 max_consecutive_run_tests = max( max_consecutive_run_tests, current_run_tests_streak ) code = s["action"].get("candidate_code") or "" total_code_chars += len(code) total_test_runs += 1 rd = s["reward_details"] or {} total_tests_passed += rd.get("tests_passed") or 0 total_tests_total += rd.get("tests_total") or 0 else: current_run_tests_streak = 0 n_types = len(action_counts) total_actions = len(steps) # Shannon entropy of action types as a diversity score import math diversity = 0.0 if total_actions > 0 and n_types > 1: for cnt in action_counts.values(): p = cnt / total_actions if p > 0: diversity -= p * math.log2(p) diversity /= math.log2(len({"inspect", "analyze_deps", "run_tests", "submit"})) final_obs = steps[-1]["state_after"] if steps else {} summary = { "total_steps": len(steps), "final_score": round(final_score, 6), "success": final_score >= 0.99, "rewards": [round(r, 4) for r in rewards], "reward_sum": round(sum(rewards), 4), "action_type_counts": action_counts, "action_diversity_score": round(diversity, 4), "functions_verified": final_obs.get("verified", []), "functions_remaining": final_obs.get("remaining", []), "functions_failing": list((final_obs.get("failing") or {}).keys()), "max_consecutive_run_tests": max_consecutive_run_tests, "lean_errors_count": lean_errors, "negative_rewards_count": negative_rewards, "avg_tests_pass_rate": ( round(total_tests_passed / total_tests_total, 4) if total_tests_total > 0 else None ), "never_used_inspect": action_counts.get("inspect", 0) == 0, "never_used_analyze_deps": action_counts.get("analyze_deps", 0) == 0, } trace = { "meta": { "task_id": task_id, "episode_id": episode_id, "model": model, "api_base_url": api_base_url, "temperature": temperature, "timestamp": ts.isoformat(), "difficulty": task.difficulty, "source_language": task.source_language, "target_language": task.target_language, "functions": [f.name for f in task.functions], "proof_functions": [f.name for f in task.functions if f.is_proof_required], "max_steps": task.max_steps, "system_prompt": system_prompt, }, "steps": steps, "summary": summary, } if out_path is None: filename = f"{task_id}__{ts_str}__{episode_id[:8]}.json" out_path = trace_dir / filename out_path.write_text(json.dumps(trace, indent=2, ensure_ascii=False)) return out_path def _print_step( step: int, obs_before: LeanMigrateObservation, action_dict: dict[str, Any], obs_after: LeanMigrateObservation, target_language: str, ) -> None: action_type = action_dict.get("type", "?") fn_name = action_dict.get("function_name", "?") console.print( Rule(f"[bold yellow]STEP {step}[/] [white]{action_type}({fn_name})[/]") ) # ── State (before) + Action summary side-by-side console.print( Columns( [ _text_panel("STATE (before)", _state_text(obs_before), "cyan"), _text_panel("ACTION", _action_text(action_dict), "magenta"), ], expand=True, ) ) # ── Code block (if run_tests) code = action_dict.get("candidate_code") if code: lang = ( "rust" if target_language == "rust" else ("typescript" if target_language == "typescript" else "python") ) console.print(_code_panel(f"CANDIDATE CODE [{lang}]", code, lang, "blue")) # ── Lean proof (if submit with proof) proof = action_dict.get("lean_proof") if proof: console.print(_code_panel("LEAN PROOF", proof, "lean", "green")) # ── Observation (after) rd = obs_after.reward_details feedback = obs_after.last_action_feedback or "" lean_err = (rd.lean_error if rd else None) or "" reward_val = clamp_open_unit(float(obs_after.reward or 0.0)) border = "green" if reward_val > 0 else ("red" if reward_val < 0 else "yellow") obs_body = Text() obs_body.append_text(_obs_text(obs_after)) if feedback: obs_body.append("\n\nfeedback :\n", style="bold cyan") obs_body.append(_trim(feedback, 15)) if lean_err: obs_body.append("\n\nlean error:\n", style="bold red") obs_body.append(_trim(lean_err, 20)) console.print(_text_panel("OBSERVATION", obs_body, border)) # ── LLM helpers (from inference.py) ────────────────────────────────────────── def _get_simple_sample(task: Task, fn: FunctionSpec) -> str: cases = task.sample_inputs.get(fn.name, []) if not cases or fn.is_proof_required: return "" case = cases[0] expected = oracle_result(task.task_id, fn.name, case.args) args_repr = ", ".join(repr(a) for a in case.args) if len(args_repr) > 300: args_repr = args_repr[:297] + "..." return f"Sample: ({args_repr}) → {expected!r}" def _get_language_notes(task: Task) -> str: if task.target_language == "typescript": return textwrap.dedent(""" ## TypeScript notes - Use top-level `function` declarations (not arrow functions / const assignments). - Dependency code is injected automatically — write only the current function. """).strip() if task.target_language == "rust": return textwrap.dedent(""" ## Rust notes - Use serde_json and serde (already in Cargo.toml). - Write only the current function. Dependencies are injected as separate pub fn definitions. - Match the inspected Rust signature exactly, including ownership and mutability. The harness passes owned JSON-decoded values into your function, so do not add `&`, `&mut`, or slice parameters unless the inspected source explicitly shows them. - If the inspect output shows a `Source symbol`, use that exact Rust function name in your code. - Return types serializable to serde_json::Value (Vec, Option, tuples all work). - run_tests compiles via cargo. If run_tests returns "cargo not found", cargo is unavailable — submit your implementation directly without run_tests. """).strip() return "" def _get_proof_notes(task: Task) -> str: if any(fn.is_proof_required for fn in task.functions): return textwrap.dedent(""" ## Proof task notes - NEVER use `sorry` — the verifier rejects it IMMEDIATELY, gives -0.05 reward, and does not compile Lean at all. `sorry` has zero partial credit. - If you cannot write the full proof yet, DO NOT submit. Call `inspect` on the proof function to read the exact theorem and full Lean spec, then reason step by step. - Useful first-try tactics: `by decide`, `by native_decide`, `by simp [...]`, `by omega`, `by rfl`, `by constructor <;> simp`. - Use qualified ADT constructors from the spec (e.g. `Spec.MyType.mk`). """).strip() return "" def _system_prompt(task: Task) -> str: fn_sections = [] for fn in task.functions: proof_tag = " [PROOF REQUIRED — submit with lean_proof field]" if fn.is_proof_required else "" fn_sections.append( f"### {fn.name}{proof_tag}\n" f"Description: {fn.description}\n" f"Depends on: {fn.depends_on or ['(none)']}\n" ) language_notes = _get_language_notes(task) proof_notes = _get_proof_notes(task) extra = "\n\n".join(s for s in [language_notes, proof_notes] if s) return textwrap.dedent( f""" You are an expert software engineer migrating {task.source_language} code to {task.target_language} with Lean 4 verification. Task: {task.display_name} ## Workflow 1. Call analyze_deps on any function to see the full dependency graph and migration order. 2. Call inspect on each function to read its source code and Lean spec before implementing. 3. Call run_tests with your implementation and iterate until all cases pass. 4. Call submit when tests pass — the system uses the last run_tests code automatically. 5. Write only the current function's code; verified dependencies are injected automatically. 6. For Rust tasks, treat the inspected function signature as authoritative. If a compile error mentions a type mismatch, fix the signature to match the harness before changing the body. 7. If inspect output names a Rust `Source symbol`, prefer that exact symbol over the Lean label when naming the function. Example sequence for a 2-function task where B depends on A: analyze_deps(B) → see order: A first, then B inspect(A) → read source + Lean spec run_tests(A, ) → iterate until passing submit(A) inspect(B) → read source + Lean spec run_tests(B, ) → iterate until passing submit(B) ## Actions (return ONLY JSON) - inspect: {{"type": "inspect", "function_name": "foo"}} - analyze_deps: {{"type": "analyze_deps", "function_name": "foo"}} - run_tests: {{"type": "run_tests", "function_name": "foo", "candidate_code": "..."}} - submit: {{"type": "submit", "function_name": "foo"}} Proof only: {{"type": "submit", "function_name": "foo", "lean_proof": "..."}} {extra} ## Functions to migrate {"".join(fn_sections)} """ ).strip() def _llm_action( client: Any, model: str, temperature: float, task: Task, obs: LeanMigrateObservation, history: list[dict], ) -> tuple[dict[str, Any], str, str]: """Return (action_dict, user_prompt, raw_model_output).""" current_step = len(history) + 1 # ── Dynamic warnings ────────────────────────────────────────────────────── warnings: list[str] = [] recent_sorry = [ h for h in history[-8:] if "sorry" in (h.get("feedback") or "").lower() and h.get("reward", 0) < 0 ] if len(recent_sorry) >= 2: warnings.append( f"⚠ You have submitted `sorry` {len(recent_sorry)} times. " "The verifier ALWAYS rejects sorry (-0.05 each time, no exceptions). " "Write a complete proof: try `by decide`, `by native_decide`, `by simp [...]`, `by omega`." ) if task.target_language == "rust": recent_runtime_errors = sum( 1 for h in history[-4:] if any(k in (h.get("feedback") or "").lower() for k in ("cargo not found", "cannot find a runtime", "install cargo")) ) if recent_runtime_errors >= 2: warnings.append( "⚠ cargo is unavailable in this environment. " "Do NOT use run_tests — submit your implementation directly." ) # ── Build prompt sections ───────────────────────────────────────────────── state_section = ( f"\n" f"Step: {current_step} / {task.max_steps}\n" f"Progress: {obs.progress:.2f}\n" f"Verified: {', '.join(obs.verified) if obs.verified else 'None'}\n" f"Remaining: {', '.join(obs.remaining)}\n" f"Failing: {json.dumps(obs.failing)}\n" f"" ) # Full last observation — no truncation; inspect output can be large last_obs_section = "" if obs.last_action_feedback: last_obs_section = ( f"\n" f"{obs.last_action_feedback}\n" f"" ) # Only the single most recent step (not a history window) last_step_section = "" if history: h = history[-1] last_step_section = ( f"\n" f"Step {h['step']}: type={h['action'].get('type')} " f"fn={h['action'].get('function_name')} " f"reward={h['reward']:.3f}\n" f"" ) warnings_section = ( ("\n" + "\n".join(warnings) + "\n") if warnings else "" ) prompt = "\n\n".join( s for s in [ state_section, last_obs_section, last_step_section, warnings_section, "Decide the next action. Return ONLY valid JSON.", ] if s ) try: response = client.beta.chat.completions.parse( model=model, messages=[ {"role": "system", "content": _system_prompt(task)}, {"role": "user", "content": prompt}, ], temperature=temperature, response_format=ActionRequest, ) raw_output = response.choices[0].message.content or "" parsed = response.choices[0].message.parsed if parsed is None: raise ValueError("No structured output returned") return parsed.model_dump(), prompt, raw_output except Exception as exc: console.print(f"[red]LLM error:[/red] {exc}") return ( { "type": "run_tests", "function_name": obs.remaining[0] if obs.remaining else "", "candidate_code": "# llm error", }, prompt, f"", ) def _parse_action(ad: dict[str, Any]) -> LeanMigrateAction: t = ad.get("type", "run_tests") fn = ad.get("function_name", "") if t == "inspect": return InspectAction(type="inspect", function_name=fn) if t == "analyze_deps": return AnalyzeDepsAction(type="analyze_deps", function_name=fn) if t == "run_tests": return RunTestsAction( type="run_tests", function_name=fn, candidate_code=ad.get("candidate_code", ""), ) if t == "submit": return SubmitAction( type="submit", function_name=fn, target_code=ad.get("target_code"), lean_proof=ad.get("lean_proof"), ) return RunTestsAction(type="run_tests", function_name="", candidate_code="") # ── Commands ────────────────────────────────────────────────────────────────── @app.command("list") def cmd_list() -> None: """List all available LeanMigrate tasks.""" tasks_info = list_tasks() from rich import box as rich_box table = Table( title="[bold]LeanMigrate Tasks[/bold]", show_header=True, header_style="bold cyan", box=rich_box.SIMPLE_HEAD, padding=(0, 1), expand=True, ) table.add_column("task_id", style="bold yellow", no_wrap=True) table.add_column("display_name") table.add_column("diff", no_wrap=True) table.add_column("migration", no_wrap=True) table.add_column("functions") table.add_column("steps", justify="right", no_wrap=True) for info in tasks_info: task = get_task(info["task_id"]) diff = info.get("difficulty", "?") color = _difficulty_color(diff) non_proof = [f.name for f in task.functions if not f.is_proof_required] proof = [f.name for f in task.functions if f.is_proof_required] fn_str = ", ".join(non_proof) if proof: fn_str += f" [dim](+{', '.join(proof)})[/dim]" table.add_row( info["task_id"], task.display_name, f"[{color}]{diff}[/{color}]", f"{task.source_language} → {task.target_language}", fn_str, str(task.max_steps), ) console.print() console.print(table) console.print() console.print( "[dim]Play a task:[/dim] uv run python scripts/play_task.py play " ) console.print( "[dim]Play all: [/dim] uv run python scripts/play_task.py play --all" ) console.print() @app.command("play") def cmd_play( task_ids: Annotated[ Optional[list[str]], typer.Argument(help="Task IDs to play (omit to use --all)."), ] = None, all_tasks: Annotated[bool, typer.Option("--all", help="Play every task.")] = False, model: Annotated[ Optional[str], typer.Option( "--model", "-m", envvar="MODEL_NAME", help="LLM model name. [env: MODEL_NAME]", ), ] = None, api_base_url: Annotated[ Optional[str], typer.Option( "--api-base-url", envvar="API_BASE_URL", help="OpenAI-compatible API base URL. [env: API_BASE_URL]", ), ] = None, api_key: Annotated[ Optional[str], typer.Option( "--api-key", envvar="HF_TOKEN", help="API key. [env: HF_TOKEN or GEMINI_API_KEY]", ), ] = None, max_steps: Annotated[ int, typer.Option("--max-steps", envvar="MAX_STEPS", help="Max LLM steps per task."), ] = 50, temperature: Annotated[ float, typer.Option( "--temperature", envvar="TEMPERATURE", help="Sampling temperature." ), ] = 0.2, quiet: Annotated[ bool, typer.Option("--quiet", "-q", help="Suppress per-step rich output."), ] = False, trace_dir: Annotated[ Optional[Path], typer.Option( "--trace-dir", help="Directory to write per-episode JSON traces. Defaults to /traces/.", ), ] = None, ) -> None: """Play LeanMigrate tasks with an LLM, showing rich step-by-step output.""" from openai import OpenAI all_ids = [t["task_id"] for t in list_tasks()] if all_tasks: selected = all_ids elif task_ids: bad = [t for t in task_ids if t not in all_ids] if bad: console.print(f"[red]Unknown task(s):[/red] {', '.join(bad)}") console.print(f"Available: {', '.join(all_ids)}") raise typer.Exit(code=1) selected = list(task_ids) else: console.print("[red]Specify task IDs or --all.[/red]") console.print(f"Available: {', '.join(all_ids)}") raise typer.Exit(code=1) # Resolve LLM config: CLI flag → .env / os.environ → hardcoded fallback resolved_model = model or os.getenv("MODEL_NAME", "gemini-2.5-flash-preview-04-17") resolved_url = api_base_url or os.getenv( "API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/" ) resolved_key = api_key or os.getenv("HF_TOKEN") or os.getenv("GEMINI_API_KEY") if not resolved_key: console.print("[red]Error:[/red] API key required.") console.print( "Set [bold]HF_TOKEN[/bold] or [bold]GEMINI_API_KEY[/bold] in your " ".env file or pass [bold]--api-key[/bold]." ) raise typer.Exit(code=1) console.print( Panel( Text.assemble( ("model : ", "bold cyan"), resolved_model, "\n", ("base_url : ", "bold cyan"), resolved_url, "\n", ("tasks : ", "bold cyan"), ", ".join(selected), ), title="[bold white]LLM CONFIG[/]", border_style="white", ) ) resolved_trace_dir = trace_dir or (ROOT / "traces") client = OpenAI(base_url=resolved_url, api_key=resolved_key) scores: dict[str, float] = {} for task_id in selected: scores[task_id] = _play_one_task( client=client, task_id=task_id, model=resolved_model, api_base_url=resolved_url, max_steps=max_steps, temperature=temperature, quiet=quiet, trace_dir=resolved_trace_dir, ) # ── Final summary ────────────────────────────────────────────────────────── console.print(Rule("[bold white]PLAYTHROUGH SUMMARY[/]")) summary_table = Table( show_header=True, header_style="bold cyan", box=None, padding=(0, 2) ) summary_table.add_column("task_id", style="bold yellow") summary_table.add_column("score", justify="right") summary_table.add_column("result") total = 0.0 for tid, score in scores.items(): color = "green" if score >= 0.99 else ("yellow" if score > 0 else "red") result = ( "✓ complete" if score >= 0.99 else ("partial" if score > 0 else "failed") ) summary_table.add_row(tid, f"{score:.3f}", f"[{color}]{result}[/{color}]") total += score if len(scores) > 1: avg = total / len(scores) summary_table.add_section() summary_table.add_row("[bold]overall[/bold]", f"[bold]{avg:.3f}[/bold]", "") console.print(summary_table) def _play_one_task( client: Any, task_id: str, model: str, api_base_url: str, max_steps: int, temperature: float, quiet: bool, trace_dir: Path, ) -> float: task = get_task(task_id) env = LeanMigrateEnvironment() obs = env.reset(task_id=task_id) # ── Task header ──────────────────────────────────────────────────────────── diff = task.difficulty color = _difficulty_color(diff) header = Text() header.append("task : ", style="bold cyan") header.append(task_id + "\n") header.append("name : ", style="bold cyan") header.append(task.display_name + "\n") header.append("difficulty: ", style="bold cyan") header.append(diff, style=f"bold {color}") header.append("\n") header.append("migration : ", style="bold cyan") header.append(f"{task.source_language} → {task.target_language}\n") header.append("episode : ", style="bold cyan") header.append(obs.episode_id + "\n") header.append("model : ", style="bold cyan") header.append(model + "\n") header.append("max_steps: ", style="bold cyan") header.append(str(max_steps)) console.print() console.print(Rule(f"[bold white]{task.display_name}[/] [dim]({task_id})[/dim]")) console.print(_text_panel("TASK", header, "white")) rewards: list[float] = [] history: list[dict] = [] trace_steps: list[dict[str, Any]] = [] code_cache: dict[str, str] = {} steps_taken = 0 episode_system_prompt = _system_prompt(task) # Pre-determine trace path so partial writes and final write share the same file. trace_dir.mkdir(parents=True, exist_ok=True) episode_start_ts = datetime.datetime.now(datetime.timezone.utc) _partial_trace_path = ( trace_dir / f"{task_id}__{episode_start_ts.strftime('%Y%m%dT%H%M%SZ')}__{obs.episode_id[:8]}.json" ) for step in range(1, max_steps + 1): if obs.done: break obs_before = obs try: action_dict, llm_user_prompt, llm_raw_output = _llm_action( client, model, temperature, task, obs, history ) # Maintain code cache: record tested code, inject on submit fn_name = action_dict.get("function_name", "") if action_dict.get("type") == "run_tests": code_cache[fn_name] = action_dict.get("candidate_code", "") elif action_dict.get("type") == "submit": if fn_name in code_cache: action_dict["target_code"] = code_cache[fn_name] elif action_dict.get("candidate_code"): # Agent skipped run_tests (e.g. Rust without cargo) and attached code directly action_dict["target_code"] = action_dict["candidate_code"] code_cache[fn_name] = action_dict["candidate_code"] action = _parse_action(action_dict) obs = env.step(action) raw_reward = float(obs.reward or 0.0) display_reward = clamp_open_unit(raw_reward) rewards.append(raw_reward) # raw for accurate analysis steps_taken = step feedback = obs.last_action_feedback or "" rd = obs.reward_details lean_err = (rd.lean_error if rd else None) or "" if lean_err: feedback = (feedback + " | Lean: " + lean_err).strip() history.append( { "step": step, "action": action_dict, "reward": raw_reward, # raw so LLM sees genuine negative signal "feedback": feedback, } ) # ── Trace record ─────────────────────────────────────────────────────── trace_steps.append( { "step": step, "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), "state_before": _obs_snapshot(obs_before), "action": { "type": action_dict.get("type"), "function_name": action_dict.get("function_name"), "candidate_code": action_dict.get("candidate_code"), "lean_proof": action_dict.get("lean_proof"), "target_code": action_dict.get("target_code"), }, "llm": { "user_prompt": llm_user_prompt, "raw_output": llm_raw_output, }, "reward": round(raw_reward, 6), "reward_details": _reward_snapshot(obs), "state_after": _obs_snapshot(obs), "feedback": feedback, "analysis": _step_analysis(obs_before, obs, action_dict, raw_reward), } ) # Persist partial trace after every step so interruption loses nothing. _write_partial_trace( out_path=_partial_trace_path, task_id=task_id, episode_id=obs.episode_id, model=model, api_base_url=api_base_url, temperature=temperature, task=task, steps=trace_steps, rewards=rewards, system_prompt=episode_system_prompt, ) if not quiet: _print_step(step, obs_before, action_dict, obs, task.target_language) if obs.done: break except KeyboardInterrupt: console.print( f"\n[yellow]Interrupted at step {step}. " f"Partial trace saved → {_partial_trace_path}[/yellow]" ) raise # ── Task footer ──────────────────────────────────────────────────────────── final_score = clamp_open_unit(float(obs.progress)) success = final_score >= 0.99 color = "green" if success else ("yellow" if final_score > 0 else "red") result_label = ( "COMPLETE" if success else ("PARTIAL" if final_score > 0 else "FAILED") ) footer = Text() footer.append("result : ", style="bold cyan") footer.append(result_label, style=f"bold {color}") footer.append("\n") footer.append("score : ", style="bold cyan") footer.append(f"{final_score:.3f}\n", style=color) footer.append("steps : ", style="bold cyan") footer.append(f"{steps_taken}\n") footer.append("verified : ", style="bold cyan") footer.append(", ".join(obs.verified) if obs.verified else "(none)") # ── Write trace (overwrites partial with full analytics) ─────────────────── trace_path = _write_trace( trace_dir=trace_dir, task_id=task_id, episode_id=obs.episode_id, model=model, api_base_url=api_base_url, temperature=temperature, task=task, steps=trace_steps, rewards=rewards, final_score=final_score, out_path=_partial_trace_path, system_prompt=episode_system_prompt, ) footer.append("\ntrace : ", style="bold cyan") footer.append(str(trace_path), style="dim") console.print(Rule()) console.print(_text_panel(f"RESULT {task_id}", footer, color)) print( f"[END] task={task_id} success={str(success).lower()} " f"steps={steps_taken} score={final_score:.2f} " f"rewards={','.join(f'{clamp_open_unit(r):.2f}' for r in rewards)} " f"trace={trace_path}", flush=True, ) return final_score if __name__ == "__main__": app()