File size: 3,125 Bytes
7feac49 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# Copyright (c) Alibaba, Inc. and its affiliates.
import re
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
import json
from .base import BaseAgentTemplate
if TYPE_CHECKING:
from swift.llm.infer import Function
from swift.llm.template import Prompt
class GLM4AgentTemplate(BaseAgentTemplate):
is_glm4_0414 = False
@staticmethod
def _find_function_call(single_content: str) -> Optional['Function']:
from swift.llm.infer import Function
single_content = single_content.replace('<|observation|>', '')
pattern = re.compile(r'([^\n`]*?)\n({.*?})(?=\w*\n|$)', re.DOTALL)
matches = pattern.findall(single_content)
if not matches:
return
name, arguments = matches[0]
return Function(name=name, arguments=arguments)
def get_toolcall(self, response: str) -> List['Function']:
toolcall_list = response.split('<|assistant|>')
functions = []
for toolcall in toolcall_list:
function = self._find_function_call(toolcall)
if function:
functions.append(function)
if len(functions) == 0:
# compat react_en
return super().get_toolcall(response)
return functions
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str:
tool_descs = []
for tool in tools:
tool = self.unwrap_tool(tool)
name = self._get_tool_name(tool)
tool_descs.append(f'## {name}\n\n{json.dumps(tool, ensure_ascii=False, indent=4)}\n'
'在调用上述函数时,请使用 Json 格式表示调用的参数。')
glm4_system = '你是一个名为 GLM-4 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,你的任务是针对用户的问题和要求提供适当的答复和支持。\n\n' # noqa
return ('' if self.is_glm4_0414 else glm4_system) + """# 可用工具
""" + '\n'.join(tool_descs)
def _format_tool_responses(
self,
assistant_content: str,
tool_messages,
) -> Tuple[str, 'Prompt']:
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content
if with_action:
return super()._format_tool_responses(assistant_content, tool_messages)
res = ['\n']
for i, tool_message in enumerate(tool_messages):
tool_content = tool_message['content']
if i > 0:
res.append('<|observation|>\n')
res.append(tool_content)
res.append('<|assistant|>\n')
return assistant_content, res
def _format_tool_calls(self, tool_call_messages) -> str:
tool_calls = []
for message in tool_call_messages:
tool_call = self._parse_tool_call(message['content'])
tool_calls.append(f'{tool_call["name"]}\n{tool_call["arguments"]}')
return '<|assistant|>'.join(tool_calls) + '<|observation|>'
class GLM4_0414AgentTemplate(GLM4AgentTemplate):
is_glm4_0414 = True
|