File size: 9,633 Bytes
2eab44b a28854c 2eab44b a061951 5fad179 a28854c 2eab44b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 |
from collections.abc import Sequence
from typing import Optional, Union
import regex as re
from transformers import PreTrainedTokenizerBase
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager
logger = init_logger(__name__)
@ReasoningParserManager.register_module("greenmind_14b_r1")
class GreenMind14bR1ReasoningParser(ReasoningParser):
"""
Reasoning parser for GreenMind-14B-R1 model.
The GreenMind-14B-R1 model uses </think> token to denote the end of reasoning
text. This parser extracts all content before </think> as reasoning content.
think start: "<think>\n": [13708, 766, 397]
think ends: "\n</think>\n<answer>\n": [198, 522, 26865, 397, 27, 9217, 397]
response ends: "</answer>": [198, 522, 9217, 29]
"""
def __init__(self, tokenizer: PreTrainedTokenizerBase):
super().__init__(tokenizer)
self.think_start_expr = r"<think>\n"
self.think_end_expr = r"\n</think>\n"
self.response_start_expr = r"\n</think>\n<answer>"
self.response_end_expr = r"</answer>"
self.full_match_reasoning_regex = re.compile(
rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}",
re.DOTALL)
self.half_match_reasoning_regex = re.compile(
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
re.DOTALL)
self.think_start_ids = [13708, 766, 397] # <think>\n
self.think_start_ids_fast = [27, 26865, 397] # < think >\n
self.response_start_ids = [522, 26865, 397, 27, 9217, 397] # </think>\n<answer>\n
self.response_start_ids_fast = [522, 26865, 397, 27, 9217, 29] # </think>\n<answer>
self.response_end_ids = [522, 9217, 29]
self.fast_think_ids = [
13708, 766, 1339, 522, 26865, 397, 27, 9217, 397
]
# when state change, send out all the buffered text in last state
self.buffered_text = []
self.buffered_ids = []
self.current_state = "reasoning"
self.all_states = ["reasoning", "response"]
self.current_state = "idle"
self.expected_sequence = self.think_start_ids
# this sequence only for the think start, it has two way to start.
self.expected_sequence_side = self.think_start_ids_fast
self.sequence_index = 0
self.token_buffer = []
self.text_buffer = ""
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.current_state == "response"
def extract_content_ids(self, input_ids: list[int]) -> list[int]:
# for hunyuan streaming reason parsing, the stream parse
# will call first, and the same token will be called in
# is_reasoning_end and extract_content_ids
# this id is not part of content, so just return [] here.
return []
def extract_reasoning_content(
self, model_output: str, request: ChatCompletionRequest
) -> tuple[Optional[str], Optional[str]]:
"""Extract the reasoning content & content sections, respectively.
If the sequence doesn't match what we expect, i.e., the model generates
something else, all content is considered non-reasoning content.
Args:
model_output (str): Output of the model to be parsed.
request (ChatCompletionRequest): Request being processed.
Returns:
tuple[Optional[str], Optional[str]]: Tuple pair containing the
reasoning content and non-reasoning content.
"""
re_match = self.full_match_reasoning_regex.findall(model_output)
if re_match:
reasoning_content, response_content = re_match[0]
if len(reasoning_content) == 0:
reasoning_content = None
if len(response_content) == 0:
response_content = None
return reasoning_content, response_content
fallback_regex = self.half_match_reasoning_regex
fallback_match = fallback_regex.findall(model_output)
if fallback_match:
reasoning_content, response_content = fallback_match[0]
if response_content.endswith(self.response_end_expr):
response_content = response_content[:-len(self.
response_end_expr)]
if len(reasoning_content) == 0:
reasoning_content = None
if len(response_content) == 0:
response_content = None
return reasoning_content, response_content
return None, model_output
def _is_strict_increasing_subsequence(self, subsequence: Sequence[int],
sequence: Sequence[int]) -> bool:
if not subsequence:
return False
sub_idx = 0
for num in sequence:
if sub_idx < len(subsequence) and num == subsequence[sub_idx]:
sub_idx += 1
return sub_idx == len(subsequence)
def extract_reasoning_content_streaming(
self,
previous_text: str,
current_text: str,
delta_text: str,
previous_token_ids: Sequence[int],
current_token_ids: Sequence[int],
delta_token_ids: Sequence[int],
) -> Union[DeltaMessage, None]:
"""Extract content using token ID sequence state machine"""
# Define sequences
think_start_sequence = self.think_start_ids
response_start_sequence = self.response_start_ids
response_end_sequence = self.response_end_ids
assert (len(delta_token_ids) == 1)
# Process each token in the delta
token = delta_token_ids[0]
def check_token_with_sequence(token):
if self.current_state == "idle" or self.current_state == "think":
return (token == self.expected_sequence[self.sequence_index]
or token == \
self.expected_sequence_side[self.sequence_index])
else:
return token == self.expected_sequence[self.sequence_index]
def check_last_token(token):
if self.current_state == "idle" or self.current_state == "think":
# only return true if it's judge using a side sequence.
if (self.sequence_index - 1 < len(self.expected_sequence_side)
and token
== self.expected_sequence_side[self.sequence_index -
1]):
return self.sequence_index == len(
self.expected_sequence_side)
else:
return self.sequence_index == len(self.expected_sequence)
else:
return self.sequence_index == len(self.expected_sequence)
# Check if token matches expected sequence
token_in_state_seq = check_token_with_sequence(token)
if token_in_state_seq:
# Store matching token
self.token_buffer.append(token)
self.text_buffer += delta_text
self.sequence_index += 1
## state change from idle->think->response->idle
# Check if sequence fully matched
if check_last_token(token):
# State transition
if self.current_state == "idle":
self.current_state = "think"
self.expected_sequence = response_start_sequence
self.expected_sequence_side = self.response_start_ids_fast
elif self.current_state == "think":
self.current_state = "response"
self.expected_sequence = response_end_sequence
elif self.current_state == "response":
self.current_state = "idle"
self.expected_sequence = think_start_sequence
self.expected_sequence_side = self.think_start_ids_fast
# Reset matching state
self.sequence_index = 0
self.token_buffer = []
self.text_buffer = ""
# Do not send content for state transition texts.
else:
# Sequence broken - handle buffered content
if self.token_buffer and len(self.token_buffer) > 0:
# Send buffered tokens
buffered_content = self.text_buffer + delta_text
# Reset matching state
self.sequence_index = 0
self.token_buffer = []
self.text_buffer = ""
# Return content based on current state
if self.current_state == "think":
return DeltaMessage(reasoning_content=buffered_content,
content=None)
else:
return DeltaMessage(reasoning_content=None,
content=buffered_content)
else:
# No buffered content, send normally
if self.current_state == "think":
return DeltaMessage(reasoning_content=delta_text,
content=None)
else:
return DeltaMessage(reasoning_content=None,
content=delta_text)
# If no content to send in this delta
return None
|