|
|
import json |
|
|
from copy import deepcopy |
|
|
from typing import Any, Dict |
|
|
from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow |
|
|
|
|
|
|
|
|
|
|
|
class CodeGeneratorAtomicFlow(ChatAtomicFlow): |
|
|
"""This class wraps around the Chat API to generate code from a goal. One thing worth noting is that we need to |
|
|
make sure the code generator does not write repetitive code that is present in the library, so we need to inject |
|
|
the function signatures in the library to the system prompts. |
|
|
|
|
|
*Input Interface Non Initialized*: |
|
|
- `goal` |
|
|
- `code_library` |
|
|
- `memory_files` |
|
|
|
|
|
*Input Interface Initialized*: |
|
|
- `goal` |
|
|
- `code_library` |
|
|
- `memory_files` |
|
|
|
|
|
*Output Interface*: |
|
|
- `code` |
|
|
- `language_of_code` |
|
|
|
|
|
|
|
|
*Configuration Parameters*: |
|
|
- Also refer to ChatAtomicFlow (https://huggingface.co/aiflows/ChatFlowModule/blob/main/ChatAtomicFlow.py) |
|
|
- `input_interface_non_initialized`: The input interface when the conversation is not initialized. |
|
|
- `input_interface_initialized`: The input interface when the conversation is initialized. |
|
|
- `output_interface`: The output interface. |
|
|
- `backend`: The backend to use for the Chat API. |
|
|
- `system_message_prompt_template`: The template for the system message prompt. |
|
|
- `human_message_prompt_template`: The template for the human message prompt. |
|
|
- `init_human_message_prompt_template`: The initial human message prompt. |
|
|
|
|
|
|
|
|
""" |
|
|
def __init__(self, **kwargs): |
|
|
"""Initialize the CodeGeneratorAtomicFlow. |
|
|
:param kwargs: Keyword arguments. |
|
|
:type kwargs: Any |
|
|
""" |
|
|
super().__init__(**kwargs) |
|
|
self.system_message_prompt_template = self.system_message_prompt_template.partial( |
|
|
code_library_file_location="no location yet", |
|
|
code_library="no code yet" |
|
|
) |
|
|
self.hint_for_model = """ |
|
|
Make sure your response is in the following format: |
|
|
Response Format: |
|
|
{ |
|
|
"language_of_code": "language of the code", |
|
|
"code": "String of the code and docstrings corresponding to the goal", |
|
|
} |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def instantiate_from_config(cls, config): |
|
|
"""Instantiate a CodeGeneratorAtomicFlow from a configuration. |
|
|
:param config: Configuration dictionary. |
|
|
:type config: Dict[str, Any] |
|
|
:return: Instantiated CodeGeneratorAtomicFlow. |
|
|
:rtype: CodeGeneratorAtomicFlow |
|
|
""" |
|
|
flow_config = deepcopy(config) |
|
|
|
|
|
kwargs = {"flow_config": flow_config} |
|
|
|
|
|
|
|
|
kwargs.update(cls._set_up_prompts(flow_config)) |
|
|
|
|
|
|
|
|
kwargs.update(cls._set_up_backend(flow_config)) |
|
|
|
|
|
|
|
|
return cls(**kwargs) |
|
|
|
|
|
def _get_code_library_file(self, input_data: Dict[str, Any]): |
|
|
"""Get the code library file location from the input data. |
|
|
:param input_data: Input data. |
|
|
:type input_data: Dict[str, Any] |
|
|
:return: Code library file location. |
|
|
:rtype: str |
|
|
:raises AssertionError: If memory_files is not in input_data. |
|
|
:raises AssertionError: If code_library is not in memory_files. |
|
|
""" |
|
|
assert "memory_files" in input_data, "memory_files not passed to CodeGeneratorAtomicFlow" |
|
|
assert "code_library" in input_data['memory_files'], "code_library not in memory_files" |
|
|
code_library_file_location = input_data['memory_files']['code_library'] |
|
|
return code_library_file_location |
|
|
|
|
|
def _get_code_library_content(self, input_data: Dict[str, Any]): |
|
|
"""Get the code library content from the input data. |
|
|
:param input_data: Input data. |
|
|
:type input_data: Dict[str, Any] |
|
|
:return: Code library content. |
|
|
:rtype: str |
|
|
:raises AssertionError: If code_library is not in input_data. |
|
|
""" |
|
|
assert "code_library" in input_data, "code_library not passed to CodeGeneratorAtomicFlow" |
|
|
code_library = input_data['code_library'] |
|
|
if len(code_library) == 0: |
|
|
code_library = "No code yet" |
|
|
return code_library |
|
|
|
|
|
def _update_prompts_and_input(self, input_data: Dict[str, Any]): |
|
|
"""Update the prompts and input data. |
|
|
:param input_data: Input data. |
|
|
:type input_data: Dict[str, Any] |
|
|
""" |
|
|
if 'goal' in input_data: |
|
|
input_data['goal'] += self.hint_for_model |
|
|
code_library_file_location = self._get_code_library_file(input_data) |
|
|
code_library = self._get_code_library_content(input_data) |
|
|
self.system_message_prompt_template = self.system_message_prompt_template.partial( |
|
|
code_library_file_location=code_library_file_location, |
|
|
code_library=code_library |
|
|
) |
|
|
|
|
|
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]: |
|
|
"""Run the flow. |
|
|
:param input_data: Input data. |
|
|
:type input_data: Dict[str, Any] |
|
|
:return: Output data. |
|
|
:rtype: Dict[str, Any] |
|
|
""" |
|
|
self._update_prompts_and_input(input_data) |
|
|
|
|
|
while True: |
|
|
api_output = super().run(input_data)["api_output"].strip() |
|
|
try: |
|
|
response = json.loads(api_output) |
|
|
return response |
|
|
except (json.decoder.JSONDecodeError, json.JSONDecodeError): |
|
|
new_goal = "The previous response cannot be parsed with json.loads, it cannot be parsed with json.loads, it could be the backslashes usesd for escaping single quotes in the string arguments of the code are not properly escaped themselves within the JSON context. Next time, do not provide any comments or code blocks. Make sure your next response is purely json parsable." |
|
|
new_input_data = input_data.copy() |
|
|
new_input_data['goal'] = new_goal |
|
|
input_data = new_input_data |
|
|
|