File size: 9,014 Bytes
bdc2878
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
"""Tool call parser — extract structured tool calls from model text output.

Tries multiple formats in priority order:
  1. <tool_calls> XML  (canonical format we inject)
  2. JSON envelope {"tool_calls": [...]}
  3. JSON array  [{"name": ..., "input": ...}]
  4. Alternative XML tags (<function_call>, <invoke>)

Returns a list of ParsedToolCall dataclasses.
"""

from __future__ import annotations

import json
import os
import re
import time
from dataclasses import dataclass, field
from typing import Any


# ---------------------------------------------------------------------------
# Data model
# ---------------------------------------------------------------------------

@dataclass
class ParsedToolCall:
    call_id: str
    name: str
    arguments: str          # always a JSON string

    @staticmethod
    def make(name: str, arguments: Any) -> "ParsedToolCall":
        call_id = f"call_{int(time.time() * 1000)}{os.urandom(3).hex()}"
        if isinstance(arguments, str):
            args_str = arguments
        else:
            try:
                args_str = json.dumps(arguments, ensure_ascii=False, separators=(",", ":"))
            except (TypeError, ValueError):
                args_str = "{}"
        return ParsedToolCall(call_id=call_id, name=name, arguments=args_str)


@dataclass
class ParseResult:
    calls: list[ParsedToolCall] = field(default_factory=list)
    saw_tool_syntax: bool = False   # detected XML/JSON envelope even if parsing failed


# ---------------------------------------------------------------------------
# Public entry point
# ---------------------------------------------------------------------------

def parse_tool_calls(
    text: str,
    available_tools: list[str] | None = None,
) -> ParseResult:
    """Parse tool calls from model-generated text.

    Args:
        text: Full or partial model output text.
        available_tools: If provided, only calls whose name appears in this
                         list are accepted (case-sensitive).
    """
    result = ParseResult()
    if not text or not text.strip():
        return result

    # Fast path: check whether tool-call syntax is present at all
    if not _has_tool_syntax(text):
        return result
    result.saw_tool_syntax = True

    # Try parsers in priority order
    calls = (
        _parse_xml_tool_calls(text)
        or _parse_json_envelope(text)
        or _parse_json_array(text)
        or _parse_alt_xml(text)
    )

    if calls and available_tools:
        calls = [c for c in calls if c.name in available_tools]

    result.calls = calls or []
    return result


# ---------------------------------------------------------------------------
# Syntax detection
# ---------------------------------------------------------------------------

_TOOL_SYNTAX_PATTERNS = re.compile(
    r"<tool_calls|<tool_call|<function_call|<invoke\s|"
    r'"tool_calls"\s*:|\btool_calls\b',
    re.IGNORECASE,
)

def _has_tool_syntax(text: str) -> bool:
    return bool(_TOOL_SYNTAX_PATTERNS.search(text))


# ---------------------------------------------------------------------------
# Parser 1: <tool_calls> XML (canonical)
# ---------------------------------------------------------------------------

_XML_ROOT_RE    = re.compile(r"<tool_calls\s*>(.*?)</tool_calls\s*>", re.DOTALL | re.IGNORECASE)
_XML_CALL_RE    = re.compile(r"<tool_call\s*>(.*?)</tool_call\s*>",   re.DOTALL | re.IGNORECASE)
_XML_NAME_RE    = re.compile(r"<tool_name\s*>(.*?)</tool_name\s*>",   re.DOTALL | re.IGNORECASE)
_XML_PARAMS_RE  = re.compile(r"<parameters\s*>(.*?)</parameters\s*>", re.DOTALL | re.IGNORECASE)


def _parse_xml_tool_calls(text: str) -> list[ParsedToolCall]:
    root_m = _XML_ROOT_RE.search(text)
    if not root_m:
        return []
    calls: list[ParsedToolCall] = []
    for call_m in _XML_CALL_RE.finditer(root_m.group(1)):
        inner = call_m.group(1)
        name_m   = _XML_NAME_RE.search(inner)
        params_m = _XML_PARAMS_RE.search(inner)
        if not name_m:
            continue
        name   = name_m.group(1).strip()
        params = params_m.group(1).strip() if params_m else "{}"
        parsed_args = _parse_json_tolerant(params)
        if parsed_args is None:
            continue
        calls.append(ParsedToolCall.make(name, parsed_args))
    return calls


# ---------------------------------------------------------------------------
# Parser 2: {"tool_calls": [...]} JSON envelope
# ---------------------------------------------------------------------------

