File size: 3,146 Bytes
3d3d712
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import json
from typing import Optional

from injector import inject

from taskweaver.code_interpreter.code_executor import CodeExecutor
from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly
from taskweaver.config.module_config import ModuleConfig
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Memory, Post
from taskweaver.memory.attachment import AttachmentType
from taskweaver.role import Role


class CodeInterpreterConfig(ModuleConfig):
    def _configure(self):
        self._set_name("code_interpreter_plugin_only")
        self.use_local_uri = self._get_bool("use_local_uri", False)
        self.max_retry_count = self._get_int("max_retry_count", 3)


class CodeInterpreterPluginOnly(Role):
    @inject
    def __init__(
        self,
        generator: CodeGeneratorPluginOnly,
        executor: CodeExecutor,
        logger: TelemetryLogger,
        config: CodeInterpreterConfig,
    ):
        self.generator = generator
        self.executor = executor
        self.logger = logger
        self.config = config
        self.retry_count = 0
        self.return_index = 0

        self.logger.info("CodeInterpreter initialized successfully.")

    def reply(
        self,
        memory: Memory,
        event_handler: callable,
        prompt_log_path: Optional[str] = None,
        use_back_up_engine: Optional[bool] = False,
    ) -> Post:
        response: Post = self.generator.reply(
            memory,
            event_handler,
        )

        if response.message is not None:
            return response

        functions = json.loads(response.get_attachment(type=AttachmentType.function)[0])
        if len(functions) > 0:
            code = []
            for i, f in enumerate(functions):
                function_name = f["name"]
                function_args = json.loads(f["arguments"])
                function_call = (
                    f"r{self.return_index + i}={function_name}("
                    + ", ".join(
                        [
                            f'{key}="{value}"' if isinstance(value, str) else f"{key}={value}"
                            for key, value in function_args.items()
                        ],
                    )
                    + ")"
                )
                code.append(function_call)
            code.append(f'{", ".join([f"r{self.return_index + i}" for i in range(len(functions))])}')
            self.return_index += len(functions)

            event_handler("code", "\n".join(code))
            exec_result = self.executor.execute_code(
                exec_id=response.id,
                code="\n".join(code),
            )

            response.message = self.executor.format_code_output(
                exec_result,
                with_code=True,
                use_local_uri=self.config.use_local_uri,
            )
            event_handler("CodeInterpreter-> Planner", response.message)
        else:
            response.message = "No code is generated because no function is selected."
            event_handler("CodeInterpreter-> Planner", response.message)

        return response