| import argparse |
| import asyncio |
| import copy |
| import logging |
| import multiprocessing as mp |
| import os |
| import random |
| import signal |
| import sys |
| import time |
| from typing import List |
|
|
| import requests |
| from setproctitle import setproctitle |
| from sglang_router.launch_router import RouterArgs, launch_router |
|
|
| from sglang.srt.server_args import ServerArgs |
| from sglang.srt.utils import is_port_available |
|
|
|
|
| def setup_logger(): |
| logger = logging.getLogger("router") |
| logger.setLevel(logging.INFO) |
|
|
| formatter = logging.Formatter( |
| "[Router (Python)] %(asctime)s - %(levelname)s - %(message)s - %(filename)s:%(lineno)d", |
| datefmt="%Y-%m-%d %H:%M:%S", |
| ) |
|
|
| handler = logging.StreamHandler() |
| handler.setFormatter(formatter) |
| logger.addHandler(handler) |
|
|
| return logger |
|
|
|
|
| logger = setup_logger() |
|
|
|
|
| |
| def run_server(server_args, dp_rank): |
| """ |
| Note: |
| |
| 1. Without os.setpgrp(), all processes share the same PGID. When you press Ctrl+C, the terminal sends SIGINT to all processes in the group simultaneously. |
| This can cause leaf processes to terminate first, which messes up the cleaning order and produces orphaned processes. |
| |
| Terminal (PGID=100) |
| └── Main Python Process (PGID=100) |
| └── Server Process 1 (PGID=100) |
| └── Scheduler 1 |
| └── Detokenizer 1 |
| └── Server Process 2 (PGID=100) |
| └── Scheduler 2 |
| └── Detokenizer 2 |
| |
| 2. With os.setpgrp(), the main Python process and its children are in a separate group. Now: |
| |
| Terminal (PGID=100) |
| └── Main Python Process (PGID=200) |
| └── Server Process 1 (PGID=300) |
| └── Scheduler 1 |
| └── Detokenizer 1 |
| └── Server Process 2 (PGID=400) |
| └── Scheduler 2 |
| └── Detokenizer 2 |
| """ |
| |
| os.setpgrp() |
|
|
| setproctitle("sglang::server") |
| |
| os.environ["SGLANG_DP_RANK"] = str(dp_rank) |
|
|
| |
| if server_args.grpc_mode: |
| from sglang.srt.entrypoints.grpc_server import serve_grpc |
|
|
| asyncio.run(serve_grpc(server_args)) |
| else: |
| from sglang.srt.entrypoints.http_server import launch_server |
|
|
| launch_server(server_args) |
|
|
|
|
| def launch_server_process( |
| server_args: ServerArgs, worker_port: int, dp_id: int |
| ) -> mp.Process: |
| """Launch a single server process with the given args and port.""" |
| server_args = copy.deepcopy(server_args) |
| server_args.port = worker_port |
| server_args.base_gpu_id = dp_id * server_args.tp_size |
| server_args.dp_size = 1 |
|
|
| proc = mp.Process(target=run_server, args=(server_args, dp_id)) |
| proc.start() |
| return proc |
|
|
|
|
| def wait_for_server_health(host: str, port: int, timeout: int = 300) -> bool: |
| """Wait for server to be healthy by checking /health endpoint.""" |
| start_time = time.perf_counter() |
| url = f"http://{host}:{port}/health" |
|
|
| while time.perf_counter() - start_time < timeout: |
| try: |
| response = requests.get(url, timeout=5) |
| if response.status_code == 200: |
| return True |
| except requests.exceptions.RequestException: |
| pass |
| time.sleep(1) |
| return False |
|
|
|
|
| def find_available_ports(base_port: int, count: int) -> List[int]: |
| """Find consecutive available ports starting from base_port.""" |
| available_ports = [] |
| current_port = base_port |
|
|
| while len(available_ports) < count: |
| if is_port_available(current_port): |
| available_ports.append(current_port) |
| current_port += random.randint(100, 1000) |
|
|
| return available_ports |
|
|
|
|
| def cleanup_processes(processes: List[mp.Process]): |
| for process in processes: |
| logger.info(f"Terminating process group {process.pid}") |
| try: |
| os.killpg(process.pid, signal.SIGTERM) |
| except ProcessLookupError: |
| |
| pass |
|
|
| |
| for process in processes: |
| process.join(timeout=5) |
| if process.is_alive(): |
| logger.warning( |
| f"Process {process.pid} did not terminate gracefully, forcing kill" |
| ) |
| try: |
| os.killpg(process.pid, signal.SIGKILL) |
| except ProcessLookupError: |
| pass |
|
|
| logger.info("All process groups terminated") |
|
|
|
|
| def main(): |
| |
| mp.set_start_method("spawn") |
|
|
| parser = argparse.ArgumentParser( |
| description="Launch SGLang router and server processes" |
| ) |
|
|
| ServerArgs.add_cli_args(parser) |
| RouterArgs.add_cli_args(parser, use_router_prefix=True, exclude_host_port=True) |
| parser.add_argument( |
| "--router-dp-worker-base-port", |
| type=int, |
| default=31000, |
| help="Base port number for data parallel workers", |
| ) |
| |
|
|
| args = parser.parse_args() |
| server_args = ServerArgs.from_cli_args(args) |
| router_args = RouterArgs.from_cli_args(args, use_router_prefix=True) |
|
|
| |
| worker_ports = find_available_ports( |
| args.router_dp_worker_base_port, server_args.dp_size |
| ) |
|
|
| |
| server_processes = [] |
|
|
| for i, worker_port in enumerate(worker_ports): |
| logger.info(f"Launching DP server process {i} on port {worker_port}") |
| proc = launch_server_process(server_args, worker_port, i) |
| server_processes.append(proc) |
|
|
| signal.signal(signal.SIGINT, lambda sig, frame: cleanup_processes(server_processes)) |
| signal.signal( |
| signal.SIGTERM, lambda sig, frame: cleanup_processes(server_processes) |
| ) |
| signal.signal( |
| signal.SIGQUIT, lambda sig, frame: cleanup_processes(server_processes) |
| ) |
|
|
| |
| |
| protocol = "grpc" if server_args.grpc_mode else "http" |
| router_args.worker_urls = [ |
| f"{protocol}://{server_args.host}:{port}" for port in worker_ports |
| ] |
|
|
| |
| try: |
| launch_router(router_args) |
| except Exception as e: |
| logger.error(f"Failed to start router: {e}") |
| cleanup_processes(server_processes) |
| sys.exit(1) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|