File size: 6,340 Bytes
0a55ff6 8beaa8b 0a55ff6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 | #!/usr/bin/env python3
"""
serve_vllm.py — VENUE ONLY (Prime Intellect, CUDA GPU). DOES NOT RUN ON THE MAC.
This is a thin, documented wrapper that prints (and optionally execs) the exact
`vllm serve` command for three configs:
1. baseline — Laguna XS.2 alone (the speed floor).
2. dflash — Laguna XS.2 + the DFlash speculator (the speed we're claiming).
3. quant — a quantized Laguna checkpoint (FP8/INT4/NVFP4) + FP8 KV cache.
This is the FALLBACK lane (see FALLBACK_QUANT.md): if DFlash hits
a vLLM-version/draft-model snag at the venue, a quantized weights
checkpoint still tells a clean single-GPU story (smaller footprint,
FP8 KV cache ~doubles concurrent trajectories per the [TR]).
baseline vs dflash are IDENTICAL except for --speculative-config — flip one flag,
get faster tokens, same greedy output. quant is a different lever (shrink each
pass instead of cutting passes); the two can stack, but the fallback keeps it
simple with quant alone.
Grounding (cite at the demo):
- DFlash config shape is from the HF model card
huggingface.co/poolside/Laguna-XS.2-speculator.dflash:
--speculative-config '{"model":"poolside/Laguna-XS.2-speculator.dflash",
"num_speculative_tokens":7,"method":"dflash"}'
- num_speculative_tokens = 7 is the card's value (this is gamma, the draft length).
- vLLM >= 0.21.0 and VLLM_USE_DEEP_GEMM=0 per the card.
- parsers --tool-call-parser poolside_v1 / --reasoning-parser poolside_v1 per the card.
VERIFY AT ONBOARDING: exact vLLM version on the PI image, whether
--trust-remote-code is required, and whether `method` is spelled "dflash"
in the build you get. The card is authoritative; confirm against `vllm serve --help`.
Usage (on Prime Intellect):
python scripts/serve_vllm.py --mode baseline --print # show the command
python scripts/serve_vllm.py --mode dflash --run # actually serve
"""
from __future__ import annotations
import argparse
import json
import os
import shlex
import subprocess
import sys
MODEL = os.environ.get("LAGUNA_MODEL", "poolside/Laguna-XS.2")
SPECULATOR = os.environ.get("LAGUNA_SPECULATOR", "poolside/Laguna-XS.2-speculator.dflash")
# Draft length gamma. Per the DFlash model card.
NUM_SPECULATIVE_TOKENS = 7
# Quantized checkpoints for the fallback lane. The [TR] says XS.2 ships FP8 (W8A8),
# INT4 (W4A16/AWQ) and NVFP4 quants in the HF collection. EXACT repo names are NOT
# confirmed pre-event — these are documented placeholders; VERIFY AT ONBOARDING
# against huggingface.co/collections/poolside/laguna-xs2 (or override via env).
QUANT_MODELS = {
"fp8": os.environ.get("LAGUNA_FP8_MODEL", "poolside/Laguna-XS.2-FP8"),
"int4": os.environ.get("LAGUNA_INT4_MODEL", "poolside/Laguna-XS.2-INT4"),
"nvfp4": os.environ.get("LAGUNA_NVFP4_MODEL", "poolside/Laguna-XS.2-NVFP4"),
}
def build_cmd(mode: str, max_model_len: int, tp: int, quant: str) -> list[str]:
model = QUANT_MODELS[quant] if mode == "quant" else MODEL
base = [
"vllm", "serve", model,
"--tensor-parallel-size", str(tp),
"--max-model-len", str(max_model_len),
"--served-model-name", "laguna",
# Poolside-specific parsers (from the model card):
"--tool-call-parser", "poolside_v1",
"--reasoning-parser", "poolside_v1",
"--enable-auto-tool-choice",
# enable_thinking: the Laguna chat template defaults this FALSE. Keep it false so
# rollouts/decode are non-thinking (fewer tokens, faster) and the greedy A/B stays clean.
# NOTE: the hosted pinference endpoint IGNORES this flag (verified — see
# autoresearch/findings.md); it only takes effect on a self-served vLLM like this one.
# Override with LAGUNA_ENABLE_THINKING=true.
"--default-chat-template-kwargs",
json.dumps({"enable_thinking": os.environ.get("LAGUNA_ENABLE_THINKING", "false").lower() == "true"}),
]
if mode == "dflash":
spec = {
"model": SPECULATOR,
"num_speculative_tokens": NUM_SPECULATIVE_TOKENS,
"method": "dflash",
}
base += ["--speculative-config", json.dumps(spec)]
if mode == "quant":
# FP8 KV cache is the high-leverage single-GPU win ([TR]: ~2x concurrent
# trajectories). Weight quant is auto-detected from the checkpoint config.
base += ["--kv-cache-dtype", "fp8"]
return base
def main() -> None:
if sys.platform == "darwin":
print("[serve_vllm] REFUSING TO RUN: this is a Mac. vLLM needs CUDA.\n"
" Run this on Prime Intellect. Use --print to inspect the command here.",
file=sys.stderr)
# Still allow --print on Mac for inspection; block --run.
p = argparse.ArgumentParser(description="Print/run the vLLM serve command for Laguna (baseline / dflash / quant).")
p.add_argument("--mode", choices=["baseline", "dflash", "quant"], required=True)
p.add_argument("--quant", choices=["fp8", "int4", "nvfp4"], default="fp8",
help="Quant format for --mode quant (the fallback lane). Default fp8.")
p.add_argument("--max-model-len", type=int, default=16384,
help="Card example uses 16384; raise toward 131072/262144 if VRAM allows. Verify at onboarding.")
p.add_argument("--tensor-parallel-size", type=int, default=1,
help="Single GPU = 1. The whole hook is one-GPU serving.")
g = p.add_mutually_exclusive_group(required=True)
g.add_argument("--print", action="store_true", help="Print the command only.")
g.add_argument("--run", action="store_true", help="Actually exec vllm serve (venue only).")
args = p.parse_args()
cmd = build_cmd(args.mode, args.max_model_len, args.tensor_parallel_size, args.quant)
env_prefix = "VLLM_USE_DEEP_GEMM=0"
printable = f"{env_prefix} " + " ".join(shlex.quote(c) for c in cmd)
print(printable)
if args.run:
if sys.platform == "darwin":
print("[serve_vllm] --run blocked on Mac.", file=sys.stderr)
sys.exit(2)
env = dict(os.environ)
env["VLLM_USE_DEEP_GEMM"] = "0" # per the model card
os.execvpe(cmd[0], cmd, env)
if __name__ == "__main__":
main()
|