viethq5 commited on
Commit
2eab44b
·
verified ·
1 Parent(s): 44ae95a

Upload greenmind_14b_r1_reasoning_parser.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. greenmind_14b_r1_reasoning_parser.py +231 -0
greenmind_14b_r1_reasoning_parser.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Sequence
2
+ from typing import Optional, Union
3
+
4
+ import regex as re
5
+ from transformers import PreTrainedTokenizerBase
6
+
7
+ from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
8
+ DeltaMessage)
9
+ from vllm.logger import init_logger
10
+ from vllm.reasoning import ReasoningParser, ReasoningParserManager
11
+
12
+ logger = init_logger(__name__)
13
+
14
+
15
+ @ReasoningParserManager.register_module("greenmind_14b_r1")
16
+ class GreenMind14bR1ReasoningParser(ReasoningParser):
17
+ """
18
+ Reasoning parser for GreenMind-14B-R1 model.
19
+
20
+ The GreenMind-14B-R1 model uses </think> token to denote the end of reasoning
21
+ text. This parser extracts all content before </think> as reasoning content.
22
+
23
+ think start: "<think>\n": [13708, 766, 397]
24
+ think ends: "\n</think>\n<answer>\n": [198, 522, 26865, 397, 27, 9217, 397]
25
+ response ends: "</answer>": [198, 522, 9217, 29]
26
+ """
27
+
28
+ def __init__(self, tokenizer: PreTrainedTokenizerBase):
29
+ super().__init__(tokenizer)
30
+ self.think_start_expr = r"<think>\n"
31
+ self.think_end_expr = r"\n</think>\n"
32
+
33
+ self.response_start_expr = r"\n</think>\n<answer>\n"
34
+ self.response_end_expr = r"\n</answer>"
35
+
36
+ self.full_match_reasoning_regex = re.compile(
37
+ rf"(?:{self.think_start_expr}(.*?){self.response_start_expr})?(.*?){self.response_end_expr}",
38
+ re.DOTALL)
39
+
40
+ self.half_match_reasoning_regex = re.compile(
41
+ rf"{self.think_start_expr}(.*?){self.response_start_expr}(.*)",
42
+ re.DOTALL)
43
+
44
+ self.think_start_ids = [13708, 766, 397]
45
+ self.think_start_ids_fast = [13708, 766, 29]
46
+ self.response_start_ids = [198, 522, 26865, 397, 27, 9217, 397]
47
+ self.response_start_ids_fast = [522, 26865, 397, 27, 9217, 397]
48
+ self.response_end_ids = [198, 522, 9217, 29]
49
+ self.fast_think_ids = [
50
+ 13708, 766, 1339, 522, 26865, 397, 27, 9217, 397
51
+ ]
52
+
53
+ # when state change, send out all the buffered text in last state
54
+ self.buffered_text = []
55
+ self.buffered_ids = []
56
+
57
+ self.current_state = "reasoning"
58
+ self.all_states = ["reasoning", "response"]
59
+
60
+ self.current_state = "idle"
61
+ self.expected_sequence = self.think_start_ids
62
+ # this sequence only for the think start, it has two way to start.
63
+ self.expected_sequence_side = self.think_start_ids_fast
64
+ self.sequence_index = 0
65
+ self.token_buffer = []
66
+ self.text_buffer = ""
67
+
68
+ def is_reasoning_end(self, input_ids: list[int]) -> bool:
69
+ return self.current_state == "response"
70
+
71
+ def extract_content_ids(self, input_ids: list[int]) -> list[int]:
72
+ # for hunyuan streaming reason parsing, the stream parse
73
+ # will call first, and the same token will be called in
74
+ # is_reasoning_end and extract_content_ids
75
+ # this id is not part of content, so just return [] here.
76
+ return []
77
+
78
+ def extract_reasoning_content(
79
+ self, model_output: str, request: ChatCompletionRequest
80
+ ) -> tuple[Optional[str], Optional[str]]:
81
+ """Extract the reasoning content & content sections, respectively.
82
+ If the sequence doesn't match what we expect, i.e., the model generates
83
+ something else, all content is considered non-reasoning content.
84
+
85
+ Args:
86
+ model_output (str): Output of the model to be parsed.
87
+ request (ChatCompletionRequest): Request being processed.
88
+
89
+ Returns:
90
+ tuple[Optional[str], Optional[str]]: Tuple pair containing the
91
+ reasoning content and non-reasoning content.
92
+ """
93
+
94
+ re_match = self.full_match_reasoning_regex.findall(model_output)
95
+ if re_match:
96
+ reasoning_content, response_content = re_match[0]
97
+ if len(reasoning_content) == 0:
98
+ reasoning_content = None
99
+ if len(response_content) == 0:
100
+ response_content = None
101
+ return reasoning_content, response_content
102
+
103
+ fallback_regex = self.half_match_reasoning_regex
104
+ fallback_match = fallback_regex.findall(model_output)
105
+ if fallback_match:
106
+ reasoning_content, response_content = fallback_match[0]
107
+
108
+ if response_content.endswith(self.response_end_expr):
109
+ response_content = response_content[:-len(self.
110
+ response_end_expr)]
111
+
112
+ if len(reasoning_content) == 0:
113
+ reasoning_content = None
114
+ if len(response_content) == 0:
115
+ response_content = None
116
+
117
+ return reasoning_content, response_content
118
+
119
+ return None, model_output
120
+
121
+ def _is_strict_increasing_subsequence(self, subsequence: Sequence[int],
122
+ sequence: Sequence[int]) -> bool:
123
+ if not subsequence:
124
+ return False
125
+
126
+ sub_idx = 0
127
+ for num in sequence:
128
+ if sub_idx < len(subsequence) and num == subsequence[sub_idx]:
129
+ sub_idx += 1
130
+ return sub_idx == len(subsequence)
131
+
132
+ def extract_reasoning_content_streaming(
133
+ self,
134
+ previous_text: str,
135
+ current_text: str,
136
+ delta_text: str,
137
+ previous_token_ids: Sequence[int],
138
+ current_token_ids: Sequence[int],
139
+ delta_token_ids: Sequence[int],
140
+ ) -> Union[DeltaMessage, None]:
141
+ """Extract content using token ID sequence state machine"""
142
+ # Define sequences
143
+ think_start_sequence = self.think_start_ids
144
+ response_start_sequence = self.response_start_ids
145
+ response_end_sequence = self.response_end_ids
146
+
147
+ assert (len(delta_token_ids) == 1)
148
+ # Process each token in the delta
149
+ token = delta_token_ids[0]
150
+
151
+ def check_token_with_sequence(token):
152
+ if self.current_state == "idle" or self.current_state == "think":
153
+ return (token == self.expected_sequence[self.sequence_index]
154
+ or token == \
155
+ self.expected_sequence_side[self.sequence_index])
156
+ else:
157
+ return token == self.expected_sequence[self.sequence_index]
158
+
159
+ def check_last_token(token):
160
+ if self.current_state == "idle" or self.current_state == "think":
161
+ # only return true if it's judge using a side sequence.
162
+ if (self.sequence_index - 1 < len(self.expected_sequence_side)
163
+ and token
164
+ == self.expected_sequence_side[self.sequence_index -
165
+ 1]):
166
+ return self.sequence_index == len(
167
+ self.expected_sequence_side)
168
+ else:
169
+ return self.sequence_index == len(self.expected_sequence)
170
+ else:
171
+ return self.sequence_index == len(self.expected_sequence)
172
+
173
+ # Check if token matches expected sequence
174
+ token_in_state_seq = check_token_with_sequence(token)
175
+
176
+ if token_in_state_seq:
177
+ # Store matching token
178
+ self.token_buffer.append(token)
179
+ self.text_buffer += delta_text
180
+ self.sequence_index += 1
181
+ ## state change from idle->think->response->idle
182
+
183
+ # Check if sequence fully matched
184
+ if check_last_token(token):
185
+ # State transition
186
+ if self.current_state == "idle":
187
+ self.current_state = "think"
188
+ self.expected_sequence = response_start_sequence
189
+ self.expected_sequence_side = self.response_start_ids_fast
190
+ elif self.current_state == "think":
191
+ self.current_state = "response"
192
+ self.expected_sequence = response_end_sequence
193
+ elif self.current_state == "response":
194
+ self.current_state = "idle"
195
+ self.expected_sequence = think_start_sequence
196
+ self.expected_sequence_side = self.think_start_ids_fast
197
+
198
+ # Reset matching state
199
+ self.sequence_index = 0
200
+ self.token_buffer = []
201
+ self.text_buffer = ""
202
+ # Do not send content for state transition texts.
203
+ else:
204
+ # Sequence broken - handle buffered content
205
+ if self.token_buffer and len(self.token_buffer) > 0:
206
+ # Send buffered tokens
207
+ buffered_content = self.text_buffer + delta_text
208
+ # Reset matching state
209
+ self.sequence_index = 0
210
+ self.token_buffer = []
211
+ self.text_buffer = ""
212
+
213
+ # Return content based on current state
214
+ if self.current_state == "think":
215
+ return DeltaMessage(reasoning_content=buffered_content,
216
+ content=None)
217
+ else:
218
+ return DeltaMessage(reasoning_content=None,
219
+ content=buffered_content)
220
+ else:
221
+ # No buffered content, send normally
222
+ if self.current_state == "think":
223
+ return DeltaMessage(reasoning_content=delta_text,
224
+ content=None)
225
+ else:
226
+ return DeltaMessage(reasoning_content=None,
227
+ content=delta_text)
228
+
229
+ # If no content to send in this delta
230
+ return None
231
+