WenminDeng commited on
Commit
493f80f
·
verified ·
1 Parent(s): c7913c1

Upload 5 files

Browse files
TeleChat3-Coder_vllm_tool_parser/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ *.egg-info/
TeleChat3-Coder_vllm_tool_parser/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### 针对 VLLM 框架的 TeleCoder3 模型补丁
2
+
3
+ #### 目前支持以下功能
4
+ - 工具调用解析
5
+ - 支持 Interleaved Thinking 模式:chat 接口的 message 支持 reasoning_content 字段,并拼接到 模型上下文中
6
+
7
+ #### 使用教程
8
+ 1. 安装插件(只需要执行一次)
9
+ ```shell
10
+ pip install -e .
11
+ ```
12
+ 2. 如果需要工具解析,启动vllm服务的时候加上参数`--enable-auto-tool-choice --tool-call-parser telechat3`
TeleChat3-Coder_vllm_tool_parser/setup.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+
3
+ setup(
4
+ name="telechat3_vllm_patch",
5
+ version="1.1.0",
6
+ packages=find_packages(
7
+ include=["telechat3_tool_parser.py", "telechat3_reasoning.py"]
8
+ ),
9
+ entry_points={
10
+ "vllm.general_plugins": [
11
+ "telechat3_tool_parser = telechat3_tool_parser:register_tool_parser",
12
+ "telechat3_reasoning = telechat3_reasoning:register_reasoning",
13
+ ]
14
+ },
15
+ )
TeleChat3-Coder_vllm_tool_parser/telechat3_reasoning.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import vllm
2
+ from vllm.entrypoints.chat_utils import (
3
+ ChatCompletionMessageParam,
4
+ ConversationMessage,
5
+ BaseMultiModalItemTracker,
6
+ _ChatTemplateContentFormat,
7
+ _parse_chat_message_content,
8
+ )
9
+ from vllm.logger import init_logger
10
+
11
+ logger = init_logger(__name__)
12
+
13
+
14
+ def _telechat3_parse_chat_message_content(
15
+ message: ChatCompletionMessageParam,
16
+ mm_tracker: BaseMultiModalItemTracker,
17
+ content_format: _ChatTemplateContentFormat,
18
+ ) -> list[ConversationMessage]:
19
+ result = _parse_chat_message_content(message, mm_tracker, content_format)
20
+ reasoning_content = message.get("reasoning_content")
21
+
22
+ if len(result) > 0 and reasoning_content:
23
+ logger.info("add reasoning content to input prompt.")
24
+ result[0].update({"reasoning_content": reasoning_content})
25
+
26
+ return result
27
+
28
+
29
+ def register_reasoning():
30
+ vllm.entrypoints.chat_utils._parse_chat_message_content = (
31
+ _telechat3_parse_chat_message_content
32
+ )
TeleChat3-Coder_vllm_tool_parser/telechat3_tool_parser.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ast
2
+ import json
3
+ import regex as re
4
+ from collections.abc import Sequence
5
+ from typing import List, Any
6
+
7
+ from transformers import PreTrainedTokenizerBase
8
+ from vllm.entrypoints.openai.protocol import (
9
+ ChatCompletionRequest,
10
+ ChatCompletionToolsParam,
11
+ DeltaFunctionCall,
12
+ DeltaMessage,
13
+ DeltaToolCall,
14
+ ExtractedToolCallInformation,
15
+ FunctionCall,
16
+ ToolCall,
17
+ )
18
+ from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import (
19
+ ToolParser,
20
+ ToolParserManager,
21
+ )
22
+
23
+ from vllm.logger import init_logger
24
+
25
+ logger = init_logger(__name__)
26
+
27
+
28
+ def _is_string_type(
29
+ tool_name: str, arg_name: str, tools: List[ChatCompletionToolsParam] | None
30
+ ):
31
+ if tools is None:
32
+ return False
33
+ for tool in tools:
34
+ if tool.function.name == tool_name:
35
+ if tool.function.parameters is None:
36
+ return False
37
+ arg_type = (
38
+ tool.function.parameters.get("properties", {})
39
+ .get(arg_name, {})
40
+ .get("type", None)
41
+ )
42
+ return arg_type == "string"
43
+ logger.debug("No tool named '%s'.", tool_name)
44
+ return False
45
+
46
+
47
+ def _deserialize(value: str) -> Any:
48
+ try:
49
+ return json.loads(value)
50
+ except Exception:
51
+ pass
52
+
53
+ try:
54
+ return ast.literal_eval(value)
55
+ except Exception:
56
+ pass
57
+ return value
58
+
59
+
60
+ @ToolParserManager.register_module("telechat3")
61
+ class TeleChat3ModelToolParser(ToolParser):
62
+ """
63
+ Tool call parser for TeleChat3-36B models.
64
+ Used when --enable-auto-tool-choice --tool-call-parser telechat3
65
+ """
66
+
67
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
68
+ super().__init__(tokenizer)
69
+
70
+ # initialize properties used for state when parsing tool calls in
71
+ # streaming mode
72
+ self.current_tool_id: int = -1
73
+
74
+ self.tool_start_token = "<tool_call>"
75
+ self.tool_end_token = "</tool_call>"
76
+
77
+ self.func_detail_regex = re.compile(
78
+ r"<tool_call>(.*?)(<param_key>.*?)?</tool_call>", re.DOTALL
79
+ )
80
+ self.func_arg_regex = re.compile(
81
+ r"<param_key>(.*?)</param_key>(?:\\n|\s)*<param_value>(.*?)</param_value>",
82
+ re.DOTALL,
83
+ )
84
+ self._buffer = ""
85
+
86
+ def extract_tool_calls(self, model_output: str, request: ChatCompletionRequest):
87
+
88
+ matched_tool_calls = self.func_detail_regex.findall(model_output)
89
+ logger.debug("model_output: %s", model_output)
90
+
91
+ tool_calls = []
92
+ try:
93
+ for match in matched_tool_calls:
94
+ tc_name = match[0].strip()
95
+ arg_dict = {}
96
+ if len(match) > 1:
97
+ for key, value in self.func_arg_regex.findall(match[1]):
98
+ arg_key = key.strip()
99
+ arg_val = value.strip()
100
+ if not _is_string_type(tc_name, key, request.tools):
101
+ arg_val = _deserialize(arg_val)
102
+ logger.debug("arg_key = %s, arg_val = %s", arg_key, arg_val)
103
+ arg_dict[arg_key] = arg_val
104
+ tool_calls.append(
105
+ ToolCall(
106
+ type="function",
107
+ function=FunctionCall(
108
+ name=tc_name,
109
+ arguments=json.dumps(arg_dict, ensure_ascii=False),
110
+ ),
111
+ )
112
+ )
113
+ except Exception:
114
+ logger.exception("Failed to extract tool call spec")
115
+ return ExtractedToolCallInformation(
116
+ tools_called=False, tool_calls=[], content=model_output
117
+ )
118
+ else:
119
+ if len(tool_calls) > 0:
120
+ content = model_output[: model_output.find(self.tool_start_token)]
121
+ return ExtractedToolCallInformation(
122
+ tools_called=True, tool_calls=tool_calls, content=content
123
+ )
124
+ return ExtractedToolCallInformation(
125
+ tools_called=False, tool_calls=[], content=model_output
126
+ )
127
+
128
+ def extract_tool_calls_streaming(
129
+ self,
130
+ previous_text: str,
131
+ current_text: str,
132
+ delta_text: str,
133
+ previous_token_ids: Sequence[int],
134
+ current_token_ids: Sequence[int],
135
+ delta_token_ids: Sequence[int],
136
+ request: ChatCompletionRequest,
137
+ ) -> DeltaMessage | None:
138
+ self._buffer += delta_text
139
+ cur_text = self._buffer
140
+ start_idx = cur_text.find(self.tool_start_token)
141
+ if start_idx == -1:
142
+ self._buffer = ""
143
+ return DeltaMessage(content=cur_text)
144
+ logger.debug("cur_text = %s", cur_text)
145
+ end_idx = cur_text.find(self.tool_end_token)
146
+ if end_idx != -1:
147
+ extracted_tool_calls = self.extract_tool_calls(
148
+ cur_text[: end_idx + len(self.tool_end_token)], request
149
+ )
150
+ if len(extracted_tool_calls.tool_calls) == 0:
151
+ logger.warning("Failed to extract any tool calls.")
152
+ return None
153
+ self.current_tool_id += 1
154
+ tool_call = extracted_tool_calls.tool_calls[0]
155
+ delta = DeltaMessage(
156
+ content=extracted_tool_calls.content,
157
+ tool_calls=[
158
+ DeltaToolCall(
159
+ index=self.current_tool_id,
160
+ id=tool_call.id,
161
+ type=tool_call.type,
162
+ function=DeltaFunctionCall(
163
+ name=tool_call.function.name,
164
+ arguments=tool_call.function.arguments,
165
+ ),
166
+ )
167
+ ],
168
+ )
169
+ self._buffer = cur_text[end_idx + len(self.tool_end_token) :]
170
+ return delta
171
+ self._buffer = cur_text[start_idx:]
172
+ return DeltaMessage(content=cur_text[:start_idx])
173
+
174
+
175
+ def register_tool_parser(): ...