File size: 7,722 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 |
import ast
import builtins
import re
from _ast import Name
from typing import List, Optional, Tuple
from injector import inject
allowed_builtins = [name for name, obj in vars(builtins).items() if callable(obj)]
class FunctionCallValidator(ast.NodeVisitor):
@inject
def __init__(
self,
lines: List[str],
plugin_list: List[str],
plugin_only: bool,
allowed_modules: List[str],
):
self.lines = lines
self.plugin_list = plugin_list
self.errors = []
self.plugin_return_values = []
self.plugin_only = plugin_only
self.allowed_modules = allowed_modules
def visit_Call(self, node):
if self.plugin_only:
if isinstance(node.func, ast.Name):
function_name = node.func.id
if function_name not in self.plugin_list and function_name not in allowed_builtins:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Function '{node.func.id}' is not allowed.",
)
return False
return True
elif isinstance(node.func, ast.Attribute):
function_name = node.func.attr
if function_name not in allowed_builtins and function_name not in self.plugin_list:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Function '{function_name}' is not allowed.",
)
return False
return True
else:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " f"=> Function call is not allowed.",
)
return False
def visit_Import(self, node):
if len(self.allowed_modules) > 0:
for alias in node.names:
if "." in alias.name:
module_name = alias.name.split(".")[0]
else:
module_name = alias.name
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing module '{module_name}' is not allowed. ",
)
def visit_ImportFrom(self, node):
if len(self.allowed_modules) > 0:
if "." in node.module:
module_name = node.module.split(".")[0]
else:
module_name = node.module
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
f"=> Importing from module '{node.module}' is not allowed.",
)
def visit_FunctionDef(self, node):
if self.plugin_only:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => Defining new functions is not allowed.",
)
def visit_Assign(self, node):
if self.plugin_only:
if isinstance(node.value, ast.Call):
is_allowed_call = self.visit_Call(node.value)
if not is_allowed_call:
return
if isinstance(node.targets[0], ast.Tuple):
for elt in node.targets[0].elts:
if isinstance(elt, ast.Name):
self.plugin_return_values.append(elt.id)
elif isinstance(node.targets[0], ast.Name):
self.plugin_return_values.append(node.targets[0].id)
# print(self.plugin_return_values)
else:
self.errors.append(f"Error: Unsupported assignment on line {node.lineno}.")
self.generic_visit(node)
def visit_Name(self, node: Name):
if self.plugin_only:
if node.id not in self.plugin_return_values:
self.errors.append(
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => "
"Only return values of plugins calls can be used.",
)
# self.generic_visit(node)
def generic_visit(self, node):
if self.plugin_only and not isinstance(
node,
(ast.Call, ast.Assign, ast.Import, ast.ImportFrom, ast.Expr, ast.Module, ast.Name),
):
if isinstance(node, ast.Tuple):
for elt in node.elts:
self.visit(elt)
else:
error_message = (
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => "
"Codes except plugin calls are not allowed."
)
self.errors.append(error_message)
else:
super().generic_visit(node)
def format_code_correction_message() -> str:
return (
"The generated code has been verified and some errors are found. "
"If you think you can fix the problem by rewriting the code, "
"please do it and try again.\n"
"Otherwise, please explain the problem to me."
)
def separate_magics_and_code(input_code: str) -> Tuple[List[str], str, List[str]]:
line_magic_pattern = re.compile(r"^\s*%\s*[a-zA-Z_]\w*")
cell_magic_pattern = re.compile(r"^\s*%%\s*[a-zA-Z_]\w*")
shell_command_pattern = re.compile(r"^\s*!")
magics = []
python_code = []
package_install_commands = []
lines = input_code.splitlines()
inside_cell_magic = False
for line in lines:
if not line.strip() or line.strip().startswith("#"):
continue
if inside_cell_magic:
magics.append(line)
if not line.strip():
inside_cell_magic = False
continue
if line_magic_pattern.match(line) or shell_command_pattern.match(line):
# Check if the line magic or shell command is a package installation command
if "pip install" in line or "conda install" in line:
package_install_commands.append(line)
else:
magics.append(line)
elif cell_magic_pattern.match(line):
inside_cell_magic = True
magics.append(line)
else:
python_code.append(line)
python_code_str = "\n".join(python_code)
return magics, python_code_str, package_install_commands
def code_snippet_verification(
code_snippet: str,
plugin_list: List[str],
code_verification_on: bool = False,
plugin_only: bool = False,
allowed_modules: List[str] = [],
) -> Optional[List[str]]:
if not code_verification_on:
return None
errors = []
try:
magics, python_code, _ = separate_magics_and_code(code_snippet)
if len(magics) > 0:
errors.append(f"Magic commands except package install are not allowed. Details: {magics}")
tree = ast.parse(python_code)
processed_lines = []
for line in python_code.splitlines():
if not line.strip() or line.strip().startswith("#"):
continue
processed_lines.append(line)
validator = FunctionCallValidator(processed_lines, plugin_list, plugin_only, allowed_modules)
validator.visit(tree)
errors.extend(validator.errors)
return errors
except SyntaxError as e:
# print(f"Syntax error: {e}")
return [f"Syntax error: {e}"]
|