|
|
import os |
|
|
from typing import List, Optional |
|
|
|
|
|
from injector import inject |
|
|
|
|
|
from taskweaver.code_interpreter.code_generator.plugin_selection import PluginSelector, SelectedPluginPool |
|
|
from taskweaver.config.module_config import ModuleConfig |
|
|
from taskweaver.llm import LLMApi |
|
|
from taskweaver.llm.util import ChatMessageType, format_chat_message |
|
|
from taskweaver.logging import TelemetryLogger |
|
|
from taskweaver.memory import Attachment, Conversation, Memory, Post, Round, RoundCompressor |
|
|
from taskweaver.memory.attachment import AttachmentType |
|
|
from taskweaver.memory.plugin import PluginEntry, PluginRegistry |
|
|
from taskweaver.misc.example import load_examples |
|
|
from taskweaver.role import PostTranslator, Role |
|
|
from taskweaver.utils import read_yaml |
|
|
|
|
|
|
|
|
class CodeGeneratorConfig(ModuleConfig): |
|
|
def _configure(self) -> None: |
|
|
self._set_name("code_generator") |
|
|
self.role_name = self._get_str("role_name", "ProgramApe") |
|
|
self.load_plugin = self._get_bool("load_plugin", True) |
|
|
self.load_example = self._get_bool("load_example", True) |
|
|
self.prompt_file_path = self._get_path( |
|
|
"prompt_file_path", |
|
|
os.path.join( |
|
|
os.path.dirname(os.path.abspath(__file__)), |
|
|
"code_generator_prompt.yaml", |
|
|
), |
|
|
) |
|
|
self.example_base_path = self._get_path( |
|
|
"example_base_path", |
|
|
os.path.join( |
|
|
self.src.app_base_path, |
|
|
"codeinterpreter_examples", |
|
|
), |
|
|
) |
|
|
self.prompt_compression = self._get_bool("prompt_compression", False) |
|
|
self.compression_prompt_path = self._get_path( |
|
|
"compression_prompt_path", |
|
|
os.path.join( |
|
|
os.path.dirname(os.path.abspath(__file__)), |
|
|
"compression_prompt.yaml", |
|
|
), |
|
|
) |
|
|
self.enable_auto_plugin_selection = self._get_bool( |
|
|
"enable_auto_plugin_selection", |
|
|
False, |
|
|
) |
|
|
self.auto_plugin_selection_topk = self._get_int("auto_plugin_selection_topk", 3) |
|
|
|
|
|
|
|
|
class CodeGenerator(Role): |
|
|
@inject |
|
|
def __init__( |
|
|
self, |
|
|
config: CodeGeneratorConfig, |
|
|
plugin_registry: PluginRegistry, |
|
|
logger: TelemetryLogger, |
|
|
llm_api: LLMApi, |
|
|
round_compressor: RoundCompressor, |
|
|
): |
|
|
self.config = config |
|
|
self.logger = logger |
|
|
self.llm_api = llm_api |
|
|
|
|
|
self.role_name = self.config.role_name |
|
|
|
|
|
self.post_translator = PostTranslator(logger) |
|
|
self.prompt_data = read_yaml(self.config.prompt_file_path) |
|
|
|
|
|
self.instruction_template = self.prompt_data["content"] |
|
|
|
|
|
self.conversation_head_template = self.prompt_data["conversation_head"] |
|
|
self.user_message_head_template = self.prompt_data["user_message_head"] |
|
|
self.plugin_pool = plugin_registry.get_list() |
|
|
self.query_requirements_template = self.prompt_data["requirements"] |
|
|
|
|
|
self.examples = None |
|
|
self.code_verification_on: bool = False |
|
|
self.allowed_modules: List[str] = [] |
|
|
|
|
|
self.instruction = self.instruction_template.format( |
|
|
ROLE_NAME=self.role_name, |
|
|
) |
|
|
|
|
|
self.round_compressor: RoundCompressor = round_compressor |
|
|
self.compression_template = read_yaml(self.config.compression_prompt_path)["content"] |
|
|
|
|
|
if self.config.enable_auto_plugin_selection: |
|
|
self.plugin_selector = PluginSelector(plugin_registry, self.llm_api) |
|
|
self.plugin_selector.generate_plugin_embeddings() |
|
|
logger.info("Plugin embeddings generated") |
|
|
self.selected_plugin_pool = SelectedPluginPool() |
|
|
|
|
|
def configure_verification( |
|
|
self, |
|
|
code_verification_on: bool, |
|
|
allowed_modules: Optional[List[str]] = None, |
|
|
): |
|
|
self.allowed_modules = allowed_modules if allowed_modules is not None else [] |
|
|
self.code_verification_on = code_verification_on |
|
|
|
|
|
def compose_verification_requirements( |
|
|
self, |
|
|
plugin_list: List[PluginEntry], |
|
|
) -> str: |
|
|
requirements: List[str] = [] |
|
|
if not self.code_verification_on: |
|
|
return "" |
|
|
|
|
|
if len(self.allowed_modules) > 0: |
|
|
requirements.append( |
|
|
f"- {self.role_name} can only import the following Python modules: " |
|
|
+ ", ".join([f"{module}" for module in self.allowed_modules]), |
|
|
) |
|
|
|
|
|
if len(self.allowed_modules) == 0: |
|
|
requirements.append(f"- {self.role_name} cannot import any Python modules.") |
|
|
return "\n".join(requirements) |
|
|
|
|
|
def compose_prompt( |
|
|
self, |
|
|
rounds: List[Round], |
|
|
plugins: List[PluginEntry], |
|
|
) -> List[ChatMessageType]: |
|
|
chat_history = [format_chat_message(role="system", message=self.instruction)] |
|
|
|
|
|
if self.examples is None: |
|
|
self.examples = self.load_examples() |
|
|
for i, example in enumerate(self.examples): |
|
|
chat_history.extend( |
|
|
self.compose_conversation(example.rounds, example.plugins, add_requirements=False), |
|
|
) |
|
|
|
|
|
summary = None |
|
|
if self.config.prompt_compression: |
|
|
summary, rounds = self.round_compressor.compress_rounds( |
|
|
rounds, |
|
|
rounds_formatter=lambda _rounds: str( |
|
|
self.compose_conversation(_rounds, plugins, add_requirements=False), |
|
|
), |
|
|
use_back_up_engine=True, |
|
|
prompt_template=self.compression_template, |
|
|
) |
|
|
|
|
|
chat_history.extend( |
|
|
self.compose_conversation( |
|
|
rounds, |
|
|
add_requirements=True, |
|
|
summary=summary, |
|
|
plugins=plugins, |
|
|
), |
|
|
) |
|
|
return chat_history |
|
|
|
|
|
def format_attachment(self, attachment: Attachment): |
|
|
if attachment.type == AttachmentType.thought: |
|
|
return attachment.content.format(ROLE_NAME=self.role_name) |
|
|
else: |
|
|
return attachment.content |
|
|
|
|
|
def compose_conversation( |
|
|
self, |
|
|
rounds: List[Round], |
|
|
plugins: List[PluginEntry], |
|
|
add_requirements: bool = False, |
|
|
summary: Optional[str] = None, |
|
|
) -> List[ChatMessageType]: |
|
|
chat_history: List[ChatMessageType] = [] |
|
|
ignored_types = [ |
|
|
AttachmentType.revise_message, |
|
|
AttachmentType.verification, |
|
|
AttachmentType.code_error, |
|
|
AttachmentType.execution_status, |
|
|
AttachmentType.execution_result, |
|
|
] |
|
|
|
|
|
is_first_post = True |
|
|
last_post: Post = None |
|
|
for round_index, conversation_round in enumerate(rounds): |
|
|
for post_index, post in enumerate(conversation_round.post_list): |
|
|
|
|
|
user_message = "" |
|
|
assistant_message = "" |
|
|
is_final_post = round_index == len(rounds) - 1 and post_index == len(conversation_round.post_list) - 1 |
|
|
if is_first_post: |
|
|
user_message = ( |
|
|
self.conversation_head_template.format( |
|
|
SUMMARY="None" if summary is None else summary, |
|
|
PLUGINS="None" if len(plugins) == 0 else self.format_plugins(plugins), |
|
|
ROLE_NAME=self.role_name, |
|
|
) |
|
|
+ "\n" |
|
|
) |
|
|
is_first_post = False |
|
|
|
|
|
if post.send_from == "Planner" and post.send_to == "CodeInterpreter": |
|
|
user_query = conversation_round.user_query |
|
|
plan = next(iter(post.get_attachment(AttachmentType.plan)), None) |
|
|
enrichment = "" |
|
|
if plan is not None: |
|
|
enrichment = ( |
|
|
f"To complete this request: {user_query}\n\n" |
|
|
f"I have drawn up a plan: \n{plan}\n\n" |
|
|
f"Please proceed with this step of this plan:" |
|
|
) |
|
|
|
|
|
user_feedback = "None" |
|
|
if last_post is not None and last_post.send_from == "CodeInterpreter": |
|
|
user_feedback = format_code_feedback(last_post) |
|
|
|
|
|
user_message += self.user_message_head_template.format( |
|
|
FEEDBACK=user_feedback, |
|
|
MESSAGE=f"{enrichment}{post.message}", |
|
|
) |
|
|
elif post.send_from == post.send_to == "CodeInterpreter": |
|
|
|
|
|
user_message += self.user_message_head_template.format( |
|
|
FEEDBACK=format_code_feedback(post), |
|
|
MESSAGE=f"{post.get_attachment(AttachmentType.revise_message)[0]}", |
|
|
) |
|
|
|
|
|
assistant_message = self.post_translator.post_to_raw_text( |
|
|
post=post, |
|
|
content_formatter=self.format_attachment, |
|
|
if_format_message=False, |
|
|
if_format_send_to=False, |
|
|
ignored_types=ignored_types, |
|
|
) |
|
|
elif post.send_from == "CodeInterpreter" and post.send_to == "Planner": |
|
|
if is_final_post: |
|
|
|
|
|
|
|
|
|
|
|
user_message += self.user_message_head_template.format( |
|
|
FEEDBACK=format_code_feedback(post), |
|
|
MESSAGE="This is the feedback.", |
|
|
) |
|
|
|
|
|
assistant_message = self.post_translator.post_to_raw_text( |
|
|
post=post, |
|
|
content_formatter=self.format_attachment, |
|
|
if_format_message=False, |
|
|
if_format_send_to=False, |
|
|
ignored_types=ignored_types, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Invalid post: {post}") |
|
|
last_post = post |
|
|
|
|
|
if len(assistant_message) > 0: |
|
|
chat_history.append( |
|
|
format_chat_message( |
|
|
role="assistant", |
|
|
message=assistant_message, |
|
|
), |
|
|
) |
|
|
if len(user_message) > 0: |
|
|
|
|
|
if is_final_post and add_requirements: |
|
|
user_message += "\n" + self.query_requirements_template.format( |
|
|
CODE_GENERATION_REQUIREMENTS=self.compose_verification_requirements(plugins), |
|
|
ROLE_NAME=self.role_name, |
|
|
) |
|
|
chat_history.append( |
|
|
format_chat_message(role="user", message=user_message), |
|
|
) |
|
|
|
|
|
return chat_history |
|
|
|
|
|
def select_plugins_for_prompt( |
|
|
self, |
|
|
user_query: str, |
|
|
) -> List[PluginEntry]: |
|
|
selected_plugins = self.plugin_selector.plugin_select( |
|
|
user_query, |
|
|
self.config.auto_plugin_selection_topk, |
|
|
) |
|
|
self.selected_plugin_pool.add_selected_plugins(selected_plugins) |
|
|
self.logger.info(f"Selected plugins: {[p.name for p in selected_plugins]}") |
|
|
self.logger.info( |
|
|
f"Selected plugin pool: {[p.name for p in self.selected_plugin_pool.get_plugins()]}", |
|
|
) |
|
|
|
|
|
return self.selected_plugin_pool.get_plugins() |
|
|
|
|
|
def reply( |
|
|
self, |
|
|
memory: Memory, |
|
|
event_handler: callable, |
|
|
prompt_log_path: Optional[str] = None, |
|
|
use_back_up_engine: bool = False, |
|
|
) -> Post: |
|
|
|
|
|
rounds = memory.get_role_rounds( |
|
|
role="CodeInterpreter", |
|
|
include_failure_rounds=False, |
|
|
) |
|
|
|
|
|
|
|
|
user_query = rounds[-1].user_query |
|
|
|
|
|
if self.config.enable_auto_plugin_selection: |
|
|
self.plugin_pool = self.select_plugins_for_prompt(user_query) |
|
|
|
|
|
prompt = self.compose_prompt(rounds, self.plugin_pool) |
|
|
|
|
|
def early_stop(_type: AttachmentType, value: str) -> bool: |
|
|
if _type in [AttachmentType.text, AttachmentType.python, AttachmentType.sample]: |
|
|
return True |
|
|
else: |
|
|
return False |
|
|
|
|
|
response = self.post_translator.raw_text_to_post( |
|
|
llm_output=self.llm_api.chat_completion( |
|
|
prompt, |
|
|
use_backup_engine=use_back_up_engine, |
|
|
)["content"], |
|
|
send_from="CodeInterpreter", |
|
|
event_handler=event_handler, |
|
|
early_stop=early_stop, |
|
|
) |
|
|
response.send_to = "Planner" |
|
|
generated_code = "" |
|
|
for attachment in response.attachment_list: |
|
|
if attachment.type in [AttachmentType.sample, AttachmentType.text]: |
|
|
response.message = attachment.content |
|
|
break |
|
|
elif attachment.type == AttachmentType.python: |
|
|
generated_code = attachment.content |
|
|
break |
|
|
|
|
|
if self.config.enable_auto_plugin_selection: |
|
|
|
|
|
self.selected_plugin_pool.filter_unused_plugins(code=generated_code) |
|
|
|
|
|
if prompt_log_path is not None: |
|
|
self.logger.dump_log_file(prompt, prompt_log_path) |
|
|
|
|
|
return response |
|
|
|
|
|
def format_plugins( |
|
|
self, |
|
|
plugin_list: List[PluginEntry], |
|
|
) -> str: |
|
|
if self.config.load_plugin: |
|
|
return "\n".join( |
|
|
[plugin.format_prompt() for plugin in plugin_list], |
|
|
) |
|
|
return "" |
|
|
|
|
|
def load_examples( |
|
|
self, |
|
|
) -> List[Conversation]: |
|
|
if self.config.load_example: |
|
|
return load_examples( |
|
|
folder=self.config.example_base_path, |
|
|
) |
|
|
return [] |
|
|
|
|
|
def get_plugin_pool(self) -> List[PluginEntry]: |
|
|
return self.plugin_pool |
|
|
|
|
|
|
|
|
def format_code_revision_message() -> str: |
|
|
return ( |
|
|
"The execution of the previous generated code has failed. " |
|
|
"If you think you can fix the problem by rewriting the code, " |
|
|
"please generate code and run it again.\n" |
|
|
"Otherwise, please explain the problem to me." |
|
|
) |
|
|
|
|
|
|
|
|
def format_output_revision_message() -> str: |
|
|
return ( |
|
|
"Your previous message is not following the output format. " |
|
|
"You must generate the output as a JSON object with the following format:\n" |
|
|
'{"response": [{"type":"this is the type", "content": "this is the content"}, ...]}\n' |
|
|
"You need at least have an element with type 'python' and content being the code to be executed.\n" |
|
|
"Don't surround the JSON with ```json and ```, just send the JSON object directly.\n" |
|
|
"Please try again." |
|
|
) |
|
|
|
|
|
|
|
|
def format_code_feedback(post: Post) -> str: |
|
|
feedback = "" |
|
|
verification_status = "" |
|
|
execution_status = "" |
|
|
for attachment in post.attachment_list: |
|
|
if attachment.type == AttachmentType.verification and attachment.content == "CORRECT": |
|
|
feedback += "## Verification\nI have verified that your code is CORRECT.\n" |
|
|
verification_status = "CORRECT" |
|
|
elif attachment.type == AttachmentType.verification and attachment.content == "NONE": |
|
|
feedback += "## Verification\nNo code verification.\n" |
|
|
verification_status = "NONE" |
|
|
elif attachment.type == AttachmentType.verification and attachment.content == "INCORRECT": |
|
|
feedback += "## Verification\nYour code is INCORRECT with the following error:\n" |
|
|
verification_status = "INCORRECT" |
|
|
elif attachment.type == AttachmentType.code_error and verification_status == "INCORRECT": |
|
|
feedback += f"{attachment.content}\n" |
|
|
elif attachment.type == AttachmentType.execution_status and attachment.content == "NONE": |
|
|
feedback += "## Execution\nNo code execution.\n" |
|
|
execution_status = "NONE" |
|
|
elif attachment.type == AttachmentType.execution_status and attachment.content == "SUCCESS": |
|
|
feedback += "## Execution\nYour code has been executed successfully with the following result:\n" |
|
|
execution_status = "SUCCESS" |
|
|
elif attachment.type == AttachmentType.execution_status and attachment.content == "FAILURE": |
|
|
feedback += "## Execution\nYour code has failed to execute with the following error:\n" |
|
|
execution_status = "FAILURE" |
|
|
elif attachment.type == AttachmentType.execution_result and execution_status != "NONE": |
|
|
feedback += f"{attachment.content}\n" |
|
|
return feedback |
|
|
|