| |
| """ |
| HuggingFace-native inference for Terminator-Qwen3-8B. |
| |
| Loads the frozen Qwen3 base model + trained Terminator head (FFN + optional |
| extra transformer layers) directly via HuggingFace transformers. |
| |
| Generates chain-of-thought reasoning token-by-token. The Terminator FFN |
| predicts when the final answer has been reached; when a sliding-window |
| majority vote exceeds the threshold, an exit message is injected and the |
| model transitions to answering mode. |
| |
| Usage: |
| python inference_hf.py --prompt "What is the sum of the first 100 natural numbers?" |
| |
| python inference_hf.py \\ |
| --prompt "Solve x^2 - 5x + 6 = 0" \\ |
| --model Qwen/Qwen3-8B \\ |
| --checkpoint terminator.pt \\ |
| --threshold 0.7 --window-size 10 |
| """ |
|
|
| import argparse |
| import os |
| import sys |
| from pathlib import Path |
|
|
| import torch |
| import torch.nn.functional as F |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| from transformers import TopKLogitsWarper, TopPLogitsWarper, TemperatureLogitsWarper |
| from transformers.generation.logits_process import LogitsProcessorList |
|
|
| |
| |
| |
|
|
| |
| _script_dir = Path(__file__).resolve().parent |
| sys.path.insert(0, str(_script_dir)) |
| from vllm_terminator.terminator_head import load_terminator_checkpoint |
|
|
| |
| _repo_root = _script_dir.parent |
| sys.path.insert(0, str(_repo_root)) |
| from terminator_utils import ExtraTransformerLayers |
|
|
| |
| |
| |
| DIM = "\033[2m" |
| BOLD = "\033[1m" |
| RESET = "\033[0m" |
|
|
|
|
| def load_model_and_tokenizer(model_name, device): |
| """Load base Qwen3 model and tokenizer.""" |
| print(f"Loading tokenizer: {model_name}") |
| tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| think_token_id = tokenizer.convert_tokens_to_ids("<think>") |
| think_end_token_id = tokenizer.convert_tokens_to_ids("</think>") |
| if think_token_id == tokenizer.unk_token_id or think_end_token_id == tokenizer.unk_token_id: |
| raise ValueError( |
| f"<think>/<think> tokens not in tokenizer! " |
| f"IDs: {think_token_id}, {think_end_token_id}" |
| ) |
|
|
| print(f"Loading model: {model_name}") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.bfloat16, |
| device_map={"": device}, |
| trust_remote_code=True, |
| ) |
| for param in model.parameters(): |
| param.requires_grad = False |
| model.eval() |
|
|
| print( |
| f"Model loaded: {model.config.num_hidden_layers} layers, " |
| f"hidden size {model.config.hidden_size}" |
| ) |
| return model, tokenizer, think_token_id, think_end_token_id |
|
|
|
|
| def build_extra_layers(base_model, checkpoint_config, extra_layers_state_dict, device): |
| """Reconstruct extra transformer layers from checkpoint state dict.""" |
| num_extra_layers = checkpoint_config.get("num_extra_layers", 0) |
| if num_extra_layers == 0 or extra_layers_state_dict is None: |
| return None |
|
|
| print(f"Reconstructing {num_extra_layers} extra transformer layer(s)...") |
| base_layer_class = base_model.model.layers[0].__class__ |
| model_config = base_model.config |
| rotary_emb = getattr(base_model.model, "rotary_emb", None) |
|
|
| extra_layers = ExtraTransformerLayers( |
| base_layer_class, num_extra_layers, model_config, rotary_emb=rotary_emb |
| ).to(device) |
| extra_layers.load_state_dict(extra_layers_state_dict) |
| extra_layers.eval() |
|
|
| param_count = sum(p.numel() for p in extra_layers.parameters()) |
| print(f"Extra layers loaded ({param_count:,} parameters)") |
| return extra_layers |
|
|
|
|
| def generate_with_terminator( |
| prompt, |
| model, |
| tokenizer, |
| ffn, |
| extra_layers, |
| layer_idx, |
| think_token_id, |
| think_end_token_id, |
| threshold, |
| window_size, |
| exit_message, |
| max_tokens, |
| temperature, |
| device, |
| ): |
| """Generate a response with Terminator early-exit logic. |
| |
| Follows the same generation pattern as inference_terminator.py:mode1_generate(). |
| Streams thinking tokens to the terminal as they are produced. |
| """ |
| |
| messages = [{"role": "user", "content": prompt}] |
| prompt_text = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
|
|
| |
| prompt_ids = tokenizer( |
| prompt_text, add_special_tokens=False, return_tensors="pt" |
| )["input_ids"].to(device).long() |
|
|
| input_ids = torch.cat( |
| [prompt_ids, torch.tensor([[think_token_id]], dtype=torch.long, device=device)], |
| dim=1, |
| ) |
|
|
| |
| logits_processor = LogitsProcessorList([ |
| TemperatureLogitsWarper(temperature=temperature), |
| TopKLogitsWarper(top_k=20), |
| TopPLogitsWarper(top_p=0.95), |
| ]) |
|
|
| |
| predictions_list = [] |
| reasoning_tokens = [] |
| early_exit = False |
|
|
| |
| sys.stdout.write(f"\n{DIM}Thinking...\n") |
| sys.stdout.flush() |
|
|
| for step in range(max_tokens): |
| attention_mask = torch.ones_like(input_ids) |
|
|
| |
| captured = {} |
|
|
| def hook_fn(module, input, output): |
| if isinstance(output, tuple): |
| captured["hidden"] = output[0].detach() |
| else: |
| captured["hidden"] = output.detach() |
|
|
| target_layer = model.model.layers[layer_idx] |
| handle = target_layer.register_forward_hook(hook_fn) |
|
|
| with torch.no_grad(): |
| outputs = model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
|
|
| handle.remove() |
|
|
| hidden_states = captured["hidden"] |
|
|
| |
| if len(reasoning_tokens) > 0: |
| if extra_layers is not None: |
| h = hidden_states.float() |
| h = extra_layers(h, attention_mask=attention_mask) |
| last_h = h[:, -1:, :] |
| logits_pred = ffn(last_h.float()) |
| else: |
| last_h = hidden_states[:, -1:, :] |
| logits_pred = ffn(last_h.float()) |
|
|
| pred = torch.sigmoid(logits_pred) |
| predictions_list.append(pred[0, 0].item()) |
|
|
| |
| if len(predictions_list) >= window_size: |
| window = predictions_list[-window_size:] |
| n_above = sum(1 for p in window if p > threshold) |
| if n_above / window_size > 0.5: |
| early_exit = True |
| break |
|
|
| |
| next_logits = outputs.logits[:, -1, :] |
| next_logits = logits_processor(input_ids, next_logits) |
| probs = F.softmax(next_logits, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
|
|
| |
| if next_token.item() == think_end_token_id: |
| break |
|
|
| input_ids = torch.cat([input_ids, next_token], dim=1) |
| reasoning_tokens.append(next_token.item()) |
|
|
| |
| token_text = tokenizer.decode([next_token.item()], skip_special_tokens=False) |
| sys.stdout.write(token_text) |
| sys.stdout.flush() |
|
|
| |
| if early_exit and exit_message: |
| sys.stdout.write(exit_message) |
| sys.stdout.write(f"{RESET}\n") |
| sys.stdout.flush() |
|
|
| |
| if early_exit and exit_message: |
| exit_ids = tokenizer( |
| exit_message, add_special_tokens=False, return_tensors="pt" |
| )["input_ids"].to(device).long() |
| input_ids = torch.cat( |
| [input_ids, exit_ids, |
| torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)], |
| dim=1, |
| ) |
| else: |
| input_ids = torch.cat( |
| [input_ids, |
| torch.tensor([[think_end_token_id]], dtype=torch.long, device=device)], |
| dim=1, |
| ) |
|
|
| |
| attention_mask = torch.ones_like(input_ids) |
| with torch.no_grad(): |
| final_outputs = model.generate( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=max_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=0.95, |
| top_k=20, |
| pad_token_id=tokenizer.pad_token_id, |
| eos_token_id=tokenizer.eos_token_id, |
| ) |
|
|
| |
| full_seq = final_outputs[0] |
| end_positions = (full_seq == think_end_token_id).nonzero(as_tuple=True)[0] |
| if len(end_positions) > 0: |
| answer_tokens = full_seq[end_positions[-1].item() + 1 :] |
| answer = tokenizer.decode(answer_tokens, skip_special_tokens=True) |
| else: |
| answer = "" |
|
|
| |
| sys.stdout.write(f"{BOLD}Answer:{RESET}\n{answer}\n") |
| sys.stdout.flush() |
|
|
| |
| n_reasoning = len(reasoning_tokens) |
| exit_reason = "predictor" if early_exit else "natural_end" |
| print( |
| f"\n{DIM}[{exit_reason} | " |
| f"{n_reasoning} thinking tokens | " |
| f"{len(predictions_list)} predictions]{RESET}" |
| ) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description=__doc__, |
| formatter_class=argparse.RawDescriptionHelpFormatter, |
| ) |
| parser.add_argument("--prompt", type=str, required=True, help="Input prompt") |
| parser.add_argument( |
| "--model", type=str, default="Qwen/Qwen3-8B", help="HuggingFace model name" |
| ) |
| parser.add_argument( |
| "--checkpoint", |
| type=str, |
| default=None, |
| help="Path to terminator .pt checkpoint (default: ./terminator.pt)", |
| ) |
| parser.add_argument( |
| "--threshold", type=float, default=0.7, help="Per-prediction binarization threshold" |
| ) |
| parser.add_argument( |
| "--window-size", type=int, default=10, help="Sliding-window size for majority vote" |
| ) |
| parser.add_argument( |
| "--exit-message", |
| type=str, |
| default="\nI've run out of thinking tokens. I need to commit to a final answer.", |
| help="Message injected when terminator fires (empty string to disable)", |
| ) |
| parser.add_argument( |
| "--max-tokens", type=int, default=32768, help="Max tokens to generate" |
| ) |
| parser.add_argument( |
| "--temperature", type=float, default=0.6, help="Sampling temperature" |
| ) |
| parser.add_argument( |
| "--device", type=str, default="cuda", help="Device (default: cuda)" |
| ) |
| args = parser.parse_args() |
|
|
| |
| if args.checkpoint is None: |
| args.checkpoint = str(_script_dir / "terminator.pt") |
|
|
| if not Path(args.checkpoint).exists(): |
| print(f"ERROR: Checkpoint not found: {args.checkpoint}", file=sys.stderr) |
| sys.exit(1) |
|
|
| |
| if args.exit_message == "": |
| args.exit_message = None |
|
|
| device = torch.device(args.device if torch.cuda.is_available() else "cpu") |
|
|
| |
| model, tokenizer, think_id, think_end_id = load_model_and_tokenizer( |
| args.model, device |
| ) |
|
|
| |
| rms_eps = getattr(model.config, "rms_norm_eps", 1e-6) |
| ffn, ckpt_config, layer_idx, num_extra_layers, extra_sd = load_terminator_checkpoint( |
| args.checkpoint, rms_norm_eps=rms_eps, device=device |
| ) |
| ffn_params = sum(p.numel() for p in ffn.parameters()) |
| print( |
| f"Terminator FFN loaded (layer_idx={layer_idx}, " |
| f"threshold={args.threshold}, window={args.window_size}, " |
| f"params={ffn_params:,})" |
| ) |
|
|
| |
| extra_layers = build_extra_layers(model, ckpt_config, extra_sd, device) |
|
|
| |
| generate_with_terminator( |
| prompt=args.prompt, |
| model=model, |
| tokenizer=tokenizer, |
| ffn=ffn, |
| extra_layers=extra_layers, |
| layer_idx=layer_idx, |
| think_token_id=think_id, |
| think_end_token_id=think_end_id, |
| threshold=args.threshold, |
| window_size=args.window_size, |
| exit_message=args.exit_message, |
| max_tokens=args.max_tokens, |
| temperature=args.temperature, |
| device=device, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|