| """ |
| Full server warmup to pre-warm Triton autotuning and CUDA graph capture. |
| |
| On cold H200 nodes (new nodes or after container recreation), CUDA graph capture |
| triggers Triton autotuning which takes ~330s per server launch. This script |
| launches actual servers with CUDA graphs enabled to cache the autotuned kernels, |
| so subsequent test launches are fast (~30-60s). |
| |
| Uses marker files to skip warmup on already-warm nodes. Marker files are |
| invalidated when Python, Triton, or PyTorch versions change. |
| |
| Usage: |
| python3 scripts/ci/cuda/warmup_server.py \ |
| deepseek-ai/DeepSeek-V3-0324:8 \ |
| inclusionAI/Ring-2.5-1T:8 |
| """ |
|
|
| import hashlib |
| import json |
| import os |
| import signal |
| import subprocess |
| import sys |
| import tempfile |
| import time |
| from pathlib import Path |
|
|
| |
| sys.path.insert(0, os.path.dirname(__file__)) |
| from warmup_deep_gemm import get_architecture_key, get_config_json |
|
|
| MARKER_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sglang", "warmup_markers") |
| HEALTH_POLL_INTERVAL = 10 |
| SERVER_STARTUP_TIMEOUT = 900 |
| DEFAULT_PORT = 39876 |
|
|
|
|
| def get_version_key(): |
| """Hash of Python + Triton + PyTorch versions to invalidate markers on upgrades.""" |
| parts = [sys.version] |
| try: |
| import triton |
|
|
| parts.append(f"triton={triton.__version__}") |
| except ImportError: |
| parts.append("triton=none") |
| try: |
| import torch |
|
|
| parts.append(f"torch={torch.__version__}") |
| except ImportError: |
| parts.append("torch=none") |
| return hashlib.sha256("|".join(parts).encode()).hexdigest()[:12] |
|
|
|
|
| def get_marker_path(model, tp): |
| """Get the marker file path for a model:tp pair.""" |
| version_key = get_version_key() |
| safe_model = model.replace("/", "--") |
| return os.path.join( |
| MARKER_DIR, f"server_warmup_{safe_model}_tp{tp}_{version_key}.done" |
| ) |
|
|
|
|
| def check_marker(model, tp): |
| """Check if warmup marker exists (node already warm).""" |
| marker = get_marker_path(model, tp) |
| return os.path.exists(marker) |
|
|
|
|
| def write_marker(model, tp): |
| """Write warmup marker after successful warmup.""" |
| marker = get_marker_path(model, tp) |
| os.makedirs(os.path.dirname(marker), exist_ok=True) |
| Path(marker).write_text( |
| json.dumps( |
| { |
| "model": model, |
| "tp": tp, |
| "version_key": get_version_key(), |
| "timestamp": time.time(), |
| } |
| ) |
| ) |
| print(f" Wrote marker: {marker}") |
|
|
|
|
| def kill_server(proc): |
| """Kill server process tree.""" |
| if proc.poll() is not None: |
| return |
| try: |
| os.killpg(os.getpgid(proc.pid), signal.SIGTERM) |
| except (ProcessLookupError, OSError): |
| pass |
| try: |
| proc.wait(timeout=15) |
| except subprocess.TimeoutExpired: |
| try: |
| os.killpg(os.getpgid(proc.pid), signal.SIGKILL) |
| except (ProcessLookupError, OSError): |
| pass |
| try: |
| proc.wait(timeout=5) |
| except subprocess.TimeoutExpired: |
| pass |
|
|
|
|
| def wait_for_server(base_url, proc, timeout): |
| """Poll /health_generate until server is ready or timeout.""" |
| import requests |
|
|
| start = time.time() |
| while time.time() - start < timeout: |
| ret = proc.poll() |
| if ret is not None: |
| return False, f"Server exited with code {ret}" |
| try: |
| resp = requests.get(f"{base_url}/health_generate", timeout=5) |
| if resp.status_code == 200: |
| return True, None |
| except requests.RequestException: |
| pass |
| time.sleep(HEALTH_POLL_INTERVAL) |
| return False, "Timed out waiting for server" |
|
|
|
|
| def send_generate_request(base_url): |
| """Send one /generate request to exercise the full inference path.""" |
| import requests |
|
|
| payload = { |
| "input_ids": [0, 1, 2, 3], |
| "sampling_params": { |
| "max_new_tokens": 8, |
| "temperature": 0, |
| }, |
| } |
| try: |
| resp = requests.post(f"{base_url}/generate", json=payload, timeout=120) |
| if resp.status_code == 200: |
| print(" Generate request succeeded") |
| else: |
| print(f" Warning: generate request returned {resp.status_code}") |
| except requests.RequestException as e: |
| print(f" Warning: generate request failed: {e}") |
|
|
|
|
| def warmup_one_model(model, tp, port): |
| """Launch server, wait for ready, send one request, then kill.""" |
| base_url = f"http://127.0.0.1:{port}" |
|
|
| cmd = [ |
| sys.executable, |
| "-m", |
| "sglang.launch_server", |
| "--model-path", |
| model, |
| "--tp", |
| str(tp), |
| "--host", |
| "127.0.0.1", |
| "--port", |
| str(port), |
| "--trust-remote-code", |
| "--model-loader-extra-config", |
| '{"enable_multithread_load": true, "num_threads": 64}', |
| ] |
|
|
| |
| |
| log_file = tempfile.NamedTemporaryFile( |
| mode="w", prefix="warmup_server_", suffix=".log", delete=False |
| ) |
| log_path = log_file.name |
|
|
| print(f" Launching server: {' '.join(cmd)}") |
| print(f" Server log: {log_path}") |
| proc = subprocess.Popen( |
| cmd, |
| stdout=log_file, |
| stderr=subprocess.STDOUT, |
| preexec_fn=os.setsid, |
| ) |
|
|
| try: |
| |
| print( |
| f" Waiting for server (timeout={SERVER_STARTUP_TIMEOUT}s, " |
| f"polling every {HEALTH_POLL_INTERVAL}s)..." |
| ) |
| ok, err = wait_for_server(base_url, proc, SERVER_STARTUP_TIMEOUT) |
| if not ok: |
| print(f" Warning: server not ready: {err}") |
| |
| try: |
| log_file.flush() |
| with open(log_path) as f: |
| lines = f.readlines() |
| for line in lines[-20:]: |
| print(f" | {line.rstrip()}") |
| except Exception: |
| pass |
| return False |
|
|
| print(" Server ready, sending generate request...") |
| send_generate_request(base_url) |
| return True |
|
|
| finally: |
| print(" Killing server...") |
| kill_server(proc) |
| log_file.close() |
| try: |
| os.unlink(log_path) |
| except OSError: |
| pass |
|
|
|
|
| def main(): |
| if len(sys.argv) < 2 or sys.argv[1] in ("-h", "--help"): |
| print("Usage: warmup_server.py model1:tp1 [model2:tp2 ...]") |
| print( |
| "\nLaunches full servers with CUDA graphs enabled to pre-warm" |
| " Triton autotuning." |
| ) |
| print("Skips instantly on warm nodes (marker file exists).") |
| sys.exit(0) |
|
|
| |
| model_tp_pairs = [] |
| for arg in sys.argv[1:]: |
| if ":" not in arg: |
| print(f"Error: expected model:tp format, got '{arg}'") |
| sys.exit(1) |
| model, tp_str = arg.rsplit(":", 1) |
| model_tp_pairs.append((model, int(tp_str))) |
|
|
| print(f"=== Server CUDA Graph Warmup ({len(model_tp_pairs)} model(s)) ===") |
| print(f" Marker dir: {MARKER_DIR}") |
| print(f" Version key: {get_version_key()}\n") |
|
|
| |
| seen_keys = {} |
| to_warmup = [] |
|
|
| for model, tp in model_tp_pairs: |
| |
| if check_marker(model, tp): |
| print(f" SKIP {model} (tp={tp}): already warm (marker exists)") |
| continue |
|
|
| |
| config = get_config_json(model) |
| if config is not None: |
| key = get_architecture_key(config, tp) |
| if key in seen_keys: |
| print( |
| f" DEDUP {model} (tp={tp}): same architecture as {seen_keys[key]}" |
| ) |
| continue |
| seen_keys[key] = model |
|
|
| to_warmup.append((model, tp)) |
| print(f" QUEUE {model} (tp={tp}): needs warmup") |
|
|
| if not to_warmup: |
| print("\nAll models already warm. Done.") |
| return |
|
|
| print(f"\n{len(to_warmup)} model(s) to warm up.\n") |
|
|
| port = DEFAULT_PORT |
| for i, (model, tp) in enumerate(to_warmup, 1): |
| print(f"\n{'=' * 60}") |
| print(f"[{i}/{len(to_warmup)}] {model} (tp={tp})") |
| print(f"{'=' * 60}") |
|
|
| t0 = time.time() |
| success = warmup_one_model(model, tp, port) |
| elapsed = time.time() - t0 |
|
|
| if success: |
| print(f" Completed in {elapsed:.0f}s") |
| write_marker(model, tp) |
| |
| config = get_config_json(model) |
| if config is not None: |
| key = get_architecture_key(config, tp) |
| for other_model, other_tp in model_tp_pairs: |
| if (other_model, other_tp) == (model, tp): |
| continue |
| other_config = get_config_json(other_model) |
| if other_config is not None: |
| other_key = get_architecture_key(other_config, other_tp) |
| if other_key == key and not check_marker(other_model, other_tp): |
| write_marker(other_model, other_tp) |
| print( |
| f" Also marked {other_model} (tp={other_tp}) as warm (same arch)" |
| ) |
| else: |
| print( |
| f" Warning: warmup failed after {elapsed:.0f}s (non-fatal, tests will still work)" |
| ) |
|
|
| |
| port += 100 |
|
|
| print("\nServer CUDA graph warmup complete.") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|