File size: 4,345 Bytes
88b9f90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa7a04b
88b9f90
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
Create a vLLM-ready model directory for Qwen3TerminatorForCausalLM.

Downloads the base Qwen3-8B config and weights from HuggingFace (if not
already cached), then creates a model directory with:
  - config.json        (Qwen3-8B base config + terminator fields)
  - tokenizer files    (symlinked from HF cache)
  - model weights      (symlinked from HF cache)

Usage:
    # Default: uses ./terminator.pt checkpoint, creates ./model_dir
    python setup_model_dir.py

    # Custom paths and settings:
    python setup_model_dir.py \\
        --checkpoint /path/to/terminator.pt \\
        --output-dir /path/to/model_dir \\
        --threshold 0.5
"""

import argparse
import os
import sys
from pathlib import Path

from huggingface_hub import snapshot_download
from transformers import AutoConfig


def main():
    parser = argparse.ArgumentParser(
        description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter
    )
    parser.add_argument(
        "--base-model", default="Qwen/Qwen3-8B",
        help="HuggingFace model ID for the base model (default: Qwen/Qwen3-8B).",
    )
    parser.add_argument(
        "--checkpoint", type=Path, default="./terminator.pt",
        help="Path to trained terminator .pt checkpoint (default: ./terminator.pt).",
    )
    parser.add_argument(
        "--output-dir", type=Path, default="./model_dir",
        help="Destination directory (default: ./model_dir; created if missing).",
    )
    parser.add_argument(
        "--threshold", type=float, default=0.7,
        help="Terminator firing threshold (default 0.7).",
    )
    parser.add_argument(
        "--window-size", type=int, default=10,
        help="Sliding window size for majority vote (default 10).",
    )
    parser.add_argument(
        "--exit-message", type=str,
        default="\nI've run out of thinking tokens. I need to commit to a final answer.",
        help="Message forced when terminator fires (default: standard exit message). "
             "Set to empty string to disable.",
    )
    parser.add_argument(
        "--no-download", action="store_true",
        help="Fail if the base model is not already cached locally "
             "(by default, downloads from HuggingFace if needed).",
    )
    parser.add_argument(
        "--force", action="store_true",
        help="Overwrite files in existing output directory.",
    )
    args = parser.parse_args()

    checkpoint = args.checkpoint.resolve()
    out_dir = args.output_dir.resolve()

    if not checkpoint.is_file():
        print(f"ERROR: checkpoint not found: {checkpoint}", file=sys.stderr)
        sys.exit(1)

    out_dir.mkdir(parents=True, exist_ok=True)

    # --- Build patched config.json ---
    print(f"Loading config for {args.base_model} from HF cache...")
    config = AutoConfig.from_pretrained(args.base_model)

    config.architectures = ["Qwen3TerminatorForCausalLM"]
    config.terminator_checkpoint_path = str(checkpoint)
    config.terminator_threshold = args.threshold
    config.terminator_window_size = args.window_size
    config.terminator_exit_message = args.exit_message

    # Remove auto_map if present from an older span-predictor config
    if hasattr(config, "auto_map"):
        del config.auto_map

    config.save_pretrained(out_dir)
    print(f"  Wrote config.json -> {out_dir / 'config.json'}")

    # --- Symlink weights and tokenizer files from HF cache ---
    print(f"Locating {args.base_model} in HF cache...")
    allow_download = not args.no_download
    base_dir = Path(snapshot_download(args.base_model, local_files_only=not allow_download))
    print(f"  Found: {base_dir}")

    linked = 0
    for src in sorted(base_dir.iterdir()):
        if src.name in ("config.json",):
            continue  # we already wrote our own

        dst = out_dir / src.name
        if dst.exists() or dst.is_symlink():
            if args.force:
                dst.unlink()
            else:
                continue

        os.symlink(src, dst)
        print(f"  Linked {src.name}")
        linked += 1

    print(f"\nDone. Linked {linked} files into {out_dir}")
    print(f"\nTo start the server:")
    print(f"  ./start_server.sh")
    print(f"\nOr manually:")
    print(f"  VLLM_MODEL={out_dir} REASONING_PARSER=qwen3 python serve.py")


if __name__ == "__main__":
    main()