File size: 7,722 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import ast
import builtins
import re
from _ast import Name
from typing import List, Optional, Tuple

from injector import inject

allowed_builtins = [name for name, obj in vars(builtins).items() if callable(obj)]


class FunctionCallValidator(ast.NodeVisitor):
    @inject
    def __init__(
        self,
        lines: List[str],
        plugin_list: List[str],
        plugin_only: bool,
        allowed_modules: List[str],
    ):
        self.lines = lines
        self.plugin_list = plugin_list
        self.errors = []
        self.plugin_return_values = []
        self.plugin_only = plugin_only
        self.allowed_modules = allowed_modules

    def visit_Call(self, node):
        if self.plugin_only:
            if isinstance(node.func, ast.Name):
                function_name = node.func.id
                if function_name not in self.plugin_list and function_name not in allowed_builtins:
                    self.errors.append(
                        f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
                        f"=> Function '{node.func.id}' is not allowed.",
                    )
                    return False
                return True
            elif isinstance(node.func, ast.Attribute):
                function_name = node.func.attr
                if function_name not in allowed_builtins and function_name not in self.plugin_list:
                    self.errors.append(
                        f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
                        f"=> Function '{function_name}' is not allowed.",
                    )
                    return False
                return True
            else:
                self.errors.append(
                    f"Error on line {node.lineno}: {self.lines[node.lineno-1]} " f"=> Function call is not allowed.",
                )
                return False

    def visit_Import(self, node):
        if len(self.allowed_modules) > 0:
            for alias in node.names:
                if "." in alias.name:
                    module_name = alias.name.split(".")[0]
                else:
                    module_name = alias.name
                if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
                    self.errors.append(
                        f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
                        f"=> Importing module '{module_name}' is not allowed. ",
                    )

    def visit_ImportFrom(self, node):
        if len(self.allowed_modules) > 0:
            if "." in node.module:
                module_name = node.module.split(".")[0]
            else:
                module_name = node.module
            if len(self.allowed_modules) > 0 and module_name not in self.allowed_modules:
                self.errors.append(
                    f"Error on line {node.lineno}: {self.lines[node.lineno-1]} "
                    f"=>  Importing from module '{node.module}' is not allowed.",
                )

    def visit_FunctionDef(self, node):
        if self.plugin_only:
            self.errors.append(
                f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => Defining new functions is not allowed.",
            )

    def visit_Assign(self, node):
        if self.plugin_only:
            if isinstance(node.value, ast.Call):
                is_allowed_call = self.visit_Call(node.value)
                if not is_allowed_call:
                    return
                if isinstance(node.targets[0], ast.Tuple):
                    for elt in node.targets[0].elts:
                        if isinstance(elt, ast.Name):
                            self.plugin_return_values.append(elt.id)
                elif isinstance(node.targets[0], ast.Name):
                    self.plugin_return_values.append(node.targets[0].id)
                # print(self.plugin_return_values)
            else:
                self.errors.append(f"Error: Unsupported assignment on line {node.lineno}.")
                self.generic_visit(node)

    def visit_Name(self, node: Name):
        if self.plugin_only:
            if node.id not in self.plugin_return_values:
                self.errors.append(
                    f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => "
                    "Only return values of plugins calls can be used.",
                )
            # self.generic_visit(node)

    def generic_visit(self, node):
        if self.plugin_only and not isinstance(
            node,
            (ast.Call, ast.Assign, ast.Import, ast.ImportFrom, ast.Expr, ast.Module, ast.Name),
        ):
            if isinstance(node, ast.Tuple):
                for elt in node.elts:
                    self.visit(elt)
            else:
                error_message = (
                    f"Error on line {node.lineno}: {self.lines[node.lineno-1]} => "
                    "Codes except plugin calls are not allowed."
                )
                self.errors.append(error_message)

        else:
            super().generic_visit(node)


def format_code_correction_message() -> str:
    return (
        "The generated code has been verified and some errors are found. "
        "If you think you can fix the problem by rewriting the code, "
        "please do it and try again.\n"
        "Otherwise, please explain the problem to me."
    )


def separate_magics_and_code(input_code: str) -> Tuple[List[str], str, List[str]]:
    line_magic_pattern = re.compile(r"^\s*%\s*[a-zA-Z_]\w*")
    cell_magic_pattern = re.compile(r"^\s*%%\s*[a-zA-Z_]\w*")
    shell_command_pattern = re.compile(r"^\s*!")

    magics = []
    python_code = []
    package_install_commands = []

    lines = input_code.splitlines()
    inside_cell_magic = False

    for line in lines:
        if not line.strip() or line.strip().startswith("#"):
            continue

        if inside_cell_magic:
            magics.append(line)
            if not line.strip():
                inside_cell_magic = False
            continue
        if line_magic_pattern.match(line) or shell_command_pattern.match(line):
            # Check if the line magic or shell command is a package installation command
            if "pip install" in line or "conda install" in line:
                package_install_commands.append(line)
            else:
                magics.append(line)
        elif cell_magic_pattern.match(line):
            inside_cell_magic = True
            magics.append(line)
        else:
            python_code.append(line)
    python_code_str = "\n".join(python_code)
    return magics, python_code_str, package_install_commands


def code_snippet_verification(
    code_snippet: str,
    plugin_list: List[str],
    code_verification_on: bool = False,
    plugin_only: bool = False,
    allowed_modules: List[str] = [],
) -> Optional[List[str]]:
    if not code_verification_on:
        return None
    errors = []
    try:
        magics, python_code, _ = separate_magics_and_code(code_snippet)
        if len(magics) > 0:
            errors.append(f"Magic commands except package install are not allowed. Details: {magics}")
        tree = ast.parse(python_code)

        processed_lines = []
        for line in python_code.splitlines():
            if not line.strip() or line.strip().startswith("#"):
                continue
            processed_lines.append(line)
        validator = FunctionCallValidator(processed_lines, plugin_list, plugin_only, allowed_modules)
        validator.visit(tree)
        errors.extend(validator.errors)
        return errors
    except SyntaxError as e:
        # print(f"Syntax error: {e}")
        return [f"Syntax error: {e}"]