File size: 8,466 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import os
from typing import Literal, Optional

from injector import inject

from taskweaver.code_interpreter.code_executor import CodeExecutor
from taskweaver.code_interpreter.code_generator import CodeGenerator, format_code_revision_message
from taskweaver.code_interpreter.code_generator.code_generator import format_output_revision_message
from taskweaver.code_interpreter.code_verification import code_snippet_verification, format_code_correction_message
from taskweaver.config.module_config import ModuleConfig
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Attachment, Memory, Post
from taskweaver.memory.attachment import AttachmentType
from taskweaver.role import Role


class CodeInterpreterConfig(ModuleConfig):
    def _configure(self):
        self._set_name("code_interpreter")
        self.use_local_uri = self._get_bool("use_local_uri", False)
        self.max_retry_count = self._get_int("max_retry_count", 3)

        # for verification
        self.code_verification_on = self._get_bool("code_verification_on", False)
        self.allowed_modules = self._get_list(
            "allowed_modules",
            ["pandas", "matplotlib", "numpy", "sklearn", "scipy", "seaborn", "datetime", "typing"],
        )


def update_verification(
    response: Post,
    status: Literal["NONE", "INCORRECT", "CORRECT"] = "NONE",
    error: str = "No verification is done.",
):
    response.add_attachment(Attachment.create(AttachmentType.verification, status))
    response.add_attachment(
        Attachment.create(AttachmentType.code_error, error),
    )


def update_execution(
    response: Post,
    status: Literal["NONE", "SUCCESS", "FAILURE"] = "NONE",
    result: str = "No code is executed.",
):
    response.add_attachment(Attachment.create(AttachmentType.execution_status, status))
    response.add_attachment(
        Attachment.create(AttachmentType.execution_result, result),
    )


class CodeInterpreter(Role):
    @inject
    def __init__(
        self,
        generator: CodeGenerator,
        executor: CodeExecutor,
        logger: TelemetryLogger,
        config: CodeInterpreterConfig,
    ):
        self.config = config

        self.generator = generator
        self.generator.configure_verification(
            code_verification_on=self.config.code_verification_on,
            allowed_modules=self.config.allowed_modules,
        )

        self.executor = executor
        self.logger = logger
        self.retry_count = 0

        self.logger.info("CodeInterpreter initialized successfully.")

    def reply(
        self,
        memory: Memory,
        event_handler: callable,
        prompt_log_path: Optional[str] = None,
        use_back_up_engine: Optional[bool] = False,
    ) -> Post:
        response: Post = self.generator.reply(
            memory,
            event_handler,
            prompt_log_path,
            use_back_up_engine,
        )
        if response.message is not None:
            update_verification(response, "NONE", "No code verification is performed.")
            update_execution(response, "NONE", "No code is executed.")
            event_handler("CodeInterpreter->Planner", response.message)
            return response

        code = next((a for a in response.attachment_list if a.type == AttachmentType.python), None)

        if code is None:
            # no code is generated is usually due to the failure of parsing the llm output
            update_verification(response, "NONE", "No code verification is performed.")
            update_execution(response, "NONE", "No code is executed due to code generation failure.")
            response.message = "Failed to generate code."
            if self.retry_count < self.config.max_retry_count:
                error_message = format_output_revision_message()
                response.add_attachment(
                    Attachment.create(
                        AttachmentType.revise_message,
                        error_message,
                    ),
                )
                response.send_to = "CodeInterpreter"
                event_handler(
                    "CodeInterpreter->CodeInterpreter",
                    error_message,
                )
                self.retry_count += 1
            else:
                self.retry_count = 0
                event_handler("CodeInterpreter->Planner", response.message)

            return response

        self.logger.info(f"Code to be verified: {code.content}")
        code_verify_errors = code_snippet_verification(
            code.content,
            [plugin.name for plugin in self.generator.get_plugin_pool()],
            self.config.code_verification_on,
            plugin_only=False,
            allowed_modules=self.config.allowed_modules,
        )

        if code_verify_errors is None:
            event_handler("verification", "NONE")
            update_verification(response, "NONE", "No code verification is performed.")
        elif len(code_verify_errors) > 0:
            self.logger.info(
                f"Code verification finished with {len(code_verify_errors)} errors.",
            )
            code_error = "\n".join(code_verify_errors)
            event_handler("verification", f"INCORRECT: {code_error}")
            update_verification(response, "INCORRECT", code_error)
            response.message = code_error
            if self.retry_count < self.config.max_retry_count:
                response.add_attachment(
                    Attachment.create(
                        AttachmentType.revise_message,
                        format_code_correction_message(),
                    ),
                )
                response.send_to = "CodeInterpreter"
                event_handler(
                    "CodeInterpreter->CodeInterpreter",
                    format_code_correction_message(),
                )
                self.retry_count += 1
            else:
                self.retry_count = 0
                event_handler("CodeInterpreter->Planner", response.message)

            # add execution status and result
            update_execution(response, "NONE", "No code is executed due to code verification failure.")
            return response
        elif len(code_verify_errors) == 0:
            event_handler("verification", "CORRECT")
            update_verification(response, "CORRECT", "No error is found.")

        self.logger.info(f"Code to be executed: {code.content}")

        exec_result = self.executor.execute_code(
            exec_id=response.id,
            code=code.content,
        )
        event_handler("status", "SUCCESS" if exec_result.is_success else "FAILURE")
        code_output = self.executor.format_code_output(
            exec_result,
            with_code=False,
            use_local_uri=self.config.use_local_uri,
        )

        event_handler("result", code_output)
        update_execution(
            response,
            status="SUCCESS" if exec_result.is_success else "FAILURE",
            result=code_output,
        )

        # add artifact paths
        response.add_attachment(
            Attachment.create(
                AttachmentType.artifact_paths,
                [
                    (
                        a.file_name
                        if os.path.isabs(a.file_name) or not self.config.use_local_uri
                        else os.path.join(self.executor.execution_cwd, a.file_name)
                    )
                    for a in exec_result.artifact
                ],
            ),
        )

        response.message = self.executor.format_code_output(
            exec_result,
            with_code=True,  # the message to be sent to the user should contain the code
            use_local_uri=self.config.use_local_uri,
        )

        if exec_result.is_success or self.retry_count >= self.config.max_retry_count:
            self.retry_count = 0
            event_handler("CodeInterpreter->Planner", response.message)
        else:
            response.add_attachment(
                Attachment.create(
                    AttachmentType.revise_message,
                    format_code_revision_message(),
                ),
            )
            response.send_to = "CodeInterpreter"
            event_handler(
                "CodeInterpreter->CodeInterpreter",
                format_code_revision_message(),
            )
            self.retry_count += 1
        return response