Upload folder using huggingface_hub
Browse files- README.md +276 -3
- client.py +203 -0
- inference_hf.py +383 -0
- serve.py +95 -0
- setup.sh +125 -0
- setup_model_dir.py +128 -0
- start_server.sh +48 -0
- terminator.pt +3 -0
- vllm_terminator/__init__.py +19 -0
- vllm_terminator/model.py +553 -0
- vllm_terminator/terminator_head.py +135 -0
README.md
CHANGED
|
@@ -1,3 +1,276 @@
|
|
| 1 |
-
---
|
| 2 |
-
|
| 3 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language:
|
| 3 |
+
- en
|
| 4 |
+
license: apache-2.0
|
| 5 |
+
library_name: vllm
|
| 6 |
+
tags:
|
| 7 |
+
- reasoning
|
| 8 |
+
- chain-of-thought
|
| 9 |
+
- efficiency
|
| 10 |
+
- inference-optimization
|
| 11 |
+
- qwen3
|
| 12 |
+
base_model: Qwen/Qwen3-14B
|
| 13 |
+
base_model_relation: finetune
|
| 14 |
+
pipeline_tag: text-generation
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
# Terminator-Qwen3-14B
|
| 18 |
+
|
| 19 |
+
**Terminator** is a lightweight neural module that predicts when a reasoning language model has reached its final answer during chain-of-thought (CoT) generation. When the Terminator detects the model has committed to an answer, it truncates the remaining reasoning and forces the model to begin its response, thereby delivering the same answer with significantly less computation.
|
| 20 |
+
|
| 21 |
+
This repository contains everything needed to run **Terminator-Qwen3-14B**:
|
| 22 |
+
|
| 23 |
+
- Trained Terminator checkpoint (1 extra transformer layer + prediction head)
|
| 24 |
+
- vLLM plugin code (`vllm_terminator/`) for high-performance serving
|
| 25 |
+
- Server launcher and streaming client
|
| 26 |
+
- Standalone HuggingFace inference script (no server required)
|
| 27 |
+
- Automated setup script
|
| 28 |
+
|
| 29 |
+
**Note**: Terminator currently supports **single-GPU, single-sequence inference only**.
|
| 30 |
+
|
| 31 |
+
---
|
| 32 |
+
|
| 33 |
+
## Quick Start
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
# 1. Clone the repository (requires Git LFS: https://git-lfs.com)
|
| 37 |
+
git lfs install
|
| 38 |
+
git clone https://huggingface.co/acnagle/Terminator-Qwen3-14B
|
| 39 |
+
cd Terminator-Qwen3-14B
|
| 40 |
+
|
| 41 |
+
# 2. Run automated setup (creates conda env, installs vllm, downloads base model)
|
| 42 |
+
./setup.sh
|
| 43 |
+
|
| 44 |
+
# 3. Start the server
|
| 45 |
+
./start_server.sh
|
| 46 |
+
|
| 47 |
+
# 4. In another terminal, chat with the model
|
| 48 |
+
python client.py --interactive
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
---
|
| 52 |
+
|
| 53 |
+
## Requirements
|
| 54 |
+
|
| 55 |
+
- **GPU**: Single NVIDIA GPU with at least ~40GB VRAM (e.g., A100 40GB)
|
| 56 |
+
- **CUDA**: Compatible CUDA driver installed, 12.9 and above recommended.
|
| 57 |
+
- **Python**: 3.12
|
| 58 |
+
- **OS**: Linux (recommended) or any OS supported by vLLM
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## Installation
|
| 63 |
+
|
| 64 |
+
### Option A: Automated Setup
|
| 65 |
+
|
| 66 |
+
The `setup.sh` script handles everything:
|
| 67 |
+
|
| 68 |
+
```bash
|
| 69 |
+
./setup.sh
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
This will:
|
| 73 |
+
1. Create a conda environment called `terminator` with Python 3.12
|
| 74 |
+
2. Install [uv](https://docs.astral.sh/uv/), [vLLM](https://docs.vllm.ai/), and [openai](https://pypi.org/project/openai/)
|
| 75 |
+
3. Download Qwen3-14B base model weights (~28GB) from HuggingFace
|
| 76 |
+
4. Create the model directory (`model_dir/`)
|
| 77 |
+
|
| 78 |
+
### Option B: Manual Setup
|
| 79 |
+
|
| 80 |
+
**1. Create a Python environment**
|
| 81 |
+
|
| 82 |
+
Using conda or micromamba:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
conda create -n terminator python=3.12 -y
|
| 86 |
+
conda activate terminator
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
**2. Install uv**
|
| 90 |
+
|
| 91 |
+
```bash
|
| 92 |
+
pip install --upgrade uv
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
Or see the [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/).
|
| 96 |
+
|
| 97 |
+
**3. Install vLLM**
|
| 98 |
+
|
| 99 |
+
```bash
|
| 100 |
+
uv pip install vllm --torch-backend=auto
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
See the [vLLM installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/) for alternative installation methods (ROCm, CPU, etc.).
|
| 104 |
+
|
| 105 |
+
**4. Install openai (for the client)**
|
| 106 |
+
|
| 107 |
+
```bash
|
| 108 |
+
uv pip install openai
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
**5. Set up the model directory**
|
| 112 |
+
|
| 113 |
+
This downloads the base Qwen3-14B weights and creates a vLLM-ready model directory:
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
python setup_model_dir.py
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
The script accepts optional arguments:
|
| 120 |
+
|
| 121 |
+
| Argument | Default | Description |
|
| 122 |
+
|----------|---------|-------------|
|
| 123 |
+
| `--checkpoint` | `./terminator.pt` | Path to the Terminator checkpoint |
|
| 124 |
+
| `--output-dir` | `./model_dir` | Output model directory |
|
| 125 |
+
| `--threshold` | `0.7` | Prediction threshold for Terminator activation |
|
| 126 |
+
| `--window-size` | `10` | Sliding window size for majority vote |
|
| 127 |
+
| `--exit-message` | *(built-in message)* | Message injected when Terminator fires |
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## Starting the Server
|
| 132 |
+
|
| 133 |
+
```bash
|
| 134 |
+
./start_server.sh
|
| 135 |
+
```
|
| 136 |
+
|
| 137 |
+
Or with custom configuration:
|
| 138 |
+
|
| 139 |
+
```bash
|
| 140 |
+
VLLM_GPU_UTIL=0.70 VLLM_MAX_MODEL_LEN=8192 ./start_server.sh
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
The server exposes an **OpenAI-compatible API** on the configured port (default: 8000).
|
| 144 |
+
|
| 145 |
+
### Configuration
|
| 146 |
+
|
| 147 |
+
Set these environment variables before running `start_server.sh` or `serve.py`:
|
| 148 |
+
|
| 149 |
+
| Variable | Default | Description |
|
| 150 |
+
|----------|---------|-------------|
|
| 151 |
+
| `VLLM_GPU_UTIL` | `0.90` | Fraction of GPU memory to use for the model |
|
| 152 |
+
| `VLLM_MAX_MODEL_LEN` | *(auto)* | Maximum context length in tokens |
|
| 153 |
+
| `VLLM_PORT` | `8000` | Server port |
|
| 154 |
+
| `VLLM_ENFORCE_EAGER` | `0` | Set to `1` to disable CUDA graphs |
|
| 155 |
+
| `VLLM_API_KEY` | *(none)* | Require this API key from clients |
|
| 156 |
+
| `VLLM_SERVED_NAME` | `Terminator-Qwen3-14B` | Model name reported by the API |
|
| 157 |
+
|
| 158 |
+
---
|
| 159 |
+
|
| 160 |
+
## Standalone Inference (No Server)
|
| 161 |
+
|
| 162 |
+
**Recommendation:** For the best performance, use the vLLM server described above. vLLM uses KV caching, CUDA graphs, and optimized kernels, making it **significantly faster** than HuggingFace-native inference. The script below is provided for quick testing and demos where spinning up a server is inconvenient.
|
| 163 |
+
|
| 164 |
+
For quick testing without starting a vLLM server, use the HuggingFace-native inference script:
|
| 165 |
+
|
| 166 |
+
```bash
|
| 167 |
+
python inference_hf.py --prompt "What is the sum of the first 100 natural numbers?"
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
This loads the model directly via HuggingFace `transformers` and runs token-by-token generation with the Terminator head. Thinking content is streamed in dimmed text; the final answer is shown in bold.
|
| 171 |
+
|
| 172 |
+
| Argument | Default | Description |
|
| 173 |
+
|----------|---------|-------------|
|
| 174 |
+
| `--prompt` | *(required)* | Input prompt |
|
| 175 |
+
| `--model` | `Qwen/Qwen3-14B` | HuggingFace model name or path |
|
| 176 |
+
| `--checkpoint` | `./terminator.pt` | Path to the Terminator checkpoint |
|
| 177 |
+
| `--threshold` | `0.7` | Prediction threshold |
|
| 178 |
+
| `--window-size` | `10` | Sliding window size for majority vote |
|
| 179 |
+
| `--exit-message` | *(built-in message)* | Message injected when Terminator fires (empty string to disable) |
|
| 180 |
+
| `--max-tokens` | `32768` | Maximum tokens to generate |
|
| 181 |
+
| `--temperature` | `0.6` | Sampling temperature |
|
| 182 |
+
|
| 183 |
+
---
|
| 184 |
+
|
| 185 |
+
## Using the Client (vLLM Server)
|
| 186 |
+
|
| 187 |
+
### Single Prompt
|
| 188 |
+
|
| 189 |
+
```bash
|
| 190 |
+
python client.py --prompt "What is the sum of the first 100 natural numbers?"
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### Interactive Mode
|
| 194 |
+
|
| 195 |
+
```bash
|
| 196 |
+
python client.py --interactive
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
This starts a multi-turn conversation with the model. Thinking content is displayed in dimmed text; the final answer is shown in bold.
|
| 200 |
+
|
| 201 |
+
### Client Options
|
| 202 |
+
|
| 203 |
+
| Argument | Default | Description |
|
| 204 |
+
|----------|---------|-------------|
|
| 205 |
+
| `--base-url` | `http://localhost:8000/v1` | Server URL |
|
| 206 |
+
| `--max-tokens` | *(server default)* | Maximum tokens to generate |
|
| 207 |
+
| `--temperature` | `0.6` | Sampling temperature |
|
| 208 |
+
|
| 209 |
+
### Using the API Directly
|
| 210 |
+
|
| 211 |
+
The server is OpenAI-compatible. You can use any OpenAI client library. Replace `localhost` with your server's address if connecting remotely:
|
| 212 |
+
|
| 213 |
+
```python
|
| 214 |
+
from openai import OpenAI
|
| 215 |
+
|
| 216 |
+
client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
|
| 217 |
+
|
| 218 |
+
response = client.chat.completions.create(
|
| 219 |
+
model="Terminator-Qwen3-14B",
|
| 220 |
+
messages=[{"role": "user", "content": "What is 25 * 37?"}],
|
| 221 |
+
temperature=0.6,
|
| 222 |
+
extra_body={"chat_template_kwargs": {"enable_thinking": True}},
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Thinking content (chain-of-thought)
|
| 226 |
+
print(response.choices[0].message.reasoning_content)
|
| 227 |
+
|
| 228 |
+
# Final answer
|
| 229 |
+
print(response.choices[0].message.content)
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
---
|
| 233 |
+
|
| 234 |
+
## How Terminator Works
|
| 235 |
+
|
| 236 |
+
Terminator is a single transformer layer followed by a prediction head, trained on top of a frozen Qwen3-14B base model. The transformer layer (initialized as a copy of the base model's final layer, then fine-tuned) takes the hidden states from the LLM and processes them before the prediction head, which outputs a per-token binary prediction: *has the model reached its final answer?*
|
| 237 |
+
|
| 238 |
+
During generation, Terminator maintains a **sliding window** of the most recent predictions. When a majority of predictions in the window exceed the threshold (default: 0.7), the model is considered to have reached its final answer. At that point:
|
| 239 |
+
|
| 240 |
+
1. A short **exit message** is injected into the reasoning (e.g., *"I've run out of thinking tokens. I need to commit to a final answer."*) to help the model transition smoothly.
|
| 241 |
+
2. The `</think>` token is forced, ending the reasoning phase.
|
| 242 |
+
3. The model generates its final answer normally.
|
| 243 |
+
|
| 244 |
+
This allows the model to skip potentially thousands of redundant reasoning tokens while preserving answer quality.
|
| 245 |
+
|
| 246 |
+
---
|
| 247 |
+
|
| 248 |
+
## File Structure
|
| 249 |
+
|
| 250 |
+
```
|
| 251 |
+
Terminator-Qwen3-14B/
|
| 252 |
+
├── README.md This file
|
| 253 |
+
├── terminator.pt Trained Terminator checkpoint
|
| 254 |
+
├── vllm_terminator/ vLLM plugin package
|
| 255 |
+
│ ├── __init__.py Registers the model architecture with vLLM
|
| 256 |
+
│ ├── model.py Qwen3TerminatorForCausalLM model class
|
| 257 |
+
│ └── terminator_head.py FFN classifier and checkpoint loading
|
| 258 |
+
├── inference_hf.py Standalone HuggingFace inference (no server)
|
| 259 |
+
├── serve.py vLLM server launcher
|
| 260 |
+
├── setup_model_dir.py Model directory setup (downloads base weights)
|
| 261 |
+
├── client.py Streaming chat client (connects to vLLM server)
|
| 262 |
+
├── setup.sh Automated setup script
|
| 263 |
+
└── start_server.sh Server launcher with sensible defaults
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
---
|
| 267 |
+
|
| 268 |
+
## Citation
|
| 269 |
+
|
| 270 |
+
*Coming soon.*
|
| 271 |
+
|
| 272 |
+
---
|
| 273 |
+
|
| 274 |
+
## License
|
| 275 |
+
|
| 276 |
+
This project builds on [Qwen3-14B](https://huggingface.co/Qwen/Qwen3-14B) by the Qwen team. Please refer to the Qwen3 license for base model usage terms.
|
client.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Client for the Terminator vLLM server.
|
| 4 |
+
|
| 5 |
+
Supports single-prompt and multi-turn conversation modes with streaming
|
| 6 |
+
output. Thinking content is displayed in dimmed text; answer content in
|
| 7 |
+
normal text.
|
| 8 |
+
|
| 9 |
+
Usage:
|
| 10 |
+
# Single prompt
|
| 11 |
+
python client.py --prompt "What is the sum of the first 100 natural numbers?"
|
| 12 |
+
|
| 13 |
+
# Interactive multi-turn conversation
|
| 14 |
+
python client.py --interactive
|
| 15 |
+
|
| 16 |
+
# Custom server URL and max tokens
|
| 17 |
+
python client.py --base-url http://localhost:8001/v1 --max-tokens 8192 --prompt "Hello"
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import argparse
|
| 21 |
+
import sys
|
| 22 |
+
|
| 23 |
+
from openai import OpenAI
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
# ANSI escape codes
|
| 27 |
+
DIM = "\033[2m"
|
| 28 |
+
BOLD = "\033[1m"
|
| 29 |
+
RESET = "\033[0m"
|
| 30 |
+
|
| 31 |
+
BANNER_LINES = [
|
| 32 |
+
r"████████╗███████╗██████╗ ███╗ ███╗██╗███╗ ██╗ █████╗ ████████╗ ██████╗ ██████╗ ",
|
| 33 |
+
r"╚══██╔══╝██╔════╝██╔══██╗████╗ ████║██║████╗ ██║██╔══██╗╚══██╔══╝██╔═══██╗██╔══██╗",
|
| 34 |
+
r" ██║ █████╗ ██████╔╝██╔████╔██║██║██╔██╗ ██║███████║ ██║ ██║ ██║██████╔╝",
|
| 35 |
+
r" ██║ ██╔══╝ ██╔══██╗██║╚██╔╝██║██║██║╚██╗██║██╔══██║ ██║ ██║ ██║██╔══██╗",
|
| 36 |
+
r" ██║ ███████╗██║ ██║██║ ╚═╝ ██║██║██║ ╚████║██║ ██║ ██║ ╚██████╔╝██║ ██║",
|
| 37 |
+
r" ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝",
|
| 38 |
+
]
|
| 39 |
+
|
| 40 |
+
# Dark red -> light red gradient (one color per row)
|
| 41 |
+
_GRADIENT_RGB = [
|
| 42 |
+
(140, 0, 0),
|
| 43 |
+
(165, 15, 15),
|
| 44 |
+
(190, 35, 35),
|
| 45 |
+
(215, 55, 55),
|
| 46 |
+
(235, 70, 70),
|
| 47 |
+
(255, 90, 90),
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def print_banner() -> None:
|
| 52 |
+
for line, (r, g, b) in zip(BANNER_LINES, _GRADIENT_RGB):
|
| 53 |
+
print(f"\033[38;2;{r};{g};{b}m{line}{RESET}")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def detect_model(client: OpenAI) -> str:
|
| 57 |
+
"""Auto-detect the served model name from the server."""
|
| 58 |
+
try:
|
| 59 |
+
models = client.models.list()
|
| 60 |
+
if not models.data:
|
| 61 |
+
print("ERROR: No models available on the server.", file=sys.stderr)
|
| 62 |
+
sys.exit(1)
|
| 63 |
+
return models.data[0].id
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"ERROR: Could not connect to server: {e}", file=sys.stderr)
|
| 66 |
+
sys.exit(1)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def stream_response(
|
| 70 |
+
client: OpenAI,
|
| 71 |
+
model: str,
|
| 72 |
+
messages: list[dict],
|
| 73 |
+
max_tokens: int | None,
|
| 74 |
+
temperature: float,
|
| 75 |
+
) -> str:
|
| 76 |
+
"""Stream a chat completion response.
|
| 77 |
+
|
| 78 |
+
Thinking content is printed in dim text, answer content in normal text.
|
| 79 |
+
Returns the assistant's answer content (for conversation history).
|
| 80 |
+
"""
|
| 81 |
+
kwargs = dict(
|
| 82 |
+
model=model,
|
| 83 |
+
messages=messages,
|
| 84 |
+
temperature=temperature,
|
| 85 |
+
stream=True,
|
| 86 |
+
extra_body={"chat_template_kwargs": {"enable_thinking": True}},
|
| 87 |
+
)
|
| 88 |
+
if max_tokens is not None:
|
| 89 |
+
kwargs["max_tokens"] = max_tokens
|
| 90 |
+
|
| 91 |
+
stream = client.chat.completions.create(**kwargs)
|
| 92 |
+
|
| 93 |
+
in_thinking = False
|
| 94 |
+
in_answer = False
|
| 95 |
+
full_content = ""
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
for chunk in stream:
|
| 99 |
+
if not chunk.choices:
|
| 100 |
+
continue
|
| 101 |
+
delta = chunk.choices[0].delta
|
| 102 |
+
|
| 103 |
+
reasoning = getattr(delta, "reasoning_content", None)
|
| 104 |
+
if reasoning:
|
| 105 |
+
if not in_thinking:
|
| 106 |
+
sys.stdout.write(f"\n{DIM}Thinking...\n")
|
| 107 |
+
in_thinking = True
|
| 108 |
+
sys.stdout.write(reasoning)
|
| 109 |
+
sys.stdout.flush()
|
| 110 |
+
|
| 111 |
+
if delta.content:
|
| 112 |
+
if not in_answer:
|
| 113 |
+
if in_thinking:
|
| 114 |
+
sys.stdout.write(RESET)
|
| 115 |
+
sys.stdout.write(f"\n{BOLD}Answer:{RESET}\n")
|
| 116 |
+
in_answer = True
|
| 117 |
+
sys.stdout.write(delta.content)
|
| 118 |
+
sys.stdout.flush()
|
| 119 |
+
full_content += delta.content
|
| 120 |
+
except KeyboardInterrupt:
|
| 121 |
+
pass
|
| 122 |
+
finally:
|
| 123 |
+
sys.stdout.write(RESET)
|
| 124 |
+
sys.stdout.flush()
|
| 125 |
+
|
| 126 |
+
print()
|
| 127 |
+
return full_content
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def run_single(client, model, prompt, max_tokens, temperature):
|
| 131 |
+
"""Run a single prompt and exit."""
|
| 132 |
+
messages = [{"role": "user", "content": prompt}]
|
| 133 |
+
stream_response(client, model, messages, max_tokens, temperature)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def run_interactive(client, model, max_tokens, temperature):
|
| 137 |
+
"""Interactive multi-turn conversation loop."""
|
| 138 |
+
messages = []
|
| 139 |
+
print()
|
| 140 |
+
print_banner()
|
| 141 |
+
print()
|
| 142 |
+
print(f" Connected to {BOLD}{model}{RESET}")
|
| 143 |
+
print(f" Type your message and press Enter. Type {BOLD}quit{RESET} or Ctrl+C to exit.")
|
| 144 |
+
print(f" {DIM}Note: input is single-line only — compose your full message before pressing Enter.{RESET}")
|
| 145 |
+
print(f" {DIM} Please ensure that copied text is formatted as a single line before pasting.{RESET}")
|
| 146 |
+
print()
|
| 147 |
+
|
| 148 |
+
while True:
|
| 149 |
+
try:
|
| 150 |
+
user_input = input(f"{BOLD}>>>{RESET} ")
|
| 151 |
+
except (KeyboardInterrupt, EOFError):
|
| 152 |
+
print("\nGoodbye!")
|
| 153 |
+
break
|
| 154 |
+
|
| 155 |
+
if user_input.strip().lower() in ("quit", "exit", "q"):
|
| 156 |
+
print("Goodbye!")
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
if not user_input.strip():
|
| 160 |
+
continue
|
| 161 |
+
|
| 162 |
+
messages.append({"role": "user", "content": user_input})
|
| 163 |
+
content = stream_response(client, model, messages, max_tokens, temperature)
|
| 164 |
+
messages.append({"role": "assistant", "content": content})
|
| 165 |
+
print()
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def main():
|
| 169 |
+
parser = argparse.ArgumentParser(
|
| 170 |
+
description=__doc__,
|
| 171 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 172 |
+
)
|
| 173 |
+
mode = parser.add_mutually_exclusive_group(required=True)
|
| 174 |
+
mode.add_argument("--prompt", type=str, help="Single prompt to send")
|
| 175 |
+
mode.add_argument(
|
| 176 |
+
"--interactive", action="store_true",
|
| 177 |
+
help="Start an interactive multi-turn conversation",
|
| 178 |
+
)
|
| 179 |
+
parser.add_argument(
|
| 180 |
+
"--base-url", default="http://localhost:8000/v1",
|
| 181 |
+
help="vLLM server URL (default: http://localhost:8000/v1)",
|
| 182 |
+
)
|
| 183 |
+
parser.add_argument(
|
| 184 |
+
"--max-tokens", type=int, default=None,
|
| 185 |
+
help="Maximum tokens to generate (default: server decides based on context length)",
|
| 186 |
+
)
|
| 187 |
+
parser.add_argument(
|
| 188 |
+
"--temperature", type=float, default=0.6,
|
| 189 |
+
help="Sampling temperature (default: 0.6)",
|
| 190 |
+
)
|
| 191 |
+
args = parser.parse_args()
|
| 192 |
+
|
| 193 |
+
client = OpenAI(base_url=args.base_url, api_key="EMPTY")
|
| 194 |
+
model = detect_model(client)
|
| 195 |
+
|
| 196 |
+
if args.prompt:
|
| 197 |
+
run_single(client, model, args.prompt, args.max_tokens, args.temperature)
|
| 198 |
+
else:
|
| 199 |
+
run_interactive(client, model, args.max_tokens, args.temperature)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
if __name__ == "__main__":
|
| 203 |
+
main()
|
inference_hf.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
HuggingFace-native inference for Terminator-Qwen3-14B.
|
| 4 |
+
|
| 5 |
+
Loads the frozen Qwen3 base model + trained Terminator head (FFN + optional
|
| 6 |
+
extra transformer layers) directly via HuggingFace transformers.
|
| 7 |
+
|
| 8 |
+
Generates chain-of-thought reasoning token-by-token. The Terminator FFN
|
| 9 |
+
predicts when the final answer has been reached; when a sliding-window
|
| 10 |
+
majority vote exceeds the threshold, an exit message is injected and the
|
| 11 |
+
model transitions to answering mode.
|
| 12 |
+
|
| 13 |
+
Usage:
|
| 14 |
+
python inference_hf.py --prompt "What is the sum of the first 100 natural numbers?"
|
| 15 |
+
|
| 16 |
+
python inference_hf.py \\
|
| 17 |
+
--prompt "Solve x^2 - 5x + 6 = 0" \\
|
| 18 |
+
--model Qwen/Qwen3-14B \\
|
| 19 |
+
--checkpoint terminator.pt \\
|
| 20 |
+
--threshold 0.7 --window-size 10
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import os
|
| 25 |
+
import sys
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
|
| 28 |
+
import torch
|
| 29 |
+
import torch.nn.functional as F
|
| 30 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 31 |
+
from transformers import TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
|
| 32 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Imports from the project
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
# Local: TerminatorFFN + checkpoint loader
|
| 39 |
+
_script_dir = Path(__file__).resolve().parent
|
| 40 |
+
sys.path.insert(0, str(_script_dir))
|
| 41 |
+
from vllm_terminator.terminator_head import load_terminator_checkpoint
|
| 42 |
+
|
| 43 |
+
# Parent dir: ExtraTransformerLayers from terminator_utils
|
| 44 |
+
_repo_root = _script_dir.parent
|
| 45 |
+
sys.path.insert(0, str(_repo_root))
|
| 46 |
+
from terminator_utils import ExtraTransformerLayers
|
| 47 |
+
|
| 48 |
+
# ---------------------------------------------------------------------------
|
| 49 |
+
# ANSI escape codes
|
| 50 |
+
# ---------------------------------------------------------------------------
|
| 51 |
+
DIM = "\033[2m"
|
| 52 |
+
BOLD = "\033[1m"
|
| 53 |
+
RESET = "\033[0m"
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def load_model_and_tokenizer(model_name, device):
|
| 57 |
+
"""Load base Qwen3 model and tokenizer."""
|
| 58 |
+
print(f"Loading tokenizer: {model_name}")
|
| 59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
| 60 |
+
if tokenizer.pad_token is None:
|
| 61 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 62 |
+
|
| 63 |
+
think_token_id = tokenizer.convert_tokens_to_ids("<think>")
|
| 64 |
+
think_end_token_id = tokenizer.convert_tokens_to_ids("</think>")
|
| 65 |
+
if think_token_id == tokenizer.unk_token_id or think_end_token_id == tokenizer.unk_token_id:
|
| 66 |
+
raise ValueError(
|
| 67 |
+
f"<think>/<think> tokens not in tokenizer! "
|
| 68 |
+
f"IDs: {think_token_id}, {think_end_token_id}"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
print(f"Loading model: {model_name}")
|
| 72 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 73 |
+
model_name,
|
| 74 |
+
torch_dtype=torch.bfloat16,
|
| 75 |
+
device_map={"": device},
|
| 76 |
+
trust_remote_code=True,
|
| 77 |
+
)
|
| 78 |
+
for param in model.parameters():
|
| 79 |
+
param.requires_grad = False
|
| 80 |
+
model.eval()
|
| 81 |
+
|
| 82 |
+
print(
|
| 83 |
+
f"Model loaded: {model.config.num_hidden_layers} layers, "
|
| 84 |
+
f"hidden size {model.config.hidden_size}"
|
| 85 |
+
)
|
| 86 |
+
return model, tokenizer, think_token_id, think_end_token_id
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def build_extra_layers(base_model, checkpoint_config, extra_layers_state_dict, device):
|
| 90 |
+
"""Reconstruct extra transformer layers from checkpoint state dict."""
|
| 91 |
+
num_extra_layers = checkpoint_config.get("num_extra_layers", 0)
|
| 92 |
+
if num_extra_layers == 0 or extra_layers_state_dict is None:
|
| 93 |
+
return None
|
| 94 |
+
|
| 95 |
+
print(f"Reconstructing {num_extra_layers} extra transformer layer(s)...")
|
| 96 |
+
base_layer_class = base_model.model.layers[0].__class__
|
| 97 |
+
model_config = base_model.config
|
| 98 |
+
rotary_emb = getattr(base_model.model, "rotary_emb", None)
|
| 99 |
+
|
| 100 |
+
extra_layers = ExtraTransformerLayers(
|
| 101 |
+
base_layer_class, num_extra_layers, model_config, rotary_emb=rotary_emb
|
| 102 |
+
).to(device)
|
| 103 |
+
extra_layers.load_state_dict(extra_layers_state_dict)
|
| 104 |
+
extra_layers.eval()
|
| 105 |
+
|
| 106 |
+
param_count = sum(p.numel() for p in extra_layers.parameters())
|
| 107 |
+
print(f"Extra layers loaded ({param_count:,} parameters)")
|
| 108 |
+
return extra_layers
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def generate_with_terminator(
|
| 112 |
+
prompt,
|
| 113 |
+
model,
|
| 114 |
+
tokenizer,
|
| 115 |
+
ffn,
|
| 116 |
+
extra_layers,
|
| 117 |
+
layer_idx,
|
| 118 |
+
think_token_id,
|
| 119 |
+
think_end_token_id,
|
| 120 |
+
threshold,
|
| 121 |
+
window_size,
|
| 122 |
+
exit_message,
|
| 123 |
+
max_tokens,
|
| 124 |
+
temperature,
|
| 125 |
+
device,
|
| 126 |
+
):
|
| 127 |
+
"""Generate a response with Terminator early-exit logic.
|
| 128 |
+
|
| 129 |
+
Follows the same generation pattern as inference_terminator.py:mode1_generate().
|
| 130 |
+
Streams thinking tokens to the terminal as they are produced.
|
| 131 |
+
"""
|
| 132 |
+
# Format prompt via chat template
|
| 133 |
+
messages = [{"role": "user", "content": prompt}]
|
| 134 |
+
prompt_text = tokenizer.apply_chat_template(
|
| 135 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Tokenize and append <think>
|
| 139 |
+
prompt_ids = tokenizer(
|
| 140 |
+
prompt_text, add_special_tokens=False, return_tensors="pt"
|
| 141 |
+
)["input_ids"].to(device).long()
|
| 142 |
+
|
| 143 |
+
input_ids = torch.cat(
|
| 144 |
+
[prompt_ids, torch.tensor([[think_token_id]], dtype=torch.long, device=device)],
|
| 145 |
+
dim=1,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
# Sampling processors
|
| 149 |
+
logits_processor = LogitsProcessorList([
|
| 150 |
+
TemperatureLogitsWarper(temperature=temperature),
|
| 151 |
+
TopKLogitsWarper(top_k=20),
|
| 152 |
+
TopPLogitsWarper(top_p=0.95),
|
| 153 |
+
])
|
| 154 |
+
|
| 155 |
+
# Sliding-window state
|
| 156 |
+
predictions_list = []
|
| 157 |
+
reasoning_tokens = []
|
| 158 |
+
early_exit = False
|
| 159 |
+
|
| 160 |
+
# Start streaming thinking output
|
| 161 |
+
sys.stdout.write(f"\n{DIM}Thinking...\n")
|
| 162 |
+
sys.stdout.flush()
|
| 163 |
+
|
| 164 |
+
for step in range(max_tokens):
|
| 165 |
+
attention_mask = torch.ones_like(input_ids)
|
| 166 |
+
|
| 167 |
+
# Hook to capture hidden states from the target layer
|
| 168 |
+
captured = {}
|
| 169 |
+
|
| 170 |
+
def hook_fn(module, input, output):
|
| 171 |
+
if isinstance(output, tuple):
|
| 172 |
+
captured["hidden"] = output[0].detach()
|
| 173 |
+
else:
|
| 174 |
+
captured["hidden"] = output.detach()
|
| 175 |
+
|
| 176 |
+
target_layer = model.model.layers[layer_idx]
|
| 177 |
+
handle = target_layer.register_forward_hook(hook_fn)
|
| 178 |
+
|
| 179 |
+
with torch.no_grad():
|
| 180 |
+
outputs = model(
|
| 181 |
+
input_ids=input_ids,
|
| 182 |
+
attention_mask=attention_mask,
|
| 183 |
+
use_cache=False,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
handle.remove()
|
| 187 |
+
|
| 188 |
+
hidden_states = captured["hidden"] # [1, seq_len, hidden_size]
|
| 189 |
+
|
| 190 |
+
# Make prediction once we have at least one thinking token
|
| 191 |
+
if len(reasoning_tokens) > 0:
|
| 192 |
+
if extra_layers is not None:
|
| 193 |
+
h = hidden_states.float()
|
| 194 |
+
h = extra_layers(h, attention_mask=attention_mask)
|
| 195 |
+
last_h = h[:, -1:, :]
|
| 196 |
+
logits_pred = ffn(last_h.float())
|
| 197 |
+
else:
|
| 198 |
+
last_h = hidden_states[:, -1:, :]
|
| 199 |
+
logits_pred = ffn(last_h.float())
|
| 200 |
+
|
| 201 |
+
pred = torch.sigmoid(logits_pred)
|
| 202 |
+
predictions_list.append(pred[0, 0].item())
|
| 203 |
+
|
| 204 |
+
# Sliding-window majority vote
|
| 205 |
+
if len(predictions_list) >= window_size:
|
| 206 |
+
window = predictions_list[-window_size:]
|
| 207 |
+
n_above = sum(1 for p in window if p > threshold)
|
| 208 |
+
if n_above / window_size > 0.5:
|
| 209 |
+
early_exit = True
|
| 210 |
+
break
|
| 211 |
+
|
| 212 |
+
# Sample next token — LogitsProcessorList expects 2D [batch, vocab]
|
| 213 |
+
next_logits = outputs.logits[:, -1, :] # [1, vocab_size]
|
| 214 |
+
next_logits = logits_processor(input_ids, next_logits)
|
| 215 |
+
probs = F.softmax(next_logits, dim=-1)
|
| 216 |
+
next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
|
| 217 |
+
|
| 218 |
+
# Natural </think>
|
| 219 |
+
if next_token.item() == think_end_token_id:
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
input_ids = torch.cat([input_ids, next_token], dim=1)
|
| 223 |
+
reasoning_tokens.append(next_token.item())
|
| 224 |
+
|
| 225 |
+
# Stream the token
|
| 226 |
+
token_text = tokenizer.decode([next_token.item()], skip_special_tokens=False)
|
| 227 |
+
sys.stdout.write(token_text)
|
| 228 |
+
sys.stdout.flush()
|
| 229 |
+
|
| 230 |
+
# End thinking section
|
| 231 |
+
if early_exit and exit_message:
|
| 232 |
+
sys.stdout.write(exit_message)
|
| 233 |
+
sys.stdout.write(f"{RESET}\n")
|
| 234 |
+
sys.stdout.flush()
|
| 235 |
+
|
| 236 |
+
# Build input for final answer generation
|
| 237 |
+
if early_exit and exit_message:
|
| 238 |
+
exit_ids = tokenizer(
|
| 239 |
+
exit_message, add_special_tokens=False, return_tensors="pt"
|
| 240 |
+
)["input_ids"].to(device).long()
|
| 241 |
+
input_ids = torch.cat(
|
| 242 |
+
[input_ids, exit_ids,
|
| 243 |
+
torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)],
|
| 244 |
+
dim=1,
|
| 245 |
+
)
|
| 246 |
+
else:
|
| 247 |
+
input_ids = torch.cat(
|
| 248 |
+
[input_ids,
|
| 249 |
+
torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)],
|
| 250 |
+
dim=1,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# Generate final answer
|
| 254 |
+
attention_mask = torch.ones_like(input_ids)
|
| 255 |
+
with torch.no_grad():
|
| 256 |
+
final_outputs = model.generate(
|
| 257 |
+
input_ids=input_ids,
|
| 258 |
+
attention_mask=attention_mask,
|
| 259 |
+
max_new_tokens=max_tokens,
|
| 260 |
+
do_sample=True,
|
| 261 |
+
temperature=temperature,
|
| 262 |
+
top_p=0.95,
|
| 263 |
+
top_k=20,
|
| 264 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 265 |
+
eos_token_id=tokenizer.eos_token_id,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Extract answer (everything after last </think>)
|
| 269 |
+
full_seq = final_outputs[0]
|
| 270 |
+
end_positions = (full_seq == think_end_token_id).nonzero(as_tuple=True)[0]
|
| 271 |
+
if len(end_positions) > 0:
|
| 272 |
+
answer_tokens = full_seq[end_positions[-1].item() + 1 :]
|
| 273 |
+
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
|
| 274 |
+
else:
|
| 275 |
+
answer = ""
|
| 276 |
+
|
| 277 |
+
# Print answer
|
| 278 |
+
sys.stdout.write(f"{BOLD}Answer:{RESET}\n{answer}\n")
|
| 279 |
+
sys.stdout.flush()
|
| 280 |
+
|
| 281 |
+
# Summary
|
| 282 |
+
n_reasoning = len(reasoning_tokens)
|
| 283 |
+
exit_reason = "predictor" if early_exit else "natural_end"
|
| 284 |
+
print(
|
| 285 |
+
f"\n{DIM}[{exit_reason} | "
|
| 286 |
+
f"{n_reasoning} thinking tokens | "
|
| 287 |
+
f"{len(predictions_list)} predictions]{RESET}"
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def main():
|
| 292 |
+
parser = argparse.ArgumentParser(
|
| 293 |
+
description=__doc__,
|
| 294 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 295 |
+
)
|
| 296 |
+
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
|
| 297 |
+
parser.add_argument(
|
| 298 |
+
"--model", type=str, default="Qwen/Qwen3-14B", help="HuggingFace model name"
|
| 299 |
+
)
|
| 300 |
+
parser.add_argument(
|
| 301 |
+
"--checkpoint",
|
| 302 |
+
type=str,
|
| 303 |
+
default=None,
|
| 304 |
+
help="Path to terminator .pt checkpoint (default: ./terminator.pt)",
|
| 305 |
+
)
|
| 306 |
+
parser.add_argument(
|
| 307 |
+
"--threshold", type=float, default=0.7, help="Per-prediction binarization threshold"
|
| 308 |
+
)
|
| 309 |
+
parser.add_argument(
|
| 310 |
+
"--window-size", type=int, default=10, help="Sliding-window size for majority vote"
|
| 311 |
+
)
|
| 312 |
+
parser.add_argument(
|
| 313 |
+
"--exit-message",
|
| 314 |
+
type=str,
|
| 315 |
+
default="\nI've run out of thinking tokens. I need to commit to a final answer.",
|
| 316 |
+
help="Message injected when terminator fires (empty string to disable)",
|
| 317 |
+
)
|
| 318 |
+
parser.add_argument(
|
| 319 |
+
"--max-tokens", type=int, default=32768, help="Max tokens to generate"
|
| 320 |
+
)
|
| 321 |
+
parser.add_argument(
|
| 322 |
+
"--temperature", type=float, default=0.6, help="Sampling temperature"
|
| 323 |
+
)
|
| 324 |
+
parser.add_argument(
|
| 325 |
+
"--device", type=str, default="cuda", help="Device (default: cuda)"
|
| 326 |
+
)
|
| 327 |
+
args = parser.parse_args()
|
| 328 |
+
|
| 329 |
+
# Resolve checkpoint path
|
| 330 |
+
if args.checkpoint is None:
|
| 331 |
+
args.checkpoint = str(_script_dir / "terminator.pt")
|
| 332 |
+
|
| 333 |
+
if not Path(args.checkpoint).exists():
|
| 334 |
+
print(f"ERROR: Checkpoint not found: {args.checkpoint}", file=sys.stderr)
|
| 335 |
+
sys.exit(1)
|
| 336 |
+
|
| 337 |
+
# Handle empty exit message
|
| 338 |
+
if args.exit_message == "":
|
| 339 |
+
args.exit_message = None
|
| 340 |
+
|
| 341 |
+
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
|
| 342 |
+
|
| 343 |
+
# Load base model
|
| 344 |
+
model, tokenizer, think_id, think_end_id = load_model_and_tokenizer(
|
| 345 |
+
args.model, device
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# Load terminator checkpoint
|
| 349 |
+
rms_eps = getattr(model.config, "rms_norm_eps", 1e-6)
|
| 350 |
+
ffn, ckpt_config, layer_idx, num_extra_layers, extra_sd = load_terminator_checkpoint(
|
| 351 |
+
args.checkpoint, rms_norm_eps=rms_eps, device=device
|
| 352 |
+
)
|
| 353 |
+
ffn_params = sum(p.numel() for p in ffn.parameters())
|
| 354 |
+
print(
|
| 355 |
+
f"Terminator FFN loaded (layer_idx={layer_idx}, "
|
| 356 |
+
f"threshold={args.threshold}, window={args.window_size}, "
|
| 357 |
+
f"params={ffn_params:,})"
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
# Extra layers
|
| 361 |
+
extra_layers = build_extra_layers(model, ckpt_config, extra_sd, device)
|
| 362 |
+
|
| 363 |
+
# Generate
|
| 364 |
+
generate_with_terminator(
|
| 365 |
+
prompt=args.prompt,
|
| 366 |
+
model=model,
|
| 367 |
+
tokenizer=tokenizer,
|
| 368 |
+
ffn=ffn,
|
| 369 |
+
extra_layers=extra_layers,
|
| 370 |
+
layer_idx=layer_idx,
|
| 371 |
+
think_token_id=think_id,
|
| 372 |
+
think_end_token_id=think_end_id,
|
| 373 |
+
threshold=args.threshold,
|
| 374 |
+
window_size=args.window_size,
|
| 375 |
+
exit_message=args.exit_message,
|
| 376 |
+
max_tokens=args.max_tokens,
|
| 377 |
+
temperature=args.temperature,
|
| 378 |
+
device=device,
|
| 379 |
+
)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
if __name__ == "__main__":
|
| 383 |
+
main()
|
serve.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
vLLM API server launcher for Qwen3TerminatorForCausalLM.
|
| 4 |
+
|
| 5 |
+
Imports vllm_terminator BEFORE vLLM initialises, which registers
|
| 6 |
+
Qwen3TerminatorForCausalLM with vLLM's ModelRegistry.
|
| 7 |
+
|
| 8 |
+
NOTE: Terminator currently supports single-GPU, single-sequence inference only.
|
| 9 |
+
Tensor parallelism and concurrent sequences are not supported.
|
| 10 |
+
|
| 11 |
+
Environment variables:
|
| 12 |
+
VLLM_MODEL — path to terminator model directory (required)
|
| 13 |
+
VLLM_PORT — port (default 8000)
|
| 14 |
+
VLLM_GPU_UTIL — GPU memory fraction (default 0.90)
|
| 15 |
+
VLLM_MAX_MODEL_LEN — max context length
|
| 16 |
+
VLLM_DTYPE — dtype (default "auto")
|
| 17 |
+
VLLM_API_KEY — require this API key from clients
|
| 18 |
+
VLLM_SERVED_NAME — override served model name
|
| 19 |
+
VLLM_HOST — bind address (default 0.0.0.0)
|
| 20 |
+
NO_PREFIX_CACHING — set to 1 to disable prefix caching
|
| 21 |
+
VLLM_ENFORCE_EAGER — set to 1 to disable CUDA graphs (default 0)
|
| 22 |
+
REASONING_PARSER — set to "qwen3" to enable <think>/</think> parsing
|
| 23 |
+
(splits reasoning_content from content in API responses)
|
| 24 |
+
|
| 25 |
+
Example:
|
| 26 |
+
VLLM_MODEL=./model_dir python serve.py
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
import os
|
| 30 |
+
import runpy
|
| 31 |
+
import sys
|
| 32 |
+
|
| 33 |
+
# -----------------------------------------------------------------------
|
| 34 |
+
# CRITICAL: import vllm_terminator HERE, before any vLLM code runs.
|
| 35 |
+
# This registers Qwen3TerminatorForCausalLM with vLLM's ModelRegistry.
|
| 36 |
+
# -----------------------------------------------------------------------
|
| 37 |
+
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
|
| 38 |
+
import vllm_terminator # noqa: F401 (registers the model as a side effect)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def env(name, default=None, required=False):
|
| 42 |
+
v = os.environ.get(name, default)
|
| 43 |
+
if required and (v is None or v == ""):
|
| 44 |
+
print(f"Missing required env var: {name}", file=sys.stderr)
|
| 45 |
+
sys.exit(2)
|
| 46 |
+
return v
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def main():
|
| 50 |
+
model = env("VLLM_MODEL", required=True)
|
| 51 |
+
host = env("VLLM_HOST", "0.0.0.0")
|
| 52 |
+
port = env("VLLM_PORT", "8000")
|
| 53 |
+
max_len = env("VLLM_MAX_MODEL_LEN", None)
|
| 54 |
+
gpu_util = env("VLLM_GPU_UTIL", "0.90")
|
| 55 |
+
served_name = env("VLLM_SERVED_NAME", None)
|
| 56 |
+
dtype = env("VLLM_DTYPE", "auto")
|
| 57 |
+
api_key = env("VLLM_API_KEY", None)
|
| 58 |
+
no_prefix_caching = env("NO_PREFIX_CACHING", "0")
|
| 59 |
+
enforce_eager = env("VLLM_ENFORCE_EAGER", "0")
|
| 60 |
+
reasoning_parser = env("REASONING_PARSER", None)
|
| 61 |
+
|
| 62 |
+
argv = [
|
| 63 |
+
"vllm.entrypoints.openai.api_server",
|
| 64 |
+
"--model", model,
|
| 65 |
+
"--host", host,
|
| 66 |
+
"--port", str(port),
|
| 67 |
+
"--dtype", dtype,
|
| 68 |
+
"--gpu-memory-utilization", str(gpu_util),
|
| 69 |
+
"--tensor-parallel-size", "1",
|
| 70 |
+
"--max-num-seqs", "1",
|
| 71 |
+
]
|
| 72 |
+
|
| 73 |
+
if served_name:
|
| 74 |
+
argv += ["--served-model-name", served_name]
|
| 75 |
+
if max_len:
|
| 76 |
+
argv += ["--max-model-len", str(max_len)]
|
| 77 |
+
if api_key:
|
| 78 |
+
argv += ["--api-key", api_key]
|
| 79 |
+
if no_prefix_caching == "1":
|
| 80 |
+
argv += ["--enable-prefix-caching", "False"]
|
| 81 |
+
if enforce_eager == "1":
|
| 82 |
+
argv += ["--enforce-eager"]
|
| 83 |
+
if reasoning_parser:
|
| 84 |
+
argv += ["--reasoning-parser", reasoning_parser]
|
| 85 |
+
|
| 86 |
+
print(f"Launching vLLM Terminator server with:\n " + " ".join(argv[1:]), flush=True)
|
| 87 |
+
|
| 88 |
+
# Replace sys.argv so vLLM's argparse sees these arguments, then run the
|
| 89 |
+
# server module in-process (so vllm_terminator registration persists).
|
| 90 |
+
sys.argv = argv
|
| 91 |
+
runpy.run_module("vllm.entrypoints.openai.api_server", run_name="__main__")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
main()
|
setup.sh
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# ==========================================================================
|
| 5 |
+
# Terminator-Qwen3-14B — Automated Setup
|
| 6 |
+
#
|
| 7 |
+
# This script:
|
| 8 |
+
# 1. Creates a conda environment with Python 3.12
|
| 9 |
+
# 2. Installs uv, vllm, and openai
|
| 10 |
+
# 3. Downloads Qwen3-14B base model weights and creates the model directory
|
| 11 |
+
#
|
| 12 |
+
# Prerequisites:
|
| 13 |
+
# - NVIDIA GPU with sufficient VRAM (minimum ~40GB for Qwen3-14B)
|
| 14 |
+
# - CUDA drivers installed
|
| 15 |
+
# - conda or micromamba installed
|
| 16 |
+
#
|
| 17 |
+
# Usage:
|
| 18 |
+
# ./setup.sh
|
| 19 |
+
# ==========================================================================
|
| 20 |
+
|
| 21 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 22 |
+
ENV_NAME="${TERMINATOR_ENV_NAME:-terminator}"
|
| 23 |
+
|
| 24 |
+
echo ""
|
| 25 |
+
echo "====================================="
|
| 26 |
+
echo " Terminator-Qwen3-14B Setup"
|
| 27 |
+
echo "====================================="
|
| 28 |
+
echo ""
|
| 29 |
+
|
| 30 |
+
# ------------------------------------------------------------------
|
| 31 |
+
# Step 1: Create conda environment
|
| 32 |
+
# ------------------------------------------------------------------
|
| 33 |
+
|
| 34 |
+
# Detect conda or micromamba
|
| 35 |
+
if command -v micromamba &>/dev/null; then
|
| 36 |
+
CONDA_CMD="micromamba"
|
| 37 |
+
elif command -v conda &>/dev/null; then
|
| 38 |
+
CONDA_CMD="conda"
|
| 39 |
+
else
|
| 40 |
+
echo "ERROR: Neither conda nor micromamba found."
|
| 41 |
+
echo ""
|
| 42 |
+
echo "Install micromamba:"
|
| 43 |
+
echo ' "${SHELL}" <(curl -L micro.mamba.pm/install.sh)'
|
| 44 |
+
echo ""
|
| 45 |
+
echo "Or install conda:"
|
| 46 |
+
echo " https://docs.conda.io/en/latest/miniconda.html"
|
| 47 |
+
exit 1
|
| 48 |
+
fi
|
| 49 |
+
|
| 50 |
+
echo "[1/3] Setting up Python environment..."
|
| 51 |
+
|
| 52 |
+
# Check if environment already exists
|
| 53 |
+
if $CONDA_CMD env list 2>/dev/null | grep -q "^${ENV_NAME} \|/${ENV_NAME}\$"; then
|
| 54 |
+
echo " Environment '${ENV_NAME}' already exists. Activating..."
|
| 55 |
+
else
|
| 56 |
+
echo " Creating environment '${ENV_NAME}' with Python 3.12..."
|
| 57 |
+
$CONDA_CMD create -n "${ENV_NAME}" python=3.12 -y
|
| 58 |
+
fi
|
| 59 |
+
|
| 60 |
+
# Helper: run a command inside the conda/micromamba environment
|
| 61 |
+
run_in_env() {
|
| 62 |
+
if [ "$CONDA_CMD" = "micromamba" ]; then
|
| 63 |
+
micromamba run -n "${ENV_NAME}" "$@"
|
| 64 |
+
else
|
| 65 |
+
# For conda, activate in a subshell
|
| 66 |
+
(eval "$(conda shell.bash hook 2>/dev/null)" && conda activate "${ENV_NAME}" && "$@")
|
| 67 |
+
fi
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
echo " Python: $(run_in_env python --version)"
|
| 71 |
+
|
| 72 |
+
# ------------------------------------------------------------------
|
| 73 |
+
# Step 2: Install packages
|
| 74 |
+
# ------------------------------------------------------------------
|
| 75 |
+
|
| 76 |
+
echo ""
|
| 77 |
+
echo "[2/3] Installing packages..."
|
| 78 |
+
|
| 79 |
+
echo " Installing uv..."
|
| 80 |
+
run_in_env pip install --upgrade uv --quiet
|
| 81 |
+
|
| 82 |
+
echo " Installing vllm (this may take a few minutes)..."
|
| 83 |
+
run_in_env uv pip install vllm --torch-backend=auto
|
| 84 |
+
|
| 85 |
+
echo " Installing openai (for client)..."
|
| 86 |
+
run_in_env uv pip install openai
|
| 87 |
+
|
| 88 |
+
echo " Installing accelerate (for HF inference)..."
|
| 89 |
+
run_in_env uv pip install accelerate
|
| 90 |
+
|
| 91 |
+
echo " Done."
|
| 92 |
+
|
| 93 |
+
# ------------------------------------------------------------------
|
| 94 |
+
# Step 3: Set up model directory
|
| 95 |
+
# ------------------------------------------------------------------
|
| 96 |
+
|
| 97 |
+
echo ""
|
| 98 |
+
echo "[3/3] Setting up model directory..."
|
| 99 |
+
echo " This downloads Qwen3-14B base weights (~28GB) from HuggingFace."
|
| 100 |
+
echo " (Skipped if already cached.)"
|
| 101 |
+
echo ""
|
| 102 |
+
|
| 103 |
+
cd "$SCRIPT_DIR"
|
| 104 |
+
run_in_env python setup_model_dir.py
|
| 105 |
+
|
| 106 |
+
# ------------------------------------------------------------------
|
| 107 |
+
# Done
|
| 108 |
+
# ------------------------------------------------------------------
|
| 109 |
+
|
| 110 |
+
echo ""
|
| 111 |
+
echo "====================================="
|
| 112 |
+
echo " Setup Complete!"
|
| 113 |
+
echo "====================================="
|
| 114 |
+
echo ""
|
| 115 |
+
echo "To start the server:"
|
| 116 |
+
echo " $CONDA_CMD activate ${ENV_NAME}"
|
| 117 |
+
echo " cd $SCRIPT_DIR"
|
| 118 |
+
echo " ./start_server.sh"
|
| 119 |
+
echo ""
|
| 120 |
+
echo "Then in another terminal:"
|
| 121 |
+
echo " $CONDA_CMD activate ${ENV_NAME}"
|
| 122 |
+
echo " cd $SCRIPT_DIR"
|
| 123 |
+
echo " python client.py --interactive"
|
| 124 |
+
echo ""
|
| 125 |
+
echo "See README.md for configuration options (GPU memory, context length, etc.)"
|
setup_model_dir.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Create a vLLM-ready model directory for Qwen3TerminatorForCausalLM.
|
| 4 |
+
|
| 5 |
+
Downloads the base Qwen3-14B config and weights from HuggingFace (if not
|
| 6 |
+
already cached), then creates a model directory with:
|
| 7 |
+
- config.json (Qwen3-14B base config + terminator fields)
|
| 8 |
+
- tokenizer files (symlinked from HF cache)
|
| 9 |
+
- model weights (symlinked from HF cache)
|
| 10 |
+
|
| 11 |
+
Usage:
|
| 12 |
+
# Default: uses ./terminator.pt checkpoint, creates ./model_dir
|
| 13 |
+
python setup_model_dir.py
|
| 14 |
+
|
| 15 |
+
# Custom paths and settings:
|
| 16 |
+
python setup_model_dir.py \\
|
| 17 |
+
--checkpoint /path/to/terminator.pt \\
|
| 18 |
+
--output-dir /path/to/model_dir \\
|
| 19 |
+
--threshold 0.5
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
import sys
|
| 25 |
+
from pathlib import Path
|
| 26 |
+
|
| 27 |
+
from huggingface_hub import snapshot_download
|
| 28 |
+
from transformers import AutoConfig
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def main():
|
| 32 |
+
parser = argparse.ArgumentParser(
|
| 33 |
+
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
|
| 34 |
+
)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--base-model", default="Qwen/Qwen3-14B",
|
| 37 |
+
help="HuggingFace model ID for the base model (default: Qwen/Qwen3-14B).",
|
| 38 |
+
)
|
| 39 |
+
parser.add_argument(
|
| 40 |
+
"--checkpoint", type=Path, default="./terminator.pt",
|
| 41 |
+
help="Path to trained terminator .pt checkpoint (default: ./terminator.pt).",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--output-dir", type=Path, default="./model_dir",
|
| 45 |
+
help="Destination directory (default: ./model_dir; created if missing).",
|
| 46 |
+
)
|
| 47 |
+
parser.add_argument(
|
| 48 |
+
"--threshold", type=float, default=0.7,
|
| 49 |
+
help="Terminator firing threshold (default 0.7).",
|
| 50 |
+
)
|
| 51 |
+
parser.add_argument(
|
| 52 |
+
"--window-size", type=int, default=10,
|
| 53 |
+
help="Sliding window size for majority vote (default 10).",
|
| 54 |
+
)
|
| 55 |
+
parser.add_argument(
|
| 56 |
+
"--exit-message", type=str,
|
| 57 |
+
default="\nI've run out of thinking tokens. I need to commit to a final answer.",
|
| 58 |
+
help="Message forced when terminator fires (default: standard exit message). "
|
| 59 |
+
"Set to empty string to disable.",
|
| 60 |
+
)
|
| 61 |
+
parser.add_argument(
|
| 62 |
+
"--no-download", action="store_true",
|
| 63 |
+
help="Fail if the base model is not already cached locally "
|
| 64 |
+
"(by default, downloads from HuggingFace if needed).",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument(
|
| 67 |
+
"--force", action="store_true",
|
| 68 |
+
help="Overwrite files in existing output directory.",
|
| 69 |
+
)
|
| 70 |
+
args = parser.parse_args()
|
| 71 |
+
|
| 72 |
+
checkpoint = args.checkpoint.resolve()
|
| 73 |
+
out_dir = args.output_dir.resolve()
|
| 74 |
+
|
| 75 |
+
if not checkpoint.is_file():
|
| 76 |
+
print(f"ERROR: checkpoint not found: {checkpoint}", file=sys.stderr)
|
| 77 |
+
sys.exit(1)
|
| 78 |
+
|
| 79 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 80 |
+
|
| 81 |
+
# --- Build patched config.json ---
|
| 82 |
+
print(f"Loading config for {args.base_model} from HF cache...")
|
| 83 |
+
config = AutoConfig.from_pretrained(args.base_model)
|
| 84 |
+
|
| 85 |
+
config.architectures = ["Qwen3TerminatorForCausalLM"]
|
| 86 |
+
config.terminator_checkpoint_path = str(checkpoint)
|
| 87 |
+
config.terminator_threshold = args.threshold
|
| 88 |
+
config.terminator_window_size = args.window_size
|
| 89 |
+
config.terminator_exit_message = args.exit_message
|
| 90 |
+
|
| 91 |
+
# Remove auto_map if present from an older span-predictor config
|
| 92 |
+
if hasattr(config, "auto_map"):
|
| 93 |
+
del config.auto_map
|
| 94 |
+
|
| 95 |
+
config.save_pretrained(out_dir)
|
| 96 |
+
print(f" Wrote config.json -> {out_dir / 'config.json'}")
|
| 97 |
+
|
| 98 |
+
# --- Symlink weights and tokenizer files from HF cache ---
|
| 99 |
+
print(f"Locating {args.base_model} in HF cache...")
|
| 100 |
+
allow_download = not args.no_download
|
| 101 |
+
base_dir = Path(snapshot_download(args.base_model, local_files_only=not allow_download))
|
| 102 |
+
print(f" Found: {base_dir}")
|
| 103 |
+
|
| 104 |
+
linked = 0
|
| 105 |
+
for src in sorted(base_dir.iterdir()):
|
| 106 |
+
if src.name in ("config.json",):
|
| 107 |
+
continue # we already wrote our own
|
| 108 |
+
|
| 109 |
+
dst = out_dir / src.name
|
| 110 |
+
if dst.exists() or dst.is_symlink():
|
| 111 |
+
if args.force:
|
| 112 |
+
dst.unlink()
|
| 113 |
+
else:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
os.symlink(src, dst)
|
| 117 |
+
print(f" Linked {src.name}")
|
| 118 |
+
linked += 1
|
| 119 |
+
|
| 120 |
+
print(f"\nDone. Linked {linked} files into {out_dir}")
|
| 121 |
+
print(f"\nTo start the server:")
|
| 122 |
+
print(f" ./start_server.sh")
|
| 123 |
+
print(f"\nOr manually:")
|
| 124 |
+
print(f" VLLM_MODEL={out_dir} REASONING_PARSER=qwen3 python serve.py")
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
if __name__ == "__main__":
|
| 128 |
+
main()
|
start_server.sh
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
# ==========================================================================
|
| 5 |
+
# Terminator-Qwen3-14B — Server Launcher
|
| 6 |
+
#
|
| 7 |
+
# Starts the vLLM server with the Terminator model.
|
| 8 |
+
# Run setup.sh first to create the model directory.
|
| 9 |
+
#
|
| 10 |
+
# Configuration (set as environment variables before running):
|
| 11 |
+
#
|
| 12 |
+
# VLLM_GPU_UTIL GPU memory fraction to use (default: 0.90)
|
| 13 |
+
#
|
| 14 |
+
# VLLM_MAX_MODEL_LEN Maximum context length in tokens (default: server picks)
|
| 15 |
+
#
|
| 16 |
+
# VLLM_PORT Server port (default: 8000)
|
| 17 |
+
#
|
| 18 |
+
# VLLM_ENFORCE_EAGER Set to 1 to disable CUDA graphs (default: 0)
|
| 19 |
+
# Use if you encounter CUDA graph compilation errors.
|
| 20 |
+
# NOTE: VLLM_ENFORCE_EAGER=0 will result in slower responses
|
| 21 |
+
#
|
| 22 |
+
# VLLM_API_KEY Require this API key from clients (default: none)
|
| 23 |
+
#
|
| 24 |
+
# Usage:
|
| 25 |
+
# ./start_server.sh
|
| 26 |
+
# or to manually override default environment variables:
|
| 27 |
+
# VLLM_GPU_UTIL=0.70 VLLM_MAX_MODEL_LEN=8192 ./start_server.sh
|
| 28 |
+
# ==========================================================================
|
| 29 |
+
|
| 30 |
+
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
| 31 |
+
MODEL_DIR="${SCRIPT_DIR}/model_dir"
|
| 32 |
+
|
| 33 |
+
if [ ! -d "$MODEL_DIR" ]; then
|
| 34 |
+
echo "ERROR: Model directory not found at: $MODEL_DIR" >&2
|
| 35 |
+
echo "" >&2
|
| 36 |
+
echo "Run setup first:" >&2
|
| 37 |
+
echo " ./setup.sh" >&2
|
| 38 |
+
echo "" >&2
|
| 39 |
+
echo "Or manually:" >&2
|
| 40 |
+
echo " python setup_model_dir.py" >&2
|
| 41 |
+
exit 1
|
| 42 |
+
fi
|
| 43 |
+
|
| 44 |
+
export VLLM_MODEL="$MODEL_DIR"
|
| 45 |
+
export REASONING_PARSER="${REASONING_PARSER:-qwen3}"
|
| 46 |
+
export VLLM_SERVED_NAME="${VLLM_SERVED_NAME:-Terminator-Qwen3-14B}"
|
| 47 |
+
|
| 48 |
+
exec python "$SCRIPT_DIR/serve.py"
|
terminator.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:94deb40f18fe3d3e642b0dd1b07d0833b1363d389c0ed01df2946b533b715c97
|
| 3 |
+
size 1321295486
|
vllm_terminator/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
vLLM plugin that registers the Qwen3 Terminator model.
|
| 3 |
+
|
| 4 |
+
Usage:
|
| 5 |
+
import vllm_terminator # registers the model architecture
|
| 6 |
+
|
| 7 |
+
Then set ``"architectures": ["Qwen3TerminatorForCausalLM"]`` in the
|
| 8 |
+
HuggingFace config.json alongside::
|
| 9 |
+
|
| 10 |
+
"terminator_checkpoint_path": "/path/to/layer_-1.pt",
|
| 11 |
+
"terminator_threshold": 0.7
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
from vllm import ModelRegistry
|
| 15 |
+
|
| 16 |
+
ModelRegistry.register_model(
|
| 17 |
+
"Qwen3TerminatorForCausalLM",
|
| 18 |
+
"vllm_terminator.model:Qwen3TerminatorForCausalLM",
|
| 19 |
+
)
|
vllm_terminator/model.py
ADDED
|
@@ -0,0 +1,553 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
vLLM-compatible Qwen3 model with Terminator FFN for early reasoning truncation.
|
| 3 |
+
|
| 4 |
+
The Terminator predicts when chain-of-thought reasoning has reached the final
|
| 5 |
+
answer. When a majority of recent predictions (sliding window) exceed a
|
| 6 |
+
threshold, the model forces generation of a configurable exit message followed
|
| 7 |
+
by </think> to truncate reasoning early. The exit message helps the model
|
| 8 |
+
transition smoothly from thinking to answering mode.
|
| 9 |
+
|
| 10 |
+
Supports optional extra transformer layers between the base model and the FFN
|
| 11 |
+
head. These layers get their own KV cache via vLLM's auto-discovery mechanism.
|
| 12 |
+
|
| 13 |
+
Constraints:
|
| 14 |
+
- layer_idx = -1 (last layer)
|
| 15 |
+
- sliding_window strategy (majority vote over last window_size predictions)
|
| 16 |
+
- lag_size = 0
|
| 17 |
+
- batch_size = 1
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from collections import deque
|
| 21 |
+
from collections.abc import Iterable
|
| 22 |
+
from itertools import islice
|
| 23 |
+
|
| 24 |
+
import torch
|
| 25 |
+
from torch import nn
|
| 26 |
+
from transformers import AutoTokenizer
|
| 27 |
+
|
| 28 |
+
from vllm.config import VllmConfig
|
| 29 |
+
from vllm.distributed import get_pp_group
|
| 30 |
+
from vllm.logger import init_logger
|
| 31 |
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
| 32 |
+
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
|
| 33 |
+
from vllm.model_executor.models.interfaces import SupportsPP
|
| 34 |
+
from vllm.model_executor.models.qwen2 import Qwen2Model
|
| 35 |
+
from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
|
| 36 |
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
| 37 |
+
from vllm.model_executor.models.utils import (
|
| 38 |
+
AutoWeightsLoader,
|
| 39 |
+
PPMissingLayer,
|
| 40 |
+
maybe_prefix,
|
| 41 |
+
)
|
| 42 |
+
from vllm.sequence import IntermediateTensors
|
| 43 |
+
|
| 44 |
+
from .terminator_head import TerminatorFFN
|
| 45 |
+
|
| 46 |
+
logger = init_logger(__name__)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Qwen3TerminatorModel(Qwen2Model):
|
| 50 |
+
"""Qwen3 backbone that captures pre-norm hidden states for the Terminator FFN."""
|
| 51 |
+
|
| 52 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 53 |
+
super().__init__(
|
| 54 |
+
vllm_config=vllm_config,
|
| 55 |
+
prefix=prefix,
|
| 56 |
+
decoder_layer_type=Qwen3DecoderLayer,
|
| 57 |
+
)
|
| 58 |
+
self._pre_norm_hidden: torch.Tensor | None = None
|
| 59 |
+
|
| 60 |
+
def forward(
|
| 61 |
+
self,
|
| 62 |
+
input_ids: torch.Tensor,
|
| 63 |
+
positions: torch.Tensor,
|
| 64 |
+
intermediate_tensors: IntermediateTensors | None = None,
|
| 65 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 66 |
+
) -> torch.Tensor | IntermediateTensors:
|
| 67 |
+
if get_pp_group().is_first_rank:
|
| 68 |
+
if inputs_embeds is not None:
|
| 69 |
+
hidden_states = inputs_embeds
|
| 70 |
+
else:
|
| 71 |
+
hidden_states = self.embed_input_ids(input_ids)
|
| 72 |
+
residual = None
|
| 73 |
+
else:
|
| 74 |
+
assert intermediate_tensors is not None
|
| 75 |
+
hidden_states = intermediate_tensors["hidden_states"]
|
| 76 |
+
residual = intermediate_tensors["residual"]
|
| 77 |
+
|
| 78 |
+
aux_hidden_states = []
|
| 79 |
+
for idx, layer in enumerate(
|
| 80 |
+
islice(self.layers, self.start_layer, self.end_layer)
|
| 81 |
+
):
|
| 82 |
+
if idx in self.aux_hidden_state_layers:
|
| 83 |
+
aux_hidden_states.append(hidden_states + residual)
|
| 84 |
+
hidden_states, residual = layer(positions, hidden_states, residual)
|
| 85 |
+
|
| 86 |
+
if not get_pp_group().is_last_rank:
|
| 87 |
+
return IntermediateTensors(
|
| 88 |
+
{"hidden_states": hidden_states, "residual": residual}
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
# Capture pre-norm hidden states (matches training's forward hook output).
|
| 92 |
+
# In the fused residual pattern, pre-norm state = hidden_states + residual.
|
| 93 |
+
self._pre_norm_hidden = hidden_states + residual
|
| 94 |
+
|
| 95 |
+
hidden_states, _ = self.norm(hidden_states, residual)
|
| 96 |
+
|
| 97 |
+
if len(aux_hidden_states) > 0:
|
| 98 |
+
return hidden_states, aux_hidden_states
|
| 99 |
+
|
| 100 |
+
return hidden_states
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Qwen3TerminatorForCausalLM(nn.Module, SupportsPP):
|
| 104 |
+
"""
|
| 105 |
+
Qwen3 causal LM with an attached Terminator FFN that can force </think>
|
| 106 |
+
generation when the model predicts the final answer has been reached.
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
packed_modules_mapping = {
|
| 110 |
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
| 111 |
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
| 115 |
+
super().__init__()
|
| 116 |
+
config = vllm_config.model_config.hf_config
|
| 117 |
+
quant_config = vllm_config.quant_config
|
| 118 |
+
|
| 119 |
+
self.config = config
|
| 120 |
+
self.quant_config = quant_config
|
| 121 |
+
|
| 122 |
+
# --- Base Qwen3 model ---
|
| 123 |
+
self.model = Qwen3TerminatorModel(
|
| 124 |
+
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
if get_pp_group().is_last_rank:
|
| 128 |
+
if config.tie_word_embeddings:
|
| 129 |
+
self.lm_head = self.model.embed_tokens
|
| 130 |
+
else:
|
| 131 |
+
self.lm_head = ParallelLMHead(
|
| 132 |
+
config.vocab_size,
|
| 133 |
+
config.hidden_size,
|
| 134 |
+
quant_config=quant_config,
|
| 135 |
+
prefix=maybe_prefix(prefix, "lm_head"),
|
| 136 |
+
)
|
| 137 |
+
else:
|
| 138 |
+
self.lm_head = PPMissingLayer()
|
| 139 |
+
|
| 140 |
+
self.logits_processor = LogitsProcessor(config.vocab_size)
|
| 141 |
+
self.make_empty_intermediate_tensors = (
|
| 142 |
+
self.model.make_empty_intermediate_tensors
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# --- Terminator FFN ---
|
| 146 |
+
terminator_checkpoint_path = getattr(
|
| 147 |
+
config, "terminator_checkpoint_path", None
|
| 148 |
+
)
|
| 149 |
+
self._terminator_threshold = getattr(
|
| 150 |
+
config, "terminator_threshold", 0.7
|
| 151 |
+
)
|
| 152 |
+
self._terminator_window_size = getattr(
|
| 153 |
+
config, "terminator_window_size", 10
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
if terminator_checkpoint_path:
|
| 157 |
+
self._terminator_checkpoint_path = terminator_checkpoint_path
|
| 158 |
+
self._terminator_enabled = True
|
| 159 |
+
|
| 160 |
+
# Load checkpoint metadata to construct FFN with correct architecture
|
| 161 |
+
checkpoint = torch.load(
|
| 162 |
+
terminator_checkpoint_path, map_location="cpu", weights_only=False
|
| 163 |
+
)
|
| 164 |
+
terminator_config = checkpoint["config"]
|
| 165 |
+
self._terminator_layer_idx = checkpoint["layer_idx"]
|
| 166 |
+
self._terminator_state_dict = checkpoint["state_dict"]
|
| 167 |
+
|
| 168 |
+
self.terminator_ffn = TerminatorFFN(
|
| 169 |
+
hidden_size=terminator_config["hidden_size"],
|
| 170 |
+
num_hidden_layers=terminator_config.get("ffn_layers", 1),
|
| 171 |
+
activation=terminator_config.get("ffn_activation", "gelu"),
|
| 172 |
+
intermediate_size=terminator_config.get(
|
| 173 |
+
"ffn_intermediate_size", None
|
| 174 |
+
),
|
| 175 |
+
dropout=0.0,
|
| 176 |
+
rms_norm_eps=getattr(config, "rms_norm_eps", 1e-6),
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
logger.info(
|
| 180 |
+
"Terminator FFN created (layer_idx=%d, threshold=%.2f, "
|
| 181 |
+
"window_size=%d, params=%d)",
|
| 182 |
+
self._terminator_layer_idx,
|
| 183 |
+
self._terminator_threshold,
|
| 184 |
+
self._terminator_window_size,
|
| 185 |
+
sum(p.numel() for p in self.terminator_ffn.parameters()),
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# --- Extra transformer layers ---
|
| 189 |
+
self._num_extra_layers = terminator_config.get(
|
| 190 |
+
"num_extra_layers", 0
|
| 191 |
+
)
|
| 192 |
+
self._extra_layers_state_dict = checkpoint.get(
|
| 193 |
+
"extra_layers_state_dict", None
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
if self._num_extra_layers > 0 and self._extra_layers_state_dict is not None:
|
| 197 |
+
cache_config = vllm_config.cache_config
|
| 198 |
+
# Use indices starting after base layers to avoid
|
| 199 |
+
# extract_layer_index() collisions in KV cache binding.
|
| 200 |
+
num_base_layers = config.num_hidden_layers
|
| 201 |
+
self.terminator_extra_layers = nn.ModuleList([
|
| 202 |
+
Qwen3DecoderLayer(
|
| 203 |
+
config=config,
|
| 204 |
+
cache_config=cache_config,
|
| 205 |
+
quant_config=quant_config,
|
| 206 |
+
prefix=f"terminator_extra_layers.{num_base_layers + i}",
|
| 207 |
+
)
|
| 208 |
+
for i in range(self._num_extra_layers)
|
| 209 |
+
])
|
| 210 |
+
logger.info(
|
| 211 |
+
"Terminator extra layers created (n=%d, params=%d)",
|
| 212 |
+
self._num_extra_layers,
|
| 213 |
+
sum(
|
| 214 |
+
p.numel()
|
| 215 |
+
for p in self.terminator_extra_layers.parameters()
|
| 216 |
+
),
|
| 217 |
+
)
|
| 218 |
+
else:
|
| 219 |
+
self._num_extra_layers = 0
|
| 220 |
+
self._extra_layers_state_dict = None
|
| 221 |
+
else:
|
| 222 |
+
self._terminator_enabled = False
|
| 223 |
+
self._num_extra_layers = 0
|
| 224 |
+
self._extra_layers_state_dict = None
|
| 225 |
+
logger.info(
|
| 226 |
+
"No terminator_checkpoint_path in config; "
|
| 227 |
+
"terminator disabled (running as standard Qwen3)"
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# --- Think token IDs ---
|
| 231 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
| 232 |
+
vllm_config.model_config.tokenizer,
|
| 233 |
+
trust_remote_code=vllm_config.model_config.trust_remote_code,
|
| 234 |
+
)
|
| 235 |
+
self._think_token_id = tokenizer.convert_tokens_to_ids("<think>")
|
| 236 |
+
self._think_end_token_id = tokenizer.convert_tokens_to_ids("</think>")
|
| 237 |
+
logger.info(
|
| 238 |
+
"<think>=%d, </think>=%d",
|
| 239 |
+
self._think_token_id,
|
| 240 |
+
self._think_end_token_id,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# --- Exit message ---
|
| 244 |
+
# Pre-tokenize the exit message + </think> so we can force one token
|
| 245 |
+
# per step when the terminator fires.
|
| 246 |
+
_default_exit_msg = (
|
| 247 |
+
"\nI've run out of thinking tokens."
|
| 248 |
+
" I need to commit to a final answer."
|
| 249 |
+
)
|
| 250 |
+
exit_msg = getattr(config, "terminator_exit_message", _default_exit_msg)
|
| 251 |
+
if exit_msg:
|
| 252 |
+
msg_ids = tokenizer.encode(exit_msg, add_special_tokens=False)
|
| 253 |
+
self._exit_sequence: list[int] = msg_ids + [self._think_end_token_id]
|
| 254 |
+
else:
|
| 255 |
+
self._exit_sequence = [self._think_end_token_id]
|
| 256 |
+
logger.info(
|
| 257 |
+
"Exit sequence: %d tokens (message=%r)",
|
| 258 |
+
len(self._exit_sequence),
|
| 259 |
+
exit_msg if exit_msg else "<none>",
|
| 260 |
+
)
|
| 261 |
+
|
| 262 |
+
# --- Per-request state (batch_size=1) ---
|
| 263 |
+
self._is_thinking = False
|
| 264 |
+
self._pred_buffer: torch.Tensor | None = None # lazily allocated in forward()
|
| 265 |
+
self._prev_output_token_id: int | None = None # argmax from previous compute_logits()
|
| 266 |
+
self._prediction_history: deque[float] = deque(
|
| 267 |
+
maxlen=self._terminator_window_size if self._terminator_enabled else 1
|
| 268 |
+
)
|
| 269 |
+
self._forcing_exit: bool = False
|
| 270 |
+
self._forcing_idx: int = 0
|
| 271 |
+
|
| 272 |
+
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
|
| 273 |
+
return self.model.embed_input_ids(input_ids)
|
| 274 |
+
|
| 275 |
+
# ------------------------------------------------------------------
|
| 276 |
+
# Thinking state tracking
|
| 277 |
+
# ------------------------------------------------------------------
|
| 278 |
+
|
| 279 |
+
def _update_thinking_state(self, input_ids: torch.Tensor) -> None:
|
| 280 |
+
"""Set ``_is_thinking`` based on the last think token in *input_ids*."""
|
| 281 |
+
think_pos = (input_ids == self._think_token_id).nonzero(as_tuple=True)[0]
|
| 282 |
+
end_pos = (input_ids == self._think_end_token_id).nonzero(as_tuple=True)[0]
|
| 283 |
+
|
| 284 |
+
last_think = think_pos[-1].item() if len(think_pos) > 0 else -1
|
| 285 |
+
last_end = end_pos[-1].item() if len(end_pos) > 0 else -1
|
| 286 |
+
|
| 287 |
+
if last_think > last_end:
|
| 288 |
+
self._is_thinking = True
|
| 289 |
+
elif last_end > last_think:
|
| 290 |
+
self._is_thinking = False
|
| 291 |
+
|
| 292 |
+
# ------------------------------------------------------------------
|
| 293 |
+
# Forward / compute_logits
|
| 294 |
+
# ------------------------------------------------------------------
|
| 295 |
+
|
| 296 |
+
def forward(
|
| 297 |
+
self,
|
| 298 |
+
input_ids: torch.Tensor,
|
| 299 |
+
positions: torch.Tensor,
|
| 300 |
+
intermediate_tensors: IntermediateTensors | None = None,
|
| 301 |
+
inputs_embeds: torch.Tensor | None = None,
|
| 302 |
+
) -> torch.Tensor | IntermediateTensors:
|
| 303 |
+
hidden_states = self.model(
|
| 304 |
+
input_ids, positions, intermediate_tensors, inputs_embeds
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
if self._terminator_enabled:
|
| 308 |
+
# Update thinking state during eager execution only (prefill).
|
| 309 |
+
# .nonzero() / .item() are not CUDA-graph-capture-safe.
|
| 310 |
+
# Reset all per-request state so predictions, exit-forcing,
|
| 311 |
+
# and output token tracking don't leak between requests.
|
| 312 |
+
if not torch.cuda.is_current_stream_capturing():
|
| 313 |
+
self._update_thinking_state(input_ids)
|
| 314 |
+
self._prev_output_token_id = None
|
| 315 |
+
self._prediction_history.clear()
|
| 316 |
+
self._forcing_exit = False
|
| 317 |
+
self._forcing_idx = 0
|
| 318 |
+
|
| 319 |
+
# Run extra layers + FFN unconditionally — all ops (RMSNorm,
|
| 320 |
+
# Linear, Attention with KV cache, sigmoid) are CUDA-graph-safe.
|
| 321 |
+
# The prediction is written in-place to a pre-allocated buffer
|
| 322 |
+
# that persists across graph replays.
|
| 323 |
+
pre_norm = self.model._pre_norm_hidden
|
| 324 |
+
|
| 325 |
+
if self._num_extra_layers > 0:
|
| 326 |
+
# Run extra transformer layers on ALL tokens so KV cache
|
| 327 |
+
# is populated during prefill and updated during decode.
|
| 328 |
+
extra_hidden = pre_norm
|
| 329 |
+
extra_residual = None
|
| 330 |
+
for layer in self.terminator_extra_layers:
|
| 331 |
+
extra_hidden, extra_residual = layer(
|
| 332 |
+
positions, extra_hidden, extra_residual
|
| 333 |
+
)
|
| 334 |
+
# Reconstruct pre-norm state for the FFN:
|
| 335 |
+
# hidden + residual gives the un-normed output,
|
| 336 |
+
# matching what the FFN's own RMSNorm expects.
|
| 337 |
+
ffn_input = extra_hidden + extra_residual
|
| 338 |
+
else:
|
| 339 |
+
ffn_input = pre_norm
|
| 340 |
+
|
| 341 |
+
logit = self.terminator_ffn(ffn_input[-1:])
|
| 342 |
+
pred = torch.sigmoid(logit)
|
| 343 |
+
if self._pred_buffer is None:
|
| 344 |
+
# First eager forward (warmup) — allocate the buffer.
|
| 345 |
+
# Subsequent calls (including graph capture) reuse it.
|
| 346 |
+
self._pred_buffer = torch.zeros_like(pred)
|
| 347 |
+
self._pred_buffer.copy_(pred)
|
| 348 |
+
|
| 349 |
+
return hidden_states
|
| 350 |
+
|
| 351 |
+
def compute_logits(
|
| 352 |
+
self,
|
| 353 |
+
hidden_states: torch.Tensor,
|
| 354 |
+
) -> torch.Tensor | None:
|
| 355 |
+
logits = self.logits_processor(self.lm_head, hidden_states)
|
| 356 |
+
|
| 357 |
+
# compute_logits() is called outside the CUDA graph, so .item()
|
| 358 |
+
# and data-dependent branching are safe here.
|
| 359 |
+
if self._terminator_enabled:
|
| 360 |
+
# --- Exit message forcing ---
|
| 361 |
+
# When the terminator has fired, we walk through the pre-tokenized
|
| 362 |
+
# exit sequence one token per step, skipping all prediction logic.
|
| 363 |
+
if self._forcing_exit:
|
| 364 |
+
token_id = self._exit_sequence[self._forcing_idx]
|
| 365 |
+
logits.fill_(float("-inf"))
|
| 366 |
+
logits[:, token_id] = 0.0
|
| 367 |
+
self._forcing_idx += 1
|
| 368 |
+
if self._forcing_idx >= len(self._exit_sequence):
|
| 369 |
+
# Done forcing — last token was </think>.
|
| 370 |
+
self._forcing_exit = False
|
| 371 |
+
self._is_thinking = False
|
| 372 |
+
logger.debug(
|
| 373 |
+
"Exit sequence complete (%d tokens forced)",
|
| 374 |
+
len(self._exit_sequence),
|
| 375 |
+
)
|
| 376 |
+
self._prev_output_token_id = token_id
|
| 377 |
+
return logits
|
| 378 |
+
|
| 379 |
+
# Track thinking state from the previous step's output token.
|
| 380 |
+
# During CUDA graph replay forward() doesn't execute, so
|
| 381 |
+
# _update_thinking_state() never sees generated <think> tokens.
|
| 382 |
+
# Instead we infer the state from the argmax of the previous
|
| 383 |
+
# step's logits (exact for greedy; heuristic for sampling).
|
| 384 |
+
if self._prev_output_token_id == self._think_token_id:
|
| 385 |
+
if not self._is_thinking:
|
| 386 |
+
self._prediction_history.clear()
|
| 387 |
+
self._is_thinking = True
|
| 388 |
+
elif self._prev_output_token_id == self._think_end_token_id:
|
| 389 |
+
self._is_thinking = False
|
| 390 |
+
|
| 391 |
+
if self._is_thinking and self._pred_buffer is not None:
|
| 392 |
+
pred = self._pred_buffer.item()
|
| 393 |
+
self._prediction_history.append(pred)
|
| 394 |
+
|
| 395 |
+
if len(self._prediction_history) >= self._terminator_window_size:
|
| 396 |
+
n_above = sum(
|
| 397 |
+
1 for p in self._prediction_history
|
| 398 |
+
if p > self._terminator_threshold
|
| 399 |
+
)
|
| 400 |
+
vote = n_above / self._terminator_window_size
|
| 401 |
+
if vote > 0.5:
|
| 402 |
+
# Majority of sliding window exceeds threshold —
|
| 403 |
+
# enter exit-message forcing mode.
|
| 404 |
+
logger.debug(
|
| 405 |
+
"Terminator FIRING: pred=%.3f, "
|
| 406 |
+
"window=[%s] (%d/%d above %.2f, vote=%.2f)",
|
| 407 |
+
pred,
|
| 408 |
+
", ".join(f"{p:.3f}" for p in self._prediction_history),
|
| 409 |
+
n_above,
|
| 410 |
+
self._terminator_window_size,
|
| 411 |
+
self._terminator_threshold,
|
| 412 |
+
vote,
|
| 413 |
+
)
|
| 414 |
+
# Force the first token of the exit sequence now,
|
| 415 |
+
# and set up state so subsequent calls continue.
|
| 416 |
+
self._forcing_exit = True
|
| 417 |
+
self._forcing_idx = 0
|
| 418 |
+
token_id = self._exit_sequence[0]
|
| 419 |
+
logits.fill_(float("-inf"))
|
| 420 |
+
logits[:, token_id] = 0.0
|
| 421 |
+
self._forcing_idx = 1
|
| 422 |
+
if self._forcing_idx >= len(self._exit_sequence):
|
| 423 |
+
self._forcing_exit = False
|
| 424 |
+
self._is_thinking = False
|
| 425 |
+
self._prev_output_token_id = token_id
|
| 426 |
+
return logits
|
| 427 |
+
else:
|
| 428 |
+
logger.debug(
|
| 429 |
+
"Terminator: pred=%.3f, "
|
| 430 |
+
"window=[%s] (%d/%d above %.2f, vote=%.2f)",
|
| 431 |
+
pred,
|
| 432 |
+
", ".join(f"{p:.3f}" for p in self._prediction_history),
|
| 433 |
+
n_above,
|
| 434 |
+
self._terminator_window_size,
|
| 435 |
+
self._terminator_threshold,
|
| 436 |
+
vote,
|
| 437 |
+
)
|
| 438 |
+
else:
|
| 439 |
+
logger.debug(
|
| 440 |
+
"Terminator: pred=%.3f, filling window (%d/%d)",
|
| 441 |
+
pred,
|
| 442 |
+
len(self._prediction_history),
|
| 443 |
+
self._terminator_window_size,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
# Record argmax for next step's thinking-state tracking.
|
| 447 |
+
self._prev_output_token_id = logits[0].argmax().item()
|
| 448 |
+
|
| 449 |
+
return logits
|
| 450 |
+
|
| 451 |
+
# ------------------------------------------------------------------
|
| 452 |
+
# Weight loading
|
| 453 |
+
# ------------------------------------------------------------------
|
| 454 |
+
|
| 455 |
+
# Mapping from HF checkpoint weight names to vLLM fused names.
|
| 456 |
+
# Mirrors the stacked_params_mapping in Qwen2Model.load_weights().
|
| 457 |
+
_extra_layer_stacked_mapping = [
|
| 458 |
+
# (vllm_param, hf_name, shard_id)
|
| 459 |
+
("qkv_proj", "q_proj", "q"),
|
| 460 |
+
("qkv_proj", "k_proj", "k"),
|
| 461 |
+
("qkv_proj", "v_proj", "v"),
|
| 462 |
+
("gate_up_proj", "gate_proj", 0),
|
| 463 |
+
("gate_up_proj", "up_proj", 1),
|
| 464 |
+
]
|
| 465 |
+
|
| 466 |
+
def _load_extra_layers_weights(self, loaded: set[str]) -> None:
|
| 467 |
+
"""Load extra transformer layer weights from the checkpoint.
|
| 468 |
+
|
| 469 |
+
The checkpoint stores HF-format keys
|
| 470 |
+
(``layers.0.self_attn.q_proj.weight``) which must be mapped to
|
| 471 |
+
vLLM's fused names (``terminator_extra_layers.0.self_attn.qkv_proj``).
|
| 472 |
+
This mirrors the ``stacked_params_mapping`` approach used by
|
| 473 |
+
``Qwen2Model.load_weights()``.
|
| 474 |
+
"""
|
| 475 |
+
if self._extra_layers_state_dict is None:
|
| 476 |
+
return
|
| 477 |
+
|
| 478 |
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
| 479 |
+
|
| 480 |
+
for ckpt_name, tensor in self._extra_layers_state_dict.items():
|
| 481 |
+
# Remap checkpoint prefix to model module path.
|
| 482 |
+
name = ckpt_name.replace("layers.", "terminator_extra_layers.", 1)
|
| 483 |
+
|
| 484 |
+
if "rotary_emb.inv_freq" in name:
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
# Check stacked (fused) projection mapping.
|
| 488 |
+
for param_name, weight_name, shard_id in self._extra_layer_stacked_mapping:
|
| 489 |
+
if weight_name not in name:
|
| 490 |
+
continue
|
| 491 |
+
name = name.replace(weight_name, param_name)
|
| 492 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 493 |
+
continue
|
| 494 |
+
param = params_dict[name]
|
| 495 |
+
weight_loader = getattr(
|
| 496 |
+
param, "weight_loader", default_weight_loader
|
| 497 |
+
)
|
| 498 |
+
if weight_loader == default_weight_loader:
|
| 499 |
+
weight_loader(param, tensor)
|
| 500 |
+
else:
|
| 501 |
+
weight_loader(param, tensor, shard_id)
|
| 502 |
+
loaded.add(name)
|
| 503 |
+
break
|
| 504 |
+
else:
|
| 505 |
+
# Direct (non-fused) parameter — norms, o_proj, down_proj.
|
| 506 |
+
if name.endswith(".bias") and name not in params_dict:
|
| 507 |
+
continue
|
| 508 |
+
if name not in params_dict:
|
| 509 |
+
logger.warning(
|
| 510 |
+
"Skipping extra-layer weight %s (no matching param)",
|
| 511 |
+
name,
|
| 512 |
+
)
|
| 513 |
+
continue
|
| 514 |
+
param = params_dict[name]
|
| 515 |
+
weight_loader = getattr(
|
| 516 |
+
param, "weight_loader", default_weight_loader
|
| 517 |
+
)
|
| 518 |
+
weight_loader(param, tensor)
|
| 519 |
+
loaded.add(name)
|
| 520 |
+
|
| 521 |
+
del self._extra_layers_state_dict
|
| 522 |
+
self._extra_layers_state_dict = None
|
| 523 |
+
logger.info(
|
| 524 |
+
"Terminator extra layers weights loaded from checkpoint",
|
| 525 |
+
)
|
| 526 |
+
|
| 527 |
+
def load_weights(
|
| 528 |
+
self, weights: Iterable[tuple[str, torch.Tensor]]
|
| 529 |
+
) -> set[str]:
|
| 530 |
+
skip = ["terminator_ffn.", "terminator_extra_layers."]
|
| 531 |
+
if self.config.tie_word_embeddings:
|
| 532 |
+
skip.append("lm_head.")
|
| 533 |
+
|
| 534 |
+
loader = AutoWeightsLoader(self, skip_prefixes=skip)
|
| 535 |
+
loaded = loader.load_weights(weights)
|
| 536 |
+
|
| 537 |
+
# Load terminator FFN and extra layers from the separate .pt
|
| 538 |
+
# checkpoint (not from the HF safetensors).
|
| 539 |
+
if self._terminator_enabled:
|
| 540 |
+
self.terminator_ffn.load_state_dict(self._terminator_state_dict)
|
| 541 |
+
del self._terminator_state_dict # free memory
|
| 542 |
+
logger.info(
|
| 543 |
+
"Terminator FFN weights loaded from %s",
|
| 544 |
+
self._terminator_checkpoint_path,
|
| 545 |
+
)
|
| 546 |
+
# Tell vLLM these weights have been handled so it doesn't
|
| 547 |
+
# complain about uninitialized parameters.
|
| 548 |
+
for name in self.terminator_ffn.state_dict():
|
| 549 |
+
loaded.add(f"terminator_ffn.{name}")
|
| 550 |
+
|
| 551 |
+
self._load_extra_layers_weights(loaded)
|
| 552 |
+
|
| 553 |
+
return loaded
|
vllm_terminator/terminator_head.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Terminator FFN head for vLLM integration.
|
| 3 |
+
|
| 4 |
+
Mirrors LayerFFN from terminator_utils.py but uses a standalone RMSNorm
|
| 5 |
+
compatible with checkpoints trained using HuggingFace's Qwen2RMSNorm.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional, Tuple
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SimpleRMSNorm(nn.Module):
|
| 16 |
+
"""RMSNorm with a `weight` parameter matching Qwen2RMSNorm's state_dict."""
|
| 17 |
+
|
| 18 |
+
def __init__(self, hidden_size: int, eps: float = 1e-6):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
| 21 |
+
self.variance_epsilon = eps
|
| 22 |
+
|
| 23 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
input_dtype = hidden_states.dtype
|
| 25 |
+
hidden_states = hidden_states.to(torch.float32)
|
| 26 |
+
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
| 27 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
| 28 |
+
return self.weight * hidden_states.to(input_dtype)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class TerminatorFFN(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
Feed-forward network for per-position binary classification.
|
| 34 |
+
|
| 35 |
+
Architecture mirrors LayerFFN from terminator_utils.py:
|
| 36 |
+
- Pre-normalization with RMSNorm
|
| 37 |
+
- Linear projection(s) to scalar logit
|
| 38 |
+
- No sigmoid (outputs raw logits)
|
| 39 |
+
|
| 40 |
+
State dict keys match training checkpoints exactly:
|
| 41 |
+
- norm.weight
|
| 42 |
+
- network.weight, network.bias (1-layer case)
|
| 43 |
+
- network.0.weight, network.0.bias, ... (multi-layer case)
|
| 44 |
+
"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
hidden_size: int,
|
| 49 |
+
num_hidden_layers: int = 1,
|
| 50 |
+
activation: str = 'gelu',
|
| 51 |
+
intermediate_size: Optional[int] = None,
|
| 52 |
+
dropout: float = 0.0,
|
| 53 |
+
rms_norm_eps: float = 1e-6,
|
| 54 |
+
):
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
self.hidden_size = hidden_size
|
| 58 |
+
self.norm = SimpleRMSNorm(hidden_size, eps=rms_norm_eps)
|
| 59 |
+
|
| 60 |
+
if num_hidden_layers == 1:
|
| 61 |
+
self.network = nn.Linear(hidden_size, 1)
|
| 62 |
+
else:
|
| 63 |
+
if intermediate_size is None:
|
| 64 |
+
intermediate_size = hidden_size * 2
|
| 65 |
+
|
| 66 |
+
layers = []
|
| 67 |
+
layers.append(nn.Linear(hidden_size, intermediate_size))
|
| 68 |
+
|
| 69 |
+
act_fn = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh}
|
| 70 |
+
if activation not in act_fn:
|
| 71 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 72 |
+
|
| 73 |
+
layers.append(act_fn[activation]())
|
| 74 |
+
layers.append(nn.Dropout(dropout))
|
| 75 |
+
|
| 76 |
+
for _ in range(num_hidden_layers - 2):
|
| 77 |
+
layers.append(nn.Linear(intermediate_size, intermediate_size))
|
| 78 |
+
layers.append(act_fn[activation]())
|
| 79 |
+
layers.append(nn.Dropout(dropout))
|
| 80 |
+
|
| 81 |
+
layers.append(nn.Linear(intermediate_size, 1))
|
| 82 |
+
self.network = nn.Sequential(*layers)
|
| 83 |
+
|
| 84 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 85 |
+
"""
|
| 86 |
+
Args:
|
| 87 |
+
hidden_states: [num_tokens, hidden_size] or [batch, seq_len, hidden_size]
|
| 88 |
+
|
| 89 |
+
Returns:
|
| 90 |
+
logits: [num_tokens] or [batch, seq_len] raw logits
|
| 91 |
+
"""
|
| 92 |
+
hidden_states = self.norm(hidden_states)
|
| 93 |
+
output = self.network(hidden_states)
|
| 94 |
+
return output.squeeze(-1)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def load_terminator_checkpoint(
|
| 98 |
+
checkpoint_path: str,
|
| 99 |
+
rms_norm_eps: float = 1e-6,
|
| 100 |
+
device: torch.device = torch.device("cpu"),
|
| 101 |
+
) -> Tuple[TerminatorFFN, Dict[str, Any], int, int, Optional[Dict[str, Any]]]:
|
| 102 |
+
"""
|
| 103 |
+
Load a trained terminator checkpoint and construct the FFN.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
checkpoint_path: Path to layer_*.pt checkpoint file
|
| 107 |
+
rms_norm_eps: Epsilon for RMSNorm (from base model config)
|
| 108 |
+
device: Device to load onto
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
(ffn, config, layer_idx, num_extra_layers, extra_layers_state_dict)
|
| 112 |
+
"""
|
| 113 |
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
| 114 |
+
|
| 115 |
+
config = checkpoint["config"]
|
| 116 |
+
layer_idx = checkpoint["layer_idx"]
|
| 117 |
+
hidden_size = config["hidden_size"]
|
| 118 |
+
|
| 119 |
+
ffn = TerminatorFFN(
|
| 120 |
+
hidden_size=hidden_size,
|
| 121 |
+
num_hidden_layers=config.get("ffn_layers", 1),
|
| 122 |
+
activation=config.get("ffn_activation", "gelu"),
|
| 123 |
+
intermediate_size=config.get("ffn_intermediate_size", None),
|
| 124 |
+
dropout=0.0, # No dropout at inference
|
| 125 |
+
rms_norm_eps=rms_norm_eps,
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
ffn.load_state_dict(checkpoint["state_dict"])
|
| 129 |
+
ffn.to(device)
|
| 130 |
+
ffn.eval()
|
| 131 |
+
|
| 132 |
+
num_extra_layers = config.get("num_extra_layers", 0)
|
| 133 |
+
extra_layers_state_dict = checkpoint.get("extra_layers_state_dict", None)
|
| 134 |
+
|
| 135 |
+
return ffn, config, layer_idx, num_extra_layers, extra_layers_state_dict
|