File size: 9,633 Bytes
2eab44b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a28854c
 
2eab44b
 
 
 
 
 
 
 
 
a061951
 
5fad179
 
a28854c
2eab44b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
from collections.abc import Sequence
from typing import Optional, Union

import regex as re
from transformers import PreTrainedTokenizerBase

from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
                                              DeltaMessage)
from vllm.logger import init_logger
from vllm.reasoning import ReasoningParser, ReasoningParserManager

logger = init_logger(__name__)


@ReasoningParserManager.register_module("greenmind_14b_r1")
class GreenMind14bR1ReasoningParser(ReasoningParser):
    """
    Reasoning parser for GreenMind-14B-R1 model.

    The GreenMind-14B-R1 model uses </think> token to denote the end of reasoning 
    text. This parser extracts all content before </think> as reasoning content.

    think start: "<think>\n": [13708, 766, 397]
    think ends: "\n</think>\n<answer>\n": [198, 522, 26865, 397, 27, 9217, 397]
    response ends: "</answer>": [198, 522, 9217, 29]
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)
        self.think_start_expr = r"<think>\n"
        self.think_end_expr = r"\n</think>\n"

        self.response_start_expr = r"\n</think>\n<answer>"
        self.response_end_expr = r"</answer>"

        self.full_match_reasoning_regex = re.compile(
            rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}",
            re.DOTALL)

        self.half_match_reasoning_regex = re.compile(
            rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
            re.DOTALL)

        self.think_start_ids = [13708, 766, 397] # <think>\n
        self.think_start_ids_fast = [27, 26865, 397] # < think >\n
        self.response_start_ids = [522, 26865, 397, 27, 9217, 397] # </think>\n<answer>\n
        self.response_start_ids_fast = [522, 26865, 397, 27, 9217, 29] # </think>\n<answer>
        self.response_end_ids = [522, 9217, 29]
        self.fast_think_ids = [
            13708, 766, 1339, 522, 26865, 397, 27, 9217, 397
        ]

        # when state change, send out all the buffered text in last state
        self.buffered_text = []
        self.buffered_ids = []

        self.current_state = "reasoning"
        self.all_states = ["reasoning", "response"]

        self.current_state = "idle"
        self.expected_sequence = self.think_start_ids
        # this sequence only for the think start, it has two way to start.
        self.expected_sequence_side = self.think_start_ids_fast
        self.sequence_index = 0
        self.token_buffer = []
        self.text_buffer = ""

    def is_reasoning_end(self, input_ids: list[int]) -> bool:
        return self.current_state == "response"

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        # for hunyuan streaming reason parsing, the stream parse
        # will call first, and the same token will be called in
        # is_reasoning_end and extract_content_ids
        # this id is not part of content, so just return [] here.
        return []

    def extract_reasoning_content(
            self, model_output: str, request: ChatCompletionRequest
    ) -> tuple[Optional[str], Optional[str]]:
        """Extract the reasoning content & content sections, respectively.
        If the sequence doesn't match what we expect, i.e., the model generates
        something else, all content is considered non-reasoning content.

        Args:
            model_output (str): Output of the model to be parsed.
            request (ChatCompletionRequest): Request being processed.

        Returns:
            tuple[Optional[str], Optional[str]]: Tuple pair containing the
            reasoning content and non-reasoning content.
        """

        re_match = self.full_match_reasoning_regex.findall(model_output)
        if re_match:
            reasoning_content, response_content = re_match[0]
            if len(reasoning_content) == 0:
                reasoning_content = None
            if len(response_content) == 0:
                response_content = None
            return reasoning_content, response_content

        fallback_regex = self.half_match_reasoning_regex
        fallback_match = fallback_regex.findall(model_output)
        if fallback_match:
            reasoning_content, response_content = fallback_match[0]

            if response_content.endswith(self.response_end_expr):
                response_content = response_content[:-len(self.
                                                          response_end_expr)]

            if len(reasoning_content) == 0:
                reasoning_content = None
            if len(response_content) == 0:
                response_content = None

            return reasoning_content, response_content

        return None, model_output

    def _is_strict_increasing_subsequence(self, subsequence: Sequence[int],
                                          sequence: Sequence[int]) -> bool:
        if not subsequence:
            return False

        sub_idx = 0
        for num in sequence:
            if sub_idx < len(subsequence) and num == subsequence[sub_idx]:
                sub_idx += 1
        return sub_idx == len(subsequence)

    def extract_reasoning_content_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> Union[DeltaMessage, None]:
        """Extract content using token ID sequence state machine"""
        # Define sequences
        think_start_sequence = self.think_start_ids
        response_start_sequence = self.response_start_ids
        response_end_sequence = self.response_end_ids

        assert (len(delta_token_ids) == 1)
        # Process each token in the delta
        token = delta_token_ids[0]

        def check_token_with_sequence(token):
            if self.current_state == "idle" or self.current_state == "think":
                return (token == self.expected_sequence[self.sequence_index]
                         or token ==  \
                         self.expected_sequence_side[self.sequence_index])
            else:
                return token == self.expected_sequence[self.sequence_index]

        def check_last_token(token):
            if self.current_state == "idle" or self.current_state == "think":
                # only return true if it's judge using a side sequence.
                if (self.sequence_index - 1 < len(self.expected_sequence_side)
                        and token
                        == self.expected_sequence_side[self.sequence_index -
                                                       1]):
                    return self.sequence_index == len(
                        self.expected_sequence_side)
                else:
                    return self.sequence_index == len(self.expected_sequence)
            else:
                return self.sequence_index == len(self.expected_sequence)

        # Check if token matches expected sequence
        token_in_state_seq = check_token_with_sequence(token)

        if token_in_state_seq:
            # Store matching token
            self.token_buffer.append(token)
            self.text_buffer += delta_text
            self.sequence_index += 1
            ## state change from idle->think->response->idle

            # Check if sequence fully matched
            if check_last_token(token):
                # State transition
                if self.current_state == "idle":
                    self.current_state = "think"
                    self.expected_sequence = response_start_sequence
                    self.expected_sequence_side = self.response_start_ids_fast
                elif self.current_state == "think":
                    self.current_state = "response"
                    self.expected_sequence = response_end_sequence
                elif self.current_state == "response":
                    self.current_state = "idle"
                    self.expected_sequence = think_start_sequence
                    self.expected_sequence_side = self.think_start_ids_fast

                # Reset matching state
                self.sequence_index = 0
                self.token_buffer = []
                self.text_buffer = ""
                # Do not send content for state transition texts.
        else:
            # Sequence broken - handle buffered content
            if self.token_buffer and len(self.token_buffer) > 0:
                # Send buffered tokens
                buffered_content = self.text_buffer + delta_text
                # Reset matching state
                self.sequence_index = 0
                self.token_buffer = []
                self.text_buffer = ""

                # Return content based on current state
                if self.current_state == "think":
                    return DeltaMessage(reasoning_content=buffered_content,
                                        content=None)
                else:
                    return DeltaMessage(reasoning_content=None,
                                        content=buffered_content)
            else:
                # No buffered content, send normally
                if self.current_state == "think":
                    return DeltaMessage(reasoning_content=delta_text,
                                        content=None)
                else:
                    return DeltaMessage(reasoning_content=None,
                                        content=delta_text)

        # If no content to send in this delta
        return None