File size: 5,879 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import json
if TYPE_CHECKING:
from swift.llm.infer import Function
from swift.llm.template import Prompt
@dataclass
class AgentKeyword:
action: str = 'Action:'
action_input: str = 'Action Input:'
observation: str = 'Observation:'
@dataclass
class ToolDesc:
name_for_model: str
name_for_human: str
description_for_model: str
parameters: str
args_format: str
class ReactCompatMixin:
keyword = AgentKeyword()
@staticmethod
def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']:
from swift.llm.template import split_str_parts_by
from swift.llm.infer import Function
agent_parts = split_str_parts_by(response, list(asdict(keyword).values()))
functions = []
action_content = None
for part in agent_parts:
key, content = part['key'].lower(), part['content']
if action_content is None and key == keyword.action.lower():
action_content = content
elif action_content is not None and key == keyword.action_input.lower():
functions.append(Function(name=action_content, arguments=content))
action_content = None
return functions
def get_toolcall(self, response: str) -> List['Function']:
functions = self._split_action_action_input(response, self.keyword)
if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword:
# compat react
functions = self._split_action_action_input(response, ReactCompatMixin.keyword)
return functions
def _format_tool_responses(
self,
assistant_content: str,
tool_messages,
) -> Tuple[str, 'Prompt']:
assert len(tool_messages) > 0
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
if with_action:
if not assistant_content.endswith(self.keyword.observation):
if not assistant_content.endswith('\n'):
assistant_content += '\n'
assistant_content += self.keyword.observation
res = []
for i, tool_message in enumerate(tool_messages):
if i > 0:
res.append(self.keyword.observation)
tool_content = tool_message['content']
res.append(tool_content)
if not tool_content.endswith('\n'):
res.append('\n')
else:
res = []
for tool_message in tool_messages:
res.append(tool_message['content'])
return assistant_content, res
@staticmethod
def _parse_tool_call(content) -> Dict[str, Any]:
obj = BaseAgentTemplate._parse_json(content)
name = obj['name']
arguments = obj.get('arguments') or obj.get('parameters')
arguments = BaseAgentTemplate._parse_json(arguments)
assert arguments is not None, f'content: {content}'
return {'name': name, 'arguments': arguments}
def _format_tool_calls(self, tool_call_messages) -> str:
# -> assistant_content
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n'
f'{self.keyword.action_input} {tool_call["arguments"]}\n')
tool_calls.append(self.keyword.observation)
return ''.join(tool_calls)
class BaseAgentTemplate(ReactCompatMixin, ABC):
@staticmethod
def _get_tool_name(tool):
return tool.get('name_for_model') or tool.get('name')
@staticmethod
def unwrap_tool(tool):
assert isinstance(tool, dict), f'tool: {tool}'
if 'type' in tool and 'function' in tool:
tool = tool['function']
return tool
@staticmethod
def wrap_tool(tool):
assert isinstance(tool, dict), f'tool: {tool}'
if 'type' not in tool and 'function' not in tool:
tool = {'type': 'function', 'function': tool}
return tool
@staticmethod
def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc:
tool = BaseAgentTemplate.unwrap_tool(tool)
name_for_model = BaseAgentTemplate._get_tool_name(tool)
name_for_human = tool.get('name_for_human') or name_for_model
description = tool.get('description') or tool.get('description_for_model')
parameters = tool.get('parameters') or {}
parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False)
args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.'
tool_desc = ToolDesc(
name_for_model=name_for_model,
name_for_human=name_for_human,
description_for_model=description,
parameters=parameters,
args_format=args_format)
assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}'
return tool_desc
@staticmethod
def _parse_json(json_str: str) -> Optional[Any]:
if not isinstance(json_str, str):
return json_str
try:
res = json.loads(json_str)
except json.JSONDecodeError:
try:
res = ast.literal_eval(json_str)
except Exception:
return
return res
@abstractmethod
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
pass
|