Spaces:
Runtime error
Runtime error
File size: 7,979 Bytes
129cd69 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | from __future__ import annotations
import ast
import json
import os
from io import StringIO
from sys import version_info
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr
from langchain.callbacks.manager import (
AsyncCallbackManagerForToolRun,
CallbackManager,
CallbackManagerForToolRun,
)
from langchain.tools import BaseTool, Tool
from langchain.tools.e2b_data_analysis.unparse import Unparser
if TYPE_CHECKING:
from e2b import EnvVars
from e2b.templates.data_analysis import Artifact
base_description = """Evaluates python code in a sandbox environment. \
The environment is long running and exists across multiple executions. \
You must send the whole script every time and print your outputs. \
Script should be pure python code that can be evaluated. \
It should be in python format NOT markdown. \
The code should NOT be wrapped in backticks. \
All python packages including requests, matplotlib, scipy, numpy, pandas, \
etc are available. Create and display chart using `plt.show()`."""
def _unparse(tree: ast.AST) -> str:
"""Unparse the AST."""
if version_info.minor < 9:
s = StringIO()
Unparser(tree, file=s)
source_code = s.getvalue()
s.close()
else:
source_code = ast.unparse(tree) # type: ignore[attr-defined]
return source_code
def add_last_line_print(code: str) -> str:
"""Add print statement to the last line if it's missing.
Sometimes, the LLM-generated code doesn't have `print(variable_name)`, instead the
LLM tries to print the variable only by writing `variable_name` (as you would in
REPL, for example).
This methods checks the AST of the generated Python code and adds the print
statement to the last line if it's missing.
"""
tree = ast.parse(code)
node = tree.body[-1]
if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
if isinstance(node.value.func, ast.Name) and node.value.func.id == "print":
return _unparse(tree)
if isinstance(node, ast.Expr):
tree.body[-1] = ast.Expr(
value=ast.Call(
func=ast.Name(id="print", ctx=ast.Load()),
args=[node.value],
keywords=[],
)
)
return _unparse(tree)
class UploadedFile(BaseModel):
"""Description of the uploaded path with its remote path."""
name: str
remote_path: str
description: str
class E2BDataAnalysisToolArguments(BaseModel):
"""Arguments for the E2BDataAnalysisTool."""
python_code: str = Field(
...,
example="print('Hello World')",
description=(
"The python script to be evaluated. "
"The contents will be in main.py. "
"It should not be in markdown format."
),
)
class E2BDataAnalysisTool(BaseTool):
"""Tool for running python code in a sandboxed environment for data analysis."""
name = "e2b_data_analysis"
args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
session: Any
description: str
_uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)
def __init__(
self,
api_key: Optional[str] = None,
cwd: Optional[str] = None,
env_vars: Optional[EnvVars] = None,
on_stdout: Optional[Callable[[str], Any]] = None,
on_stderr: Optional[Callable[[str], Any]] = None,
on_artifact: Optional[Callable[[Artifact], Any]] = None,
on_exit: Optional[Callable[[int], Any]] = None,
**kwargs: Any,
):
try:
from e2b import DataAnalysis
except ImportError as e:
raise ImportError(
"Unable to import e2b, please install with `pip install e2b`."
) from e
# If no API key is provided, E2B will try to read it from the environment
# variable E2B_API_KEY
super().__init__(description=base_description, **kwargs)
self.session = DataAnalysis(
api_key=api_key,
cwd=cwd,
env_vars=env_vars,
on_stdout=on_stdout,
on_stderr=on_stderr,
on_exit=on_exit,
on_artifact=on_artifact,
)
def close(self) -> None:
"""Close the cloud sandbox."""
self._uploaded_files = []
self.session.close()
@property
def uploaded_files_description(self) -> str:
if len(self._uploaded_files) == 0:
return ""
lines = ["The following files available in the sandbox:"]
for f in self._uploaded_files:
if f.description == "":
lines.append(f"- path: `{f.remote_path}`")
else:
lines.append(
f"- path: `{f.remote_path}` \n description: `{f.description}`"
)
return "\n".join(lines)
def _run(
self,
python_code: str,
run_manager: Optional[CallbackManagerForToolRun] = None,
callbacks: Optional[CallbackManager] = None,
) -> str:
python_code = add_last_line_print(python_code)
if callbacks is not None:
on_artifact = getattr(callbacks.metadata, "on_artifact", None)
else:
on_artifact = None
stdout, stderr, artifacts = self.session.run_python(
python_code, on_artifact=on_artifact
)
out = {
"stdout": stdout,
"stderr": stderr,
"artifacts": list(map(lambda artifact: artifact.name, artifacts)),
}
return json.dumps(out)
async def _arun(
self,
python_code: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> str:
raise NotImplementedError("e2b_data_analysis does not support async")
def run_command(
self,
cmd: str,
) -> dict:
"""Run shell command in the sandbox."""
proc = self.session.process.start(cmd)
output = proc.wait()
return {
"stdout": output.stdout,
"stderr": output.stderr,
"exit_code": output.exit_code,
}
def install_python_packages(self, package_names: str | List[str]) -> None:
"""Install python packages in the sandbox."""
self.session.install_python_packages(package_names)
def install_system_packages(self, package_names: str | List[str]) -> None:
"""Install system packages (via apt) in the sandbox."""
self.session.install_system_packages(package_names)
def download_file(self, remote_path: str) -> bytes:
"""Download file from the sandbox."""
return self.session.download_file(remote_path)
def upload_file(self, file: IO, description: str) -> UploadedFile:
"""Upload file to the sandbox.
The file is uploaded to the '/home/user/<filename>' path."""
remote_path = self.session.upload_file(file)
f = UploadedFile(
name=os.path.basename(file.name),
remote_path=remote_path,
description=description,
)
self._uploaded_files.append(f)
self.description = self.description + "\n" + self.uploaded_files_description
return f
def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
"""Remove uploaded file from the sandbox."""
self.session.filesystem.remove(uploaded_file.remote_path)
self._uploaded_files = [
f
for f in self._uploaded_files
if f.remote_path != uploaded_file.remote_path
]
self.description = self.description + "\n" + self.uploaded_files_description
def as_tool(self) -> Tool:
return Tool.from_function(
func=self._run,
name=self.name,
description=self.description,
args_schema=self.args_schema,
)
|