File size: 7,979 Bytes
129cd69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
from __future__ import annotations

import ast
import json
import os
from io import StringIO
from sys import version_info
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type

from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr

from langchain.callbacks.manager import (
    AsyncCallbackManagerForToolRun,
    CallbackManager,
    CallbackManagerForToolRun,
)
from langchain.tools import BaseTool, Tool
from langchain.tools.e2b_data_analysis.unparse import Unparser

if TYPE_CHECKING:
    from e2b import EnvVars
    from e2b.templates.data_analysis import Artifact

base_description = """Evaluates python code in a sandbox environment. \
The environment is long running and exists across multiple executions. \
You must send the whole script every time and print your outputs. \
Script should be pure python code that can be evaluated. \
It should be in python format NOT markdown. \
The code should NOT be wrapped in backticks. \
All python packages including requests, matplotlib, scipy, numpy, pandas, \
etc are available. Create and display chart using `plt.show()`."""


def _unparse(tree: ast.AST) -> str:
    """Unparse the AST."""
    if version_info.minor < 9:
        s = StringIO()
        Unparser(tree, file=s)
        source_code = s.getvalue()
        s.close()
    else:
        source_code = ast.unparse(tree)  # type: ignore[attr-defined]
    return source_code


def add_last_line_print(code: str) -> str:
    """Add print statement to the last line if it's missing.

    Sometimes, the LLM-generated code doesn't have `print(variable_name)`, instead the
        LLM tries to print the variable only by writing `variable_name` (as you would in
        REPL, for example).

    This methods checks the AST of the generated Python code and adds the print
        statement to the last line if it's missing.
    """
    tree = ast.parse(code)
    node = tree.body[-1]
    if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call):
        if isinstance(node.value.func, ast.Name) and node.value.func.id == "print":
            return _unparse(tree)

    if isinstance(node, ast.Expr):
        tree.body[-1] = ast.Expr(
            value=ast.Call(
                func=ast.Name(id="print", ctx=ast.Load()),
                args=[node.value],
                keywords=[],
            )
        )

    return _unparse(tree)


class UploadedFile(BaseModel):
    """Description of the uploaded path with its remote path."""

    name: str
    remote_path: str
    description: str


class E2BDataAnalysisToolArguments(BaseModel):
    """Arguments for the E2BDataAnalysisTool."""

    python_code: str = Field(
        ...,
        example="print('Hello World')",
        description=(
            "The python script to be evaluated. "
            "The contents will be in main.py. "
            "It should not be in markdown format."
        ),
    )


class E2BDataAnalysisTool(BaseTool):
    """Tool for running python code in a sandboxed environment for data analysis."""

    name = "e2b_data_analysis"
    args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments
    session: Any
    description: str
    _uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list)

    def __init__(
        self,
        api_key: Optional[str] = None,
        cwd: Optional[str] = None,
        env_vars: Optional[EnvVars] = None,
        on_stdout: Optional[Callable[[str], Any]] = None,
        on_stderr: Optional[Callable[[str], Any]] = None,
        on_artifact: Optional[Callable[[Artifact], Any]] = None,
        on_exit: Optional[Callable[[int], Any]] = None,
        **kwargs: Any,
    ):
        try:
            from e2b import DataAnalysis
        except ImportError as e:
            raise ImportError(
                "Unable to import e2b, please install with `pip install e2b`."
            ) from e

        # If no API key is provided, E2B will try to read it from the environment
        # variable E2B_API_KEY
        super().__init__(description=base_description, **kwargs)
        self.session = DataAnalysis(
            api_key=api_key,
            cwd=cwd,
            env_vars=env_vars,
            on_stdout=on_stdout,
            on_stderr=on_stderr,
            on_exit=on_exit,
            on_artifact=on_artifact,
        )

    def close(self) -> None:
        """Close the cloud sandbox."""
        self._uploaded_files = []
        self.session.close()

    @property
    def uploaded_files_description(self) -> str:
        if len(self._uploaded_files) == 0:
            return ""
        lines = ["The following files available in the sandbox:"]

        for f in self._uploaded_files:
            if f.description == "":
                lines.append(f"- path: `{f.remote_path}`")
            else:
                lines.append(
                    f"- path: `{f.remote_path}` \n description: `{f.description}`"
                )
        return "\n".join(lines)

    def _run(
        self,
        python_code: str,
        run_manager: Optional[CallbackManagerForToolRun] = None,
        callbacks: Optional[CallbackManager] = None,
    ) -> str:
        python_code = add_last_line_print(python_code)

        if callbacks is not None:
            on_artifact = getattr(callbacks.metadata, "on_artifact", None)
        else:
            on_artifact = None

        stdout, stderr, artifacts = self.session.run_python(
            python_code, on_artifact=on_artifact
        )

        out = {
            "stdout": stdout,
            "stderr": stderr,
            "artifacts": list(map(lambda artifact: artifact.name, artifacts)),
        }
        return json.dumps(out)

    async def _arun(
        self,
        python_code: str,
        run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
    ) -> str:
        raise NotImplementedError("e2b_data_analysis does not support async")

    def run_command(
        self,
        cmd: str,
    ) -> dict:
        """Run shell command in the sandbox."""
        proc = self.session.process.start(cmd)
        output = proc.wait()
        return {
            "stdout": output.stdout,
            "stderr": output.stderr,
            "exit_code": output.exit_code,
        }

    def install_python_packages(self, package_names: str | List[str]) -> None:
        """Install python packages in the sandbox."""
        self.session.install_python_packages(package_names)

    def install_system_packages(self, package_names: str | List[str]) -> None:
        """Install system packages (via apt) in the sandbox."""
        self.session.install_system_packages(package_names)

    def download_file(self, remote_path: str) -> bytes:
        """Download file from the sandbox."""
        return self.session.download_file(remote_path)

    def upload_file(self, file: IO, description: str) -> UploadedFile:
        """Upload file to the sandbox.

        The file is uploaded to the '/home/user/<filename>' path."""
        remote_path = self.session.upload_file(file)

        f = UploadedFile(
            name=os.path.basename(file.name),
            remote_path=remote_path,
            description=description,
        )
        self._uploaded_files.append(f)
        self.description = self.description + "\n" + self.uploaded_files_description
        return f

    def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None:
        """Remove uploaded file from the sandbox."""
        self.session.filesystem.remove(uploaded_file.remote_path)
        self._uploaded_files = [
            f
            for f in self._uploaded_files
            if f.remote_path != uploaded_file.remote_path
        ]
        self.description = self.description + "\n" + self.uploaded_files_description

    def as_tool(self) -> Tool:
        return Tool.from_function(
            func=self._run,
            name=self.name,
            description=self.description,
            args_schema=self.args_schema,
        )