File size: 12,816 Bytes
bc5d1b8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""Reasoning parser plugin for Domyn-Small ``<think>...</think>`` outputs.

Loaded into vLLM with ``--reasoning-parser-plugin <path>`` and selected via
``--reasoning-parser think_block``. The parser splits each model output on
the literal ``</think>`` marker: everything before it is reasoning,
everything after is final content.

See :class:`ThinkBlockReasoningParser` for the streaming state machine and
how per-request thinking-on/off is discovered.
"""

from __future__ import annotations

from collections.abc import Iterable, Sequence
from typing import TYPE_CHECKING

from vllm.reasoning import ReasoningParser, ReasoningParserManager

if TYPE_CHECKING:
    from vllm.entrypoints.openai.chat_completion.protocol import ChatCompletionRequest
    from vllm.entrypoints.openai.engine.protocol import DeltaMessage
    from vllm.entrypoints.openai.responses.protocol import ResponsesRequest

# Literal markers emitted by the Domyn-Small chat template. `<think>` is
# pre-emitted by the prompt, so model output never starts with it; only `</think>`
# actually has to be detected at runtime.
START = "<think>"
END = "</think>"


def _max_suffix_prefix(s: str, marker: str) -> str:
    """Longest non-empty suffix of ``s`` that is also a prefix of ``marker``.

    Used to decide how many trailing bytes of the streaming buffer must be
    held back — if those bytes could still grow into ``marker`` on the next
    delta, releasing them now would fragment the marker across deltas (e.g.
    emitting ``</thi`` and then ``nk>``).
    """
    for i in range(min(len(marker) - 1, len(s)), 0, -1):
        if s.endswith(marker[:i]):
            return s[-i:]
    return ""


@ReasoningParserManager.register_module("think_block")
class ThinkBlockReasoningParser(ReasoningParser):
    """Splits model output on the literal ``</think>`` marker.

    **Streaming.** Olmo3-style buffered state machine: incoming text is
    accumulated in :attr:`_buffer` and only released when the marker is
    either confirmed (split point reached) or ruled out (the buffer tail
    can no longer be a prefix of ``</think>``). This guarantees the marker
    is never fragmented across deltas.

    **Per-request lane.** The initial lane (``"reasoning"`` vs
    ``"content"``) is set from the request itself: ``True`` if
    ``chat_template_kwargs.enable_thinking`` (or ``.thinking``) is truthy,
    or if any system message contains the literal ``"thinking on"``
    directive — mirroring the chat template's own detection.

    **Request discovery.** vLLM instantiates the parser per request from
    inside ``create_chat_completion(self, request, ...)``, but does not
    pass the request to the constructor. We recover it by walking the call
    stack at ``__init__`` time, inspecting only each frame's *function
    arguments* (so we don't accidentally match request-shaped objects in
    module globals or unrelated locals). If no request is found we fall
    back to ``thinking=off``, which keeps tool-call streaming working out
    of the box.
    """

    def __init__(self, tokenizer, *args, **kwargs) -> None:
        # Base ReasoningParser only accepts `tokenizer`; swallow any extras so
        # the registration signature stays compatible across vLLM versions.
        super().__init__(tokenizer)
        self._buffer: str = ""
        # Current lane for streaming output: "reasoning" while inside
        # <think>...</think>, "content" otherwise. Locked to "content" once
        # `</think>` is observed.
        self._state: str = "content"
        # Tracks whether we have applied per-request configuration yet —
        # stack-walking covers the streaming path; `extract_reasoning` also
        # configures on the first non-streaming call as a safety net.
        self._configured: bool = False

        request = self._find_request_in_stack()
        if request is not None:
            self._configure_for_request(request)

    @staticmethod
    def _looks_like_request(obj) -> bool:
        """Duck-typed check for ChatCompletionRequest / ResponsesRequest.

        Avoids importing vLLM's protocol module, which differs across forks
        and isn't guaranteed to be importable at plugin load time.
        """
        return hasattr(obj, "messages") and (
            hasattr(obj, "chat_template_kwargs") or hasattr(obj, "stream")
        )

    @classmethod
    def _find_request_in_stack(cls, max_depth: int = 12):
        """Locate the in-flight request by scanning caller-frame arguments.

        Walks a bounded number of caller frames via ``sys._getframe`` /
        ``frame.f_back`` and inspects only each frame's *function
        arguments* — never its full locals. This matches vLLM's
        ``create_chat_completion(self, request, ...)`` signature and avoids
        matching request-shaped objects that happen to live in module
        globals or unrelated locals (e.g. test fixtures).

        We deliberately avoid :func:`inspect.stack`, which reads source
        files via ``linecache`` and builds ``FrameInfo`` objects for the
        whole stack on every call — measurable overhead per request under
        high concurrency, since parser construction is per-request and
        runs under the GIL on the serving event loop.
        """
        import sys
        try:
            frame = sys._getframe(1)
        except Exception:
            return None
        depth = 0
        while frame is not None and depth < max_depth:
            code = frame.f_code
            n_args = code.co_argcount + code.co_kwonlyargcount
            for name in code.co_varnames[:n_args]:
                value = frame.f_locals.get(name)
                if cls._looks_like_request(value):
                    return value
            frame = frame.f_back
            depth += 1
        return None

    def _configure_for_request(self, request) -> None:
        """Set initial streaming lane from the request's thinking flag."""
        self._state = "reasoning" if self._thinking_was_enabled(request) else "content"
        self._configured = True

    def _decode(self, ids: Sequence[int]) -> str:
        # `skip_special_tokens=False` is required: `</think>` may be tokenized
        # as (or contain) special tokens that the default decode would strip,
        # which would silently break marker detection.
        try:
            return self.model_tokenizer.decode(list(ids), skip_special_tokens=False)
        except Exception:
            return ""

    @property
    def reasoning_start_str(self) -> str | None:
        return START

    @property
    def reasoning_end_str(self) -> str | None:
        return END

    def is_reasoning_end(self, input_ids: Sequence[int]) -> bool:
        return END in self._decode(input_ids)

    def is_reasoning_end_streaming(
        self, input_ids: Sequence[int], delta_ids: Iterable[int]
    ) -> bool:
        # Decode a 64-token tail window so the marker is detected even when
        # it straddles the previous-vs-delta token boundary (BPE may split
        # `</think>` across multiple tokens, especially around punctuation).
        tail = list(input_ids)[-64:]
        return END in self._decode(tail)

    def extract_content_ids(self, input_ids: list[int]) -> list[int]:
        text = self._decode(input_ids)
        idx = text.rfind(END)
        if idx < 0:
            return []
        try:
            return self.model_tokenizer.encode(
                text[idx + len(END):], add_special_tokens=False
            )
        except Exception:
            return []

    def count_reasoning_tokens(self, token_ids: Sequence[int]) -> int:
        text = self._decode(token_ids)
        idx = text.find(END)
        prefix = text if idx < 0 else text[:idx]
        try:
            return len(self.model_tokenizer.encode(prefix, add_special_tokens=False))
        except Exception:
            return 0

    def extract_reasoning(
        self,
        model_output: str,
        request: "ChatCompletionRequest | ResponsesRequest",
    ) -> tuple[str | None, str | None]:
        """Split a full (non-streaming) output into ``(reasoning, content)``.

        Returns ``(None, content)`` when the request has thinking disabled
        and the output contains no marker — the chat template pre-emits
        ``<think></think>`` in the prompt in that case, so a marker-less
        output is purely the answer.
        """
        # Configure streaming state as a side effect: a fork's serving layer
        # may call this before streaming starts, and we don't want the
        # streaming path to fall back to the `thinking=off` default if the
        # request actually had thinking enabled.
        if not self._configured:
            self._configure_for_request(request)

        s = model_output
        if s.startswith(START):
            s = s[len(START):]
        if END in s:
            reasoning, _, content = s.partition(END)
            return (reasoning.strip("\n") or None, content.lstrip("\n") or None)
        # No `</think>` in output: only treat the text as truncated reasoning
        # if we have positive evidence that thinking was enabled — otherwise
        # it is the final answer.
        if self._thinking_was_enabled(request):
            return (s.strip("\n") or None, None)
        return (None, s.lstrip("\n") or None)

    @staticmethod
    def _thinking_was_enabled(request) -> bool:
        """Whether ``request`` asked for reasoning to be emitted.

        Mirrors the chat template's own detection so the parser stays in
        lockstep with prompt construction: enabled iff
        ``chat_template_kwargs.enable_thinking`` (or ``.thinking``) is
        truthy, or any system message contains the literal ``"thinking on"``
        directive (case-insensitive).
        """
        kwargs = getattr(request, "chat_template_kwargs", None) or {}
        if kwargs.get("enable_thinking") or kwargs.get("thinking"):
            return True
        messages = getattr(request, "messages", None) or []
        for m in messages:
            role = m.get("role") if isinstance(m, dict) else getattr(m, "role", None)
            if role != "system":
                continue
            content = m.get("content") if isinstance(m, dict) else getattr(m, "content", None)
            if isinstance(content, str) and "thinking on" in content.lower():
                return True
        return False

    def extract_reasoning_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
    ) -> "DeltaMessage | None":
        """Emit one ``DeltaMessage`` per delta, routed to reasoning or content.

        The marker ``</think>`` is never emitted to the client. Trailing
        bytes of the buffer that *could* still grow into the marker on the
        next delta are held back, so the marker is never fragmented across
        deltas (e.g. ``</thi`` ... ``nk>``). When the marker is observed,
        pre-marker bytes go to the current lane and post-marker bytes go
        to ``content``; the lane is then locked to ``content``.
        """
        from vllm.entrypoints.openai.engine.protocol import DeltaMessage

        self._buffer += delta_text

        # Case 1 — marker fully present in the buffer: split and switch lane.
        # The pre-marker chunk stays on the *current* lane (reasoning if we
        # were inside <think>, content otherwise); the post-marker chunk
        # always goes to content; the lane is locked to content afterwards.
        idx = self._buffer.find(END)
        if idx >= 0:
            pre = self._buffer[:idx]
            post = self._buffer[idx + len(END):]
            self._buffer = ""
            pre_lane = self._state
            self._state = "content"
            if not pre and not post:
                return None
            fields: dict = {}
            if pre:
                fields[pre_lane] = pre
            if post:
                # `.get` covers the edge case where pre_lane is already
                # "content" and both pre and post are non-empty — they get
                # concatenated into a single content delta.
                fields["content"] = fields.get("content", "") + post
            return DeltaMessage(**fields)

        # Case 2 — no marker yet: release everything except a possible
        # partial-marker tail, which we retain for the next delta.
        held = _max_suffix_prefix(self._buffer, END)
        safe_end = len(self._buffer) - len(held)
        if safe_end == 0:
            return None
        chunk = self._buffer[:safe_end]
        self._buffer = self._buffer[safe_end:]
        return DeltaMessage(**{self._state: chunk})