Supan23 commited on
Commit
41dad88
·
verified ·
1 Parent(s): 8fc58fb

Update code_interpreter.py

Browse files
Files changed (1) hide show
  1. code_interpreter.py +267 -181
code_interpreter.py CHANGED
@@ -1,181 +1,267 @@
1
- import os
2
- import io
3
- import sys
4
- import uuid
5
- import base64
6
- import traceback
7
- import contextlib
8
- import tempfile
9
- import subprocess
10
- import sqlite3
11
- from typing import Dict, List, Any, Optional, Union
12
- import numpy as np
13
- import pandas as pd
14
- import matplotlib.pyplot as plt
15
- from PIL import Image
16
-
17
- class CodeInterpreter:
18
- def __init__(self, allowed_modules=None, max_execution_time=30, working_directory=None):
19
- """Initialize the code interpreter with safety measures."""
20
- self.allowed_modules = allowed_modules or [
21
- "numpy", "pandas", "matplotlib", "scipy", "sklearn",
22
- "math", "random", "statistics", "datetime", "collections",
23
- "itertools", "functools", "operator", "re", "json",
24
- "sympy", "networkx", "nltk", "PIL", "pytesseract",
25
- "cmath", "uuid", "tempfile", "requests", "urllib"
26
- ]
27
- self.max_execution_time = max_execution_time
28
- self.working_directory = working_directory or os.path.join(os.getcwd())
29
- if not os.path.exists(self.working_directory):
30
- os.makedirs(self.working_directory)
31
- self.globals = {
32
- "__builtins__": __builtins__,
33
- "np": np,
34
- "pd": pd,
35
- "plt": plt,
36
- "Image": Image,
37
- }
38
- self.temp_sqlite_db = os.path.join(tempfile.gettempdir(), "code_exec.db")
39
-
40
- def execute_code(self, code: str, language: str = "python") -> Dict[str, Any]:
41
- """Execute the provided code in the selected programming language."""
42
- language = language.lower()
43
- execution_id = str(uuid.uuid4())
44
- result = {
45
- "execution_id": execution_id,
46
- "status": "error",
47
- "stdout": "",
48
- "stderr": "",
49
- "result": None,
50
- "plots": [],
51
- "dataframes": []
52
- }
53
-
54
- try:
55
- if language == "python":
56
- return self._execute_python(code, execution_id)
57
- elif language == "bash":
58
- return self._execute_bash(code, execution_id)
59
- elif language == "sql":
60
- return self._execute_sql(code, execution_id)
61
- else:
62
- result["stderr"] = f"Unsupported language: {language}"
63
- except Exception as e:
64
- result["stderr"] = str(e)
65
- return result
66
-
67
- def _execute_python(self, code: str, execution_id: str) -> dict:
68
- output_buffer = io.StringIO()
69
- error_buffer = io.StringIO()
70
- result = {
71
- "execution_id": execution_id,
72
- "status": "error",
73
- "stdout": "",
74
- "stderr": "",
75
- "result": None,
76
- "plots": [],
77
- "dataframes": []
78
- }
79
-
80
- try:
81
- exec_dir = os.path.join(self.working_directory, execution_id)
82
- os.makedirs(exec_dir, exist_ok=True)
83
- plt.switch_backend('Agg')
84
-
85
- with contextlib.redirect_stdout(output_buffer), contextlib.redirect_stderr(error_buffer):
86
- exec_result = exec(code, self.globals)
87
-
88
- # Handle plots
89
- if plt.get_fignums():
90
- for i, fig_num in enumerate(plt.get_fignums()):
91
- fig = plt.figure(fig_num)
92
- img_path = os.path.join(exec_dir, f"plot_{i}.png")
93
- fig.savefig(img_path)
94
- with open(img_path, "rb") as img_file:
95
- img_data = base64.b64encode(img_file.read()).decode('utf-8')
96
- result["plots"].append({
97
- "figure_number": fig_num,
98
- "data": img_data
99
- })
100
-
101
- # Handle dataframes
102
- for var_name, var_value in self.globals.items():
103
- if isinstance(var_value, pd.DataFrame) and len(var_value) > 0:
104
- result["dataframes"].append({
105
- "name": var_name,
106
- "head": var_value.head().to_dict(),
107
- "shape": var_value.shape,
108
- "dtypes": str(var_value.dtypes)
109
- })
110
-
111
- result["status"] = "success"
112
- result["stdout"] = output_buffer.getvalue()
113
- result["result"] = exec_result
114
-
115
- except Exception as e:
116
- result["status"] = "error"
117
- result["stderr"] = f"{error_buffer.getvalue()}\n{traceback.format_exc()}"
118
-
119
- return result
120
-
121
- def _execute_bash(self, code: str, execution_id: str) -> dict:
122
- try:
123
- completed = subprocess.run(
124
- code, shell=True, capture_output=True, text=True, timeout=self.max_execution_time
125
- )
126
- return {
127
- "execution_id": execution_id,
128
- "status": "success" if completed.returncode == 0 else "error",
129
- "stdout": completed.stdout,
130
- "stderr": completed.stderr,
131
- "result": None,
132
- "plots": [],
133
- "dataframes": []
134
- }
135
- except subprocess.TimeoutExpired:
136
- return {
137
- "execution_id": execution_id,
138
- "status": "error",
139
- "stdout": "",
140
- "stderr": "Execution timed out.",
141
- "result": None,
142
- "plots": [],
143
- "dataframes": []
144
- }
145
-
146
- def _execute_sql(self, code: str, execution_id: str) -> dict:
147
- result = {
148
- "execution_id": execution_id,
149
- "status": "error",
150
- "stdout": "",
151
- "stderr": "",
152
- "result": None,
153
- "plots": [],
154
- "dataframes": []
155
- }
156
-
157
- try:
158
- conn = sqlite3.connect(self.temp_sqlite_db)
159
- cur = conn.cursor()
160
- cur.execute(code)
161
-
162
- if code.strip().lower().startswith("select"):
163
- columns = [description[0] for description in cur.description]
164
- rows = cur.fetchall()
165
- df = pd.DataFrame(rows, columns=columns)
166
- result["dataframes"].append({
167
- "name": "query_result",
168
- "head": df.head().to_dict(),
169
- "shape": df.shape,
170
- "dtypes": str(df.dtypes)
171
- })
172
- else:
173
- conn.commit()
174
-
175
- result["status"] = "success"
176
- result["stdout"] = "Query executed successfully."
177
- except Exception as e:
178
- result["stderr"] = str(e)
179
- finally:
180
- conn.close()
181
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import sys
4
+ import uuid
5
+ import base64
6
+ import traceback
7
+ import contextlib
8
+ import tempfile
9
+ import subprocess
10
+ import sqlite3
11
+ from typing import Dict, List, Any, Optional, Union
12
+
13
+ import numpy as np
14
+ import pandas as pd
15
+ import matplotlib.pyplot as plt
16
+ from PIL import Image
17
+
18
+ class CodeInterpreter:
19
+ def __init__(self, allowed_modules=None, max_execution_time=30, working_directory=None):
20
+ """Initialize the code interpreter with safety measures."""
21
+ self.allowed_modules = allowed_modules or [
22
+ "numpy", "pandas", "matplotlib", "scipy", "sklearn",
23
+ "math", "random", "statistics", "datetime", "collections",
24
+ "itertools", "functools", "operator", "re", "json",
25
+ "sympy", "networkx", "nltk", "PIL", "pytesseract",
26
+ "cmath", "uuid", "tempfile", "requests", "urllib"
27
+ ]
28
+
29
+ self.max_execution_time = max_execution_time
30
+ self.working_directory = working_directory or os.path.join(os.getcwd())
31
+ if not os.path.exists(self.working_directory):
32
+ os.makedirs(self.working_directory)
33
+
34
+ self.globals = {
35
+ "__builtins__": __builtins__,
36
+ "np": np,
37
+ "pd": pd,
38
+ "plt": plt,
39
+ "Image": Image,
40
+ }
41
+
42
+ self.temp_sqlite_db = os.path.join(tempfile.gettempdir(), "code_exec.db")
43
+
44
+ def execute_code(self, code: str, language: str = "python") -> Dict[str, Any]:
45
+ """Execute the provided code in the selected programming language."""
46
+ language = language.lower()
47
+ execution_id = str(uuid.uuid4())
48
+
49
+ result = {
50
+ "execution_id": execution_id,
51
+ "status": "error",
52
+ "stdout": "",
53
+ "stderr": "",
54
+ "result": None,
55
+ "plots": [],
56
+ "dataframes": []
57
+ }
58
+
59
+ try:
60
+ if language == "python":
61
+ return self._execute_python(code, execution_id)
62
+ elif language == "bash":
63
+ return self._execute_bash(code, execution_id)
64
+ elif language == "sql":
65
+ return self._execute_sql(code, execution_id)
66
+ elif language == "c":
67
+ return self._execute_c(code, execution_id)
68
+ elif language == "java":
69
+ return self._execute_java(code, execution_id)
70
+ else:
71
+ result["stderr"] = f"Unsupported language: {language}"
72
+ except Exception as e:
73
+ result["stderr"] = str(e)
74
+ return result
75
+
76
+ def _execute_python(self, code: str, execution_id: str) -> dict:
77
+ output_buffer = io.StringIO()
78
+ error_buffer = io.StringIO()
79
+ result = {
80
+ "execution_id": execution_id,
81
+ "status": "error",
82
+ "stdout": "",
83
+ "stderr": "",
84
+ "result": None,
85
+ "plots": [],
86
+ "dataframes": []
87
+ }
88
+ try:
89
+ exec_dir = os.path.join(self.working_directory, execution_id)
90
+ os.makedirs(exec_dir, exist_ok=True)
91
+ plt.switch_backend('Agg')
92
+ with contextlib.redirect_stdout(output_buffer), contextlib.redirect_stderr(error_buffer):
93
+ exec_result = exec(code, self.globals)
94
+ if plt.get_fignums():
95
+ for i, fig_num in enumerate(plt.get_fignums()):
96
+ fig = plt.figure(fig_num)
97
+ img_path = os.path.join(exec_dir, f"plot_{i}.png")
98
+ fig.savefig(img_path)
99
+ with open(img_path, "rb") as img_file:
100
+ img_data = base64.b64encode(img_file.read()).decode('utf-8')
101
+ result["plots"].append({
102
+ "figure_number": fig_num,
103
+ "data": img_data
104
+ })
105
+ for var_name, var_value in self.globals.items():
106
+ if isinstance(var_value, pd.DataFrame) and len(var_value) > 0:
107
+ result["dataframes"].append({
108
+ "name": var_name,
109
+ "head": var_value.head().to_dict(),
110
+ "shape": var_value.shape,
111
+ "dtypes": str(var_value.dtypes)
112
+ })
113
+ result["status"] = "success"
114
+ result["stdout"] = output_buffer.getvalue()
115
+ result["result"] = exec_result
116
+ except Exception as e:
117
+ result["status"] = "error"
118
+ result["stderr"] = f"{error_buffer.getvalue()}\n{traceback.format_exc()}"
119
+ return result
120
+
121
+ def _execute_bash(self, code: str, execution_id: str) -> dict:
122
+ try:
123
+ completed = subprocess.run(
124
+ code, shell=True, capture_output=True, text=True, timeout=self.max_execution_time
125
+ )
126
+ return {
127
+ "execution_id": execution_id,
128
+ "status": "success" if completed.returncode == 0 else "error",
129
+ "stdout": completed.stdout,
130
+ "stderr": completed.stderr,
131
+ "result": None,
132
+ "plots": [],
133
+ "dataframes": []
134
+ }
135
+ except subprocess.TimeoutExpired:
136
+ return {
137
+ "execution_id": execution_id,
138
+ "status": "error",
139
+ "stdout": "",
140
+ "stderr": "Execution timed out.",
141
+ "result": None,
142
+ "plots": [],
143
+ "dataframes": []
144
+ }
145
+
146
+ def _execute_sql(self, code: str, execution_id: str) -> dict:
147
+ result = {
148
+ "execution_id": execution_id,
149
+ "status": "error",
150
+ "stdout": "",
151
+ "stderr": "",
152
+ "result": None,
153
+ "plots": [],
154
+ "dataframes": []
155
+ }
156
+ try:
157
+ conn = sqlite3.connect(self.temp_sqlite_db)
158
+ cur = conn.cursor()
159
+ cur.execute(code)
160
+ if code.strip().lower().startswith("select"):
161
+ columns = [description[0] for description in cur.description]
162
+ rows = cur.fetchall()
163
+ df = pd.DataFrame(rows, columns=columns)
164
+ result["dataframes"].append({
165
+ "name": "query_result",
166
+ "head": df.head().to_dict(),
167
+ "shape": df.shape,
168
+ "dtypes": str(df.dtypes)
169
+ })
170
+ else:
171
+ conn.commit()
172
+ result["status"] = "success"
173
+ result["stdout"] = "Query executed successfully."
174
+ except Exception as e:
175
+ result["stderr"] = str(e)
176
+ finally:
177
+ conn.close()
178
+ return result
179
+
180
+ def _execute_c(self, code: str, execution_id: str) -> dict:
181
+ temp_dir = tempfile.mkdtemp()
182
+ source_path = os.path.join(temp_dir, "program.c")
183
+ binary_path = os.path.join(temp_dir, "program")
184
+ try:
185
+ with open(source_path, "w") as f:
186
+ f.write(code)
187
+ compile_proc = subprocess.run(
188
+ ["gcc", source_path, "-o", binary_path],
189
+ capture_output=True, text=True, timeout=self.max_execution_time
190
+ )
191
+ if compile_proc.returncode != 0:
192
+ return {
193
+ "execution_id": execution_id,
194
+ "status": "error",
195
+ "stdout": compile_proc.stdout,
196
+ "stderr": compile_proc.stderr,
197
+ "result": None,
198
+ "plots": [],
199
+ "dataframes": []
200
+ }
201
+ run_proc = subprocess.run(
202
+ [binary_path],
203
+ capture_output=True, text=True, timeout=self.max_execution_time
204
+ )
205
+ return {
206
+ "execution_id": execution_id,
207
+ "status": "success" if run_proc.returncode == 0 else "error",
208
+ "stdout": run_proc.stdout,
209
+ "stderr": run_proc.stderr,
210
+ "result": None,
211
+ "plots": [],
212
+ "dataframes": []
213
+ }
214
+ except Exception as e:
215
+ return {
216
+ "execution_id": execution_id,
217
+ "status": "error",
218
+ "stdout": "",
219
+ "stderr": str(e),
220
+ "result": None,
221
+ "plots": [],
222
+ "dataframes": []
223
+ }
224
+
225
+ def _execute_java(self, code: str, execution_id: str) -> dict:
226
+ temp_dir = tempfile.mkdtemp()
227
+ source_path = os.path.join(temp_dir, "Main.java")
228
+ try:
229
+ with open(source_path, "w") as f:
230
+ f.write(code)
231
+ compile_proc = subprocess.run(
232
+ ["javac", source_path],
233
+ capture_output=True, text=True, timeout=self.max_execution_time
234
+ )
235
+ if compile_proc.returncode != 0:
236
+ return {
237
+ "execution_id": execution_id,
238
+ "status": "error",
239
+ "stdout": compile_proc.stdout,
240
+ "stderr": compile_proc.stderr,
241
+ "result": None,
242
+ "plots": [],
243
+ "dataframes": []
244
+ }
245
+ run_proc = subprocess.run(
246
+ ["java", "-cp", temp_dir, "Main"],
247
+ capture_output=True, text=True, timeout=self.max_execution_time
248
+ )
249
+ return {
250
+ "execution_id": execution_id,
251
+ "status": "success" if run_proc.returncode == 0 else "error",
252
+ "stdout": run_proc.stdout,
253
+ "stderr": run_proc.stderr,
254
+ "result": None,
255
+ "plots": [],
256
+ "dataframes": []
257
+ }
258
+ except Exception as e:
259
+ return {
260
+ "execution_id": execution_id,
261
+ "status": "error",
262
+ "stdout": "",
263
+ "stderr": str(e),
264
+ "result": None,
265
+ "plots": [],
266
+ "dataframes": []
267
+ }