constraint-env / parser.py
DecentSanage's picture
Upload folder using huggingface_hub
f823a82 verified
class ConstraintParser:
# 1. NEW: Add target_functions to the initialization
def __init__(self, model, schedule, global_data, target_functions=None):
self.model = model
self.schedule = schedule
self.global_data = global_data
self.target_functions = target_functions or {}
def resolve_arg(self, arg, local_vars):
"""Extracts values from local variables based on JSON rules."""
if isinstance(arg, (int, float, bool)):
return arg
if isinstance(arg, str):
return local_vars.get(arg, arg)
if isinstance(arg, dict):
prop, var_name = list(arg.items())[0]
parent_obj = local_vars[var_name]
if prop.isdigit() and isinstance(parent_obj, (list, tuple)):
return parent_obj[int(prop)]
if isinstance(parent_obj, dict):
return parent_obj[prop]
raise ValueError(f"Could not resolve argument: {arg}")
def apply_assert(self, assert_ast, local_vars):
"""Translates the 'assert' block into a CP-SAT model.Add()"""
left_val = self.evaluate_expression(assert_ast["left"], local_vars)
right_val = self.evaluate_expression(assert_ast["right"], local_vars)
operator = assert_ast["operator"]
if isinstance(left_val, (int, float)) and isinstance(right_val, (int, float)):
return
if operator == "==": self.model.Add(left_val == right_val)
elif operator == "!=": self.model.Add(left_val != right_val)
elif operator == "<=": self.model.Add(left_val <= right_val)
elif operator == ">=": self.model.Add(left_val >= right_val)
elif operator == "<": self.model.Add(left_val < right_val)
elif operator == ">": self.model.Add(left_val > right_val)
def evaluate_expression(self, expr, local_vars):
"""Recursively evaluates ALL expressions (Math, Logic, and Variables)"""
if not isinstance(expr, dict):
return self.resolve_arg(expr, local_vars)
# 2. NEW: Dynamic Target Lookup
if "target" in expr:
args = tuple(self.resolve_arg(a, local_vars) for a in expr["args"])
if expr["target"] == "schedule":
return self.schedule.get(args, 0)
# If the JSON asks for a custom function, dynamically run it!
elif expr["target"] in self.target_functions:
func = self.target_functions[expr["target"]]
return func(self.schedule, *args)
else:
raise ValueError(f"Unknown target function: {expr['target']}")
op = expr.get("operator")
if op:
# 3. RESTORED: Intercept 'sum' BEFORE evaluating left/right
if op == "sum":
sum_results = []
def execute_inner_loops(loop_array, depth, current_vars):
if depth == len(loop_array):
val = self.evaluate_expression(expr["expression"], current_vars)
sum_results.append(val)
return
current_loop = loop_array[depth]
iterator_name, iterator_source = list(current_loop.items())[0]
# --- THE FIX: Add the global_data check here! ---
if isinstance(iterator_source, str) and iterator_source in self.global_data:
iterable = self.global_data[iterator_source]
else:
iterable = self.resolve_arg(iterator_source, current_vars)
# ------------------------------------------------
for item in iterable:
new_vars = current_vars.copy()
new_vars[iterator_name] = item
if "where" in expr:
if not self.evaluate_expression(expr["where"], new_vars):
continue
execute_inner_loops(loop_array, depth + 1, new_vars)
execute_inner_loops(expr["over"], 0, local_vars)
return sum(sum_results)
# --- BINARY OPERATORS (Require left/right) ---
left_val = self.evaluate_expression(expr.get("left"), local_vars)
right_val = self.evaluate_expression(expr.get("right"), local_vars)
# Math
if op == "+": return left_val + right_val
if op == "-": return left_val - right_val
if op == "*": return left_val * right_val
# Comparison
if op == "<": return left_val < right_val
if op == ">": return left_val > right_val
if op == "<=": return left_val <= right_val
if op == ">=": return left_val >= right_val
if op == "==": return left_val == right_val
if op == "!=": return left_val != right_val
# Boolean
if op == "AND": return left_val and right_val
if op == "OR": return left_val or right_val
if op == "in": return left_val in right_val
if op == "not_in": return left_val not in right_val
return self.resolve_arg(expr, local_vars)
def execute_loops(self, loop_array, current_depth, local_vars, ast):
"""Recursively iterates through the 'forall' array."""
if current_depth == len(loop_array):
if "where" in ast:
if not self.evaluate_expression(ast["where"], local_vars):
return
self.apply_assert(ast["assert"], local_vars)
return
current_loop = loop_array[current_depth]
iterator_name, iterator_source = list(current_loop.items())[0]
iterable = None
if isinstance(iterator_source, str) and iterator_source in self.global_data:
iterable = self.global_data[iterator_source]
else:
iterable = self.resolve_arg(iterator_source, local_vars)
for item in iterable:
new_vars = local_vars.copy()
new_vars[iterator_name] = item
self.execute_loops(loop_array, current_depth + 1, new_vars, ast)
def parse_and_apply(self, ast):
"""Main entry point to execute a JSON AST constraint."""
if ast["type"] == "hard":
self.execute_loops(ast["forall"], 0, {}, ast)