| | |
| | """ |
| | Create a vLLM-ready model directory for Qwen3TerminatorForCausalLM. |
| | |
| | Downloads the base Qwen3-14B config and weights from HuggingFace (if not |
| | already cached), then creates a model directory with: |
| | - config.json (Qwen3-14B base config + terminator fields) |
| | - tokenizer files (symlinked from HF cache) |
| | - model weights (symlinked from HF cache) |
| | |
| | Usage: |
| | # Default: uses ./terminator.pt checkpoint, creates ./model_dir |
| | python setup_model_dir.py |
| | |
| | # Custom paths and settings: |
| | python setup_model_dir.py \\ |
| | --checkpoint /path/to/terminator.pt \\ |
| | --output-dir /path/to/model_dir \\ |
| | --threshold 0.5 |
| | """ |
| |
|
| | import argparse |
| | import os |
| | import sys |
| | from pathlib import Path |
| |
|
| | from huggingface_hub import snapshot_download |
| | from transformers import AutoConfig |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser( |
| | description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter |
| | ) |
| | parser.add_argument( |
| | "--base-model", default="Qwen/Qwen3-14B", |
| | help="HuggingFace model ID for the base model (default: Qwen/Qwen3-14B).", |
| | ) |
| | parser.add_argument( |
| | "--checkpoint", type=Path, default="./terminator.pt", |
| | help="Path to trained terminator .pt checkpoint (default: ./terminator.pt).", |
| | ) |
| | parser.add_argument( |
| | "--output-dir", type=Path, default="./model_dir", |
| | help="Destination directory (default: ./model_dir; created if missing).", |
| | ) |
| | parser.add_argument( |
| | "--threshold", type=float, default=0.7, |
| | help="Terminator firing threshold (default 0.7).", |
| | ) |
| | parser.add_argument( |
| | "--window-size", type=int, default=10, |
| | help="Sliding window size for majority vote (default 10).", |
| | ) |
| | parser.add_argument( |
| | "--exit-message", type=str, |
| | default="\nI've run out of thinking tokens. I need to commit to a final answer.", |
| | help="Message forced when terminator fires (default: standard exit message). " |
| | "Set to empty string to disable.", |
| | ) |
| | parser.add_argument( |
| | "--no-download", action="store_true", |
| | help="Fail if the base model is not already cached locally " |
| | "(by default, downloads from HuggingFace if needed).", |
| | ) |
| | parser.add_argument( |
| | "--force", action="store_true", |
| | help="Overwrite files in existing output directory.", |
| | ) |
| | args = parser.parse_args() |
| |
|
| | checkpoint = args.checkpoint.resolve() |
| | out_dir = args.output_dir.resolve() |
| |
|
| | if not checkpoint.is_file(): |
| | print(f"ERROR: checkpoint not found: {checkpoint}", file=sys.stderr) |
| | sys.exit(1) |
| |
|
| | out_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | print(f"Loading config for {args.base_model} from HF cache...") |
| | config = AutoConfig.from_pretrained(args.base_model) |
| |
|
| | config.architectures = ["Qwen3TerminatorForCausalLM"] |
| | config.terminator_checkpoint_path = str(checkpoint) |
| | config.terminator_threshold = args.threshold |
| | config.terminator_window_size = args.window_size |
| | config.terminator_exit_message = args.exit_message |
| |
|
| | |
| | if hasattr(config, "auto_map"): |
| | del config.auto_map |
| |
|
| | config.save_pretrained(out_dir) |
| | print(f" Wrote config.json -> {out_dir / 'config.json'}") |
| |
|
| | |
| | print(f"Locating {args.base_model} in HF cache...") |
| | allow_download = not args.no_download |
| | base_dir = Path(snapshot_download(args.base_model, local_files_only=not allow_download)) |
| | print(f" Found: {base_dir}") |
| |
|
| | linked = 0 |
| | for src in sorted(base_dir.iterdir()): |
| | if src.name in ("config.json",): |
| | continue |
| |
|
| | dst = out_dir / src.name |
| | if dst.exists() or dst.is_symlink(): |
| | if args.force: |
| | dst.unlink() |
| | else: |
| | continue |
| |
|
| | os.symlink(src, dst) |
| | print(f" Linked {src.name}") |
| | linked += 1 |
| |
|
| | print(f"\nDone. Linked {linked} files into {out_dir}") |
| | print(f"\nTo start the server:") |
| | print(f" ./start_server.sh") |
| | print(f"\nOr manually:") |
| | print(f" VLLM_MODEL={out_dir} REASONING_PARSER=qwen3 python serve.py") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|