Spaces:
Sleeping
Sleeping
| """play_task β LLM-driven CLI for LeanMigrate task playthrough. | |
| Usage: | |
| uv run python scripts/play_task.py list | |
| uv run python scripts/play_task.py play lru_cache expression_eval | |
| uv run python scripts/play_task.py play --all | |
| """ | |
| from __future__ import annotations | |
| import datetime | |
| import json | |
| import os | |
| import re | |
| import sys | |
| import textwrap | |
| from pathlib import Path | |
| from typing import Annotated, Any, Literal, Optional | |
| from dotenv import load_dotenv | |
| import typer | |
| from rich.columns import Columns | |
| from rich.console import Console | |
| from rich.panel import Panel | |
| from rich.rule import Rule | |
| from rich.syntax import Syntax | |
| from rich.table import Table | |
| from rich.text import Text | |
| # Load .env before anything reads os.environ (including typer envvar resolution) | |
| load_dotenv(Path(__file__).resolve().parents[1] / ".env") | |
| ROOT = Path(__file__).resolve().parents[1] | |
| if str(ROOT) not in sys.path: | |
| sys.path.insert(0, str(ROOT)) | |
| from lean_migrate.env.grader import clamp_open_unit, oracle_result # noqa: E402 | |
| from lean_migrate.env.models import ( # noqa: E402 | |
| AnalyzeDepsAction, | |
| InspectAction, | |
| LeanMigrateAction, | |
| LeanMigrateObservation, | |
| RunTestsAction, | |
| SubmitAction, | |
| ) | |
| from lean_migrate.env.tasks import FunctionSpec, Task, get_task, list_tasks # noqa: E402 | |
| from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment # noqa: E402 | |
| # ββ Pydantic schema for structured LLM output ββββββββββββββββββββββββββββββββ | |
| from pydantic import BaseModel, Field # noqa: E402 | |
| class ActionRequest(BaseModel): | |
| """Structured output schema for LLM action responses.""" | |
| type: Literal["inspect", "analyze_deps", "run_tests", "submit"] | |
| function_name: str = Field(description="Name of the function this action targets.") | |
| candidate_code: str | None = Field( | |
| default=None, | |
| description="Full implementation for run_tests. Omit for other action types.", | |
| ) | |
| lean_proof: str | None = Field( | |
| default=None, | |
| description="Lean proof text. Only required for proof tasks.", | |
| ) | |
| # ββ App + console βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app = typer.Typer( | |
| name="play_task", | |
| help="LLM-driven LeanMigrate task runner with rich step display.", | |
| add_completion=False, | |
| ) | |
| console = Console() | |
| # ββ Rich helpers ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _difficulty_color(d: str) -> str: | |
| return {"easy": "green", "medium": "yellow", "hard": "red"}.get(d, "white") | |
| def _trim(text: str, max_lines: int = 20) -> str: | |
| lines = text.strip().splitlines() | |
| if len(lines) <= max_lines: | |
| return "\n".join(lines) | |
| return ( | |
| "\n".join(lines[:max_lines]) | |
| + f"\n[dim]β¦ ({len(lines) - max_lines} more lines)[/dim]" | |
| ) | |
| def _split_at_semicolons(s: str) -> list[str]: | |
| """Depth-aware split at ';' β won't split inside brackets.""" | |
| parts: list[str] = [] | |
| depth = 0 | |
| buf: list[str] = [] | |
| for ch in s: | |
| if ch in "([{": | |
| depth += 1 | |
| buf.append(ch) | |
| elif ch in ")]}": | |
| depth -= 1 | |
| buf.append(ch) | |
| elif ch == ";" and depth == 0: | |
| part = "".join(buf).strip() | |
| if part: | |
| parts.append(part) | |
| buf = [] | |
| else: | |
| buf.append(ch) | |
| last = "".join(buf).strip() | |
| if last: | |
| parts.append(last) | |
| return parts | |
| # Lean tactic keywords that start a new tactic at the top level of a `by` block. | |
| _LEAN_TACTIC_KWS = re.compile( | |
| r"(?<!\w)(" | |
| r"intro\b|intros\b|apply\b|refine\b|exact\b|simp\b|rfl\b|norm_num\b|" | |
| r"decide\b|ring\b|omega\b|constructor\b|cases\b|induction\b|have\b|" | |
| r"obtain\b|rw\s*\[|rewrite\s*\[|calc\b|show\b|suffices\b|use\b|" | |
| r"trivial\b|tauto\b|aesop\b|linarith\b|nlinarith\b|positivity\b|" | |
| r"field_simp\b|push_neg\b|contrapose\b|by_contra\b|by_cases\b|" | |
| r"rcases\b|ext\b|funext\b|congr\b|subst\b|assumption\b|" | |
| r"contradiction\b|scalar_tac\b|progress\b|next\b|all_goals\b|" | |
| r"any_goals\b|first\b|repeat\b|done\b|skip\b" | |
| r")" | |
| ) | |
| def _format_lean(code: str) -> str: | |
| """Best-effort Lean 4 single-line β multi-line formatter.""" | |
| # Already multi-line β just normalise indentation slightly. | |
| if "\n" in code.strip(): | |
| lines = [ln.rstrip() for ln in code.strip().splitlines()] | |
| return "\n".join(lines) | |
| s = code.strip() | |
| # Match: <decl head> := by <tactics> | |
| m = re.match(r"^(.*?)\s*:=\s*by\s+(.+)$", s, re.DOTALL) | |
| if m: | |
| head, tactics_str = m.group(1).strip(), m.group(2).strip() | |
| # First split at semicolons | |
| tactics = _split_at_semicolons(tactics_str) | |
| # Then split each part further at tactic keyword boundaries (depth 0 only) | |
| expanded: list[str] = [] | |
| for part in tactics: | |
| # Find positions of tactic-keyword starts (skip inside brackets) | |
| splits: list[int] = [0] | |
| depth = 0 | |
| i = 0 | |
| while i < len(part): | |
| if part[i] in "([{": | |
| depth += 1 | |
| elif part[i] in ")]}": | |
| depth -= 1 | |
| elif depth == 0 and i > 0 and part[i - 1] == " ": | |
| hit = _LEAN_TACTIC_KWS.match(part, i) | |
| if hit: | |
| splits.append(i) | |
| i += 1 | |
| for j, start in enumerate(splits): | |
| end = splits[j + 1] if j + 1 < len(splits) else len(part) | |
| chunk = part[start:end].strip() | |
| if chunk: | |
| expanded.append(chunk) | |
| if len(expanded) == 1: | |
| return f"{head} := by\n {expanded[0]}" | |
| return head + " := by\n" + "\n".join(f" {t}" for t in expanded) | |
| # Match: <decl head> := <term> (no `by`) | |
| m2 = re.match(r"^(.*?)\s*:=\s*(.+)$", s, re.DOTALL) | |
| if m2: | |
| head, body = m2.group(1).strip(), m2.group(2).strip() | |
| return f"{head} :=\n {body}" | |
| return s | |
| def _format_code(code: str, language: str) -> str: | |
| """Run a language-specific formatter. Returns original on any failure.""" | |
| import subprocess | |
| try: | |
| if language == "lean": | |
| return _format_lean(code) | |
| if language == "typescript": | |
| result = subprocess.run( | |
| ["npx", "--yes", "prettier", "--stdin-filepath", "candidate.ts"], | |
| input=code, capture_output=True, text=True, timeout=15, | |
| ) | |
| elif language == "python": | |
| result = subprocess.run( | |
| ["ruff", "format", "--stdin-filename", "candidate.py", "-"], | |
| input=code, capture_output=True, text=True, timeout=10, | |
| ) | |
| elif language == "rust": | |
| result = subprocess.run( | |
| ["rustfmt", "--edition", "2021"], | |
| input=code, capture_output=True, text=True, timeout=10, | |
| ) | |
| else: | |
| # Unknown language β return as-is (word_wrap still applies in panel) | |
| return code | |
| return result.stdout if result.returncode == 0 and result.stdout.strip() else code | |
| except Exception: | |
| return code | |
| def _code_panel(title: str, code: str, language: str, border: str = "blue") -> Panel: | |
| formatted = _format_code(code, language) | |
| try: | |
| body = Syntax(formatted, language, theme="monokai", line_numbers=True, word_wrap=True) | |
| except Exception: | |
| body = Text(formatted, no_wrap=False) | |
| return Panel(body, title=f"[bold {border}]{title}[/]", border_style=border) | |
| def _text_panel(title: str, body: Text | str, border: str = "cyan") -> Panel: | |
| if isinstance(body, str): | |
| body = Text(body) | |
| return Panel(body, title=f"[bold {border}]{title}[/]", border_style=border) | |
| def _state_text(obs: LeanMigrateObservation) -> Text: | |
| t = Text() | |
| t.append("progress : ", style="bold cyan") | |
| t.append(f"{clamp_open_unit(float(obs.progress)):.3f}\n") | |
| t.append("verified : ", style="bold cyan") | |
| t.append(", ".join(obs.verified) if obs.verified else "(none)") | |
| t.append("\n") | |
| t.append("remaining: ", style="bold cyan") | |
| t.append(", ".join(obs.remaining) if obs.remaining else "(none)") | |
| t.append("\n") | |
| t.append("failing : ", style="bold cyan") | |
| t.append(json.dumps(obs.failing) if obs.failing else "(none)") | |
| t.append("\n") | |
| t.append("done : ", style="bold cyan") | |
| t.append(str(obs.done).lower()) | |
| return t | |
| def _action_text(ad: dict[str, Any]) -> Text: | |
| t = Text() | |
| t.append("type : ", style="bold magenta") | |
| t.append(ad.get("type", "?")) | |
| t.append("\n") | |
| t.append("function : ", style="bold magenta") | |
| t.append(ad.get("function_name", "?")) | |
| return t | |
| def _obs_text(obs: LeanMigrateObservation) -> Text: | |
| t = Text() | |
| reward_val = clamp_open_unit(float(obs.reward or 0.0)) | |
| color = "green" if reward_val > 0 else ("red" if reward_val < 0 else "yellow") | |
| t.append("reward : ", style=f"bold {color}") | |
| t.append(f"{reward_val:.3f}\n", style=color) | |
| t.append("done : ", style="bold cyan") | |
| t.append(str(obs.done).lower()) | |
| return t | |
| # ββ Trace helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _obs_snapshot(obs: LeanMigrateObservation) -> dict[str, Any]: | |
| """Serialise the parts of an observation useful for offline analysis.""" | |
| return { | |
| "progress": float(obs.progress), | |
| "verified": list(obs.verified), | |
| "remaining": list(obs.remaining), | |
| "failing": dict(obs.failing), | |
| "done": bool(obs.done), | |
| "episode_step": int(obs.episode_step), | |
| } | |
| def _reward_snapshot(obs: LeanMigrateObservation) -> dict[str, Any] | None: | |
| rd = obs.reward_details | |
| if rd is None: | |
| return None | |
| return { | |
| "score": float(rd.score), | |
| "cumulative_score": float(rd.cumulative_score), | |
| "tests_passed": rd.tests_passed, | |
| "tests_total": rd.tests_total, | |
| "proof_compiled": rd.proof_compiled, | |
| "breakdown": dict(rd.breakdown), | |
| "feedback": rd.feedback, | |
| "lean_error": rd.lean_error, | |
| } | |
| def _step_analysis( | |
| obs_before: LeanMigrateObservation, | |
| obs_after: LeanMigrateObservation, | |
| action_dict: dict[str, Any], | |
| reward: float, | |
| ) -> dict[str, Any]: | |
| """Derived signals useful for reward-shaping / behaviour analysis.""" | |
| progress_delta = float(obs_after.progress) - float(obs_before.progress) | |
| rd = obs_after.reward_details | |
| tests_pass_rate: float | None = None | |
| if rd and rd.tests_total: | |
| tests_pass_rate = (rd.tests_passed or 0) / rd.tests_total | |
| return { | |
| "progress_delta": round(progress_delta, 6), | |
| "tests_pass_rate": tests_pass_rate, | |
| "lean_error_present": bool(rd and rd.lean_error), | |
| "proof_compiled": rd.proof_compiled if rd else None, | |
| "new_function_verified": progress_delta > 0, | |
| "reward_sign": "positive" if reward > 0 else ("negative" if reward < 0 else "zero"), | |
| "action_type": action_dict.get("type"), | |
| "used_inspection": action_dict.get("type") in {"inspect", "analyze_deps"}, | |
| } | |
| def _write_partial_trace( | |
| out_path: Path, | |
| task_id: str, | |
| episode_id: str, | |
| model: str, | |
| api_base_url: str, | |
| temperature: float, | |
| task: Any, | |
| steps: list[dict[str, Any]], | |
| rewards: list[float], | |
| system_prompt: str = "", | |
| ) -> None: | |
| """Write an in-progress snapshot so work is preserved on interruption.""" | |
| trace = { | |
| "meta": { | |
| "task_id": task_id, | |
| "episode_id": episode_id, | |
| "model": model, | |
| "api_base_url": api_base_url, | |
| "temperature": temperature, | |
| "difficulty": task.difficulty, | |
| "source_language": task.source_language, | |
| "target_language": task.target_language, | |
| "functions": [f.name for f in task.functions], | |
| "proof_functions": [f.name for f in task.functions if f.is_proof_required], | |
| "max_steps": task.max_steps, | |
| "system_prompt": system_prompt, | |
| "partial": True, | |
| }, | |
| "steps": steps, | |
| "summary": None, | |
| } | |
| out_path.write_text(json.dumps(trace, indent=2, ensure_ascii=False)) | |
| def _write_trace( | |
| trace_dir: Path, | |
| task_id: str, | |
| episode_id: str, | |
| model: str, | |
| api_base_url: str, | |
| temperature: float, | |
| task: Any, | |
| steps: list[dict[str, Any]], | |
| rewards: list[float], | |
| final_score: float, | |
| out_path: Path | None = None, | |
| system_prompt: str = "", | |
| ) -> Path: | |
| """Write the full episode trace to a JSON file and return its path.""" | |
| trace_dir.mkdir(parents=True, exist_ok=True) | |
| ts = datetime.datetime.now(datetime.timezone.utc) | |
| ts_str = ts.strftime("%Y%m%dT%H%M%SZ") | |
| # ββ Summary analytics ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| action_counts: dict[str, int] = {} | |
| lean_errors = 0 | |
| negative_rewards = 0 | |
| max_consecutive_run_tests = 0 | |
| current_run_tests_streak = 0 | |
| total_code_chars = 0 | |
| total_test_runs = 0 | |
| total_tests_passed = 0 | |
| total_tests_total = 0 | |
| for s in steps: | |
| atype = s["action"]["type"] | |
| action_counts[atype] = action_counts.get(atype, 0) + 1 | |
| if s["analysis"]["lean_error_present"]: | |
| lean_errors += 1 | |
| if s["reward"] < 0: | |
| negative_rewards += 1 | |
| if atype == "run_tests": | |
| current_run_tests_streak += 1 | |
| max_consecutive_run_tests = max( | |
| max_consecutive_run_tests, current_run_tests_streak | |
| ) | |
| code = s["action"].get("candidate_code") or "" | |
| total_code_chars += len(code) | |
| total_test_runs += 1 | |
| rd = s["reward_details"] or {} | |
| total_tests_passed += rd.get("tests_passed") or 0 | |
| total_tests_total += rd.get("tests_total") or 0 | |
| else: | |
| current_run_tests_streak = 0 | |
| n_types = len(action_counts) | |
| total_actions = len(steps) | |
| # Shannon entropy of action types as a diversity score | |
| import math | |
| diversity = 0.0 | |
| if total_actions > 0 and n_types > 1: | |
| for cnt in action_counts.values(): | |
| p = cnt / total_actions | |
| if p > 0: | |
| diversity -= p * math.log2(p) | |
| diversity /= math.log2(len({"inspect", "analyze_deps", "run_tests", "submit"})) | |
| final_obs = steps[-1]["state_after"] if steps else {} | |
| summary = { | |
| "total_steps": len(steps), | |
| "final_score": round(final_score, 6), | |
| "success": final_score >= 0.99, | |
| "rewards": [round(r, 4) for r in rewards], | |
| "reward_sum": round(sum(rewards), 4), | |
| "action_type_counts": action_counts, | |
| "action_diversity_score": round(diversity, 4), | |
| "functions_verified": final_obs.get("verified", []), | |
| "functions_remaining": final_obs.get("remaining", []), | |
| "functions_failing": list((final_obs.get("failing") or {}).keys()), | |
| "max_consecutive_run_tests": max_consecutive_run_tests, | |
| "lean_errors_count": lean_errors, | |
| "negative_rewards_count": negative_rewards, | |
| "avg_tests_pass_rate": ( | |
| round(total_tests_passed / total_tests_total, 4) | |
| if total_tests_total > 0 | |
| else None | |
| ), | |
| "never_used_inspect": action_counts.get("inspect", 0) == 0, | |
| "never_used_analyze_deps": action_counts.get("analyze_deps", 0) == 0, | |
| } | |
| trace = { | |
| "meta": { | |
| "task_id": task_id, | |
| "episode_id": episode_id, | |
| "model": model, | |
| "api_base_url": api_base_url, | |
| "temperature": temperature, | |
| "timestamp": ts.isoformat(), | |
| "difficulty": task.difficulty, | |
| "source_language": task.source_language, | |
| "target_language": task.target_language, | |
| "functions": [f.name for f in task.functions], | |
| "proof_functions": [f.name for f in task.functions if f.is_proof_required], | |
| "max_steps": task.max_steps, | |
| "system_prompt": system_prompt, | |
| }, | |
| "steps": steps, | |
| "summary": summary, | |
| } | |
| if out_path is None: | |
| filename = f"{task_id}__{ts_str}__{episode_id[:8]}.json" | |
| out_path = trace_dir / filename | |
| out_path.write_text(json.dumps(trace, indent=2, ensure_ascii=False)) | |
| return out_path | |
| def _print_step( | |
| step: int, | |
| obs_before: LeanMigrateObservation, | |
| action_dict: dict[str, Any], | |
| obs_after: LeanMigrateObservation, | |
| target_language: str, | |
| ) -> None: | |
| action_type = action_dict.get("type", "?") | |
| fn_name = action_dict.get("function_name", "?") | |
| console.print( | |
| Rule(f"[bold yellow]STEP {step}[/] [white]{action_type}({fn_name})[/]") | |
| ) | |
| # ββ State (before) + Action summary side-by-side | |
| console.print( | |
| Columns( | |
| [ | |
| _text_panel("STATE (before)", _state_text(obs_before), "cyan"), | |
| _text_panel("ACTION", _action_text(action_dict), "magenta"), | |
| ], | |
| expand=True, | |
| ) | |
| ) | |
| # ββ Code block (if run_tests) | |
| code = action_dict.get("candidate_code") | |
| if code: | |
| lang = ( | |
| "rust" | |
| if target_language == "rust" | |
| else ("typescript" if target_language == "typescript" else "python") | |
| ) | |
| console.print(_code_panel(f"CANDIDATE CODE [{lang}]", code, lang, "blue")) | |
| # ββ Lean proof (if submit with proof) | |
| proof = action_dict.get("lean_proof") | |
| if proof: | |
| console.print(_code_panel("LEAN PROOF", proof, "lean", "green")) | |
| # ββ Observation (after) | |
| rd = obs_after.reward_details | |
| feedback = obs_after.last_action_feedback or "" | |
| lean_err = (rd.lean_error if rd else None) or "" | |
| reward_val = clamp_open_unit(float(obs_after.reward or 0.0)) | |
| border = "green" if reward_val > 0 else ("red" if reward_val < 0 else "yellow") | |
| obs_body = Text() | |
| obs_body.append_text(_obs_text(obs_after)) | |
| if feedback: | |
| obs_body.append("\n\nfeedback :\n", style="bold cyan") | |
| obs_body.append(_trim(feedback, 15)) | |
| if lean_err: | |
| obs_body.append("\n\nlean error:\n", style="bold red") | |
| obs_body.append(_trim(lean_err, 20)) | |
| console.print(_text_panel("OBSERVATION", obs_body, border)) | |
| # ββ LLM helpers (from inference.py) ββββββββββββββββββββββββββββββββββββββββββ | |
| def _get_simple_sample(task: Task, fn: FunctionSpec) -> str: | |
| cases = task.sample_inputs.get(fn.name, []) | |
| if not cases or fn.is_proof_required: | |
| return "" | |
| case = cases[0] | |
| expected = oracle_result(task.task_id, fn.name, case.args) | |
| args_repr = ", ".join(repr(a) for a in case.args) | |
| if len(args_repr) > 300: | |
| args_repr = args_repr[:297] + "..." | |
| return f"Sample: ({args_repr}) β {expected!r}" | |
| def _get_language_notes(task: Task) -> str: | |
| if task.target_language == "typescript": | |
| return textwrap.dedent(""" | |
| ## TypeScript notes | |
| - Use top-level `function` declarations (not arrow functions / const assignments). | |
| - Dependency code is injected automatically β write only the current function. | |
| """).strip() | |
| if task.target_language == "rust": | |
| return textwrap.dedent(""" | |
| ## Rust notes | |
| - Use serde_json and serde (already in Cargo.toml). | |
| - Write only the current function. Dependencies are injected as separate pub fn definitions. | |
| - Match the inspected Rust signature exactly, including ownership and mutability. The harness passes owned JSON-decoded values into your function, so do not add `&`, `&mut`, or slice parameters unless the inspected source explicitly shows them. | |
| - If the inspect output shows a `Source symbol`, use that exact Rust function name in your code. | |
| - Return types serializable to serde_json::Value (Vec, Option, tuples all work). | |
| - run_tests compiles via cargo. If run_tests returns "cargo not found", cargo is | |
| unavailable β submit your implementation directly without run_tests. | |
| """).strip() | |
| return "" | |
| def _get_proof_notes(task: Task) -> str: | |
| if any(fn.is_proof_required for fn in task.functions): | |
| return textwrap.dedent(""" | |
| ## Proof task notes | |
| - NEVER use `sorry` β the verifier rejects it IMMEDIATELY, gives -0.05 reward, | |
| and does not compile Lean at all. `sorry` has zero partial credit. | |
| - If you cannot write the full proof yet, DO NOT submit. Call `inspect` on the | |
| proof function to read the exact theorem and full Lean spec, then reason step by step. | |
| - Useful first-try tactics: `by decide`, `by native_decide`, `by simp [...]`, | |
| `by omega`, `by rfl`, `by constructor <;> simp`. | |
| - Use qualified ADT constructors from the spec (e.g. `Spec.MyType.mk`). | |
| """).strip() | |
| return "" | |
| def _system_prompt(task: Task) -> str: | |
| fn_sections = [] | |
| for fn in task.functions: | |
| proof_tag = " [PROOF REQUIRED β submit with lean_proof field]" if fn.is_proof_required else "" | |
| fn_sections.append( | |
| f"### {fn.name}{proof_tag}\n" | |
| f"Description: {fn.description}\n" | |
| f"Depends on: {fn.depends_on or ['(none)']}\n" | |
| ) | |
| language_notes = _get_language_notes(task) | |
| proof_notes = _get_proof_notes(task) | |
| extra = "\n\n".join(s for s in [language_notes, proof_notes] if s) | |
| return textwrap.dedent( | |
| f""" | |
| You are an expert software engineer migrating {task.source_language} code to | |
| {task.target_language} with Lean 4 verification. | |
| Task: {task.display_name} | |
| ## Workflow | |
| 1. Call analyze_deps on any function to see the full dependency graph and migration order. | |
| 2. Call inspect on each function to read its source code and Lean spec before implementing. | |
| 3. Call run_tests with your implementation and iterate until all cases pass. | |
| 4. Call submit when tests pass β the system uses the last run_tests code automatically. | |
| 5. Write only the current function's code; verified dependencies are injected automatically. | |
| 6. For Rust tasks, treat the inspected function signature as authoritative. If a compile error mentions a type mismatch, fix the signature to match the harness before changing the body. | |
| 7. If inspect output names a Rust `Source symbol`, prefer that exact symbol over the Lean label when naming the function. | |
| Example sequence for a 2-function task where B depends on A: | |
| analyze_deps(B) β see order: A first, then B | |
| inspect(A) β read source + Lean spec | |
| run_tests(A, <code>) β iterate until passing | |
| submit(A) | |
| inspect(B) β read source + Lean spec | |
| run_tests(B, <code>) β iterate until passing | |
| submit(B) | |
| ## Actions (return ONLY JSON) | |
| - inspect: {{"type": "inspect", "function_name": "foo"}} | |
| - analyze_deps: {{"type": "analyze_deps", "function_name": "foo"}} | |
| - run_tests: {{"type": "run_tests", "function_name": "foo", "candidate_code": "..."}} | |
| - submit: {{"type": "submit", "function_name": "foo"}} | |
| Proof only: {{"type": "submit", "function_name": "foo", "lean_proof": "..."}} | |
| {extra} | |
| ## Functions to migrate | |
| {"".join(fn_sections)} | |
| """ | |
| ).strip() | |
| def _llm_action( | |
| client: Any, | |
| model: str, | |
| temperature: float, | |
| task: Task, | |
| obs: LeanMigrateObservation, | |
| history: list[dict], | |
| ) -> tuple[dict[str, Any], str, str]: | |
| """Return (action_dict, user_prompt, raw_model_output).""" | |
| current_step = len(history) + 1 | |
| # ββ Dynamic warnings ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| warnings: list[str] = [] | |
| recent_sorry = [ | |
| h for h in history[-8:] | |
| if "sorry" in (h.get("feedback") or "").lower() and h.get("reward", 0) < 0 | |
| ] | |
| if len(recent_sorry) >= 2: | |
| warnings.append( | |
| f"β You have submitted `sorry` {len(recent_sorry)} times. " | |
| "The verifier ALWAYS rejects sorry (-0.05 each time, no exceptions). " | |
| "Write a complete proof: try `by decide`, `by native_decide`, `by simp [...]`, `by omega`." | |
| ) | |
| if task.target_language == "rust": | |
| recent_runtime_errors = sum( | |
| 1 for h in history[-4:] | |
| if any(k in (h.get("feedback") or "").lower() | |
| for k in ("cargo not found", "cannot find a runtime", "install cargo")) | |
| ) | |
| if recent_runtime_errors >= 2: | |
| warnings.append( | |
| "β cargo is unavailable in this environment. " | |
| "Do NOT use run_tests β submit your implementation directly." | |
| ) | |
| # ββ Build prompt sections βββββββββββββββββββββββββββββββββββββββββββββββββ | |
| state_section = ( | |
| f"<STATE>\n" | |
| f"Step: {current_step} / {task.max_steps}\n" | |
| f"Progress: {obs.progress:.2f}\n" | |
| f"Verified: {', '.join(obs.verified) if obs.verified else 'None'}\n" | |
| f"Remaining: {', '.join(obs.remaining)}\n" | |
| f"Failing: {json.dumps(obs.failing)}\n" | |
| f"</STATE>" | |
| ) | |
| # Full last observation β no truncation; inspect output can be large | |
| last_obs_section = "" | |
| if obs.last_action_feedback: | |
| last_obs_section = ( | |
| f"<LAST_OBSERVATION>\n" | |
| f"{obs.last_action_feedback}\n" | |
| f"</LAST_OBSERVATION>" | |
| ) | |
| # Only the single most recent step (not a history window) | |
| last_step_section = "" | |
| if history: | |
| h = history[-1] | |
| last_step_section = ( | |
| f"<LAST_STEP>\n" | |
| f"Step {h['step']}: type={h['action'].get('type')} " | |
| f"fn={h['action'].get('function_name')} " | |
| f"reward={h['reward']:.3f}\n" | |
| f"</LAST_STEP>" | |
| ) | |
| warnings_section = ( | |
| ("<WARNINGS>\n" + "\n".join(warnings) + "\n</WARNINGS>") | |
| if warnings else "" | |
| ) | |
| prompt = "\n\n".join( | |
| s for s in [ | |
| state_section, | |
| last_obs_section, | |
| last_step_section, | |
| warnings_section, | |
| "Decide the next action. Return ONLY valid JSON.", | |
| ] | |
| if s | |
| ) | |
| try: | |
| response = client.beta.chat.completions.parse( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": _system_prompt(task)}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| temperature=temperature, | |
| response_format=ActionRequest, | |
| ) | |
| raw_output = response.choices[0].message.content or "" | |
| parsed = response.choices[0].message.parsed | |
| if parsed is None: | |
| raise ValueError("No structured output returned") | |
| return parsed.model_dump(), prompt, raw_output | |
| except Exception as exc: | |
| console.print(f"[red]LLM error:[/red] {exc}") | |
| return ( | |
| { | |
| "type": "run_tests", | |
| "function_name": obs.remaining[0] if obs.remaining else "", | |
| "candidate_code": "# llm error", | |
| }, | |
| prompt, | |
| f"<error: {exc}>", | |
| ) | |
| def _parse_action(ad: dict[str, Any]) -> LeanMigrateAction: | |
| t = ad.get("type", "run_tests") | |
| fn = ad.get("function_name", "") | |
| if t == "inspect": | |
| return InspectAction(type="inspect", function_name=fn) | |
| if t == "analyze_deps": | |
| return AnalyzeDepsAction(type="analyze_deps", function_name=fn) | |
| if t == "run_tests": | |
| return RunTestsAction( | |
| type="run_tests", | |
| function_name=fn, | |
| candidate_code=ad.get("candidate_code", ""), | |
| ) | |
| if t == "submit": | |
| return SubmitAction( | |
| type="submit", | |
| function_name=fn, | |
| target_code=ad.get("target_code"), | |
| lean_proof=ad.get("lean_proof"), | |
| ) | |
| return RunTestsAction(type="run_tests", function_name="", candidate_code="") | |
| # ββ Commands ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def cmd_list() -> None: | |
| """List all available LeanMigrate tasks.""" | |
| tasks_info = list_tasks() | |
| from rich import box as rich_box | |
| table = Table( | |
| title="[bold]LeanMigrate Tasks[/bold]", | |
| show_header=True, | |
| header_style="bold cyan", | |
| box=rich_box.SIMPLE_HEAD, | |
| padding=(0, 1), | |
| expand=True, | |
| ) | |
| table.add_column("task_id", style="bold yellow", no_wrap=True) | |
| table.add_column("display_name") | |
| table.add_column("diff", no_wrap=True) | |
| table.add_column("migration", no_wrap=True) | |
| table.add_column("functions") | |
| table.add_column("steps", justify="right", no_wrap=True) | |
| for info in tasks_info: | |
| task = get_task(info["task_id"]) | |
| diff = info.get("difficulty", "?") | |
| color = _difficulty_color(diff) | |
| non_proof = [f.name for f in task.functions if not f.is_proof_required] | |
| proof = [f.name for f in task.functions if f.is_proof_required] | |
| fn_str = ", ".join(non_proof) | |
| if proof: | |
| fn_str += f" [dim](+{', '.join(proof)})[/dim]" | |
| table.add_row( | |
| info["task_id"], | |
| task.display_name, | |
| f"[{color}]{diff}[/{color}]", | |
| f"{task.source_language} β {task.target_language}", | |
| fn_str, | |
| str(task.max_steps), | |
| ) | |
| console.print() | |
| console.print(table) | |
| console.print() | |
| console.print( | |
| "[dim]Play a task:[/dim] uv run python scripts/play_task.py play <task_id>" | |
| ) | |
| console.print( | |
| "[dim]Play all: [/dim] uv run python scripts/play_task.py play --all" | |
| ) | |
| console.print() | |
| def cmd_play( | |
| task_ids: Annotated[ | |
| Optional[list[str]], | |
| typer.Argument(help="Task IDs to play (omit to use --all)."), | |
| ] = None, | |
| all_tasks: Annotated[bool, typer.Option("--all", help="Play every task.")] = False, | |
| model: Annotated[ | |
| Optional[str], | |
| typer.Option( | |
| "--model", | |
| "-m", | |
| envvar="MODEL_NAME", | |
| help="LLM model name. [env: MODEL_NAME]", | |
| ), | |
| ] = None, | |
| api_base_url: Annotated[ | |
| Optional[str], | |
| typer.Option( | |
| "--api-base-url", | |
| envvar="API_BASE_URL", | |
| help="OpenAI-compatible API base URL. [env: API_BASE_URL]", | |
| ), | |
| ] = None, | |
| api_key: Annotated[ | |
| Optional[str], | |
| typer.Option( | |
| "--api-key", | |
| envvar="HF_TOKEN", | |
| help="API key. [env: HF_TOKEN or GEMINI_API_KEY]", | |
| ), | |
| ] = None, | |
| max_steps: Annotated[ | |
| int, | |
| typer.Option("--max-steps", envvar="MAX_STEPS", help="Max LLM steps per task."), | |
| ] = 50, | |
| temperature: Annotated[ | |
| float, | |
| typer.Option( | |
| "--temperature", envvar="TEMPERATURE", help="Sampling temperature." | |
| ), | |
| ] = 0.2, | |
| quiet: Annotated[ | |
| bool, | |
| typer.Option("--quiet", "-q", help="Suppress per-step rich output."), | |
| ] = False, | |
| trace_dir: Annotated[ | |
| Optional[Path], | |
| typer.Option( | |
| "--trace-dir", | |
| help="Directory to write per-episode JSON traces. Defaults to <project>/traces/.", | |
| ), | |
| ] = None, | |
| ) -> None: | |
| """Play LeanMigrate tasks with an LLM, showing rich step-by-step output.""" | |
| from openai import OpenAI | |
| all_ids = [t["task_id"] for t in list_tasks()] | |
| if all_tasks: | |
| selected = all_ids | |
| elif task_ids: | |
| bad = [t for t in task_ids if t not in all_ids] | |
| if bad: | |
| console.print(f"[red]Unknown task(s):[/red] {', '.join(bad)}") | |
| console.print(f"Available: {', '.join(all_ids)}") | |
| raise typer.Exit(code=1) | |
| selected = list(task_ids) | |
| else: | |
| console.print("[red]Specify task IDs or --all.[/red]") | |
| console.print(f"Available: {', '.join(all_ids)}") | |
| raise typer.Exit(code=1) | |
| # Resolve LLM config: CLI flag β .env / os.environ β hardcoded fallback | |
| resolved_model = model or os.getenv("MODEL_NAME", "gemini-2.5-flash-preview-04-17") | |
| resolved_url = api_base_url or os.getenv( | |
| "API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/" | |
| ) | |
| resolved_key = api_key or os.getenv("HF_TOKEN") or os.getenv("GEMINI_API_KEY") | |
| if not resolved_key: | |
| console.print("[red]Error:[/red] API key required.") | |
| console.print( | |
| "Set [bold]HF_TOKEN[/bold] or [bold]GEMINI_API_KEY[/bold] in your " | |
| ".env file or pass [bold]--api-key[/bold]." | |
| ) | |
| raise typer.Exit(code=1) | |
| console.print( | |
| Panel( | |
| Text.assemble( | |
| ("model : ", "bold cyan"), | |
| resolved_model, | |
| "\n", | |
| ("base_url : ", "bold cyan"), | |
| resolved_url, | |
| "\n", | |
| ("tasks : ", "bold cyan"), | |
| ", ".join(selected), | |
| ), | |
| title="[bold white]LLM CONFIG[/]", | |
| border_style="white", | |
| ) | |
| ) | |
| resolved_trace_dir = trace_dir or (ROOT / "traces") | |
| client = OpenAI(base_url=resolved_url, api_key=resolved_key) | |
| scores: dict[str, float] = {} | |
| for task_id in selected: | |
| scores[task_id] = _play_one_task( | |
| client=client, | |
| task_id=task_id, | |
| model=resolved_model, | |
| api_base_url=resolved_url, | |
| max_steps=max_steps, | |
| temperature=temperature, | |
| quiet=quiet, | |
| trace_dir=resolved_trace_dir, | |
| ) | |
| # ββ Final summary ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| console.print(Rule("[bold white]PLAYTHROUGH SUMMARY[/]")) | |
| summary_table = Table( | |
| show_header=True, header_style="bold cyan", box=None, padding=(0, 2) | |
| ) | |
| summary_table.add_column("task_id", style="bold yellow") | |
| summary_table.add_column("score", justify="right") | |
| summary_table.add_column("result") | |
| total = 0.0 | |
| for tid, score in scores.items(): | |
| color = "green" if score >= 0.99 else ("yellow" if score > 0 else "red") | |
| result = ( | |
| "β complete" if score >= 0.99 else ("partial" if score > 0 else "failed") | |
| ) | |
| summary_table.add_row(tid, f"{score:.3f}", f"[{color}]{result}[/{color}]") | |
| total += score | |
| if len(scores) > 1: | |
| avg = total / len(scores) | |
| summary_table.add_section() | |
| summary_table.add_row("[bold]overall[/bold]", f"[bold]{avg:.3f}[/bold]", "") | |
| console.print(summary_table) | |
| def _play_one_task( | |
| client: Any, | |
| task_id: str, | |
| model: str, | |
| api_base_url: str, | |
| max_steps: int, | |
| temperature: float, | |
| quiet: bool, | |
| trace_dir: Path, | |
| ) -> float: | |
| task = get_task(task_id) | |
| env = LeanMigrateEnvironment() | |
| obs = env.reset(task_id=task_id) | |
| # ββ Task header ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| diff = task.difficulty | |
| color = _difficulty_color(diff) | |
| header = Text() | |
| header.append("task : ", style="bold cyan") | |
| header.append(task_id + "\n") | |
| header.append("name : ", style="bold cyan") | |
| header.append(task.display_name + "\n") | |
| header.append("difficulty: ", style="bold cyan") | |
| header.append(diff, style=f"bold {color}") | |
| header.append("\n") | |
| header.append("migration : ", style="bold cyan") | |
| header.append(f"{task.source_language} β {task.target_language}\n") | |
| header.append("episode : ", style="bold cyan") | |
| header.append(obs.episode_id + "\n") | |
| header.append("model : ", style="bold cyan") | |
| header.append(model + "\n") | |
| header.append("max_steps: ", style="bold cyan") | |
| header.append(str(max_steps)) | |
| console.print() | |
| console.print(Rule(f"[bold white]{task.display_name}[/] [dim]({task_id})[/dim]")) | |
| console.print(_text_panel("TASK", header, "white")) | |
| rewards: list[float] = [] | |
| history: list[dict] = [] | |
| trace_steps: list[dict[str, Any]] = [] | |
| code_cache: dict[str, str] = {} | |
| steps_taken = 0 | |
| episode_system_prompt = _system_prompt(task) | |
| # Pre-determine trace path so partial writes and final write share the same file. | |
| trace_dir.mkdir(parents=True, exist_ok=True) | |
| episode_start_ts = datetime.datetime.now(datetime.timezone.utc) | |
| _partial_trace_path = ( | |
| trace_dir | |
| / f"{task_id}__{episode_start_ts.strftime('%Y%m%dT%H%M%SZ')}__{obs.episode_id[:8]}.json" | |
| ) | |
| for step in range(1, max_steps + 1): | |
| if obs.done: | |
| break | |
| obs_before = obs | |
| try: | |
| action_dict, llm_user_prompt, llm_raw_output = _llm_action( | |
| client, model, temperature, task, obs, history | |
| ) | |
| # Maintain code cache: record tested code, inject on submit | |
| fn_name = action_dict.get("function_name", "") | |
| if action_dict.get("type") == "run_tests": | |
| code_cache[fn_name] = action_dict.get("candidate_code", "") | |
| elif action_dict.get("type") == "submit": | |
| if fn_name in code_cache: | |
| action_dict["target_code"] = code_cache[fn_name] | |
| elif action_dict.get("candidate_code"): | |
| # Agent skipped run_tests (e.g. Rust without cargo) and attached code directly | |
| action_dict["target_code"] = action_dict["candidate_code"] | |
| code_cache[fn_name] = action_dict["candidate_code"] | |
| action = _parse_action(action_dict) | |
| obs = env.step(action) | |
| raw_reward = float(obs.reward or 0.0) | |
| display_reward = clamp_open_unit(raw_reward) | |
| rewards.append(raw_reward) # raw for accurate analysis | |
| steps_taken = step | |
| feedback = obs.last_action_feedback or "" | |
| rd = obs.reward_details | |
| lean_err = (rd.lean_error if rd else None) or "" | |
| if lean_err: | |
| feedback = (feedback + " | Lean: " + lean_err).strip() | |
| history.append( | |
| { | |
| "step": step, | |
| "action": action_dict, | |
| "reward": raw_reward, # raw so LLM sees genuine negative signal | |
| "feedback": feedback, | |
| } | |
| ) | |
| # ββ Trace record βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| trace_steps.append( | |
| { | |
| "step": step, | |
| "timestamp": datetime.datetime.now(datetime.timezone.utc).isoformat(), | |
| "state_before": _obs_snapshot(obs_before), | |
| "action": { | |
| "type": action_dict.get("type"), | |
| "function_name": action_dict.get("function_name"), | |
| "candidate_code": action_dict.get("candidate_code"), | |
| "lean_proof": action_dict.get("lean_proof"), | |
| "target_code": action_dict.get("target_code"), | |
| }, | |
| "llm": { | |
| "user_prompt": llm_user_prompt, | |
| "raw_output": llm_raw_output, | |
| }, | |
| "reward": round(raw_reward, 6), | |
| "reward_details": _reward_snapshot(obs), | |
| "state_after": _obs_snapshot(obs), | |
| "feedback": feedback, | |
| "analysis": _step_analysis(obs_before, obs, action_dict, raw_reward), | |
| } | |
| ) | |
| # Persist partial trace after every step so interruption loses nothing. | |
| _write_partial_trace( | |
| out_path=_partial_trace_path, | |
| task_id=task_id, | |
| episode_id=obs.episode_id, | |
| model=model, | |
| api_base_url=api_base_url, | |
| temperature=temperature, | |
| task=task, | |
| steps=trace_steps, | |
| rewards=rewards, | |
| system_prompt=episode_system_prompt, | |
| ) | |
| if not quiet: | |
| _print_step(step, obs_before, action_dict, obs, task.target_language) | |
| if obs.done: | |
| break | |
| except KeyboardInterrupt: | |
| console.print( | |
| f"\n[yellow]Interrupted at step {step}. " | |
| f"Partial trace saved β {_partial_trace_path}[/yellow]" | |
| ) | |
| raise | |
| # ββ Task footer ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| final_score = clamp_open_unit(float(obs.progress)) | |
| success = final_score >= 0.99 | |
| color = "green" if success else ("yellow" if final_score > 0 else "red") | |
| result_label = ( | |
| "COMPLETE" if success else ("PARTIAL" if final_score > 0 else "FAILED") | |
| ) | |
| footer = Text() | |
| footer.append("result : ", style="bold cyan") | |
| footer.append(result_label, style=f"bold {color}") | |
| footer.append("\n") | |
| footer.append("score : ", style="bold cyan") | |
| footer.append(f"{final_score:.3f}\n", style=color) | |
| footer.append("steps : ", style="bold cyan") | |
| footer.append(f"{steps_taken}\n") | |
| footer.append("verified : ", style="bold cyan") | |
| footer.append(", ".join(obs.verified) if obs.verified else "(none)") | |
| # ββ Write trace (overwrites partial with full analytics) βββββββββββββββββββ | |
| trace_path = _write_trace( | |
| trace_dir=trace_dir, | |
| task_id=task_id, | |
| episode_id=obs.episode_id, | |
| model=model, | |
| api_base_url=api_base_url, | |
| temperature=temperature, | |
| task=task, | |
| steps=trace_steps, | |
| rewards=rewards, | |
| final_score=final_score, | |
| out_path=_partial_trace_path, | |
| system_prompt=episode_system_prompt, | |
| ) | |
| footer.append("\ntrace : ", style="bold cyan") | |
| footer.append(str(trace_path), style="dim") | |
| console.print(Rule()) | |
| console.print(_text_panel(f"RESULT {task_id}", footer, color)) | |
| print( | |
| f"[END] task={task_id} success={str(success).lower()} " | |
| f"steps={steps_taken} score={final_score:.2f} " | |
| f"rewards={','.join(f'{clamp_open_unit(r):.2f}' for r in rewards)} " | |
| f"trace={trace_path}", | |
| flush=True, | |
| ) | |
| return final_score | |
| if __name__ == "__main__": | |
| app() | |