File size: 3,146 Bytes
3d3d712 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import json
from typing import Optional
from injector import inject
from taskweaver.code_interpreter.code_executor import CodeExecutor
from taskweaver.code_interpreter.code_generator import CodeGeneratorPluginOnly
from taskweaver.config.module_config import ModuleConfig
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Memory, Post
from taskweaver.memory.attachment import AttachmentType
from taskweaver.role import Role
class CodeInterpreterConfig(ModuleConfig):
def _configure(self):
self._set_name("code_interpreter_plugin_only")
self.use_local_uri = self._get_bool("use_local_uri", False)
self.max_retry_count = self._get_int("max_retry_count", 3)
class CodeInterpreterPluginOnly(Role):
@inject
def __init__(
self,
generator: CodeGeneratorPluginOnly,
executor: CodeExecutor,
logger: TelemetryLogger,
config: CodeInterpreterConfig,
):
self.generator = generator
self.executor = executor
self.logger = logger
self.config = config
self.retry_count = 0
self.return_index = 0
self.logger.info("CodeInterpreter initialized successfully.")
def reply(
self,
memory: Memory,
event_handler: callable,
prompt_log_path: Optional[str] = None,
use_back_up_engine: Optional[bool] = False,
) -> Post:
response: Post = self.generator.reply(
memory,
event_handler,
)
if response.message is not None:
return response
functions = json.loads(response.get_attachment(type=AttachmentType.function)[0])
if len(functions) > 0:
code = []
for i, f in enumerate(functions):
function_name = f["name"]
function_args = json.loads(f["arguments"])
function_call = (
f"r{self.return_index + i}={function_name}("
+ ", ".join(
[
f'{key}="{value}"' if isinstance(value, str) else f"{key}={value}"
for key, value in function_args.items()
],
)
+ ")"
)
code.append(function_call)
code.append(f'{", ".join([f"r{self.return_index + i}" for i in range(len(functions))])}')
self.return_index += len(functions)
event_handler("code", "\n".join(code))
exec_result = self.executor.execute_code(
exec_id=response.id,
code="\n".join(code),
)
response.message = self.executor.format_code_output(
exec_result,
with_code=True,
use_local_uri=self.config.use_local_uri,
)
event_handler("CodeInterpreter-> Planner", response.message)
else:
response.message = "No code is generated because no function is selected."
event_handler("CodeInterpreter-> Planner", response.message)
return response
|