File size: 4,345 Bytes
88b9f90 aa7a04b 88b9f90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 | #!/usr/bin/env python3
"""
Create a vLLM-ready model directory for Qwen3TerminatorForCausalLM.
Downloads the base Qwen3-8B config and weights from HuggingFace (if not
already cached), then creates a model directory with:
- config.json (Qwen3-8B base config + terminator fields)
- tokenizer files (symlinked from HF cache)
- model weights (symlinked from HF cache)
Usage:
# Default: uses ./terminator.pt checkpoint, creates ./model_dir
python setup_model_dir.py
# Custom paths and settings:
python setup_model_dir.py \\
--checkpoint /path/to/terminator.pt \\
--output-dir /path/to/model_dir \\
--threshold 0.5
"""
import argparse
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
from transformers import AutoConfig
def main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--base-model", default="Qwen/Qwen3-8B",
help="HuggingFace model ID for the base model (default: Qwen/Qwen3-8B).",
)
parser.add_argument(
"--checkpoint", type=Path, default="./terminator.pt",
help="Path to trained terminator .pt checkpoint (default: ./terminator.pt).",
)
parser.add_argument(
"--output-dir", type=Path, default="./model_dir",
help="Destination directory (default: ./model_dir; created if missing).",
)
parser.add_argument(
"--threshold", type=float, default=0.7,
help="Terminator firing threshold (default 0.7).",
)
parser.add_argument(
"--window-size", type=int, default=10,
help="Sliding window size for majority vote (default 10).",
)
parser.add_argument(
"--exit-message", type=str,
default="\nI've run out of thinking tokens. I need to commit to a final answer.",
help="Message forced when terminator fires (default: standard exit message). "
"Set to empty string to disable.",
)
parser.add_argument(
"--no-download", action="store_true",
help="Fail if the base model is not already cached locally "
"(by default, downloads from HuggingFace if needed).",
)
parser.add_argument(
"--force", action="store_true",
help="Overwrite files in existing output directory.",
)
args = parser.parse_args()
checkpoint = args.checkpoint.resolve()
out_dir = args.output_dir.resolve()
if not checkpoint.is_file():
print(f"ERROR: checkpoint not found: {checkpoint}", file=sys.stderr)
sys.exit(1)
out_dir.mkdir(parents=True, exist_ok=True)
# --- Build patched config.json ---
print(f"Loading config for {args.base_model} from HF cache...")
config = AutoConfig.from_pretrained(args.base_model)
config.architectures = ["Qwen3TerminatorForCausalLM"]
config.terminator_checkpoint_path = str(checkpoint)
config.terminator_threshold = args.threshold
config.terminator_window_size = args.window_size
config.terminator_exit_message = args.exit_message
# Remove auto_map if present from an older span-predictor config
if hasattr(config, "auto_map"):
del config.auto_map
config.save_pretrained(out_dir)
print(f" Wrote config.json -> {out_dir / 'config.json'}")
# --- Symlink weights and tokenizer files from HF cache ---
print(f"Locating {args.base_model} in HF cache...")
allow_download = not args.no_download
base_dir = Path(snapshot_download(args.base_model, local_files_only=not allow_download))
print(f" Found: {base_dir}")
linked = 0
for src in sorted(base_dir.iterdir()):
if src.name in ("config.json",):
continue # we already wrote our own
dst = out_dir / src.name
if dst.exists() or dst.is_symlink():
if args.force:
dst.unlink()
else:
continue
os.symlink(src, dst)
print(f" Linked {src.name}")
linked += 1
print(f"\nDone. Linked {linked} files into {out_dir}")
print(f"\nTo start the server:")
print(f" ./start_server.sh")
print(f"\nOr manually:")
print(f" VLLM_MODEL={out_dir} REASONING_PARSER=qwen3 python serve.py")
if __name__ == "__main__":
main()
|