tskwvr / taskweaver /code_interpreter /code_interpreter.py
TRaw's picture
Upload 297 files
3d3d712
import os
from typing import Literal, Optional
from injector import inject
from taskweaver.code_interpreter.code_executor import CodeExecutor
from taskweaver.code_interpreter.code_generator import CodeGenerator, format_code_revision_message
from taskweaver.code_interpreter.code_generator.code_generator import format_output_revision_message
from taskweaver.code_interpreter.code_verification import code_snippet_verification, format_code_correction_message
from taskweaver.config.module_config import ModuleConfig
from taskweaver.logging import TelemetryLogger
from taskweaver.memory import Attachment, Memory, Post
from taskweaver.memory.attachment import AttachmentType
from taskweaver.role import Role
class CodeInterpreterConfig(ModuleConfig):
def _configure(self):
self._set_name("code_interpreter")
self.use_local_uri = self._get_bool("use_local_uri", False)
self.max_retry_count = self._get_int("max_retry_count", 3)
# for verification
self.code_verification_on = self._get_bool("code_verification_on", False)
self.allowed_modules = self._get_list(
"allowed_modules",
["pandas", "matplotlib", "numpy", "sklearn", "scipy", "seaborn", "datetime", "typing"],
)
def update_verification(
response: Post,
status: Literal["NONE", "INCORRECT", "CORRECT"] = "NONE",
error: str = "No verification is done.",
):
response.add_attachment(Attachment.create(AttachmentType.verification, status))
response.add_attachment(
Attachment.create(AttachmentType.code_error, error),
)
def update_execution(
response: Post,
status: Literal["NONE", "SUCCESS", "FAILURE"] = "NONE",
result: str = "No code is executed.",
):
response.add_attachment(Attachment.create(AttachmentType.execution_status, status))
response.add_attachment(
Attachment.create(AttachmentType.execution_result, result),
)
class CodeInterpreter(Role):
@inject
def __init__(
self,
generator: CodeGenerator,
executor: CodeExecutor,
logger: TelemetryLogger,
config: CodeInterpreterConfig,
):
self.config = config
self.generator = generator
self.generator.configure_verification(
code_verification_on=self.config.code_verification_on,
allowed_modules=self.config.allowed_modules,
)
self.executor = executor
self.logger = logger
self.retry_count = 0
self.logger.info("CodeInterpreter initialized successfully.")
def reply(
self,
memory: Memory,
event_handler: callable,
prompt_log_path: Optional[str] = None,
use_back_up_engine: Optional[bool] = False,
) -> Post:
response: Post = self.generator.reply(
memory,
event_handler,
prompt_log_path,
use_back_up_engine,
)
if response.message is not None:
update_verification(response, "NONE", "No code verification is performed.")
update_execution(response, "NONE", "No code is executed.")
event_handler("CodeInterpreter->Planner", response.message)
return response
code = next((a for a in response.attachment_list if a.type == AttachmentType.python), None)
if code is None:
# no code is generated is usually due to the failure of parsing the llm output
update_verification(response, "NONE", "No code verification is performed.")
update_execution(response, "NONE", "No code is executed due to code generation failure.")
response.message = "Failed to generate code."
if self.retry_count < self.config.max_retry_count:
error_message = format_output_revision_message()
response.add_attachment(
Attachment.create(
AttachmentType.revise_message,
error_message,
),
)
response.send_to = "CodeInterpreter"
event_handler(
"CodeInterpreter->CodeInterpreter",
error_message,
)
self.retry_count += 1
else:
self.retry_count = 0
event_handler("CodeInterpreter->Planner", response.message)
return response
self.logger.info(f"Code to be verified: {code.content}")
code_verify_errors = code_snippet_verification(
code.content,
[plugin.name for plugin in self.generator.get_plugin_pool()],
self.config.code_verification_on,
plugin_only=False,
allowed_modules=self.config.allowed_modules,
)
if code_verify_errors is None:
event_handler("verification", "NONE")
update_verification(response, "NONE", "No code verification is performed.")
elif len(code_verify_errors) > 0:
self.logger.info(
f"Code verification finished with {len(code_verify_errors)} errors.",
)
code_error = "\n".join(code_verify_errors)
event_handler("verification", f"INCORRECT: {code_error}")
update_verification(response, "INCORRECT", code_error)
response.message = code_error
if self.retry_count < self.config.max_retry_count:
response.add_attachment(
Attachment.create(
AttachmentType.revise_message,
format_code_correction_message(),
),
)
response.send_to = "CodeInterpreter"
event_handler(
"CodeInterpreter->CodeInterpreter",
format_code_correction_message(),
)
self.retry_count += 1
else:
self.retry_count = 0
event_handler("CodeInterpreter->Planner", response.message)
# add execution status and result
update_execution(response, "NONE", "No code is executed due to code verification failure.")
return response
elif len(code_verify_errors) == 0:
event_handler("verification", "CORRECT")
update_verification(response, "CORRECT", "No error is found.")
self.logger.info(f"Code to be executed: {code.content}")
exec_result = self.executor.execute_code(
exec_id=response.id,
code=code.content,
)
event_handler("status", "SUCCESS" if exec_result.is_success else "FAILURE")
code_output = self.executor.format_code_output(
exec_result,
with_code=False,
use_local_uri=self.config.use_local_uri,
)
event_handler("result", code_output)
update_execution(
response,
status="SUCCESS" if exec_result.is_success else "FAILURE",
result=code_output,
)
# add artifact paths
response.add_attachment(
Attachment.create(
AttachmentType.artifact_paths,
[
(
a.file_name
if os.path.isabs(a.file_name) or not self.config.use_local_uri
else os.path.join(self.executor.execution_cwd, a.file_name)
)
for a in exec_result.artifact
],
),
)
response.message = self.executor.format_code_output(
exec_result,
with_code=True, # the message to be sent to the user should contain the code
use_local_uri=self.config.use_local_uri,
)
if exec_result.is_success or self.retry_count >= self.config.max_retry_count:
self.retry_count = 0
event_handler("CodeInterpreter->Planner", response.message)
else:
response.add_attachment(
Attachment.create(
AttachmentType.revise_message,
format_code_revision_message(),
),
)
response.send_to = "CodeInterpreter"
event_handler(
"CodeInterpreter->CodeInterpreter",
format_code_revision_message(),
)
self.retry_count += 1
return response