File size: 8,252 Bytes
51fd6a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
"""
scripts/budget_processor.py
============================
Inference-time hard budget enforcement for the Thinking Budget environment.

PROBLEM
-------
Reward shaping during training teaches the policy to allocate <think> tokens
intelligently β€” long on bugs, short on safe files.  But at inference time,
nothing prevents a model from blowing past its own learned budget if the
sampling distribution drifts.

CONTRIBUTION
------------
A `LogitsProcessor` that converts a soft, learned budget into a *hard*
inference-time constraint.  The processor:
  β€’ counts tokens emitted between an opening <think> and its closing </think>
  β€’ when the per-block budget is exceeded, forces the next sampled token to
    be the </think> token id, ending reasoning gracefully
  β€’ respects an *episode-level* budget across all <think> blocks too β€”
    once the global pool is empty no further reasoning is allowed

This is the inference complement to the training-time reward.  Together
they form a closed loop:  reward shaping during GRPO  β†’  the policy learns
to allocate;  budget processor at inference  β†’  the user can hard-cap
compute and the policy still degrades gracefully.

USAGE
-----
    from transformers import AutoTokenizer
    from scripts.budget_processor import ThinkingBudgetProcessor

    tok = AutoTokenizer.from_pretrained(...)
    proc = ThinkingBudgetProcessor(
        tokenizer=tok,
        per_block_budget=400,    # max tokens per <think>...</think>
        episode_budget=2000,     # max total thinking tokens per generation
    )
    out = model.generate(..., logits_processor=[proc])

DEMO
----
The Gradio Space (Tab "🎚 Budget Slider") wraps this processor with a
slider [50 / 100 / 250 / 500 / 1000 / ∞] and a live F1 readout.  Tighter
budgets degrade the untrained baseline catastrophically while the trained
policy adapts β€” that's the screenshot.
"""
from __future__ import annotations

from typing import List, Optional

try:
    import torch
    from transformers import LogitsProcessor
    _HAS_TORCH = True
except ImportError:  # pragma: no cover - allow import without torch (CI)
    _HAS_TORCH = False
    LogitsProcessor = object  # type: ignore


# ── Tag ids cache ─────────────────────────────────────────────────────────
def _resolve_tag_ids(tokenizer, tag: str) -> List[int]:
    """
    Resolve all token ids that emit `tag` (with and without leading space,
    BPE-merged variants).  We need this because the same logical </think>
    can be tokenized as several different sequences depending on context.
    """
    candidates = {tag, " " + tag, "\n" + tag, "\n\n" + tag}
    ids: List[int] = []
    for cand in candidates:
        try:
            toks = tokenizer.encode(cand, add_special_tokens=False)
            if toks:
                ids.extend(toks)
        except Exception:
            continue
    seen = set()
    out = []
    for i in ids:
        if i not in seen:
            seen.add(i)
            out.append(i)
    return out


class ThinkingBudgetProcessor(LogitsProcessor):
    """
    Stateful per-batch logits processor that enforces hard <think> budgets.

    State machine (per sequence in the batch):
        0  outside any <think> block
        1  inside an open <think> block β€” counting tokens

    When per-block or episode budget hits zero, the next-token logits are
    forced to the most-preferred </think> token id, emitting `</think>` and
    transitioning back to state 0.
    """

    def __init__(
        self,
        tokenizer,
        per_block_budget: int = 400,
        episode_budget: Optional[int] = None,
        verbose: bool = False,
    ):
        if not _HAS_TORCH:
            raise RuntimeError("torch + transformers required for ThinkingBudgetProcessor")
        self.per_block_budget = max(1, int(per_block_budget))
        self.episode_budget = int(episode_budget) if episode_budget else None
        self.verbose = verbose

        self.open_ids = _resolve_tag_ids(tokenizer, "<think>")
        self.close_ids = _resolve_tag_ids(tokenizer, "</think>")
        if not self.open_ids or not self.close_ids:
            raise ValueError("tokenizer does not contain <think> / </think> tokens")
        self.preferred_close_id = self.close_ids[0]

        # Per-sequence state (key: id(input_ids tensor))
        self._state: dict = {}

    def _get_seq_state(self, key) -> dict:
        if key not in self._state:
            self._state[key] = {
                "in_block": False,
                "block_used": 0,
                "episode_used": 0,
            }
        return self._state[key]

    def _scan_last_token(self, last_id: int, st: dict) -> None:
        """Update the state machine based on the last emitted token id."""
        if last_id in self.open_ids:
            st["in_block"] = True
            st["block_used"] = 0
        elif last_id in self.close_ids:
            st["in_block"] = False
        elif st["in_block"]:
            st["block_used"] += 1
            st["episode_used"] += 1

    def _budget_exceeded(self, st: dict) -> bool:
        if st["block_used"] >= self.per_block_budget:
            return True
        if self.episode_budget is not None and st["episode_used"] >= self.episode_budget:
            return True
        return False

    def __call__(self, input_ids: "torch.Tensor", scores: "torch.Tensor") -> "torch.Tensor":
        # input_ids: (batch, seq)   scores: (batch, vocab)
        for b in range(input_ids.shape[0]):
            seq = input_ids[b]
            key = (id(input_ids), b)
            st = self._get_seq_state(key)

            # Initial scan: rebuild state from full sequence on first call
            if not st.get("_initialized"):
                st["_initialized"] = True
                for tid in seq.tolist():
                    self._scan_last_token(int(tid), st)
            else:
                self._scan_last_token(int(seq[-1].item()), st)

            if st["in_block"] and self._budget_exceeded(st):
                # Force the next token to be </think>
                forced = torch.full_like(scores[b], float("-inf"))
                forced[self.preferred_close_id] = 0.0
                scores[b] = forced
                if self.verbose:
                    print(f"[budget] seq {b}: forced </think> at "
                          f"block_used={st['block_used']} "
                          f"episode_used={st['episode_used']}")
        return scores

    def reset(self) -> None:
        """Clear all per-sequence state.  Call between generation runs."""
        self._state.clear()


# ── Lightweight character-level fallback (when no tokenizer is available) ──
def enforce_character_budget(
    text: str,
    per_block_budget: int = 400,
    episode_budget: Optional[int] = None,
) -> str:
    """
    Post-hoc enforcement on already-generated text.  Used by the demo when
    we want to show what a budget *would* have done to a recorded trace.
    Truncates each <think> block to `per_block_budget` characters and
    inserts </think>; tracks total budget across all blocks.
    """
    import re
    out = []
    spent = 0
    pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
    last = 0
    for m in pattern.finditer(text):
        out.append(text[last:m.start()])
        block = m.group(1)
        if episode_budget is not None:
            remaining_episode = max(0, episode_budget - spent)
            cap = min(per_block_budget, remaining_episode)
        else:
            cap = per_block_budget
        if len(block) <= cap:
            kept = block
        else:
            kept = block[:cap].rstrip() + "  …[truncated by budget]"
        out.append(f"<think>{kept}</think>")
        spent += len(kept)
        last = m.end()
    out.append(text[last:])
    return "".join(out)


if __name__ == "__main__":
    sample = "Pre <think>" + "x" * 1000 + "</think> mid <think>" + "y" * 50 + "</think> end"
    out = enforce_character_budget(sample, per_block_budget=200, episode_budget=300)
    print(out[:400])
    assert "[truncated by budget]" in out
    print("βœ… character-budget smoke test passed")