k2thinkService / src /tool_handler.py
youbiaokachi's picture
Upload 10 files
1a06196 verified
"""
工具处理模块
处理工具调用相关的所有逻辑
"""
import json
import re
import time
import logging
from typing import List, Dict, Optional, Union
from src.constants import (
ToolConstants, ContentConstants, LogMessages,
TimeConstants
)
from src.exceptions import ToolProcessingError
logger = logging.getLogger(__name__)
class ToolHandler:
"""工具调用处理器"""
# 工具调用提取模式
TOOL_CALL_FENCE_PATTERN = re.compile(r"```json\s*(\{.*?\})\s*```", re.DOTALL)
FUNCTION_CALL_PATTERN = re.compile(
r"调用函数\s*[::]\s*([\w\-\.]+)\s*(?:参数|arguments)[::]\s*(\{.*?\})",
re.DOTALL
)
def __init__(self, config):
self.config = config
self.scan_limit = config.SCAN_LIMIT
self.system_message_length = config.SYSTEM_MESSAGE_LENGTH
self.tool_support = config.TOOL_SUPPORT
def generate_tool_prompt(self, tools: List[Dict]) -> str:
"""生成简洁的工具注入提示"""
if not tools:
return ""
tool_definitions = []
for tool in tools:
if tool.get("type") != ToolConstants.FUNCTION_TYPE:
continue
function_spec = tool.get("function", {}) or {}
function_name = function_spec.get("name", "unknown")
function_description = function_spec.get("description", "")
parameters = function_spec.get("parameters", {}) or {}
# 创建简洁的工具定义
tool_info = f"{function_name}: {function_description}"
# 添加简化的参数信息
parameter_properties = parameters.get("properties", {}) or {}
required_parameters = set(parameters.get("required", []) or [])
if parameter_properties:
param_list = []
for param_name, param_details in parameter_properties.items():
param_desc = (param_details or {}).get("description", "")
is_required = param_name in required_parameters
param_list.append(f"{param_name}{'*' if is_required else ''}: {param_desc}")
tool_info += f" Parameters: {', '.join(param_list)}"
tool_definitions.append(tool_info)
if not tool_definitions:
return ""
# 构建简洁的工具提示
prompt_template = (
f"\n\nAvailable tools: {'; '.join(tool_definitions)}. "
"To use a tool, respond with JSON: "
'{"tool_calls":[{"id":"call_xxx","type":"function","function":{"name":"tool_name","arguments":"{\\"param\\":\\"value\\"}"}}]}'
)
return prompt_template
def process_messages_with_tools(
self,
messages: List[Dict],
tools: Optional[List[Dict]] = None,
tool_choice: Optional[Union[str, Dict]] = None
) -> List[Dict]:
"""处理消息并注入工具提示"""
if not tools or not self.tool_support or (tool_choice == "none"):
# 如果没有工具或禁用工具,直接返回原消息
return [dict(m) for m in messages]
tools_prompt = self.generate_tool_prompt(tools)
# 限制工具提示长度,避免过长导致上游API拒绝
if len(tools_prompt) > ToolConstants.MAX_TOOL_PROMPT_LENGTH:
logger.warning(LogMessages.TOOL_PROMPT_TOO_LONG.format(len(tools_prompt)))
tools_prompt = tools_prompt[:ToolConstants.MAX_TOOL_PROMPT_LENGTH] + ToolConstants.TOOL_PROMPT_TRUNCATE_SUFFIX
processed = []
has_system = any(m.get("role") == "system" for m in messages)
if has_system:
# 如果已有系统消息,在第一个系统消息中添加工具提示
for m in messages:
if m.get("role") == "system":
mm = dict(m)
content = self._content_to_string(mm.get("content", ""))
# 确保系统消息不会过长
new_content = content + tools_prompt
if len(new_content) > self.system_message_length:
logger.warning(LogMessages.SYSTEM_MESSAGE_TOO_LONG.format(len(new_content)))
mm["content"] = "你是一个有用的助手。" + tools_prompt
else:
mm["content"] = new_content
processed.append(mm)
# 只在第一个系统消息中添加工具提示
tools_prompt = ""
else:
processed.append(dict(m))
else:
# 如果没有系统消息,需要添加一个,但只有当确实需要工具时
if tools_prompt.strip():
processed = [{"role": "system", "content": "你是一个有用的助手。" + tools_prompt}]
processed.extend([dict(m) for m in messages])
else:
processed = [dict(m) for m in messages]
# 添加简化的工具选择提示
if tool_choice == "required":
if processed and processed[-1].get("role") == "user":
last = processed[-1]
content = self._content_to_string(last.get("content", ""))
last["content"] = content + "\n请使用工具来处理这个请求。"
elif isinstance(tool_choice, dict) and tool_choice.get("type") == ToolConstants.FUNCTION_TYPE:
fname = (tool_choice.get("function") or {}).get("name")
if fname and processed and processed[-1].get("role") == "user":
last = processed[-1]
content = self._content_to_string(last.get("content", ""))
last["content"] = content + f"\n请使用 {fname} 工具。"
# 处理工具/函数消息
final_msgs = []
for m in processed:
role = m.get("role")
if role in ("tool", "function"):
tool_name = m.get("name", "unknown")
tool_content = self._content_to_string(m.get("content", ""))
if isinstance(tool_content, dict):
tool_content = json.dumps(tool_content, ensure_ascii=False)
# 简化工具结果消息
content = f"工具 {tool_name} 结果: {tool_content}"
if not content.strip():
content = f"工具 {tool_name} 执行完成"
final_msgs.append({
"role": "assistant",
"content": content,
})
else:
# 对于常规消息,确保内容是字符串格式
final_msg = dict(m)
content = self._content_to_string(final_msg.get("content", ""))
final_msg["content"] = content
final_msgs.append(final_msg)
return final_msgs
def extract_tool_invocations(self, text: str) -> Optional[List[Dict]]:
"""从响应文本中提取工具调用"""
if not text:
return None
# 限制扫描大小以提高性能
scannable_text = text[:self.scan_limit]
# 尝试1:从JSON代码块中提取
json_blocks = self.TOOL_CALL_FENCE_PATTERN.findall(scannable_text)
for json_block in json_blocks:
try:
parsed_data = json.loads(json_block)
tool_calls = parsed_data.get("tool_calls")
if tool_calls and isinstance(tool_calls, list):
# 确保arguments字段是字符串
self._normalize_tool_calls(tool_calls)
return tool_calls
except (json.JSONDecodeError, AttributeError):
continue
# 尝试2:使用括号平衡方法提取内联JSON对象
tool_calls = self._extract_inline_json_tool_calls(scannable_text)
if tool_calls:
return tool_calls
# 尝试3:解析自然语言函数调用
natural_lang_match = self.FUNCTION_CALL_PATTERN.search(scannable_text)
if natural_lang_match:
function_name = natural_lang_match.group(1).strip()
arguments_str = natural_lang_match.group(2).strip()
try:
# 验证JSON格式
json.loads(arguments_str)
return [
{
"id": f"{ToolConstants.CALL_ID_PREFIX}{int(time.time() * TimeConstants.MICROSECONDS_MULTIPLIER)}",
"type": ToolConstants.FUNCTION_TYPE,
"function": {"name": function_name, "arguments": arguments_str},
}
]
except json.JSONDecodeError:
return None
return None
def remove_tool_json_content(self, text: str) -> str:
"""从响应文本中移除工具JSON内容 - 使用括号平衡方法"""
def remove_tool_call_block(match: re.Match) -> str:
json_content = match.group(1)
try:
parsed_data = json.loads(json_content)
if "tool_calls" in parsed_data:
return ""
except (json.JSONDecodeError, AttributeError):
pass
return match.group(0)
# 步骤1:移除围栏工具JSON块
cleaned_text = self.TOOL_CALL_FENCE_PATTERN.sub(remove_tool_call_block, text)
# 步骤2:移除内联工具JSON - 使用基于括号平衡的智能方法
result = []
i = 0
while i < len(cleaned_text):
if cleaned_text[i] == '{':
# 尝试找到匹配的右括号
brace_count = 1
j = i + 1
in_string = False
escape_next = False
while j < len(cleaned_text) and brace_count > 0:
if escape_next:
escape_next = False
elif cleaned_text[j] == '\\':
escape_next = True
elif cleaned_text[j] == '"' and not escape_next:
in_string = not in_string
elif not in_string:
if cleaned_text[j] == '{':
brace_count += 1
elif cleaned_text[j] == '}':
brace_count -= 1
j += 1
if brace_count == 0:
# 找到了完整的JSON对象
json_str = cleaned_text[i:j]
try:
parsed = json.loads(json_str)
if "tool_calls" in parsed:
# 这是一个工具调用,跳过它
i = j
continue
except:
pass
# 不是工具调用或无法解析,保留这个字符
result.append(cleaned_text[i])
i += 1
else:
result.append(cleaned_text[i])
i += 1
return ''.join(result).strip()
def _extract_inline_json_tool_calls(self, text: str) -> Optional[List[Dict]]:
"""使用括号平衡方法提取内联JSON工具调用"""
i = 0
while i < len(text):
if text[i] == '{':
# 尝试找到匹配的右括号
brace_count = 1
j = i + 1
in_string = False
escape_next = False
while j < len(text) and brace_count > 0:
if escape_next:
escape_next = False
elif text[j] == '\\':
escape_next = True
elif text[j] == '"' and not escape_next:
in_string = not in_string
elif not in_string:
if text[j] == '{':
brace_count += 1
elif text[j] == '}':
brace_count -= 1
j += 1
if brace_count == 0:
# 找到了完整的JSON对象
json_str = text[i:j]
try:
parsed_data = json.loads(json_str)
tool_calls = parsed_data.get("tool_calls")
if tool_calls and isinstance(tool_calls, list):
# 确保arguments字段是字符串
self._normalize_tool_calls(tool_calls)
return tool_calls
except (json.JSONDecodeError, AttributeError):
pass
i += 1
else:
i += 1
return None
def _normalize_tool_calls(self, tool_calls: List[Dict]) -> None:
"""标准化工具调用,确保arguments字段是字符串"""
for tc in tool_calls:
if "function" in tc:
func = tc["function"]
if "arguments" in func:
if isinstance(func["arguments"], dict):
# 将字典转换为JSON字符串
func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
elif not isinstance(func["arguments"], str):
func["arguments"] = json.dumps(func["arguments"], ensure_ascii=False)
def _content_to_string(self, content) -> str:
"""将各种格式的内容转换为字符串"""
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts = []
for p in content:
if hasattr(p, 'text'): # ContentPart object
if getattr(p, 'text', None):
parts.append(getattr(p, 'text', ''))
elif isinstance(p, dict):
if p.get("type") == ContentConstants.TEXT_TYPE:
parts.append(p.get("text", ""))
elif p.get("type") == ContentConstants.IMAGE_URL_TYPE:
# 处理图像内容,添加描述性文本
parts.append(ContentConstants.IMAGE_PLACEHOLDER)
elif isinstance(p, str):
parts.append(p)
else:
# 处理其他类型的对象
try:
if hasattr(p, '__dict__'):
# 如果是对象,尝试获取text属性或转换为字符串
text_attr = getattr(p, 'text', None)
if text_attr:
parts.append(str(text_attr))
else:
parts.append(str(p))
except:
continue
return " ".join(parts)
# 处理其他类型
try:
return str(content)
except:
return ""