Spaces:
Sleeping
Sleeping
File size: 6,739 Bytes
f823a82 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | class ConstraintParser:
# 1. NEW: Add target_functions to the initialization
def __init__(self, model, schedule, global_data, target_functions=None):
self.model = model
self.schedule = schedule
self.global_data = global_data
self.target_functions = target_functions or {}
def resolve_arg(self, arg, local_vars):
"""Extracts values from local variables based on JSON rules."""
if isinstance(arg, (int, float, bool)):
return arg
if isinstance(arg, str):
return local_vars.get(arg, arg)
if isinstance(arg, dict):
prop, var_name = list(arg.items())[0]
parent_obj = local_vars[var_name]
if prop.isdigit() and isinstance(parent_obj, (list, tuple)):
return parent_obj[int(prop)]
if isinstance(parent_obj, dict):
return parent_obj[prop]
raise ValueError(f"Could not resolve argument: {arg}")
def apply_assert(self, assert_ast, local_vars):
"""Translates the 'assert' block into a CP-SAT model.Add()"""
left_val = self.evaluate_expression(assert_ast["left"], local_vars)
right_val = self.evaluate_expression(assert_ast["right"], local_vars)
operator = assert_ast["operator"]
if isinstance(left_val, (int, float)) and isinstance(right_val, (int, float)):
return
if operator == "==": self.model.Add(left_val == right_val)
elif operator == "!=": self.model.Add(left_val != right_val)
elif operator == "<=": self.model.Add(left_val <= right_val)
elif operator == ">=": self.model.Add(left_val >= right_val)
elif operator == "<": self.model.Add(left_val < right_val)
elif operator == ">": self.model.Add(left_val > right_val)
def evaluate_expression(self, expr, local_vars):
"""Recursively evaluates ALL expressions (Math, Logic, and Variables)"""
if not isinstance(expr, dict):
return self.resolve_arg(expr, local_vars)
# 2. NEW: Dynamic Target Lookup
if "target" in expr:
args = tuple(self.resolve_arg(a, local_vars) for a in expr["args"])
if expr["target"] == "schedule":
return self.schedule.get(args, 0)
# If the JSON asks for a custom function, dynamically run it!
elif expr["target"] in self.target_functions:
func = self.target_functions[expr["target"]]
return func(self.schedule, *args)
else:
raise ValueError(f"Unknown target function: {expr['target']}")
op = expr.get("operator")
if op:
# 3. RESTORED: Intercept 'sum' BEFORE evaluating left/right
if op == "sum":
sum_results = []
def execute_inner_loops(loop_array, depth, current_vars):
if depth == len(loop_array):
val = self.evaluate_expression(expr["expression"], current_vars)
sum_results.append(val)
return
current_loop = loop_array[depth]
iterator_name, iterator_source = list(current_loop.items())[0]
# --- THE FIX: Add the global_data check here! ---
if isinstance(iterator_source, str) and iterator_source in self.global_data:
iterable = self.global_data[iterator_source]
else:
iterable = self.resolve_arg(iterator_source, current_vars)
# ------------------------------------------------
for item in iterable:
new_vars = current_vars.copy()
new_vars[iterator_name] = item
if "where" in expr:
if not self.evaluate_expression(expr["where"], new_vars):
continue
execute_inner_loops(loop_array, depth + 1, new_vars)
execute_inner_loops(expr["over"], 0, local_vars)
return sum(sum_results)
# --- BINARY OPERATORS (Require left/right) ---
left_val = self.evaluate_expression(expr.get("left"), local_vars)
right_val = self.evaluate_expression(expr.get("right"), local_vars)
# Math
if op == "+": return left_val + right_val
if op == "-": return left_val - right_val
if op == "*": return left_val * right_val
# Comparison
if op == "<": return left_val < right_val
if op == ">": return left_val > right_val
if op == "<=": return left_val <= right_val
if op == ">=": return left_val >= right_val
if op == "==": return left_val == right_val
if op == "!=": return left_val != right_val
# Boolean
if op == "AND": return left_val and right_val
if op == "OR": return left_val or right_val
if op == "in": return left_val in right_val
if op == "not_in": return left_val not in right_val
return self.resolve_arg(expr, local_vars)
def execute_loops(self, loop_array, current_depth, local_vars, ast):
"""Recursively iterates through the 'forall' array."""
if current_depth == len(loop_array):
if "where" in ast:
if not self.evaluate_expression(ast["where"], local_vars):
return
self.apply_assert(ast["assert"], local_vars)
return
current_loop = loop_array[current_depth]
iterator_name, iterator_source = list(current_loop.items())[0]
iterable = None
if isinstance(iterator_source, str) and iterator_source in self.global_data:
iterable = self.global_data[iterator_source]
else:
iterable = self.resolve_arg(iterator_source, local_vars)
for item in iterable:
new_vars = local_vars.copy()
new_vars[iterator_name] = item
self.execute_loops(loop_array, current_depth + 1, new_vars, ast)
def parse_and_apply(self, ast):
"""Main entry point to execute a JSON AST constraint."""
if ast["type"] == "hard":
self.execute_loops(ast["forall"], 0, {}, ast) |