ControllerAtomicFlowModule / ControllerAtomicFlow.py
Tachi67's picture
New version of atomic flow
3d30823
raw
history blame
5 kB
import importlib
import importlib.util
import json
import os.path
from copy import deepcopy
from typing import Any, Dict, List
from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow
from dataclasses import dataclass
@dataclass
class Command:
name: str
description: str
input_args: List[str]
class ControllerAtomicFlow(OpenAIChatAtomicFlow):
def __init__(
self,
commands: List[Command],
plan_file_location: str,
code_file_location: str,
**kwargs):
super().__init__(**kwargs)
if os.path.isdir(plan_file_location):
plan_file_location = os.path.join(plan_file_location, "plan.txt")
self.plan_file_location = plan_file_location
self.code_file_location = code_file_location
self.system_message_prompt_template = self.system_message_prompt_template.partial(
commands=self._build_commands_manual(commands)
)
self.hint_for_model = """
Make sure your response is in the following format:
Response Format:
{
"thought": "thought",
"reasoning": "reasoning",
"criticism": "constructive self-criticism",
"speak": "thoughts summary to say to user",
"command": "the python function you would like to call",
"command_args": {
"arg name": "value"
}
}
"""
self.original_system_template = self.system_message_prompt_template.template
@staticmethod
def _build_commands_manual(commands: List[Command]) -> str:
ret = ""
for i, command in enumerate(commands):
command_input_json_schema = json.dumps(
{input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
return ret
@classmethod
def instantiate_from_config(cls, config):
flow_config = deepcopy(config)
kwargs = {"flow_config": flow_config}
# ~~~ Set up prompts ~~~
kwargs.update(cls._set_up_prompts(flow_config))
kwargs.update(cls._set_up_backend(flow_config))
# ~~~ Set up commands ~~~
commands = flow_config["commands"]
commands = [
Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
commands.items()
]
kwargs.update({"commands": commands})
# ~~~ Instantiate flow ~~~
return cls(**kwargs)
def _get_library_function_signatures(self):
try:
spec = importlib.util.spec_from_file_location("code_library", self.code_file_location)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
ret = ''
import inspect
for name, obj in inspect.getmembers(module):
if inspect.isfunction(obj):
ret += f"{name}: {inspect.signature(obj)}\n"
return ret
except FileNotFoundError:
return 'There is no function available yet.'
def _get_plan(self):
try:
with open(self.plan_file_location, 'r') as file:
return file.read()
except FileNotFoundError:
return "There is no plan yet"
def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
if 'goal' in input_data:
input_data['goal'] += self.hint_for_model
if 'human_feedback' in input_data:
input_data['human_feedback'] += self.hint_for_model
# self.system_message_prompt_template.template
plan_to_append = self._get_plan()
function_signatures_to_append = self._get_library_function_signatures()
self.system_message_prompt_template.template = \
self.original_system_template + "\n" + f"Here are the available functions at {self.code_file_location}\n" \
+ function_signatures_to_append + "\n" \
+ f"Here is the step-by-step plan to achieve the goal:\n" \
+ plan_to_append + "\n"
api_output = super().run(input_data)["api_output"].strip()
try:
response = json.loads(api_output)
return response
except json.decoder.JSONDecodeError:
new_input_data = input_data.copy()
new_input_data['observation'] = ""
new_input_data['human_feedback'] = "The previous respond cannot be parsed with json.loads, it could be the backslashes used for escaping single quotes in the string arguments of the Python code are not properly escaped themselves within the JSON context. Make sure your next response is in JSON format."
new_api_output = super().run(new_input_data)["api_output"].strip()
return json.loads(new_api_output)