|
|
""" |
|
|
XML Tool Call Parser Module |
|
|
|
|
|
This module provides a reliable XML tool call parsing system that supports |
|
|
the Cursor-style format with structured function_calls blocks. |
|
|
""" |
|
|
|
|
|
import re |
|
|
import xml.etree.ElementTree as ET |
|
|
from typing import List, Dict, Any, Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
import json |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class XMLToolCall: |
|
|
"""Represents a parsed XML tool call.""" |
|
|
function_name: str |
|
|
parameters: Dict[str, Any] |
|
|
raw_xml: str |
|
|
parsing_details: Dict[str, Any] |
|
|
|
|
|
|
|
|
class XMLToolParser: |
|
|
""" |
|
|
Parser for XML tool calls using the Cursor-style format: |
|
|
|
|
|
<function_calls> |
|
|
<invoke name="function_name"> |
|
|
<parameter name="param_name">param_value</parameter> |
|
|
... |
|
|
</invoke> |
|
|
</function_calls> |
|
|
""" |
|
|
|
|
|
|
|
|
FUNCTION_CALLS_PATTERN = re.compile( |
|
|
r'<function_calls>(.*?)</function_calls>', |
|
|
re.DOTALL | re.IGNORECASE |
|
|
) |
|
|
|
|
|
INVOKE_PATTERN = re.compile( |
|
|
r'<invoke\s+name=["\']([^"\']+)["\']>(.*?)</invoke>', |
|
|
re.DOTALL | re.IGNORECASE |
|
|
) |
|
|
|
|
|
PARAMETER_PATTERN = re.compile( |
|
|
r'<parameter\s+name=["\']([^"\']+)["\']>(.*?)</parameter>', |
|
|
re.DOTALL | re.IGNORECASE |
|
|
) |
|
|
|
|
|
def __init__(self, strict_mode: bool = False): |
|
|
""" |
|
|
Initialize the XML tool parser. |
|
|
|
|
|
Args: |
|
|
strict_mode: If True, only accept the exact format. If False, |
|
|
also try to parse legacy formats for backwards compatibility. |
|
|
""" |
|
|
self.strict_mode = strict_mode |
|
|
|
|
|
def parse_content(self, content: str) -> List[XMLToolCall]: |
|
|
""" |
|
|
Parse XML tool calls from content. |
|
|
|
|
|
Args: |
|
|
content: The text content potentially containing XML tool calls |
|
|
|
|
|
Returns: |
|
|
List of parsed XMLToolCall objects |
|
|
""" |
|
|
tool_calls = [] |
|
|
|
|
|
|
|
|
function_calls_matches = self.FUNCTION_CALLS_PATTERN.findall(content) |
|
|
|
|
|
for fc_content in function_calls_matches: |
|
|
|
|
|
invoke_matches = self.INVOKE_PATTERN.findall(fc_content) |
|
|
|
|
|
for function_name, invoke_content in invoke_matches: |
|
|
try: |
|
|
tool_call = self._parse_invoke_block( |
|
|
function_name, |
|
|
invoke_content, |
|
|
fc_content |
|
|
) |
|
|
if tool_call: |
|
|
tool_calls.append(tool_call) |
|
|
except Exception as e: |
|
|
logger.error(f"Error parsing invoke block for {function_name}: {e}") |
|
|
|
|
|
|
|
|
if not self.strict_mode and not tool_calls: |
|
|
tool_calls.extend(self._parse_legacy_format(content)) |
|
|
|
|
|
return tool_calls |
|
|
|
|
|
def _parse_invoke_block( |
|
|
self, |
|
|
function_name: str, |
|
|
invoke_content: str, |
|
|
full_block: str |
|
|
) -> Optional[XMLToolCall]: |
|
|
"""Parse a single invoke block into an XMLToolCall.""" |
|
|
parameters = {} |
|
|
parsing_details = { |
|
|
"format": "v2", |
|
|
"function_name": function_name, |
|
|
"raw_parameters": {} |
|
|
} |
|
|
|
|
|
|
|
|
param_matches = self.PARAMETER_PATTERN.findall(invoke_content) |
|
|
|
|
|
for param_name, param_value in param_matches: |
|
|
|
|
|
param_value = param_value.strip() |
|
|
|
|
|
|
|
|
parsed_value = self._parse_parameter_value(param_value) |
|
|
|
|
|
parameters[param_name] = parsed_value |
|
|
parsing_details["raw_parameters"][param_name] = param_value |
|
|
|
|
|
|
|
|
invoke_pattern = re.compile( |
|
|
rf'<invoke\s+name=["\']{re.escape(function_name)}["\']>.*?</invoke>', |
|
|
re.DOTALL | re.IGNORECASE |
|
|
) |
|
|
raw_xml_match = invoke_pattern.search(full_block) |
|
|
raw_xml = raw_xml_match.group(0) if raw_xml_match else f"<invoke name=\"{function_name}\">...</invoke>" |
|
|
|
|
|
return XMLToolCall( |
|
|
function_name=function_name, |
|
|
parameters=parameters, |
|
|
raw_xml=raw_xml, |
|
|
parsing_details=parsing_details |
|
|
) |
|
|
|
|
|
def _parse_parameter_value(self, value: str) -> Any: |
|
|
""" |
|
|
Parse a parameter value, attempting to convert to appropriate type. |
|
|
|
|
|
Args: |
|
|
value: The string value to parse |
|
|
|
|
|
Returns: |
|
|
Parsed value (could be dict, list, bool, int, float, or str) |
|
|
""" |
|
|
value = value.strip() |
|
|
|
|
|
|
|
|
if value.startswith(('{', '[')): |
|
|
try: |
|
|
return json.loads(value) |
|
|
except json.JSONDecodeError: |
|
|
pass |
|
|
|
|
|
|
|
|
if value.lower() in ('true', 'false'): |
|
|
return value.lower() == 'true' |
|
|
|
|
|
|
|
|
try: |
|
|
if '.' in value: |
|
|
return float(value) |
|
|
else: |
|
|
return int(value) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
return value |
|
|
|
|
|
def _parse_legacy_format(self, content: str) -> List[XMLToolCall]: |
|
|
""" |
|
|
Parse legacy XML tool formats for backwards compatibility. |
|
|
This handles formats like <tool_name>...</tool_name> or |
|
|
<tool_name param="value">...</tool_name> |
|
|
""" |
|
|
tool_calls = [] |
|
|
|
|
|
|
|
|
tag_pattern = re.compile(r'<([a-zA-Z][\w\-]*)((?:\s+[\w\-]+=["\'][^"\']*["\'])*)\s*>(.*?)</\1>', re.DOTALL) |
|
|
|
|
|
for match in tag_pattern.finditer(content): |
|
|
tag_name = match.group(1) |
|
|
attributes_str = match.group(2) |
|
|
inner_content = match.group(3) |
|
|
|
|
|
|
|
|
if tag_name in ('function_calls', 'invoke', 'parameter'): |
|
|
continue |
|
|
|
|
|
parameters = {} |
|
|
parsing_details = { |
|
|
"format": "legacy", |
|
|
"tag_name": tag_name, |
|
|
"attributes": {}, |
|
|
"inner_content": inner_content.strip() |
|
|
} |
|
|
|
|
|
|
|
|
if attributes_str: |
|
|
attr_pattern = re.compile(r'([\w\-]+)=["\']([^"\']*)["\']') |
|
|
for attr_match in attr_pattern.finditer(attributes_str): |
|
|
attr_name = attr_match.group(1) |
|
|
attr_value = attr_match.group(2) |
|
|
parameters[attr_name] = self._parse_parameter_value(attr_value) |
|
|
parsing_details["attributes"][attr_name] = attr_value |
|
|
|
|
|
|
|
|
if inner_content.strip() and not parameters: |
|
|
parameters['content'] = inner_content.strip() |
|
|
|
|
|
|
|
|
function_name = tag_name.replace('-', '_') |
|
|
|
|
|
tool_calls.append(XMLToolCall( |
|
|
function_name=function_name, |
|
|
parameters=parameters, |
|
|
raw_xml=match.group(0), |
|
|
parsing_details=parsing_details |
|
|
)) |
|
|
|
|
|
return tool_calls |
|
|
|
|
|
def format_tool_call(self, function_name: str, parameters: Dict[str, Any]) -> str: |
|
|
""" |
|
|
Format a tool call in the Cursor-style XML format. |
|
|
|
|
|
Args: |
|
|
function_name: Name of the function to call |
|
|
parameters: Dictionary of parameters |
|
|
|
|
|
Returns: |
|
|
Formatted XML string |
|
|
""" |
|
|
lines = ['<function_calls>', '<invoke name="{}">'.format(function_name)] |
|
|
|
|
|
for param_name, param_value in parameters.items(): |
|
|
|
|
|
if isinstance(param_value, (dict, list)): |
|
|
value_str = json.dumps(param_value) |
|
|
elif isinstance(param_value, bool): |
|
|
value_str = str(param_value).lower() |
|
|
else: |
|
|
value_str = str(param_value) |
|
|
|
|
|
lines.append('<parameter name="{}">{}</parameter>'.format( |
|
|
param_name, value_str |
|
|
)) |
|
|
|
|
|
lines.extend(['</invoke>', '</function_calls>']) |
|
|
return '\n'.join(lines) |
|
|
|
|
|
def validate_tool_call(self, tool_call: XMLToolCall, expected_params: Optional[Dict[str, type]] = None) -> Tuple[bool, Optional[str]]: |
|
|
""" |
|
|
Validate a tool call against expected parameters. |
|
|
|
|
|
Args: |
|
|
tool_call: The XMLToolCall to validate |
|
|
expected_params: Optional dict of parameter names to expected types |
|
|
|
|
|
Returns: |
|
|
Tuple of (is_valid, error_message) |
|
|
""" |
|
|
if not tool_call.function_name: |
|
|
return False, "Function name is required" |
|
|
|
|
|
if expected_params: |
|
|
for param_name, expected_type in expected_params.items(): |
|
|
if param_name not in tool_call.parameters: |
|
|
return False, f"Missing required parameter: {param_name}" |
|
|
|
|
|
param_value = tool_call.parameters[param_name] |
|
|
if not isinstance(param_value, expected_type): |
|
|
return False, f"Parameter {param_name} should be of type {expected_type.__name__}" |
|
|
|
|
|
return True, None |
|
|
|
|
|
|
|
|
|
|
|
def parse_xml_tool_calls(content: str, strict_mode: bool = False) -> List[XMLToolCall]: |
|
|
""" |
|
|
Parse XML tool calls from content. |
|
|
|
|
|
Args: |
|
|
content: The text content potentially containing XML tool calls |
|
|
strict_mode: If True, only accept the Cursor-style format |
|
|
|
|
|
Returns: |
|
|
List of parsed XMLToolCall objects |
|
|
""" |
|
|
parser = XMLToolParser(strict_mode=strict_mode) |
|
|
return parser.parse_content(content) |