|
|
import ast |
|
|
import builtins |
|
|
import re |
|
|
from _ast import Name |
|
|
from typing import List, Optional, Tuple |
|
|
|
|
|
from injector import inject |
|
|
|
|
|
allowed_builtins = [name for name, obj in vars(builtins).items() if callable(obj)] |
|
|
|
|
|
|
|
|
class FunctionCallValidator(ast.NodeVisitor): |
|
|
@inject |
|
|
def __init__( |
|
|
self, |
|
|
lines: List[str], |
|
|
plugin_list: List[str], |
|
|
plugin_only: bool, |
|
|
allowed_modules: List[str], |
|
|
): |
|
|
self.lines = lines |
|
|
self.plugin_list = plugin_list |
|
|
self.errors = [] |
|
|
self.plugin_return_values = [] |
|
|
self.plugin_only = plugin_only |
|
|
self.allowed_modules = allowed_modules |
|
|
|
|
|
def visit_Call(self, node): |
|
|
if self.plugin_only: |
|
|
if isinstance(node.func, ast.Name): |
|
|
function_name = node.func.id |
|
|
if function_name not in self.plugin_list and function_name not in allowed_builtins: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " |
|
|
f"=> Function '{node.func.id}' is not allowed.", |
|
|
) |
|
|
return False |
|
|
return True |
|
|
elif isinstance(node.func, ast.Attribute): |
|
|
function_name = node.func.attr |
|
|
if function_name not in allowed_builtins and function_name not in self.plugin_list: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " |
|
|
f"=> Function '{function_name}' is not allowed.", |
|
|
) |
|
|
return False |
|
|
return True |
|
|
else: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " f"=> Function call is not allowed.", |
|
|
) |
|
|
return False |
|
|
|
|
|
def visit_Import(self, node): |
|
|
if len(self.allowed_modules) > 0: |
|
|
for alias in node.names: |
|
|
if "." in alias.name: |
|
|
module_name = alias.name.split(".")[0] |
|
|
else: |
|
|
module_name = alias.name |
|
|
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " |
|
|
f"=> Importing module '{module_name}' is not allowed. ", |
|
|
) |
|
|
|
|
|
def visit_ImportFrom(self, node): |
|
|
if len(self.allowed_modules) > 0: |
|
|
if "." in node.module: |
|
|
module_name = node.module.split(".")[0] |
|
|
else: |
|
|
module_name = node.module |
|
|
if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " |
|
|
f"=> Importing from module '{node.module}' is not allowed.", |
|
|
) |
|
|
|
|
|
def visit_FunctionDef(self, node): |
|
|
if self.plugin_only: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => Defining new functions is not allowed.", |
|
|
) |
|
|
|
|
|
def visit_Assign(self, node): |
|
|
if self.plugin_only: |
|
|
if isinstance(node.value, ast.Call): |
|
|
is_allowed_call = self.visit_Call(node.value) |
|
|
if not is_allowed_call: |
|
|
return |
|
|
if isinstance(node.targets[0], ast.Tuple): |
|
|
for elt in node.targets[0].elts: |
|
|
if isinstance(elt, ast.Name): |
|
|
self.plugin_return_values.append(elt.id) |
|
|
elif isinstance(node.targets[0], ast.Name): |
|
|
self.plugin_return_values.append(node.targets[0].id) |
|
|
|
|
|
else: |
|
|
self.errors.append(f"Error: Unsupported assignment on line {node.lineno}.") |
|
|
self.generic_visit(node) |
|
|
|
|
|
def visit_Name(self, node: Name): |
|
|
if self.plugin_only: |
|
|
if node.id not in self.plugin_return_values: |
|
|
self.errors.append( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => " |
|
|
"Only return values of plugins calls can be used.", |
|
|
) |
|
|
|
|
|
|
|
|
def generic_visit(self, node): |
|
|
if self.plugin_only and not isinstance( |
|
|
node, |
|
|
(ast.Call, ast.Assign, ast.Import, ast.ImportFrom, ast.Expr, ast.Module, ast.Name), |
|
|
): |
|
|
if isinstance(node, ast.Tuple): |
|
|
for elt in node.elts: |
|
|
self.visit(elt) |
|
|
else: |
|
|
error_message = ( |
|
|
f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => " |
|
|
"Codes except plugin calls are not allowed." |
|
|
) |
|
|
self.errors.append(error_message) |
|
|
|
|
|
else: |
|
|
super().generic_visit(node) |
|
|
|
|
|
|
|
|
def format_code_correction_message() -> str: |
|
|
return ( |
|
|
"The generated code has been verified and some errors are found. " |
|
|
"If you think you can fix the problem by rewriting the code, " |
|
|
"please do it and try again.\n" |
|
|
"Otherwise, please explain the problem to me." |
|
|
) |
|
|
|
|
|
|
|
|
def separate_magics_and_code(input_code: str) -> Tuple[List[str], str, List[str]]: |
|
|
line_magic_pattern = re.compile(r"^\s*%\s*[a-zA-Z_]\w*") |
|
|
cell_magic_pattern = re.compile(r"^\s*%%\s*[a-zA-Z_]\w*") |
|
|
shell_command_pattern = re.compile(r"^\s*!") |
|
|
|
|
|
magics = [] |
|
|
python_code = [] |
|
|
package_install_commands = [] |
|
|
|
|
|
lines = input_code.splitlines() |
|
|
inside_cell_magic = False |
|
|
|
|
|
for line in lines: |
|
|
if not line.strip() or line.strip().startswith("#"): |
|
|
continue |
|
|
|
|
|
if inside_cell_magic: |
|
|
magics.append(line) |
|
|
if not line.strip(): |
|
|
inside_cell_magic = False |
|
|
continue |
|
|
if line_magic_pattern.match(line) or shell_command_pattern.match(line): |
|
|
|
|
|
if "pip install" in line or "conda install" in line: |
|
|
package_install_commands.append(line) |
|
|
else: |
|
|
magics.append(line) |
|
|
elif cell_magic_pattern.match(line): |
|
|
inside_cell_magic = True |
|
|
magics.append(line) |
|
|
else: |
|
|
python_code.append(line) |
|
|
python_code_str = "\n".join(python_code) |
|
|
return magics, python_code_str, package_install_commands |
|
|
|
|
|
|
|
|
def code_snippet_verification( |
|
|
code_snippet: str, |
|
|
plugin_list: List[str], |
|
|
code_verification_on: bool = False, |
|
|
plugin_only: bool = False, |
|
|
allowed_modules: List[str] = [], |
|
|
) -> Optional[List[str]]: |
|
|
if not code_verification_on: |
|
|
return None |
|
|
errors = [] |
|
|
try: |
|
|
magics, python_code, _ = separate_magics_and_code(code_snippet) |
|
|
if len(magics) > 0: |
|
|
errors.append(f"Magic commands except package install are not allowed. Details: {magics}") |
|
|
tree = ast.parse(python_code) |
|
|
|
|
|
processed_lines = [] |
|
|
for line in python_code.splitlines(): |
|
|
if not line.strip() or line.strip().startswith("#"): |
|
|
continue |
|
|
processed_lines.append(line) |
|
|
validator = FunctionCallValidator(processed_lines, plugin_list, plugin_only, allowed_modules) |
|
|
validator.visit(tree) |
|
|
errors.extend(validator.errors) |
|
|
return errors |
|
|
except SyntaxError as e: |
|
|
|
|
|
return [f"Syntax error: {e}"] |
|
|
|