lean-migrate / scripts /export_lean_business_logic.py
Hrushi's picture
Upload folder using huggingface_hub
8c75600 verified
"""Export the Lean business-logic mirror for a LeanMigrate task submission."""
from __future__ import annotations
import argparse
import sys
import textwrap
from collections.abc import Sequence
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from lean_migrate.env.target_snippets import ( # noqa: E402
TASK_TARGET_SNIPPETS,
build_submission_bundle,
)
from lean_migrate.env.tasks import get_task, list_tasks # noqa: E402
from lean_migrate.env.verification_ir import build_verification_ir # noqa: E402
TASK_IDS = tuple(task_info["task_id"] for task_info in list_tasks())
def _default_source_code(task_id: str, function_name: str) -> str:
task = get_task(task_id)
snippets = TASK_TARGET_SNIPPETS[task_id]
function_spec = task.get_function(function_name)
if function_spec is None:
raise ValueError(f"Unknown function '{function_name}' for task '{task_id}'.")
if function_spec.is_proof_required:
return function_spec.source_fragment
return build_submission_bundle(
task, function_name, snippets, snippets[function_name]
)
def _read_source_code(
task_id: str, function_name: str, source_file: Path | None, source_code: str | None
) -> str:
if source_code is not None:
return source_code
if source_file is None:
return _default_source_code(task_id, function_name)
return source_file.read_text()
def _render_section_header(task_id: str, function_name: str, proof_only: bool) -> str:
if proof_only:
purpose = "This block is Lean because the function is a proof obligation, not runtime code."
else:
purpose = "This block is Lean because the runtime submission is being mirrored into a proof-checkable model."
return textwrap.dedent(
f"""
-- ============================================================
-- Task: {task_id}
-- Function: {function_name}
-- {purpose}
-- ============================================================
"""
).strip()
def export_lean_business_logic(
task_id: str,
function_name: str,
source_code: str | None = None,
include_proof: bool = False,
) -> str:
task = get_task(task_id)
function_spec = task.get_function(function_name)
if function_spec is None:
raise ValueError(f"Unknown function '{function_name}' for task '{task_id}'.")
if source_code is None:
source_code = _default_source_code(task_id, function_name)
section_header = _render_section_header(
task_id, function_name, function_spec.is_proof_required
)
if function_spec.is_proof_required:
if not include_proof:
raise ValueError(
f"Function '{function_name}' in task '{task_id}' is proof-only; pass --include-proof to export it."
)
proof_block = textwrap.dedent(function_spec.lean_fragment).strip()
return "\n\n".join([section_header, proof_block]).strip()
result = build_verification_ir(task, function_spec, source_code)
if not result.ready or result.lean_code is None:
raise ValueError(result.feedback)
return "\n\n".join([section_header, result.lean_code]).strip()
def main(argv: Sequence[str] | None = None) -> int:
parser = argparse.ArgumentParser(
description="Export the Lean mirror for one of the shipped LeanMigrate task functions."
)
parser.add_argument(
"--task-id",
required=True,
choices=TASK_IDS,
help="Task id to export.",
)
parser.add_argument(
"--function-name",
required=True,
help="Function name to export.",
)
parser.add_argument(
"--include-proof",
action="store_true",
help="Allow proof-only Lean fragments to be exported.",
)
source_group = parser.add_mutually_exclusive_group(required=False)
source_group.add_argument(
"--source-file",
type=Path,
help="Path to the submitted source code file. Defaults to the bundled task source.",
)
source_group.add_argument(
"--source-code",
help="Inline source code for the submitted function. Defaults to the bundled task source.",
)
args = parser.parse_args(argv)
try:
source_code = _read_source_code(
args.task_id, args.function_name, args.source_file, args.source_code
)
output = export_lean_business_logic(
task_id=args.task_id,
function_name=args.function_name,
source_code=source_code,
include_proof=args.include_proof,
)
except (OSError, ValueError) as error:
print(str(error), file=sys.stderr)
return 1
sys.stdout.write(output)
if not output.endswith("\n"):
sys.stdout.write("\n")
return 0
if __name__ == "__main__":
raise SystemExit(main())