File size: 6,739 Bytes
f823a82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
class ConstraintParser:
    # 1. NEW: Add target_functions to the initialization
    def __init__(self, model, schedule, global_data, target_functions=None):
        self.model = model
        self.schedule = schedule
        self.global_data = global_data
        self.target_functions = target_functions or {}

    def resolve_arg(self, arg, local_vars):
        """Extracts values from local variables based on JSON rules."""
        if isinstance(arg, (int, float, bool)):
            return arg
        if isinstance(arg, str):
            return local_vars.get(arg, arg)
        if isinstance(arg, dict):
            prop, var_name = list(arg.items())[0]
            parent_obj = local_vars[var_name]
            
            if prop.isdigit() and isinstance(parent_obj, (list, tuple)):
                return parent_obj[int(prop)]
            if isinstance(parent_obj, dict):
                return parent_obj[prop]
                
        raise ValueError(f"Could not resolve argument: {arg}")

    def apply_assert(self, assert_ast, local_vars):
        """Translates the 'assert' block into a CP-SAT model.Add()"""
        left_val = self.evaluate_expression(assert_ast["left"], local_vars)
        right_val = self.evaluate_expression(assert_ast["right"], local_vars)
        operator = assert_ast["operator"]

        if isinstance(left_val, (int, float)) and isinstance(right_val, (int, float)):
            return

        if operator == "==": self.model.Add(left_val == right_val)
        elif operator == "!=": self.model.Add(left_val != right_val)
        elif operator == "<=": self.model.Add(left_val <= right_val)
        elif operator == ">=": self.model.Add(left_val >= right_val)
        elif operator == "<": self.model.Add(left_val < right_val)
        elif operator == ">": self.model.Add(left_val > right_val)

    def evaluate_expression(self, expr, local_vars):
        """Recursively evaluates ALL expressions (Math, Logic, and Variables)"""
        if not isinstance(expr, dict):
            return self.resolve_arg(expr, local_vars)

        # 2. NEW: Dynamic Target Lookup
        if "target" in expr:
            args = tuple(self.resolve_arg(a, local_vars) for a in expr["args"])
            
            if expr["target"] == "schedule":
                return self.schedule.get(args, 0)
                
            # If the JSON asks for a custom function, dynamically run it!
            elif expr["target"] in self.target_functions:
                func = self.target_functions[expr["target"]]
                return func(self.schedule, *args)
                
            else:
                raise ValueError(f"Unknown target function: {expr['target']}")

        op = expr.get("operator")
        if op:
            # 3. RESTORED: Intercept 'sum' BEFORE evaluating left/right
            if op == "sum":
                sum_results = []
                
                def execute_inner_loops(loop_array, depth, current_vars):
                    if depth == len(loop_array):
                        val = self.evaluate_expression(expr["expression"], current_vars)
                        sum_results.append(val)
                        return
                    
                    current_loop = loop_array[depth]
                    iterator_name, iterator_source = list(current_loop.items())[0]
                    
                    # --- THE FIX: Add the global_data check here! ---
                    if isinstance(iterator_source, str) and iterator_source in self.global_data:
                        iterable = self.global_data[iterator_source]
                    else:
                        iterable = self.resolve_arg(iterator_source, current_vars)
                    # ------------------------------------------------

                    for item in iterable:
                        new_vars = current_vars.copy()
                        new_vars[iterator_name] = item
                        
                        if "where" in expr:
                            if not self.evaluate_expression(expr["where"], new_vars):
                                continue 

                        execute_inner_loops(loop_array, depth + 1, new_vars)
                        
                execute_inner_loops(expr["over"], 0, local_vars)
                return sum(sum_results)

            # --- BINARY OPERATORS (Require left/right) ---
            left_val = self.evaluate_expression(expr.get("left"), local_vars)
            right_val = self.evaluate_expression(expr.get("right"), local_vars)
            
            # Math
            if op == "+": return left_val + right_val
            if op == "-": return left_val - right_val
            if op == "*": return left_val * right_val
            
            # Comparison
            if op == "<": return left_val < right_val
            if op == ">": return left_val > right_val
            if op == "<=": return left_val <= right_val
            if op == ">=": return left_val >= right_val
            if op == "==": return left_val == right_val
            if op == "!=": return left_val != right_val
            
            # Boolean
            if op == "AND": return left_val and right_val
            if op == "OR": return left_val or right_val
            if op == "in": return left_val in right_val
            if op == "not_in": return left_val not in right_val

        return self.resolve_arg(expr, local_vars)

    def execute_loops(self, loop_array, current_depth, local_vars, ast):
        """Recursively iterates through the 'forall' array."""
        if current_depth == len(loop_array):
            if "where" in ast:
                if not self.evaluate_expression(ast["where"], local_vars):
                    return
            self.apply_assert(ast["assert"], local_vars)
            return

        current_loop = loop_array[current_depth]
        iterator_name, iterator_source = list(current_loop.items())[0]

        iterable = None
        if isinstance(iterator_source, str) and iterator_source in self.global_data:
            iterable = self.global_data[iterator_source]
        else:
            iterable = self.resolve_arg(iterator_source, local_vars)

        for item in iterable:
            new_vars = local_vars.copy()
            new_vars[iterator_name] = item
            self.execute_loops(loop_array, current_depth + 1, new_vars, ast)

    def parse_and_apply(self, ast):
        """Main entry point to execute a JSON AST constraint."""
        if ast["type"] == "hard":
            self.execute_loops(ast["forall"], 0, {}, ast)