kovsbo commited on
Commit
9539d7b
·
verified ·
1 Parent(s): 9847d73

Upload 2 files

Browse files
openpipe_dual_chat_template.jinja ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {{- bos_token }}
2
+ {%- set template_variant = template_variant | default("official") %}
3
+
4
+ {%- if template_variant == "llama31instruct" %}
5
+ {%- if not tools is defined %}
6
+ {%- set tools = none %}
7
+ {%- endif %}
8
+ {%- if messages[0]['role'] in ['system', 'developer'] %}
9
+ {%- set system_message = messages[0]['content']|trim %}
10
+ {%- set messages = messages[1:] %}
11
+ {%- else %}
12
+ {%- set system_message = "" %}
13
+ {%- endif %}
14
+ {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
15
+ {%- if tools is not none and tools|length > 0 %}
16
+ {{- "Environment: ipython\n" }}
17
+ {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
18
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
19
+ {{- "Do not use variables.\n\n" }}
20
+ {%- for t in tools %}
21
+ {{- t | tojson }}
22
+ {{- "\n\n" }}
23
+ {%- endfor %}
24
+ {%- endif %}
25
+ {{- system_message }}
26
+ {{- "<|eot_id|>" }}
27
+ {%- for message in messages %}
28
+ {%- if not (message.role == 'tool' or 'tool_calls' in message) %}
29
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
30
+ {%- elif 'tool_calls' in message %}
31
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
32
+ {%- for tool_call in message.tool_calls -%}
33
+ {{- '<|start_tool_call|>' }}
34
+ {{- '{"name": "' + tool_call.function.name + '", ' }}
35
+ {{- '"parameters": ' }}
36
+ {{- tool_call.function.arguments | tojson }}
37
+ {{- "}" }}
38
+ {{- '<|end_tool_call|>' }}
39
+ {%- if not loop.last %}, {% endif %}
40
+ {%- endfor -%}
41
+ {{- "<|eot_id|>" }}
42
+ {%- elif message.role == "tool" %}
43
+ {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
44
+ {%- if message.content is mapping or message.content is iterable %}
45
+ {{- message.content | tojson }}
46
+ {%- else %}
47
+ {{- message.content }}
48
+ {%- endif %}
49
+ {{- "<|eot_id|>" }}
50
+ {%- endif %}
51
+ {%- endfor %}
52
+ {%- if add_generation_prompt %}
53
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
54
+ {%- endif %}
55
+ {%- elif template_variant == "pipeline3" %}
56
+ {%- if not tools is defined %}
57
+ {%- set tools = none %}
58
+ {%- endif %}
59
+ {%- set functions = tools | map(attribute="function") | map(attribute="name") | list if tools is not none else none %}
60
+ {{- "### Instruction:\n" }}
61
+ {%- if functions is not none %}
62
+ {{- {"messages": messages, "functions": functions} | tojson }}
63
+ {%- else %}
64
+ {{- {"messages": messages} | tojson }}
65
+ {%- endif %}
66
+ {{- "\n\n### Response:\n" }}
67
+ {%- else %}
68
+ {%- if custom_tools is defined %}
69
+ {%- set tools = custom_tools %}
70
+ {%- endif %}
71
+ {%- if not tools_in_user_message is defined %}
72
+ {%- set tools_in_user_message = true %}
73
+ {%- endif %}
74
+ {%- if not date_string is defined %}
75
+ {%- set date_string = "26 Jul 2024" %}
76
+ {%- endif %}
77
+ {%- if not tools is defined %}
78
+ {%- set tools = none %}
79
+ {%- endif %}
80
+ {%- if messages[0]['role'] == 'system' %}
81
+ {%- set system_message = messages[0]['content']|trim %}
82
+ {%- set messages = messages[1:] %}
83
+ {%- else %}
84
+ {%- set system_message = "" %}
85
+ {%- endif %}
86
+ {{- "<|start_header_id|>system<|end_header_id|>\n\n" }}
87
+ {%- if builtin_tools is defined or tools is not none %}
88
+ {{- "Environment: ipython\n" }}
89
+ {%- endif %}
90
+ {%- if builtin_tools is defined %}
91
+ {{- "Tools: " + builtin_tools | reject('equalto', 'code_interpreter') | join(", ") + "\n\n" }}
92
+ {%- endif %}
93
+ {{- "Cutting Knowledge Date: December 2023\n" }}
94
+ {{- "Today Date: " + date_string + "\n\n" }}
95
+ {%- if tools is not none and not tools_in_user_message %}
96
+ {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }}
97
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
98
+ {{- "Do not use variables.\n\n" }}
99
+ {%- for t in tools %}
100
+ {{- t | tojson(indent=4) }}
101
+ {{- "\n\n" }}
102
+ {%- endfor %}
103
+ {%- endif %}
104
+ {{- system_message }}
105
+ {{- "<|eot_id|>" }}
106
+ {%- if tools_in_user_message and not tools is none %}
107
+ {%- if messages | length != 0 %}
108
+ {%- set first_user_message = messages[0]['content']|trim %}
109
+ {%- set messages = messages[1:] %}
110
+ {%- else %}
111
+ {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }}
112
+ {%- endif %}
113
+ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}}
114
+ {{- "Given the following functions, please respond with a JSON for a function call " }}
115
+ {{- "with its proper arguments that best answers the given prompt.\n\n" }}
116
+ {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }}
117
+ {{- "Do not use variables.\n\n" }}
118
+ {%- for t in tools %}
119
+ {{- t | tojson(indent=4) }}
120
+ {{- "\n\n" }}
121
+ {%- endfor %}
122
+ {{- first_user_message + "<|eot_id|>" }}
123
+ {%- endif %}
124
+ {%- for message in messages %}
125
+ {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %}
126
+ {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
127
+ {%- elif 'tool_calls' in message %}
128
+ {%- if not message.tool_calls|length == 1 %}
129
+ {{- raise_exception("This model only supports single tool-calls at once!") }}
130
+ {%- endif %}
131
+ {%- set tool_call = message.tool_calls[0].function %}
132
+ {%- if builtin_tools is defined and tool_call.name in builtin_tools %}
133
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
134
+ {{- "<|python_tag|>" + tool_call.name + ".call(" }}
135
+ {%- for arg_name, arg_val in tool_call.arguments | items %}
136
+ {{- arg_name + '="' + arg_val + '"' }}
137
+ {%- if not loop.last %}
138
+ {{- ", " }}
139
+ {%- endif %}
140
+ {%- endfor %}
141
+ {{- ")" }}
142
+ {%- else %}
143
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}}
144
+ {{- '{"name": "' + tool_call.name + '", ' }}
145
+ {{- '"parameters": ' }}
146
+ {{- tool_call.arguments | tojson }}
147
+ {{- "}" }}
148
+ {%- endif %}
149
+ {%- if builtin_tools is defined %}
150
+ {{- "<|eom_id|>" }}
151
+ {%- else %}
152
+ {{- "<|eot_id|>" }}
153
+ {%- endif %}
154
+ {%- elif message.role == "tool" or message.role == "ipython" %}
155
+ {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }}
156
+ {%- if message.content is mapping or message.content is iterable %}
157
+ {{- message.content | tojson }}
158
+ {%- else %}
159
+ {{- message.content }}
160
+ {%- endif %}
161
+ {{- "<|eot_id|>" }}
162
+ {%- endif %}
163
+ {%- endfor %}
164
+ {%- if add_generation_prompt %}
165
+ {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
166
+ {%- endif %}
167
+ {%- endif %}
openpipe_llama_dual.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from collections.abc import Sequence
3
+ from typing import Any, Optional
4
+
5
+ from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
6
+ from vllm.entrypoints.openai.engine.protocol import (
7
+ DeltaFunctionCall,
8
+ DeltaMessage,
9
+ DeltaToolCall,
10
+ ExtractedToolCallInformation,
11
+ FunctionCall,
12
+ ToolCall,
13
+ )
14
+ from vllm.tokenizers import TokenizerLike
15
+ from vllm.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
16
+
17
+
18
+ @ToolParserManager.register_module(["openpipe_llama_dual"])
19
+ class OpenPipeLlamaDualParser(ToolParser):
20
+ """Parse official JSON, llama31 tool markers, and pipeline3 function tags."""
21
+
22
+ LEGACY_START = "<|start_tool_call|>"
23
+ LEGACY_END = "<|end_tool_call|>"
24
+ FUNCTION_CALL_TAG = "<function>"
25
+ FUNCTION_ARGS_TAG = "<arguments>"
26
+ VARIANT_LLAMA31 = "llama31instruct"
27
+ VARIANT_PIPELINE3 = "pipeline3"
28
+ VARIANT_OFFICIAL = "official"
29
+
30
+ def __init__(self, tokenizer: TokenizerLike):
31
+ super().__init__(tokenizer)
32
+ self.tokenizer = tokenizer
33
+
34
+ def _get_template_variant(self, request: ChatCompletionRequest) -> Optional[str]:
35
+ kwargs = getattr(request, "chat_template_kwargs", None)
36
+ if kwargs is None:
37
+ return None
38
+ if isinstance(kwargs, dict):
39
+ value = kwargs.get("template_variant")
40
+ return value if isinstance(value, str) else None
41
+ value = getattr(kwargs, "template_variant", None)
42
+ return value if isinstance(value, str) else None
43
+
44
+ def _normalize_tool_call(self, payload: dict[str, Any]) -> Optional[dict[str, Any]]:
45
+ if "name" in payload and "parameters" in payload:
46
+ return {
47
+ "name": payload["name"],
48
+ "arguments": payload["parameters"],
49
+ }
50
+ if "function" in payload and isinstance(payload["function"], dict):
51
+ function = payload["function"]
52
+ if "name" in function and "arguments" in function:
53
+ return {
54
+ "name": function["name"],
55
+ "arguments": function["arguments"],
56
+ }
57
+ return None
58
+
59
+ def _extract_legacy_tool_calls(self, text: str) -> list[dict[str, Any]]:
60
+ tool_calls = []
61
+ current_index = 0
62
+
63
+ while True:
64
+ start_index = text.find(self.LEGACY_START, current_index)
65
+ if start_index == -1:
66
+ break
67
+
68
+ end_index = text.find(self.LEGACY_END, start_index)
69
+ if end_index == -1:
70
+ break
71
+
72
+ tool_call_json = text[start_index + len(self.LEGACY_START) : end_index].strip()
73
+ payload = json.loads(tool_call_json)
74
+ normalized = self._normalize_tool_call(payload)
75
+ if normalized:
76
+ tool_calls.append(normalized)
77
+ current_index = end_index + len(self.LEGACY_END)
78
+
79
+ return tool_calls
80
+
81
+ def _extract_function_tag_tool_calls(self, text: str) -> list[dict[str, Any]]:
82
+ tool_calls = []
83
+ current_index = 0
84
+
85
+ while True:
86
+ function_start = text.find(self.FUNCTION_CALL_TAG, current_index)
87
+ if function_start == -1:
88
+ break
89
+
90
+ name_start = function_start + len(self.FUNCTION_CALL_TAG)
91
+ args_tag_index = text.find(self.FUNCTION_ARGS_TAG, name_start)
92
+ if args_tag_index == -1:
93
+ break
94
+
95
+ function_name = text[name_start:args_tag_index].strip()
96
+ if not function_name:
97
+ break
98
+
99
+ arguments_start = args_tag_index + len(self.FUNCTION_ARGS_TAG)
100
+ next_function_index = text.find(self.FUNCTION_CALL_TAG, arguments_start)
101
+ if next_function_index == -1:
102
+ arguments_raw = text[arguments_start:].strip()
103
+ current_index = len(text)
104
+ else:
105
+ arguments_raw = text[arguments_start:next_function_index].strip()
106
+ current_index = next_function_index
107
+
108
+ if not arguments_raw:
109
+ arguments: Any = ""
110
+ else:
111
+ try:
112
+ arguments = json.loads(arguments_raw)
113
+ except Exception:
114
+ arguments = arguments_raw
115
+
116
+ tool_calls.append(
117
+ {
118
+ "name": function_name,
119
+ "arguments": arguments,
120
+ }
121
+ )
122
+
123
+ return tool_calls
124
+
125
+ def _extract_official_tool_call(self, text: str) -> Optional[dict[str, Any]]:
126
+ stripped = text.strip()
127
+ if not stripped.startswith("{") or not stripped.endswith("}"):
128
+ return None
129
+ payload = json.loads(stripped)
130
+ return self._normalize_tool_call(payload)
131
+
132
+ def _build_delta_tool_call(self, tool_call: dict[str, Any], index: int = 0) -> DeltaMessage:
133
+ arguments = tool_call["arguments"]
134
+ return DeltaMessage(
135
+ tool_calls=[
136
+ DeltaToolCall(
137
+ index=index,
138
+ id=f"call_{tool_call['name']}",
139
+ type="function",
140
+ function=DeltaFunctionCall(
141
+ name=tool_call["name"],
142
+ arguments=json.dumps(arguments, ensure_ascii=False)
143
+ if isinstance(arguments, (dict, list))
144
+ else arguments,
145
+ ),
146
+ )
147
+ ]
148
+ )
149
+
150
+ def _build_tool_calls_response(
151
+ self,
152
+ tool_calls: list[dict[str, Any]],
153
+ ) -> ExtractedToolCallInformation:
154
+ return ExtractedToolCallInformation(
155
+ tools_called=True,
156
+ tool_calls=[
157
+ ToolCall(
158
+ id=f"call_{index + 1}",
159
+ type="function",
160
+ function=FunctionCall(
161
+ name=tool_call["name"],
162
+ arguments=json.dumps(
163
+ tool_call["arguments"], ensure_ascii=False
164
+ )
165
+ if isinstance(tool_call["arguments"], (dict, list))
166
+ else tool_call["arguments"],
167
+ ),
168
+ )
169
+ for index, tool_call in enumerate(tool_calls)
170
+ ],
171
+ content=None,
172
+ )
173
+
174
+ def _looks_like_partial_official_json(self, text: str) -> bool:
175
+ stripped = text.strip()
176
+ if not stripped.startswith("{"):
177
+ return False
178
+ if stripped.endswith("}"):
179
+ return False
180
+ return (
181
+ '"name"' in stripped
182
+ or '"parameters"' in stripped
183
+ or '"function"' in stripped
184
+ )
185
+
186
+ def extract_tool_calls_streaming(
187
+ self,
188
+ previous_text: str,
189
+ current_text: str,
190
+ delta_text: str,
191
+ previous_token_ids: Sequence[int],
192
+ current_token_ids: Sequence[int],
193
+ delta_token_ids: Sequence[int],
194
+ request: ChatCompletionRequest,
195
+ ) -> DeltaMessage | None:
196
+ variant = self._get_template_variant(request)
197
+
198
+ try:
199
+ if (
200
+ variant == self.VARIANT_LLAMA31
201
+ or self.LEGACY_START in current_text
202
+ ):
203
+ if self.LEGACY_START in current_text and self.LEGACY_END in current_text:
204
+ tool_calls = self._extract_legacy_tool_calls(current_text)
205
+ if tool_calls:
206
+ return self._build_delta_tool_call(
207
+ tool_calls[-1], index=len(tool_calls) - 1
208
+ )
209
+ if self.LEGACY_START in current_text:
210
+ return None
211
+ return DeltaMessage(content=delta_text)
212
+
213
+ if variant == self.VARIANT_PIPELINE3 or self.FUNCTION_CALL_TAG in current_text:
214
+ tool_calls = self._extract_function_tag_tool_calls(current_text)
215
+ if tool_calls:
216
+ return self._build_delta_tool_call(
217
+ tool_calls[-1], index=len(tool_calls) - 1
218
+ )
219
+ return None
220
+
221
+ official_tool_call = self._extract_official_tool_call(current_text)
222
+ if official_tool_call:
223
+ return self._build_delta_tool_call(official_tool_call)
224
+ if variant == self.VARIANT_OFFICIAL and self._looks_like_partial_official_json(
225
+ current_text
226
+ ):
227
+ return None
228
+ except Exception:
229
+ return DeltaMessage(content=delta_text)
230
+
231
+ return DeltaMessage(content=delta_text)
232
+
233
+ def extract_tool_calls(
234
+ self,
235
+ model_output: str,
236
+ request: ChatCompletionRequest,
237
+ ) -> ExtractedToolCallInformation:
238
+ variant = self._get_template_variant(request)
239
+
240
+ try:
241
+ if (
242
+ variant == self.VARIANT_LLAMA31
243
+ or self.LEGACY_START in model_output
244
+ ):
245
+ tool_calls = self._extract_legacy_tool_calls(model_output)
246
+ if tool_calls:
247
+ return self._build_tool_calls_response(tool_calls)
248
+
249
+ if variant == self.VARIANT_PIPELINE3 or self.FUNCTION_CALL_TAG in model_output:
250
+ tool_calls = self._extract_function_tag_tool_calls(model_output)
251
+ if tool_calls:
252
+ return self._build_tool_calls_response(tool_calls)
253
+
254
+ official_tool_call = self._extract_official_tool_call(model_output)
255
+ if official_tool_call:
256
+ return self._build_tool_calls_response([official_tool_call])
257
+ except Exception:
258
+ pass
259
+
260
+ return ExtractedToolCallInformation(
261
+ tools_called=False,
262
+ tool_calls=[],
263
+ content=model_output,
264
+ )