File size: 22,805 Bytes
7155cf2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
# BudgetThinker: Empowering Budget-aware LLM Reasoning with Control Tokens ๐Ÿš€

## Table of Contents

- [About](#About) ๐Ÿ“
- [Install](#Install) โš™๏ธ
- [Preparation](#preparation) ๐Ÿ“š
- [Training](#training) ๐Ÿ‹๏ธโ€โ™‚๏ธ
- [Evaluation](#evaluation) ๐Ÿ“Š

## About
This repository contains the code implementation for the paper : 

[BudgetThinker: Empowering Budget-aware LLM Reasoning with Control Tokens](https://www.arxiv.org/abs/2508.17196 ) ๐Ÿš€

Our training data can be downloaded from the following links:

[Dataset-BudgetThinker](https://huggingface.co/datasets/Xin-Rui/Dataset-BudgetThinker/tree/main ) ๐Ÿ“ฅ

The trained model (based on DeepSeek-R1-Distill-Qwen-1.5B) can be obtained from the following link:

[BudgetThinker-1.5b](https://huggingface.co/Xin-Rui/BudgetThinker-1.5b/tree/main ) ๐Ÿ“ฆ

## Install

### Clone This Repo ๐Ÿ“‹

### SFT-Stage๏ผšLLaMA-Factory

```bash
git clone git@github.com:hiyouga/LLaMA-Factory.git
```

After cloning the repository, follow the instructions in the [Installation Guide](https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/installation.html ) to configure the necessary dependencies. ๐Ÿ”ง

### Modify Environments' Code ๐Ÿ› ๏ธ

You need to modify a piece of code in the transformers library within the environment corresponding to the LLaMA-Factory project. Locate the source code of the transformers library in your environment and replace the loss/loss_utils.py file. For example, using my path:

```bash
/home/user/anaconda3/envs/llama-fac/lib/python3.11/site-packages/transformers/loss/loss_utils.py

โ†•๏ธ

to_replace/transformers/loss/loss_utils.py
```

> Note: The version of the transformers library corresponding to this code is 4.46.1.

The modified code will allow you to adjust the loss weights for special tokens during training by modifying environment variables. The specific instructions are as follows:

```bash
export special_token_loss=F # Set to F to disable loss calculation for special tokens (weight = 0)
export special_token_loss=T # Set to T to enable loss calculation for special tokens (default weight = 1)
export special_token_loss=Tn # Set the loss weight for special tokens, where n is a float representing the specified weight value
# For example: export special_token_loss=T10, which sets the loss weight for special tokens to 10
```

### RL-Stage๏ผšEasyR1 ๐ŸŽฏ

The modified project code is included in the `./easyr1` directory. For environment configuration, please refer to the [EasyR1](https://github.com/hiyouga/EasyR1 ) documentation.

### Eval-Stage: Qwen2.5-Math ๐Ÿ“ˆ

The modified project code is included in the `./evaluation` directory. For environment configuration, please refer to the [Qwen2.5-Math](https://github.com/QwenLM/Qwen2.5-Math ) documentation.

### Modify Environments' Code ๐Ÿ› ๏ธ

It is necessary to modify the code in the environments corresponding to the `./easyr1` and `./evaluation` directories. We need to modify the source code of vllm to support the insertion of special tokens during inference:

#### Method 1: Direct Replacement (Limited to vllm Version 0.7.3) ๐Ÿ”
Locate the `worker/model_runner.py` file in the vllm library and replace it:

```bash
/home/user/anaconda3/envs/easyr1/lib/python3.11/site-packages/vllm/worker/model_runner.py
& 
/home/user/anaconda3/envs/QMath/lib/python3.11/site-packages/vllm/worker/model_runner.py

โ†•๏ธ

to_replace/vllm/worker/model_runner.py
```

> Note: The version of the vllm library corresponding to this code is 0.7.3.

#### Methods 2: Direct Modification ๐Ÿ“

Focus on the execute_model function in the `...vllm/worker/model_runner.py` file. The original version is as follows:

```python

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

        ... more code ...
        ... more code ...

        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
            return []

        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )




        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
            if model_input.is_prompt:
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
            elif decode_meta.use_cuda_graph:
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states

            output.hidden_states = hidden_states

        return [output]
```

Modify the code as follows:

```python

    @torch.inference_mode()
    def execute_model(
        self,
        model_input: ModelInputForGPUWithSamplingMetadata,
        kv_caches: List[torch.Tensor],
        intermediate_tensors: Optional[IntermediateTensors] = None,
        num_steps: int = 1,
    ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]:
        if num_steps > 1:
            raise ValueError("num_steps > 1 is not supported in ModelRunner")

        ... more code ...
        ... more code ...

        # Compute the logits in the last pipeline stage.
        if not get_pp_group().is_last_rank:
            return hidden_or_intermediate_states

        logits = self.model.compute_logits(hidden_or_intermediate_states,
                                           model_input.sampling_metadata)

        if not self.is_driver_worker:
            return []

        # Sample the next token.
        output: SamplerOutput = self.model.sample(
            logits=logits,
            sampling_metadata=model_input.sampling_metadata,
        )

        #! >>>>>>>>>>> add remaining tokens to output <<<<<<<<<<<<
        import os
        if os.getenv("remaining", "remaing") == "remaing":
            special_tokens = [151665+i for i in range(400)]
            for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
                prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
                output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
                # reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
                last_special_token_idx, last_special_token = None, None
                for idx in range(len(output_token_ids_till_now)-1, -1, -1):
                    token_id = output_token_ids_till_now[idx]
                    if token_id in special_tokens:
                        last_special_token_idx = idx
                        last_special_token = token_id
                        break
                if last_special_token == 151665:  # has reached the last special token of <remaining 50>
                    continue
                if last_special_token_idx is not None:
                    distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
                    if distance_to_last_special_token == 50:
                        output.outputs[seq_id].samples[0].output_token = last_special_token - 1
                        former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
                        output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
                        # delete former key-value pair
                        
                        #g
                        # print(f"former_key = {former_key}")
                        # print(f"last_special_token - 1 = {last_special_token - 1}")
                        if former_key == last_special_token -1:
                            print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
                        else:
                            del output.outputs[seq_id].samples[0].logprobs[former_key]
                        #g
                        
                        # del output.outputs[seq_id].samples[0].logprobs[former_key]
                else:  # there has not been any special token in the output
                    last_special_token = None
                    for prompt_token_id in prompt_token_ids:
                        if prompt_token_id in special_tokens:
                            last_special_token = prompt_token_id
                            break
                    if last_special_token is not None:
                        if len(output_token_ids_till_now) == 50:
                            output.outputs[seq_id].samples[0].output_token = last_special_token - 1
                            former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
                            output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
                            #g
                            # print(f"former_key = {former_key}")
                            # print(f"last_special_token - 1 = {last_special_token - 1}")
                            if former_key == last_special_token -1:
                                print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
                            else:
                                del output.outputs[seq_id].samples[0].logprobs[former_key]
                            #g
                            # del output.outputs[seq_id].samples[0].logprobs[former_key]

        elif "ratio" in os.getenv("remaining", "remaing"):
            N = int(os.getenv("remaining", "remaing").replace("ratio", ""))
            assert os.getenv("budget") is not None
            budget = int(os.environ["budget"])
            delta = budget // N + 1

            special_tokens = [151665+i for i in range(N-1)]
            for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
                prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
                output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
                # reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
                last_special_token_idx, last_special_token = None, None
                for idx in range(len(output_token_ids_till_now)-1, -1, -1):
                    token_id = output_token_ids_till_now[idx]
                    if token_id in special_tokens:
                        last_special_token_idx = idx
                        last_special_token = token_id
                        break
                if last_special_token == 151665:  # has reached the last special token of <remaining 50>
                    continue
                if last_special_token_idx is not None:
                    distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
                    if distance_to_last_special_token == delta:
                        output.outputs[seq_id].samples[0].output_token = last_special_token - 1
                        former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
                        output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
                        # delete former key-value pair
                        
                        #g
                        # print(f"former_key = {former_key}")
                        # print(f"last_special_token - 1 = {last_special_token - 1}")
                        if former_key == last_special_token -1:
                            print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
                        else:
                            del output.outputs[seq_id].samples[0].logprobs[former_key]
                        #g
                        
                        # del output.outputs[seq_id].samples[0].logprobs[former_key]
                else:  # there has not been any special token in the output
                    last_special_token = 151671 + 1 #g ๆ‰‹ๅŠจ่ฎพ็ฝฎๆˆ7/8 + 1็š„token๏ผŒๅฆๅˆ™ๅ…จๆ˜ฏไปŽ6/8ๅผ€ๅง‹่พ“ๅ‡บใ€‚
                    if last_special_token is not None:
                        if len(output_token_ids_till_now) == delta:
                            output.outputs[seq_id].samples[0].output_token = last_special_token - 1
                            former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
                            output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
                            #g
                            # print(f"former_key = {former_key}")
                            # print(f"last_special_token - 1 = {last_special_token - 1}")
                            if former_key == last_special_token -1:
                                print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
                            else:
                                del output.outputs[seq_id].samples[0].logprobs[former_key]
                            #g
                            # del output.outputs[seq_id].samples[0].logprobs[former_key]
            

        elif os.getenv("remaining", "remaing") == "remaining250":
            special_tokens = [151665+i for i in range(40)]
            for seq_id in range(len(model_input.sampling_metadata.seq_groups)):
                prompt_token_ids = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).prompt_token_ids
                output_token_ids_till_now = next(iter(model_input.sampling_metadata.seq_groups[seq_id].seq_data.values())).output_token_ids
                # reversely iterate outputtoken_ids_till_now, which is a tuple, to find the last special token
                last_special_token_idx, last_special_token = None, None
                for idx in range(len(output_token_ids_till_now)-1, -1, -1):
                    token_id = output_token_ids_till_now[idx]
                    if token_id in special_tokens:
                        last_special_token_idx = idx
                        last_special_token = token_id
                        break
                if last_special_token == 151665:  # has reached the last special token of <remaining 50>
                    continue
                if last_special_token_idx is not None:
                    distance_to_last_special_token = len(output_token_ids_till_now) - last_special_token_idx - 1
                    if distance_to_last_special_token == 250:
                        output.outputs[seq_id].samples[0].output_token = last_special_token - 1
                        former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
                        output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
                        # delete former key-value pair
                        
                        #g
                        # print(f"former_key = {former_key}")
                        # print(f"last_special_token - 1 = {last_special_token - 1}")
                        if former_key == last_special_token -1:
                            print("&"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
                        else:
                            del output.outputs[seq_id].samples[0].logprobs[former_key]
                        #g
                        
                        # del output.outputs[seq_id].samples[0].logprobs[former_key]
                else:  # there has not been any special token in the output
                    last_special_token = None
                    for prompt_token_id in prompt_token_ids:
                        if prompt_token_id in special_tokens:
                            last_special_token = prompt_token_id
                            break
                    if last_special_token is not None:
                        if len(output_token_ids_till_now) == 250:
                            output.outputs[seq_id].samples[0].output_token = last_special_token - 1
                            former_key = list(output.outputs[seq_id].samples[0].logprobs.keys())[0]
                            output.outputs[seq_id].samples[0].logprobs[last_special_token - 1] = list(output.outputs[seq_id].samples[0].logprobs.values())[0]
                            #g
                            # print(f"former_key = {former_key}")
                            # print(f"last_special_token - 1 = {last_special_token - 1}")
                            if former_key == last_special_token -1:
                                print("#"*50 + f"former_key == last_special_token -1 == {former_key}" + "!"*50)
                            else:
                                del output.outputs[seq_id].samples[0].logprobs[former_key]
                            #g
                            # del output.outputs[seq_id].samples[0].logprobs[former_key]
        
        else:
            pass
        #! >>>>>>>>>>> add remaining tokens to output <<<<<<<<<<<<


        if self.return_hidden_states:
            # we only need to pass hidden states of most recent token
            assert model_input.sampling_metadata is not None
            indices = model_input.sampling_metadata.selected_token_indices
            if model_input.is_prompt:
                hidden_states = hidden_or_intermediate_states.index_select(
                    0, indices)
            elif decode_meta.use_cuda_graph:
                hidden_states = hidden_or_intermediate_states[:len(indices)]
            else:
                hidden_states = hidden_or_intermediate_states

            output.hidden_states = hidden_states

        return [output]
```


## Preparation ๐Ÿ“–

### Model Preparation ๐Ÿ› ๏ธ

```bash
cd ./Preparation
```

Modify the `ori_model_path` and `new_model_path` variables in `Preparation/add_special_tokens.py` to embed special tokens into the new model.

```python
    ori_model_path = '/path/to/your/ori/model'
    new_model_path = '/path/to/your/new/model'
```

### Data Preparation ๐Ÿ“ฅ

Our training data can be downloaded from the following links:

[Dataset-BudgetThinker](https://huggingface.co/datasets/Xin-Rui/Dataset-BudgetThinker/tree/main )

After downloading the SFT-Data, register it in the `dataset_info.json` file of LLaMA-Factory with the registration name `8ratio_SFT_below10000`.

#### Data Format

**NOTICE!** โš ๏ธ

The data format must remain the same during the SFT and RL stages.

The format of data must strictly follow the following example (especially the prompt format in 'prompt', it's must be the same as ):
```json
"prompt":"Return your final response within \\boxed{}. 
xxxxxx
\n(Complete thinking within 1600 tokens or fewer, 7 special tokens ( \n<remaining>7/8</remaining>\n , \n<remaining>6/8</remaining>\n , \n<remaining>5/8</remaining>\n , \n<remaining>4/8</remaining>\n , \n<remaining>3/8</remaining>\n , \n<remaining>2/8</remaining>\n , \n<remaining>1/8</remaining>\n ) will split the thinking process into 8 parts.)"

"answer":"<think>
xxxxx
</think>\n**Final Answer**\\boxed{}"
```

The data format is the same as the one used in the paper. For more details, please refer to the paper.

## Training ๐Ÿ‹๏ธโ€โ™‚๏ธ

### SFT Training

```bash
cd ./LLaMA-Factory
```

Use deepseed to accelerate the training process.
For detailed scripts, refer to `LLaMA-Factory/examples/deepseed_train.sh`.

### RL Training

```bash
cd ./easyr1
```

After configuring the `model_path` parameter in the `easyr1/examples/8ratio_v1.sh` and `easyr1/examples/8ratio_v1.yaml` files, you can run the following command:

```bash
bash /mnt/lyc/wuxinrui/BudgetThinker/easyr1/examples/8ratio_v1.sh
```

#### Parameter Introduction

The script involves three environment variables: stage, steady, and remaining.
- stage: 1/2, representing the use of 1/2 stage inference during training.

    Stage 1 represents normal output of the chain of thought.

    Stage 2 represents manually interrupting the output when the chain of thought reaches the budget, and manually inserting `</think>\n**Final Answer**` as the ending prompt at the current position, followed by another output.

- steady: Represents the name of the current training session. For example, with "8ratio_v1", it is best to modify all occurrences of this string in both the .sh and .yaml files. This will affect the output location of checkpoints, the output location of logs, and the budget settings under the current training configuration. For more details, refer to `easyr1/verl/utils/dataset.py`.

- remaining: The vllm inference mode. Setting it to 8ratio uses the default method (splitting the chain of thought into 8 parts). If set to default, vllm will perform normal inference without adding any special tokens.

## Evaluation ๐Ÿ“Š

First, modify the `MODEL_NAME_OR_PATH` parameter in the `evaluation/remaining_eval/Eval.sh` script, and then run the following command:

```bash
cd ./evaluation

bash evaluation/remaining_eval/Eval.sh
```

### Parameter Introduction

The following parameters/environment variables need to be set in the script:

- remaining/stage: Same as described above.

- tip: The template for the prompt before the question. If using the 8ratio inference mode, the tip must also be set to 8ratio. Additionally, tip can be set to prompt_v1 or prompt_v2, which are two different natural language prompts.

- MODEL_NAME_OR_PATH: The path to the model. It is recommended to use a recognizable model name as the second-to-last folder name in the path, as the code will read this name as the current evaluation model and store logs in the corresponding folder. For example: `/path1/path2/Model_Name/models`