Hanrui / sglang /scripts /ci /cuda /warmup_server.py
Lekr0's picture
Add files using upload-large-folder tool
61ba51e verified
"""
Full server warmup to pre-warm Triton autotuning and CUDA graph capture.
On cold H200 nodes (new nodes or after container recreation), CUDA graph capture
triggers Triton autotuning which takes ~330s per server launch. This script
launches actual servers with CUDA graphs enabled to cache the autotuned kernels,
so subsequent test launches are fast (~30-60s).
Uses marker files to skip warmup on already-warm nodes. Marker files are
invalidated when Python, Triton, or PyTorch versions change.
Usage:
python3 scripts/ci/cuda/warmup_server.py \
deepseek-ai/DeepSeek-V3-0324:8 \
inclusionAI/Ring-2.5-1T:8
"""
import hashlib
import json
import os
import signal
import subprocess
import sys
import tempfile
import time
from pathlib import Path
# Reuse helpers from warmup_deep_gemm (same directory)
sys.path.insert(0, os.path.dirname(__file__))
from warmup_deep_gemm import get_architecture_key, get_config_json
MARKER_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sglang", "warmup_markers")
HEALTH_POLL_INTERVAL = 10 # seconds between health checks
SERVER_STARTUP_TIMEOUT = 900 # 15 min max to wait for server ready
DEFAULT_PORT = 39876
def get_version_key():
"""Hash of Python + Triton + PyTorch versions to invalidate markers on upgrades."""
parts = [sys.version]
try:
import triton
parts.append(f"triton={triton.__version__}")
except ImportError:
parts.append("triton=none")
try:
import torch
parts.append(f"torch={torch.__version__}")
except ImportError:
parts.append("torch=none")
return hashlib.sha256("|".join(parts).encode()).hexdigest()[:12]
def get_marker_path(model, tp):
"""Get the marker file path for a model:tp pair."""
version_key = get_version_key()
safe_model = model.replace("/", "--")
return os.path.join(
MARKER_DIR, f"server_warmup_{safe_model}_tp{tp}_{version_key}.done"
)
def check_marker(model, tp):
"""Check if warmup marker exists (node already warm)."""
marker = get_marker_path(model, tp)
return os.path.exists(marker)
def write_marker(model, tp):
"""Write warmup marker after successful warmup."""
marker = get_marker_path(model, tp)
os.makedirs(os.path.dirname(marker), exist_ok=True)
Path(marker).write_text(
json.dumps(
{
"model": model,
"tp": tp,
"version_key": get_version_key(),
"timestamp": time.time(),
}
)
)
print(f" Wrote marker: {marker}")
def kill_server(proc):
"""Kill server process tree."""
if proc.poll() is not None:
return
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except (ProcessLookupError, OSError):
pass
try:
proc.wait(timeout=15)
except subprocess.TimeoutExpired:
try:
os.killpg(os.getpgid(proc.pid), signal.SIGKILL)
except (ProcessLookupError, OSError):
pass
try:
proc.wait(timeout=5)
except subprocess.TimeoutExpired:
pass
def wait_for_server(base_url, proc, timeout):
"""Poll /health_generate until server is ready or timeout."""
import requests
start = time.time()
while time.time() - start < timeout:
ret = proc.poll()
if ret is not None:
return False, f"Server exited with code {ret}"
try:
resp = requests.get(f"{base_url}/health_generate", timeout=5)
if resp.status_code == 200:
return True, None
except requests.RequestException:
pass
time.sleep(HEALTH_POLL_INTERVAL)
return False, "Timed out waiting for server"
def send_generate_request(base_url):
"""Send one /generate request to exercise the full inference path."""
import requests
payload = {
"input_ids": [0, 1, 2, 3],
"sampling_params": {
"max_new_tokens": 8,
"temperature": 0,
},
}
try:
resp = requests.post(f"{base_url}/generate", json=payload, timeout=120)
if resp.status_code == 200:
print(" Generate request succeeded")
else:
print(f" Warning: generate request returned {resp.status_code}")
except requests.RequestException as e:
print(f" Warning: generate request failed: {e}")
def warmup_one_model(model, tp, port):
"""Launch server, wait for ready, send one request, then kill."""
base_url = f"http://127.0.0.1:{port}"
cmd = [
sys.executable,
"-m",
"sglang.launch_server",
"--model-path",
model,
"--tp",
str(tp),
"--host",
"127.0.0.1",
"--port",
str(port),
"--trust-remote-code",
"--model-loader-extra-config",
'{"enable_multithread_load": true, "num_threads": 64}',
]
# Use a temp file for server output to avoid pipe buffer deadlock
# (server logs can exceed the 64KB pipe buffer during CUDA graph capture)
log_file = tempfile.NamedTemporaryFile(
mode="w", prefix="warmup_server_", suffix=".log", delete=False
)
log_path = log_file.name
print(f" Launching server: {' '.join(cmd)}")
print(f" Server log: {log_path}")
proc = subprocess.Popen(
cmd,
stdout=log_file,
stderr=subprocess.STDOUT,
preexec_fn=os.setsid,
)
try:
# Wait for server to be ready (includes CUDA graph capture)
print(
f" Waiting for server (timeout={SERVER_STARTUP_TIMEOUT}s, "
f"polling every {HEALTH_POLL_INTERVAL}s)..."
)
ok, err = wait_for_server(base_url, proc, SERVER_STARTUP_TIMEOUT)
if not ok:
print(f" Warning: server not ready: {err}")
# Dump last lines of server log for debugging
try:
log_file.flush()
with open(log_path) as f:
lines = f.readlines()
for line in lines[-20:]:
print(f" | {line.rstrip()}")
except Exception:
pass
return False
print(" Server ready, sending generate request...")
send_generate_request(base_url)
return True
finally:
print(" Killing server...")
kill_server(proc)
log_file.close()
try:
os.unlink(log_path)
except OSError:
pass
def main():
if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"):
print("Usage: warmup_server.py model1:tp1 [model2:tp2 ...]")
print(
"\nLaunches full servers with CUDA graphs enabled to pre-warm"
" Triton autotuning."
)
print("Skips instantly on warm nodes (marker file exists).")
sys.exit(0)
# Parse model:tp pairs
model_tp_pairs = []
for arg in sys.argv[1:]:
if ":" not in arg:
print(f"Error: expected model:tp format, got '{arg}'")
sys.exit(1)
model, tp_str = arg.rsplit(":", 1)
model_tp_pairs.append((model, int(tp_str)))
print(f"=== Server CUDA Graph Warmup ({len(model_tp_pairs)} model(s)) ===")
print(f" Marker dir: {MARKER_DIR}")
print(f" Version key: {get_version_key()}\n")
# Deduplicate by architecture and check markers
seen_keys = {}
to_warmup = []
for model, tp in model_tp_pairs:
# Check marker first (fast path)
if check_marker(model, tp):
print(f" SKIP {model} (tp={tp}): already warm (marker exists)")
continue
# Architecture dedup
config = get_config_json(model)
if config is not None:
key = get_architecture_key(config, tp)
if key in seen_keys:
print(
f" DEDUP {model} (tp={tp}): same architecture as {seen_keys[key]}"
)
continue
seen_keys[key] = model
to_warmup.append((model, tp))
print(f" QUEUE {model} (tp={tp}): needs warmup")
if not to_warmup:
print("\nAll models already warm. Done.")
return
print(f"\n{len(to_warmup)} model(s) to warm up.\n")
port = DEFAULT_PORT
for i, (model, tp) in enumerate(to_warmup, 1):
print(f"\n{'=' * 60}")
print(f"[{i}/{len(to_warmup)}] {model} (tp={tp})")
print(f"{'=' * 60}")
t0 = time.time()
success = warmup_one_model(model, tp, port)
elapsed = time.time() - t0
if success:
print(f" Completed in {elapsed:.0f}s")
write_marker(model, tp)
# Also write markers for dedup'd models that share this architecture
config = get_config_json(model)
if config is not None:
key = get_architecture_key(config, tp)
for other_model, other_tp in model_tp_pairs:
if (other_model, other_tp) == (model, tp):
continue
other_config = get_config_json(other_model)
if other_config is not None:
other_key = get_architecture_key(other_config, other_tp)
if other_key == key and not check_marker(other_model, other_tp):
write_marker(other_model, other_tp)
print(
f" Also marked {other_model} (tp={other_tp}) as warm (same arch)"
)
else:
print(
f" Warning: warmup failed after {elapsed:.0f}s (non-fatal, tests will still work)"
)
# Use a different port for the next model to avoid bind conflicts
port += 100
print("\nServer CUDA graph warmup complete.")
if __name__ == "__main__":
main()