lean-migrate / inference.py
Hrushi's picture
Upload folder using huggingface_hub
1f506a9 verified
"""Autonomous inference runner for LeanMigrate (v2)."""
from __future__ import annotations
import asyncio
import json
import os
import sys
import textwrap
from pathlib import Path
from typing import Any, Literal, Optional
from dotenv import load_dotenv
from openai import OpenAI
from pydantic import BaseModel, Field
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from lean_migrate.env.models import (
AnalyzeDepsAction,
InspectAction,
LeanMigrateAction,
LeanMigrateObservation,
RunTestsAction,
SubmitAction,
)
from lean_migrate.env.tasks import Task, get_task, list_tasks
from lean_migrate.server.lean_migrate_environment import LeanMigrateEnvironment
load_dotenv() # Load environment variables from .env file
API_BASE_URL = os.getenv("API_BASE_URL", "https://generativelanguage.googleapis.com/v1beta/openai/")
MODEL_NAME = os.getenv("MODEL_NAME", "gemini-3.1-flash-lite-preview")
# HF_TOKEN is the required credential for inference via the HuggingFace router.
API_KEY = os.getenv("HF_TOKEN") or os.getenv("GEMINI_API_KEY")
TASK_ID = os.getenv("TASK_ID")
MAX_STEPS = int(os.getenv("MAX_STEPS", "50"))
TEMPERATURE = float(os.getenv("TEMPERATURE", "0.2"))
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def _clamp01(value: float) -> float:
"""Clamp a reward/score to [0.0, 1.0] for stdout log compliance."""
return max(0.0, min(1.0, value))
def _clamp_open01_for_precision(value: float, decimals: int=3) -> float:
"""Clamp a value to the open interval (0, 1), preserving strictness after formatting.
For example, with decimals=3 this guarantees the rendered value is in [0.001, 0.999].
"""
quantum = 10 ** (-decimals)
lower = quantum
upper = 1.0 - quantum
return max(lower, min(upper, value))
def log_step(
step: int, action: dict, reward: float, done: bool, feedback: Optional[str]
) -> None:
# Strip large code blobs so the action fits on one line
_OMIT = {"candidate_code", "target_code", "lean_proof"}
compact = {k: v for k, v in action.items() if k not in _OMIT}
action_str = json.dumps(compact)
# Hackathon spec: field is named `error`, value is single-line or null
error_value = " ".join(feedback.split()) if feedback else "null"
print(
f"[STEP] step={step} action={action_str} reward={_clamp01(reward):.2f} done={str(done).lower()} error={error_value}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: list[float]) -> None:
# Hackathon spec: score uses 3 decimal places; rewards list uses 2.
# Keep score strictly in (0, 1) after formatting to satisfy validator checks.
score_value = _clamp_open01_for_precision(score, decimals=3)
rewards_value = ",".join(f"{_clamp01(r):.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score_value:.3f} rewards={rewards_value}",
flush=True,
)
class ActionRequest(BaseModel):
"""Structured output schema for LLM action responses.
For submit: do NOT include target_code — the system automatically uses the
last candidate_code you provided via run_tests. Only include lean_proof for
proof tasks (payment_saga).
"""
type: Literal["inspect", "analyze_deps", "run_tests", "submit"]
function_name: str = Field(description="Name of the function this action targets.")
candidate_code: str | None = Field(
default=None,
description="Full implementation for run_tests. Omit for other action types.",
)
lean_proof: str | None = Field(
default=None,
description="Lean proof text. Only required for proof tasks (payment_saga).",
)
def _get_language_notes(task: Task) -> str:
if task.target_language == "typescript":
return textwrap.dedent("""
## TypeScript notes
- tsx is available: use TypeScript syntax for parameter types and interfaces.
- Use top-level `function` declarations — the verifier extracts function declarations with tree-sitter.
✓ `function findRole(roles: Role[], name: string): Role | null { ... }`
✗ `const findRole = (roles: Role[], name: string) => { ... }`
- Return type annotations and `export` are allowed; tree-sitter handles them safely.
- Keep the submission focused on the current function — verified dependencies are auto-injected.
""").strip()
if task.target_language == "rust":
return textwrap.dedent("""
## Rust notes
- Use serde_json and serde (already in Cargo.toml).
- Write only the current function. Dependencies are injected as separate pub fn definitions.
- Return types serializable to serde_json::Value (Vec, Option, tuples all work).
- run_tests compiles via cargo. If run_tests returns "cargo not found", cargo is
unavailable — submit your implementation directly without run_tests.
""").strip()
return ""
def _get_proof_notes(task: Task) -> str:
if any(fn.is_proof_required for fn in task.functions):
return textwrap.dedent("""
## Proof task notes
- NEVER use `sorry` — the verifier rejects it IMMEDIATELY, gives -0.05 reward,
and does not compile Lean at all. `sorry` has zero partial credit.
- If you cannot write the full proof yet, DO NOT submit. Call `inspect` on the
proof function to read the exact theorem and full Lean spec, then reason step by step.
- Useful first-try tactics: `by decide`, `by native_decide`, `by simp [...]`,
`by omega`, `by rfl`, `by constructor <;> simp`.
- Use qualified ADT constructors from the spec (e.g. `SagaEvent.Reserve`,
`SagaState.Settled`), not raw strings like `"Reserve"` or `"Settled"`.
""").strip()
return ""
def _get_system_prompt(task_id: str) -> str:
task = get_task(task_id)
fn_sections = []
for fn in task.functions:
proof_tag = " [PROOF REQUIRED — submit with lean_proof field]" if fn.is_proof_required else ""
fn_sections.append(
f"### {fn.name}{proof_tag}\n"
f"Description: {fn.description}\n"
f"Depends on: {fn.depends_on or ['(none)']}\n"
)
language_notes = _get_language_notes(task)
proof_notes = _get_proof_notes(task)
extra_sections = "\n\n".join(s for s in [language_notes, proof_notes] if s)
return textwrap.dedent(
f"""
You are an expert software engineer migrating {task.source_language} code to {task.target_language} with Lean verification.
Task: {task.display_name}
## Recommended workflow
1. Call analyze_deps on any function to see the dependency graph and migration order.
2. Call inspect on each function to read its source code and Lean spec before implementing.
3. Call run_tests with your implementation and iterate until all cases pass.
4. Call submit when tests pass — the system uses the last run_tests code automatically.
5. Write only the current function; verified dependencies are injected automatically.
Example sequence for a 2-function task where B depends on A:
analyze_deps(B) → see order: A first, then B
inspect(A) → read source + Lean spec
run_tests(A, <code>) → iterate until passing
submit(A)
inspect(B) → read source + Lean spec
run_tests(B, <code>) → iterate until passing
submit(B)
## Available actions (return ONLY a JSON object)
- inspect: get source code + Lean spec for a function
{{"type": "inspect", "function_name": "foo"}}
- analyze_deps: get dependency graph and migration order
{{"type": "analyze_deps", "function_name": "foo"}}
- run_tests: test your implementation (write only this function)
{{"type": "run_tests", "function_name": "foo", "candidate_code": "def foo(): ..."}}
- submit: submit for Lean verification using the last tested code (no code field needed)
{{"type": "submit", "function_name": "foo"}}
For proof tasks only: {{"type": "submit", "function_name": "foo", "lean_proof": "..."}}
{extra_sections}
## Functions to migrate
{"".join(fn_sections)}
"""
).strip()
def _model_action(
client: OpenAI,
task_id: str,
observation: LeanMigrateObservation,
history: list[dict],
) -> dict[str, Any]:
task = get_task(task_id)
# Trim history to type+fn only — full action dicts include candidate_code which is huge
history_str = "".join(
f"Step {h['step']}: type={h['action'].get('type')} fn={h['action'].get('function_name')} "
f"reward={h['reward']:.2f} feedback={h['feedback']}\n"
for h in history[-6:]
)
# Dynamic escalation: sorry pattern
recent_sorry = [
h for h in history[-8:]
if "sorry" in (h.get("feedback") or "").lower() and h.get("reward", 0) < 0
]
sorry_warning = ""
if len(recent_sorry) >= 2:
sorry_warning = (
f"\n⚠ WARNING: You have submitted `sorry` {len(recent_sorry)} times recently. "
"The verifier ALWAYS rejects sorry (-0.05 each time, no exceptions). "
"You MUST write a complete proof without sorry. "
"Try: `by decide`, `by native_decide`, `by simp [...]`, `by omega`.\n"
)
# Dynamic escalation: Rust cargo not available
runtime_warning = ""
if task.target_language == "rust":
recent_runtime_errors = sum(
1 for h in history[-4:]
if any(k in (h.get("feedback") or "").lower()
for k in ("cargo not found", "cannot find a runtime", "install cargo"))
)
if recent_runtime_errors >= 2:
runtime_warning = (
"\n⚠ WARNING: cargo is unavailable in this environment. "
"Do NOT use run_tests — submit your implementation directly.\n"
)
# Last observation verbatim (capped to avoid token explosion)
last_obs_section = ""
if observation.last_action_feedback:
truncated = observation.last_action_feedback[:600]
last_obs_section = f"\nLast observation:\n{truncated}\n"
current_step = len(history) + 1
prompt = textwrap.dedent(
f"""
Current State:
- Step: {current_step} / {task.max_steps}
- Progress: {observation.progress:.2f}
- Verified: {", ".join(observation.verified) if observation.verified else "None"}
- Remaining: {", ".join(observation.remaining)}
- Failing: {json.dumps(observation.failing)}
{last_obs_section}
Recent History:
{history_str or "None"}
{sorry_warning}{runtime_warning}
Decide the next action. Return ONLY the JSON.
"""
).strip()
try:
response = client.beta.chat.completions.parse(
model=MODEL_NAME,
messages=[
{"role": "system", "content": _get_system_prompt(task_id)},
{"role": "user", "content": prompt},
],
temperature=TEMPERATURE,
response_format=ActionRequest,
)
parsed = response.choices[0].message.parsed
if parsed is None:
raise ValueError("No structured output returned")
return parsed.model_dump()
except Exception as e:
print(f"Error calling LLM: {e}")
return {
"type": "run_tests",
"function_name": observation.remaining[0] if observation.remaining else "",
"candidate_code": "# error",
}
def _parse_action(action_dict: dict[str, Any]) -> LeanMigrateAction:
action_type = action_dict.get("type", "run_tests")
function_name = action_dict.get("function_name", "")
if action_type == "inspect":
return InspectAction(type="inspect", function_name=function_name)
elif action_type == "analyze_deps":
return AnalyzeDepsAction(type="analyze_deps", function_name=function_name)
elif action_type == "run_tests":
return RunTestsAction(
type="run_tests",
function_name=function_name,
candidate_code=action_dict.get("candidate_code", ""),
)
elif action_type == "submit":
return SubmitAction(
type="submit",
function_name=function_name,
target_code=action_dict.get("target_code"),
lean_proof=action_dict.get("lean_proof"),
)
# Default fallback
return RunTestsAction(type="run_tests", function_name="", candidate_code="")
async def _run_task(client: OpenAI, task_id: str) -> None:
env = LeanMigrateEnvironment()
observation = env.reset(task_id=task_id)
log_start(task=task_id, env="lean_migrate", model=MODEL_NAME)
rewards: list[float] = []
steps_taken = 0
success = False
history: list[dict] = []
code_cache: dict[
str, str
] = {} # function_name -> last candidate_code from run_tests
try:
for step in range(1, MAX_STEPS + 1):
if observation.done:
break
action_dict = _model_action(client, task_id, observation, history)
# Maintain code cache: record tested code, inject it on submit
fn_name = action_dict.get("function_name", "")
if action_dict.get("type") == "run_tests":
code_cache[fn_name] = action_dict.get("candidate_code", "")
elif action_dict.get("type") == "submit":
if fn_name in code_cache:
action_dict["target_code"] = code_cache[fn_name]
elif action_dict.get("candidate_code"):
# Agent skipped run_tests and attached code directly on submit
action_dict["target_code"] = action_dict["candidate_code"]
code_cache[fn_name] = action_dict["candidate_code"]
action = _parse_action(action_dict)
observation = env.step(action)
reward_value = float(observation.reward or 0.0)
rewards.append(reward_value)
steps_taken = step
feedback = observation.last_action_feedback
if observation.reward_details and observation.reward_details.lean_error:
feedback = (
(feedback or "")
+ " | Lean Error: "
+ observation.reward_details.lean_error
)
log_step(
step=step,
action=action_dict,
reward=reward_value,
done=observation.done,
feedback=feedback,
)
history.append(
{
"step": step,
"action": action_dict,
"reward": reward_value,
"feedback": feedback,
}
)
if observation.done:
break
success = observation.progress >= 1.0
finally:
log_end(
success=success,
steps=steps_taken,
score=float(observation.progress),
rewards=rewards,
)
async def main() -> None:
if not API_KEY:
raise RuntimeError(
"Set HF_TOKEN before running inference.py"
)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
task_ids = [TASK_ID] if TASK_ID else [task["task_id"] for task in list_tasks()]
for task_id in task_ids:
await _run_task(client, task_id)
if __name__ == "__main__":
asyncio.run(main())