File size: 6,340 Bytes
0a55ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8beaa8b
 
 
 
 
 
 
0a55ff6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#!/usr/bin/env python3
"""
serve_vllm.py — VENUE ONLY (Prime Intellect, CUDA GPU). DOES NOT RUN ON THE MAC.

This is a thin, documented wrapper that prints (and optionally execs) the exact
`vllm serve` command for three configs:

  1. baseline  — Laguna XS.2 alone (the speed floor).
  2. dflash    — Laguna XS.2 + the DFlash speculator (the speed we're claiming).
  3. quant     — a quantized Laguna checkpoint (FP8/INT4/NVFP4) + FP8 KV cache.
                 This is the FALLBACK lane (see FALLBACK_QUANT.md): if DFlash hits
                 a vLLM-version/draft-model snag at the venue, a quantized weights
                 checkpoint still tells a clean single-GPU story (smaller footprint,
                 FP8 KV cache ~doubles concurrent trajectories per the [TR]).

baseline vs dflash are IDENTICAL except for --speculative-config — flip one flag,
get faster tokens, same greedy output. quant is a different lever (shrink each
pass instead of cutting passes); the two can stack, but the fallback keeps it
simple with quant alone.

Grounding (cite at the demo):
  - DFlash config shape is from the HF model card
    huggingface.co/poolside/Laguna-XS.2-speculator.dflash:
        --speculative-config '{"model":"poolside/Laguna-XS.2-speculator.dflash",
                               "num_speculative_tokens":7,"method":"dflash"}'
  - num_speculative_tokens = 7 is the card's value (this is gamma, the draft length).
  - vLLM >= 0.21.0 and VLLM_USE_DEEP_GEMM=0 per the card.
  - parsers --tool-call-parser poolside_v1 / --reasoning-parser poolside_v1 per the card.

VERIFY AT ONBOARDING: exact vLLM version on the PI image, whether
--trust-remote-code is required, and whether `method` is spelled "dflash"
in the build you get. The card is authoritative; confirm against `vllm serve --help`.

Usage (on Prime Intellect):
  python scripts/serve_vllm.py --mode baseline --print     # show the command
  python scripts/serve_vllm.py --mode dflash --run         # actually serve
"""
from __future__ import annotations

import argparse
import json
import os
import shlex
import subprocess
import sys

MODEL = os.environ.get("LAGUNA_MODEL", "poolside/Laguna-XS.2")
SPECULATOR = os.environ.get("LAGUNA_SPECULATOR", "poolside/Laguna-XS.2-speculator.dflash")

# Draft length gamma. Per the DFlash model card.
NUM_SPECULATIVE_TOKENS = 7

# Quantized checkpoints for the fallback lane. The [TR] says XS.2 ships FP8 (W8A8),
# INT4 (W4A16/AWQ) and NVFP4 quants in the HF collection. EXACT repo names are NOT
# confirmed pre-event — these are documented placeholders; VERIFY AT ONBOARDING
# against huggingface.co/collections/poolside/laguna-xs2 (or override via env).
QUANT_MODELS = {
    "fp8":   os.environ.get("LAGUNA_FP8_MODEL",   "poolside/Laguna-XS.2-FP8"),
    "int4":  os.environ.get("LAGUNA_INT4_MODEL",  "poolside/Laguna-XS.2-INT4"),
    "nvfp4": os.environ.get("LAGUNA_NVFP4_MODEL", "poolside/Laguna-XS.2-NVFP4"),
}


def build_cmd(mode: str, max_model_len: int, tp: int, quant: str) -> list[str]:
    model = QUANT_MODELS[quant] if mode == "quant" else MODEL
    base = [
        "vllm", "serve", model,
        "--tensor-parallel-size", str(tp),
        "--max-model-len", str(max_model_len),
        "--served-model-name", "laguna",
        # Poolside-specific parsers (from the model card):
        "--tool-call-parser", "poolside_v1",
        "--reasoning-parser", "poolside_v1",
        "--enable-auto-tool-choice",
        # enable_thinking: the Laguna chat template defaults this FALSE. Keep it false so
        # rollouts/decode are non-thinking (fewer tokens, faster) and the greedy A/B stays clean.
        # NOTE: the hosted pinference endpoint IGNORES this flag (verified — see
        # autoresearch/findings.md); it only takes effect on a self-served vLLM like this one.
        # Override with LAGUNA_ENABLE_THINKING=true.
        "--default-chat-template-kwargs",
        json.dumps({"enable_thinking": os.environ.get("LAGUNA_ENABLE_THINKING", "false").lower() == "true"}),
    ]
    if mode == "dflash":
        spec = {
            "model": SPECULATOR,
            "num_speculative_tokens": NUM_SPECULATIVE_TOKENS,
            "method": "dflash",
        }
        base += ["--speculative-config", json.dumps(spec)]
    if mode == "quant":
        # FP8 KV cache is the high-leverage single-GPU win ([TR]: ~2x concurrent
        # trajectories). Weight quant is auto-detected from the checkpoint config.
        base += ["--kv-cache-dtype", "fp8"]
    return base


def main() -> None:
    if sys.platform == "darwin":
        print("[serve_vllm] REFUSING TO RUN: this is a Mac. vLLM needs CUDA.\n"
              "             Run this on Prime Intellect. Use --print to inspect the command here.",
              file=sys.stderr)
        # Still allow --print on Mac for inspection; block --run.

    p = argparse.ArgumentParser(description="Print/run the vLLM serve command for Laguna (baseline / dflash / quant).")
    p.add_argument("--mode", choices=["baseline", "dflash", "quant"], required=True)
    p.add_argument("--quant", choices=["fp8", "int4", "nvfp4"], default="fp8",
                   help="Quant format for --mode quant (the fallback lane). Default fp8.")
    p.add_argument("--max-model-len", type=int, default=16384,
                   help="Card example uses 16384; raise toward 131072/262144 if VRAM allows. Verify at onboarding.")
    p.add_argument("--tensor-parallel-size", type=int, default=1,
                   help="Single GPU = 1. The whole hook is one-GPU serving.")
    g = p.add_mutually_exclusive_group(required=True)
    g.add_argument("--print", action="store_true", help="Print the command only.")
    g.add_argument("--run", action="store_true", help="Actually exec vllm serve (venue only).")
    args = p.parse_args()

    cmd = build_cmd(args.mode, args.max_model_len, args.tensor_parallel_size, args.quant)
    env_prefix = "VLLM_USE_DEEP_GEMM=0"
    printable = f"{env_prefix} " + " ".join(shlex.quote(c) for c in cmd)
    print(printable)

    if args.run:
        if sys.platform == "darwin":
            print("[serve_vllm] --run blocked on Mac.", file=sys.stderr)
            sys.exit(2)
        env = dict(os.environ)
        env["VLLM_USE_DEEP_GEMM"] = "0"  # per the model card
        os.execvpe(cmd[0], cmd, env)


if __name__ == "__main__":
    main()