|
|
|
|
|
import re |
|
|
from typing import TYPE_CHECKING, List, Tuple, Union |
|
|
|
|
|
import json |
|
|
|
|
|
from .base import BaseAgentTemplate |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from swift.llm.infer import Function |
|
|
from swift.llm.template import Prompt |
|
|
|
|
|
|
|
|
class Llama3AgentTemplate(BaseAgentTemplate): |
|
|
eom_token = '<|eom_id|>' |
|
|
start_token = '<|start_header_id|>' |
|
|
end_token = '<|end_header_id|>' |
|
|
eot_token = '<|eot_id|>' |
|
|
|
|
|
def get_toolcall(self, response: str) -> List['Function']: |
|
|
from swift.llm.infer import Function |
|
|
if response.endswith(self.eom_token): |
|
|
response = response[:-len(self.eom_token)] |
|
|
functions = [] |
|
|
res_list = re.findall(r'{[^{]*?"name":.*?"parameters":\s*?{.*?}\s*?}', response, re.DOTALL) |
|
|
for res in res_list: |
|
|
res = self._parse_json(res) |
|
|
if isinstance(res, dict) and 'name' in res and 'parameters' in res: |
|
|
functions.append(Function(name=res['name'], arguments=res['parameters'])) |
|
|
if len(functions) == 0: |
|
|
|
|
|
return super().get_toolcall(response) |
|
|
return functions |
|
|
|
|
|
def _format_tool_responses( |
|
|
self, |
|
|
assistant_content: str, |
|
|
tool_messages, |
|
|
) -> Tuple[str, 'Prompt']: |
|
|
with_action = self.keyword.action in assistant_content and self.keyword.action_input in assistant_content |
|
|
if with_action: |
|
|
return super()._format_tool_responses(assistant_content, tool_messages) |
|
|
res = [self.eot_token] |
|
|
for tool_message in tool_messages: |
|
|
tool_content = tool_message['content'] |
|
|
res.append(f'{self.start_token}tool{self.end_token}\n\n{tool_content}{self.eot_token}') |
|
|
res.append(f'{self.start_token}assistant{self.end_token}\n\n') |
|
|
return assistant_content, res |
|
|
|
|
|
def _format_tools(self, tools: List[Union[str, dict]], system: str, user_message=None) -> str: |
|
|
assert user_message is not None |
|
|
user_content = user_message['content'] |
|
|
tool_descs = [json.dumps(tool, ensure_ascii=False, indent=4) for tool in tools] |
|
|
new_user_content = """Given the following functions, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. |
|
|
|
|
|
Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. Do not use variables. |
|
|
|
|
|
""" + '\n\n'.join(tool_descs) + f""" |
|
|
|
|
|
{user_content}""" |
|
|
user_message['content'] = new_user_content |
|
|
return system |
|
|
|
|
|
def _format_tool_calls(self, tool_call_messages) -> str: |
|
|
tool_calls = [] |
|
|
for message in tool_call_messages: |
|
|
tool_call = self._parse_tool_call(message['content']) |
|
|
tool_call['parameters'] = tool_call.pop('arguments') |
|
|
tool_calls.append(json.dumps(tool_call, ensure_ascii=False)) |
|
|
return '\n'.join(tool_calls) |
|
|
|
|
|
|
|
|
class Llama4AgentTemplate(Llama3AgentTemplate): |
|
|
eom_token = '<|eom|>' |
|
|
start_token = '<|header_start|>' |
|
|
end_token = '<|header_end|>' |
|
|
eot_token = '<|eot|>' |
|
|
toolcall_pattern = r'(.+?)<\|eom\|>' |
|
|
|