def _parse_json_envelope(text: str) -> list[ParsedToolCall]:
    # Only attempt if the text literally contains "tool_calls" key
    if '"tool_calls"' not in text:
        return []
    obj = _extract_outermost_json_obj(text)
    if not isinstance(obj, dict):
        return []
    raw_calls = obj.get("tool_calls")
    if not isinstance(raw_calls, list):
        return []
    return _extract_from_call_list(raw_calls)


_JSON_DECODER = json.JSONDecoder()


def _extract_outermost_json_obj(text: str) -> Any:
    """Find and parse the first top-level JSON object in *text*.

    Uses JSONDecoder.raw_decode which handles the object boundary correctly
    without a manual bracket-depth tracker.
    """
    start = text.find("{")
    if start == -1:
        return None
    try:
        obj, _ = _JSON_DECODER.raw_decode(text, start)
        return obj
    except (json.JSONDecodeError, ValueError):
        # Attempt repair on the substring from first '{' onward
        end = text.rfind("}") + 1
        return _try_repair_json(text[start:end]) if end > start else None


# ---------------------------------------------------------------------------
# Parser 3: bare JSON array [{"name":..., "input":...}]
# ---------------------------------------------------------------------------

_JSON_ARR_RE = re.compile(r"\[[\s\S]+\]", re.DOTALL)

def _parse_json_array(text: str) -> list[ParsedToolCall]:
    m = _JSON_ARR_RE.search(text)
    if not m:
        return []
    try:
        arr = json.loads(m.group(0))
    except (json.JSONDecodeError, ValueError):
        return []
    if not isinstance(arr, list):
        return []
    return _extract_from_call_list(arr)


def _extract_from_call_list(items: list[Any]) -> list[ParsedToolCall]:
    calls: list[ParsedToolCall] = []
    for item in items:
        if not isinstance(item, dict):
            continue
        name = (item.get("name") or item.get("tool_name") or "").strip()
        args = item.get("input") or item.get("arguments") or item.get("parameters") or {}
        if not name:
            continue
        calls.append(ParsedToolCall.make(name, args))
    return calls


# ---------------------------------------------------------------------------
# Parser 4: alternative XML tags (<function_call>, <invoke name="...">)
# ---------------------------------------------------------------------------

_FC_RE      = re.compile(r"<function_call\s*>(.*?)</function_call\s*>", re.DOTALL | re.IGNORECASE)
_INVOKE_RE  = re.compile(r'<invoke\s+name=["\']?(\w+)["\']?\s*>(.*?)</invoke\s*>',  re.DOTALL | re.IGNORECASE)
_FC_NAME_RE = re.compile(r"<name\s*>(.*?)</name\s*>",                  re.DOTALL | re.IGNORECASE)
_FC_ARGS_RE = re.compile(r"<arguments\s*>(.*?)</arguments\s*>",        re.DOTALL | re.IGNORECASE)


def _parse_alt_xml(text: str) -> list[ParsedToolCall]:
    calls: list[ParsedToolCall] = []

    # <function_call><name>...</name><arguments>...</arguments></function_call>
    for m in _FC_RE.finditer(text):
        inner  = m.group(1)
        name_m = _FC_NAME_RE.search(inner)
        args_m = _FC_ARGS_RE.search(inner)
        if not name_m:
            continue
        name = name_m.group(1).strip()
        args = _parse_json_tolerant(args_m.group(1).strip() if args_m else "{}")
        if args is None:
            continue
        calls.append(ParsedToolCall.make(name, args))

    # <invoke name="tool_name">...</invoke>
    for m in _INVOKE_RE.finditer(text):
        name  = m.group(1).strip()
        inner = m.group(2)
        args  = _parse_json_tolerant(inner.strip())
        if args is None:
            args = {}
        calls.append(ParsedToolCall.make(name, args))

    return calls


# ---------------------------------------------------------------------------
# JSON helpers
# ---------------------------------------------------------------------------

def _parse_json_tolerant(s: str) -> Any:
    """Try to parse JSON; attempt light repair on failure."""
    if not s:
        return {}
    try:
        return json.loads(s)
    except (json.JSONDecodeError, ValueError):
        repaired = _try_repair_json(s)
        return repaired


def _try_repair_json(s: str) -> Any:
    """Very lightweight JSON repair: fix unescaped newlines inside strings."""
    try:
        # Replace literal newlines inside strings (common model output issue)
        fixed = re.sub(r'(?<!\\)\n', r'\\n', s)
        return json.loads(fixed)
    except (json.JSONDecodeError, ValueError):
        return None