|
|
from collections.abc import Sequence |
|
|
from typing import Optional, Union |
|
|
|
|
|
import regex as re |
|
|
from transformers import PreTrainedTokenizerBase |
|
|
|
|
|
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, |
|
|
DeltaMessage) |
|
|
from vllm.logger import init_logger |
|
|
from vllm.reasoning import ReasoningParser, ReasoningParserManager |
|
|
|
|
|
logger = init_logger(__name__) |
|
|
|
|
|
|
|
|
@ReasoningParserManager.register_module("greenmind_14b_r1") |
|
|
class GreenMind14bR1ReasoningParser(ReasoningParser): |
|
|
""" |
|
|
Reasoning parser for GreenMind-14B-R1 model. |
|
|
|
|
|
The GreenMind-14B-R1 model uses </think> token to denote the end of reasoning |
|
|
text. This parser extracts all content before </think> as reasoning content. |
|
|
|
|
|
think start: "<think>\n": [13708, 766, 397] |
|
|
think ends: "\n</think>\n<answer>\n": [198, 522, 26865, 397, 27, 9217, 397] |
|
|
response ends: "</answer>": [198, 522, 9217, 29] |
|
|
""" |
|
|
|
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase): |
|
|
super().__init__(tokenizer) |
|
|
self.think_start_expr = r"<think>\n" |
|
|
self.think_end_expr = r"\n</think>\n" |
|
|
|
|
|
self.response_start_expr = r"\n</think>\n<answer>" |
|
|
self.response_end_expr = r"</answer>" |
|
|
|
|
|
self.full_match_reasoning_regex = re.compile( |
|
|
rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}", |
|
|
re.DOTALL) |
|
|
|
|
|
self.half_match_reasoning_regex = re.compile( |
|
|
rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)", |
|
|
re.DOTALL) |
|
|
|
|
|
self.think_start_ids = [13708, 766, 397] |
|
|
self.think_start_ids_fast = [27, 26865, 397] |
|
|
self.response_start_ids = [522, 26865, 397, 27, 9217, 397] |
|
|
self.response_start_ids_fast = [522, 26865, 397, 27, 9217, 29] |
|
|
self.response_end_ids = [522, 9217, 29] |
|
|
self.fast_think_ids = [ |
|
|
13708, 766, 1339, 522, 26865, 397, 27, 9217, 397 |
|
|
] |
|
|
|
|
|
|
|
|
self.buffered_text = [] |
|
|
self.buffered_ids = [] |
|
|
|
|
|
self.current_state = "reasoning" |
|
|
self.all_states = ["reasoning", "response"] |
|
|
|
|
|
self.current_state = "idle" |
|
|
self.expected_sequence = self.think_start_ids |
|
|
|
|
|
self.expected_sequence_side = self.think_start_ids_fast |
|
|
self.sequence_index = 0 |
|
|
self.token_buffer = [] |
|
|
self.text_buffer = "" |
|
|
|
|
|
def is_reasoning_end(self, input_ids: list[int]) -> bool: |
|
|
return self.current_state == "response" |
|
|
|
|
|
def extract_content_ids(self, input_ids: list[int]) -> list[int]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return [] |
|
|
|
|
|
def extract_reasoning_content( |
|
|
self, model_output: str, request: ChatCompletionRequest |
|
|
) -> tuple[Optional[str], Optional[str]]: |
|
|
"""Extract the reasoning content & content sections, respectively. |
|
|
If the sequence doesn't match what we expect, i.e., the model generates |
|
|
something else, all content is considered non-reasoning content. |
|
|
|
|
|
Args: |
|
|
model_output (str): Output of the model to be parsed. |
|
|
request (ChatCompletionRequest): Request being processed. |
|
|
|
|
|
Returns: |
|
|
tuple[Optional[str], Optional[str]]: Tuple pair containing the |
|
|
reasoning content and non-reasoning content. |
|
|
""" |
|
|
|
|
|
re_match = self.full_match_reasoning_regex.findall(model_output) |
|
|
if re_match: |
|
|
reasoning_content, response_content = re_match[0] |
|
|
if len(reasoning_content) == 0: |
|
|
reasoning_content = None |
|
|
if len(response_content) == 0: |
|
|
response_content = None |
|
|
return reasoning_content, response_content |
|
|
|
|
|
fallback_regex = self.half_match_reasoning_regex |
|
|
fallback_match = fallback_regex.findall(model_output) |
|
|
if fallback_match: |
|
|
reasoning_content, response_content = fallback_match[0] |
|
|
|
|
|
if response_content.endswith(self.response_end_expr): |
|
|
response_content = response_content[:-len(self. |
|
|
response_end_expr)] |
|
|
|
|
|
if len(reasoning_content) == 0: |
|
|
reasoning_content = None |
|
|
if len(response_content) == 0: |
|
|
response_content = None |
|
|
|
|
|
return reasoning_content, response_content |
|
|
|
|
|
return None, model_output |
|
|
|
|
|
def _is_strict_increasing_subsequence(self, subsequence: Sequence[int], |
|
|
sequence: Sequence[int]) -> bool: |
|
|
if not subsequence: |
|
|
return False |
|
|
|
|
|
sub_idx = 0 |
|
|
for num in sequence: |
|
|
if sub_idx < len(subsequence) and num == subsequence[sub_idx]: |
|
|
sub_idx += 1 |
|
|
return sub_idx == len(subsequence) |
|
|
|
|
|
def extract_reasoning_content_streaming( |
|
|
self, |
|
|
previous_text: str, |
|
|
current_text: str, |
|
|
delta_text: str, |
|
|
previous_token_ids: Sequence[int], |
|
|
current_token_ids: Sequence[int], |
|
|
delta_token_ids: Sequence[int], |
|
|
) -> Union[DeltaMessage, None]: |
|
|
"""Extract content using token ID sequence state machine""" |
|
|
|
|
|
think_start_sequence = self.think_start_ids |
|
|
response_start_sequence = self.response_start_ids |
|
|
response_end_sequence = self.response_end_ids |
|
|
|
|
|
assert (len(delta_token_ids) == 1) |
|
|
|
|
|
token = delta_token_ids[0] |
|
|
|
|
|
def check_token_with_sequence(token): |
|
|
if self.current_state == "idle" or self.current_state == "think": |
|
|
return (token == self.expected_sequence[self.sequence_index] |
|
|
or token == \ |
|
|
self.expected_sequence_side[self.sequence_index]) |
|
|
else: |
|
|
return token == self.expected_sequence[self.sequence_index] |
|
|
|
|
|
def check_last_token(token): |
|
|
if self.current_state == "idle" or self.current_state == "think": |
|
|
|
|
|
if (self.sequence_index - 1 < len(self.expected_sequence_side) |
|
|
and token |
|
|
== self.expected_sequence_side[self.sequence_index - |
|
|
1]): |
|
|
return self.sequence_index == len( |
|
|
self.expected_sequence_side) |
|
|
else: |
|
|
return self.sequence_index == len(self.expected_sequence) |
|
|
else: |
|
|
return self.sequence_index == len(self.expected_sequence) |
|
|
|
|
|
|
|
|
token_in_state_seq = check_token_with_sequence(token) |
|
|
|
|
|
if token_in_state_seq: |
|
|
|
|
|
self.token_buffer.append(token) |
|
|
self.text_buffer += delta_text |
|
|
self.sequence_index += 1 |
|
|
|
|
|
|
|
|
|
|
|
if check_last_token(token): |
|
|
|
|
|
if self.current_state == "idle": |
|
|
self.current_state = "think" |
|
|
self.expected_sequence = response_start_sequence |
|
|
self.expected_sequence_side = self.response_start_ids_fast |
|
|
elif self.current_state == "think": |
|
|
self.current_state = "response" |
|
|
self.expected_sequence = response_end_sequence |
|
|
elif self.current_state == "response": |
|
|
self.current_state = "idle" |
|
|
self.expected_sequence = think_start_sequence |
|
|
self.expected_sequence_side = self.think_start_ids_fast |
|
|
|
|
|
|
|
|
self.sequence_index = 0 |
|
|
self.token_buffer = [] |
|
|
self.text_buffer = "" |
|
|
|
|
|
else: |
|
|
|
|
|
if self.token_buffer and len(self.token_buffer) > 0: |
|
|
|
|
|
buffered_content = self.text_buffer + delta_text |
|
|
|
|
|
self.sequence_index = 0 |
|
|
self.token_buffer = [] |
|
|
self.text_buffer = "" |
|
|
|
|
|
|
|
|
if self.current_state == "think": |
|
|
return DeltaMessage(reasoning_content=buffered_content, |
|
|
content=None) |
|
|
else: |
|
|
return DeltaMessage(reasoning_content=None, |
|
|
content=buffered_content) |
|
|
else: |
|
|
|
|
|
if self.current_state == "think": |
|
|
return DeltaMessage(reasoning_content=delta_text, |
|
|
content=None) |
|
|
else: |
|
|
return DeltaMessage(reasoning_content=None, |
|
|
content=delta_text) |
|
|
|
|
|
|
|
|
return None |
|
|
|
|
|
|