GAIA-Inspired-Multi-Agent-System / tools /code_interpreter_tools.py
Prasanthkumar's picture
Update tools/code_interpreter_tools.py
702bd40 verified
# ========================== #
# 📦 Imports and Setup
# ========================== #
import os
import io
import sys
import uuid
import base64
import traceback
import contextlib
import tempfile
import subprocess
import sqlite3
import logging
from typing import Dict, Any
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from langchain_core.tools import tool
# ========================== #
# 📋 Logging Setup
# ========================== #
def setup_logger(log_file="execution.log"):
logger = logging.getLogger("CodeInterpreter")
logger.setLevel(logging.INFO)
if not logger.handlers:
handler = logging.FileHandler(log_file)
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
# =================================================================== #
# Code interpreter tools for languages like Python, Java, C++, SQL and C
# =================================================================== #
class Code_Interpreter:
def __init__ (
self,
allowed_modules = None,
max_execution_time = 30,
working_directory = None
):
self.allowed_modules = allowed_modules or [
"numpy", "pandas", "matplotlib", "scipy", "sklearn", "math", "random", "statistics",
"datetime", "collections", "itertools", "functools", "operator", "re", "json", "sympy",
"networkx", "nltk", "PIL", "pytesseract", "cmath", "uuid", "tempfile", "requests", "urllib"
]
self.max_execution_time = max_execution_time
self.working_directory = working_directory or os.path.join(os.getcwd())
if not os.path.exists(self.working_directory):
os.makedirs(self.working_directory)
self.globals = {"__builtins__": __builtins__, "np": np, "pd": pd, "plt": plt, "Image": Image}
self.temp_sqlite_db = os.path.join(tempfile.gettempdir(), "code_exec.db")
def execute_code(self, code: str, language: str = "python") -> Dict[str, Any]:
"""Dispatch execution to the appropriate language handler."""
lang = langauge.lower()
execution_id = str(uuid.uuid4())
logger.info(f"[{execution_id}] Executing code in language: {lang}")
result = {
"execution_id": execution_id,
"status": "error",
"stdout": "",
"stderr": "",
"result": None,
"plots": [],
"dataframes": []
}
try:
if lang == "python":
if any(x in code for x in ["os.remove", "shutil.rmtree", "open('/etc", "__import__"]):
raise ValueError("Unsafe code detected.")
return self._execute_python(code, execution_id)
elif lang == "java":
return self._execute_java(code, execution_id)
elif lang == "c":
return self._execute_c(code, execution_id)
elif lang == "sql":
return self._execute_sql(code, execution_id)
elif lang == "bash":
return self._execute_bash(code, execution_id)
except Exception as e:
result["stderr"] = str(e)
logger.error(f"[{execution_id}] Execution error: {e}", exc_info=True)
return result
def _execute_python(self, code: str, execution_id: str) -> dict:
"""Execute Python code safely with stdout/stderr capture and plot handling."""
output_buffer = io.StringIO()
error_buffer = io.StringIO()
result = {
"execution_id": execution_id,
"status": "error",
"stdout": "",
"stderr": "",
"result": None,
"plots": [],
"dataframes": []
}
try:
exec_dir = os.path.join(self.working_directory, execution_id)
os.makedirs(exec_dir, exist_ok=True)
plt.switch_backend('Agg')
with contextlib.redirect_stdout(output_buffer), contextlib.redirect_stderr(error_buffer):
exec_result = exec(code, self.globals)
# Capture plots
if plt.get_fignums():
for i, fig_num in enumerate(plt.get_fignums()):
fig = plt.figure(fig_num)
img_path = os.path.join(exec_dir, f"plot_{i}.png")
fig.savefig(img_path)
with open(img_path, "rb") as img_file:
img_data = base64.b64encode(img_file.read()).decode('utf-8')
result["plots"].append({"figure_number": fig_num, "data": img_data})
# Capture dataframes
for var_name, var_value in self.globals.items():
if isinstance(var_value, pd.DataFrame) and len(var_value) > 0:
result["dataframes"].append({
"name": var_name,
"head": var_value.head().to_dict(),
"shape": var_value.shape,
"dtypes": str(var_value.dtypes)
})
result["status"] = "success"
result["stdout"] = output_buffer.getvalue()
result["result"] = exec_result
logger.info(f"[{execution_id}] Python code executed successfully.")
except Exception as e:
result["status"] = "error"
result["stderr"] = error_buffer.getvalue() + "\n" + traceback.format_exc()
logger.error(f"[{execution_id}] Python execution failed: {e}", exc_info=True)
return result
def _execute_java(self, code: str, execution_id: str) -> dict:
temp_dir = tempfile.mkdtemp()
source_path = os.path.join(temp_dir, "Main.java")
try:
with open(source_path, "w") as f:
f.write(code)
compile_proc = subprocess.run(["javac", source_path], capture_output=True, text=True, timeout=self.max_execution_time)
if compile_proc.returncode != 0:
return {
"execution_id": execution_id,
"status": "error",
"stdout": compile_proc.stdout,
"stderr": compile_proc.stderr,
"result": None,
"plots": [],
"dataframes": []
}
run_proc = subprocess.run(["java", "-cp", temp_dir, "Main"], capture_output=True, text=True, timeout=self.max_execution_time)
return {
"execution_id": execution_id,
"status": "success" if run_proc.returncode == 0 else "error",
"stdout": run_proc.stdout,
"stderr": run_proc.stderr,
"result": None,
"plots": [],
"dataframes": []
}
except Exception as e:
return {
"execution_id": execution_id,
"status": "error",
"stdout": "",
"stderr": str(e),
"result": None,
"plots": [],
"dataframes": []
}
def _execute_c(self, code: str, execution_id: str) -> dict:
temp_dir = tempfile.mkdtemp()
source_path = os.path.join(temp_dir, "program.c")
binary_path = os.path.join(temp_dir, "program")
try:
with open(source_path, "w") as f:
f.write(code)
compile_proc = subprocess.run(["gcc", source_path, "-o", binary_path], capture_output=True, text=True, timeout=self.max_execution_time)
if compile_proc.returncode != 0:
return {
"execution_id": execution_id,
"status": "error",
"stdout": compile_proc.stdout,
"stderr": compile_proc.stderr,
"result": None,
"plots": [],
"dataframes": []
}
run_proc = subprocess.run([binary_path], capture_output=True, text=True, timeout=self.max_execution_time)
return {
"execution_id": execution_id,
"status": "success" if run_proc.returncode == 0 else "error",
"stdout": run_proc.stdout,
"stderr": run_proc.stderr,
"result": None,
"plots": [],
"dataframes": []
}
except Exception as e:
return {
"execution_id": execution_id,
"status": "error",
"stdout": "",
"stderr": str(e),
"result": None,
"plots": [],
"dataframes": []
}
def _execute_sql(self, code: str, execution_id: str) -> dict:
result = {
"execution_id": execution_id,
"status": "error",
"stdout": "",
"stderr": "",
"result": None,
"plots": [],
"dataframes": []
}
try:
conn = sqlite3.connect(self.temp_sqlite_db)
cur = conn.cursor()
cur.execute(code)
if code.strip().lower().startswith("select"):
columns = [desc[0] for desc in cur.description]
rows = cur.fetchall()
df = pd.DataFrame(rows, columns=columns)
result["dataframes"].append({
"name": "query_result",
"head": df.head().to_dict(),
"shape": df.shape,
"dtypes": str(df.dtypes)
})
else:
conn.commit()
result["status"] = "success"
result["stdout"] = "Query executed successfully."
except Exception as e:
result["stderr"] = str(e)
logger.error(f"[{execution_id}] SQL execution failed: {e}", exc_info=True)
finally:
conn.close()
return result
def _execute_bash(self, code: str, execution_id: str) -> dict:
try:
completed = subprocess.run(code, shell=True, capture_output=True, text=True, timeout=self.max_execution_time)
return {
"execution_id": execution_id,
"status": "success" if completed.returncode == 0 else "error",
"stdout": completed.stdout,
"stderr": completed.stderr,
"result": None,
"plots": [],
"dataframes": []
}
except subprocess.TimeoutExpired:
return {
"execution_id": execution_id,
"status": "error",
"stdout": "",
"stderr": "Execution timed out.",
"result": None,
"plots": [],
"dataframes": []
}
# ================================== #
# LangChain tool
# ================================== #
interpreter = Code_Interpreter()
@tool
def execute_code_multilang(code: str, language: str = "python") -> str:
"""
Execute code in multiple languages (Python, Bash, SQL, C, Java) and return results.
Args:
code (str): the source code to execute
language (str): the language of the code
"""
result = interpreter.execute_code(code, language)
response = []
if result["status"] == "success":
response.append(f"✅ Code executed successfully in **{language.upper()}**")
if result.get("stdout"):
response.append("\n**Standard Output:**\n```\n" + result["stdout"].strip() + "\n```")
if result.get("stderr"):
response.append("\n**Standard Error (if any):**\n```\n" + result["stderr"].strip() + "\n```")
if result.get("dataframes"):
for df in result["dataframes"]:
preview = pd.DataFrame(df["head"])
response.append(f"\n**DataFrame `{df['name']}` (Shape: {df['shape']})**\n```\n{preview}\n```")
if result.get("plots"):
response.append(f"\n🖼️ {len(result['plots'])} plot(s) generated (encoded)")
else:
response.append(f"❌ Code execution failed in **{language.upper()}**")
if result.get("stderr"):
response.append("\n**Error Log:**\n```\n" + result["stderr"].strip() + "\n```")
return "\n".join(response)