llaa33219 commited on
Commit
3f88742
·
verified ·
1 Parent(s): 91ebbd3

Upload solar_open_logits_processor.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. solar_open_logits_processor.py +763 -0
solar_open_logits_processor.py ADDED
@@ -0,0 +1,763 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2025 Upstage AI.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+ from enum import Enum
18
+ from typing import TYPE_CHECKING
19
+
20
+ import torch
21
+
22
+ from vllm.sampling_params import SamplingParams
23
+ from vllm.v1.sample.logits_processor import (
24
+ AdapterLogitsProcessor,
25
+ RequestLogitsProcessor,
26
+ )
27
+
28
+ if TYPE_CHECKING:
29
+ from vllm.config import VllmConfig
30
+
31
+ # Hardcoded token IDs for Solar tokenizer
32
+
33
+ # Special token IDs for chat template
34
+ BEGIN_TOKEN_ID = 20 # <|begin|>
35
+ END_TOKEN_ID = 21 # <|end|>
36
+ THINK_TOKEN_ID = 22 # <|think|>
37
+ CONTENT_TOKEN_ID = 23 # <|content|>
38
+ FLUSH_TOKEN_ID = 24 # <|flush|> (eos token)
39
+ ASSISTANT_TOKEN_ID = 163444 # assistant
40
+ '''
41
+ 'assistant' is not a special token exactly, but is treated as one in the logits
42
+ processing.
43
+ '''
44
+
45
+ # Tool call related tokens
46
+ CALLS_TOKEN_ID = 25 # <|calls|> (eos token for tool calls)
47
+ TOOL_CALLS_TOKEN_ID = 30 # <|tool_calls|>
48
+ TOOL_CALL_BEGIN_TOKEN_ID = 31 # <|tool_call:begin|>
49
+ TOOL_CALL_END_TOKEN_ID = 32 # <|tool_call:end|>
50
+ TOOL_CALL_NAME_TOKEN_ID = 33 # <|tool_call:name|>
51
+ TOOL_CALL_ARGS_TOKEN_ID = 34 # <|tool_call:args|>
52
+
53
+ # =============================================================================
54
+ # Dynamic Reasoning Budget Configuration
55
+ # =============================================================================
56
+ # budget = min(max_budget, max(min_budget, max_tokens * ratio / 100))
57
+ # Priority: max_budget > min_budget > ratio
58
+ #
59
+ # Available environment variables:
60
+ # HIGH effort:
61
+ # SOLAR_REASONING_BUDGET_HIGH_MAX (default: 32768) - max_budget
62
+ # SOLAR_REASONING_BUDGET_HIGH_MIN (default: 8192) - min_budget
63
+ # SOLAR_REASONING_BUDGET_HIGH_RATIO (default: 60) - % of max_tokens
64
+ #
65
+ # MEDIUM effort:
66
+ # SOLAR_REASONING_BUDGET_MEDIUM_MAX (default: 16384) - max_budget
67
+ # SOLAR_REASONING_BUDGET_MEDIUM_MIN (default: 4096) - min_budget
68
+ # SOLAR_REASONING_BUDGET_MEDIUM_RATIO (default: 30) - % of max_tokens
69
+ #
70
+ # Tool call:
71
+ # SOLAR_TOOL_CALL_ID_BUDGET (default: 10) - Max tokens for tool call ID
72
+ # =============================================================================
73
+
74
+ DEFAULT_REASONING_EFFORT = "high"
75
+
76
+ # HIGH effort settings (1k = 1024 tokens)
77
+ DEFAULT_REASONING_BUDGET_HIGH_MAX = 32 * 1024
78
+ DEFAULT_REASONING_BUDGET_HIGH_MIN = 8 * 1024
79
+ DEFAULT_REASONING_BUDGET_HIGH_RATIO = 60
80
+
81
+ # MEDIUM effort settings
82
+ DEFAULT_REASONING_BUDGET_MEDIUM_MAX = 16 * 1024
83
+ DEFAULT_REASONING_BUDGET_MEDIUM_MIN = 4 * 1024
84
+ DEFAULT_REASONING_BUDGET_MEDIUM_RATIO = 30
85
+
86
+ # Tool call settings
87
+ DEFAULT_TOOL_CALL_ID_BUDGET = 10
88
+
89
+ # Pre-computed constant to avoid repeated string parsing
90
+ NEG_INF = float("-inf")
91
+
92
+
93
+ def is_reasoning_request(params: SamplingParams) -> bool:
94
+ """Check if the request is a reasoning request based on reasoning_effort."""
95
+ return (params.reasoning_effort is None) or (params.reasoning_effort in ("medium", "high"))
96
+
97
+
98
+ def is_structured_outputs(params: SamplingParams) -> bool:
99
+ """Check if the request has structured outputs constraints."""
100
+ return (
101
+ params.structured_outputs is not None
102
+ and not params.structured_outputs.all_constraints_none()
103
+ )
104
+
105
+
106
+ class GenerationState(Enum):
107
+ """Enum representing the current state of response generation."""
108
+
109
+ # Initial state - no tokens generated yet
110
+ INITIAL = "initial"
111
+
112
+ # New message states (after think_end)
113
+ NEW_MESSAGE_BEGIN = "new_message_begin" # <|begin|> token was just generated
114
+ NEW_MESSAGE_ASSISTANT = "new_message_assistant" # assistant token after <|begin|>
115
+
116
+ # Think mode states
117
+ THINK_BEGIN = "think_begin" # <|think|> token was just generated
118
+ THINK_IN_PROGRESS = "think_in_progress" # Generating think content
119
+ THINK_END = "think_end" # <|end|> after think content
120
+ THINK_FLUSH = "think_flush" # <|flush|> after think content
121
+
122
+ # Content states
123
+ CONTENT_BEGIN = "content_begin" # <|content|> token was just generated
124
+ CONTENT_IN_PROGRESS = "content_in_progress" # Generating content
125
+ CONTENT_END = "content_end" # <|end|> or <|flush|> after content
126
+ CONTENT_FLUSH = "content_flush" # <|flush|> after content
127
+
128
+ # Tool call states
129
+ # Flow: <|tool_calls|> -> (<|tool_call:begin|> -> id -> <|tool_call:name|> -> name -> <|tool_call:args|> -> args -> <|tool_call:end|>)+ -> <|calls|>
130
+ # Note: Think message can appear before <|tool_calls|>
131
+ TOOL_CALLS_BEGIN = "tool_calls_begin" # <|tool_calls|> token was just generated
132
+ TOOL_CALL_BEGIN = "tool_call_begin" # <|tool_call:begin|> token was just generated
133
+ TOOL_CALL_ID_IN_PROGRESS = "tool_call_id_in_progress" # Generating tool call ID
134
+ TOOL_CALL_NAME_BEGIN = "tool_call_name_begin" # <|tool_call:name|> token was just generated
135
+ TOOL_CALL_NAME_IN_PROGRESS = "tool_call_name_in_progress" # Generating tool name
136
+ TOOL_CALL_ARGS_BEGIN = "tool_call_args_begin" # <|tool_call:args|> token was just generated
137
+ TOOL_CALL_ARGS_IN_PROGRESS = "tool_call_args_in_progress" # Generating tool arguments (JSON)
138
+ TOOL_CALL_END = "tool_call_end" # <|tool_call:end|> token was just generated (can start another tool call or end)
139
+ CALLS = "calls" # <|calls|> token was just generated (eos token for tool calls)
140
+
141
+
142
+ def get_generation_state(
143
+ output_token_ids: list[int],
144
+ begin_token_id: int = BEGIN_TOKEN_ID,
145
+ end_token_id: int = END_TOKEN_ID,
146
+ flush_token_id: int = FLUSH_TOKEN_ID,
147
+ think_token_id: int = THINK_TOKEN_ID,
148
+ content_token_id: int = CONTENT_TOKEN_ID,
149
+ tool_calls_token_id: int = TOOL_CALLS_TOKEN_ID,
150
+ tool_call_begin_token_id: int = TOOL_CALL_BEGIN_TOKEN_ID,
151
+ tool_call_name_token_id: int = TOOL_CALL_NAME_TOKEN_ID,
152
+ tool_call_args_token_id: int = TOOL_CALL_ARGS_TOKEN_ID,
153
+ tool_call_end_token_id: int = TOOL_CALL_END_TOKEN_ID,
154
+ calls_token_id: int = CALLS_TOKEN_ID,
155
+ assistant_token_id: int = ASSISTANT_TOKEN_ID,
156
+ ) -> GenerationState:
157
+ """Determine the current generation state based on output token IDs.
158
+
159
+ Analyzes the sequence of generated tokens to determine which phase
160
+ of the chat template the generation is currently in.
161
+
162
+ Response format specs:
163
+ - think mode: <|think|>{{think-tokens}}<|end|><|begin|>assistant<|content|>{{content-tokens}}<|flush|>
164
+ - tool mode: <|begin|>assistant<|tool_calls|><|tool_call:begin|>{{id}}<|tool_call:name|>{{name}}<|tool_call:args|>{{args}}<|tool_call:end|><|calls|>
165
+ - tool mode (with think): <|think|>{{think-tokens}}<|end|><|begin|>assistant<|tool_calls|>...<|calls|>
166
+ - no-think mode: <|content|>{{content-tokens}}<|flush|>
167
+
168
+ Args:
169
+ output_token_ids: List of token IDs generated so far.
170
+ begin_token_id: Token ID for <|begin|>.
171
+ end_token_id: Token ID for <|end|>.
172
+ flush_token_id: Token ID for <|flush|> (eos).
173
+ think_token_id: Token ID for <|think|>.
174
+ content_token_id: Token ID for <|content|>.
175
+ tool_calls_token_id: Token ID for <|tool_calls|>.
176
+ tool_call_begin_token_id: Token ID for <|tool_call:begin|>.
177
+ tool_call_name_token_id: Token ID for <|tool_call:name|>.
178
+ tool_call_args_token_id: Token ID for <|tool_call:args|>.
179
+ tool_call_end_token_id: Token ID for <|tool_call:end|>.
180
+ calls_token_id: Token ID for <|calls|> (eos).
181
+ assistant_token_id: Token ID for assistant.
182
+
183
+ Returns:
184
+ GenerationState indicating the current phase of generation.
185
+ """
186
+ if not output_token_ids:
187
+ return GenerationState.INITIAL
188
+
189
+ # Track state by scanning through tokens
190
+ state = GenerationState.INITIAL
191
+ in_think = False
192
+ in_content = False
193
+
194
+ for token_id in output_token_ids:
195
+ if token_id == think_token_id:
196
+ state = GenerationState.THINK_BEGIN
197
+ in_think = True
198
+ in_content = False
199
+
200
+ elif token_id == content_token_id:
201
+ state = GenerationState.CONTENT_BEGIN
202
+ in_content = True
203
+ in_think = False
204
+
205
+ elif token_id == tool_calls_token_id:
206
+ state = GenerationState.TOOL_CALLS_BEGIN
207
+ in_think = False
208
+ in_content = False
209
+
210
+ elif token_id == tool_call_begin_token_id:
211
+ state = GenerationState.TOOL_CALL_BEGIN
212
+
213
+ elif token_id == tool_call_name_token_id:
214
+ state = GenerationState.TOOL_CALL_NAME_BEGIN
215
+
216
+ elif token_id == tool_call_args_token_id:
217
+ state = GenerationState.TOOL_CALL_ARGS_BEGIN
218
+
219
+ elif token_id == tool_call_end_token_id:
220
+ state = GenerationState.TOOL_CALL_END
221
+
222
+ elif token_id == calls_token_id:
223
+ state = GenerationState.CALLS
224
+
225
+ elif token_id == begin_token_id:
226
+ state = GenerationState.NEW_MESSAGE_BEGIN
227
+
228
+ elif token_id == assistant_token_id:
229
+ if state == GenerationState.NEW_MESSAGE_BEGIN:
230
+ state = GenerationState.NEW_MESSAGE_ASSISTANT
231
+
232
+ elif token_id == end_token_id:
233
+ if in_think:
234
+ state = GenerationState.THINK_END
235
+ in_think = False
236
+ elif in_content:
237
+ state = GenerationState.CONTENT_END
238
+ in_content = False
239
+
240
+ elif token_id == flush_token_id:
241
+ if in_think:
242
+ state = GenerationState.THINK_FLUSH
243
+ in_think = False
244
+ elif in_content:
245
+ state = GenerationState.CONTENT_FLUSH
246
+ in_content = False
247
+
248
+ else:
249
+ # Regular token - update state based on current context
250
+ if state == GenerationState.THINK_BEGIN:
251
+ state = GenerationState.THINK_IN_PROGRESS
252
+ elif state == GenerationState.THINK_IN_PROGRESS:
253
+ pass # Stay in think_in_progress
254
+ elif state == GenerationState.CONTENT_BEGIN:
255
+ state = GenerationState.CONTENT_IN_PROGRESS
256
+ elif state == GenerationState.CONTENT_IN_PROGRESS:
257
+ pass # Stay in content_in_progress
258
+ elif state == GenerationState.TOOL_CALL_BEGIN:
259
+ state = GenerationState.TOOL_CALL_ID_IN_PROGRESS
260
+ elif state == GenerationState.TOOL_CALL_ID_IN_PROGRESS:
261
+ pass # Stay in tool_call_id_in_progress
262
+ elif state == GenerationState.TOOL_CALL_NAME_BEGIN:
263
+ state = GenerationState.TOOL_CALL_NAME_IN_PROGRESS
264
+ elif state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS:
265
+ pass # Stay in tool_call_name_in_progress
266
+ elif state == GenerationState.TOOL_CALL_ARGS_BEGIN:
267
+ state = GenerationState.TOOL_CALL_ARGS_IN_PROGRESS
268
+ elif state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS:
269
+ pass # Stay in tool_call_args_in_progress
270
+
271
+ return state
272
+
273
+
274
+ # Pre-computed list of all special token IDs for batch indexing
275
+ _ALL_SPECIAL_TOKEN_IDS = [
276
+ BEGIN_TOKEN_ID,
277
+ END_TOKEN_ID,
278
+ THINK_TOKEN_ID,
279
+ CONTENT_TOKEN_ID,
280
+ FLUSH_TOKEN_ID,
281
+ CALLS_TOKEN_ID,
282
+ TOOL_CALLS_TOKEN_ID,
283
+ TOOL_CALL_BEGIN_TOKEN_ID,
284
+ TOOL_CALL_END_TOKEN_ID,
285
+ TOOL_CALL_NAME_TOKEN_ID,
286
+ TOOL_CALL_ARGS_TOKEN_ID,
287
+ ]
288
+
289
+ # Pre-computed lists for state-specific batch indexing (excluding allowed tokens)
290
+ _SPECIAL_EXCEPT_END = [ # For THINK states (allow END)
291
+ BEGIN_TOKEN_ID, FLUSH_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID,
292
+ TOOL_CALLS_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID,
293
+ TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID,
294
+ ]
295
+
296
+ _SPECIAL_EXCEPT_CONTENT_TOOLCALLS = [ # For NEW_MESSAGE_ASSISTANT (allow CONTENT, TOOL_CALLS)
297
+ THINK_TOKEN_ID, BEGIN_TOKEN_ID, END_TOKEN_ID, FLUSH_TOKEN_ID,
298
+ CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID,
299
+ TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID,
300
+ ]
301
+
302
+ _SPECIAL_EXCEPT_FLUSH = [ # For CONTENT states (allow FLUSH)
303
+ BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID,
304
+ TOOL_CALLS_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALL_BEGIN_TOKEN_ID,
305
+ TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID,
306
+ ]
307
+
308
+ _SPECIAL_EXCEPT_TOOLCALL_NAME = [ # For TOOL_CALL_ID_IN_PROGRESS (allow TOOL_CALL_NAME)
309
+ BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID,
310
+ FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID,
311
+ TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID,
312
+ ]
313
+
314
+ _SPECIAL_EXCEPT_TOOLCALL_ARGS = [ # For TOOL_CALL_NAME_IN_PROGRESS (allow TOOL_CALL_ARGS)
315
+ BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID,
316
+ FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID,
317
+ TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_END_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID,
318
+ ]
319
+
320
+ _SPECIAL_EXCEPT_TOOLCALL_END = [ # For TOOL_CALL_ARGS_IN_PROGRESS (allow TOOL_CALL_END)
321
+ BEGIN_TOKEN_ID, END_TOKEN_ID, THINK_TOKEN_ID, CONTENT_TOKEN_ID,
322
+ FLUSH_TOKEN_ID, CALLS_TOKEN_ID, TOOL_CALLS_TOKEN_ID,
323
+ TOOL_CALL_BEGIN_TOKEN_ID, TOOL_CALL_NAME_TOKEN_ID, TOOL_CALL_ARGS_TOKEN_ID,
324
+ ]
325
+
326
+
327
+ def _forbid_all_special_tokens(logits: torch.Tensor) -> None:
328
+ """Set all special token logits to -inf."""
329
+ logits[_ALL_SPECIAL_TOKEN_IDS] = NEG_INF
330
+
331
+
332
+ class SolarOpenTemplateEnforcer:
333
+ """Request-level logits processor that enforces Solar Open chat template.
334
+
335
+ Enforces the following generation rules:
336
+ - think mode: <|think|>{{tokens}}<|end|><|begin|>assistant<|content|>{{tokens}}<|flush|>
337
+ - tool mode: <|tool_calls|><|tool_call:begin|>{{id}}<|tool_call:name|>{{name}}<|tool_call:args|>{{args}}<|tool_call:end|><|calls|>
338
+ - tool+think mode: <|think|>{{tokens}}<|end|><|begin|>assistant<|tool_calls|>...<|calls|>
339
+ - no-think mode: <|content|>{{tokens}}<|flush|>
340
+
341
+ Key constraints:
342
+ - Think message can only appear first
343
+ - Think message must be followed by another message
344
+ - Content and tool messages cannot coexist
345
+ - Maximum 2 messages (think + content/tool, or just content/tool)
346
+
347
+ Performance optimization:
348
+ - Uses incremental state tracking to avoid full token sequence scan on each call
349
+ - Maintains local counters for budget tracking
350
+ - Uses pre-computed constants to avoid repeated object creation
351
+ """
352
+
353
+ # Pre-computed frozenset for reasoning state check (avoids set creation per call)
354
+ _REASONING_STATES = frozenset({
355
+ GenerationState.INITIAL,
356
+ GenerationState.THINK_BEGIN,
357
+ GenerationState.THINK_IN_PROGRESS,
358
+ })
359
+
360
+ def __init__(
361
+ self,
362
+ is_reasoning_request: bool,
363
+ is_structured_outputs: bool,
364
+ reasoning_budget: int | None = None,
365
+ tool_call_id_budget: int = DEFAULT_TOOL_CALL_ID_BUDGET,
366
+ ):
367
+ self._is_reasoning_request = is_reasoning_request
368
+ self._is_structured_outputs = is_structured_outputs
369
+ self._reasoning_budget = reasoning_budget
370
+ self._tool_call_id_budget = tool_call_id_budget
371
+
372
+ # Incremental state tracking
373
+ self._state = GenerationState.INITIAL
374
+ self._last_processed_len = 0
375
+ self._in_think = False
376
+ self._in_content = False
377
+
378
+ # Budget counters
379
+ self._think_token_count = 0
380
+ self._tool_call_id_token_count = 0
381
+
382
+ def _reset_state(self) -> None:
383
+ """Reset all incremental state to initial values.
384
+
385
+ Called when defensive reprocessing is needed (e.g., token sequence inconsistency).
386
+ """
387
+ self._state = GenerationState.INITIAL
388
+ self._last_processed_len = 0
389
+ self._in_think = False
390
+ self._in_content = False
391
+ self._think_token_count = 0
392
+ self._tool_call_id_token_count = 0
393
+
394
+ def _process_token(self, token_id: int) -> None:
395
+ """Process a single token and update internal state incrementally.
396
+
397
+ Args:
398
+ token_id: The token ID to process.
399
+ """
400
+ if token_id == THINK_TOKEN_ID:
401
+ self._state = GenerationState.THINK_BEGIN
402
+ self._in_think = True
403
+ self._in_content = False
404
+ self._think_token_count = 0 # Reset counter for new think block
405
+
406
+ elif token_id == CONTENT_TOKEN_ID:
407
+ self._state = GenerationState.CONTENT_BEGIN
408
+ self._in_content = True
409
+ self._in_think = False
410
+
411
+ elif token_id == TOOL_CALLS_TOKEN_ID:
412
+ self._state = GenerationState.TOOL_CALLS_BEGIN
413
+ self._in_think = False
414
+ self._in_content = False
415
+
416
+ elif token_id == TOOL_CALL_BEGIN_TOKEN_ID:
417
+ self._state = GenerationState.TOOL_CALL_BEGIN
418
+ self._tool_call_id_token_count = 0 # Reset counter for new tool call
419
+
420
+ elif token_id == TOOL_CALL_NAME_TOKEN_ID:
421
+ self._state = GenerationState.TOOL_CALL_NAME_BEGIN
422
+
423
+ elif token_id == TOOL_CALL_ARGS_TOKEN_ID:
424
+ self._state = GenerationState.TOOL_CALL_ARGS_BEGIN
425
+
426
+ elif token_id == TOOL_CALL_END_TOKEN_ID:
427
+ self._state = GenerationState.TOOL_CALL_END
428
+
429
+ elif token_id == CALLS_TOKEN_ID:
430
+ self._state = GenerationState.CALLS
431
+
432
+ elif token_id == BEGIN_TOKEN_ID:
433
+ self._state = GenerationState.NEW_MESSAGE_BEGIN
434
+
435
+ elif token_id == ASSISTANT_TOKEN_ID:
436
+ if self._state == GenerationState.NEW_MESSAGE_BEGIN:
437
+ self._state = GenerationState.NEW_MESSAGE_ASSISTANT
438
+
439
+ elif token_id == END_TOKEN_ID:
440
+ if self._in_think:
441
+ self._state = GenerationState.THINK_END
442
+ self._in_think = False
443
+ elif self._in_content:
444
+ self._state = GenerationState.CONTENT_END
445
+ self._in_content = False
446
+
447
+ elif token_id == FLUSH_TOKEN_ID:
448
+ if self._in_think:
449
+ self._state = GenerationState.THINK_FLUSH
450
+ self._in_think = False
451
+ elif self._in_content:
452
+ self._state = GenerationState.CONTENT_FLUSH
453
+ self._in_content = False
454
+
455
+ else:
456
+ # Regular token - update state and counters based on current context
457
+ if self._state == GenerationState.THINK_BEGIN:
458
+ self._state = GenerationState.THINK_IN_PROGRESS
459
+ self._think_token_count += 1
460
+ elif self._state == GenerationState.THINK_IN_PROGRESS:
461
+ self._think_token_count += 1
462
+ elif self._state == GenerationState.CONTENT_BEGIN:
463
+ self._state = GenerationState.CONTENT_IN_PROGRESS
464
+ elif self._state == GenerationState.CONTENT_IN_PROGRESS:
465
+ pass # Stay in content_in_progress
466
+ elif self._state == GenerationState.TOOL_CALL_BEGIN:
467
+ self._state = GenerationState.TOOL_CALL_ID_IN_PROGRESS
468
+ self._tool_call_id_token_count += 1
469
+ elif self._state == GenerationState.TOOL_CALL_ID_IN_PROGRESS:
470
+ self._tool_call_id_token_count += 1
471
+ elif self._state == GenerationState.TOOL_CALL_NAME_BEGIN:
472
+ self._state = GenerationState.TOOL_CALL_NAME_IN_PROGRESS
473
+ elif self._state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS:
474
+ pass # Stay in tool_call_name_in_progress
475
+ elif self._state == GenerationState.TOOL_CALL_ARGS_BEGIN:
476
+ self._state = GenerationState.TOOL_CALL_ARGS_IN_PROGRESS
477
+ elif self._state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS:
478
+ pass # Stay in tool_call_args_in_progress
479
+
480
+ def _update_state_incremental(self, output_token_ids: list[int]) -> None:
481
+ """Update internal state by processing only new tokens.
482
+
483
+ Args:
484
+ output_token_ids: Full list of output token IDs.
485
+ """
486
+ current_len = len(output_token_ids)
487
+
488
+ # Defensive check: if token sequence is shorter than expected, reset and reprocess
489
+ if current_len < self._last_processed_len:
490
+ self._reset_state()
491
+
492
+ # Process only new tokens
493
+ for i in range(self._last_processed_len, current_len):
494
+ self._process_token(output_token_ids[i])
495
+
496
+ self._last_processed_len = current_len
497
+
498
+ @staticmethod
499
+ def _count_think_tokens(output_token_ids: list[int]) -> int:
500
+ """Count the number of tokens generated after <|think|> token.
501
+
502
+ Returns 0 if <|think|> token is not found (defensive).
503
+ Note: This static method is kept for backward compatibility and testing.
504
+ The incremental version uses _think_token_count instead.
505
+ """
506
+ try:
507
+ think_index = output_token_ids.index(THINK_TOKEN_ID)
508
+ return len(output_token_ids) - think_index - 1
509
+ except ValueError:
510
+ return 0
511
+
512
+ @staticmethod
513
+ def _count_tool_call_id_tokens(output_token_ids: list[int]) -> int:
514
+ """Count the number of tokens generated after the last <|tool_call:begin|> token.
515
+
516
+ Returns 0 if <|tool_call:begin|> token is not found (defensive).
517
+ Note: This static method is kept for backward compatibility and testing.
518
+ The incremental version uses _tool_call_id_token_count instead.
519
+ """
520
+ # Find the last occurrence of <|tool_call:begin|> for multi-tool-call support
521
+ try:
522
+ # Reverse search for the last <|tool_call:begin|>
523
+ reversed_index = output_token_ids[::-1].index(TOOL_CALL_BEGIN_TOKEN_ID)
524
+ last_begin_index = len(output_token_ids) - 1 - reversed_index
525
+ return len(output_token_ids) - last_begin_index - 1
526
+ except ValueError:
527
+ return 0
528
+
529
+ def __call__(
530
+ self,
531
+ output_token_ids: list[int],
532
+ logits: torch.Tensor,
533
+ ) -> torch.Tensor:
534
+ # Update state incrementally (only process new tokens)
535
+ self._update_state_incremental(output_token_ids)
536
+ state = self._state
537
+
538
+ # Handle structured outputs mode
539
+ if self._is_structured_outputs:
540
+ if not self._is_reasoning_request:
541
+ # Non-reasoning request with structured outputs: no logit control
542
+ return logits
543
+ else:
544
+ # Reasoning request with structured outputs:
545
+ # Control logits only during reasoning phase
546
+ if state not in self._REASONING_STATES:
547
+ # Reasoning finished, let structured outputs handle it
548
+ return logits
549
+
550
+ if state == GenerationState.INITIAL:
551
+ if self._is_reasoning_request:
552
+ # Force: <|think|> only (reasoning request must start with think)
553
+ think_logit = logits[THINK_TOKEN_ID].clone()
554
+ logits.fill_(NEG_INF)
555
+ logits[THINK_TOKEN_ID] = think_logit
556
+ else:
557
+ # Allow: <|content|>, <|tool_calls|> only
558
+ content_logit = logits[CONTENT_TOKEN_ID].clone()
559
+ tool_calls_logit = logits[TOOL_CALLS_TOKEN_ID].clone()
560
+ logits.fill_(NEG_INF)
561
+ logits[CONTENT_TOKEN_ID] = content_logit
562
+ logits[TOOL_CALLS_TOKEN_ID] = tool_calls_logit
563
+
564
+ elif state in (GenerationState.THINK_BEGIN, GenerationState.THINK_IN_PROGRESS):
565
+ # Check if reasoning budget is exceeded (using incremental counter)
566
+ if (
567
+ self._reasoning_budget is not None
568
+ and state == GenerationState.THINK_IN_PROGRESS
569
+ ):
570
+ if self._think_token_count >= self._reasoning_budget:
571
+ # Force <|end|> token to terminate reasoning
572
+ logits.fill_(NEG_INF)
573
+ logits[END_TOKEN_ID] = 0.0
574
+ return logits
575
+
576
+ # Transform: <|flush|> -> <|end|>
577
+ # Think must be followed by another message, so prevent early termination
578
+ logits[END_TOKEN_ID] = torch.maximum(logits[END_TOKEN_ID], logits[FLUSH_TOKEN_ID])
579
+ # Forbid all special tokens except <|end|>
580
+ logits[_SPECIAL_EXCEPT_END] = NEG_INF
581
+
582
+ elif state == GenerationState.THINK_END:
583
+ # Force: <|begin|> only
584
+ # Think must be followed by another message
585
+ logits.fill_(NEG_INF)
586
+ logits[BEGIN_TOKEN_ID] = 0.0
587
+
588
+ elif state == GenerationState.NEW_MESSAGE_BEGIN:
589
+ # Force: assistant token only
590
+ logits.fill_(NEG_INF)
591
+ logits[ASSISTANT_TOKEN_ID] = 0.0
592
+
593
+ elif state == GenerationState.NEW_MESSAGE_ASSISTANT:
594
+ # Allow: <|content|>, <|tool_calls|>, regular tokens
595
+ # Forbid: all other special tokens
596
+ logits[_SPECIAL_EXCEPT_CONTENT_TOOLCALLS] = NEG_INF
597
+
598
+ elif state in (GenerationState.CONTENT_BEGIN, GenerationState.CONTENT_IN_PROGRESS):
599
+ # Transform: <|end|> -> <|flush|>
600
+ # Content cannot be followed by another message
601
+ logits[FLUSH_TOKEN_ID] = torch.maximum(logits[FLUSH_TOKEN_ID], logits[END_TOKEN_ID])
602
+ # Forbid all special tokens except <|flush|>
603
+ logits[_SPECIAL_EXCEPT_FLUSH] = NEG_INF
604
+
605
+ elif state == GenerationState.TOOL_CALLS_BEGIN:
606
+ # Force: <|tool_call:begin|> only
607
+ tool_call_begin_logit = logits[TOOL_CALL_BEGIN_TOKEN_ID].clone()
608
+ logits.fill_(NEG_INF)
609
+ logits[TOOL_CALL_BEGIN_TOKEN_ID] = tool_call_begin_logit
610
+
611
+ elif state == GenerationState.TOOL_CALL_BEGIN:
612
+ # Allow: regular tokens only (ID generation)
613
+ # Forbid: all special tokens
614
+ _forbid_all_special_tokens(logits)
615
+
616
+ elif state == GenerationState.TOOL_CALL_ID_IN_PROGRESS:
617
+ # Check if tool call ID budget is exceeded (using incremental counter)
618
+ if self._tool_call_id_token_count >= self._tool_call_id_budget:
619
+ # Force <|tool_call:name|> token to terminate ID generation
620
+ logits.fill_(NEG_INF)
621
+ logits[TOOL_CALL_NAME_TOKEN_ID] = 0.0
622
+ return logits
623
+
624
+ # Allow: <|tool_call:name|>, regular tokens
625
+ # Forbid: all other special tokens
626
+ logits[_SPECIAL_EXCEPT_TOOLCALL_NAME] = NEG_INF
627
+
628
+ elif state == GenerationState.TOOL_CALL_NAME_BEGIN:
629
+ # Allow: regular tokens only (function name generation)
630
+ # Forbid: all special tokens
631
+ _forbid_all_special_tokens(logits)
632
+
633
+ elif state == GenerationState.TOOL_CALL_NAME_IN_PROGRESS:
634
+ # Allow: <|tool_call:args|>, regular tokens
635
+ # Forbid: all other special tokens
636
+ logits[_SPECIAL_EXCEPT_TOOLCALL_ARGS] = NEG_INF
637
+
638
+ elif state == GenerationState.TOOL_CALL_ARGS_BEGIN:
639
+ # Allow: regular tokens only (JSON args generation)
640
+ # Forbid: all special tokens
641
+ _forbid_all_special_tokens(logits)
642
+
643
+ elif state == GenerationState.TOOL_CALL_ARGS_IN_PROGRESS:
644
+ # Allow: <|tool_call:end|>, regular tokens
645
+ # Forbid: all other special tokens
646
+ logits[_SPECIAL_EXCEPT_TOOLCALL_END] = NEG_INF
647
+
648
+ elif state == GenerationState.TOOL_CALL_END:
649
+ # Allow: <|tool_call:begin|> (next tool call), <|calls|> (end)
650
+ # Forbid: all other special tokens
651
+ tool_call_begin_logit = logits[TOOL_CALL_BEGIN_TOKEN_ID].clone()
652
+ calls_logit = logits[CALLS_TOKEN_ID].clone()
653
+ logits.fill_(NEG_INF)
654
+ logits[TOOL_CALL_BEGIN_TOKEN_ID] = tool_call_begin_logit
655
+ logits[CALLS_TOKEN_ID] = calls_logit
656
+
657
+ # CALLS state: no processing needed (EOS)
658
+
659
+ return logits
660
+
661
+ class SolarOpenTemplateLogitsProcessor(AdapterLogitsProcessor):
662
+ """
663
+ Logits processor that enforces Solar Open chat template.
664
+ This processor manages the generation flow according to the
665
+ Solar Open chat template by tracking generation states.
666
+ """
667
+
668
+ def __init__(
669
+ self,
670
+ vllm_config: "VllmConfig",
671
+ device: torch.device,
672
+ is_pin_memory: bool,
673
+ ):
674
+ super().__init__(vllm_config, device, is_pin_memory)
675
+
676
+ # Dynamic reasoning budget settings for HIGH effort
677
+ self._high_max = self._parse_env_int(
678
+ "SOLAR_REASONING_BUDGET_HIGH_MAX", DEFAULT_REASONING_BUDGET_HIGH_MAX
679
+ )
680
+ self._high_min = self._parse_env_int(
681
+ "SOLAR_REASONING_BUDGET_HIGH_MIN", DEFAULT_REASONING_BUDGET_HIGH_MIN
682
+ )
683
+ self._high_ratio = self._parse_env_int(
684
+ "SOLAR_REASONING_BUDGET_HIGH_RATIO", DEFAULT_REASONING_BUDGET_HIGH_RATIO
685
+ )
686
+
687
+ # Dynamic reasoning budget settings for MEDIUM effort
688
+ self._medium_max = self._parse_env_int(
689
+ "SOLAR_REASONING_BUDGET_MEDIUM_MAX", DEFAULT_REASONING_BUDGET_MEDIUM_MAX
690
+ )
691
+ self._medium_min = self._parse_env_int(
692
+ "SOLAR_REASONING_BUDGET_MEDIUM_MIN", DEFAULT_REASONING_BUDGET_MEDIUM_MIN
693
+ )
694
+ self._medium_ratio = self._parse_env_int(
695
+ "SOLAR_REASONING_BUDGET_MEDIUM_RATIO", DEFAULT_REASONING_BUDGET_MEDIUM_RATIO
696
+ )
697
+
698
+ self._tool_call_id_budget: int = self._parse_env_int(
699
+ "SOLAR_TOOL_CALL_ID_BUDGET", DEFAULT_TOOL_CALL_ID_BUDGET
700
+ )
701
+
702
+ @staticmethod
703
+ def _parse_env_int(env_var: str, default: int) -> int:
704
+ """Parse environment variable as integer, return default if not set or invalid."""
705
+ value = os.environ.get(env_var)
706
+ if value is None:
707
+ return default
708
+ try:
709
+ return int(value)
710
+ except ValueError:
711
+ return default
712
+
713
+ def _calculate_reasoning_budget(self, effort: str, max_tokens: int) -> int:
714
+ """Calculate dynamic reasoning budget based on effort level and max_tokens.
715
+
716
+ Priority (higher priority conditions are applied first):
717
+ 1. max_budget: Upper limit for reasoning tokens
718
+ 2. min_budget: Lower limit for reasoning tokens
719
+ 3. ratio: Percentage of max_tokens allocated for reasoning (e.g., 60 means 60%)
720
+
721
+ budget = min(max_budget, max(min_budget, max_tokens * ratio / 100))
722
+ """
723
+ if effort == "high":
724
+ max_budget = self._high_max
725
+ min_budget = self._high_min
726
+ ratio = self._high_ratio
727
+ elif effort == "medium":
728
+ max_budget = self._medium_max
729
+ min_budget = self._medium_min
730
+ ratio = self._medium_ratio
731
+ else:
732
+ # Fallback to high for unknown effort levels
733
+ max_budget = self._high_max
734
+ min_budget = self._high_min
735
+ ratio = self._high_ratio
736
+
737
+ # Calculate ratio-based budget (ratio is percentage, e.g., 60 means 60%)
738
+ ratio_budget = max_tokens * ratio // 100
739
+
740
+ # Apply priority: max > min > ratio
741
+ budget = min(max_budget, max(min_budget, ratio_budget))
742
+
743
+ return budget
744
+
745
+ def is_argmax_invariant(self) -> bool:
746
+ """This processor can change argmax result by forcing specific tokens."""
747
+ return False
748
+
749
+ def new_req_logits_processor(
750
+ self,
751
+ params: SamplingParams,
752
+ ) -> RequestLogitsProcessor | None:
753
+ reasoning_effort = params.reasoning_effort or DEFAULT_REASONING_EFFORT
754
+ reasoning_budget = self._calculate_reasoning_budget(
755
+ reasoning_effort, params.max_tokens
756
+ )
757
+ return SolarOpenTemplateEnforcer(
758
+ is_reasoning_request=is_reasoning_request(params),
759
+ is_structured_outputs=is_structured_outputs(params),
760
+ reasoning_budget=reasoning_budget,
761
+ tool_call_id_budget=self._tool_call_id_budget,
762
+ )
763
+