zz1358m commited on
Commit
a54c8b6
·
verified ·
1 Parent(s): a58dbfa

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +135 -1
  2. requirements.txt +206 -0
README.md CHANGED
@@ -1,3 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- license: mit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- <p align="center" width="100%">
2
+ <img src="./docs/static/images/logo_resize.png" width="80%">
3
+ </p> -->
4
+
5
+ <div align="center">
6
+ <h1 align="center"> SofT-GRPO: Reinforcing the LLM Soft-Thinking Policy with Gumbel Reparameterization
7
+ </h1>
8
+ </div>
9
+
10
+ <p align="center">
11
+ <img src="assets/mainprocess.png">
12
+ </p>
13
+
14
+
15
+
16
+ - **Authors**: [Zhi Zheng](https://zz1358m.github.io/zhizheng.github.io/), [Wee Sun Lee](https://scholar.google.com/citations?user=8PCrLgwAAAAJ&hl=en)
17
+ - **Institutes**: School of Computing, National University of Singapore, Singapore;
18
+ - **Resources**: [📖[Paper]()] [[🏠Twitter]()] [[🤗Huggingface](https://huggingface.co/zz1358m/SofT-GRPO-master)]
19
+
20
+
21
+
22
+ ## 📧 Welcome for feedback
23
+
24
+ We greatly appreciate your feedback and questions regarding the current status of this work.
25
+
26
+ Please feel free to contact Zhi Zheng by [zhi.zheng@u.nus.edu](zhi.zheng@u.nus.edu)
27
+
28
+
29
+ ## 💡 Highlights
30
+
31
+ - 🔥 **The First Powerful RLVR Algorithm for Soft-Thinking Reasoning:** We introduce **SofT-GRPO**, a novel and powerful policy optimization algorithm designed for reinforcing the soft-thinking reasoning paradigm in LLMs.
32
+
33
+ - ⚙️ **Gumbel-Softmax Noise in Rollout:** It integrates the Gumbel-Softmax technique into the group rollout process, actively obtaining diverse but valid soft-thinking reasoning paths.
34
+
35
+ - ⚙️ **Gumbel Reparameterization :** We propose an innovative gradient estimation approach via Gumbel reparameterization, enabling precise attribution of improvements to the LLM’s output probability distributions in policy optimization.
36
+
37
+ - 🔥 **Comprehensive Experiments and High Effectiveness:** We conduct comprehensive experiments across LLMs of 1.5B–7B parameters on five benchmarks, demonstrating that SofT-GRPO consistently outperforms the discrete-token GRPO baselines, especially at higher sample rates (Pass@16 and Pass@32).
38
+
39
+ ## 📜 News
40
+
41
+ **[2025/9/24]** [Code]() [Weight]() and [Paper](https://arxiv.org/pdf/2509.20317) are released!
42
+
43
+ ## 👨‍💻 Todo
44
+
45
+ - [x] SGLang & verl Code Modification (e.g., activate the overlap for efficiency).
46
+
47
+
48
+ ## 🛠️ Usage
49
+
50
+ ### 1. Clone the repository
51
+ ```bash
52
+ git clone https://github.com/zz1358m/SofT-GRPO-master
53
+ cd SofT-GRPO-master
54
+ ```
55
+
56
+ ### 2. Install dependencies
57
+ ##### Option1: For inference only,
58
+ ```bash
59
+ conda create -n st python=3.11 -y && conda activate st
60
+ pip install --upgrade pip
61
+ pip install torch transformers accelerate jsonlines math_verify openai torch_memory_saver
62
+ pip install flash_attn --no-build-isolation # may take more time (20min). try `pip install flash_attn==2.7.3 --no-build-isolation` if find undefined symbol bug
63
+
64
+ cd Soft-Thinking+noise+loss-main/sglang_soft_thinking_pkg
65
+ pip install -e "python[all]"
66
+ cd ../..
67
+ ```
68
+
69
+ ##### Option2: For inference & SofT-GRPO fine-tuning,
70
+ ```bash
71
+ pip install -r requirements.txt
72
+ ```
73
+ or building the verl-0.4.x after doing the Option1.
74
+ ```bash
75
+ cd verl-0.4.x
76
+ pip3 install --no-deps -e .
77
+ cd ..
78
+ ```
79
+
80
+
81
  ---
82
+
83
+ ### 3. Evaluating SofT-GRPO fine-tuned LLMs with soft-thinking pattern
84
+
85
+ #### Step 1: Download the SofT-GRPO, GRPO, weights from [[🤗Huggingface](https://huggingface.co/zz1358m/SofT-GRPO-master)]
86
+
87
+ #### Step 2: Evaluating GRPO under the discrete-token CoT pattern.
88
+ ```bash
89
+ ./Soft-Thinking+noise+loss-main/run_sample_discrete-token_grpo.sh
90
+ ```
91
+
92
+ #### Step 3: Evaluating GRPO under the soft-thinking reasoning pattern.
93
+ ```bash
94
+ ./Soft-Thinking+noise+loss-main/run_sample_gumbel_grpo.sh
95
+ ```
96
+
97
+ #### Step 3: Evaluating SofT-GRPO under the soft-thinking reasoning pattern.
98
+ ```bash
99
+ ./Soft-Thinking+noise+loss-main/run_sample_gumbel.sh
100
+ ```
101
+
102
+
103
  ---
104
+
105
+ ### 4. Training with SofT-GRPO
106
+
107
+ #### Option 1: Train the SofT-GRPO on DeepSeek-R1-Distill-Qwen-1.5B
108
+ ```bash
109
+ ./SofT-GRPO-deepscaler-8k.sh # change the LLM path, dataset path accordingly
110
+ ```
111
+
112
+ #### Option 2: Train the SofT-GRPO on DeepSeek-R1-Distill-Qwen-7B
113
+ ```bash
114
+ ./SofT-GRPO-deepscaler-8k-qwen7.sh # change the LLM path, dataset path accordingly
115
+ ```
116
+
117
+
118
+ #### Option 3: Train the SofT-GRPO on Llama-3.2-3B-Instruct
119
+ ```bash
120
+ ./SofT-GRPO-deepscaler-8k-llama3.sh # change the LLM path, dataset path accordingly
121
+ ```
122
+
123
+
124
+
125
+
126
+ ## ✒️ Citation
127
+
128
+ If you find our work helpful for your research, please consider giving a star ⭐ and citation 📝
129
+
130
+ ```bibtex
131
+ ```
132
+
133
+ ## ❤️ Acknowledgments
134
+
135
+ - [Soft-Thinking](https://github.com/eric-ai-lab/Soft-Thinking): The codebase we built upon. Thanks for their wonderful work.
136
+ - [verl-0.4.x](https://github.com/volcengine/verl/tree/v0.4.x): Our work is based on this codebase as well.
137
+ - [SIM-CoT](https://github.com/InternLM/SIM-CoT): We use their template for README.md!
requirements.txt ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Editable install with no version control (sglang==0.4.6.post1)
2
+ -e ./Soft-Thinking+noise+loss-main/sglang_soft_thinking_pkg/python
3
+ # Editable install with no version control (verl==0.4.0)
4
+ -e ./verl-0.4.x
5
+ absl-py==2.3.1
6
+ accelerate==1.10.1
7
+ aiohappyeyeballs==2.6.1
8
+ aiohttp==3.12.15
9
+ aiosignal==1.4.0
10
+ airportsdata==20250909
11
+ annotated-types==0.7.0
12
+ anthropic==0.68.1
13
+ antlr4-python3-runtime==4.9.3
14
+ anyio==4.11.0
15
+ asttokens==3.0.0
16
+ attrs==25.3.0
17
+ blobfile==3.0.0
18
+ build==1.3.0
19
+ cachetools==6.2.0
20
+ certifi==2025.8.3
21
+ cffi==2.0.0
22
+ charset-normalizer==3.4.3
23
+ click==8.3.0
24
+ cloudpickle==3.1.1
25
+ codetiming==1.4.0
26
+ compressed-tensors==0.11.0
27
+ contourpy==1.3.3
28
+ cuda-bindings==13.0.1
29
+ cuda-pathfinder==1.2.3
30
+ cuda-python==13.0.1
31
+ cycler==0.12.1
32
+ datasets==4.1.1
33
+ decorator==5.2.1
34
+ decord==0.6.0
35
+ dill==0.4.0
36
+ diskcache==5.6.3
37
+ distro==1.9.0
38
+ docstring_parser==0.17.0
39
+ einops==0.8.1
40
+ executing==2.2.1
41
+ expecttest==0.3.0
42
+ fastapi==0.117.1
43
+ fastuuid==0.13.5
44
+ filelock==3.19.1
45
+ flash-attn==2.7.3
46
+ flashinfer-python==0.2.3
47
+ fonttools==4.60.1
48
+ frozendict==2.4.6
49
+ frozenlist==1.7.0
50
+ fsspec==2025.9.0
51
+ gitdb==4.0.12
52
+ GitPython==3.1.45
53
+ grpcio==1.75.1
54
+ h11==0.16.0
55
+ hf-xet==1.1.10
56
+ hf_transfer==0.1.9
57
+ httpcore==1.0.9
58
+ httpx==0.28.1
59
+ huggingface-hub==0.35.1
60
+ hydra-core==1.3.2
61
+ idna==3.10
62
+ importlib_metadata==8.7.0
63
+ iniconfig==2.1.0
64
+ interegular==0.3.3
65
+ ipython==9.5.0
66
+ ipython_pygments_lexers==1.1.1
67
+ jedi==0.19.2
68
+ Jinja2==3.1.6
69
+ jiter==0.11.0
70
+ joblib==1.5.2
71
+ jsonlines==4.0.0
72
+ jsonschema==4.25.1
73
+ jsonschema-specifications==2025.9.1
74
+ kiwisolver==1.4.9
75
+ lark==1.3.0
76
+ latex2sympy2_extended==1.10.2
77
+ litellm==1.77.5
78
+ llguidance==0.7.30
79
+ lxml==6.0.2
80
+ Markdown==3.9
81
+ MarkupSafe==3.0.3
82
+ math-verify==0.8.0
83
+ matplotlib==3.10.6
84
+ matplotlib-inline==0.1.7
85
+ modelscope==1.30.0
86
+ mpmath==1.3.0
87
+ msgpack==1.1.1
88
+ msgspec==0.19.0
89
+ multidict==6.6.4
90
+ multiprocess==0.70.16
91
+ nanobind==2.9.2
92
+ nest-asyncio==1.6.0
93
+ networkx==3.5
94
+ ninja==1.13.0
95
+ numpy==2.3.3
96
+ nvidia-cublas-cu12==12.4.5.8
97
+ nvidia-cuda-cupti-cu12==12.4.127
98
+ nvidia-cuda-nvrtc-cu12==12.4.127
99
+ nvidia-cuda-runtime-cu12==12.4.127
100
+ nvidia-cudnn-cu12==9.1.0.70
101
+ nvidia-cudnn-frontend==1.14.1
102
+ nvidia-cufft-cu12==11.2.1.3
103
+ nvidia-cufile-cu12==1.13.1.3
104
+ nvidia-curand-cu12==10.3.5.147
105
+ nvidia-cusolver-cu12==11.6.1.9
106
+ nvidia-cusparse-cu12==12.3.1.170
107
+ nvidia-cusparselt-cu12==0.6.2
108
+ nvidia-cutlass-dsl==4.2.1
109
+ nvidia-ml-py==13.580.82
110
+ nvidia-nccl-cu12==2.21.5
111
+ nvidia-nvjitlink-cu12==12.4.127
112
+ nvidia-nvtx-cu12==12.4.127
113
+ omegaconf==2.3.0
114
+ openai==1.109.1
115
+ openai-harmony==0.0.4
116
+ orjson==3.11.3
117
+ outlines==0.1.11
118
+ outlines_core==0.1.26
119
+ packaging==25.0
120
+ pandas==2.3.2
121
+ parso==0.8.5
122
+ partial-json-parser==0.2.1.1.post6
123
+ peft==0.17.1
124
+ pexpect==4.9.0
125
+ pillow==11.3.0
126
+ platformdirs==4.4.0
127
+ pluggy==1.6.0
128
+ prometheus_client==0.23.1
129
+ prompt_toolkit==3.0.52
130
+ propcache==0.3.2
131
+ protobuf==6.32.1
132
+ psutil==7.1.0
133
+ ptyprocess==0.7.0
134
+ pure_eval==0.2.3
135
+ pyarrow==21.0.0
136
+ pybase64==1.4.2
137
+ pybind11==3.0.1
138
+ pycountry==24.6.1
139
+ pycparser==2.23
140
+ pycryptodomex==3.23.0
141
+ pydantic==2.11.9
142
+ pydantic_core==2.33.2
143
+ Pygments==2.19.2
144
+ pylatexenc==2.10
145
+ pynvml==13.0.1
146
+ pyparsing==3.2.5
147
+ pyproject_hooks==1.2.0
148
+ pytest==8.4.2
149
+ python-dateutil==2.9.0.post0
150
+ python-dotenv==1.1.1
151
+ python-multipart==0.0.20
152
+ pytz==2025.2
153
+ pyvers==0.1.0
154
+ PyYAML==6.0.3
155
+ pyzmq==27.1.0
156
+ ray==2.49.2
157
+ referencing==0.36.2
158
+ regex==2025.9.18
159
+ requests==2.32.5
160
+ rpds-py==0.27.1
161
+ safetensors==0.6.2
162
+ scikit-learn==1.7.2
163
+ scipy==1.16.2
164
+ sentence-transformers==5.1.1
165
+ sentencepiece==0.2.1
166
+ sentry-sdk==2.39.0
167
+ setproctitle==1.3.7
168
+ sgl-kernel==0.1.1
169
+ six==1.17.0
170
+ smmap==5.0.2
171
+ sniffio==1.3.1
172
+ soundfile==0.13.1
173
+ stack-data==0.6.3
174
+ starlette==0.48.0
175
+ sympy==1.13.1
176
+ tabulate==0.9.0
177
+ tensorboard==2.20.0
178
+ tensorboard-data-server==0.7.2
179
+ tensordict==0.10.0
180
+ threadpoolctl==3.6.0
181
+ tiktoken==0.11.0
182
+ timm==1.0.16
183
+ tokenizers==0.21.4
184
+ torch==2.6.0
185
+ torch_memory_saver==0.0.8
186
+ torchao==0.9.0
187
+ torchaudio==2.8.0
188
+ torchdata==0.11.0
189
+ torchvision==0.21.0
190
+ tqdm==4.67.1
191
+ traitlets==5.14.3
192
+ transformers==4.51.1
193
+ triton==3.2.0
194
+ typing-inspection==0.4.1
195
+ typing_extensions==4.15.0
196
+ tzdata==2025.2
197
+ urllib3==2.5.0
198
+ uvicorn==0.37.0
199
+ uvloop==0.21.0
200
+ wandb==0.22.0
201
+ wcwidth==0.2.14
202
+ Werkzeug==3.1.3
203
+ xgrammar==0.1.17
204
+ xxhash==3.5.0
205
+ yarl==1.20.1
206
+ zipp==3.23.0