File size: 5,879 Bytes
7feac49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union

import json

if TYPE_CHECKING:
    from swift.llm.infer import Function
    from swift.llm.template import Prompt


@dataclass
class AgentKeyword:
    action: str = 'Action:'
    action_input: str = 'Action Input:'
    observation: str = 'Observation:'


@dataclass
class ToolDesc:
    name_for_model: str
    name_for_human: str
    description_for_model: str
    parameters: str
    args_format: str


class ReactCompatMixin:
    keyword = AgentKeyword()

    @staticmethod
    def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']:
        from swift.llm.template import split_str_parts_by
        from swift.llm.infer import Function
        agent_parts = split_str_parts_by(response, list(asdict(keyword).values()))
        functions = []
        action_content = None

        for part in agent_parts:
            key, content = part['key'].lower(), part['content']
            if action_content is None and key == keyword.action.lower():
                action_content = content
            elif action_content is not None and key == keyword.action_input.lower():
                functions.append(Function(name=action_content, arguments=content))
                action_content = None

        return functions

    def get_toolcall(self, response: str) -> List['Function']:
        functions = self._split_action_action_input(response, self.keyword)
        if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword:
            # compat react
            functions = self._split_action_action_input(response, ReactCompatMixin.keyword)
        return functions

    def _format_tool_responses(
        self,
        assistant_content: str,
        tool_messages,
    ) -> Tuple[str, 'Prompt']:
        assert len(tool_messages) > 0
        with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
        if with_action:
            if not assistant_content.endswith(self.keyword.observation):
                if not assistant_content.endswith('\n'):
                    assistant_content += '\n'
                assistant_content += self.keyword.observation
            res = []
            for i, tool_message in enumerate(tool_messages):
                if i > 0:
                    res.append(self.keyword.observation)
                tool_content = tool_message['content']
                res.append(tool_content)
                if not tool_content.endswith('\n'):
                    res.append('\n')
        else:
            res = []
            for tool_message in tool_messages:
                res.append(tool_message['content'])
        return assistant_content, res

    @staticmethod
    def _parse_tool_call(content) -> Dict[str, Any]:
        obj = BaseAgentTemplate._parse_json(content)
        name = obj['name']
        arguments = obj.get('arguments') or obj.get('parameters')
        arguments = BaseAgentTemplate._parse_json(arguments)
        assert arguments is not None, f'content: {content}'
        return {'name': name, 'arguments': arguments}

    def _format_tool_calls(self, tool_call_messages) -> str:
        # -> assistant_content
        tool_calls = []
        for message in tool_call_messages:
            tool_call = self._parse_tool_call(message['content'])
            tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n'
                              f'{self.keyword.action_input} {tool_call["arguments"]}\n')
        tool_calls.append(self.keyword.observation)
        return ''.join(tool_calls)


class BaseAgentTemplate(ReactCompatMixin, ABC):

    @staticmethod
    def _get_tool_name(tool):
        return tool.get('name_for_model') or tool.get('name')

    @staticmethod
    def unwrap_tool(tool):
        assert isinstance(tool, dict), f'tool: {tool}'
        if 'type' in tool and 'function' in tool:
            tool = tool['function']
        return tool

    @staticmethod
    def wrap_tool(tool):
        assert isinstance(tool, dict), f'tool: {tool}'
        if 'type' not in tool and 'function' not in tool:
            tool = {'type': 'function', 'function': tool}
        return tool

    @staticmethod
    def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc:
        tool = BaseAgentTemplate.unwrap_tool(tool)
        name_for_model = BaseAgentTemplate._get_tool_name(tool)
        name_for_human = tool.get('name_for_human') or name_for_model

        description = tool.get('description') or tool.get('description_for_model')
        parameters = tool.get('parameters') or {}
        parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False)
        args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.'
        tool_desc = ToolDesc(
            name_for_model=name_for_model,
            name_for_human=name_for_human,
            description_for_model=description,
            parameters=parameters,
            args_format=args_format)
        assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}'
        return tool_desc

    @staticmethod
    def _parse_json(json_str: str) -> Optional[Any]:
        if not isinstance(json_str, str):
            return json_str
        try:
            res = json.loads(json_str)
        except json.JSONDecodeError:
            try:
                res = ast.literal_eval(json_str)
            except Exception:
                return
        return res

    @abstractmethod
    def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
        pass