acnagle commited on
Commit
9907df0
·
verified ·
1 Parent(s): 0061d01

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,3 +1,276 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ language:
3
+ - en
4
+ license: apache-2.0
5
+ library_name: vllm
6
+ tags:
7
+ - reasoning
8
+ - chain-of-thought
9
+ - efficiency
10
+ - inference-optimization
11
+ - qwen3
12
+ base_model: Qwen/Qwen3-14B
13
+ base_model_relation: finetune
14
+ pipeline_tag: text-generation
15
+ ---
16
+
17
+ # Terminator-Qwen3-14B
18
+
19
+ **Terminator** is a lightweight neural module that predicts when a reasoning language model has reached its final answer during chain-of-thought (CoT) generation. When the Terminator detects the model has committed to an answer, it truncates the remaining reasoning and forces the model to begin its response, thereby delivering the same answer with significantly less computation.
20
+
21
+ This repository contains everything needed to run **Terminator-Qwen3-14B**:
22
+
23
+ - Trained Terminator checkpoint (1 extra transformer layer + prediction head)
24
+ - vLLM plugin code (`vllm_terminator/`) for high-performance serving
25
+ - Server launcher and streaming client
26
+ - Standalone HuggingFace inference script (no server required)
27
+ - Automated setup script
28
+
29
+ **Note**: Terminator currently supports **single-GPU, single-sequence inference only**.
30
+
31
+ ---
32
+
33
+ ## Quick Start
34
+
35
+ ```bash
36
+ # 1. Clone the repository (requires Git LFS: https://git-lfs.com)
37
+ git lfs install
38
+ git clone https://huggingface.co/acnagle/Terminator-Qwen3-14B
39
+ cd Terminator-Qwen3-14B
40
+
41
+ # 2. Run automated setup (creates conda env, installs vllm, downloads base model)
42
+ ./setup.sh
43
+
44
+ # 3. Start the server
45
+ ./start_server.sh
46
+
47
+ # 4. In another terminal, chat with the model
48
+ python client.py --interactive
49
+ ```
50
+
51
+ ---
52
+
53
+ ## Requirements
54
+
55
+ - **GPU**: Single NVIDIA GPU with at least ~40GB VRAM (e.g., A100 40GB)
56
+ - **CUDA**: Compatible CUDA driver installed, 12.9 and above recommended.
57
+ - **Python**: 3.12
58
+ - **OS**: Linux (recommended) or any OS supported by vLLM
59
+
60
+ ---
61
+
62
+ ## Installation
63
+
64
+ ### Option A: Automated Setup
65
+
66
+ The `setup.sh` script handles everything:
67
+
68
+ ```bash
69
+ ./setup.sh
70
+ ```
71
+
72
+ This will:
73
+ 1. Create a conda environment called `terminator` with Python 3.12
74
+ 2. Install [uv](https://docs.astral.sh/uv/), [vLLM](https://docs.vllm.ai/), and [openai](https://pypi.org/project/openai/)
75
+ 3. Download Qwen3-14B base model weights (~28GB) from HuggingFace
76
+ 4. Create the model directory (`model_dir/`)
77
+
78
+ ### Option B: Manual Setup
79
+
80
+ **1. Create a Python environment**
81
+
82
+ Using conda or micromamba:
83
+
84
+ ```bash
85
+ conda create -n terminator python=3.12 -y
86
+ conda activate terminator
87
+ ```
88
+
89
+ **2. Install uv**
90
+
91
+ ```bash
92
+ pip install --upgrade uv
93
+ ```
94
+
95
+ Or see the [uv installation guide](https://docs.astral.sh/uv/getting-started/installation/).
96
+
97
+ **3. Install vLLM**
98
+
99
+ ```bash
100
+ uv pip install vllm --torch-backend=auto
101
+ ```
102
+
103
+ See the [vLLM installation guide](https://docs.vllm.ai/en/latest/getting_started/installation/) for alternative installation methods (ROCm, CPU, etc.).
104
+
105
+ **4. Install openai (for the client)**
106
+
107
+ ```bash
108
+ uv pip install openai
109
+ ```
110
+
111
+ **5. Set up the model directory**
112
+
113
+ This downloads the base Qwen3-14B weights and creates a vLLM-ready model directory:
114
+
115
+ ```bash
116
+ python setup_model_dir.py
117
+ ```
118
+
119
+ The script accepts optional arguments:
120
+
121
+ | Argument | Default | Description |
122
+ |----------|---------|-------------|
123
+ | `--checkpoint` | `./terminator.pt` | Path to the Terminator checkpoint |
124
+ | `--output-dir` | `./model_dir` | Output model directory |
125
+ | `--threshold` | `0.7` | Prediction threshold for Terminator activation |
126
+ | `--window-size` | `10` | Sliding window size for majority vote |
127
+ | `--exit-message` | *(built-in message)* | Message injected when Terminator fires |
128
+
129
+ ---
130
+
131
+ ## Starting the Server
132
+
133
+ ```bash
134
+ ./start_server.sh
135
+ ```
136
+
137
+ Or with custom configuration:
138
+
139
+ ```bash
140
+ VLLM_GPU_UTIL=0.70 VLLM_MAX_MODEL_LEN=8192 ./start_server.sh
141
+ ```
142
+
143
+ The server exposes an **OpenAI-compatible API** on the configured port (default: 8000).
144
+
145
+ ### Configuration
146
+
147
+ Set these environment variables before running `start_server.sh` or `serve.py`:
148
+
149
+ | Variable | Default | Description |
150
+ |----------|---------|-------------|
151
+ | `VLLM_GPU_UTIL` | `0.90` | Fraction of GPU memory to use for the model |
152
+ | `VLLM_MAX_MODEL_LEN` | *(auto)* | Maximum context length in tokens |
153
+ | `VLLM_PORT` | `8000` | Server port |
154
+ | `VLLM_ENFORCE_EAGER` | `0` | Set to `1` to disable CUDA graphs |
155
+ | `VLLM_API_KEY` | *(none)* | Require this API key from clients |
156
+ | `VLLM_SERVED_NAME` | `Terminator-Qwen3-14B` | Model name reported by the API |
157
+
158
+ ---
159
+
160
+ ## Standalone Inference (No Server)
161
+
162
+ **Recommendation:** For the best performance, use the vLLM server described above. vLLM uses KV caching, CUDA graphs, and optimized kernels, making it **significantly faster** than HuggingFace-native inference. The script below is provided for quick testing and demos where spinning up a server is inconvenient.
163
+
164
+ For quick testing without starting a vLLM server, use the HuggingFace-native inference script:
165
+
166
+ ```bash
167
+ python inference_hf.py --prompt "What is the sum of the first 100 natural numbers?"
168
+ ```
169
+
170
+ This loads the model directly via HuggingFace `transformers` and runs token-by-token generation with the Terminator head. Thinking content is streamed in dimmed text; the final answer is shown in bold.
171
+
172
+ | Argument | Default | Description |
173
+ |----------|---------|-------------|
174
+ | `--prompt` | *(required)* | Input prompt |
175
+ | `--model` | `Qwen/Qwen3-14B` | HuggingFace model name or path |
176
+ | `--checkpoint` | `./terminator.pt` | Path to the Terminator checkpoint |
177
+ | `--threshold` | `0.7` | Prediction threshold |
178
+ | `--window-size` | `10` | Sliding window size for majority vote |
179
+ | `--exit-message` | *(built-in message)* | Message injected when Terminator fires (empty string to disable) |
180
+ | `--max-tokens` | `32768` | Maximum tokens to generate |
181
+ | `--temperature` | `0.6` | Sampling temperature |
182
+
183
+ ---
184
+
185
+ ## Using the Client (vLLM Server)
186
+
187
+ ### Single Prompt
188
+
189
+ ```bash
190
+ python client.py --prompt "What is the sum of the first 100 natural numbers?"
191
+ ```
192
+
193
+ ### Interactive Mode
194
+
195
+ ```bash
196
+ python client.py --interactive
197
+ ```
198
+
199
+ This starts a multi-turn conversation with the model. Thinking content is displayed in dimmed text; the final answer is shown in bold.
200
+
201
+ ### Client Options
202
+
203
+ | Argument | Default | Description |
204
+ |----------|---------|-------------|
205
+ | `--base-url` | `http://localhost:8000/v1` | Server URL |
206
+ | `--max-tokens` | *(server default)* | Maximum tokens to generate |
207
+ | `--temperature` | `0.6` | Sampling temperature |
208
+
209
+ ### Using the API Directly
210
+
211
+ The server is OpenAI-compatible. You can use any OpenAI client library. Replace `localhost` with your server's address if connecting remotely:
212
+
213
+ ```python
214
+ from openai import OpenAI
215
+
216
+ client = OpenAI(base_url="http://localhost:8000/v1", api_key="EMPTY")
217
+
218
+ response = client.chat.completions.create(
219
+ model="Terminator-Qwen3-14B",
220
+ messages=[{"role": "user", "content": "What is 25 * 37?"}],
221
+ temperature=0.6,
222
+ extra_body={"chat_template_kwargs": {"enable_thinking": True}},
223
+ )
224
+
225
+ # Thinking content (chain-of-thought)
226
+ print(response.choices[0].message.reasoning_content)
227
+
228
+ # Final answer
229
+ print(response.choices[0].message.content)
230
+ ```
231
+
232
+ ---
233
+
234
+ ## How Terminator Works
235
+
236
+ Terminator is a single transformer layer followed by a prediction head, trained on top of a frozen Qwen3-14B base model. The transformer layer (initialized as a copy of the base model's final layer, then fine-tuned) takes the hidden states from the LLM and processes them before the prediction head, which outputs a per-token binary prediction: *has the model reached its final answer?*
237
+
238
+ During generation, Terminator maintains a **sliding window** of the most recent predictions. When a majority of predictions in the window exceed the threshold (default: 0.7), the model is considered to have reached its final answer. At that point:
239
+
240
+ 1. A short **exit message** is injected into the reasoning (e.g., *"I've run out of thinking tokens. I need to commit to a final answer."*) to help the model transition smoothly.
241
+ 2. The `</think>` token is forced, ending the reasoning phase.
242
+ 3. The model generates its final answer normally.
243
+
244
+ This allows the model to skip potentially thousands of redundant reasoning tokens while preserving answer quality.
245
+
246
+ ---
247
+
248
+ ## File Structure
249
+
250
+ ```
251
+ Terminator-Qwen3-14B/
252
+ ├── README.md This file
253
+ ├── terminator.pt Trained Terminator checkpoint
254
+ ├── vllm_terminator/ vLLM plugin package
255
+ │ ├── __init__.py Registers the model architecture with vLLM
256
+ │ ├── model.py Qwen3TerminatorForCausalLM model class
257
+ │ └── terminator_head.py FFN classifier and checkpoint loading
258
+ ├── inference_hf.py Standalone HuggingFace inference (no server)
259
+ ├── serve.py vLLM server launcher
260
+ ├── setup_model_dir.py Model directory setup (downloads base weights)
261
+ ├── client.py Streaming chat client (connects to vLLM server)
262
+ ├── setup.sh Automated setup script
263
+ └── start_server.sh Server launcher with sensible defaults
264
+ ```
265
+
266
+ ---
267
+
268
+ ## Citation
269
+
270
+ *Coming soon.*
271
+
272
+ ---
273
+
274
+ ## License
275
+
276
+ This project builds on [Qwen3-14B](https://huggingface.co/Qwen/Qwen3-14B) by the Qwen team. Please refer to the Qwen3 license for base model usage terms.
client.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Client for the Terminator vLLM server.
4
+
5
+ Supports single-prompt and multi-turn conversation modes with streaming
6
+ output. Thinking content is displayed in dimmed text; answer content in
7
+ normal text.
8
+
9
+ Usage:
10
+ # Single prompt
11
+ python client.py --prompt "What is the sum of the first 100 natural numbers?"
12
+
13
+ # Interactive multi-turn conversation
14
+ python client.py --interactive
15
+
16
+ # Custom server URL and max tokens
17
+ python client.py --base-url http://localhost:8001/v1 --max-tokens 8192 --prompt "Hello"
18
+ """
19
+
20
+ import argparse
21
+ import sys
22
+
23
+ from openai import OpenAI
24
+
25
+
26
+ # ANSI escape codes
27
+ DIM = "\033[2m"
28
+ BOLD = "\033[1m"
29
+ RESET = "\033[0m"
30
+
31
+ BANNER_LINES = [
32
+ r"████████╗███████╗██████╗ ███╗ ███╗██╗███╗ ██╗ █████╗ ████████╗ ██████╗ ██████╗ ",
33
+ r"╚══██╔══╝██╔════╝██╔══██╗████╗ ████║██║████╗ ██║██╔══██╗╚══██╔══╝██╔═══██╗██╔══██╗",
34
+ r" ██║ █████╗ ██████╔╝██╔████╔██║██║██╔██╗ ██║███████║ ██║ ██║ ██║██████╔╝",
35
+ r" ██║ ██╔══╝ ██╔══██╗██║╚██╔╝██║██║██║╚██╗██║██╔══██║ ██║ ██║ ██║██╔══██╗",
36
+ r" ██║ ███████╗██║ ██║██║ ╚═╝ ██║██║██║ ╚████║██║ ██║ ██║ ╚██████╔╝██║ ██║",
37
+ r" ╚═╝ ╚══════╝╚═╝ ╚═╝╚═╝ ╚═╝╚═╝╚═╝ ╚═══╝╚═╝ ╚═╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝",
38
+ ]
39
+
40
+ # Dark red -> light red gradient (one color per row)
41
+ _GRADIENT_RGB = [
42
+ (140, 0, 0),
43
+ (165, 15, 15),
44
+ (190, 35, 35),
45
+ (215, 55, 55),
46
+ (235, 70, 70),
47
+ (255, 90, 90),
48
+ ]
49
+
50
+
51
+ def print_banner() -> None:
52
+ for line, (r, g, b) in zip(BANNER_LINES, _GRADIENT_RGB):
53
+ print(f"\033[38;2;{r};{g};{b}m{line}{RESET}")
54
+
55
+
56
+ def detect_model(client: OpenAI) -> str:
57
+ """Auto-detect the served model name from the server."""
58
+ try:
59
+ models = client.models.list()
60
+ if not models.data:
61
+ print("ERROR: No models available on the server.", file=sys.stderr)
62
+ sys.exit(1)
63
+ return models.data[0].id
64
+ except Exception as e:
65
+ print(f"ERROR: Could not connect to server: {e}", file=sys.stderr)
66
+ sys.exit(1)
67
+
68
+
69
+ def stream_response(
70
+ client: OpenAI,
71
+ model: str,
72
+ messages: list[dict],
73
+ max_tokens: int | None,
74
+ temperature: float,
75
+ ) -> str:
76
+ """Stream a chat completion response.
77
+
78
+ Thinking content is printed in dim text, answer content in normal text.
79
+ Returns the assistant's answer content (for conversation history).
80
+ """
81
+ kwargs = dict(
82
+ model=model,
83
+ messages=messages,
84
+ temperature=temperature,
85
+ stream=True,
86
+ extra_body={"chat_template_kwargs": {"enable_thinking": True}},
87
+ )
88
+ if max_tokens is not None:
89
+ kwargs["max_tokens"] = max_tokens
90
+
91
+ stream = client.chat.completions.create(**kwargs)
92
+
93
+ in_thinking = False
94
+ in_answer = False
95
+ full_content = ""
96
+
97
+ try:
98
+ for chunk in stream:
99
+ if not chunk.choices:
100
+ continue
101
+ delta = chunk.choices[0].delta
102
+
103
+ reasoning = getattr(delta, "reasoning_content", None)
104
+ if reasoning:
105
+ if not in_thinking:
106
+ sys.stdout.write(f"\n{DIM}Thinking...\n")
107
+ in_thinking = True
108
+ sys.stdout.write(reasoning)
109
+ sys.stdout.flush()
110
+
111
+ if delta.content:
112
+ if not in_answer:
113
+ if in_thinking:
114
+ sys.stdout.write(RESET)
115
+ sys.stdout.write(f"\n{BOLD}Answer:{RESET}\n")
116
+ in_answer = True
117
+ sys.stdout.write(delta.content)
118
+ sys.stdout.flush()
119
+ full_content += delta.content
120
+ except KeyboardInterrupt:
121
+ pass
122
+ finally:
123
+ sys.stdout.write(RESET)
124
+ sys.stdout.flush()
125
+
126
+ print()
127
+ return full_content
128
+
129
+
130
+ def run_single(client, model, prompt, max_tokens, temperature):
131
+ """Run a single prompt and exit."""
132
+ messages = [{"role": "user", "content": prompt}]
133
+ stream_response(client, model, messages, max_tokens, temperature)
134
+
135
+
136
+ def run_interactive(client, model, max_tokens, temperature):
137
+ """Interactive multi-turn conversation loop."""
138
+ messages = []
139
+ print()
140
+ print_banner()
141
+ print()
142
+ print(f" Connected to {BOLD}{model}{RESET}")
143
+ print(f" Type your message and press Enter. Type {BOLD}quit{RESET} or Ctrl+C to exit.")
144
+ print(f" {DIM}Note: input is single-line only — compose your full message before pressing Enter.{RESET}")
145
+ print(f" {DIM} Please ensure that copied text is formatted as a single line before pasting.{RESET}")
146
+ print()
147
+
148
+ while True:
149
+ try:
150
+ user_input = input(f"{BOLD}>>>{RESET} ")
151
+ except (KeyboardInterrupt, EOFError):
152
+ print("\nGoodbye!")
153
+ break
154
+
155
+ if user_input.strip().lower() in ("quit", "exit", "q"):
156
+ print("Goodbye!")
157
+ break
158
+
159
+ if not user_input.strip():
160
+ continue
161
+
162
+ messages.append({"role": "user", "content": user_input})
163
+ content = stream_response(client, model, messages, max_tokens, temperature)
164
+ messages.append({"role": "assistant", "content": content})
165
+ print()
166
+
167
+
168
+ def main():
169
+ parser = argparse.ArgumentParser(
170
+ description=__doc__,
171
+ formatter_class=argparse.RawDescriptionHelpFormatter,
172
+ )
173
+ mode = parser.add_mutually_exclusive_group(required=True)
174
+ mode.add_argument("--prompt", type=str, help="Single prompt to send")
175
+ mode.add_argument(
176
+ "--interactive", action="store_true",
177
+ help="Start an interactive multi-turn conversation",
178
+ )
179
+ parser.add_argument(
180
+ "--base-url", default="http://localhost:8000/v1",
181
+ help="vLLM server URL (default: http://localhost:8000/v1)",
182
+ )
183
+ parser.add_argument(
184
+ "--max-tokens", type=int, default=None,
185
+ help="Maximum tokens to generate (default: server decides based on context length)",
186
+ )
187
+ parser.add_argument(
188
+ "--temperature", type=float, default=0.6,
189
+ help="Sampling temperature (default: 0.6)",
190
+ )
191
+ args = parser.parse_args()
192
+
193
+ client = OpenAI(base_url=args.base_url, api_key="EMPTY")
194
+ model = detect_model(client)
195
+
196
+ if args.prompt:
197
+ run_single(client, model, args.prompt, args.max_tokens, args.temperature)
198
+ else:
199
+ run_interactive(client, model, args.max_tokens, args.temperature)
200
+
201
+
202
+ if __name__ == "__main__":
203
+ main()
inference_hf.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ HuggingFace-native inference for Terminator-Qwen3-14B.
4
+
5
+ Loads the frozen Qwen3 base model + trained Terminator head (FFN + optional
6
+ extra transformer layers) directly via HuggingFace transformers.
7
+
8
+ Generates chain-of-thought reasoning token-by-token. The Terminator FFN
9
+ predicts when the final answer has been reached; when a sliding-window
10
+ majority vote exceeds the threshold, an exit message is injected and the
11
+ model transitions to answering mode.
12
+
13
+ Usage:
14
+ python inference_hf.py --prompt "What is the sum of the first 100 natural numbers?"
15
+
16
+ python inference_hf.py \\
17
+ --prompt "Solve x^2 - 5x + 6 = 0" \\
18
+ --model Qwen/Qwen3-14B \\
19
+ --checkpoint terminator.pt \\
20
+ --threshold 0.7 --window-size 10
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ import sys
26
+ from pathlib import Path
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from transformers import AutoModelForCausalLM, AutoTokenizer
31
+ from transformers import TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
32
+ from transformers.generation.logits_process import LogitsProcessorList
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Imports from the project
36
+ # ---------------------------------------------------------------------------
37
+
38
+ # Local: TerminatorFFN + checkpoint loader
39
+ _script_dir = Path(__file__).resolve().parent
40
+ sys.path.insert(0, str(_script_dir))
41
+ from vllm_terminator.terminator_head import load_terminator_checkpoint
42
+
43
+ # Parent dir: ExtraTransformerLayers from terminator_utils
44
+ _repo_root = _script_dir.parent
45
+ sys.path.insert(0, str(_repo_root))
46
+ from terminator_utils import ExtraTransformerLayers
47
+
48
+ # ---------------------------------------------------------------------------
49
+ # ANSI escape codes
50
+ # ---------------------------------------------------------------------------
51
+ DIM = "\033[2m"
52
+ BOLD = "\033[1m"
53
+ RESET = "\033[0m"
54
+
55
+
56
+ def load_model_and_tokenizer(model_name, device):
57
+ """Load base Qwen3 model and tokenizer."""
58
+ print(f"Loading tokenizer: {model_name}")
59
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
60
+ if tokenizer.pad_token is None:
61
+ tokenizer.pad_token = tokenizer.eos_token
62
+
63
+ think_token_id = tokenizer.convert_tokens_to_ids("<think>")
64
+ think_end_token_id = tokenizer.convert_tokens_to_ids("</think>")
65
+ if think_token_id == tokenizer.unk_token_id or think_end_token_id == tokenizer.unk_token_id:
66
+ raise ValueError(
67
+ f"<think>/<think> tokens not in tokenizer! "
68
+ f"IDs: {think_token_id}, {think_end_token_id}"
69
+ )
70
+
71
+ print(f"Loading model: {model_name}")
72
+ model = AutoModelForCausalLM.from_pretrained(
73
+ model_name,
74
+ torch_dtype=torch.bfloat16,
75
+ device_map={"": device},
76
+ trust_remote_code=True,
77
+ )
78
+ for param in model.parameters():
79
+ param.requires_grad = False
80
+ model.eval()
81
+
82
+ print(
83
+ f"Model loaded: {model.config.num_hidden_layers} layers, "
84
+ f"hidden size {model.config.hidden_size}"
85
+ )
86
+ return model, tokenizer, think_token_id, think_end_token_id
87
+
88
+
89
+ def build_extra_layers(base_model, checkpoint_config, extra_layers_state_dict, device):
90
+ """Reconstruct extra transformer layers from checkpoint state dict."""
91
+ num_extra_layers = checkpoint_config.get("num_extra_layers", 0)
92
+ if num_extra_layers == 0 or extra_layers_state_dict is None:
93
+ return None
94
+
95
+ print(f"Reconstructing {num_extra_layers} extra transformer layer(s)...")
96
+ base_layer_class = base_model.model.layers[0].__class__
97
+ model_config = base_model.config
98
+ rotary_emb = getattr(base_model.model, "rotary_emb", None)
99
+
100
+ extra_layers = ExtraTransformerLayers(
101
+ base_layer_class, num_extra_layers, model_config, rotary_emb=rotary_emb
102
+ ).to(device)
103
+ extra_layers.load_state_dict(extra_layers_state_dict)
104
+ extra_layers.eval()
105
+
106
+ param_count = sum(p.numel() for p in extra_layers.parameters())
107
+ print(f"Extra layers loaded ({param_count:,} parameters)")
108
+ return extra_layers
109
+
110
+
111
+ def generate_with_terminator(
112
+ prompt,
113
+ model,
114
+ tokenizer,
115
+ ffn,
116
+ extra_layers,
117
+ layer_idx,
118
+ think_token_id,
119
+ think_end_token_id,
120
+ threshold,
121
+ window_size,
122
+ exit_message,
123
+ max_tokens,
124
+ temperature,
125
+ device,
126
+ ):
127
+ """Generate a response with Terminator early-exit logic.
128
+
129
+ Follows the same generation pattern as inference_terminator.py:mode1_generate().
130
+ Streams thinking tokens to the terminal as they are produced.
131
+ """
132
+ # Format prompt via chat template
133
+ messages = [{"role": "user", "content": prompt}]
134
+ prompt_text = tokenizer.apply_chat_template(
135
+ messages, tokenize=False, add_generation_prompt=True
136
+ )
137
+
138
+ # Tokenize and append <think>
139
+ prompt_ids = tokenizer(
140
+ prompt_text, add_special_tokens=False, return_tensors="pt"
141
+ )["input_ids"].to(device).long()
142
+
143
+ input_ids = torch.cat(
144
+ [prompt_ids, torch.tensor([[think_token_id]], dtype=torch.long, device=device)],
145
+ dim=1,
146
+ )
147
+
148
+ # Sampling processors
149
+ logits_processor = LogitsProcessorList([
150
+ TemperatureLogitsWarper(temperature=temperature),
151
+ TopKLogitsWarper(top_k=20),
152
+ TopPLogitsWarper(top_p=0.95),
153
+ ])
154
+
155
+ # Sliding-window state
156
+ predictions_list = []
157
+ reasoning_tokens = []
158
+ early_exit = False
159
+
160
+ # Start streaming thinking output
161
+ sys.stdout.write(f"\n{DIM}Thinking...\n")
162
+ sys.stdout.flush()
163
+
164
+ for step in range(max_tokens):
165
+ attention_mask = torch.ones_like(input_ids)
166
+
167
+ # Hook to capture hidden states from the target layer
168
+ captured = {}
169
+
170
+ def hook_fn(module, input, output):
171
+ if isinstance(output, tuple):
172
+ captured["hidden"] = output[0].detach()
173
+ else:
174
+ captured["hidden"] = output.detach()
175
+
176
+ target_layer = model.model.layers[layer_idx]
177
+ handle = target_layer.register_forward_hook(hook_fn)
178
+
179
+ with torch.no_grad():
180
+ outputs = model(
181
+ input_ids=input_ids,
182
+ attention_mask=attention_mask,
183
+ use_cache=False,
184
+ )
185
+
186
+ handle.remove()
187
+
188
+ hidden_states = captured["hidden"] # [1, seq_len, hidden_size]
189
+
190
+ # Make prediction once we have at least one thinking token
191
+ if len(reasoning_tokens) > 0:
192
+ if extra_layers is not None:
193
+ h = hidden_states.float()
194
+ h = extra_layers(h, attention_mask=attention_mask)
195
+ last_h = h[:, -1:, :]
196
+ logits_pred = ffn(last_h.float())
197
+ else:
198
+ last_h = hidden_states[:, -1:, :]
199
+ logits_pred = ffn(last_h.float())
200
+
201
+ pred = torch.sigmoid(logits_pred)
202
+ predictions_list.append(pred[0, 0].item())
203
+
204
+ # Sliding-window majority vote
205
+ if len(predictions_list) >= window_size:
206
+ window = predictions_list[-window_size:]
207
+ n_above = sum(1 for p in window if p > threshold)
208
+ if n_above / window_size > 0.5:
209
+ early_exit = True
210
+ break
211
+
212
+ # Sample next token — LogitsProcessorList expects 2D [batch, vocab]
213
+ next_logits = outputs.logits[:, -1, :] # [1, vocab_size]
214
+ next_logits = logits_processor(input_ids, next_logits)
215
+ probs = F.softmax(next_logits, dim=-1)
216
+ next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
217
+
218
+ # Natural </think>
219
+ if next_token.item() == think_end_token_id:
220
+ break
221
+
222
+ input_ids = torch.cat([input_ids, next_token], dim=1)
223
+ reasoning_tokens.append(next_token.item())
224
+
225
+ # Stream the token
226
+ token_text = tokenizer.decode([next_token.item()], skip_special_tokens=False)
227
+ sys.stdout.write(token_text)
228
+ sys.stdout.flush()
229
+
230
+ # End thinking section
231
+ if early_exit and exit_message:
232
+ sys.stdout.write(exit_message)
233
+ sys.stdout.write(f"{RESET}\n")
234
+ sys.stdout.flush()
235
+
236
+ # Build input for final answer generation
237
+ if early_exit and exit_message:
238
+ exit_ids = tokenizer(
239
+ exit_message, add_special_tokens=False, return_tensors="pt"
240
+ )["input_ids"].to(device).long()
241
+ input_ids = torch.cat(
242
+ [input_ids, exit_ids,
243
+ torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)],
244
+ dim=1,
245
+ )
246
+ else:
247
+ input_ids = torch.cat(
248
+ [input_ids,
249
+ torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)],
250
+ dim=1,
251
+ )
252
+
253
+ # Generate final answer
254
+ attention_mask = torch.ones_like(input_ids)
255
+ with torch.no_grad():
256
+ final_outputs = model.generate(
257
+ input_ids=input_ids,
258
+ attention_mask=attention_mask,
259
+ max_new_tokens=max_tokens,
260
+ do_sample=True,
261
+ temperature=temperature,
262
+ top_p=0.95,
263
+ top_k=20,
264
+ pad_token_id=tokenizer.pad_token_id,
265
+ eos_token_id=tokenizer.eos_token_id,
266
+ )
267
+
268
+ # Extract answer (everything after last </think>)
269
+ full_seq = final_outputs[0]
270
+ end_positions = (full_seq == think_end_token_id).nonzero(as_tuple=True)[0]
271
+ if len(end_positions) > 0:
272
+ answer_tokens = full_seq[end_positions[-1].item() + 1 :]
273
+ answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
274
+ else:
275
+ answer = ""
276
+
277
+ # Print answer
278
+ sys.stdout.write(f"{BOLD}Answer:{RESET}\n{answer}\n")
279
+ sys.stdout.flush()
280
+
281
+ # Summary
282
+ n_reasoning = len(reasoning_tokens)
283
+ exit_reason = "predictor" if early_exit else "natural_end"
284
+ print(
285
+ f"\n{DIM}[{exit_reason} | "
286
+ f"{n_reasoning} thinking tokens | "
287
+ f"{len(predictions_list)} predictions]{RESET}"
288
+ )
289
+
290
+
291
+ def main():
292
+ parser = argparse.ArgumentParser(
293
+ description=__doc__,
294
+ formatter_class=argparse.RawDescriptionHelpFormatter,
295
+ )
296
+ parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
297
+ parser.add_argument(
298
+ "--model", type=str, default="Qwen/Qwen3-14B", help="HuggingFace model name"
299
+ )
300
+ parser.add_argument(
301
+ "--checkpoint",
302
+ type=str,
303
+ default=None,
304
+ help="Path to terminator .pt checkpoint (default: ./terminator.pt)",
305
+ )
306
+ parser.add_argument(
307
+ "--threshold", type=float, default=0.7, help="Per-prediction binarization threshold"
308
+ )
309
+ parser.add_argument(
310
+ "--window-size", type=int, default=10, help="Sliding-window size for majority vote"
311
+ )
312
+ parser.add_argument(
313
+ "--exit-message",
314
+ type=str,
315
+ default="\nI've run out of thinking tokens. I need to commit to a final answer.",
316
+ help="Message injected when terminator fires (empty string to disable)",
317
+ )
318
+ parser.add_argument(
319
+ "--max-tokens", type=int, default=32768, help="Max tokens to generate"
320
+ )
321
+ parser.add_argument(
322
+ "--temperature", type=float, default=0.6, help="Sampling temperature"
323
+ )
324
+ parser.add_argument(
325
+ "--device", type=str, default="cuda", help="Device (default: cuda)"
326
+ )
327
+ args = parser.parse_args()
328
+
329
+ # Resolve checkpoint path
330
+ if args.checkpoint is None:
331
+ args.checkpoint = str(_script_dir / "terminator.pt")
332
+
333
+ if not Path(args.checkpoint).exists():
334
+ print(f"ERROR: Checkpoint not found: {args.checkpoint}", file=sys.stderr)
335
+ sys.exit(1)
336
+
337
+ # Handle empty exit message
338
+ if args.exit_message == "":
339
+ args.exit_message = None
340
+
341
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
342
+
343
+ # Load base model
344
+ model, tokenizer, think_id, think_end_id = load_model_and_tokenizer(
345
+ args.model, device
346
+ )
347
+
348
+ # Load terminator checkpoint
349
+ rms_eps = getattr(model.config, "rms_norm_eps", 1e-6)
350
+ ffn, ckpt_config, layer_idx, num_extra_layers, extra_sd = load_terminator_checkpoint(
351
+ args.checkpoint, rms_norm_eps=rms_eps, device=device
352
+ )
353
+ ffn_params = sum(p.numel() for p in ffn.parameters())
354
+ print(
355
+ f"Terminator FFN loaded (layer_idx={layer_idx}, "
356
+ f"threshold={args.threshold}, window={args.window_size}, "
357
+ f"params={ffn_params:,})"
358
+ )
359
+
360
+ # Extra layers
361
+ extra_layers = build_extra_layers(model, ckpt_config, extra_sd, device)
362
+
363
+ # Generate
364
+ generate_with_terminator(
365
+ prompt=args.prompt,
366
+ model=model,
367
+ tokenizer=tokenizer,
368
+ ffn=ffn,
369
+ extra_layers=extra_layers,
370
+ layer_idx=layer_idx,
371
+ think_token_id=think_id,
372
+ think_end_token_id=think_end_id,
373
+ threshold=args.threshold,
374
+ window_size=args.window_size,
375
+ exit_message=args.exit_message,
376
+ max_tokens=args.max_tokens,
377
+ temperature=args.temperature,
378
+ device=device,
379
+ )
380
+
381
+
382
+ if __name__ == "__main__":
383
+ main()
serve.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ vLLM API server launcher for Qwen3TerminatorForCausalLM.
4
+
5
+ Imports vllm_terminator BEFORE vLLM initialises, which registers
6
+ Qwen3TerminatorForCausalLM with vLLM's ModelRegistry.
7
+
8
+ NOTE: Terminator currently supports single-GPU, single-sequence inference only.
9
+ Tensor parallelism and concurrent sequences are not supported.
10
+
11
+ Environment variables:
12
+ VLLM_MODEL — path to terminator model directory (required)
13
+ VLLM_PORT — port (default 8000)
14
+ VLLM_GPU_UTIL — GPU memory fraction (default 0.90)
15
+ VLLM_MAX_MODEL_LEN — max context length
16
+ VLLM_DTYPE — dtype (default "auto")
17
+ VLLM_API_KEY — require this API key from clients
18
+ VLLM_SERVED_NAME — override served model name
19
+ VLLM_HOST — bind address (default 0.0.0.0)
20
+ NO_PREFIX_CACHING — set to 1 to disable prefix caching
21
+ VLLM_ENFORCE_EAGER — set to 1 to disable CUDA graphs (default 0)
22
+ REASONING_PARSER — set to "qwen3" to enable <think>/</think> parsing
23
+ (splits reasoning_content from content in API responses)
24
+
25
+ Example:
26
+ VLLM_MODEL=./model_dir python serve.py
27
+ """
28
+
29
+ import os
30
+ import runpy
31
+ import sys
32
+
33
+ # -----------------------------------------------------------------------
34
+ # CRITICAL: import vllm_terminator HERE, before any vLLM code runs.
35
+ # This registers Qwen3TerminatorForCausalLM with vLLM's ModelRegistry.
36
+ # -----------------------------------------------------------------------
37
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
38
+ import vllm_terminator # noqa: F401 (registers the model as a side effect)
39
+
40
+
41
+ def env(name, default=None, required=False):
42
+ v = os.environ.get(name, default)
43
+ if required and (v is None or v == ""):
44
+ print(f"Missing required env var: {name}", file=sys.stderr)
45
+ sys.exit(2)
46
+ return v
47
+
48
+
49
+ def main():
50
+ model = env("VLLM_MODEL", required=True)
51
+ host = env("VLLM_HOST", "0.0.0.0")
52
+ port = env("VLLM_PORT", "8000")
53
+ max_len = env("VLLM_MAX_MODEL_LEN", None)
54
+ gpu_util = env("VLLM_GPU_UTIL", "0.90")
55
+ served_name = env("VLLM_SERVED_NAME", None)
56
+ dtype = env("VLLM_DTYPE", "auto")
57
+ api_key = env("VLLM_API_KEY", None)
58
+ no_prefix_caching = env("NO_PREFIX_CACHING", "0")
59
+ enforce_eager = env("VLLM_ENFORCE_EAGER", "0")
60
+ reasoning_parser = env("REASONING_PARSER", None)
61
+
62
+ argv = [
63
+ "vllm.entrypoints.openai.api_server",
64
+ "--model", model,
65
+ "--host", host,
66
+ "--port", str(port),
67
+ "--dtype", dtype,
68
+ "--gpu-memory-utilization", str(gpu_util),
69
+ "--tensor-parallel-size", "1",
70
+ "--max-num-seqs", "1",
71
+ ]
72
+
73
+ if served_name:
74
+ argv += ["--served-model-name", served_name]
75
+ if max_len:
76
+ argv += ["--max-model-len", str(max_len)]
77
+ if api_key:
78
+ argv += ["--api-key", api_key]
79
+ if no_prefix_caching == "1":
80
+ argv += ["--enable-prefix-caching", "False"]
81
+ if enforce_eager == "1":
82
+ argv += ["--enforce-eager"]
83
+ if reasoning_parser:
84
+ argv += ["--reasoning-parser", reasoning_parser]
85
+
86
+ print(f"Launching vLLM Terminator server with:\n " + " ".join(argv[1:]), flush=True)
87
+
88
+ # Replace sys.argv so vLLM's argparse sees these arguments, then run the
89
+ # server module in-process (so vllm_terminator registration persists).
90
+ sys.argv = argv
91
+ runpy.run_module("vllm.entrypoints.openai.api_server", run_name="__main__")
92
+
93
+
94
+ if __name__ == "__main__":
95
+ main()
setup.sh ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # ==========================================================================
5
+ # Terminator-Qwen3-14B — Automated Setup
6
+ #
7
+ # This script:
8
+ # 1. Creates a conda environment with Python 3.12
9
+ # 2. Installs uv, vllm, and openai
10
+ # 3. Downloads Qwen3-14B base model weights and creates the model directory
11
+ #
12
+ # Prerequisites:
13
+ # - NVIDIA GPU with sufficient VRAM (minimum ~40GB for Qwen3-14B)
14
+ # - CUDA drivers installed
15
+ # - conda or micromamba installed
16
+ #
17
+ # Usage:
18
+ # ./setup.sh
19
+ # ==========================================================================
20
+
21
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
22
+ ENV_NAME="${TERMINATOR_ENV_NAME:-terminator}"
23
+
24
+ echo ""
25
+ echo "====================================="
26
+ echo " Terminator-Qwen3-14B Setup"
27
+ echo "====================================="
28
+ echo ""
29
+
30
+ # ------------------------------------------------------------------
31
+ # Step 1: Create conda environment
32
+ # ------------------------------------------------------------------
33
+
34
+ # Detect conda or micromamba
35
+ if command -v micromamba &>/dev/null; then
36
+ CONDA_CMD="micromamba"
37
+ elif command -v conda &>/dev/null; then
38
+ CONDA_CMD="conda"
39
+ else
40
+ echo "ERROR: Neither conda nor micromamba found."
41
+ echo ""
42
+ echo "Install micromamba:"
43
+ echo ' "${SHELL}" <(curl -L micro.mamba.pm/install.sh)'
44
+ echo ""
45
+ echo "Or install conda:"
46
+ echo " https://docs.conda.io/en/latest/miniconda.html"
47
+ exit 1
48
+ fi
49
+
50
+ echo "[1/3] Setting up Python environment..."
51
+
52
+ # Check if environment already exists
53
+ if $CONDA_CMD env list 2>/dev/null | grep -q "^${ENV_NAME} \|/${ENV_NAME}\$"; then
54
+ echo " Environment '${ENV_NAME}' already exists. Activating..."
55
+ else
56
+ echo " Creating environment '${ENV_NAME}' with Python 3.12..."
57
+ $CONDA_CMD create -n "${ENV_NAME}" python=3.12 -y
58
+ fi
59
+
60
+ # Helper: run a command inside the conda/micromamba environment
61
+ run_in_env() {
62
+ if [ "$CONDA_CMD" = "micromamba" ]; then
63
+ micromamba run -n "${ENV_NAME}" "$@"
64
+ else
65
+ # For conda, activate in a subshell
66
+ (eval "$(conda shell.bash hook 2>/dev/null)" && conda activate "${ENV_NAME}" && "$@")
67
+ fi
68
+ }
69
+
70
+ echo " Python: $(run_in_env python --version)"
71
+
72
+ # ------------------------------------------------------------------
73
+ # Step 2: Install packages
74
+ # ------------------------------------------------------------------
75
+
76
+ echo ""
77
+ echo "[2/3] Installing packages..."
78
+
79
+ echo " Installing uv..."
80
+ run_in_env pip install --upgrade uv --quiet
81
+
82
+ echo " Installing vllm (this may take a few minutes)..."
83
+ run_in_env uv pip install vllm --torch-backend=auto
84
+
85
+ echo " Installing openai (for client)..."
86
+ run_in_env uv pip install openai
87
+
88
+ echo " Installing accelerate (for HF inference)..."
89
+ run_in_env uv pip install accelerate
90
+
91
+ echo " Done."
92
+
93
+ # ------------------------------------------------------------------
94
+ # Step 3: Set up model directory
95
+ # ------------------------------------------------------------------
96
+
97
+ echo ""
98
+ echo "[3/3] Setting up model directory..."
99
+ echo " This downloads Qwen3-14B base weights (~28GB) from HuggingFace."
100
+ echo " (Skipped if already cached.)"
101
+ echo ""
102
+
103
+ cd "$SCRIPT_DIR"
104
+ run_in_env python setup_model_dir.py
105
+
106
+ # ------------------------------------------------------------------
107
+ # Done
108
+ # ------------------------------------------------------------------
109
+
110
+ echo ""
111
+ echo "====================================="
112
+ echo " Setup Complete!"
113
+ echo "====================================="
114
+ echo ""
115
+ echo "To start the server:"
116
+ echo " $CONDA_CMD activate ${ENV_NAME}"
117
+ echo " cd $SCRIPT_DIR"
118
+ echo " ./start_server.sh"
119
+ echo ""
120
+ echo "Then in another terminal:"
121
+ echo " $CONDA_CMD activate ${ENV_NAME}"
122
+ echo " cd $SCRIPT_DIR"
123
+ echo " python client.py --interactive"
124
+ echo ""
125
+ echo "See README.md for configuration options (GPU memory, context length, etc.)"
setup_model_dir.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Create a vLLM-ready model directory for Qwen3TerminatorForCausalLM.
4
+
5
+ Downloads the base Qwen3-14B config and weights from HuggingFace (if not
6
+ already cached), then creates a model directory with:
7
+ - config.json (Qwen3-14B base config + terminator fields)
8
+ - tokenizer files (symlinked from HF cache)
9
+ - model weights (symlinked from HF cache)
10
+
11
+ Usage:
12
+ # Default: uses ./terminator.pt checkpoint, creates ./model_dir
13
+ python setup_model_dir.py
14
+
15
+ # Custom paths and settings:
16
+ python setup_model_dir.py \\
17
+ --checkpoint /path/to/terminator.pt \\
18
+ --output-dir /path/to/model_dir \\
19
+ --threshold 0.5
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+ import sys
25
+ from pathlib import Path
26
+
27
+ from huggingface_hub import snapshot_download
28
+ from transformers import AutoConfig
29
+
30
+
31
+ def main():
32
+ parser = argparse.ArgumentParser(
33
+ description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
34
+ )
35
+ parser.add_argument(
36
+ "--base-model", default="Qwen/Qwen3-14B",
37
+ help="HuggingFace model ID for the base model (default: Qwen/Qwen3-14B).",
38
+ )
39
+ parser.add_argument(
40
+ "--checkpoint", type=Path, default="./terminator.pt",
41
+ help="Path to trained terminator .pt checkpoint (default: ./terminator.pt).",
42
+ )
43
+ parser.add_argument(
44
+ "--output-dir", type=Path, default="./model_dir",
45
+ help="Destination directory (default: ./model_dir; created if missing).",
46
+ )
47
+ parser.add_argument(
48
+ "--threshold", type=float, default=0.7,
49
+ help="Terminator firing threshold (default 0.7).",
50
+ )
51
+ parser.add_argument(
52
+ "--window-size", type=int, default=10,
53
+ help="Sliding window size for majority vote (default 10).",
54
+ )
55
+ parser.add_argument(
56
+ "--exit-message", type=str,
57
+ default="\nI've run out of thinking tokens. I need to commit to a final answer.",
58
+ help="Message forced when terminator fires (default: standard exit message). "
59
+ "Set to empty string to disable.",
60
+ )
61
+ parser.add_argument(
62
+ "--no-download", action="store_true",
63
+ help="Fail if the base model is not already cached locally "
64
+ "(by default, downloads from HuggingFace if needed).",
65
+ )
66
+ parser.add_argument(
67
+ "--force", action="store_true",
68
+ help="Overwrite files in existing output directory.",
69
+ )
70
+ args = parser.parse_args()
71
+
72
+ checkpoint = args.checkpoint.resolve()
73
+ out_dir = args.output_dir.resolve()
74
+
75
+ if not checkpoint.is_file():
76
+ print(f"ERROR: checkpoint not found: {checkpoint}", file=sys.stderr)
77
+ sys.exit(1)
78
+
79
+ out_dir.mkdir(parents=True, exist_ok=True)
80
+
81
+ # --- Build patched config.json ---
82
+ print(f"Loading config for {args.base_model} from HF cache...")
83
+ config = AutoConfig.from_pretrained(args.base_model)
84
+
85
+ config.architectures = ["Qwen3TerminatorForCausalLM"]
86
+ config.terminator_checkpoint_path = str(checkpoint)
87
+ config.terminator_threshold = args.threshold
88
+ config.terminator_window_size = args.window_size
89
+ config.terminator_exit_message = args.exit_message
90
+
91
+ # Remove auto_map if present from an older span-predictor config
92
+ if hasattr(config, "auto_map"):
93
+ del config.auto_map
94
+
95
+ config.save_pretrained(out_dir)
96
+ print(f" Wrote config.json -> {out_dir / 'config.json'}")
97
+
98
+ # --- Symlink weights and tokenizer files from HF cache ---
99
+ print(f"Locating {args.base_model} in HF cache...")
100
+ allow_download = not args.no_download
101
+ base_dir = Path(snapshot_download(args.base_model, local_files_only=not allow_download))
102
+ print(f" Found: {base_dir}")
103
+
104
+ linked = 0
105
+ for src in sorted(base_dir.iterdir()):
106
+ if src.name in ("config.json",):
107
+ continue # we already wrote our own
108
+
109
+ dst = out_dir / src.name
110
+ if dst.exists() or dst.is_symlink():
111
+ if args.force:
112
+ dst.unlink()
113
+ else:
114
+ continue
115
+
116
+ os.symlink(src, dst)
117
+ print(f" Linked {src.name}")
118
+ linked += 1
119
+
120
+ print(f"\nDone. Linked {linked} files into {out_dir}")
121
+ print(f"\nTo start the server:")
122
+ print(f" ./start_server.sh")
123
+ print(f"\nOr manually:")
124
+ print(f" VLLM_MODEL={out_dir} REASONING_PARSER=qwen3 python serve.py")
125
+
126
+
127
+ if __name__ == "__main__":
128
+ main()
start_server.sh ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ # ==========================================================================
5
+ # Terminator-Qwen3-14B — Server Launcher
6
+ #
7
+ # Starts the vLLM server with the Terminator model.
8
+ # Run setup.sh first to create the model directory.
9
+ #
10
+ # Configuration (set as environment variables before running):
11
+ #
12
+ # VLLM_GPU_UTIL GPU memory fraction to use (default: 0.90)
13
+ #
14
+ # VLLM_MAX_MODEL_LEN Maximum context length in tokens (default: server picks)
15
+ #
16
+ # VLLM_PORT Server port (default: 8000)
17
+ #
18
+ # VLLM_ENFORCE_EAGER Set to 1 to disable CUDA graphs (default: 0)
19
+ # Use if you encounter CUDA graph compilation errors.
20
+ # NOTE: VLLM_ENFORCE_EAGER=0 will result in slower responses
21
+ #
22
+ # VLLM_API_KEY Require this API key from clients (default: none)
23
+ #
24
+ # Usage:
25
+ # ./start_server.sh
26
+ # or to manually override default environment variables:
27
+ # VLLM_GPU_UTIL=0.70 VLLM_MAX_MODEL_LEN=8192 ./start_server.sh
28
+ # ==========================================================================
29
+
30
+ SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
31
+ MODEL_DIR="${SCRIPT_DIR}/model_dir"
32
+
33
+ if [ ! -d "$MODEL_DIR" ]; then
34
+ echo "ERROR: Model directory not found at: $MODEL_DIR" >&2
35
+ echo "" >&2
36
+ echo "Run setup first:" >&2
37
+ echo " ./setup.sh" >&2
38
+ echo "" >&2
39
+ echo "Or manually:" >&2
40
+ echo " python setup_model_dir.py" >&2
41
+ exit 1
42
+ fi
43
+
44
+ export VLLM_MODEL="$MODEL_DIR"
45
+ export REASONING_PARSER="${REASONING_PARSER:-qwen3}"
46
+ export VLLM_SERVED_NAME="${VLLM_SERVED_NAME:-Terminator-Qwen3-14B}"
47
+
48
+ exec python "$SCRIPT_DIR/serve.py"
terminator.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:94deb40f18fe3d3e642b0dd1b07d0833b1363d389c0ed01df2946b533b715c97
3
+ size 1321295486
vllm_terminator/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vLLM plugin that registers the Qwen3 Terminator model.
3
+
4
+ Usage:
5
+ import vllm_terminator # registers the model architecture
6
+
7
+ Then set ``"architectures": ["Qwen3TerminatorForCausalLM"]`` in the
8
+ HuggingFace config.json alongside::
9
+
10
+ "terminator_checkpoint_path": "/path/to/layer_-1.pt",
11
+ "terminator_threshold": 0.7
12
+ """
13
+
14
+ from vllm import ModelRegistry
15
+
16
+ ModelRegistry.register_model(
17
+ "Qwen3TerminatorForCausalLM",
18
+ "vllm_terminator.model:Qwen3TerminatorForCausalLM",
19
+ )
vllm_terminator/model.py ADDED
@@ -0,0 +1,553 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ vLLM-compatible Qwen3 model with Terminator FFN for early reasoning truncation.
3
+
4
+ The Terminator predicts when chain-of-thought reasoning has reached the final
5
+ answer. When a majority of recent predictions (sliding window) exceed a
6
+ threshold, the model forces generation of a configurable exit message followed
7
+ by </think> to truncate reasoning early. The exit message helps the model
8
+ transition smoothly from thinking to answering mode.
9
+
10
+ Supports optional extra transformer layers between the base model and the FFN
11
+ head. These layers get their own KV cache via vLLM's auto-discovery mechanism.
12
+
13
+ Constraints:
14
+ - layer_idx = -1 (last layer)
15
+ - sliding_window strategy (majority vote over last window_size predictions)
16
+ - lag_size = 0
17
+ - batch_size = 1
18
+ """
19
+
20
+ from collections import deque
21
+ from collections.abc import Iterable
22
+ from itertools import islice
23
+
24
+ import torch
25
+ from torch import nn
26
+ from transformers import AutoTokenizer
27
+
28
+ from vllm.config import VllmConfig
29
+ from vllm.distributed import get_pp_group
30
+ from vllm.logger import init_logger
31
+ from vllm.model_executor.layers.logits_processor import LogitsProcessor
32
+ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
33
+ from vllm.model_executor.models.interfaces import SupportsPP
34
+ from vllm.model_executor.models.qwen2 import Qwen2Model
35
+ from vllm.model_executor.models.qwen3 import Qwen3DecoderLayer
36
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
37
+ from vllm.model_executor.models.utils import (
38
+ AutoWeightsLoader,
39
+ PPMissingLayer,
40
+ maybe_prefix,
41
+ )
42
+ from vllm.sequence import IntermediateTensors
43
+
44
+ from .terminator_head import TerminatorFFN
45
+
46
+ logger = init_logger(__name__)
47
+
48
+
49
+ class Qwen3TerminatorModel(Qwen2Model):
50
+ """Qwen3 backbone that captures pre-norm hidden states for the Terminator FFN."""
51
+
52
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
53
+ super().__init__(
54
+ vllm_config=vllm_config,
55
+ prefix=prefix,
56
+ decoder_layer_type=Qwen3DecoderLayer,
57
+ )
58
+ self._pre_norm_hidden: torch.Tensor | None = None
59
+
60
+ def forward(
61
+ self,
62
+ input_ids: torch.Tensor,
63
+ positions: torch.Tensor,
64
+ intermediate_tensors: IntermediateTensors | None = None,
65
+ inputs_embeds: torch.Tensor | None = None,
66
+ ) -> torch.Tensor | IntermediateTensors:
67
+ if get_pp_group().is_first_rank:
68
+ if inputs_embeds is not None:
69
+ hidden_states = inputs_embeds
70
+ else:
71
+ hidden_states = self.embed_input_ids(input_ids)
72
+ residual = None
73
+ else:
74
+ assert intermediate_tensors is not None
75
+ hidden_states = intermediate_tensors["hidden_states"]
76
+ residual = intermediate_tensors["residual"]
77
+
78
+ aux_hidden_states = []
79
+ for idx, layer in enumerate(
80
+ islice(self.layers, self.start_layer, self.end_layer)
81
+ ):
82
+ if idx in self.aux_hidden_state_layers:
83
+ aux_hidden_states.append(hidden_states + residual)
84
+ hidden_states, residual = layer(positions, hidden_states, residual)
85
+
86
+ if not get_pp_group().is_last_rank:
87
+ return IntermediateTensors(
88
+ {"hidden_states": hidden_states, "residual": residual}
89
+ )
90
+
91
+ # Capture pre-norm hidden states (matches training's forward hook output).
92
+ # In the fused residual pattern, pre-norm state = hidden_states + residual.
93
+ self._pre_norm_hidden = hidden_states + residual
94
+
95
+ hidden_states, _ = self.norm(hidden_states, residual)
96
+
97
+ if len(aux_hidden_states) > 0:
98
+ return hidden_states, aux_hidden_states
99
+
100
+ return hidden_states
101
+
102
+
103
+ class Qwen3TerminatorForCausalLM(nn.Module, SupportsPP):
104
+ """
105
+ Qwen3 causal LM with an attached Terminator FFN that can force </think>
106
+ generation when the model predicts the final answer has been reached.
107
+ """
108
+
109
+ packed_modules_mapping = {
110
+ "qkv_proj": ["q_proj", "k_proj", "v_proj"],
111
+ "gate_up_proj": ["gate_proj", "up_proj"],
112
+ }
113
+
114
+ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
115
+ super().__init__()
116
+ config = vllm_config.model_config.hf_config
117
+ quant_config = vllm_config.quant_config
118
+
119
+ self.config = config
120
+ self.quant_config = quant_config
121
+
122
+ # --- Base Qwen3 model ---
123
+ self.model = Qwen3TerminatorModel(
124
+ vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
125
+ )
126
+
127
+ if get_pp_group().is_last_rank:
128
+ if config.tie_word_embeddings:
129
+ self.lm_head = self.model.embed_tokens
130
+ else:
131
+ self.lm_head = ParallelLMHead(
132
+ config.vocab_size,
133
+ config.hidden_size,
134
+ quant_config=quant_config,
135
+ prefix=maybe_prefix(prefix, "lm_head"),
136
+ )
137
+ else:
138
+ self.lm_head = PPMissingLayer()
139
+
140
+ self.logits_processor = LogitsProcessor(config.vocab_size)
141
+ self.make_empty_intermediate_tensors = (
142
+ self.model.make_empty_intermediate_tensors
143
+ )
144
+
145
+ # --- Terminator FFN ---
146
+ terminator_checkpoint_path = getattr(
147
+ config, "terminator_checkpoint_path", None
148
+ )
149
+ self._terminator_threshold = getattr(
150
+ config, "terminator_threshold", 0.7
151
+ )
152
+ self._terminator_window_size = getattr(
153
+ config, "terminator_window_size", 10
154
+ )
155
+
156
+ if terminator_checkpoint_path:
157
+ self._terminator_checkpoint_path = terminator_checkpoint_path
158
+ self._terminator_enabled = True
159
+
160
+ # Load checkpoint metadata to construct FFN with correct architecture
161
+ checkpoint = torch.load(
162
+ terminator_checkpoint_path, map_location="cpu", weights_only=False
163
+ )
164
+ terminator_config = checkpoint["config"]
165
+ self._terminator_layer_idx = checkpoint["layer_idx"]
166
+ self._terminator_state_dict = checkpoint["state_dict"]
167
+
168
+ self.terminator_ffn = TerminatorFFN(
169
+ hidden_size=terminator_config["hidden_size"],
170
+ num_hidden_layers=terminator_config.get("ffn_layers", 1),
171
+ activation=terminator_config.get("ffn_activation", "gelu"),
172
+ intermediate_size=terminator_config.get(
173
+ "ffn_intermediate_size", None
174
+ ),
175
+ dropout=0.0,
176
+ rms_norm_eps=getattr(config, "rms_norm_eps", 1e-6),
177
+ )
178
+
179
+ logger.info(
180
+ "Terminator FFN created (layer_idx=%d, threshold=%.2f, "
181
+ "window_size=%d, params=%d)",
182
+ self._terminator_layer_idx,
183
+ self._terminator_threshold,
184
+ self._terminator_window_size,
185
+ sum(p.numel() for p in self.terminator_ffn.parameters()),
186
+ )
187
+
188
+ # --- Extra transformer layers ---
189
+ self._num_extra_layers = terminator_config.get(
190
+ "num_extra_layers", 0
191
+ )
192
+ self._extra_layers_state_dict = checkpoint.get(
193
+ "extra_layers_state_dict", None
194
+ )
195
+
196
+ if self._num_extra_layers > 0 and self._extra_layers_state_dict is not None:
197
+ cache_config = vllm_config.cache_config
198
+ # Use indices starting after base layers to avoid
199
+ # extract_layer_index() collisions in KV cache binding.
200
+ num_base_layers = config.num_hidden_layers
201
+ self.terminator_extra_layers = nn.ModuleList([
202
+ Qwen3DecoderLayer(
203
+ config=config,
204
+ cache_config=cache_config,
205
+ quant_config=quant_config,
206
+ prefix=f"terminator_extra_layers.{num_base_layers + i}",
207
+ )
208
+ for i in range(self._num_extra_layers)
209
+ ])
210
+ logger.info(
211
+ "Terminator extra layers created (n=%d, params=%d)",
212
+ self._num_extra_layers,
213
+ sum(
214
+ p.numel()
215
+ for p in self.terminator_extra_layers.parameters()
216
+ ),
217
+ )
218
+ else:
219
+ self._num_extra_layers = 0
220
+ self._extra_layers_state_dict = None
221
+ else:
222
+ self._terminator_enabled = False
223
+ self._num_extra_layers = 0
224
+ self._extra_layers_state_dict = None
225
+ logger.info(
226
+ "No terminator_checkpoint_path in config; "
227
+ "terminator disabled (running as standard Qwen3)"
228
+ )
229
+
230
+ # --- Think token IDs ---
231
+ tokenizer = AutoTokenizer.from_pretrained(
232
+ vllm_config.model_config.tokenizer,
233
+ trust_remote_code=vllm_config.model_config.trust_remote_code,
234
+ )
235
+ self._think_token_id = tokenizer.convert_tokens_to_ids("<think>")
236
+ self._think_end_token_id = tokenizer.convert_tokens_to_ids("</think>")
237
+ logger.info(
238
+ "<think>=%d, </think>=%d",
239
+ self._think_token_id,
240
+ self._think_end_token_id,
241
+ )
242
+
243
+ # --- Exit message ---
244
+ # Pre-tokenize the exit message + </think> so we can force one token
245
+ # per step when the terminator fires.
246
+ _default_exit_msg = (
247
+ "\nI've run out of thinking tokens."
248
+ " I need to commit to a final answer."
249
+ )
250
+ exit_msg = getattr(config, "terminator_exit_message", _default_exit_msg)
251
+ if exit_msg:
252
+ msg_ids = tokenizer.encode(exit_msg, add_special_tokens=False)
253
+ self._exit_sequence: list[int] = msg_ids + [self._think_end_token_id]
254
+ else:
255
+ self._exit_sequence = [self._think_end_token_id]
256
+ logger.info(
257
+ "Exit sequence: %d tokens (message=%r)",
258
+ len(self._exit_sequence),
259
+ exit_msg if exit_msg else "<none>",
260
+ )
261
+
262
+ # --- Per-request state (batch_size=1) ---
263
+ self._is_thinking = False
264
+ self._pred_buffer: torch.Tensor | None = None # lazily allocated in forward()
265
+ self._prev_output_token_id: int | None = None # argmax from previous compute_logits()
266
+ self._prediction_history: deque[float] = deque(
267
+ maxlen=self._terminator_window_size if self._terminator_enabled else 1
268
+ )
269
+ self._forcing_exit: bool = False
270
+ self._forcing_idx: int = 0
271
+
272
+ def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
273
+ return self.model.embed_input_ids(input_ids)
274
+
275
+ # ------------------------------------------------------------------
276
+ # Thinking state tracking
277
+ # ------------------------------------------------------------------
278
+
279
+ def _update_thinking_state(self, input_ids: torch.Tensor) -> None:
280
+ """Set ``_is_thinking`` based on the last think token in *input_ids*."""
281
+ think_pos = (input_ids == self._think_token_id).nonzero(as_tuple=True)[0]
282
+ end_pos = (input_ids == self._think_end_token_id).nonzero(as_tuple=True)[0]
283
+
284
+ last_think = think_pos[-1].item() if len(think_pos) > 0 else -1
285
+ last_end = end_pos[-1].item() if len(end_pos) > 0 else -1
286
+
287
+ if last_think > last_end:
288
+ self._is_thinking = True
289
+ elif last_end > last_think:
290
+ self._is_thinking = False
291
+
292
+ # ------------------------------------------------------------------
293
+ # Forward / compute_logits
294
+ # ------------------------------------------------------------------
295
+
296
+ def forward(
297
+ self,
298
+ input_ids: torch.Tensor,
299
+ positions: torch.Tensor,
300
+ intermediate_tensors: IntermediateTensors | None = None,
301
+ inputs_embeds: torch.Tensor | None = None,
302
+ ) -> torch.Tensor | IntermediateTensors:
303
+ hidden_states = self.model(
304
+ input_ids, positions, intermediate_tensors, inputs_embeds
305
+ )
306
+
307
+ if self._terminator_enabled:
308
+ # Update thinking state during eager execution only (prefill).
309
+ # .nonzero() / .item() are not CUDA-graph-capture-safe.
310
+ # Reset all per-request state so predictions, exit-forcing,
311
+ # and output token tracking don't leak between requests.
312
+ if not torch.cuda.is_current_stream_capturing():
313
+ self._update_thinking_state(input_ids)
314
+ self._prev_output_token_id = None
315
+ self._prediction_history.clear()
316
+ self._forcing_exit = False
317
+ self._forcing_idx = 0
318
+
319
+ # Run extra layers + FFN unconditionally — all ops (RMSNorm,
320
+ # Linear, Attention with KV cache, sigmoid) are CUDA-graph-safe.
321
+ # The prediction is written in-place to a pre-allocated buffer
322
+ # that persists across graph replays.
323
+ pre_norm = self.model._pre_norm_hidden
324
+
325
+ if self._num_extra_layers > 0:
326
+ # Run extra transformer layers on ALL tokens so KV cache
327
+ # is populated during prefill and updated during decode.
328
+ extra_hidden = pre_norm
329
+ extra_residual = None
330
+ for layer in self.terminator_extra_layers:
331
+ extra_hidden, extra_residual = layer(
332
+ positions, extra_hidden, extra_residual
333
+ )
334
+ # Reconstruct pre-norm state for the FFN:
335
+ # hidden + residual gives the un-normed output,
336
+ # matching what the FFN's own RMSNorm expects.
337
+ ffn_input = extra_hidden + extra_residual
338
+ else:
339
+ ffn_input = pre_norm
340
+
341
+ logit = self.terminator_ffn(ffn_input[-1:])
342
+ pred = torch.sigmoid(logit)
343
+ if self._pred_buffer is None:
344
+ # First eager forward (warmup) — allocate the buffer.
345
+ # Subsequent calls (including graph capture) reuse it.
346
+ self._pred_buffer = torch.zeros_like(pred)
347
+ self._pred_buffer.copy_(pred)
348
+
349
+ return hidden_states
350
+
351
+ def compute_logits(
352
+ self,
353
+ hidden_states: torch.Tensor,
354
+ ) -> torch.Tensor | None:
355
+ logits = self.logits_processor(self.lm_head, hidden_states)
356
+
357
+ # compute_logits() is called outside the CUDA graph, so .item()
358
+ # and data-dependent branching are safe here.
359
+ if self._terminator_enabled:
360
+ # --- Exit message forcing ---
361
+ # When the terminator has fired, we walk through the pre-tokenized
362
+ # exit sequence one token per step, skipping all prediction logic.
363
+ if self._forcing_exit:
364
+ token_id = self._exit_sequence[self._forcing_idx]
365
+ logits.fill_(float("-inf"))
366
+ logits[:, token_id] = 0.0
367
+ self._forcing_idx += 1
368
+ if self._forcing_idx >= len(self._exit_sequence):
369
+ # Done forcing — last token was </think>.
370
+ self._forcing_exit = False
371
+ self._is_thinking = False
372
+ logger.debug(
373
+ "Exit sequence complete (%d tokens forced)",
374
+ len(self._exit_sequence),
375
+ )
376
+ self._prev_output_token_id = token_id
377
+ return logits
378
+
379
+ # Track thinking state from the previous step's output token.
380
+ # During CUDA graph replay forward() doesn't execute, so
381
+ # _update_thinking_state() never sees generated <think> tokens.
382
+ # Instead we infer the state from the argmax of the previous
383
+ # step's logits (exact for greedy; heuristic for sampling).
384
+ if self._prev_output_token_id == self._think_token_id:
385
+ if not self._is_thinking:
386
+ self._prediction_history.clear()
387
+ self._is_thinking = True
388
+ elif self._prev_output_token_id == self._think_end_token_id:
389
+ self._is_thinking = False
390
+
391
+ if self._is_thinking and self._pred_buffer is not None:
392
+ pred = self._pred_buffer.item()
393
+ self._prediction_history.append(pred)
394
+
395
+ if len(self._prediction_history) >= self._terminator_window_size:
396
+ n_above = sum(
397
+ 1 for p in self._prediction_history
398
+ if p > self._terminator_threshold
399
+ )
400
+ vote = n_above / self._terminator_window_size
401
+ if vote > 0.5:
402
+ # Majority of sliding window exceeds threshold —
403
+ # enter exit-message forcing mode.
404
+ logger.debug(
405
+ "Terminator FIRING: pred=%.3f, "
406
+ "window=[%s] (%d/%d above %.2f, vote=%.2f)",
407
+ pred,
408
+ ", ".join(f"{p:.3f}" for p in self._prediction_history),
409
+ n_above,
410
+ self._terminator_window_size,
411
+ self._terminator_threshold,
412
+ vote,
413
+ )
414
+ # Force the first token of the exit sequence now,
415
+ # and set up state so subsequent calls continue.
416
+ self._forcing_exit = True
417
+ self._forcing_idx = 0
418
+ token_id = self._exit_sequence[0]
419
+ logits.fill_(float("-inf"))
420
+ logits[:, token_id] = 0.0
421
+ self._forcing_idx = 1
422
+ if self._forcing_idx >= len(self._exit_sequence):
423
+ self._forcing_exit = False
424
+ self._is_thinking = False
425
+ self._prev_output_token_id = token_id
426
+ return logits
427
+ else:
428
+ logger.debug(
429
+ "Terminator: pred=%.3f, "
430
+ "window=[%s] (%d/%d above %.2f, vote=%.2f)",
431
+ pred,
432
+ ", ".join(f"{p:.3f}" for p in self._prediction_history),
433
+ n_above,
434
+ self._terminator_window_size,
435
+ self._terminator_threshold,
436
+ vote,
437
+ )
438
+ else:
439
+ logger.debug(
440
+ "Terminator: pred=%.3f, filling window (%d/%d)",
441
+ pred,
442
+ len(self._prediction_history),
443
+ self._terminator_window_size,
444
+ )
445
+
446
+ # Record argmax for next step's thinking-state tracking.
447
+ self._prev_output_token_id = logits[0].argmax().item()
448
+
449
+ return logits
450
+
451
+ # ------------------------------------------------------------------
452
+ # Weight loading
453
+ # ------------------------------------------------------------------
454
+
455
+ # Mapping from HF checkpoint weight names to vLLM fused names.
456
+ # Mirrors the stacked_params_mapping in Qwen2Model.load_weights().
457
+ _extra_layer_stacked_mapping = [
458
+ # (vllm_param, hf_name, shard_id)
459
+ ("qkv_proj", "q_proj", "q"),
460
+ ("qkv_proj", "k_proj", "k"),
461
+ ("qkv_proj", "v_proj", "v"),
462
+ ("gate_up_proj", "gate_proj", 0),
463
+ ("gate_up_proj", "up_proj", 1),
464
+ ]
465
+
466
+ def _load_extra_layers_weights(self, loaded: set[str]) -> None:
467
+ """Load extra transformer layer weights from the checkpoint.
468
+
469
+ The checkpoint stores HF-format keys
470
+ (``layers.0.self_attn.q_proj.weight``) which must be mapped to
471
+ vLLM's fused names (``terminator_extra_layers.0.self_attn.qkv_proj``).
472
+ This mirrors the ``stacked_params_mapping`` approach used by
473
+ ``Qwen2Model.load_weights()``.
474
+ """
475
+ if self._extra_layers_state_dict is None:
476
+ return
477
+
478
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
479
+
480
+ for ckpt_name, tensor in self._extra_layers_state_dict.items():
481
+ # Remap checkpoint prefix to model module path.
482
+ name = ckpt_name.replace("layers.", "terminator_extra_layers.", 1)
483
+
484
+ if "rotary_emb.inv_freq" in name:
485
+ continue
486
+
487
+ # Check stacked (fused) projection mapping.
488
+ for param_name, weight_name, shard_id in self._extra_layer_stacked_mapping:
489
+ if weight_name not in name:
490
+ continue
491
+ name = name.replace(weight_name, param_name)
492
+ if name.endswith(".bias") and name not in params_dict:
493
+ continue
494
+ param = params_dict[name]
495
+ weight_loader = getattr(
496
+ param, "weight_loader", default_weight_loader
497
+ )
498
+ if weight_loader == default_weight_loader:
499
+ weight_loader(param, tensor)
500
+ else:
501
+ weight_loader(param, tensor, shard_id)
502
+ loaded.add(name)
503
+ break
504
+ else:
505
+ # Direct (non-fused) parameter — norms, o_proj, down_proj.
506
+ if name.endswith(".bias") and name not in params_dict:
507
+ continue
508
+ if name not in params_dict:
509
+ logger.warning(
510
+ "Skipping extra-layer weight %s (no matching param)",
511
+ name,
512
+ )
513
+ continue
514
+ param = params_dict[name]
515
+ weight_loader = getattr(
516
+ param, "weight_loader", default_weight_loader
517
+ )
518
+ weight_loader(param, tensor)
519
+ loaded.add(name)
520
+
521
+ del self._extra_layers_state_dict
522
+ self._extra_layers_state_dict = None
523
+ logger.info(
524
+ "Terminator extra layers weights loaded from checkpoint",
525
+ )
526
+
527
+ def load_weights(
528
+ self, weights: Iterable[tuple[str, torch.Tensor]]
529
+ ) -> set[str]:
530
+ skip = ["terminator_ffn.", "terminator_extra_layers."]
531
+ if self.config.tie_word_embeddings:
532
+ skip.append("lm_head.")
533
+
534
+ loader = AutoWeightsLoader(self, skip_prefixes=skip)
535
+ loaded = loader.load_weights(weights)
536
+
537
+ # Load terminator FFN and extra layers from the separate .pt
538
+ # checkpoint (not from the HF safetensors).
539
+ if self._terminator_enabled:
540
+ self.terminator_ffn.load_state_dict(self._terminator_state_dict)
541
+ del self._terminator_state_dict # free memory
542
+ logger.info(
543
+ "Terminator FFN weights loaded from %s",
544
+ self._terminator_checkpoint_path,
545
+ )
546
+ # Tell vLLM these weights have been handled so it doesn't
547
+ # complain about uninitialized parameters.
548
+ for name in self.terminator_ffn.state_dict():
549
+ loaded.add(f"terminator_ffn.{name}")
550
+
551
+ self._load_extra_layers_weights(loaded)
552
+
553
+ return loaded
vllm_terminator/terminator_head.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Terminator FFN head for vLLM integration.
3
+
4
+ Mirrors LayerFFN from terminator_utils.py but uses a standalone RMSNorm
5
+ compatible with checkpoints trained using HuggingFace's Qwen2RMSNorm.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional, Tuple
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ class SimpleRMSNorm(nn.Module):
16
+ """RMSNorm with a `weight` parameter matching Qwen2RMSNorm's state_dict."""
17
+
18
+ def __init__(self, hidden_size: int, eps: float = 1e-6):
19
+ super().__init__()
20
+ self.weight = nn.Parameter(torch.ones(hidden_size))
21
+ self.variance_epsilon = eps
22
+
23
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
24
+ input_dtype = hidden_states.dtype
25
+ hidden_states = hidden_states.to(torch.float32)
26
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
27
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
28
+ return self.weight * hidden_states.to(input_dtype)
29
+
30
+
31
+ class TerminatorFFN(nn.Module):
32
+ """
33
+ Feed-forward network for per-position binary classification.
34
+
35
+ Architecture mirrors LayerFFN from terminator_utils.py:
36
+ - Pre-normalization with RMSNorm
37
+ - Linear projection(s) to scalar logit
38
+ - No sigmoid (outputs raw logits)
39
+
40
+ State dict keys match training checkpoints exactly:
41
+ - norm.weight
42
+ - network.weight, network.bias (1-layer case)
43
+ - network.0.weight, network.0.bias, ... (multi-layer case)
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ hidden_size: int,
49
+ num_hidden_layers: int = 1,
50
+ activation: str = 'gelu',
51
+ intermediate_size: Optional[int] = None,
52
+ dropout: float = 0.0,
53
+ rms_norm_eps: float = 1e-6,
54
+ ):
55
+ super().__init__()
56
+
57
+ self.hidden_size = hidden_size
58
+ self.norm = SimpleRMSNorm(hidden_size, eps=rms_norm_eps)
59
+
60
+ if num_hidden_layers == 1:
61
+ self.network = nn.Linear(hidden_size, 1)
62
+ else:
63
+ if intermediate_size is None:
64
+ intermediate_size = hidden_size * 2
65
+
66
+ layers = []
67
+ layers.append(nn.Linear(hidden_size, intermediate_size))
68
+
69
+ act_fn = {'relu': nn.ReLU, 'gelu': nn.GELU, 'tanh': nn.Tanh}
70
+ if activation not in act_fn:
71
+ raise ValueError(f"Unknown activation: {activation}")
72
+
73
+ layers.append(act_fn[activation]())
74
+ layers.append(nn.Dropout(dropout))
75
+
76
+ for _ in range(num_hidden_layers - 2):
77
+ layers.append(nn.Linear(intermediate_size, intermediate_size))
78
+ layers.append(act_fn[activation]())
79
+ layers.append(nn.Dropout(dropout))
80
+
81
+ layers.append(nn.Linear(intermediate_size, 1))
82
+ self.network = nn.Sequential(*layers)
83
+
84
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
85
+ """
86
+ Args:
87
+ hidden_states: [num_tokens, hidden_size] or [batch, seq_len, hidden_size]
88
+
89
+ Returns:
90
+ logits: [num_tokens] or [batch, seq_len] raw logits
91
+ """
92
+ hidden_states = self.norm(hidden_states)
93
+ output = self.network(hidden_states)
94
+ return output.squeeze(-1)
95
+
96
+
97
+ def load_terminator_checkpoint(
98
+ checkpoint_path: str,
99
+ rms_norm_eps: float = 1e-6,
100
+ device: torch.device = torch.device("cpu"),
101
+ ) -> Tuple[TerminatorFFN, Dict[str, Any], int, int, Optional[Dict[str, Any]]]:
102
+ """
103
+ Load a trained terminator checkpoint and construct the FFN.
104
+
105
+ Args:
106
+ checkpoint_path: Path to layer_*.pt checkpoint file
107
+ rms_norm_eps: Epsilon for RMSNorm (from base model config)
108
+ device: Device to load onto
109
+
110
+ Returns:
111
+ (ffn, config, layer_idx, num_extra_layers, extra_layers_state_dict)
112
+ """
113
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
114
+
115
+ config = checkpoint["config"]
116
+ layer_idx = checkpoint["layer_idx"]
117
+ hidden_size = config["hidden_size"]
118
+
119
+ ffn = TerminatorFFN(
120
+ hidden_size=hidden_size,
121
+ num_hidden_layers=config.get("ffn_layers", 1),
122
+ activation=config.get("ffn_activation", "gelu"),
123
+ intermediate_size=config.get("ffn_intermediate_size", None),
124
+ dropout=0.0, # No dropout at inference
125
+ rms_norm_eps=rms_norm_eps,
126
+ )
127
+
128
+ ffn.load_state_dict(checkpoint["state_dict"])
129
+ ffn.to(device)
130
+ ffn.eval()
131
+
132
+ num_extra_layers = config.get("num_extra_layers", 0)
133
+ extra_layers_state_dict = checkpoint.get("extra_layers_state_dict", None)
134
+
135
+ return ffn, config, layer_idx, num_extra_layers, extra_layers_state_dict