lean-migrate / env /verification_semantics.py
Hrushi's picture
Upload folder using huggingface_hub
bf9c466 verified
from __future__ import annotations
import json
from dataclasses import dataclass
from typing import Any, Callable
@dataclass(frozen=True)
class FunctionSemantics:
oracle: Callable[[tuple[Any, ...]], Any]
lean_value: Callable[[Any], str]
lean_call: Callable[[tuple[Any, ...]], str]
def _lean_string(value: str) -> str:
return json.dumps(value)
def _lean_bool(value: bool) -> str:
return "true" if value else "false"
def _lean_int(value: int) -> str:
return str(int(value))
def _lean_list(items: list[str]) -> str:
return "[" + ", ".join(items) + "]"
def _lean_permission(permission: dict[str, Any]) -> str:
return (
"{ resource := "
+ _lean_string(str(permission["resource"]))
+ ", action := "
+ _lean_string(str(permission["action"]))
+ " }"
)
def _lean_role(role: dict[str, Any]) -> str:
permissions = _lean_list(
[_lean_permission(permission) for permission in role.get("permissions", [])]
)
inherits = _lean_list([_lean_string(name) for name in role.get("inherits", [])])
return (
"{ name := "
+ _lean_string(str(role["name"]))
+ ", permissions := "
+ permissions
+ ", inherits := "
+ inherits
+ " }"
)
def _lean_role_list(roles: list[dict[str, Any]]) -> str:
return _lean_list([_lean_role(role) for role in roles])
def _lean_option_role(role: dict[str, Any] | None) -> str:
if role is None:
return "none"
return "some (" + _lean_role(role) + " : AuthSpec.Role)"
def _lean_item(item: dict[str, Any]) -> str:
return (
"{ sku := "
+ _lean_string(str(item["sku"]))
+ ", quantity := "
+ _lean_int(int(item["quantity"]))
+ ", unitPrice := "
+ _lean_int(int(item["unitPrice"]))
+ " }"
)
def _lean_coupon(coupon: dict[str, Any]) -> str:
return (
"{ code := "
+ _lean_string(str(coupon["code"]))
+ ", discountPercent := "
+ _lean_int(int(coupon["discountPercent"]))
+ " }"
)
def _lean_order(order: dict[str, Any]) -> str:
items = _lean_list([_lean_item(item) for item in order.get("items", [])])
coupons = _lean_list([_lean_coupon(coupon) for coupon in order.get("coupons", [])])
return (
"{ items := "
+ items
+ ", coupons := "
+ coupons
+ ", regionId := "
+ _lean_string(str(order.get("regionId", "")))
+ ", loyaltyPoints := "
+ _lean_int(int(order.get("loyaltyPoints", 0)))
+ " }"
)
def _lean_saga_state(state: str) -> str:
mapping = {
"Idle": ".Idle",
"Reserved": ".Reserved",
"Authorized": ".Authorized",
"Captured": ".Captured",
"Settled": ".Settled",
"Compensating": ".Compensating",
"Compensated": ".Compensated",
"Failed": ".Failed",
}
return mapping[state]
def _lean_saga_event(event: str) -> str:
mapping = {
"Reserve": ".Reserve",
"Authorize": ".Authorize",
"Capture": ".Capture",
"Settle": ".Settle",
"CompensateReserve": ".CompensateReserve",
"CompensateAuthorize": ".CompensateAuthorize",
"CompensateCapture": ".CompensateCapture",
"Fail": ".Fail",
}
return mapping[event]
def _oracle_rbac_find_role(args: tuple[Any, ...]) -> dict[str, Any] | None:
roles, name = args
return next((role for role in roles if role["name"] == name), None)
def _oracle_rbac_has_direct_permission(args: tuple[Any, ...]) -> bool:
role, resource, action = args
return any(
permission["resource"] == resource and permission["action"] == action
for permission in role.get("permissions", [])
)
def _oracle_rbac_can_access(args: tuple[Any, ...]) -> bool:
roles, role_name, resource, action, *rest = args
depth = int(rest[0]) if rest else 5
if depth == 0:
return False
role = _oracle_rbac_find_role((roles, role_name))
if role is None:
return False
if _oracle_rbac_has_direct_permission((role, resource, action)):
return True
return any(
_oracle_rbac_can_access((roles, parent_name, resource, action, depth - 1))
for parent_name in role.get("inherits", [])
)
def _oracle_pricing_tax_rate_bps(args: tuple[Any, ...]) -> int:
(region_id,) = args
return {
"US-CA": 875,
"US-TX": 625,
"US-NY": 800,
"UK": 2000,
}.get(region_id, 0)
def _oracle_pricing_subtotal(args: tuple[Any, ...]) -> int:
(order,) = args
return sum(
int(item["unitPrice"]) * int(item["quantity"])
for item in order.get("items", [])
)
def _oracle_pricing_coupon_discount(args: tuple[Any, ...]) -> int:
(order,) = args
subtotal = _oracle_pricing_subtotal((order,))
raw_discount = sum(
(subtotal * int(coupon["discountPercent"])) // 100
for coupon in order.get("coupons", [])
)
return min(raw_discount, subtotal // 2)
def _oracle_pricing_loyalty_discount(args: tuple[Any, ...]) -> int:
(order,) = args
subtotal = _oracle_pricing_subtotal((order,))
return min(int(order.get("loyaltyPoints", 0)), subtotal // 10)
def _oracle_pricing_final_price(args: tuple[Any, ...]) -> int:
(order,) = args
subtotal = _oracle_pricing_subtotal((order,))
total_discount = _oracle_pricing_coupon_discount(
(order,)
) + _oracle_pricing_loyalty_discount((order,))
after_discount = subtotal - total_discount
tax = (
after_discount * _oracle_pricing_tax_rate_bps((order.get("regionId", ""),))
) // 10000
return after_discount + tax
def _oracle_saga_transition(args: tuple[Any, ...]) -> str:
state, event = args
if event == "Fail":
return "Failed"
transitions = {
("Idle", "Reserve"): "Reserved",
("Reserved", "Authorize"): "Authorized",
("Authorized", "Capture"): "Captured",
("Captured", "Settle"): "Settled",
("Reserved", "CompensateReserve"): "Compensated",
("Authorized", "CompensateAuthorize"): "Compensating",
("Compensating", "CompensateReserve"): "Compensated",
("Captured", "CompensateCapture"): "Compensating",
}
return transitions.get((state, event), state)
def _oracle_saga_run(args: tuple[Any, ...]) -> str:
(events,) = args
state = "Idle"
for event in events:
state = _oracle_saga_transition((state, event))
return state
def _oracle_saga_is_charged(args: tuple[Any, ...]) -> bool:
(state,) = args
return state in {"Captured", "Settled"}
def _lean_value_rbac_find_role(value: Any) -> str:
return _lean_option_role(value)
def _lean_value_rbac_bool(value: Any) -> str:
return _lean_bool(bool(value))
def _lean_value_pricing_number(value: Any) -> str:
return _lean_int(int(value))
def _lean_value_saga_state(value: Any) -> str:
return _lean_saga_state(str(value))
def _lean_value_saga_bool(value: Any) -> str:
return _lean_bool(bool(value))
def _lean_call_rbac_find_role(args: tuple[Any, ...]) -> str:
roles, name = args
return f"_root_.findRole {_lean_role_list(roles)} {_lean_string(str(name))}"
def _lean_call_rbac_has_direct_permission(args: tuple[Any, ...]) -> str:
role, resource, action = args
return (
f"_root_.hasDirectPermission {_lean_role(role)} "
f"{_lean_string(str(resource))} {_lean_string(str(action))}"
)
def _lean_call_rbac_can_access(args: tuple[Any, ...]) -> str:
roles, role_name, resource, action, *rest = args
depth = int(rest[0]) if rest else 5
return (
f"_root_.canAccess {_lean_role_list(roles)} "
f"{_lean_string(str(role_name))} {_lean_string(str(resource))} "
f"{_lean_string(str(action))} {_lean_int(depth)}"
)
def _lean_call_pricing_subtotal(args: tuple[Any, ...]) -> str:
(order,) = args
return f"_root_.subtotal {_lean_order(order)}"
def _lean_call_pricing_tax_rate_bps(args: tuple[Any, ...]) -> str:
(region_id,) = args
return f"_root_.taxRateBps {_lean_string(str(region_id))}"
def _lean_call_pricing_coupon_discount(args: tuple[Any, ...]) -> str:
(order,) = args
return f"_root_.couponDiscount {_lean_order(order)}"
def _lean_call_pricing_loyalty_discount(args: tuple[Any, ...]) -> str:
(order,) = args
return f"_root_.loyaltyDiscount {_lean_order(order)}"
def _lean_call_pricing_final_price(args: tuple[Any, ...]) -> str:
(order,) = args
return f"_root_.finalPrice {_lean_order(order)}"
def _lean_call_saga_transition(args: tuple[Any, ...]) -> str:
state, event = args
return f"_root_.transition {_lean_saga_state(str(state))} {_lean_saga_event(str(event))}"
def _lean_call_saga_run(args: tuple[Any, ...]) -> str:
(events,) = args
return f"_root_.runSaga {_lean_list([_lean_saga_event(str(event)) for event in events])}"
def _lean_call_saga_is_charged(args: tuple[Any, ...]) -> str:
(state,) = args
return f"_root_.isCharged {_lean_saga_state(str(state))}"
# ── Path helpers ──────────────────────────────────────────────────────────────
def _lean_str_list(segments: list[str]) -> str:
return "[" + ", ".join(_lean_string(s) for s in segments) + "]"
def _oracle_path_normalize(args: tuple[Any, ...]) -> list[str]:
(segments,) = args
acc: list[str] = []
for seg in segments:
if seg == ".":
pass
elif seg == "..":
if acc:
acc.pop()
else:
acc.append(seg)
return acc
def _oracle_path_join(args: tuple[Any, ...]) -> list[str]:
base, rel = args
return _oracle_path_normalize((list(base) + list(rel),))
def _oracle_path_depth(args: tuple[Any, ...]) -> int:
(segments,) = args
return len(segments)
def _lean_call_path_normalize(args: tuple[Any, ...]) -> str:
(segments,) = args
return f"_root_.normalizePath {_lean_str_list(list(segments))}"
def _lean_call_path_join(args: tuple[Any, ...]) -> str:
base, rel = args
return f"_root_.joinPaths {_lean_str_list(list(base))} {_lean_str_list(list(rel))}"
def _lean_call_path_depth(args: tuple[Any, ...]) -> str:
(segments,) = args
return f"_root_.pathDepth {_lean_str_list(list(segments))}"
def _lean_value_str_list(value: Any) -> str:
return _lean_str_list(list(value))
def _lean_value_nat(value: Any) -> str:
return str(int(value))
# ── Expression helpers ────────────────────────────────────────────────────────
def _lean_op(op: str) -> str:
return f".{op}"
def _lean_expr(expr: dict[str, Any]) -> str:
if expr["tag"] == "Lit":
n = int(expr["n"])
if n < 0:
return f"(.Lit ({n}))"
return f"(.Lit {n})"
op = expr["op"]
l = _lean_expr(expr["l"])
r = _lean_expr(expr["r"])
return f"(.BinOp .{op} {l} {r})"
def _oracle_expr_eval_bin_op(args: tuple[Any, ...]) -> int | None:
op, a, b = args
a, b = int(a), int(b)
if op == "Add":
return a + b
if op == "Sub":
return a - b
if op == "Mul":
return a * b
if op == "Div":
return None if b == 0 else int(a / b) # truncate toward zero
return None
def _oracle_expr_eval_expr(args: tuple[Any, ...]) -> int | None:
(expr,) = args
def _eval(e: dict[str, Any]) -> int | None:
if e["tag"] == "Lit":
return int(e["n"])
op = e["op"]
lv = _eval(e["l"])
rv = _eval(e["r"])
if lv is None or rv is None:
return None
return _oracle_expr_eval_bin_op((op, lv, rv))
return _eval(expr)
def _lean_value_option_int(value: Any) -> str:
if value is None:
return "none"
return f"some ({int(value)})"
def _lean_call_eval_bin_op(args: tuple[Any, ...]) -> str:
op, a, b = args
a, b = int(a), int(b)
a_str = f"({a})" if a < 0 else str(a)
b_str = f"({b})" if b < 0 else str(b)
return f"_root_.evalBinOp .{op} {a_str} {b_str}"
def _lean_call_eval_expr(args: tuple[Any, ...]) -> str:
(expr,) = args
return f"_root_.evalExpr {_lean_expr(expr)}"
# ── LRU helpers ───────────────────────────────────────────────────────────────
def _lean_nat_pair(p: tuple[int, int]) -> str:
return f"({int(p[0])}, {int(p[1])})"
def _lean_nat_pair_list(cache: list[tuple[int, int]]) -> str:
return "[" + ", ".join(_lean_nat_pair(p) for p in cache) + "]"
def _oracle_lru_evict(args: tuple[Any, ...]) -> list[tuple[int, int]]:
cache, cap = args
return list(cache)[:int(cap)]
def _oracle_lru_put(args: tuple[Any, ...]) -> list[tuple[int, int]]:
cache, cap, key, val = args
cap, key, val = int(cap), int(key), int(val)
filtered = [(k, v) for k, v in cache if int(k) != key]
return ([(key, val)] + filtered)[:cap]
def _oracle_lru_get(args: tuple[Any, ...]) -> tuple[int | None, list[tuple[int, int]]]:
cache, key = args
key = int(key)
found = next((v for k, v in cache if int(k) == key), None)
if found is None:
return (None, list(cache))
new_cache = [(key, found)] + [(k, v) for k, v in cache if int(k) != key]
return (found, new_cache)
def _lean_value_lru_cache(value: Any) -> str:
return _lean_nat_pair_list([(int(k), int(v)) for k, v in value])
def _lean_value_lru_get(value: Any) -> str:
opt, cache = value
opt_str = "none" if opt is None else f"some {int(opt)}"
cache_str = _lean_nat_pair_list([(int(k), int(v)) for k, v in cache])
return f"({opt_str}, {cache_str})"
def _lean_call_lru_evict(args: tuple[Any, ...]) -> str:
cache, cap = args
return f"_root_.lruEvict {_lean_nat_pair_list([(int(k),int(v)) for k,v in cache])} {int(cap)}"
def _lean_call_lru_put(args: tuple[Any, ...]) -> str:
cache, cap, key, val = args
return (
f"_root_.lruPut {_lean_nat_pair_list([(int(k),int(v)) for k,v in cache])} "
f"{int(cap)} {int(key)} {int(val)}"
)
def _lean_call_lru_get(args: tuple[Any, ...]) -> str:
cache, key = args
return f"_root_.lruGet {_lean_nat_pair_list([(int(k),int(v)) for k,v in cache])} {int(key)}"
# ── Shortest-path helpers ─────────────────────────────────────────────────────
def _lean_edge_list(edges: list[tuple[int, int, int]]) -> str:
parts = [f"({int(fr)}, {int(to)}, {int(w)})" for fr, to, w in edges]
return "[" + ", ".join(parts) + "]"
def _oracle_dijkstra(args: tuple[Any, ...]) -> list[int | None]:
edges, n, src = args
import heapq
n, src = int(n), int(src)
dist: list[int | None] = [None] * n
if src >= n:
return dist
dist[src] = 0
pq: list[tuple[int, int]] = [(0, src)]
while pq:
d, u = heapq.heappop(pq)
if dist[u] is not None and d > dist[u]:
continue
for fr, to, w in edges:
fr, to, w = int(fr), int(to), int(w)
if fr == u and to < n:
nd = d + w
if dist[to] is None or nd < dist[to]:
dist[to] = nd
heapq.heappush(pq, (nd, to))
return dist
def _oracle_shortest_dist(args: tuple[Any, ...]) -> int | None:
edges, n, src, dst = args
dists = _oracle_dijkstra((edges, n, src))
dst = int(dst)
if dst >= len(dists):
return None
return dists[dst]
def _lean_value_dist_list(value: Any) -> str:
parts = ["none" if d is None else f"some {int(d)}" for d in value]
return "[" + ", ".join(parts) + "]"
def _lean_value_option_nat(value: Any) -> str:
if value is None:
return "none"
return f"some {int(value)}"
def _lean_call_dijkstra(args: tuple[Any, ...]) -> str:
edges, n, src = args
return f"_root_.dijkstra {_lean_edge_list(edges)} {int(n)} {int(src)}"
def _lean_call_shortest_dist(args: tuple[Any, ...]) -> str:
edges, n, src, dst = args
return (
f"_root_.shortestDist {_lean_edge_list(edges)} "
f"{int(n)} {int(src)} {int(dst)}"
)
# ── Interval helpers ──────────────────────────────────────────────────────────
def _lean_int_pair(iv: tuple[int, int]) -> str:
a, b = int(iv[0]), int(iv[1])
a_str = f"({a})" if a < 0 else str(a)
b_str = f"({b})" if b < 0 else str(b)
return f"({a_str}, {b_str})"
def _lean_int_pair_list(ivs: list[tuple[int, int]]) -> str:
return "[" + ", ".join(_lean_int_pair(iv) for iv in ivs) + "]"
def _oracle_merge_intervals(args: tuple[Any, ...]) -> list[tuple[int, int]]:
(ivs,) = args
if not ivs:
return []
sorted_ivs = sorted(ivs, key=lambda x: (int(x[0]), int(x[1])))
merged: list[tuple[int, int]] = [tuple(sorted_ivs[0])] # type: ignore[assignment]
for s, e in sorted_ivs[1:]:
s, e = int(s), int(e)
ls, le = merged[-1]
if s <= le:
merged[-1] = (ls, max(le, e))
else:
merged.append((s, e))
return merged
def _oracle_max_schedule(args: tuple[Any, ...]) -> list[tuple[int, int]]:
(ivs,) = args
if not ivs:
return []
sorted_ivs = sorted(ivs, key=lambda x: (int(x[1]), int(x[0])))
result: list[tuple[int, int]] = []
last_end = None
for s, e in sorted_ivs:
s, e = int(s), int(e)
if last_end is None or s >= last_end:
result.append((s, e))
last_end = e
return result
def _lean_value_int_pair_list(value: Any) -> str:
return _lean_int_pair_list([(int(a), int(b)) for a, b in value])
def _lean_call_merge_intervals(args: tuple[Any, ...]) -> str:
(ivs,) = args
return f"_root_.mergeIntervals {_lean_int_pair_list([(int(a),int(b)) for a,b in ivs])}"
def _lean_call_max_schedule(args: tuple[Any, ...]) -> str:
(ivs,) = args
return f"_root_.maxSchedule {_lean_int_pair_list([(int(a),int(b)) for a,b in ivs])}"
FUNCTION_SEMANTICS: dict[tuple[str, str], FunctionSemantics] = {
("rbac_auth", "findRole"): FunctionSemantics(
oracle=_oracle_rbac_find_role,
lean_value=_lean_value_rbac_find_role,
lean_call=_lean_call_rbac_find_role,
),
("rbac_auth", "hasDirectPermission"): FunctionSemantics(
oracle=_oracle_rbac_has_direct_permission,
lean_value=_lean_value_rbac_bool,
lean_call=_lean_call_rbac_has_direct_permission,
),
("rbac_auth", "canAccess"): FunctionSemantics(
oracle=_oracle_rbac_can_access,
lean_value=_lean_value_rbac_bool,
lean_call=_lean_call_rbac_can_access,
),
("pricing_engine", "subtotal"): FunctionSemantics(
oracle=_oracle_pricing_subtotal,
lean_value=_lean_value_pricing_number,
lean_call=_lean_call_pricing_subtotal,
),
("pricing_engine", "taxRateBps"): FunctionSemantics(
oracle=_oracle_pricing_tax_rate_bps,
lean_value=_lean_value_pricing_number,
lean_call=_lean_call_pricing_tax_rate_bps,
),
("pricing_engine", "couponDiscount"): FunctionSemantics(
oracle=_oracle_pricing_coupon_discount,
lean_value=_lean_value_pricing_number,
lean_call=_lean_call_pricing_coupon_discount,
),
("pricing_engine", "loyaltyDiscount"): FunctionSemantics(
oracle=_oracle_pricing_loyalty_discount,
lean_value=_lean_value_pricing_number,
lean_call=_lean_call_pricing_loyalty_discount,
),
("pricing_engine", "finalPrice"): FunctionSemantics(
oracle=_oracle_pricing_final_price,
lean_value=_lean_value_pricing_number,
lean_call=_lean_call_pricing_final_price,
),
("payment_saga", "transition"): FunctionSemantics(
oracle=_oracle_saga_transition,
lean_value=_lean_value_saga_state,
lean_call=_lean_call_saga_transition,
),
("payment_saga", "runSaga"): FunctionSemantics(
oracle=_oracle_saga_run,
lean_value=_lean_value_saga_state,
lean_call=_lean_call_saga_run,
),
("payment_saga", "isCharged"): FunctionSemantics(
oracle=_oracle_saga_is_charged,
lean_value=_lean_value_saga_bool,
lean_call=_lean_call_saga_is_charged,
),
# Path canonicalizer
("path_canonicalizer", "normalizePath"): FunctionSemantics(
oracle=_oracle_path_normalize,
lean_value=_lean_value_str_list,
lean_call=_lean_call_path_normalize,
),
("path_canonicalizer", "joinPaths"): FunctionSemantics(
oracle=_oracle_path_join,
lean_value=_lean_value_str_list,
lean_call=_lean_call_path_join,
),
("path_canonicalizer", "pathDepth"): FunctionSemantics(
oracle=_oracle_path_depth,
lean_value=_lean_value_nat,
lean_call=_lean_call_path_depth,
),
# Expression evaluator
("expression_eval", "evalBinOp"): FunctionSemantics(
oracle=_oracle_expr_eval_bin_op,
lean_value=_lean_value_option_int,
lean_call=_lean_call_eval_bin_op,
),
("expression_eval", "evalExpr"): FunctionSemantics(
oracle=_oracle_expr_eval_expr,
lean_value=_lean_value_option_int,
lean_call=_lean_call_eval_expr,
),
# LRU cache
("lru_cache", "lruEvict"): FunctionSemantics(
oracle=_oracle_lru_evict,
lean_value=_lean_value_lru_cache,
lean_call=_lean_call_lru_evict,
),
("lru_cache", "lruPut"): FunctionSemantics(
oracle=_oracle_lru_put,
lean_value=_lean_value_lru_cache,
lean_call=_lean_call_lru_put,
),
("lru_cache", "lruGet"): FunctionSemantics(
oracle=_oracle_lru_get,
lean_value=_lean_value_lru_get,
lean_call=_lean_call_lru_get,
),
# Shortest path
("shortest_path", "dijkstra"): FunctionSemantics(
oracle=_oracle_dijkstra,
lean_value=_lean_value_dist_list,
lean_call=_lean_call_dijkstra,
),
("shortest_path", "shortestDist"): FunctionSemantics(
oracle=_oracle_shortest_dist,
lean_value=_lean_value_option_nat,
lean_call=_lean_call_shortest_dist,
),
# Interval scheduler
("interval_scheduler", "mergeIntervals"): FunctionSemantics(
oracle=_oracle_merge_intervals,
lean_value=_lean_value_int_pair_list,
lean_call=_lean_call_merge_intervals,
),
("interval_scheduler", "maxSchedule"): FunctionSemantics(
oracle=_oracle_max_schedule,
lean_value=_lean_value_int_pair_list,
lean_call=_lean_call_max_schedule,
),
}
def get_function_semantics(task_id: str, function_name: str) -> FunctionSemantics:
try:
return FUNCTION_SEMANTICS[(task_id, function_name)]
except KeyError as error:
available = ", ".join(
f"{current_task}.{current_function}"
for current_task, current_function in sorted(FUNCTION_SEMANTICS)
)
raise ValueError(
f"No verification semantics available for {task_id}.{function_name}. "
f"Available: {available}"
) from error
def oracle_result(task_id: str, function_name: str, args: tuple[Any, ...]) -> Any:
return get_function_semantics(task_id, function_name).oracle(args)
def lean_value(task_id: str, function_name: str, value: Any) -> str:
return get_function_semantics(task_id, function_name).lean_value(value)
def lean_call(task_id: str, function_name: str, args: tuple[Any, ...]) -> str:
return get_function_semantics(task_id, function_name).lean_call(args)
__all__ = [
"FunctionSemantics",
"FUNCTION_SEMANTICS",
"get_function_semantics",
"lean_call",
"lean_value",
"oracle_result",
]