File size: 12,386 Bytes
47258ea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
工具调用处理模块
"""
import json
import re
from typing import Dict, List, Any, Optional, Tuple
from app.utils.logger import get_logger
logger = get_logger()
def generate_tool_prompt(tools: Optional[List[Dict[str, Any]]]) -> str:
"""
生成工具调用提示词
将 OpenAI tools 定义转换为 Markdown 格式的说明文档
Args:
tools: OpenAI 格式的工具定义列表
Returns:
str: Markdown 格式的工具使用说明
"""
if not tools or len(tools) == 0:
return ""
tool_definitions = []
for tool in tools:
if tool.get("type") != "function":
continue
function_spec = tool.get("function", {})
function_name = function_spec.get("name", "unknown")
function_description = function_spec.get("description", "")
parameters = function_spec.get("parameters", {})
# 创建结构化的工具定义
tool_info = [
f"## {function_name}",
f"**Purpose**: {function_description}"
]
# 添加参数详情
parameter_properties = parameters.get("properties", {})
required_parameters = set(parameters.get("required", []))
if parameter_properties:
tool_info.append("**Parameters**:")
for param_name, param_info in parameter_properties.items():
param_type = param_info.get("type", "string")
param_desc = param_info.get("description", "")
is_required = param_name in required_parameters
required_str = " (required)" if is_required else " (optional)"
tool_info.append(f"- `{param_name}` ({param_type}){required_str}: {param_desc}")
tool_definitions.append("\n".join(tool_info))
# 组合完整的提示词
prompt = (
"\n\n---\n"
"# Available Tools\n\n"
+ "\n\n".join(tool_definitions) +
"\n\n"
"**Tool Invocation Format**:\n"
"To use a tool, include a JSON block with this structure:\n"
'{"tool_calls": [{"id": "call_ID", "type": "function", "function": {"name": "TOOL_NAME", "arguments": "JSON_STRING"}}]}\n\n'
"**Rules**:\n"
"- Use tool ONLY when user explicitly requests an action that matches a tool's purpose\n"
"- For normal conversation, respond naturally WITHOUT any tool calls\n"
"- The `arguments` must be a JSON string, not an object\n"
"- Multiple tools can be called by adding more items to the array\n"
"---\n\n"
)
logger.debug(f"生成工具提示词,包含 {len(tool_definitions)} 个工具定义")
return prompt
def process_messages_with_tools(
messages: List[Dict[str, Any]],
tools: Optional[List[Dict[str, Any]]],
tool_choice: str = "auto"
) -> List[Dict[str, Any]]:
"""
将工具定义注入到消息列表中
Args:
messages: 原始消息列表
tools: 工具定义列表
tool_choice: 工具选择策略 ("auto", "none", 等)
Returns:
List[Dict]: 处理后的消息列表
"""
if not tools or tool_choice == "none":
return messages
tools_prompt = generate_tool_prompt(tools)
if not tools_prompt:
return messages
processed = []
has_system = any(m.get("role") == "system" for m in messages)
if has_system:
# 如果有 system 消息,将工具提示追加到第一个 system 消息
for msg in messages:
if msg.get("role") == "system":
new_msg = msg.copy()
content = new_msg.get("content", "")
if isinstance(content, list):
# 多模态内容
content_str = " ".join([
item.get("text", "") if item.get("type") == "text" else ""
for item in content
])
else:
content_str = str(content)
new_msg["content"] = content_str + tools_prompt
processed.append(new_msg)
else:
processed.append(msg)
else:
# 没有 system 消息,创建一个新的 system 消息
processed.append({
"role": "system",
"content": f"You are a helpful assistant with access to tools.{tools_prompt}"
})
processed.extend(messages)
logger.debug(f"工具提示已注入到消息列表,共 {len(processed)} 条消息")
return processed
def parse_and_extract_tool_calls(content: str) -> Tuple[Optional[List[Dict[str, Any]]], str]:
"""
从响应内容中提取 tool_calls JSON
Args:
content: 模型返回的文本内容
Returns:
Tuple[Optional[List], str]: (提取的 tool_calls 列表, 清理后的内容)
"""
if not content or not content.strip():
return None, content
tool_calls = None
cleaned_content = content
# 方法1: 尝试解析 JSON 代码块中的 tool_calls
# 匹配 ```json ... ``` 或 ```...```
json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```'
json_blocks = re.findall(json_block_pattern, content)
for json_str in json_blocks:
try:
parsed_data = json.loads(json_str)
if "tool_calls" in parsed_data:
tool_calls = parsed_data["tool_calls"]
if tool_calls and isinstance(tool_calls, list):
# 确保 arguments 字段是字符串
for tc in tool_calls:
if tc.get("function"):
func = tc["function"]
if func.get("arguments"):
if isinstance(func["arguments"], dict):
# 转换对象为 JSON 字符串
func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
elif not isinstance(func["arguments"], str):
func["arguments"] = str(func["arguments"])
logger.debug(f"从 JSON 代码块中提取到 {len(tool_calls)} 个工具调用")
break
except json.JSONDecodeError:
continue
# 方法2: 尝试从文本中直接查找 JSON 对象
if not tool_calls:
# 查找包含 "tool_calls" 的 JSON 对象
i = 0
scannable_text = content
while i < len(scannable_text):
if scannable_text[i] == '{':
# 尝试找到匹配的闭合括号
brace_count = 1
j = i + 1
in_string = False
escape_next = False
while j < len(scannable_text) and brace_count > 0:
if escape_next:
escape_next = False
elif scannable_text[j] == '\\':
escape_next = True
elif scannable_text[j] == '"':
in_string = not in_string
elif not in_string:
if scannable_text[j] == '{':
brace_count += 1
elif scannable_text[j] == '}':
brace_count -= 1
j += 1
if brace_count == 0:
# 找到完整的 JSON 对象
json_candidate = scannable_text[i:j]
try:
parsed_data = json.loads(json_candidate)
if "tool_calls" in parsed_data:
tool_calls = parsed_data["tool_calls"]
if tool_calls and isinstance(tool_calls, list):
# 确保 arguments 字段是字符串
for tc in tool_calls:
if tc.get("function"):
func = tc["function"]
if func.get("arguments"):
if isinstance(func["arguments"], dict):
func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
elif not isinstance(func["arguments"], str):
func["arguments"] = str(func["arguments"])
logger.debug(f"从内联 JSON 中提取到 {len(tool_calls)} 个工具调用")
break
except json.JSONDecodeError:
pass
i = j
else:
i += 1
# 清理内容 - 移除包含 tool_calls 的 JSON
if tool_calls:
cleaned_content = remove_tool_json_content(content)
return tool_calls, cleaned_content
def remove_tool_json_content(content: str) -> str:
"""
从响应内容中移除工具调用 JSON
Args:
content: 原始响应内容
Returns:
str: 清理后的内容
"""
if not content:
return content
# 步骤1: 移除 JSON 代码块中包含 tool_calls 的部分
cleaned_text = content
# 匹配 ```json ... ``` 或 ```...```
def replace_json_block(match):
json_content = match.group(1)
try:
parsed_data = json.loads(json_content)
if "tool_calls" in parsed_data:
return "" # 移除整个代码块
except json.JSONDecodeError:
pass
return match.group(0) # 保留原文
json_block_pattern = r'```(?:json)?\s*\n?(\{[\s\S]*?\})\s*\n?```'
cleaned_text = re.sub(json_block_pattern, replace_json_block, cleaned_text)
# 步骤2: 移除内联的 tool JSON - 使用括号平衡方法
result = []
i = 0
while i < len(cleaned_text):
if cleaned_text[i] == '{':
# 尝试找到匹配的闭合括号
brace_count = 1
j = i + 1
in_string = False
escape_next = False
while j < len(cleaned_text) and brace_count > 0:
if escape_next:
escape_next = False
elif cleaned_text[j] == '\\':
escape_next = True
elif cleaned_text[j] == '"':
in_string = not in_string
elif not in_string:
if cleaned_text[j] == '{':
brace_count += 1
elif cleaned_text[j] == '}':
brace_count -= 1
j += 1
if brace_count == 0:
# 找到完整的 JSON 对象,检查是否包含 tool_calls
json_candidate = cleaned_text[i:j]
try:
parsed = json.loads(json_candidate)
if "tool_calls" in parsed:
# 这是一个工具调用,跳过它
i = j
continue
except json.JSONDecodeError:
pass
# 不是工具调用或无法解析,保留这个字符
result.append(cleaned_text[i])
i += 1
else:
result.append(cleaned_text[i])
i += 1
cleaned_result = "".join(result).strip()
# 移除多余的空白行
cleaned_result = re.sub(r'\n{3,}', '\n\n', cleaned_result)
logger.debug(f"内容清理完成,原始长度: {len(content)}, 清理后长度: {len(cleaned_result)}")
return cleaned_result
def content_to_string(content: Any) -> str:
"""
将消息内容转换为字符串
Args:
content: 消息内容,可能是字符串或列表(多模态)
Returns:
str: 字符串格式的内容
"""
if isinstance(content, str):
return content
elif isinstance(content, list):
# 多模态内容,提取文本部分
text_parts = []
for item in content:
if isinstance(item, dict):
if item.get("type") == "text":
text_parts.append(item.get("text", ""))
elif isinstance(item, str):
text_parts.append(item)
return " ".join(text_parts)
else:
return str(content)
|