gakrchat1 / tools /tool_detector.py
extraplus's picture
Upload 3 files
3b667a5 verified
"""Tool-call detection and parsing for Autobot Instruct outputs."""
from __future__ import annotations
import json
import re
from typing import Any, Dict, Optional, Tuple
SPECIAL_TOKENS = [
"<|im_end|>",
"<|im_start|>",
"<|endoftext|>",
"<|startoftext|>",
"<|tool_list_start|>",
"<|tool_list_end|>",
]
def _strip_special_tokens(text: str) -> str:
cleaned = text
for token in SPECIAL_TOKENS:
cleaned = cleaned.replace(token, "")
return cleaned.strip()
def _normalize_tool_name(name: str) -> str:
if not name:
return ""
# Convert to lowercase, replace spaces/dashes with underscores
normalized = name.strip().lower().replace(" ", "_").replace("-", "_")
normalized_ascii = re.sub(r"[^a-z0-9_]", "", normalized)
# Handle all variations of web_search and search
if "websearch" in normalized_ascii:
return "web_search"
if "searchweb" in normalized_ascii or "search_web" in normalized_ascii:
return "web_search"
if "web" in normalized_ascii and "search" in normalized_ascii:
return "web_search"
if "search" in normalized_ascii and "web" in normalized_ascii:
return "web_search"
if normalized_ascii == "search": # Plain 'search' -> 'web_search'
return "web_search"
if "搜索" in normalized or "搜" in normalized:
return "web_search"
if normalized_ascii.startswith("web_") or normalized_ascii == "web":
return "web_search"
return normalized
def _extract_tool_payload(text: str) -> Optional[Tuple[str, str]]:
"""Only extract tool payload when explicit <|tool_call_start|> marker is present."""
start_token = "<|tool_call_start|>"
end_token = "<|tool_call_end|>"
if start_token not in text:
return None
print("[TOOL] Tool call marker detected in generated text")
start_pos = text.find(start_token) + len(start_token)
end_pos = text.find(end_token, start_pos)
if end_pos != -1:
payload = text[start_pos:end_pos].strip()
else:
payload = text[start_pos:].strip()
# Remove outer brackets if present: [web_search(...)] → web_search(...)
while (payload.startswith("[") and payload.endswith("]")) or \
(payload.startswith("(") and payload.endswith(")")):
payload = payload[1:-1].strip()
raw_tool_call = f"{start_token}{payload}{end_token}"
return payload, raw_tool_call
def _parse_args(args_text: str) -> Dict[str, Any]:
"""Parse function call arguments like: query="value" or query='value' or just "value" """
args: Dict[str, Any] = {}
if not args_text.strip():
return args
# Try named arguments: query="value", max=5, etc.
# Handles both single and double quotes
for key, quoted in re.findall(r'([a-zA-Z_]\w*)\s*=\s*(\'[^\']*\'|"[^"]*")', args_text):
# Remove quotes
args[key] = quoted[1:-1]
if args:
return args
# Try positional argument (single quoted string)
positional = re.match(r'^\s*(\'[^\']*\'|"[^"]*")\s*$', args_text)
if positional:
# Remove quotes
args["query"] = positional.group(1)[1:-1]
return args
# Fallback: treat entire args_text as query if it looks like a value
args_text = args_text.strip()
if args_text:
args["query"] = args_text
return args
def _parse_tool_call(text: str) -> Optional[Dict[str, Any]]:
extracted = _extract_tool_payload(text)
if not extracted:
return None
payload, raw_tool_call = extracted
payload = payload.strip()
print(f"[TOOL] Parsing payload preview: {payload[:120]}")
# If we have <|tool_call_start|> marker, we're 100% sure this is a tool call
# So use aggressive extraction strategies
has_marker = "<|tool_call_start|>" in raw_tool_call
if has_marker:
print("[TOOL] Tool call marker found - using aggressive extraction")
# Step 1: Try to extract query using regex (handles both quoted and unquoted values)
# Matches: query="value" or query=value (with value ending at comma/paren/space)
query_match = re.search(r'query\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|([^,)\s]+(?:\s+[^,)]+)*))', payload, re.IGNORECASE)
if query_match:
query = query_match.group(1) or query_match.group(2) or query_match.group(3)
if query:
query = query.strip()
print(f"[TOOL] Extracted query: {query}")
# Step 2: Extract or infer tool name
# Look for tool name patterns: web_search, websearch, search, search_web, etc.
tool_name = _extract_tool_name_from_payload(payload)
args = {"query": query}
# Also extract max_results if present
max_results_match = re.search(r'max_results?\s*=\s*(\d+)', payload, re.IGNORECASE)
if max_results_match:
args["max_results"] = int(max_results_match.group(1))
if tool_name:
print(f"[TOOL] Aggressive extraction successful: tool={tool_name}, query={query}")
return {
"tool_name": tool_name,
"args": args,
"raw_tool_call": raw_tool_call,
}
# Fall back to JSON parsing if no marker or aggressive extraction failed
if payload.startswith("{") and payload.endswith("}"):
try:
obj = json.loads(payload)
if isinstance(obj, dict) and "tool_call" in obj:
tool_call = obj.get("tool_call") or {}
name = _normalize_tool_name(str(tool_call.get("name", "")))
args = tool_call.get("args") or {}
if isinstance(args, dict) and name:
return {
"tool_name": name,
"args": args,
"raw_tool_call": raw_tool_call,
}
if isinstance(obj, dict) and "name" in obj:
name = _normalize_tool_name(str(obj.get("name", "")))
args = obj.get("args") if isinstance(obj.get("args"), dict) else {}
if name:
return {
"tool_name": name,
"args": args,
"raw_tool_call": raw_tool_call,
}
except Exception as json_err:
print(f"[TOOL] JSON parse failed: {json_err}, trying regex fallbacks")
# Continue to regex-based extraction if JSON fails
# Fallback: Try to extract tool_name and query using regex for malformed JSON
# Look for "tool_name": "value" pattern
tool_name_match = re.search(r'"tool_name"\s*:\s*"([^"]+)"', payload)
if not tool_name_match:
# Try alternate pattern: "search": "value" as tool name
search_match = re.search(r'"search"\s*:\s*"([^"]+)"', payload)
if search_match:
# If we find a search pattern with a value, it might be the query
# but we'll mark tool as web_search
tool_name = "web_search"
else:
tool_name = None
else:
tool_name = _normalize_tool_name(tool_name_match.group(1))
if tool_name:
# Look for "query": "value" or similar arguments
query_match = re.search(r'"query"\s*:\s*"([^"]+)"', payload)
args = {}
if query_match:
args["query"] = query_match.group(1)
else:
# Try to find any string value as fallback
first_string_match = re.search(r':\s*"([^"]+)"', payload)
if first_string_match:
args["query"] = first_string_match.group(1)
# Look for max_results/max_Results patterns
max_results_match = re.search(r'"max_[Rr]esults?"\s*:\s*(\d+)', payload)
if max_results_match:
args["max_results"] = int(max_results_match.group(1))
# Look for any other key-value pairs
for key, value in re.findall(r'"([a-zA-Z_]+)"\s*:\s*"([^"]*)"', payload):
if key.lower() not in ["tool_name", "name", "tool_cale", "search"]: # Skip redundant keys
args[key] = value
if tool_name or args:
if not tool_name:
tool_name = "web_search" # Default to web_search
print(f"[TOOL] Extracted via regex: tool_name={tool_name}, args={args}")
return {
"tool_name": tool_name,
"args": args,
"raw_tool_call": raw_tool_call,
}
# Try function call syntax: tool_name(args) or "web search"(query="...")
# Handle tool names with spaces by matching up to the first (
paren_pos = payload.find("(")
if paren_pos != -1:
tool_name_raw = payload[:paren_pos].strip()
args_text = payload[paren_pos+1:]
# Find matching closing paren
paren_close = args_text.rfind(")")
if paren_close != -1:
args_text = args_text[:paren_close].strip()
name = _normalize_tool_name(tool_name_raw)
args = _parse_args(args_text)
if name:
return {
"tool_name": name,
"args": args,
"raw_tool_call": raw_tool_call,
}
return None
def _extract_tool_name_from_payload(text: str) -> Optional[str]:
"""Extract tool name from payload text, handling various naming conventions."""
# Handle function call syntax: web_search(...) or search(...) etc.
paren_pos = text.find("(")
if paren_pos > 0:
potential_name = text[:paren_pos].strip()
# Remove brackets/quotes
potential_name = potential_name.strip("[]\"'")
normalized = _normalize_tool_name(potential_name)
if normalized:
return normalized
# Look for specific tool name patterns
# Check for variants: web_search, websearch, search_web, web search, etc.
variants = [
(r'\bweb[_\s]search\b', 'web_search'),
(r'\bwebsearch\b', 'web_search'),
(r'\bsearch[_\s]web\b', 'web_search'),
(r'\bsearch\b', 'web_search'), # Default 'search' to 'web_search'
(r'搜索', 'web_search'),
]
for pattern, tool in variants:
if re.search(pattern, text, re.IGNORECASE):
return tool
return None
def detect_tool_call(generation_result: Dict[str, Any]) -> Dict[str, Any]:
"""Return {'type':'tool'|'no_tool', 'args':...} from generator output."""
# STRICTLY use raw_text (the final response). Tool calls should appear in
# the final answer, NOT in internal thinking. Do not check raw_think.
raw_text = str(generation_result.get("raw_text", "")).strip()
if not raw_text:
# Fallback to text if raw_text is missing
raw_text = str(generation_result.get("text", "")).strip()
print(f"[TOOL] Detecting tool call from raw_text preview: {raw_text[:120]!r}")
parsed_tool = _parse_tool_call(raw_text)
if parsed_tool:
print(
f"[TOOL] Tool detected: name={parsed_tool['tool_name']}, "
f"args={parsed_tool.get('args', {})}"
)
return {
"type": "tool",
"args": {
"tool_name": parsed_tool["tool_name"],
"args": parsed_tool.get("args", {}),
"raw_tool_call": parsed_tool["raw_tool_call"],
"raw_text": raw_text,
"template_token_count": generation_result.get("template_token_count", 0),
"formatted_prompt": generation_result.get("formatted_prompt", ""),
"input_length": generation_result.get("input_length", 0),
"generated_tokens": generation_result.get("generated_tokens", 0),
},
}
print("[TOOL] No tool call detected; returning normal response payload")
no_tool_payload = {
"text": str(generation_result.get("text", _strip_special_tokens(raw_text))).strip(),
"raw_text": raw_text,
"template_token_count": generation_result.get("template_token_count", 0),
"formatted_prompt": generation_result.get("formatted_prompt", ""),
"input_length": generation_result.get("input_length", 0),
"generated_tokens": generation_result.get("generated_tokens", 0),
}
return {
"type": "no_tool",
"args": no_tool_payload,
}