from __future__ import annotations import json from dataclasses import dataclass from typing import Any, Callable @dataclass(frozen=True) class FunctionSemantics: oracle: Callable[[tuple[Any, ...]], Any] lean_value: Callable[[Any], str] lean_call: Callable[[tuple[Any, ...]], str] def _lean_string(value: str) -> str: return json.dumps(value) def _lean_bool(value: bool) -> str: return "true" if value else "false" def _lean_int(value: int) -> str: return str(int(value)) def _lean_list(items: list[str]) -> str: return "[" + ", ".join(items) + "]" def _lean_permission(permission: dict[str, Any]) -> str: return ( "{ resource := " + _lean_string(str(permission["resource"])) + ", action := " + _lean_string(str(permission["action"])) + " }" ) def _lean_role(role: dict[str, Any]) -> str: permissions = _lean_list( [_lean_permission(permission) for permission in role.get("permissions", [])] ) inherits = _lean_list([_lean_string(name) for name in role.get("inherits", [])]) return ( "{ name := " + _lean_string(str(role["name"])) + ", permissions := " + permissions + ", inherits := " + inherits + " }" ) def _lean_role_list(roles: list[dict[str, Any]]) -> str: return _lean_list([_lean_role(role) for role in roles]) def _lean_option_role(role: dict[str, Any] | None) -> str: if role is None: return "none" return "some (" + _lean_role(role) + " : AuthSpec.Role)" def _lean_item(item: dict[str, Any]) -> str: return ( "{ sku := " + _lean_string(str(item["sku"])) + ", quantity := " + _lean_int(int(item["quantity"])) + ", unitPrice := " + _lean_int(int(item["unitPrice"])) + " }" ) def _lean_coupon(coupon: dict[str, Any]) -> str: return ( "{ code := " + _lean_string(str(coupon["code"])) + ", discountPercent := " + _lean_int(int(coupon["discountPercent"])) + " }" ) def _lean_order(order: dict[str, Any]) -> str: items = _lean_list([_lean_item(item) for item in order.get("items", [])]) coupons = _lean_list([_lean_coupon(coupon) for coupon in order.get("coupons", [])]) return ( "{ items := " + items + ", coupons := " + coupons + ", regionId := " + _lean_string(str(order.get("regionId", ""))) + ", loyaltyPoints := " + _lean_int(int(order.get("loyaltyPoints", 0))) + " }" ) def _lean_saga_state(state: str) -> str: mapping = { "Idle": ".Idle", "Reserved": ".Reserved", "Authorized": ".Authorized", "Captured": ".Captured", "Settled": ".Settled", "Compensating": ".Compensating", "Compensated": ".Compensated", "Failed": ".Failed", } return mapping[state] def _lean_saga_event(event: str) -> str: mapping = { "Reserve": ".Reserve", "Authorize": ".Authorize", "Capture": ".Capture", "Settle": ".Settle", "CompensateReserve": ".CompensateReserve", "CompensateAuthorize": ".CompensateAuthorize", "CompensateCapture": ".CompensateCapture", "Fail": ".Fail", } return mapping[event] def _oracle_rbac_find_role(args: tuple[Any, ...]) -> dict[str, Any] | None: roles, name = args return next((role for role in roles if role["name"] == name), None) def _oracle_rbac_has_direct_permission(args: tuple[Any, ...]) -> bool: role, resource, action = args return any( permission["resource"] == resource and permission["action"] == action for permission in role.get("permissions", []) ) def _oracle_rbac_can_access(args: tuple[Any, ...]) -> bool: roles, role_name, resource, action, *rest = args depth = int(rest[0]) if rest else 5 if depth == 0: return False role = _oracle_rbac_find_role((roles, role_name)) if role is None: return False if _oracle_rbac_has_direct_permission((role, resource, action)): return True return any( _oracle_rbac_can_access((roles, parent_name, resource, action, depth - 1)) for parent_name in role.get("inherits", []) ) def _oracle_pricing_tax_rate_bps(args: tuple[Any, ...]) -> int: (region_id,) = args return { "US-CA": 875, "US-TX": 625, "US-NY": 800, "UK": 2000, }.get(region_id, 0) def _oracle_pricing_subtotal(args: tuple[Any, ...]) -> int: (order,) = args return sum( int(item["unitPrice"]) * int(item["quantity"]) for item in order.get("items", []) ) def _oracle_pricing_coupon_discount(args: tuple[Any, ...]) -> int: (order,) = args subtotal = _oracle_pricing_subtotal((order,)) raw_discount = sum( (subtotal * int(coupon["discountPercent"])) // 100 for coupon in order.get("coupons", []) ) return min(raw_discount, subtotal // 2) def _oracle_pricing_loyalty_discount(args: tuple[Any, ...]) -> int: (order,) = args subtotal = _oracle_pricing_subtotal((order,)) return min(int(order.get("loyaltyPoints", 0)), subtotal // 10) def _oracle_pricing_final_price(args: tuple[Any, ...]) -> int: (order,) = args subtotal = _oracle_pricing_subtotal((order,)) total_discount = _oracle_pricing_coupon_discount( (order,) ) + _oracle_pricing_loyalty_discount((order,)) after_discount = subtotal - total_discount tax = ( after_discount * _oracle_pricing_tax_rate_bps((order.get("regionId", ""),)) ) // 10000 return after_discount + tax def _oracle_saga_transition(args: tuple[Any, ...]) -> str: state, event = args if event == "Fail": return "Failed" transitions = { ("Idle", "Reserve"): "Reserved", ("Reserved", "Authorize"): "Authorized", ("Authorized", "Capture"): "Captured", ("Captured", "Settle"): "Settled", ("Reserved", "CompensateReserve"): "Compensated", ("Authorized", "CompensateAuthorize"): "Compensating", ("Compensating", "CompensateReserve"): "Compensated", ("Captured", "CompensateCapture"): "Compensating", } return transitions.get((state, event), state) def _oracle_saga_run(args: tuple[Any, ...]) -> str: (events,) = args state = "Idle" for event in events: state = _oracle_saga_transition((state, event)) return state def _oracle_saga_is_charged(args: tuple[Any, ...]) -> bool: (state,) = args return state in {"Captured", "Settled"} def _lean_value_rbac_find_role(value: Any) -> str: return _lean_option_role(value) def _lean_value_rbac_bool(value: Any) -> str: return _lean_bool(bool(value)) def _lean_value_pricing_number(value: Any) -> str: return _lean_int(int(value)) def _lean_value_saga_state(value: Any) -> str: return _lean_saga_state(str(value)) def _lean_value_saga_bool(value: Any) -> str: return _lean_bool(bool(value)) def _lean_call_rbac_find_role(args: tuple[Any, ...]) -> str: roles, name = args return f"_root_.findRole {_lean_role_list(roles)} {_lean_string(str(name))}" def _lean_call_rbac_has_direct_permission(args: tuple[Any, ...]) -> str: role, resource, action = args return ( f"_root_.hasDirectPermission {_lean_role(role)} " f"{_lean_string(str(resource))} {_lean_string(str(action))}" ) def _lean_call_rbac_can_access(args: tuple[Any, ...]) -> str: roles, role_name, resource, action, *rest = args depth = int(rest[0]) if rest else 5 return ( f"_root_.canAccess {_lean_role_list(roles)} " f"{_lean_string(str(role_name))} {_lean_string(str(resource))} " f"{_lean_string(str(action))} {_lean_int(depth)}" ) def _lean_call_pricing_subtotal(args: tuple[Any, ...]) -> str: (order,) = args return f"_root_.subtotal {_lean_order(order)}" def _lean_call_pricing_tax_rate_bps(args: tuple[Any, ...]) -> str: (region_id,) = args return f"_root_.taxRateBps {_lean_string(str(region_id))}" def _lean_call_pricing_coupon_discount(args: tuple[Any, ...]) -> str: (order,) = args return f"_root_.couponDiscount {_lean_order(order)}" def _lean_call_pricing_loyalty_discount(args: tuple[Any, ...]) -> str: (order,) = args return f"_root_.loyaltyDiscount {_lean_order(order)}" def _lean_call_pricing_final_price(args: tuple[Any, ...]) -> str: (order,) = args return f"_root_.finalPrice {_lean_order(order)}" def _lean_call_saga_transition(args: tuple[Any, ...]) -> str: state, event = args return f"_root_.transition {_lean_saga_state(str(state))} {_lean_saga_event(str(event))}" def _lean_call_saga_run(args: tuple[Any, ...]) -> str: (events,) = args return f"_root_.runSaga {_lean_list([_lean_saga_event(str(event)) for event in events])}" def _lean_call_saga_is_charged(args: tuple[Any, ...]) -> str: (state,) = args return f"_root_.isCharged {_lean_saga_state(str(state))}" # ── Path helpers ────────────────────────────────────────────────────────────── def _lean_str_list(segments: list[str]) -> str: return "[" + ", ".join(_lean_string(s) for s in segments) + "]" def _oracle_path_normalize(args: tuple[Any, ...]) -> list[str]: (segments,) = args acc: list[str] = [] for seg in segments: if seg == ".": pass elif seg == "..": if acc: acc.pop() else: acc.append(seg) return acc def _oracle_path_join(args: tuple[Any, ...]) -> list[str]: base, rel = args return _oracle_path_normalize((list(base) + list(rel),)) def _oracle_path_depth(args: tuple[Any, ...]) -> int: (segments,) = args return len(segments) def _lean_call_path_normalize(args: tuple[Any, ...]) -> str: (segments,) = args return f"_root_.normalizePath {_lean_str_list(list(segments))}" def _lean_call_path_join(args: tuple[Any, ...]) -> str: base, rel = args return f"_root_.joinPaths {_lean_str_list(list(base))} {_lean_str_list(list(rel))}" def _lean_call_path_depth(args: tuple[Any, ...]) -> str: (segments,) = args return f"_root_.pathDepth {_lean_str_list(list(segments))}" def _lean_value_str_list(value: Any) -> str: return _lean_str_list(list(value)) def _lean_value_nat(value: Any) -> str: return str(int(value)) # ── Expression helpers ──────────────────────────────────────────────────────── def _lean_op(op: str) -> str: return f".{op}" def _lean_expr(expr: dict[str, Any]) -> str: if expr["tag"] == "Lit": n = int(expr["n"]) if n < 0: return f"(.Lit ({n}))" return f"(.Lit {n})" op = expr["op"] l = _lean_expr(expr["l"]) r = _lean_expr(expr["r"]) return f"(.BinOp .{op} {l} {r})" def _oracle_expr_eval_bin_op(args: tuple[Any, ...]) -> int | None: op, a, b = args a, b = int(a), int(b) if op == "Add": return a + b if op == "Sub": return a - b if op == "Mul": return a * b if op == "Div": return None if b == 0 else int(a / b) # truncate toward zero return None def _oracle_expr_eval_expr(args: tuple[Any, ...]) -> int | None: (expr,) = args def _eval(e: dict[str, Any]) -> int | None: if e["tag"] == "Lit": return int(e["n"]) op = e["op"] lv = _eval(e["l"]) rv = _eval(e["r"]) if lv is None or rv is None: return None return _oracle_expr_eval_bin_op((op, lv, rv)) return _eval(expr) def _lean_value_option_int(value: Any) -> str: if value is None: return "none" return f"some ({int(value)})" def _lean_call_eval_bin_op(args: tuple[Any, ...]) -> str: op, a, b = args a, b = int(a), int(b) a_str = f"({a})" if a < 0 else str(a) b_str = f"({b})" if b < 0 else str(b) return f"_root_.evalBinOp .{op} {a_str} {b_str}" def _lean_call_eval_expr(args: tuple[Any, ...]) -> str: (expr,) = args return f"_root_.evalExpr {_lean_expr(expr)}" # ── LRU helpers ─────────────────────────────────────────────────────────────── def _lean_nat_pair(p: tuple[int, int]) -> str: return f"({int(p[0])}, {int(p[1])})" def _lean_nat_pair_list(cache: list[tuple[int, int]]) -> str: return "[" + ", ".join(_lean_nat_pair(p) for p in cache) + "]" def _oracle_lru_evict(args: tuple[Any, ...]) -> list[tuple[int, int]]: cache, cap = args return list(cache)[:int(cap)] def _oracle_lru_put(args: tuple[Any, ...]) -> list[tuple[int, int]]: cache, cap, key, val = args cap, key, val = int(cap), int(key), int(val) filtered = [(k, v) for k, v in cache if int(k) != key] return ([(key, val)] + filtered)[:cap] def _oracle_lru_get(args: tuple[Any, ...]) -> tuple[int | None, list[tuple[int, int]]]: cache, key = args key = int(key) found = next((v for k, v in cache if int(k) == key), None) if found is None: return (None, list(cache)) new_cache = [(key, found)] + [(k, v) for k, v in cache if int(k) != key] return (found, new_cache) def _lean_value_lru_cache(value: Any) -> str: return _lean_nat_pair_list([(int(k), int(v)) for k, v in value]) def _lean_value_lru_get(value: Any) -> str: opt, cache = value opt_str = "none" if opt is None else f"some {int(opt)}" cache_str = _lean_nat_pair_list([(int(k), int(v)) for k, v in cache]) return f"({opt_str}, {cache_str})" def _lean_call_lru_evict(args: tuple[Any, ...]) -> str: cache, cap = args return f"_root_.lruEvict {_lean_nat_pair_list([(int(k),int(v)) for k,v in cache])} {int(cap)}" def _lean_call_lru_put(args: tuple[Any, ...]) -> str: cache, cap, key, val = args return ( f"_root_.lruPut {_lean_nat_pair_list([(int(k),int(v)) for k,v in cache])} " f"{int(cap)} {int(key)} {int(val)}" ) def _lean_call_lru_get(args: tuple[Any, ...]) -> str: cache, key = args return f"_root_.lruGet {_lean_nat_pair_list([(int(k),int(v)) for k,v in cache])} {int(key)}" # ── Shortest-path helpers ───────────────────────────────────────────────────── def _lean_edge_list(edges: list[tuple[int, int, int]]) -> str: parts = [f"({int(fr)}, {int(to)}, {int(w)})" for fr, to, w in edges] return "[" + ", ".join(parts) + "]" def _oracle_dijkstra(args: tuple[Any, ...]) -> list[int | None]: edges, n, src = args import heapq n, src = int(n), int(src) dist: list[int | None] = [None] * n if src >= n: return dist dist[src] = 0 pq: list[tuple[int, int]] = [(0, src)] while pq: d, u = heapq.heappop(pq) if dist[u] is not None and d > dist[u]: continue for fr, to, w in edges: fr, to, w = int(fr), int(to), int(w) if fr == u and to < n: nd = d + w if dist[to] is None or nd < dist[to]: dist[to] = nd heapq.heappush(pq, (nd, to)) return dist def _oracle_shortest_dist(args: tuple[Any, ...]) -> int | None: edges, n, src, dst = args dists = _oracle_dijkstra((edges, n, src)) dst = int(dst) if dst >= len(dists): return None return dists[dst] def _lean_value_dist_list(value: Any) -> str: parts = ["none" if d is None else f"some {int(d)}" for d in value] return "[" + ", ".join(parts) + "]" def _lean_value_option_nat(value: Any) -> str: if value is None: return "none" return f"some {int(value)}" def _lean_call_dijkstra(args: tuple[Any, ...]) -> str: edges, n, src = args return f"_root_.dijkstra {_lean_edge_list(edges)} {int(n)} {int(src)}" def _lean_call_shortest_dist(args: tuple[Any, ...]) -> str: edges, n, src, dst = args return ( f"_root_.shortestDist {_lean_edge_list(edges)} " f"{int(n)} {int(src)} {int(dst)}" ) # ── Interval helpers ────────────────────────────────────────────────────────── def _lean_int_pair(iv: tuple[int, int]) -> str: a, b = int(iv[0]), int(iv[1]) a_str = f"({a})" if a < 0 else str(a) b_str = f"({b})" if b < 0 else str(b) return f"({a_str}, {b_str})" def _lean_int_pair_list(ivs: list[tuple[int, int]]) -> str: return "[" + ", ".join(_lean_int_pair(iv) for iv in ivs) + "]" def _oracle_merge_intervals(args: tuple[Any, ...]) -> list[tuple[int, int]]: (ivs,) = args if not ivs: return [] sorted_ivs = sorted(ivs, key=lambda x: (int(x[0]), int(x[1]))) merged: list[tuple[int, int]] = [tuple(sorted_ivs[0])] # type: ignore[assignment] for s, e in sorted_ivs[1:]: s, e = int(s), int(e) ls, le = merged[-1] if s <= le: merged[-1] = (ls, max(le, e)) else: merged.append((s, e)) return merged def _oracle_max_schedule(args: tuple[Any, ...]) -> list[tuple[int, int]]: (ivs,) = args if not ivs: return [] sorted_ivs = sorted(ivs, key=lambda x: (int(x[1]), int(x[0]))) result: list[tuple[int, int]] = [] last_end = None for s, e in sorted_ivs: s, e = int(s), int(e) if last_end is None or s >= last_end: result.append((s, e)) last_end = e return result def _lean_value_int_pair_list(value: Any) -> str: return _lean_int_pair_list([(int(a), int(b)) for a, b in value]) def _lean_call_merge_intervals(args: tuple[Any, ...]) -> str: (ivs,) = args return f"_root_.mergeIntervals {_lean_int_pair_list([(int(a),int(b)) for a,b in ivs])}" def _lean_call_max_schedule(args: tuple[Any, ...]) -> str: (ivs,) = args return f"_root_.maxSchedule {_lean_int_pair_list([(int(a),int(b)) for a,b in ivs])}" FUNCTION_SEMANTICS: dict[tuple[str, str], FunctionSemantics] = { ("rbac_auth", "findRole"): FunctionSemantics( oracle=_oracle_rbac_find_role, lean_value=_lean_value_rbac_find_role, lean_call=_lean_call_rbac_find_role, ), ("rbac_auth", "hasDirectPermission"): FunctionSemantics( oracle=_oracle_rbac_has_direct_permission, lean_value=_lean_value_rbac_bool, lean_call=_lean_call_rbac_has_direct_permission, ), ("rbac_auth", "canAccess"): FunctionSemantics( oracle=_oracle_rbac_can_access, lean_value=_lean_value_rbac_bool, lean_call=_lean_call_rbac_can_access, ), ("pricing_engine", "subtotal"): FunctionSemantics( oracle=_oracle_pricing_subtotal, lean_value=_lean_value_pricing_number, lean_call=_lean_call_pricing_subtotal, ), ("pricing_engine", "taxRateBps"): FunctionSemantics( oracle=_oracle_pricing_tax_rate_bps, lean_value=_lean_value_pricing_number, lean_call=_lean_call_pricing_tax_rate_bps, ), ("pricing_engine", "couponDiscount"): FunctionSemantics( oracle=_oracle_pricing_coupon_discount, lean_value=_lean_value_pricing_number, lean_call=_lean_call_pricing_coupon_discount, ), ("pricing_engine", "loyaltyDiscount"): FunctionSemantics( oracle=_oracle_pricing_loyalty_discount, lean_value=_lean_value_pricing_number, lean_call=_lean_call_pricing_loyalty_discount, ), ("pricing_engine", "finalPrice"): FunctionSemantics( oracle=_oracle_pricing_final_price, lean_value=_lean_value_pricing_number, lean_call=_lean_call_pricing_final_price, ), ("payment_saga", "transition"): FunctionSemantics( oracle=_oracle_saga_transition, lean_value=_lean_value_saga_state, lean_call=_lean_call_saga_transition, ), ("payment_saga", "runSaga"): FunctionSemantics( oracle=_oracle_saga_run, lean_value=_lean_value_saga_state, lean_call=_lean_call_saga_run, ), ("payment_saga", "isCharged"): FunctionSemantics( oracle=_oracle_saga_is_charged, lean_value=_lean_value_saga_bool, lean_call=_lean_call_saga_is_charged, ), # Path canonicalizer ("path_canonicalizer", "normalizePath"): FunctionSemantics( oracle=_oracle_path_normalize, lean_value=_lean_value_str_list, lean_call=_lean_call_path_normalize, ), ("path_canonicalizer", "joinPaths"): FunctionSemantics( oracle=_oracle_path_join, lean_value=_lean_value_str_list, lean_call=_lean_call_path_join, ), ("path_canonicalizer", "pathDepth"): FunctionSemantics( oracle=_oracle_path_depth, lean_value=_lean_value_nat, lean_call=_lean_call_path_depth, ), # Expression evaluator ("expression_eval", "evalBinOp"): FunctionSemantics( oracle=_oracle_expr_eval_bin_op, lean_value=_lean_value_option_int, lean_call=_lean_call_eval_bin_op, ), ("expression_eval", "evalExpr"): FunctionSemantics( oracle=_oracle_expr_eval_expr, lean_value=_lean_value_option_int, lean_call=_lean_call_eval_expr, ), # LRU cache ("lru_cache", "lruEvict"): FunctionSemantics( oracle=_oracle_lru_evict, lean_value=_lean_value_lru_cache, lean_call=_lean_call_lru_evict, ), ("lru_cache", "lruPut"): FunctionSemantics( oracle=_oracle_lru_put, lean_value=_lean_value_lru_cache, lean_call=_lean_call_lru_put, ), ("lru_cache", "lruGet"): FunctionSemantics( oracle=_oracle_lru_get, lean_value=_lean_value_lru_get, lean_call=_lean_call_lru_get, ), # Shortest path ("shortest_path", "dijkstra"): FunctionSemantics( oracle=_oracle_dijkstra, lean_value=_lean_value_dist_list, lean_call=_lean_call_dijkstra, ), ("shortest_path", "shortestDist"): FunctionSemantics( oracle=_oracle_shortest_dist, lean_value=_lean_value_option_nat, lean_call=_lean_call_shortest_dist, ), # Interval scheduler ("interval_scheduler", "mergeIntervals"): FunctionSemantics( oracle=_oracle_merge_intervals, lean_value=_lean_value_int_pair_list, lean_call=_lean_call_merge_intervals, ), ("interval_scheduler", "maxSchedule"): FunctionSemantics( oracle=_oracle_max_schedule, lean_value=_lean_value_int_pair_list, lean_call=_lean_call_max_schedule, ), } def get_function_semantics(task_id: str, function_name: str) -> FunctionSemantics: try: return FUNCTION_SEMANTICS[(task_id, function_name)] except KeyError as error: available = ", ".join( f"{current_task}.{current_function}" for current_task, current_function in sorted(FUNCTION_SEMANTICS) ) raise ValueError( f"No verification semantics available for {task_id}.{function_name}. " f"Available: {available}" ) from error def oracle_result(task_id: str, function_name: str, args: tuple[Any, ...]) -> Any: return get_function_semantics(task_id, function_name).oracle(args) def lean_value(task_id: str, function_name: str, value: Any) -> str: return get_function_semantics(task_id, function_name).lean_value(value) def lean_call(task_id: str, function_name: str, args: tuple[Any, ...]) -> str: return get_function_semantics(task_id, function_name).lean_call(args) __all__ = [ "FunctionSemantics", "FUNCTION_SEMANTICS", "get_function_semantics", "lean_call", "lean_value", "oracle_result", ]