File size: 4,780 Bytes
d4a9b53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
531832b
 
 
 
 
 
 
 
d4a9b53
 
 
 
 
8746d5b
d4a9b53
8746d5b
d4a9b53
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import json
from copy import deepcopy
from typing import Any, Dict, List

from flow_modules.aiflows.ChatFlowModule import ChatAtomicFlow


from dataclasses import dataclass


@dataclass
class Command:
    name: str
    description: str
    input_args: List[str]

# TODO: controller should be generalized
class Controller_JarvisFlow(ChatAtomicFlow):
    def __init__(
            self,
            commands: List[Command],
            **kwargs):
        super().__init__(**kwargs)
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            commands=self._build_commands_manual(commands),
            plan="no plans yet",
            plan_file_location="no plan file location yet",
            logs="no logs yet",
        )
        self.hint_for_model = """
        Make sure your response is in the following format:
              Response Format:
              {
              "command": "call one of the subordinates",
              "command_args": {
                  "arg name": "value"
                  }
              }
        """

    def _get_content_file_location(self, input_data, content_name):
        # get the location of the file that contains the content: plan, logs, code_library
        assert "memory_files" in input_data, "memory_files not passed to Jarvis/Controller"
        assert content_name in input_data["memory_files"], f"{content_name} not in memory files"
        return input_data["memory_files"][content_name]

    def _get_content(self, input_data, content_name):
        # get the content of the file that contains the content: plan, logs, code_library
        assert content_name in input_data, f"{content_name} not passed to Jarvis/Controller"
        content = input_data[content_name]
        if len(content) == 0:
            content = f'No {content_name} yet'
        return content
    @staticmethod
    def _build_commands_manual(commands: List[Command]) -> str:
        ret = ""
        for i, command in enumerate(commands):
            command_input_json_schema = json.dumps(
                {input_arg: f"YOUR_{input_arg.upper()}" for input_arg in command.input_args})
            ret += f"{i + 1}. {command.name}: {command.description} Input arguments (given in the JSON schema): {command_input_json_schema}\n"
        return ret


    @classmethod
    def instantiate_from_config(cls, config):
        flow_config = deepcopy(config)

        kwargs = {"flow_config": flow_config}

        # ~~~ Set up prompts ~~~
        kwargs.update(cls._set_up_prompts(flow_config))

        # ~~~Set up backend ~~~
        kwargs.update(cls._set_up_backend(flow_config))

        # ~~~ Set up commands ~~~
        commands = flow_config["commands"]
        commands = [
            Command(name, command_conf["description"], command_conf["input_args"]) for name, command_conf in
            commands.items()
        ]
        kwargs.update({"commands": commands})

        # ~~~ Instantiate flow ~~~
        return cls(**kwargs)

    def _update_prompts_and_input(self, input_data: Dict[str, Any]):
        if 'goal' in input_data:
            input_data['goal'] += self.hint_for_model
        if 'result' in input_data:
            input_data['result'] += self.hint_for_model
        plan_file_location = self._get_content_file_location(input_data, "plan")
        plan_content = self._get_content(input_data, "plan")
        logs_content = self._get_content(input_data, "logs")
        self.system_message_prompt_template = self.system_message_prompt_template.partial(
            plan_file_location=plan_file_location,
            plan=plan_content,
            logs=logs_content
        )

    def run(self, input_data: Dict[str, Any]) -> Dict[str, Any]:
        self._update_prompts_and_input(input_data)

        # ~~~when conversation is initialized, append the updated system prompts to the chat history ~~~
        if self._is_conversation_initialized():
            updated_system_message_content = self._get_message(self.system_message_prompt_template, input_data)
            self._state_update_add_chat_message(content=updated_system_message_content,
                                                role=self.flow_config["system_name"])


        api_output = super().run(input_data)["api_output"].strip()
        try:
            response = json.loads(api_output)
            return response
        except json.decoder.JSONDecodeError:
            new_goal = f"Here is your previous response {api_output}, it cannot be parsed with json.loads, please fix this issue."
            new_input_data = input_data.copy()
            new_input_data['goal'] = new_goal
            new_api_output = super().run(new_input_data)["api_output"].strip()
            return json.loads(new_api_output)