# Copyright (c) Alibaba, Inc. and its affiliates. import re from typing import TYPE_CHECKING, List, Tuple, Union import json from .base import BaseAgentTemplate if TYPE_CHECKING: from swift.llm.infer import Function from swift.llm.template import Prompt class HermesAgentTemplate(BaseAgentTemplate): def get_toolcall(self, response: str) -> List['Function']: from swift.llm.infer import Function res_list = re.findall(r'(.+?)', response, re.DOTALL) functions = [] for res in res_list: res = self._parse_json(res) if isinstance(res, dict) and 'name' in res and 'arguments' in res: functions.append(Function(name=res['name'], arguments=res['arguments'])) if len(functions) == 0: # compat react_en return super().get_toolcall(response) return functions def _format_tool_responses( self, assistant_content: str, tool_messages, ) -> Tuple[str, 'Prompt']: with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content if with_action: return super()._format_tool_responses(assistant_content, tool_messages) if hasattr(self, 'template_meta'): prompt = self.template_meta.prompt chat_sep = self.template_meta.chat_sep else: prompt = ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'] chat_sep = ['<|im_end|>\n'] res = chat_sep.copy() res_tool = [] for tool_message in tool_messages: tool_content = tool_message['content'] res_tool.append(f'\n{tool_content}\n') total_tool = '\n'.join(res_tool) for context in prompt: if isinstance(context, str): context = context.replace('{{QUERY}}', total_tool) res.append(context) return assistant_content, res def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: tool_descs = [json.dumps(self.wrap_tool(tool), ensure_ascii=False) for tool in tools] return f"""{system} # Tools You may call one or more functions to assist with the user query. You are provided with function signatures within XML tags: """ + '\n'.join(tool_descs) + """ For each function call, return a json object with function name and arguments within XML tags: {"name": , "arguments": } """ def _format_tool_calls(self, tool_call_messages): tool_calls = [] for message in tool_call_messages: tool_call = self._parse_tool_call(message['content']) tool_calls.append(f'\n{json.dumps(tool_call, ensure_ascii=False)}\n') return '\n'.join(tool_calls)