Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
raw
history blame
5.88 kB
# Copyright (c) Alibaba, Inc. and its affiliates.
import ast
from abc import ABC, abstractmethod
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union
import json
if TYPE_CHECKING:
from swift.llm.infer import Function
from swift.llm.template import Prompt
@dataclass
class AgentKeyword:
action: str = 'Action:'
action_input: str = 'Action Input:'
observation: str = 'Observation:'
@dataclass
class ToolDesc:
name_for_model: str
name_for_human: str
description_for_model: str
parameters: str
args_format: str
class ReactCompatMixin:
keyword = AgentKeyword()
@staticmethod
def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']:
from swift.llm.template import split_str_parts_by
from swift.llm.infer import Function
agent_parts = split_str_parts_by(response, list(asdict(keyword).values()))
functions = []
action_content = None
for part in agent_parts:
key, content = part['key'].lower(), part['content']
if action_content is None and key == keyword.action.lower():
action_content = content
elif action_content is not None and key == keyword.action_input.lower():
functions.append(Function(name=action_content, arguments=content))
action_content = None
return functions
def get_toolcall(self, response: str) -> List['Function']:
functions = self._split_action_action_input(response, self.keyword)
if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword:
# compat react
functions = self._split_action_action_input(response, ReactCompatMixin.keyword)
return functions
def _format_tool_responses(
self,
assistant_content: str,
tool_messages,
) -> Tuple[str, 'Prompt']:
assert len(tool_messages) > 0
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
if with_action:
if not assistant_content.endswith(self.keyword.observation):
if not assistant_content.endswith('\n'):
assistant_content += '\n'
assistant_content += self.keyword.observation
res = []
for i, tool_message in enumerate(tool_messages):
if i > 0:
res.append(self.keyword.observation)
tool_content = tool_message['content']
res.append(tool_content)
if not tool_content.endswith('\n'):
res.append('\n')
else:
res = []
for tool_message in tool_messages:
res.append(tool_message['content'])
return assistant_content, res
@staticmethod
def _parse_tool_call(content) -> Dict[str, Any]:
obj = BaseAgentTemplate._parse_json(content)
name = obj['name']
arguments = obj.get('arguments') or obj.get('parameters')
arguments = BaseAgentTemplate._parse_json(arguments)
assert arguments is not None, f'content: {content}'
return {'name': name, 'arguments': arguments}
def _format_tool_calls(self, tool_call_messages) -> str:
# -> assistant_content
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n'
f'{self.keyword.action_input} {tool_call["arguments"]}\n')
tool_calls.append(self.keyword.observation)
return ''.join(tool_calls)
class BaseAgentTemplate(ReactCompatMixin, ABC):
@staticmethod
def _get_tool_name(tool):
return tool.get('name_for_model') or tool.get('name')
@staticmethod
def unwrap_tool(tool):
assert isinstance(tool, dict), f'tool: {tool}'
if 'type' in tool and 'function' in tool:
tool = tool['function']
return tool
@staticmethod
def wrap_tool(tool):
assert isinstance(tool, dict), f'tool: {tool}'
if 'type' not in tool and 'function' not in tool:
tool = {'type': 'function', 'function': tool}
return tool
@staticmethod
def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc:
tool = BaseAgentTemplate.unwrap_tool(tool)
name_for_model = BaseAgentTemplate._get_tool_name(tool)
name_for_human = tool.get('name_for_human') or name_for_model
description = tool.get('description') or tool.get('description_for_model')
parameters = tool.get('parameters') or {}
parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False)
args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.'
tool_desc = ToolDesc(
name_for_model=name_for_model,
name_for_human=name_for_human,
description_for_model=description,
parameters=parameters,
args_format=args_format)
assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}'
return tool_desc
@staticmethod
def _parse_json(json_str: str) -> Optional[Any]:
if not isinstance(json_str, str):
return json_str
try:
res = json.loads(json_str)
except json.JSONDecodeError:
try:
res = ast.literal_eval(json_str)
except Exception:
return
return res
@abstractmethod
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
pass