File size: 6,913 Bytes
24c2665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 |
import hashlib
import ast
import re
from typing import List
def check_determinism(code: str, inputs: str, executor, prev_output: str = None, n_runs: int = 1):
"""expects an executor that outputs string output and status"""
all_outputs = set()
if prev_output is not None:
hash = hashlib.md5(str(prev_output).encode()).hexdigest()
all_outputs.add(hash)
for _ in range(n_runs):
result = executor.run_code(code, inputs)[0]
hash = hashlib.md5(str(result).encode()).hexdigest()
all_outputs.add(hash)
return len(all_outputs) == 1
def contains_banned_imports(code: str, banned_keywords: List[str], banned_keywords_for_errors_and_exceptions: List[str] = []) -> bool:
"""Check if code imports any banned modules using AST parsing."""
try:
tree = ast.parse(code)
for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
if any(banned in alias.name.split('.') for banned in banned_keywords):
return True
elif isinstance(node, ast.ImportFrom):
module = node.module.split('.') if node.module else []
if any(banned in module for banned in banned_keywords):
return True
for alias in node.names:
if any(banned in alias.name.split('.') for banned in banned_keywords):
return True
if banned_keywords_for_errors_and_exceptions:
# Check for assert statements
if isinstance(node, ast.Assert) and 'assert' in banned_keywords_for_errors_and_exceptions:
return True
# Check for raise statements
elif isinstance(node, ast.Raise) and 'raise' in banned_keywords_for_errors_and_exceptions:
return True
# Check for try-except blocks
elif isinstance(node, ast.Try) and 'try' in banned_keywords_for_errors_and_exceptions:
return True
# Check for except handlers
elif isinstance(node, ast.ExceptHandler) and 'except' in banned_keywords_for_errors_and_exceptions:
return True
return False
except SyntaxError:
# Fallback to simple check if AST parsing fails
return any(re.search(rf'\b{re.escape(banned)}\b', code) for banned in banned_keywords)
def check_no_definitions(code: str, composite_functions: List[str]) -> bool:
try:
tree = ast.parse(code)
except SyntaxError:
return False
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name in composite_functions:
return False
return True
def check_composite_function(code: str, composite_functions: List[str]) -> bool:
composite_function_names = [f"g_{i}" for i in range(len(composite_functions))]
try:
tree = ast.parse(code)
except SyntaxError:
return False
f_def = None
for node in tree.body:
if isinstance(node, ast.FunctionDef) and node.name == 'f':
f_def = node
break
if f_def is None:
return False
parameters = {arg.arg for arg in f_def.args.args}
assigned_vars_visitor = AssignedVarsVisitor()
for stmt in f_def.body:
assigned_vars_visitor.visit(stmt)
scope_vars = parameters | assigned_vars_visitor.assigned
call_checker = CallChecker(composite_function_names, scope_vars)
for stmt in f_def.body:
call_checker.visit(stmt)
result = call_checker.called == set(composite_function_names) and call_checker.valid
return result
class AssignedVarsVisitor(ast.NodeVisitor):
def __init__(self):
self.assigned = set()
def visit_Assign(self, node):
for target in node.targets:
self.collect_names(target)
self.generic_visit(node)
def collect_names(self, node):
if isinstance(node, ast.Name):
self.assigned.add(node.id)
elif isinstance(node, (ast.Tuple, ast.List)):
for elt in node.elts:
self.collect_names(elt)
class CallChecker(ast.NodeVisitor):
def __init__(self, composite_functions, scope_vars):
self.composite_functions = composite_functions
self.scope_vars = scope_vars
self.called = set()
self.valid = True
self.local_scopes = [{}]
def visit_FunctionDef(self, node):
self.local_scopes.append({arg.arg: None for arg in node.args.args})
self.generic_visit(node)
self.local_scopes.pop()
def visit_ListComp(self, node):
comp_scope = {}
for gen in node.generators:
if isinstance(gen.iter, ast.Name) and gen.iter.id in self.scope_vars:
self.collect_names(gen.target, comp_scope)
self.local_scopes.append(comp_scope)
self.visit(node.elt)
for gen in node.generators:
for comp_if in gen.ifs:
self.visit(comp_if)
self.local_scopes.pop()
def visit_Call(self, node):
if isinstance(node.func, ast.Name):
if node.func.id in self.composite_functions:
func_name = node.func.id
self.called.add(func_name)
current_scope = self.build_current_scope()
for arg in node.args:
names = self.get_names(arg)
if not all(name in current_scope for name in names):
self.valid = False
elif node.func.id in {n.name for n in ast.walk(node) if isinstance(n, ast.FunctionDef)}:
for parent in ast.walk(node):
if isinstance(parent, ast.FunctionDef) and parent.name == node.func.id:
for param, arg in zip(parent.args.args, node.args):
if isinstance(arg, ast.Name):
self.local_scopes[-1][param.arg] = arg.id
self.generic_visit(node)
def build_current_scope(self):
scope = set(self.scope_vars)
for local_scope in self.local_scopes:
scope.update(local_scope.keys())
for mapped_var in local_scope.values():
if mapped_var:
scope.add(mapped_var)
return scope
def collect_names(self, node, scope_dict):
if isinstance(node, ast.Name):
scope_dict[node.id] = None
elif isinstance(node, (ast.Tuple, ast.List)):
for elt in node.elts:
self.collect_names(elt, scope_dict)
def get_names(self, node):
return [n.id for n in ast.walk(node) if isinstance(n, ast.Name)
and isinstance(n.ctx, ast.Load)
and n.id not in self.composite_functions]
|