|
|
import os |
|
|
from typing import Literal, Optional |
|
|
|
|
|
from injector import inject |
|
|
|
|
|
from taskweaver.code_interpreter.code_executor import CodeExecutor |
|
|
from taskweaver.code_interpreter.code_generator import CodeGenerator, format_code_revision_message |
|
|
from taskweaver.code_interpreter.code_generator.code_generator import format_output_revision_message |
|
|
from taskweaver.code_interpreter.code_verification import code_snippet_verification, format_code_correction_message |
|
|
from taskweaver.config.module_config import ModuleConfig |
|
|
from taskweaver.logging import TelemetryLogger |
|
|
from taskweaver.memory import Attachment, Memory, Post |
|
|
from taskweaver.memory.attachment import AttachmentType |
|
|
from taskweaver.role import Role |
|
|
|
|
|
|
|
|
class CodeInterpreterConfig(ModuleConfig): |
|
|
def _configure(self): |
|
|
self._set_name("code_interpreter") |
|
|
self.use_local_uri = self._get_bool("use_local_uri", False) |
|
|
self.max_retry_count = self._get_int("max_retry_count", 3) |
|
|
|
|
|
|
|
|
self.code_verification_on = self._get_bool("code_verification_on", False) |
|
|
self.allowed_modules = self._get_list( |
|
|
"allowed_modules", |
|
|
["pandas", "matplotlib", "numpy", "sklearn", "scipy", "seaborn", "datetime", "typing"], |
|
|
) |
|
|
|
|
|
|
|
|
def update_verification( |
|
|
response: Post, |
|
|
status: Literal["NONE", "INCORRECT", "CORRECT"] = "NONE", |
|
|
error: str = "No verification is done.", |
|
|
): |
|
|
response.add_attachment(Attachment.create(AttachmentType.verification, status)) |
|
|
response.add_attachment( |
|
|
Attachment.create(AttachmentType.code_error, error), |
|
|
) |
|
|
|
|
|
|
|
|
def update_execution( |
|
|
response: Post, |
|
|
status: Literal["NONE", "SUCCESS", "FAILURE"] = "NONE", |
|
|
result: str = "No code is executed.", |
|
|
): |
|
|
response.add_attachment(Attachment.create(AttachmentType.execution_status, status)) |
|
|
response.add_attachment( |
|
|
Attachment.create(AttachmentType.execution_result, result), |
|
|
) |
|
|
|
|
|
|
|
|
class CodeInterpreter(Role): |
|
|
@inject |
|
|
def __init__( |
|
|
self, |
|
|
generator: CodeGenerator, |
|
|
executor: CodeExecutor, |
|
|
logger: TelemetryLogger, |
|
|
config: CodeInterpreterConfig, |
|
|
): |
|
|
self.config = config |
|
|
|
|
|
self.generator = generator |
|
|
self.generator.configure_verification( |
|
|
code_verification_on=self.config.code_verification_on, |
|
|
allowed_modules=self.config.allowed_modules, |
|
|
) |
|
|
|
|
|
self.executor = executor |
|
|
self.logger = logger |
|
|
self.retry_count = 0 |
|
|
|
|
|
self.logger.info("CodeInterpreter initialized successfully.") |
|
|
|
|
|
def reply( |
|
|
self, |
|
|
memory: Memory, |
|
|
event_handler: callable, |
|
|
prompt_log_path: Optional[str] = None, |
|
|
use_back_up_engine: Optional[bool] = False, |
|
|
) -> Post: |
|
|
response: Post = self.generator.reply( |
|
|
memory, |
|
|
event_handler, |
|
|
prompt_log_path, |
|
|
use_back_up_engine, |
|
|
) |
|
|
if response.message is not None: |
|
|
update_verification(response, "NONE", "No code verification is performed.") |
|
|
update_execution(response, "NONE", "No code is executed.") |
|
|
event_handler("CodeInterpreter->Planner", response.message) |
|
|
return response |
|
|
|
|
|
code = next((a for a in response.attachment_list if a.type == AttachmentType.python), None) |
|
|
|
|
|
if code is None: |
|
|
|
|
|
update_verification(response, "NONE", "No code verification is performed.") |
|
|
update_execution(response, "NONE", "No code is executed due to code generation failure.") |
|
|
response.message = "Failed to generate code." |
|
|
if self.retry_count < self.config.max_retry_count: |
|
|
error_message = format_output_revision_message() |
|
|
response.add_attachment( |
|
|
Attachment.create( |
|
|
AttachmentType.revise_message, |
|
|
error_message, |
|
|
), |
|
|
) |
|
|
response.send_to = "CodeInterpreter" |
|
|
event_handler( |
|
|
"CodeInterpreter->CodeInterpreter", |
|
|
error_message, |
|
|
) |
|
|
self.retry_count += 1 |
|
|
else: |
|
|
self.retry_count = 0 |
|
|
event_handler("CodeInterpreter->Planner", response.message) |
|
|
|
|
|
return response |
|
|
|
|
|
self.logger.info(f"Code to be verified: {code.content}") |
|
|
code_verify_errors = code_snippet_verification( |
|
|
code.content, |
|
|
[plugin.name for plugin in self.generator.get_plugin_pool()], |
|
|
self.config.code_verification_on, |
|
|
plugin_only=False, |
|
|
allowed_modules=self.config.allowed_modules, |
|
|
) |
|
|
|
|
|
if code_verify_errors is None: |
|
|
event_handler("verification", "NONE") |
|
|
update_verification(response, "NONE", "No code verification is performed.") |
|
|
elif len(code_verify_errors) > 0: |
|
|
self.logger.info( |
|
|
f"Code verification finished with {len(code_verify_errors)} errors.", |
|
|
) |
|
|
code_error = "\n".join(code_verify_errors) |
|
|
event_handler("verification", f"INCORRECT: {code_error}") |
|
|
update_verification(response, "INCORRECT", code_error) |
|
|
response.message = code_error |
|
|
if self.retry_count < self.config.max_retry_count: |
|
|
response.add_attachment( |
|
|
Attachment.create( |
|
|
AttachmentType.revise_message, |
|
|
format_code_correction_message(), |
|
|
), |
|
|
) |
|
|
response.send_to = "CodeInterpreter" |
|
|
event_handler( |
|
|
"CodeInterpreter->CodeInterpreter", |
|
|
format_code_correction_message(), |
|
|
) |
|
|
self.retry_count += 1 |
|
|
else: |
|
|
self.retry_count = 0 |
|
|
event_handler("CodeInterpreter->Planner", response.message) |
|
|
|
|
|
|
|
|
update_execution(response, "NONE", "No code is executed due to code verification failure.") |
|
|
return response |
|
|
elif len(code_verify_errors) == 0: |
|
|
event_handler("verification", "CORRECT") |
|
|
update_verification(response, "CORRECT", "No error is found.") |
|
|
|
|
|
self.logger.info(f"Code to be executed: {code.content}") |
|
|
|
|
|
exec_result = self.executor.execute_code( |
|
|
exec_id=response.id, |
|
|
code=code.content, |
|
|
) |
|
|
event_handler("status", "SUCCESS" if exec_result.is_success else "FAILURE") |
|
|
code_output = self.executor.format_code_output( |
|
|
exec_result, |
|
|
with_code=False, |
|
|
use_local_uri=self.config.use_local_uri, |
|
|
) |
|
|
|
|
|
event_handler("result", code_output) |
|
|
update_execution( |
|
|
response, |
|
|
status="SUCCESS" if exec_result.is_success else "FAILURE", |
|
|
result=code_output, |
|
|
) |
|
|
|
|
|
|
|
|
response.add_attachment( |
|
|
Attachment.create( |
|
|
AttachmentType.artifact_paths, |
|
|
[ |
|
|
( |
|
|
a.file_name |
|
|
if os.path.isabs(a.file_name) or not self.config.use_local_uri |
|
|
else os.path.join(self.executor.execution_cwd, a.file_name) |
|
|
) |
|
|
for a in exec_result.artifact |
|
|
], |
|
|
), |
|
|
) |
|
|
|
|
|
response.message = self.executor.format_code_output( |
|
|
exec_result, |
|
|
with_code=True, |
|
|
use_local_uri=self.config.use_local_uri, |
|
|
) |
|
|
|
|
|
if exec_result.is_success or self.retry_count >= self.config.max_retry_count: |
|
|
self.retry_count = 0 |
|
|
event_handler("CodeInterpreter->Planner", response.message) |
|
|
else: |
|
|
response.add_attachment( |
|
|
Attachment.create( |
|
|
AttachmentType.revise_message, |
|
|
format_code_revision_message(), |
|
|
), |
|
|
) |
|
|
response.send_to = "CodeInterpreter" |
|
|
event_handler( |
|
|
"CodeInterpreter->CodeInterpreter", |
|
|
format_code_revision_message(), |
|
|
) |
|
|
self.retry_count += 1 |
|
|
return response |
|
|
|