|
|
|
|
|
import ast |
|
|
from abc import ABC, abstractmethod |
|
|
from dataclasses import asdict, dataclass |
|
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Tuple, Union |
|
|
|
|
|
import json |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from swift.llm.infer import Function |
|
|
from swift.llm.template import Prompt |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class AgentKeyword: |
|
|
action: str = 'Action:' |
|
|
action_input: str = 'Action Input:' |
|
|
observation: str = 'Observation:' |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ToolDesc: |
|
|
name_for_model: str |
|
|
name_for_human: str |
|
|
description_for_model: str |
|
|
parameters: str |
|
|
args_format: str |
|
|
|
|
|
|
|
|
class ReactCompatMixin: |
|
|
keyword = AgentKeyword() |
|
|
|
|
|
@staticmethod |
|
|
def _split_action_action_input(response: str, keyword: AgentKeyword) -> List['Function']: |
|
|
from swift.llm.template import split_str_parts_by |
|
|
from swift.llm.infer import Function |
|
|
agent_parts = split_str_parts_by(response, list(asdict(keyword).values())) |
|
|
functions = [] |
|
|
action_content = None |
|
|
|
|
|
for part in agent_parts: |
|
|
key, content = part['key'].lower(), part['content'] |
|
|
if action_content is None and key == keyword.action.lower(): |
|
|
action_content = content |
|
|
elif action_content is not None and key == keyword.action_input.lower(): |
|
|
functions.append(Function(name=action_content, arguments=content)) |
|
|
action_content = None |
|
|
|
|
|
return functions |
|
|
|
|
|
def get_toolcall(self, response: str) -> List['Function']: |
|
|
functions = self._split_action_action_input(response, self.keyword) |
|
|
if len(functions) == 0 and self.keyword != ReactCompatMixin.keyword: |
|
|
|
|
|
functions = self._split_action_action_input(response, ReactCompatMixin.keyword) |
|
|
return functions |
|
|
|
|
|
def _format_tool_responses( |
|
|
self, |
|
|
assistant_content: str, |
|
|
tool_messages, |
|
|
) -> Tuple[str, 'Prompt']: |
|
|
assert len(tool_messages) > 0 |
|
|
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content |
|
|
if with_action: |
|
|
if not assistant_content.endswith(self.keyword.observation): |
|
|
if not assistant_content.endswith('\n'): |
|
|
assistant_content += '\n' |
|
|
assistant_content += self.keyword.observation |
|
|
res = [] |
|
|
for i, tool_message in enumerate(tool_messages): |
|
|
if i > 0: |
|
|
res.append(self.keyword.observation) |
|
|
tool_content = tool_message['content'] |
|
|
res.append(tool_content) |
|
|
if not tool_content.endswith('\n'): |
|
|
res.append('\n') |
|
|
else: |
|
|
res = [] |
|
|
for tool_message in tool_messages: |
|
|
res.append(tool_message['content']) |
|
|
return assistant_content, res |
|
|
|
|
|
@staticmethod |
|
|
def _parse_tool_call(content) -> Dict[str, Any]: |
|
|
obj = BaseAgentTemplate._parse_json(content) |
|
|
name = obj['name'] |
|
|
arguments = obj.get('arguments') or obj.get('parameters') |
|
|
arguments = BaseAgentTemplate._parse_json(arguments) |
|
|
assert arguments is not None, f'content: {content}' |
|
|
return {'name': name, 'arguments': arguments} |
|
|
|
|
|
def _format_tool_calls(self, tool_call_messages) -> str: |
|
|
|
|
|
tool_calls = [] |
|
|
for message in tool_call_messages: |
|
|
tool_call = self._parse_tool_call(message['content']) |
|
|
tool_calls.append(f'{self.keyword.action} {tool_call["name"]}\n' |
|
|
f'{self.keyword.action_input} {tool_call["arguments"]}\n') |
|
|
tool_calls.append(self.keyword.observation) |
|
|
return ''.join(tool_calls) |
|
|
|
|
|
|
|
|
class BaseAgentTemplate(ReactCompatMixin, ABC): |
|
|
|
|
|
@staticmethod |
|
|
def _get_tool_name(tool): |
|
|
return tool.get('name_for_model') or tool.get('name') |
|
|
|
|
|
@staticmethod |
|
|
def unwrap_tool(tool): |
|
|
assert isinstance(tool, dict), f'tool: {tool}' |
|
|
if 'type' in tool and 'function' in tool: |
|
|
tool = tool['function'] |
|
|
return tool |
|
|
|
|
|
@staticmethod |
|
|
def wrap_tool(tool): |
|
|
assert isinstance(tool, dict), f'tool: {tool}' |
|
|
if 'type' not in tool and 'function' not in tool: |
|
|
tool = {'type': 'function', 'function': tool} |
|
|
return tool |
|
|
|
|
|
@staticmethod |
|
|
def _parse_tool(tool, lang: Literal['zh', 'en']) -> ToolDesc: |
|
|
tool = BaseAgentTemplate.unwrap_tool(tool) |
|
|
name_for_model = BaseAgentTemplate._get_tool_name(tool) |
|
|
name_for_human = tool.get('name_for_human') or name_for_model |
|
|
|
|
|
description = tool.get('description') or tool.get('description_for_model') |
|
|
parameters = tool.get('parameters') or {} |
|
|
parameters = parameters if isinstance(parameters, str) else json.dumps(parameters, ensure_ascii=False) |
|
|
args_format = '此工具的输入应为JSON对象。' if lang == 'zh' else 'Format the arguments as a JSON object.' |
|
|
tool_desc = ToolDesc( |
|
|
name_for_model=name_for_model, |
|
|
name_for_human=name_for_human, |
|
|
description_for_model=description, |
|
|
parameters=parameters, |
|
|
args_format=args_format) |
|
|
assert name_for_model is not None and description is not None, f'tool_desc: {tool_desc}' |
|
|
return tool_desc |
|
|
|
|
|
@staticmethod |
|
|
def _parse_json(json_str: str) -> Optional[Any]: |
|
|
if not isinstance(json_str, str): |
|
|
return json_str |
|
|
try: |
|
|
res = json.loads(json_str) |
|
|
except json.JSONDecodeError: |
|
|
try: |
|
|
res = ast.literal_eval(json_str) |
|
|
except Exception: |
|
|
return |
|
|
return res |
|
|
|
|
|
@abstractmethod |
|
|
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: |
|
|
pass |
|
|
|