| """Tool-call detection and parsing for Autobot Instruct outputs.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
| from typing import Any, Dict, Optional, Tuple |
|
|
| SPECIAL_TOKENS = [ |
| "<|im_end|>", |
| "<|im_start|>", |
| "<|endoftext|>", |
| "<|startoftext|>", |
| "<|tool_list_start|>", |
| "<|tool_list_end|>", |
| ] |
|
|
|
|
| def _strip_special_tokens(text: str) -> str: |
| cleaned = text |
| for token in SPECIAL_TOKENS: |
| cleaned = cleaned.replace(token, "") |
| return cleaned.strip() |
|
|
|
|
| def _normalize_tool_name(name: str) -> str: |
| if not name: |
| return "" |
| |
| normalized = name.strip().lower().replace(" ", "_").replace("-", "_") |
| normalized_ascii = re.sub(r"[^a-z0-9_]", "", normalized) |
| |
| |
| if "websearch" in normalized_ascii: |
| return "web_search" |
| if "searchweb" in normalized_ascii or "search_web" in normalized_ascii: |
| return "web_search" |
| if "web" in normalized_ascii and "search" in normalized_ascii: |
| return "web_search" |
| if "search" in normalized_ascii and "web" in normalized_ascii: |
| return "web_search" |
| if normalized_ascii == "search": |
| return "web_search" |
| if "搜索" in normalized or "搜" in normalized: |
| return "web_search" |
| if normalized_ascii.startswith("web_") or normalized_ascii == "web": |
| return "web_search" |
| |
| return normalized |
|
|
|
|
| def _extract_tool_payload(text: str) -> Optional[Tuple[str, str]]: |
| """Only extract tool payload when explicit <|tool_call_start|> marker is present.""" |
| start_token = "<|tool_call_start|>" |
| end_token = "<|tool_call_end|>" |
|
|
| if start_token not in text: |
| return None |
|
|
| print("[TOOL] Tool call marker detected in generated text") |
| start_pos = text.find(start_token) + len(start_token) |
| end_pos = text.find(end_token, start_pos) |
| if end_pos != -1: |
| payload = text[start_pos:end_pos].strip() |
| else: |
| payload = text[start_pos:].strip() |
|
|
| |
| while (payload.startswith("[") and payload.endswith("]")) or \ |
| (payload.startswith("(") and payload.endswith(")")): |
| payload = payload[1:-1].strip() |
|
|
| raw_tool_call = f"{start_token}{payload}{end_token}" |
| return payload, raw_tool_call |
|
|
|
|
| def _parse_args(args_text: str) -> Dict[str, Any]: |
| """Parse function call arguments like: query="value" or query='value' or just "value" """ |
| args: Dict[str, Any] = {} |
| |
| if not args_text.strip(): |
| return args |
|
|
| |
| |
| for key, quoted in re.findall(r'([a-zA-Z_]\w*)\s*=\s*(\'[^\']*\'|"[^"]*")', args_text): |
| |
| args[key] = quoted[1:-1] |
|
|
| if args: |
| return args |
|
|
| |
| positional = re.match(r'^\s*(\'[^\']*\'|"[^"]*")\s*$', args_text) |
| if positional: |
| |
| args["query"] = positional.group(1)[1:-1] |
| return args |
| |
| |
| args_text = args_text.strip() |
| if args_text: |
| args["query"] = args_text |
|
|
| return args |
|
|
|
|
| def _parse_tool_call(text: str) -> Optional[Dict[str, Any]]: |
| extracted = _extract_tool_payload(text) |
| if not extracted: |
| return None |
|
|
| payload, raw_tool_call = extracted |
| payload = payload.strip() |
| print(f"[TOOL] Parsing payload preview: {payload[:120]}") |
|
|
| |
| |
| has_marker = "<|tool_call_start|>" in raw_tool_call |
| |
| if has_marker: |
| print("[TOOL] Tool call marker found - using aggressive extraction") |
| |
| |
| |
| query_match = re.search(r'query\s*=\s*(?:"([^"]+)"|\'([^\']+)\'|([^,)\s]+(?:\s+[^,)]+)*))', payload, re.IGNORECASE) |
| if query_match: |
| query = query_match.group(1) or query_match.group(2) or query_match.group(3) |
| if query: |
| query = query.strip() |
| print(f"[TOOL] Extracted query: {query}") |
| |
| |
| |
| tool_name = _extract_tool_name_from_payload(payload) |
| |
| args = {"query": query} |
| |
| |
| max_results_match = re.search(r'max_results?\s*=\s*(\d+)', payload, re.IGNORECASE) |
| if max_results_match: |
| args["max_results"] = int(max_results_match.group(1)) |
| |
| if tool_name: |
| print(f"[TOOL] Aggressive extraction successful: tool={tool_name}, query={query}") |
| return { |
| "tool_name": tool_name, |
| "args": args, |
| "raw_tool_call": raw_tool_call, |
| } |
|
|
| |
| if payload.startswith("{") and payload.endswith("}"): |
| try: |
| obj = json.loads(payload) |
| if isinstance(obj, dict) and "tool_call" in obj: |
| tool_call = obj.get("tool_call") or {} |
| name = _normalize_tool_name(str(tool_call.get("name", ""))) |
| args = tool_call.get("args") or {} |
| if isinstance(args, dict) and name: |
| return { |
| "tool_name": name, |
| "args": args, |
| "raw_tool_call": raw_tool_call, |
| } |
| if isinstance(obj, dict) and "name" in obj: |
| name = _normalize_tool_name(str(obj.get("name", ""))) |
| args = obj.get("args") if isinstance(obj.get("args"), dict) else {} |
| if name: |
| return { |
| "tool_name": name, |
| "args": args, |
| "raw_tool_call": raw_tool_call, |
| } |
| except Exception as json_err: |
| print(f"[TOOL] JSON parse failed: {json_err}, trying regex fallbacks") |
| |
|
|
| |
| |
| tool_name_match = re.search(r'"tool_name"\s*:\s*"([^"]+)"', payload) |
| if not tool_name_match: |
| |
| search_match = re.search(r'"search"\s*:\s*"([^"]+)"', payload) |
| if search_match: |
| |
| |
| tool_name = "web_search" |
| else: |
| tool_name = None |
| else: |
| tool_name = _normalize_tool_name(tool_name_match.group(1)) |
| |
| if tool_name: |
| |
| query_match = re.search(r'"query"\s*:\s*"([^"]+)"', payload) |
| args = {} |
| if query_match: |
| args["query"] = query_match.group(1) |
| else: |
| |
| first_string_match = re.search(r':\s*"([^"]+)"', payload) |
| if first_string_match: |
| args["query"] = first_string_match.group(1) |
| |
| |
| max_results_match = re.search(r'"max_[Rr]esults?"\s*:\s*(\d+)', payload) |
| if max_results_match: |
| args["max_results"] = int(max_results_match.group(1)) |
| |
| |
| for key, value in re.findall(r'"([a-zA-Z_]+)"\s*:\s*"([^"]*)"', payload): |
| if key.lower() not in ["tool_name", "name", "tool_cale", "search"]: |
| args[key] = value |
| |
| if tool_name or args: |
| if not tool_name: |
| tool_name = "web_search" |
| print(f"[TOOL] Extracted via regex: tool_name={tool_name}, args={args}") |
| return { |
| "tool_name": tool_name, |
| "args": args, |
| "raw_tool_call": raw_tool_call, |
| } |
|
|
| |
| |
| paren_pos = payload.find("(") |
| if paren_pos != -1: |
| tool_name_raw = payload[:paren_pos].strip() |
| args_text = payload[paren_pos+1:] |
| |
| |
| paren_close = args_text.rfind(")") |
| if paren_close != -1: |
| args_text = args_text[:paren_close].strip() |
| |
| name = _normalize_tool_name(tool_name_raw) |
| args = _parse_args(args_text) |
| |
| if name: |
| return { |
| "tool_name": name, |
| "args": args, |
| "raw_tool_call": raw_tool_call, |
| } |
| |
| return None |
|
|
|
|
| def _extract_tool_name_from_payload(text: str) -> Optional[str]: |
| """Extract tool name from payload text, handling various naming conventions.""" |
| |
| paren_pos = text.find("(") |
| if paren_pos > 0: |
| potential_name = text[:paren_pos].strip() |
| |
| potential_name = potential_name.strip("[]\"'") |
| normalized = _normalize_tool_name(potential_name) |
| if normalized: |
| return normalized |
| |
| |
| |
| variants = [ |
| (r'\bweb[_\s]search\b', 'web_search'), |
| (r'\bwebsearch\b', 'web_search'), |
| (r'\bsearch[_\s]web\b', 'web_search'), |
| (r'\bsearch\b', 'web_search'), |
| (r'搜索', 'web_search'), |
| ] |
| |
| for pattern, tool in variants: |
| if re.search(pattern, text, re.IGNORECASE): |
| return tool |
| |
| return None |
|
|
|
|
| def detect_tool_call(generation_result: Dict[str, Any]) -> Dict[str, Any]: |
| """Return {'type':'tool'|'no_tool', 'args':...} from generator output.""" |
| |
| |
| raw_text = str(generation_result.get("raw_text", "")).strip() |
| if not raw_text: |
| |
| raw_text = str(generation_result.get("text", "")).strip() |
| |
| print(f"[TOOL] Detecting tool call from raw_text preview: {raw_text[:120]!r}") |
|
|
| parsed_tool = _parse_tool_call(raw_text) |
| if parsed_tool: |
| print( |
| f"[TOOL] Tool detected: name={parsed_tool['tool_name']}, " |
| f"args={parsed_tool.get('args', {})}" |
| ) |
| return { |
| "type": "tool", |
| "args": { |
| "tool_name": parsed_tool["tool_name"], |
| "args": parsed_tool.get("args", {}), |
| "raw_tool_call": parsed_tool["raw_tool_call"], |
| "raw_text": raw_text, |
| "template_token_count": generation_result.get("template_token_count", 0), |
| "formatted_prompt": generation_result.get("formatted_prompt", ""), |
| "input_length": generation_result.get("input_length", 0), |
| "generated_tokens": generation_result.get("generated_tokens", 0), |
| }, |
| } |
|
|
| print("[TOOL] No tool call detected; returning normal response payload") |
| no_tool_payload = { |
| "text": str(generation_result.get("text", _strip_special_tokens(raw_text))).strip(), |
| "raw_text": raw_text, |
| "template_token_count": generation_result.get("template_token_count", 0), |
| "formatted_prompt": generation_result.get("formatted_prompt", ""), |
| "input_length": generation_result.get("input_length", 0), |
| "generated_tokens": generation_result.get("generated_tokens", 0), |
| } |
| return { |
| "type": "no_tool", |
| "args": no_tool_payload, |
| } |
|
|