File size: 3,166 Bytes
e8d999b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import json
from copy import deepcopy
from typing import Any, Dict, List
from flow_modules.aiflows.OpenAIChatFlowModule import OpenAIChatAtomicFlow

from dataclasses import dataclass


@dataclass
class Command:
    name: str
    description: str
    input_args: List[str]


class ControllerAtomicFlow(OpenAIChatAtomicFlow):
    def __init__(self, commands: List[Command], **kwargs):
        super().__init__(**kwargs)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            commands=self._build_commands_manual(commands)
        )

    @staticmethod
    def _build_commands_manual(commands: List[Command]) -> str:
        ret = ""
        for i, command in enumerate(commands):
            command_input_json_schema = json.dumps(
                {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
            ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
        return ret

    @classmethod
    def instantiate_from_config(cls, config):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~ Set up commands ~~~
        commands = flow_config["commands"]
        commands = [
            Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
            commands.items()
        ]
        kwargs.update({"commands": commands})

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        hint_for_model = """
        Make sure your response is in the following format:
              Response Format:
              {
              "thought": "thought",
              "reasoning": "reasoning",
              "plan": "- short bulleted\n- list that conveys\n- long-term plan",
              "criticism": "constructive self-criticism",
              "speak": "thoughts summary to say to user",
              "command": "the python function you would like to call",
              "command_args": {
                  "arg name": "value"
                  }
              }
        """
        if 'goal' in input_data:
            input_data['goal'] += hint_for_model
        if 'human_feedback' in input_data:
            input_data['human_feedback'] += hint_for_model
        api_output = super().run(input_data)["api_output"].strip()

        try:
            response = json.loads(api_output)
            return response
        except json.decoder.JSONDecodeError:
            new_input_data = input_data.copy()
            new_input_data['observation'] = ""
            new_input_data['human_feedback'] = "The previous respond cannot be parsed with json.loads, it could be the backslashes used for escaping single quotes in the string arguments of the Python code are not properly escaped themselves within the JSON context."
            new_api_output = super().run(new_input_data)["api_output"].strip()
            return json.loads(new_api_output)