File size: 6,913 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import hashlib
import ast
import re
from typing import List


def check_determinism(code: str, inputs: str, executor, prev_output: str = None, n_runs: int = 1):
    """expects an executor that outputs string output and status"""
    all_outputs = set()
    if prev_output is not None:
        hash = hashlib.md5(str(prev_output).encode()).hexdigest()
        all_outputs.add(hash)
    for _ in range(n_runs):
        result = executor.run_code(code, inputs)[0]
        hash = hashlib.md5(str(result).encode()).hexdigest()
        all_outputs.add(hash)
    return len(all_outputs) == 1


def contains_banned_imports(code: str, banned_keywords: List[str], banned_keywords_for_errors_and_exceptions: List[str] = []) -> bool:
    """Check if code imports any banned modules using AST parsing."""
    try:
        tree = ast.parse(code)
        for node in ast.walk(tree):
            if isinstance(node, ast.Import):
                for alias in node.names:
                    if any(banned in alias.name.split('.') for banned in banned_keywords):
                        return True
            elif isinstance(node, ast.ImportFrom):
                module = node.module.split('.') if node.module else []
                if any(banned in module for banned in banned_keywords):
                    return True
                for alias in node.names:
                    if any(banned in alias.name.split('.') for banned in banned_keywords):
                        return True

            if banned_keywords_for_errors_and_exceptions:
                # Check for assert statements
                if isinstance(node, ast.Assert) and 'assert' in banned_keywords_for_errors_and_exceptions:
                    return True

                # Check for raise statements
                elif isinstance(node, ast.Raise) and 'raise' in banned_keywords_for_errors_and_exceptions:
                    return True

                # Check for try-except blocks
                elif isinstance(node, ast.Try) and 'try' in banned_keywords_for_errors_and_exceptions:
                    return True

                # Check for except handlers
                elif isinstance(node, ast.ExceptHandler) and 'except' in banned_keywords_for_errors_and_exceptions:
                    return True

        return False
    except SyntaxError:
        # Fallback to simple check if AST parsing fails
        return any(re.search(rf'\b{re.escape(banned)}\b', code) for banned in banned_keywords)


def check_no_definitions(code: str, composite_functions: List[str]) -> bool:
    try:
        tree = ast.parse(code)
    except SyntaxError:
        return False
    
    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and node.name in composite_functions:
            return False
    return True


def check_composite_function(code: str, composite_functions: List[str]) -> bool:
    composite_function_names = [f"g_{i}" for i in range(len(composite_functions))]

    try:
        tree = ast.parse(code)
    except SyntaxError:
        return False
    
    f_def = None
    for node in tree.body:
        if isinstance(node, ast.FunctionDef) and node.name == 'f':
            f_def = node
            break
    if f_def is None:
        return False
    
    parameters = {arg.arg for arg in f_def.args.args}
    
    assigned_vars_visitor = AssignedVarsVisitor()
    for stmt in f_def.body:
        assigned_vars_visitor.visit(stmt)
    scope_vars = parameters | assigned_vars_visitor.assigned
    
    call_checker = CallChecker(composite_function_names, scope_vars)
    for stmt in f_def.body:
        call_checker.visit(stmt)
    
    result = call_checker.called == set(composite_function_names) and call_checker.valid
    return result


class AssignedVarsVisitor(ast.NodeVisitor):
    def __init__(self):
        self.assigned = set()
    
    def visit_Assign(self, node):
        for target in node.targets:
            self.collect_names(target)
        self.generic_visit(node)
    
    def collect_names(self, node):
        if isinstance(node, ast.Name):
            self.assigned.add(node.id)
        elif isinstance(node, (ast.Tuple, ast.List)):
            for elt in node.elts:
                self.collect_names(elt)


class CallChecker(ast.NodeVisitor):
    def __init__(self, composite_functions, scope_vars):
        self.composite_functions = composite_functions
        self.scope_vars = scope_vars
        self.called = set()
        self.valid = True
        self.local_scopes = [{}]
    
    def visit_FunctionDef(self, node):
        self.local_scopes.append({arg.arg: None for arg in node.args.args})
        self.generic_visit(node)
        self.local_scopes.pop()
    
    def visit_ListComp(self, node):
        comp_scope = {}
        for gen in node.generators:
            if isinstance(gen.iter, ast.Name) and gen.iter.id in self.scope_vars:
                self.collect_names(gen.target, comp_scope)
        self.local_scopes.append(comp_scope)
        self.visit(node.elt)
        for gen in node.generators:
            for comp_if in gen.ifs:
                self.visit(comp_if)
        self.local_scopes.pop()
    
    def visit_Call(self, node):
        if isinstance(node.func, ast.Name):
            if node.func.id in self.composite_functions:
                func_name = node.func.id
                self.called.add(func_name)
                current_scope = self.build_current_scope()
                for arg in node.args:
                    names = self.get_names(arg)
                    if not all(name in current_scope for name in names):
                        self.valid = False
            elif node.func.id in {n.name for n in ast.walk(node) if isinstance(n, ast.FunctionDef)}:
                for parent in ast.walk(node):
                    if isinstance(parent, ast.FunctionDef) and parent.name == node.func.id:
                        for param, arg in zip(parent.args.args, node.args):
                            if isinstance(arg, ast.Name):
                                self.local_scopes[-1][param.arg] = arg.id
        self.generic_visit(node)
    
    def build_current_scope(self):
        scope = set(self.scope_vars)
        for local_scope in self.local_scopes:
            scope.update(local_scope.keys())
            for mapped_var in local_scope.values():
                if mapped_var:
                    scope.add(mapped_var)
        return scope
    
    def collect_names(self, node, scope_dict):
        if isinstance(node, ast.Name):
            scope_dict[node.id] = None
        elif isinstance(node, (ast.Tuple, ast.List)):
            for elt in node.elts:
                self.collect_names(elt, scope_dict)
    
    def get_names(self, node):
        return [n.id for n in ast.walk(node) if isinstance(n, ast.Name) 
                and isinstance(n.ctx, ast.Load) 
                and n.id not in self.composite_functions]