Terminator-Qwen3-8B / setup_model_dir.py
acnagle's picture
Upload folder using huggingface_hub
aa7a04b verified
#!/usr/bin/env python3
"""
Create a vLLM-ready model directory for Qwen3TerminatorForCausalLM.
Downloads the base Qwen3-8B config and weights from HuggingFace (if not
already cached), then creates a model directory with:
- config.json (Qwen3-8B base config + terminator fields)
- tokenizer files (symlinked from HF cache)
- model weights (symlinked from HF cache)
Usage:
# Default: uses ./terminator.pt checkpoint, creates ./model_dir
python setup_model_dir.py
# Custom paths and settings:
python setup_model_dir.py \\
--checkpoint /path/to/terminator.pt \\
--output-dir /path/to/model_dir \\
--threshold 0.5
"""
import argparse
import os
import sys
from pathlib import Path
from huggingface_hub import snapshot_download
from transformers import AutoConfig
def main():
parser = argparse.ArgumentParser(
description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
)
parser.add_argument(
"--base-model", default="Qwen/Qwen3-8B",
help="HuggingFace model ID for the base model (default: Qwen/Qwen3-8B).",
)
parser.add_argument(
"--checkpoint", type=Path, default="./terminator.pt",
help="Path to trained terminator .pt checkpoint (default: ./terminator.pt).",
)
parser.add_argument(
"--output-dir", type=Path, default="./model_dir",
help="Destination directory (default: ./model_dir; created if missing).",
)
parser.add_argument(
"--threshold", type=float, default=0.7,
help="Terminator firing threshold (default 0.7).",
)
parser.add_argument(
"--window-size", type=int, default=10,
help="Sliding window size for majority vote (default 10).",
)
parser.add_argument(
"--exit-message", type=str,
default="\nI've run out of thinking tokens. I need to commit to a final answer.",
help="Message forced when terminator fires (default: standard exit message). "
"Set to empty string to disable.",
)
parser.add_argument(
"--no-download", action="store_true",
help="Fail if the base model is not already cached locally "
"(by default, downloads from HuggingFace if needed).",
)
parser.add_argument(
"--force", action="store_true",
help="Overwrite files in existing output directory.",
)
args = parser.parse_args()
checkpoint = args.checkpoint.resolve()
out_dir = args.output_dir.resolve()
if not checkpoint.is_file():
print(f"ERROR: checkpoint not found: {checkpoint}", file=sys.stderr)
sys.exit(1)
out_dir.mkdir(parents=True, exist_ok=True)
# --- Build patched config.json ---
print(f"Loading config for {args.base_model} from HF cache...")
config = AutoConfig.from_pretrained(args.base_model)
config.architectures = ["Qwen3TerminatorForCausalLM"]
config.terminator_checkpoint_path = str(checkpoint)
config.terminator_threshold = args.threshold
config.terminator_window_size = args.window_size
config.terminator_exit_message = args.exit_message
# Remove auto_map if present from an older span-predictor config
if hasattr(config, "auto_map"):
del config.auto_map
config.save_pretrained(out_dir)
print(f" Wrote config.json -> {out_dir / 'config.json'}")
# --- Symlink weights and tokenizer files from HF cache ---
print(f"Locating {args.base_model} in HF cache...")
allow_download = not args.no_download
base_dir = Path(snapshot_download(args.base_model, local_files_only=not allow_download))
print(f" Found: {base_dir}")
linked = 0
for src in sorted(base_dir.iterdir()):
if src.name in ("config.json",):
continue # we already wrote our own
dst = out_dir / src.name
if dst.exists() or dst.is_symlink():
if args.force:
dst.unlink()
else:
continue
os.symlink(src, dst)
print(f" Linked {src.name}")
linked += 1
print(f"\nDone. Linked {linked} files into {out_dir}")
print(f"\nTo start the server:")
print(f" ./start_server.sh")
print(f"\nOr manually:")
print(f" VLLM_MODEL={out_dir} REASONING_PARSER=qwen3 python serve.py")
if __name__ == "__main__":
main()