Explyt commited on
Commit
0acccec
·
verified ·
1 Parent(s): 95bc39e

Upload glm47_moe_tool_parser_fixed.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. glm47_moe_tool_parser_fixed.py +532 -0
glm47_moe_tool_parser_fixed.py ADDED
@@ -0,0 +1,532 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ GLM-4 Tool Call Parser with incremental string streaming support.
5
+
6
+ This parser fixes the streaming issue reported in Issue #32829 where long string
7
+ parameters (e.g., file content with 4000+ characters of code) are buffered until
8
+ complete, causing multi-second delays before the user sees any content.
9
+
10
+ The fix streams string values incrementally as they arrive, providing a true
11
+ streaming experience for long content.
12
+ """
13
+
14
+ import ast
15
+ import json
16
+ from collections.abc import Sequence
17
+ from typing import Any
18
+
19
+ import regex as re
20
+
21
+ from vllm.entrypoints.chat_utils import make_tool_call_id
22
+ from vllm.entrypoints.openai.chat_completion.protocol import (
23
+ ChatCompletionRequest,
24
+ ChatCompletionToolsParam,
25
+ )
26
+ from vllm.entrypoints.openai.engine.protocol import (
27
+ DeltaFunctionCall,
28
+ DeltaMessage,
29
+ DeltaToolCall,
30
+ ExtractedToolCallInformation,
31
+ FunctionCall,
32
+ ToolCall,
33
+ )
34
+ from vllm.logger import init_logger
35
+ from vllm.tokenizers import TokenizerLike
36
+ from vllm.tool_parsers.abstract_tool_parser import (
37
+ ToolParser,
38
+ ToolParserManager,
39
+ )
40
+
41
+ logger = init_logger(__name__)
42
+
43
+
44
+ @ToolParserManager.register_module("glm47_fixed")
45
+ class Glm47MoeModelToolParser(ToolParser):
46
+ """Tool parser for GLM-4 models with incremental string streaming.
47
+
48
+ This parser emits tool-call deltas incrementally as arguments arrive.
49
+ For string-type parameters, content is streamed character-by-character
50
+ rather than waiting for the complete </arg_value> tag.
51
+ """
52
+
53
+ def __init__(self, tokenizer: TokenizerLike):
54
+ super().__init__(tokenizer)
55
+ # Stateful streaming fields
56
+ self.current_tool_name_sent: bool = False
57
+ self.prev_tool_call_arr: list[dict[str, Any]] = []
58
+ self.current_tool_id: int = -1
59
+ self.streamed_args_for_tool: list[str] = []
60
+
61
+ self.tool_call_start_token: str = "<tool_call>"
62
+ self.tool_call_end_token: str = "</tool_call>"
63
+ self.arg_key_start: str = "<arg_key>"
64
+ self.arg_key_end: str = "</arg_key>"
65
+ self.arg_val_start: str = "<arg_value>"
66
+ self.arg_val_end: str = "</arg_value>"
67
+
68
+ self.tool_calls_start_token = self.tool_call_start_token
69
+
70
+ self.func_call_regex = re.compile(r"<tool_call>.*?</tool_call>", re.DOTALL)
71
+
72
+ # GLM-4.7 format: <tool_call>func_name[<arg_key>...]*</tool_call>
73
+ # The function name can be followed by a newline, whitespace, or
74
+ # directly by <arg_key> tags (no separator). The arg section is
75
+ # optional so that zero-argument calls are supported.
76
+ self.func_detail_regex = re.compile(
77
+ r"<tool_call>\s*(\S+?)\s*(<arg_key>.*)?</tool_call>", re.DOTALL
78
+ )
79
+ self.func_arg_regex = re.compile(
80
+ r"<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>",
81
+ re.DOTALL,
82
+ )
83
+
84
+ if not self.model_tokenizer:
85
+ raise ValueError(
86
+ "The model tokenizer must be passed to the ToolParser "
87
+ "constructor during construction."
88
+ )
89
+
90
+ self.tool_call_start_token_id = self.vocab.get(self.tool_call_start_token)
91
+ self.tool_call_end_token_id = self.vocab.get(self.tool_call_end_token)
92
+ self._buffer: str = ""
93
+
94
+ # Streaming state for incremental tool-call streaming
95
+ self._in_tool_call: bool = False
96
+ self._current_tool_name: str | None = None
97
+ self._pending_key: str | None = None
98
+ self._streaming_string_value: bool = False
99
+ self._tool_call_ids: list[str] = []
100
+ self._args_started: list[bool] = []
101
+ self._args_closed: list[bool] = []
102
+ self._seen_keys: list[set[str]] = []
103
+
104
+ @staticmethod
105
+ def _deserialize(value: str) -> Any:
106
+ try:
107
+ return json.loads(value)
108
+ except json.JSONDecodeError:
109
+ pass
110
+
111
+ try:
112
+ return ast.literal_eval(value)
113
+ except (ValueError, SyntaxError):
114
+ pass
115
+
116
+ return value
117
+
118
+ @staticmethod
119
+ def _json_escape_string_content(s: str) -> str:
120
+ """JSON-escape string content for incremental streaming.
121
+
122
+ This escapes the content that goes INSIDE a JSON string (between quotes),
123
+ not including the surrounding quotes themselves.
124
+ """
125
+ if not s:
126
+ return ""
127
+ return json.dumps(s, ensure_ascii=False)[1:-1]
128
+
129
+ @staticmethod
130
+ def _is_string_type(
131
+ tool_name: str,
132
+ arg_name: str,
133
+ tools: list[ChatCompletionToolsParam] | None,
134
+ ) -> bool:
135
+ if tools is None:
136
+ return False
137
+ for tool in tools:
138
+ if tool.function.name != tool_name:
139
+ continue
140
+ if tool.function.parameters is None:
141
+ return False
142
+ arg_type = (
143
+ tool.function.parameters.get("properties", {})
144
+ .get(arg_name, {})
145
+ .get("type", None)
146
+ )
147
+ return arg_type == "string"
148
+ logger.debug("No tool named '%s'.", tool_name)
149
+ return False
150
+
151
+ @staticmethod
152
+ def _tools_enabled(request: ChatCompletionRequest) -> bool:
153
+ """Return whether tool parsing should be applied for this request."""
154
+ try:
155
+ tools = getattr(request, "tools", None)
156
+ tool_choice = getattr(request, "tool_choice", None)
157
+ return bool(tools) and tool_choice != "none"
158
+ except Exception:
159
+ logger.exception("Failed to determine if tools are enabled.")
160
+ return False
161
+
162
+ def adjust_request(self, request: ChatCompletionRequest) -> ChatCompletionRequest:
163
+ """Adjust request parameters for tool call token handling."""
164
+ request = super().adjust_request(request)
165
+ if request.tools and request.tool_choice != "none":
166
+ # Ensure tool call tokens (<tool_call>, </tool_call>) are not skipped
167
+ # during decoding. Even though they are not marked as special tokens,
168
+ # setting skip_special_tokens=False ensures proper handling in
169
+ # transformers 5.x where decoding behavior may have changed.
170
+ request.skip_special_tokens = False
171
+ return request
172
+
173
+ def extract_tool_calls(
174
+ self,
175
+ model_output: str,
176
+ request: ChatCompletionRequest,
177
+ ) -> ExtractedToolCallInformation:
178
+ matched_tool_calls = self.func_call_regex.findall(model_output)
179
+ logger.debug("model_output: %s", model_output)
180
+ try:
181
+ tool_calls: list[ToolCall] = []
182
+ for match in matched_tool_calls:
183
+ tc_detail = self.func_detail_regex.search(match)
184
+ if not tc_detail:
185
+ logger.warning(
186
+ "Failed to parse tool call details from: %s",
187
+ match,
188
+ )
189
+ continue
190
+ tc_name = tc_detail.group(1).strip()
191
+ tc_args = tc_detail.group(2)
192
+ pairs = self.func_arg_regex.findall(tc_args) if tc_args else []
193
+ arg_dct: dict[str, Any] = {}
194
+ for key, value in pairs:
195
+ arg_key = key.strip()
196
+ arg_val = value.strip()
197
+ if not self._is_string_type(tc_name, arg_key, request.tools):
198
+ arg_val = self._deserialize(arg_val)
199
+ logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
200
+ arg_dct[arg_key] = arg_val
201
+ tool_calls.append(
202
+ ToolCall(
203
+ type="function",
204
+ function=FunctionCall(
205
+ name=tc_name,
206
+ arguments=json.dumps(arg_dct, ensure_ascii=False),
207
+ ),
208
+ )
209
+ )
210
+ except Exception:
211
+ logger.exception("Failed to extract tool call spec")
212
+ return ExtractedToolCallInformation(
213
+ tools_called=False, tool_calls=[], content=model_output
214
+ )
215
+ else:
216
+ if len(tool_calls) > 0:
217
+ content: str | None = model_output[
218
+ : model_output.find(self.tool_calls_start_token)
219
+ ]
220
+ # Normalize empty/whitespace-only content to None
221
+ if not content or not content.strip():
222
+ content = None
223
+ return ExtractedToolCallInformation(
224
+ tools_called=True, tool_calls=tool_calls, content=content
225
+ )
226
+ return ExtractedToolCallInformation(
227
+ tools_called=False, tool_calls=[], content=model_output
228
+ )
229
+
230
+ def extract_tool_calls_streaming(
231
+ self,
232
+ previous_text: str,
233
+ current_text: str,
234
+ delta_text: str,
235
+ previous_token_ids: Sequence[int],
236
+ current_token_ids: Sequence[int],
237
+ delta_token_ids: Sequence[int],
238
+ request: ChatCompletionRequest,
239
+ ) -> DeltaMessage | None:
240
+ if not self._tools_enabled(request):
241
+ return DeltaMessage(content=delta_text) if delta_text else None
242
+
243
+ self._buffer += delta_text
244
+
245
+ while True:
246
+ if not self._in_tool_call:
247
+ start_idx = self._buffer.find(self.tool_call_start_token)
248
+ if start_idx == -1:
249
+ # Check for partial start token at end of buffer
250
+ for i in range(1, len(self.tool_call_start_token)):
251
+ if self._buffer.endswith(self.tool_call_start_token[:i]):
252
+ out = self._buffer[:-i]
253
+ self._buffer = self._buffer[-i:]
254
+ return DeltaMessage(content=out) if out else None
255
+ out = self._buffer
256
+ self._buffer = ""
257
+ return DeltaMessage(content=out) if out else None
258
+
259
+ if start_idx > 0:
260
+ out = self._buffer[:start_idx]
261
+ self._buffer = self._buffer[start_idx:]
262
+ return DeltaMessage(content=out) if out else None
263
+
264
+ self._buffer = self._buffer[len(self.tool_call_start_token) :]
265
+ self._begin_tool_call()
266
+ continue
267
+
268
+ # Parse tool name first
269
+ if not self.current_tool_name_sent:
270
+ nl = self._buffer.find("\n")
271
+ ak = self._buffer.find(self.arg_key_start)
272
+ end = self._buffer.find(self.tool_call_end_token)
273
+ candidates = [i for i in [nl, ak, end] if i != -1]
274
+ if not candidates:
275
+ return None
276
+ cut = min(candidates)
277
+ tool_name = self._buffer[:cut].strip()
278
+ if tool_name == "" and cut == end:
279
+ # Handle empty tool call like `<tool_call></tool_call>`.
280
+ # Consume the tokens and reset state to avoid infinite loop.
281
+ self._buffer = self._buffer[end + len(self.tool_call_end_token) :]
282
+ self._finish_tool_call()
283
+ self._revert_last_tool_call_state()
284
+ continue
285
+
286
+ if cut == nl:
287
+ self._buffer = self._buffer[nl + 1 :]
288
+ else:
289
+ self._buffer = self._buffer[cut:]
290
+
291
+ self._current_tool_name = tool_name
292
+ self.current_tool_name_sent = True
293
+ return self._emit_tool_name_delta(tool_name)
294
+
295
+ assert self._current_tool_name is not None
296
+
297
+ # Handle incremental string value streaming
298
+ if self._streaming_string_value:
299
+ val_end = self._buffer.find(self.arg_val_end)
300
+ if val_end != -1:
301
+ raw_content = self._buffer[:val_end]
302
+ self._buffer = self._buffer[val_end + len(self.arg_val_end) :]
303
+ self._streaming_string_value = False
304
+ self._pending_key = None
305
+
306
+ escaped = self._json_escape_string_content(raw_content)
307
+ frag = escaped + '"'
308
+ self.streamed_args_for_tool[self.current_tool_id] += frag
309
+ return self._emit_tool_args_delta(frag)
310
+ else:
311
+ # Check for partial </arg_value> at end
312
+ safe_len = len(self._buffer)
313
+ for i in range(1, len(self.arg_val_end)):
314
+ if self._buffer.endswith(self.arg_val_end[:i]):
315
+ safe_len = len(self._buffer) - i
316
+ break
317
+
318
+ if safe_len > 0:
319
+ to_emit = self._buffer[:safe_len]
320
+ self._buffer = self._buffer[safe_len:]
321
+ escaped = self._json_escape_string_content(to_emit)
322
+ if escaped:
323
+ self.streamed_args_for_tool[self.current_tool_id] += escaped
324
+ return self._emit_tool_args_delta(escaped)
325
+ return None
326
+
327
+ # If we have a pending key, parse its value
328
+ if self._pending_key is not None:
329
+ val_pos = self._buffer.find(self.arg_val_start)
330
+ if val_pos == -1:
331
+ return None
332
+ if val_pos > 0:
333
+ self._buffer = self._buffer[val_pos:]
334
+
335
+ key = (self._pending_key or "").strip()
336
+
337
+ is_string = self._is_string_type(
338
+ self._current_tool_name, key, request.tools
339
+ )
340
+
341
+ if is_string:
342
+ # String type: stream incrementally
343
+ self._buffer = self._buffer[len(self.arg_val_start) :]
344
+
345
+ if key in self._seen_keys[self.current_tool_id]:
346
+ self._pending_key = None
347
+ continue
348
+
349
+ self._seen_keys[self.current_tool_id].add(key)
350
+ key_json = json.dumps(key, ensure_ascii=False)
351
+
352
+ if not self._args_started[self.current_tool_id]:
353
+ frag = "{" + key_json + ': "'
354
+ self._args_started[self.current_tool_id] = True
355
+ else:
356
+ frag = ", " + key_json + ': "'
357
+
358
+ self.streamed_args_for_tool[self.current_tool_id] += frag
359
+ self._streaming_string_value = True
360
+ return self._emit_tool_args_delta(frag)
361
+ else:
362
+ # Non-string type: wait for complete value
363
+ val_end = self._buffer.find(self.arg_val_end)
364
+ if val_end == -1:
365
+ return None
366
+
367
+ raw_val = self._buffer[len(self.arg_val_start) : val_end].strip()
368
+ self._buffer = self._buffer[val_end + len(self.arg_val_end) :]
369
+ self._pending_key = None
370
+
371
+ frag_or_none = self._append_arg_fragment(key=key, raw_val=raw_val)
372
+ if frag_or_none:
373
+ return self._emit_tool_args_delta(frag_or_none)
374
+ continue
375
+
376
+ # Parse next arg or close
377
+ end_pos = self._buffer.find(self.tool_call_end_token)
378
+ key_pos = self._buffer.find(self.arg_key_start)
379
+ if end_pos != -1 and (key_pos == -1 or end_pos < key_pos):
380
+ self._buffer = self._buffer[end_pos + len(self.tool_call_end_token) :]
381
+ frag_or_none = self._close_args_if_needed()
382
+ # Finalize prev_tool_call_arr with complete parsed arguments
383
+ if self._current_tool_name:
384
+ try:
385
+ full_args_str = self.streamed_args_for_tool[
386
+ self.current_tool_id
387
+ ]
388
+ json.loads(full_args_str)
389
+ self.prev_tool_call_arr[self.current_tool_id] = {
390
+ "name": self._current_tool_name,
391
+ "arguments": full_args_str,
392
+ }
393
+ except (json.JSONDecodeError, IndexError) as e:
394
+ logger.warning(
395
+ "Failed to finalize tool call state for tool %d: %s",
396
+ self.current_tool_id,
397
+ e,
398
+ )
399
+ self._finish_tool_call()
400
+ return (
401
+ self._emit_tool_args_delta(frag_or_none) if frag_or_none else None
402
+ )
403
+
404
+ if key_pos == -1:
405
+ return None
406
+ if key_pos > 0:
407
+ self._buffer = self._buffer[key_pos:]
408
+ key_end = self._buffer.find(self.arg_key_end)
409
+ if key_end == -1:
410
+ return None
411
+ key = self._buffer[len(self.arg_key_start) : key_end]
412
+ self._buffer = self._buffer[key_end + len(self.arg_key_end) :]
413
+ self._pending_key = key
414
+ continue
415
+
416
+ def _ensure_tool_state(self) -> None:
417
+ while len(self._tool_call_ids) <= self.current_tool_id:
418
+ self._tool_call_ids.append(
419
+ make_tool_call_id(id_type="random", func_name=None, idx=None)
420
+ )
421
+ while len(self.streamed_args_for_tool) <= self.current_tool_id:
422
+ self.streamed_args_for_tool.append("")
423
+ while len(self.prev_tool_call_arr) <= self.current_tool_id:
424
+ self.prev_tool_call_arr.append({})
425
+ while len(self._args_started) <= self.current_tool_id:
426
+ self._args_started.append(False)
427
+ while len(self._args_closed) <= self.current_tool_id:
428
+ self._args_closed.append(False)
429
+ while len(self._seen_keys) <= self.current_tool_id:
430
+ self._seen_keys.append(set())
431
+
432
+ def _begin_tool_call(self) -> None:
433
+ if self.current_tool_id == -1:
434
+ self.current_tool_id = 0
435
+ else:
436
+ self.current_tool_id += 1
437
+ self._ensure_tool_state()
438
+ self.current_tool_name_sent = False
439
+ self._current_tool_name = None
440
+ self._pending_key = None
441
+ self._streaming_string_value = False
442
+ self._in_tool_call = True
443
+
444
+ def _finish_tool_call(self) -> None:
445
+ self._in_tool_call = False
446
+ self._current_tool_name = None
447
+ self._pending_key = None
448
+ self._streaming_string_value = False
449
+
450
+ def _revert_last_tool_call_state(self) -> None:
451
+ """Revert the state allocation for the last tool call."""
452
+ if self.current_tool_id < 0:
453
+ return
454
+ self._tool_call_ids.pop()
455
+ self.streamed_args_for_tool.pop()
456
+ self.prev_tool_call_arr.pop()
457
+ self._args_started.pop()
458
+ self._args_closed.pop()
459
+ self._seen_keys.pop()
460
+ self.current_tool_id -= 1
461
+
462
+ def _emit_tool_name_delta(self, tool_name: str) -> DeltaMessage:
463
+ self.prev_tool_call_arr[self.current_tool_id] = {
464
+ "name": self._current_tool_name,
465
+ "arguments": {},
466
+ }
467
+ return DeltaMessage(
468
+ tool_calls=[
469
+ DeltaToolCall(
470
+ index=self.current_tool_id,
471
+ id=self._tool_call_ids[self.current_tool_id],
472
+ type="function",
473
+ function=DeltaFunctionCall(
474
+ name=tool_name,
475
+ arguments="",
476
+ ).model_dump(exclude_none=True),
477
+ )
478
+ ]
479
+ )
480
+
481
+ def _emit_tool_args_delta(self, fragment: str) -> DeltaMessage:
482
+ return DeltaMessage(
483
+ tool_calls=[
484
+ DeltaToolCall(
485
+ index=self.current_tool_id,
486
+ function=DeltaFunctionCall(arguments=fragment).model_dump(
487
+ exclude_none=True
488
+ ),
489
+ )
490
+ ]
491
+ )
492
+
493
+ def _append_arg_fragment(
494
+ self,
495
+ *,
496
+ key: str,
497
+ raw_val: str,
498
+ ) -> str | None:
499
+ key = key.strip()
500
+ if not key:
501
+ return None
502
+ if key in self._seen_keys[self.current_tool_id]:
503
+ return None
504
+
505
+ # This function is only called for non-string types (already checked
506
+ # by _is_string_type in the caller), so we always deserialize.
507
+ val_obj: Any = self._deserialize(raw_val)
508
+
509
+ key_json = json.dumps(key, ensure_ascii=False)
510
+ val_json = json.dumps(val_obj, ensure_ascii=False)
511
+
512
+ if not self._args_started[self.current_tool_id]:
513
+ fragment = "{" + key_json + ": " + val_json
514
+ self._args_started[self.current_tool_id] = True
515
+ else:
516
+ fragment = "," + key_json + ": " + val_json
517
+
518
+ self._seen_keys[self.current_tool_id].add(key)
519
+ self.streamed_args_for_tool[self.current_tool_id] += fragment
520
+ return fragment
521
+
522
+ def _close_args_if_needed(self) -> str | None:
523
+ if self._args_closed[self.current_tool_id]:
524
+ return None
525
+ self._args_closed[self.current_tool_id] = True
526
+ if not self._args_started[self.current_tool_id]:
527
+ fragment = "{}"
528
+ self.streamed_args_for_tool[self.current_tool_id] = fragment
529
+ else:
530
+ fragment = "}"
531
+ self.streamed_args_for_tool[self.current_tool_id] += fragment
532
+ return fragment