class ConstraintParser: # 1. NEW: Add target_functions to the initialization def __init__(self, model, schedule, global_data, target_functions=None): self.model = model self.schedule = schedule self.global_data = global_data self.target_functions = target_functions or {} def resolve_arg(self, arg, local_vars): """Extracts values from local variables based on JSON rules.""" if isinstance(arg, (int, float, bool)): return arg if isinstance(arg, str): return local_vars.get(arg, arg) if isinstance(arg, dict): prop, var_name = list(arg.items())[0] parent_obj = local_vars[var_name] if prop.isdigit() and isinstance(parent_obj, (list, tuple)): return parent_obj[int(prop)] if isinstance(parent_obj, dict): return parent_obj[prop] raise ValueError(f"Could not resolve argument: {arg}") def apply_assert(self, assert_ast, local_vars): """Translates the 'assert' block into a CP-SAT model.Add()""" left_val = self.evaluate_expression(assert_ast["left"], local_vars) right_val = self.evaluate_expression(assert_ast["right"], local_vars) operator = assert_ast["operator"] if isinstance(left_val, (int, float)) and isinstance(right_val, (int, float)): return if operator == "==": self.model.Add(left_val == right_val) elif operator == "!=": self.model.Add(left_val != right_val) elif operator == "<=": self.model.Add(left_val <= right_val) elif operator == ">=": self.model.Add(left_val >= right_val) elif operator == "<": self.model.Add(left_val < right_val) elif operator == ">": self.model.Add(left_val > right_val) def evaluate_expression(self, expr, local_vars): """Recursively evaluates ALL expressions (Math, Logic, and Variables)""" if not isinstance(expr, dict): return self.resolve_arg(expr, local_vars) # 2. NEW: Dynamic Target Lookup if "target" in expr: args = tuple(self.resolve_arg(a, local_vars) for a in expr["args"]) if expr["target"] == "schedule": return self.schedule.get(args, 0) # If the JSON asks for a custom function, dynamically run it! elif expr["target"] in self.target_functions: func = self.target_functions[expr["target"]] return func(self.schedule, *args) else: raise ValueError(f"Unknown target function: {expr['target']}") op = expr.get("operator") if op: # 3. RESTORED: Intercept 'sum' BEFORE evaluating left/right if op == "sum": sum_results = [] def execute_inner_loops(loop_array, depth, current_vars): if depth == len(loop_array): val = self.evaluate_expression(expr["expression"], current_vars) sum_results.append(val) return current_loop = loop_array[depth] iterator_name, iterator_source = list(current_loop.items())[0] # --- THE FIX: Add the global_data check here! --- if isinstance(iterator_source, str) and iterator_source in self.global_data: iterable = self.global_data[iterator_source] else: iterable = self.resolve_arg(iterator_source, current_vars) # ------------------------------------------------ for item in iterable: new_vars = current_vars.copy() new_vars[iterator_name] = item if "where" in expr: if not self.evaluate_expression(expr["where"], new_vars): continue execute_inner_loops(loop_array, depth + 1, new_vars) execute_inner_loops(expr["over"], 0, local_vars) return sum(sum_results) # --- BINARY OPERATORS (Require left/right) --- left_val = self.evaluate_expression(expr.get("left"), local_vars) right_val = self.evaluate_expression(expr.get("right"), local_vars) # Math if op == "+": return left_val + right_val if op == "-": return left_val - right_val if op == "*": return left_val * right_val # Comparison if op == "<": return left_val < right_val if op == ">": return left_val > right_val if op == "<=": return left_val <= right_val if op == ">=": return left_val >= right_val if op == "==": return left_val == right_val if op == "!=": return left_val != right_val # Boolean if op == "AND": return left_val and right_val if op == "OR": return left_val or right_val if op == "in": return left_val in right_val if op == "not_in": return left_val not in right_val return self.resolve_arg(expr, local_vars) def execute_loops(self, loop_array, current_depth, local_vars, ast): """Recursively iterates through the 'forall' array.""" if current_depth == len(loop_array): if "where" in ast: if not self.evaluate_expression(ast["where"], local_vars): return self.apply_assert(ast["assert"], local_vars) return current_loop = loop_array[current_depth] iterator_name, iterator_source = list(current_loop.items())[0] iterable = None if isinstance(iterator_source, str) and iterator_source in self.global_data: iterable = self.global_data[iterator_source] else: iterable = self.resolve_arg(iterator_source, local_vars) for item in iterable: new_vars = local_vars.copy() new_vars[iterator_name] = item self.execute_loops(loop_array, current_depth + 1, new_vars, ast) def parse_and_apply(self, ast): """Main entry point to execute a JSON AST constraint.""" if ast["type"] == "hard": self.execute_loops(ast["forall"], 0, {}, ast)