Terminator-Qwen3-8B / inference_hf.py
acnagle's picture
Upload folder using huggingface_hub
aa7a04b verified
#!/usr/bin/env python3
"""
HuggingFace-native inference for Terminator-Qwen3-8B.
Loads the frozen Qwen3 base model + trained Terminator head (FFN + optional
extra transformer layers) directly via HuggingFace transformers.
Generates chain-of-thought reasoning token-by-token. The Terminator FFN
predicts when the final answer has been reached; when a sliding-window
majority vote exceeds the threshold, an exit message is injected and the
model transitions to answering mode.
Usage:
python inference_hf.py --prompt "What is the sum of the first 100 natural numbers?"
python inference_hf.py \\
--prompt "Solve x^2 - 5x + 6 = 0" \\
--model Qwen/Qwen3-8B \\
--checkpoint terminator.pt \\
--threshold 0.7 --window-size 10
"""
import argparse
import os
import sys
from pathlib import Path
import torch
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper
from transformers.generation.logits_process import LogitsProcessorList
# ---------------------------------------------------------------------------
# Imports from the project
# ---------------------------------------------------------------------------
# Local: TerminatorFFN + checkpoint loader
_script_dir = Path(__file__).resolve().parent
sys.path.insert(0, str(_script_dir))
from vllm_terminator.terminator_head import load_terminator_checkpoint
# Parent dir: ExtraTransformerLayers from terminator_utils
_repo_root = _script_dir.parent
sys.path.insert(0, str(_repo_root))
from terminator_utils import ExtraTransformerLayers
# ---------------------------------------------------------------------------
# ANSI escape codes
# ---------------------------------------------------------------------------
DIM = "\033[2m"
BOLD = "\033[1m"
RESET = "\033[0m"
def load_model_and_tokenizer(model_name, device):
"""Load base Qwen3 model and tokenizer."""
print(f"Loading tokenizer: {model_name}")
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
think_token_id = tokenizer.convert_tokens_to_ids("<think>")
think_end_token_id = tokenizer.convert_tokens_to_ids("</think>")
if think_token_id == tokenizer.unk_token_id or think_end_token_id == tokenizer.unk_token_id:
raise ValueError(
f"<think>/<think> tokens not in tokenizer! "
f"IDs: {think_token_id}, {think_end_token_id}"
)
print(f"Loading model: {model_name}")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map={"": device},
trust_remote_code=True,
)
for param in model.parameters():
param.requires_grad = False
model.eval()
print(
f"Model loaded: {model.config.num_hidden_layers} layers, "
f"hidden size {model.config.hidden_size}"
)
return model, tokenizer, think_token_id, think_end_token_id
def build_extra_layers(base_model, checkpoint_config, extra_layers_state_dict, device):
"""Reconstruct extra transformer layers from checkpoint state dict."""
num_extra_layers = checkpoint_config.get("num_extra_layers", 0)
if num_extra_layers == 0 or extra_layers_state_dict is None:
return None
print(f"Reconstructing {num_extra_layers} extra transformer layer(s)...")
base_layer_class = base_model.model.layers[0].__class__
model_config = base_model.config
rotary_emb = getattr(base_model.model, "rotary_emb", None)
extra_layers = ExtraTransformerLayers(
base_layer_class, num_extra_layers, model_config, rotary_emb=rotary_emb
).to(device)
extra_layers.load_state_dict(extra_layers_state_dict)
extra_layers.eval()
param_count = sum(p.numel() for p in extra_layers.parameters())
print(f"Extra layers loaded ({param_count:,} parameters)")
return extra_layers
def generate_with_terminator(
prompt,
model,
tokenizer,
ffn,
extra_layers,
layer_idx,
think_token_id,
think_end_token_id,
threshold,
window_size,
exit_message,
max_tokens,
temperature,
device,
):
"""Generate a response with Terminator early-exit logic.
Follows the same generation pattern as inference_terminator.py:mode1_generate().
Streams thinking tokens to the terminal as they are produced.
"""
# Format prompt via chat template
messages = [{"role": "user", "content": prompt}]
prompt_text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# Tokenize and append <think>
prompt_ids = tokenizer(
prompt_text, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to(device).long()
input_ids = torch.cat(
[prompt_ids, torch.tensor([[think_token_id]], dtype=torch.long, device=device)],
dim=1,
)
# Sampling processors
logits_processor = LogitsProcessorList([
TemperatureLogitsWarper(temperature=temperature),
TopKLogitsWarper(top_k=20),
TopPLogitsWarper(top_p=0.95),
])
# Sliding-window state
predictions_list = []
reasoning_tokens = []
early_exit = False
# Start streaming thinking output
sys.stdout.write(f"\n{DIM}Thinking...\n")
sys.stdout.flush()
for step in range(max_tokens):
attention_mask = torch.ones_like(input_ids)
# Hook to capture hidden states from the target layer
captured = {}
def hook_fn(module, input, output):
if isinstance(output, tuple):
captured["hidden"] = output[0].detach()
else:
captured["hidden"] = output.detach()
target_layer = model.model.layers[layer_idx]
handle = target_layer.register_forward_hook(hook_fn)
with torch.no_grad():
outputs = model(
input_ids=input_ids,
attention_mask=attention_mask,
use_cache=False,
)
handle.remove()
hidden_states = captured["hidden"] # [1, seq_len, hidden_size]
# Make prediction once we have at least one thinking token
if len(reasoning_tokens) > 0:
if extra_layers is not None:
h = hidden_states.float()
h = extra_layers(h, attention_mask=attention_mask)
last_h = h[:, -1:, :]
logits_pred = ffn(last_h.float())
else:
last_h = hidden_states[:, -1:, :]
logits_pred = ffn(last_h.float())
pred = torch.sigmoid(logits_pred)
predictions_list.append(pred[0, 0].item())
# Sliding-window majority vote
if len(predictions_list) >= window_size:
window = predictions_list[-window_size:]
n_above = sum(1 for p in window if p > threshold)
if n_above / window_size > 0.5:
early_exit = True
break
# Sample next token — LogitsProcessorList expects 2D [batch, vocab]
next_logits = outputs.logits[:, -1, :] # [1, vocab_size]
next_logits = logits_processor(input_ids, next_logits)
probs = F.softmax(next_logits, dim=-1)
next_token = torch.multinomial(probs, num_samples=1) # [1, 1]
# Natural </think>
if next_token.item() == think_end_token_id:
break
input_ids = torch.cat([input_ids, next_token], dim=1)
reasoning_tokens.append(next_token.item())
# Stream the token
token_text = tokenizer.decode([next_token.item()], skip_special_tokens=False)
sys.stdout.write(token_text)
sys.stdout.flush()
# End thinking section
if early_exit and exit_message:
sys.stdout.write(exit_message)
sys.stdout.write(f"{RESET}\n")
sys.stdout.flush()
# Build input for final answer generation
if early_exit and exit_message:
exit_ids = tokenizer(
exit_message, add_special_tokens=False, return_tensors="pt"
)["input_ids"].to(device).long()
input_ids = torch.cat(
[input_ids, exit_ids,
torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)],
dim=1,
)
else:
input_ids = torch.cat(
[input_ids,
torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)],
dim=1,
)
# Generate final answer
attention_mask = torch.ones_like(input_ids)
with torch.no_grad():
final_outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=max_tokens,
do_sample=True,
temperature=temperature,
top_p=0.95,
top_k=20,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Extract answer (everything after last </think>)
full_seq = final_outputs[0]
end_positions = (full_seq == think_end_token_id).nonzero(as_tuple=True)[0]
if len(end_positions) > 0:
answer_tokens = full_seq[end_positions[-1].item() + 1 :]
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True)
else:
answer = ""
# Print answer
sys.stdout.write(f"{BOLD}Answer:{RESET}\n{answer}\n")
sys.stdout.flush()
# Summary
n_reasoning = len(reasoning_tokens)
exit_reason = "predictor" if early_exit else "natural_end"
print(
f"\n{DIM}[{exit_reason} | "
f"{n_reasoning} thinking tokens | "
f"{len(predictions_list)} predictions]{RESET}"
)
def main():
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument("--prompt", type=str, required=True, help="Input prompt")
parser.add_argument(
"--model", type=str, default="Qwen/Qwen3-8B", help="HuggingFace model name"
)
parser.add_argument(
"--checkpoint",
type=str,
default=None,
help="Path to terminator .pt checkpoint (default: ./terminator.pt)",
)
parser.add_argument(
"--threshold", type=float, default=0.7, help="Per-prediction binarization threshold"
)
parser.add_argument(
"--window-size", type=int, default=10, help="Sliding-window size for majority vote"
)
parser.add_argument(
"--exit-message",
type=str,
default="\nI've run out of thinking tokens. I need to commit to a final answer.",
help="Message injected when terminator fires (empty string to disable)",
)
parser.add_argument(
"--max-tokens", type=int, default=32768, help="Max tokens to generate"
)
parser.add_argument(
"--temperature", type=float, default=0.6, help="Sampling temperature"
)
parser.add_argument(
"--device", type=str, default="cuda", help="Device (default: cuda)"
)
args = parser.parse_args()
# Resolve checkpoint path
if args.checkpoint is None:
args.checkpoint = str(_script_dir / "terminator.pt")
if not Path(args.checkpoint).exists():
print(f"ERROR: Checkpoint not found: {args.checkpoint}", file=sys.stderr)
sys.exit(1)
# Handle empty exit message
if args.exit_message == "":
args.exit_message = None
device = torch.device(args.device if torch.cuda.is_available() else "cpu")
# Load base model
model, tokenizer, think_id, think_end_id = load_model_and_tokenizer(
args.model, device
)
# Load terminator checkpoint
rms_eps = getattr(model.config, "rms_norm_eps", 1e-6)
ffn, ckpt_config, layer_idx, num_extra_layers, extra_sd = load_terminator_checkpoint(
args.checkpoint, rms_norm_eps=rms_eps, device=device
)
ffn_params = sum(p.numel() for p in ffn.parameters())
print(
f"Terminator FFN loaded (layer_idx={layer_idx}, "
f"threshold={args.threshold}, window={args.window_size}, "
f"params={ffn_params:,})"
)
# Extra layers
extra_layers = build_extra_layers(model, ckpt_config, extra_sd, device)
# Generate
generate_with_terminator(
prompt=args.prompt,
model=model,
tokenizer=tokenizer,
ffn=ffn,
extra_layers=extra_layers,
layer_idx=layer_idx,
think_token_id=think_id,
think_end_token_id=think_end_id,
threshold=args.threshold,
window_size=args.window_size,
exit_message=args.exit_message,
max_tokens=args.max_tokens,
temperature=args.temperature,
device=device,
)
if __name__ == "__main__":
main()