|
|
|
|
|
|
|
|
""" |
|
|
工具调用处理模块 |
|
|
""" |
|
|
|
|
|
import json |
|
|
import re |
|
|
from typing import Dict, List, Any, Optional, Tuple |
|
|
from app.utils.logger import get_logger |
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
def generate_tool_prompt(tools: Optional[List[Dict[str, Any]]]) -> str: |
|
|
""" |
|
|
生成工具调用提示词 |
|
|
将 OpenAI tools 定义转换为 Markdown 格式的说明文档 |
|
|
|
|
|
Args: |
|
|
tools: OpenAI 格式的工具定义列表 |
|
|
|
|
|
Returns: |
|
|
str: Markdown 格式的工具使用说明 |
|
|
""" |
|
|
if not tools or len(tools) == 0: |
|
|
return "" |
|
|
|
|
|
tool_definitions = [] |
|
|
|
|
|
for tool in tools: |
|
|
if tool.get("type") != "function": |
|
|
continue |
|
|
|
|
|
function_spec = tool.get("function", {}) |
|
|
function_name = function_spec.get("name", "unknown") |
|
|
function_description = function_spec.get("description", "") |
|
|
parameters = function_spec.get("parameters", {}) |
|
|
|
|
|
|
|
|
tool_info = [ |
|
|
f"## {function_name}", |
|
|
f"**Purpose**: {function_description}" |
|
|
] |
|
|
|
|
|
|
|
|
parameter_properties = parameters.get("properties", {}) |
|
|
required_parameters = set(parameters.get("required", [])) |
|
|
|
|
|
if parameter_properties: |
|
|
tool_info.append("**Parameters**:") |
|
|
for param_name, param_info in parameter_properties.items(): |
|
|
param_type = param_info.get("type", "string") |
|
|
param_desc = param_info.get("description", "") |
|
|
is_required = param_name in required_parameters |
|
|
required_str = " (required)" if is_required else " (optional)" |
|
|
tool_info.append(f"- `{param_name}` ({param_type}){required_str}: {param_desc}") |
|
|
|
|
|
tool_definitions.append("\n".join(tool_info)) |
|
|
|
|
|
|
|
|
prompt = ( |
|
|
"\n\n---\n" |
|
|
"# Available Tools\n\n" |
|
|
+ "\n\n".join(tool_definitions) + |
|
|
"\n\n" |
|
|
"**Tool Invocation Format**:\n" |
|
|
"To use a tool, include a JSON block with this structure:\n" |
|
|
'{"tool_calls": [{"id": "call_ID", "type": "function", "function": {"name": "TOOL_NAME", "arguments": "JSON_STRING"}}]}\n\n' |
|
|
"**Rules**:\n" |
|
|
"- Use tool ONLY when user explicitly requests an action that matches a tool's purpose\n" |
|
|
"- For normal conversation, respond naturally WITHOUT any tool calls\n" |
|
|
"- The `arguments` must be a JSON string, not an object\n" |
|
|
"- Multiple tools can be called by adding more items to the array\n" |
|
|
"---\n\n" |
|
|
) |
|
|
|
|
|
logger.debug(f"生成工具提示词,包含 {len(tool_definitions)} 个工具定义") |
|
|
return prompt |
|
|
|
|
|
|
|
|
def process_messages_with_tools( |
|
|
messages: List[Dict[str, Any]], |
|
|
tools: Optional[List[Dict[str, Any]]], |
|
|
tool_choice: str = "auto" |
|
|
) -> List[Dict[str, Any]]: |
|
|
""" |
|
|
将工具定义注入到消息列表中 |
|
|
|
|
|
Args: |
|
|
messages: 原始消息列表 |
|
|
tools: 工具定义列表 |
|
|
tool_choice: 工具选择策略 ("auto", "none", 等) |
|
|
|
|
|
Returns: |
|
|
List[Dict]: 处理后的消息列表 |
|
|
""" |
|
|
if not tools or tool_choice == "none": |
|
|
return messages |
|
|
|
|
|
tools_prompt = generate_tool_prompt(tools) |
|
|
if not tools_prompt: |
|
|
return messages |
|
|
|
|
|
processed = [] |
|
|
has_system = any(m.get("role") == "system" for m in messages) |
|
|
|
|
|
if has_system: |
|
|
|
|
|
for msg in messages: |
|
|
if msg.get("role") == "system": |
|
|
new_msg = msg.copy() |
|
|
content = new_msg.get("content", "") |
|
|
if isinstance(content, list): |
|
|
|
|
|
content_str = " ".join([ |
|
|
item.get("text", "") if item.get("type") == "text" else "" |
|
|
for item in content |
|
|
]) |
|
|
else: |
|
|
content_str = str(content) |
|
|
new_msg["content"] = content_str + tools_prompt |
|
|
processed.append(new_msg) |
|
|
else: |
|
|
processed.append(msg) |
|
|
else: |
|
|
|
|
|
processed.append({ |
|
|
"role": "system", |
|
|
"content": f"You are a helpful assistant with access to tools.{tools_prompt}" |
|
|
}) |
|
|
processed.extend(messages) |
|
|
|
|
|
logger.debug(f"工具提示已注入到消息列表,共 {len(processed)} 条消息") |
|
|
return processed |
|
|
|
|
|
|
|
|
def parse_and_extract_tool_calls(content: str) -> Tuple[Optional[List[Dict[str, Any]]], str]: |
|
|
""" |
|
|
从响应内容中提取 tool_calls JSON |
|
|
|
|
|
Args: |
|
|
content: 模型返回的文本内容 |
|
|
|
|
|
Returns: |
|
|
Tuple[Optional[List], str]: (提取的 tool_calls 列表, 清理后的内容) |
|
|
""" |
|
|
if not content or not content.strip(): |
|
|
return None, content |
|
|
|
|
|
tool_calls = None |
|
|
cleaned_content = content |
|
|
|
|
|
|
|
|
|
|
|
json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```' |
|
|
json_blocks = re.findall(json_block_pattern, content) |
|
|
|
|
|
for json_str in json_blocks: |
|
|
try: |
|
|
parsed_data = json.loads(json_str) |
|
|
if "tool_calls" in parsed_data: |
|
|
tool_calls = parsed_data["tool_calls"] |
|
|
if tool_calls and isinstance(tool_calls, list): |
|
|
|
|
|
for tc in tool_calls: |
|
|
if tc.get("function"): |
|
|
func = tc["function"] |
|
|
if func.get("arguments"): |
|
|
if isinstance(func["arguments"], dict): |
|
|
|
|
|
func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False) |
|
|
elif not isinstance(func["arguments"], str): |
|
|
func["arguments"] = str(func["arguments"]) |
|
|
logger.debug(f"从 JSON 代码块中提取到 {len(tool_calls)} 个工具调用") |
|
|
break |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
|
|
|
|
|
|
if not tool_calls: |
|
|
|
|
|
i = 0 |
|
|
scannable_text = content |
|
|
while i < len(scannable_text): |
|
|
if scannable_text[i] == '{': |
|
|
|
|
|
brace_count = 1 |
|
|
j = i + 1 |
|
|
in_string = False |
|
|
escape_next = False |
|
|
|
|
|
while j < len(scannable_text) and brace_count > 0: |
|
|
if escape_next: |
|
|
escape_next = False |
|
|
elif scannable_text[j] == '\\': |
|
|
escape_next = True |
|
|
elif scannable_text[j] == '"': |
|
|
in_string = not in_string |
|
|
elif not in_string: |
|
|
if scannable_text[j] == '{': |
|
|
brace_count += 1 |
|
|
elif scannable_text[j] == '}': |
|
|
brace_count -= 1 |
|
|
j += 1 |
|
|
|
|
|
if brace_count == 0: |
|
|
|
|
|
json_candidate = scannable_text[i:j] |
|
|
try: |
|
|
parsed_data = json.loads(json_candidate) |
|
|
if "tool_calls" in parsed_data: |
|
|
tool_calls = parsed_data["tool_calls"] |
|
|
if tool_calls and isinstance(tool_calls, list): |
|
|
|
|
|
for tc in tool_calls: |
|
|
if tc.get("function"): |
|
|
func = tc["function"] |
|
|
if func.get("arguments"): |
|
|
if isinstance(func["arguments"], dict): |
|
|
func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False) |
|
|
elif not isinstance(func["arguments"], str): |
|
|
func["arguments"] = str(func["arguments"]) |
|
|
logger.debug(f"从内联 JSON 中提取到 {len(tool_calls)} 个工具调用") |
|
|
break |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
i = j |
|
|
else: |
|
|
i += 1 |
|
|
|
|
|
|
|
|
if tool_calls: |
|
|
cleaned_content = remove_tool_json_content(content) |
|
|
|
|
|
return tool_calls, cleaned_content |
|
|
|
|
|
|
|
|
def remove_tool_json_content(content: str) -> str: |
|
|
""" |
|
|
从响应内容中移除工具调用 JSON |
|
|
|
|
|
Args: |
|
|
content: 原始响应内容 |
|
|
|
|
|
Returns: |
|
|
str: 清理后的内容 |
|
|
""" |
|
|
if not content: |
|
|
return content |
|
|
|
|
|
|
|
|
cleaned_text = content |
|
|
|
|
|
|
|
|
def replace_json_block(match): |
|
|
json_content = match.group(1) |
|
|
try: |
|
|
parsed_data = json.loads(json_content) |
|
|
if "tool_calls" in parsed_data: |
|
|
return "" |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
return match.group(0) |
|
|
|
|
|
json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```' |
|
|
cleaned_text = re.sub(json_block_pattern, replace_json_block, cleaned_text) |
|
|
|
|
|
|
|
|
result = [] |
|
|
i = 0 |
|
|
|
|
|
while i < len(cleaned_text): |
|
|
if cleaned_text[i] == '{': |
|
|
|
|
|
brace_count = 1 |
|
|
j = i + 1 |
|
|
in_string = False |
|
|
escape_next = False |
|
|
|
|
|
while j < len(cleaned_text) and brace_count > 0: |
|
|
if escape_next: |
|
|
escape_next = False |
|
|
elif cleaned_text[j] == '\\': |
|
|
escape_next = True |
|
|
elif cleaned_text[j] == '"': |
|
|
in_string = not in_string |
|
|
elif not in_string: |
|
|
if cleaned_text[j] == '{': |
|
|
brace_count += 1 |
|
|
elif cleaned_text[j] == '}': |
|
|
brace_count -= 1 |
|
|
j += 1 |
|
|
|
|
|
if brace_count == 0: |
|
|
|
|
|
json_candidate = cleaned_text[i:j] |
|
|
try: |
|
|
parsed = json.loads(json_candidate) |
|
|
if "tool_calls" in parsed: |
|
|
|
|
|
i = j |
|
|
continue |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
result.append(cleaned_text[i]) |
|
|
i += 1 |
|
|
else: |
|
|
result.append(cleaned_text[i]) |
|
|
i += 1 |
|
|
|
|
|
cleaned_result = "".join(result).strip() |
|
|
|
|
|
|
|
|
cleaned_result = re.sub(r'\n{3,}', '\n\n', cleaned_result) |
|
|
|
|
|
logger.debug(f"内容清理完成,原始长度: {len(content)}, 清理后长度: {len(cleaned_result)}") |
|
|
return cleaned_result |
|
|
|
|
|
|
|
|
def content_to_string(content: Any) -> str: |
|
|
""" |
|
|
将消息内容转换为字符串 |
|
|
|
|
|
Args: |
|
|
content: 消息内容,可能是字符串或列表(多模态) |
|
|
|
|
|
Returns: |
|
|
str: 字符串格式的内容 |
|
|
""" |
|
|
if isinstance(content, str): |
|
|
return content |
|
|
elif isinstance(content, list): |
|
|
|
|
|
text_parts = [] |
|
|
for item in content: |
|
|
if isinstance(item, dict): |
|
|
if item.get("type") == "text": |
|
|
text_parts.append(item.get("text", "")) |
|
|
elif isinstance(item, str): |
|
|
text_parts.append(item) |
|
|
return " ".join(text_parts) |
|
|
else: |
|
|
return str(content) |
|
|
|