File size: 6,567 Bytes
a402b9b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | import argparse
import asyncio
import copy
import logging
import multiprocessing as mp
import os
import random
import signal
import sys
import time
from typing import List
import requests
from setproctitle import setproctitle
from sglang_router.launch_router import RouterArgs, launch_router
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import is_port_available
def setup_logger():
logger = logging.getLogger("router")
logger.setLevel(logging.INFO)
formatter = logging.Formatter(
"[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d",
datefmt="%Y-%m-%d %H:%M:%S",
)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
logger.addHandler(handler)
return logger
logger = setup_logger()
# Create new process group
def run_server(server_args, dp_rank):
"""
Note:
1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously.
This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes.
Terminal (PGID=100)
βββ Main Python Process (PGID=100)
βββ Server Process 1 (PGID=100)
βββ Scheduler 1
βββ Detokenizer 1
βββ Server Process 2 (PGID=100)
βββ Scheduler 2
βββ Detokenizer 2
2. With os.setpgrp(), the main Python process and its children are in a separate group. Now:
Terminal (PGID=100)
βββ Main Python Process (PGID=200)
βββ Server Process 1 (PGID=300)
βββ Scheduler 1
βββ Detokenizer 1
βββ Server Process 2 (PGID=400)
βββ Scheduler 2
βββ Detokenizer 2
"""
# create new process group
os.setpgrp()
setproctitle("sglang::server")
# Set SGLANG_DP_RANK environment variable
os.environ["SGLANG_DP_RANK"] = str(dp_rank)
# Launch server in appropriate mode (HTTP or gRPC)
if server_args.grpc_mode:
from sglang.srt.entrypoints.grpc_server import serve_grpc
asyncio.run(serve_grpc(server_args))
else:
from sglang.srt.entrypoints.http_server import launch_server
launch_server(server_args)
def launch_server_process(
server_args: ServerArgs, worker_port: int, dp_id: int
) -> mp.Process:
"""Launch a single server process with the given args and port."""
server_args = copy.deepcopy(server_args)
server_args.port = worker_port
server_args.base_gpu_id = dp_id * server_args.tp_size
server_args.dp_size = 1
proc = mp.Process(target=run_server, args=(server_args, dp_id))
proc.start()
return proc
def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool:
"""Wait for server to be healthy by checking /health endpoint."""
start_time = time.perf_counter()
url = f"http://{host}:{port}/health"
while time.perf_counter() - start_time < timeout:
try:
response = requests.get(url, timeout=5)
if response.status_code == 200:
return True
except requests.exceptions.RequestException:
pass
time.sleep(1)
return False
def find_available_ports(base_port: int, count: int) -> List[int]:
"""Find consecutive available ports starting from base_port."""
available_ports = []
current_port = base_port
while len(available_ports) < count:
if is_port_available(current_port):
available_ports.append(current_port)
current_port += random.randint(100, 1000)
return available_ports
def cleanup_processes(processes: List[mp.Process]):
for process in processes:
logger.info(f"Terminating process group {process.pid}")
try:
os.killpg(process.pid, signal.SIGTERM)
except ProcessLookupError:
# Process group may already be terminated
pass
# Wait for processes to terminate
for process in processes:
process.join(timeout=5)
if process.is_alive():
logger.warning(
f"Process {process.pid} did not terminate gracefully, forcing kill"
)
try:
os.killpg(process.pid, signal.SIGKILL)
except ProcessLookupError:
pass
logger.info("All process groups terminated")
def main():
# CUDA runtime isn't fork-safe, which can lead to subtle bugs or crashes
mp.set_start_method("spawn")
parser = argparse.ArgumentParser(
description="Launch SGLang router and server processes"
)
ServerArgs.add_cli_args(parser)
RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True)
parser.add_argument(
"--router-dp-worker-base-port",
type=int,
default=31000,
help="Base port number for data parallel workers",
)
# No extra retry/CB flags here; RouterArgs.add_cli_args already defines them with router- prefix
args = parser.parse_args()
server_args = ServerArgs.from_cli_args(args)
router_args = RouterArgs.from_cli_args(args, use_router_prefix=True)
# Find available ports for workers
worker_ports = find_available_ports(
args.router_dp_worker_base_port, server_args.dp_size
)
# Start server processes
server_processes = []
for i, worker_port in enumerate(worker_ports):
logger.info(f"Launching DP server process {i} on port {worker_port}")
proc = launch_server_process(server_args, worker_port, i)
server_processes.append(proc)
signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes))
signal.signal(
signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes)
)
signal.signal(
signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes)
)
# Update router args with worker URLs
# Use grpc:// protocol if server is in gRPC mode, otherwise http://
protocol = "grpc" if server_args.grpc_mode else "http"
router_args.worker_urls = [
f"{protocol}://{server_args.host}:{port}" for port in worker_ports
]
# Start the router
try:
launch_router(router_args)
except Exception as e:
logger.error(f"Failed to start router: {e}")
cleanup_processes(server_processes)
sys.exit(1)
if __name__ == "__main__":
main()
|