leideng/QCFuse / srt /entrypoints /grpc_server.py
leideng's picture
download
raw
38.9 kB
"""
Standalone gRPC Server for SGLang - Fully separated from HTTP server.
Uses GrpcRequestManager for orchestration without tokenization.
"""
import asyncio
import dataclasses
import logging
import multiprocessing as mp
import os
import signal
import threading
import time
from concurrent import futures
from typing import AsyncIterator, Dict, Optional
import grpc
from google.protobuf.json_format import MessageToDict
from google.protobuf.struct_pb2 import Struct
from google.protobuf.timestamp_pb2 import Timestamp
from grpc_health.v1 import health_pb2_grpc
from grpc_reflection.v1alpha import reflection
import sglang
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST, DisaggregationMode
from sglang.srt.grpc import sglang_scheduler_pb2, sglang_scheduler_pb2_grpc
from sglang.srt.grpc.grpc_request_manager import GrpcRequestManager
from sglang.srt.grpc.health_servicer import SGLangHealthServicer
from sglang.srt.grpc.scheduler_launcher import launch_scheduler_process_only
from sglang.srt.managers.disagg_service import start_disagg_service
from sglang.srt.managers.io_struct import (
TokenizedEmbeddingReqInput,
TokenizedGenerateReqInput,
)
from sglang.srt.sampling.sampling_params import SamplingParams as SGLSamplingParams
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import kill_process_tree
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
HEALTH_CHECK_TIMEOUT = int(os.getenv("SGLANG_HEALTH_CHECK_TIMEOUT", 20))
class SGLangSchedulerServicer(sglang_scheduler_pb2_grpc.SglangSchedulerServicer):
"""
Standalone gRPC service implementation using GrpcRequestManager.
Fully separated from HTTP server with its own process and no shared globals.
"""
def __init__(
self,
request_manager: GrpcRequestManager,
server_args: ServerArgs,
model_info: Dict,
scheduler_info: Dict,
health_servicer: Optional[SGLangHealthServicer] = None,
):
"""Initialize the standalone gRPC service."""
self.request_manager = request_manager
self.server_args = server_args
self.model_info = model_info
self.scheduler_info = scheduler_info
self.start_time = time.time()
self.health_servicer = health_servicer
# Start the request manager's event loop using auto_create_handle_loop
self.request_manager.auto_create_handle_loop()
logger.info("gRPC scheduler servicer initialized")
async def Generate(
self,
request: sglang_scheduler_pb2.GenerateRequest,
context: grpc.aio.ServicerContext,
) -> AsyncIterator[sglang_scheduler_pb2.GenerateResponse]:
"""Handle generation requests with streaming responses."""
logger.info(f"Receive generation request: {request.request_id}")
try:
# Convert gRPC request to internal format
tokenized_req = self._convert_generate_request(request)
# Submit to request manager (automatically handles n>1)
response_generator = self.request_manager.generate_request(
obj=tokenized_req,
request_id=request.request_id,
grpc_context=context,
)
async for output in response_generator:
# Handle batch responses (for n>1 non-streaming)
if isinstance(output, list):
for batch_output in output:
if "error" in batch_output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=batch_output["error"],
http_status_code=(
"500" if "abort" not in batch_output else "499"
),
),
)
else:
# All non-error batch outputs are final responses
yield self._create_completion_response(
request.request_id, batch_output
)
else:
# Handle single response (for streaming or n=1 non-streaming)
if "error" in output:
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=output["error"],
http_status_code=(
"500" if "abort" not in output else "499"
),
),
)
elif output.get("finished", False):
yield self._create_completion_response(
request.request_id, output
)
else:
yield self._create_chunk_response(request.request_id, output)
except Exception as e:
logger.error(
f"Generate failed for request {request.request_id}: {e}\n"
f"{get_exception_traceback()}"
)
yield sglang_scheduler_pb2.GenerateResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.GenerateError(
message=str(e),
http_status_code="500",
details=get_exception_traceback(),
),
)
async def Embed(
self,
request: sglang_scheduler_pb2.EmbedRequest,
_context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.EmbedResponse:
"""Handle embedding requests."""
logger.info(f"Receive embedding request: {request.request_id}")
try:
# Convert request
tokenized_req = self._convert_embed_request(request)
# Submit to request manager
future = await self.request_manager.embedding_request(
obj=tokenized_req,
request_id=request.request_id,
)
# Wait for result
result = await future
# Create response
return sglang_scheduler_pb2.EmbedResponse(
request_id=request.request_id,
complete=sglang_scheduler_pb2.EmbedComplete(
embedding=result["embedding"],
prompt_tokens=result.get("prompt_tokens", 0),
cached_tokens=0,
embedding_dim=len(result["embedding"]),
),
)
except Exception as e:
logger.error(
f"Embed failed for request {request.request_id}: {e}\n"
f"{get_exception_traceback()}"
)
return sglang_scheduler_pb2.EmbedResponse(
request_id=request.request_id,
error=sglang_scheduler_pb2.EmbedError(
message=str(e),
code="INTERNAL_ERROR",
details=get_exception_traceback(),
),
)
async def HealthCheck(
self,
request: sglang_scheduler_pb2.HealthCheckRequest,
context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.HealthCheckResponse:
"""
Check the health of the inference server by sending a special request to generate one token.
Similar to HTTP server's /health endpoint.
"""
rid = f"HEALTH_CHECK_{time.time()}"
logger.info(f"Receive health check request: {rid}")
if self.request_manager.gracefully_exit:
logger.info(
"Health check request received during shutdown. Returning unhealthy."
)
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message="Server is shutting down"
)
# Create a special health check request
sampling_params = SGLSamplingParams(max_new_tokens=1, temperature=0.0)
sampling_params.normalize(tokenizer=None)
# Create health check request
is_generation = self.scheduler_info.get("is_generation", True)
if is_generation:
health_req = TokenizedGenerateReqInput(
rid=rid,
input_text="",
input_ids=[0],
sampling_params=sampling_params,
return_logprob=False,
logprob_start_len=-1,
top_logprobs_num=0,
stream=False,
mm_inputs=None,
token_ids_logprob=None,
)
# Set disaggregation params if needed
if self.server_args.disaggregation_mode != DisaggregationMode.NULL:
health_req.bootstrap_host = FAKE_BOOTSTRAP_HOST
health_req.bootstrap_room = 0
else:
health_req = TokenizedEmbeddingReqInput(
rid=rid,
input_text="",
input_ids=[0],
)
# Submit health check request
async def run_health_check():
try:
async for _ in self.request_manager.generate_request(
obj=health_req,
request_id=rid,
):
# Got at least one response, server is healthy
return True
except Exception as e:
logger.warning(f"Health check failed: {e}")
return False
return False
task = asyncio.create_task(run_health_check())
# Wait for response with timeout
tic = time.time()
while time.time() < tic + HEALTH_CHECK_TIMEOUT:
await asyncio.sleep(1)
# Check if we got a response from scheduler
if self.request_manager.last_receive_tstamp > tic:
task.cancel()
# Clean up health check state
self.request_manager._cleanup_request_state(rid)
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=True, message="Health check passed"
)
# Timeout - server not responding
task.cancel()
self.request_manager._cleanup_request_state(rid)
logger.warning(f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s")
return sglang_scheduler_pb2.HealthCheckResponse(
healthy=False, message=f"Health check timeout after {HEALTH_CHECK_TIMEOUT}s"
)
async def Abort(
self,
request: sglang_scheduler_pb2.AbortRequest,
_context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.AbortResponse:
"""Abort an ongoing request."""
logger.info(f"Receive abort request: {request.request_id}")
try:
success = await self.request_manager.abort_request(request.request_id)
return sglang_scheduler_pb2.AbortResponse(
success=success,
message=f"Request {request.request_id} {'aborted' if success else 'not found'}",
)
except Exception as e:
logger.error(
f"Abort failed for request {request.request_id}: {e}\n"
f"{get_exception_traceback()}"
)
return sglang_scheduler_pb2.AbortResponse(
success=False,
message=str(e),
)
async def GetModelInfo(
self,
_request: sglang_scheduler_pb2.GetModelInfoRequest,
_context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.GetModelInfoResponse:
"""Get model information."""
logger.debug("Receive model info request")
is_generation = self.scheduler_info.get("is_generation")
if is_generation is None:
is_generation = not self.server_args.is_embedding
return sglang_scheduler_pb2.GetModelInfoResponse(
model_path=self.server_args.model_path,
tokenizer_path=self.server_args.tokenizer_path or "",
is_generation=is_generation,
preferred_sampling_params=(
self.server_args.preferred_sampling_params or ""
),
weight_version=self.server_args.weight_version or "",
served_model_name=self.server_args.served_model_name,
max_context_length=self.model_info["max_context_length"],
vocab_size=self.model_info["vocab_size"],
supports_vision=self.model_info["supports_vision"],
model_type=self.model_info["model_type"],
eos_token_ids=self.model_info["eos_token_ids"],
pad_token_id=self.model_info["pad_token_id"],
bos_token_id=self.model_info["bos_token_id"],
max_req_input_len=self.model_info["max_req_input_len"],
)
async def GetServerInfo(
self,
_request: sglang_scheduler_pb2.GetServerInfoRequest,
_context: grpc.aio.ServicerContext,
) -> sglang_scheduler_pb2.GetServerInfoResponse:
"""Get server information."""
logger.debug("Receive server info request")
server_args_dict = dataclasses.asdict(self.server_args)
server_args_struct = Struct()
def make_serializable(obj):
if obj is None:
return None
elif isinstance(obj, (str, int, float, bool)):
return obj
elif isinstance(obj, (list, tuple, set)):
return [make_serializable(item) for item in obj]
elif isinstance(obj, dict):
return {k: make_serializable(v) for k, v in obj.items()}
else:
return str(obj)
serializable_args = make_serializable(server_args_dict)
server_args_struct.update(serializable_args)
# Convert scheduler_info to Struct
scheduler_info_struct = Struct()
scheduler_info_struct.update(self.scheduler_info)
# Get runtime state from request manager
manager_state = self.request_manager.get_server_info()
# Calculate uptime
uptime = time.time() - self.start_time
# Create timestamp
start_timestamp = Timestamp()
start_timestamp.FromSeconds(int(self.start_time))
return sglang_scheduler_pb2.GetServerInfoResponse(
server_args=server_args_struct,
scheduler_info=scheduler_info_struct,
active_requests=manager_state["active_requests"],
is_paused=manager_state["paused"],
last_receive_timestamp=manager_state["last_receive_time"],
uptime_seconds=uptime,
sglang_version=sglang.__version__,
server_type="grpc",
start_time=start_timestamp,
)
# Helper methods for request/response conversion
def _convert_generate_request(
self, grpc_req: sglang_scheduler_pb2.GenerateRequest
) -> TokenizedGenerateReqInput:
"""Convert gRPC GenerateRequest to internal format."""
# Extract tokenized input
if not grpc_req.HasField("tokenized"):
raise ValueError("Tokenized input must be provided")
input_text = grpc_req.tokenized.original_text
input_ids = list(grpc_req.tokenized.input_ids)
# Convert sampling params
sampling_params = self._convert_sampling_params(grpc_req.sampling_params)
sampling_params.normalize(tokenizer=None)
# Extract disaggregated params if present
bootstrap_host = None
bootstrap_port = None
bootstrap_room = None
if grpc_req.HasField("disaggregated_params"):
# Don't use 'or None' as it treats 0 as falsy
bootstrap_host = (
grpc_req.disaggregated_params.bootstrap_host
if grpc_req.disaggregated_params.bootstrap_host
else None
)
bootstrap_port = (
grpc_req.disaggregated_params.bootstrap_port
if grpc_req.disaggregated_params.bootstrap_port
else None
)
bootstrap_room = (
grpc_req.disaggregated_params.bootstrap_room
) # Can be 0, don't use 'or None'
# Create request
return TokenizedGenerateReqInput(
rid=grpc_req.request_id,
input_text=input_text,
input_ids=input_ids,
mm_inputs=None, # TODO: implement mm support
sampling_params=sampling_params,
return_logprob=grpc_req.return_logprob,
logprob_start_len=(
grpc_req.logprob_start_len
if grpc_req.logprob_start_len is not None
else -1
),
top_logprobs_num=grpc_req.top_logprobs_num or 0,
stream=grpc_req.stream or False,
lora_id=grpc_req.lora_id if grpc_req.lora_id else None,
token_ids_logprob=(
list(grpc_req.token_ids_logprob) if grpc_req.token_ids_logprob else None
),
bootstrap_host=bootstrap_host,
bootstrap_port=bootstrap_port,
bootstrap_room=bootstrap_room,
)
def _convert_embed_request(
self, grpc_req: sglang_scheduler_pb2.EmbedRequest
) -> TokenizedEmbeddingReqInput:
"""Convert gRPC EmbedRequest to internal format."""
# Extract tokenized input
if not grpc_req.HasField("tokenized"):
raise ValueError("Tokenized input must be provided")
input_text = grpc_req.tokenized.original_text
input_ids = list(grpc_req.tokenized.input_ids)
return TokenizedEmbeddingReqInput(
rid=grpc_req.request_id,
input_text=input_text,
input_ids=input_ids,
)
def _convert_sampling_params(
self, grpc_params: sglang_scheduler_pb2.SamplingParams
) -> SGLSamplingParams:
"""Convert gRPC SamplingParams to internal format."""
# Handle constraint types
regex = None
json_schema = None
ebnf_grammar = None
structural_tag = None
if grpc_params.HasField("regex"):
regex = grpc_params.regex
elif grpc_params.HasField("json_schema"):
json_schema = grpc_params.json_schema
elif grpc_params.HasField("ebnf_grammar"):
ebnf_grammar = grpc_params.ebnf_grammar
elif grpc_params.HasField("structural_tag"):
structural_tag = grpc_params.structural_tag
# Handle optional parameters conversion
custom_params = (
MessageToDict(grpc_params.custom_params)
if grpc_params.HasField("custom_params")
else None
)
max_new_tokens = (
grpc_params.max_new_tokens
if grpc_params.HasField("max_new_tokens")
else None
)
stream_interval = (
grpc_params.stream_interval
if grpc_params.HasField("stream_interval")
else None
)
logit_bias = dict(grpc_params.logit_bias) if grpc_params.logit_bias else None
stop = list(grpc_params.stop) if grpc_params.stop else None
stop_token_ids = (
list(grpc_params.stop_token_ids) if grpc_params.stop_token_ids else None
)
return SGLSamplingParams(
temperature=grpc_params.temperature,
top_p=grpc_params.top_p,
top_k=grpc_params.top_k,
min_p=grpc_params.min_p,
frequency_penalty=grpc_params.frequency_penalty,
presence_penalty=grpc_params.presence_penalty,
repetition_penalty=grpc_params.repetition_penalty,
max_new_tokens=max_new_tokens,
min_new_tokens=grpc_params.min_new_tokens,
stop=stop,
stop_token_ids=stop_token_ids,
skip_special_tokens=grpc_params.skip_special_tokens,
spaces_between_special_tokens=grpc_params.spaces_between_special_tokens,
no_stop_trim=grpc_params.no_stop_trim,
regex=regex,
json_schema=json_schema,
ebnf=ebnf_grammar,
structural_tag=structural_tag,
n=grpc_params.n,
ignore_eos=grpc_params.ignore_eos,
stream_interval=stream_interval,
logit_bias=logit_bias,
custom_params=custom_params,
)
def _convert_output_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.OutputLogProbs]:
"""Convert output logprobs dict to proto (no None values, plain floats)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.OutputLogProbs(
token_logprobs=token_logprobs_val, # Plain float array
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _convert_input_logprobs_to_proto(
self, logprobs_data: Dict
) -> Optional[sglang_scheduler_pb2.InputLogProbs]:
"""Convert input logprobs dict to proto (first token is None, wrapped in InputTokenLogProb)."""
if not logprobs_data:
return None
token_logprobs_val = logprobs_data.get("token_logprobs_val", [])
token_logprobs_idx = logprobs_data.get("token_logprobs_idx", [])
top_logprobs_val = logprobs_data.get("top_logprobs_val", [])
top_logprobs_idx = logprobs_data.get("top_logprobs_idx", [])
# Wrap values in InputTokenLogProb (None for first token, value for others)
token_logprobs_wrapped = [
(
sglang_scheduler_pb2.InputTokenLogProb()
if x is None
else sglang_scheduler_pb2.InputTokenLogProb(value=x)
)
for x in token_logprobs_val
]
# Build TopLogProbs entries
top_logprobs_proto = []
if top_logprobs_val and top_logprobs_idx:
for val_list, idx_list in zip(top_logprobs_val, top_logprobs_idx):
top_logprobs_proto.append(
sglang_scheduler_pb2.TopLogProbs(
values=val_list,
token_ids=idx_list,
)
)
return sglang_scheduler_pb2.InputLogProbs(
token_logprobs=token_logprobs_wrapped,
token_ids=token_logprobs_idx,
top_logprobs=top_logprobs_proto,
)
def _create_chunk_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a streaming chunk response."""
meta_info = output.get("meta_info", {})
# Convert output logprobs if present
output_logprobs_proto = self._convert_output_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present (only in first chunk)
input_logprobs_proto = self._convert_input_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
chunk=sglang_scheduler_pb2.GenerateStreamChunk(
token_ids=output.get("token_ids", []),
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get("completion_tokens", 0),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
),
)
def _create_completion_response(
self, request_id: str, output: Dict
) -> sglang_scheduler_pb2.GenerateResponse:
"""Create a completion response."""
# Extract meta info and finish reason details
meta_info = output.get("meta_info", {})
finish_reason_data = meta_info.get("finish_reason")
# Determine finish reason, default is stop
finish_reason = "stop"
if finish_reason_data:
if isinstance(finish_reason_data, dict):
finish_reason_type = finish_reason_data.get("type")
else:
# Handle legacy string format
finish_reason_type = finish_reason_data
if finish_reason_type == "length":
finish_reason = "length"
elif finish_reason_type == "abort":
finish_reason = "abort"
# Extract matched_stop information
matched_stop_kwargs = {}
if isinstance(finish_reason_data, dict) and "matched" in finish_reason_data:
matched = finish_reason_data["matched"]
if isinstance(matched, int):
matched_stop_kwargs["matched_token_id"] = matched
elif isinstance(matched, str):
matched_stop_kwargs["matched_stop_str"] = matched
# Convert output logprobs if present
output_logprobs_proto = self._convert_output_logprobs_to_proto(
output.get("output_logprobs")
)
# Convert input logprobs if present
input_logprobs_proto = self._convert_input_logprobs_to_proto(
output.get("input_logprobs")
)
return sglang_scheduler_pb2.GenerateResponse(
request_id=request_id,
complete=sglang_scheduler_pb2.GenerateComplete(
output_ids=output.get("token_ids", []),
finish_reason=finish_reason,
prompt_tokens=meta_info.get("prompt_tokens", 0),
completion_tokens=meta_info.get(
"completion_tokens", len(output.get("token_ids", []))
),
cached_tokens=meta_info.get("cached_tokens", 0),
output_logprobs=output_logprobs_proto,
input_logprobs=input_logprobs_proto,
index=output.get("index", 0),
**matched_stop_kwargs,
),
)
async def shutdown(self):
"""Shutdown the service."""
logger.info("Shutting down gRPC service")
# Mark health service as NOT_SERVING before shutdown
if self.health_servicer:
self.health_servicer.set_not_serving()
# Shutdown request manager (handles its own tasks)
await self.request_manager.shutdown()
async def serve_grpc(
server_args: ServerArgs,
model_info: Optional[Dict] = None,
):
"""Start the standalone gRPC server with integrated scheduler."""
# Start bootstrap server BEFORE launching scheduler processes (only in PREFILL mode)
# This ensures the bootstrap server is ready when prefill schedulers try to register
bootstrap_server = None
if server_args.disaggregation_mode == "prefill":
bootstrap_server = start_disagg_service(server_args)
if bootstrap_server:
logger.info(
f"Bootstrap server started for disaggregation mode on {server_args.host}:{server_args.disaggregation_bootstrap_port}"
)
# Launch only the scheduler process(es) (no tokenizer/detokenizer needed for gRPC)
logger.info("Launching scheduler process(es)...")
scheduler_info, port_args, scheduler_procs = launch_scheduler_process_only(
server_args=server_args,
)
# Update model info from scheduler info
if model_info is None:
model_info = {
"model_name": server_args.model_path,
"max_context_length": scheduler_info.get(
"max_total_num_tokens", server_args.context_length or 8192
),
"vocab_size": scheduler_info.get("vocab_size", 128256),
"supports_vision": scheduler_info.get("supports_vision", False),
"model_type": scheduler_info.get("model_type", "transformer"),
"max_req_input_len": scheduler_info.get("max_req_input_len", 8192),
"eos_token_ids": scheduler_info.get("eos_token_ids", []),
"pad_token_id": scheduler_info.get("pad_token_id", 0),
"bos_token_id": scheduler_info.get("bos_token_id", 1),
}
# Create request manager with the correct port args
# Note: We pass None for bootstrap_server since it's already started above
request_manager = GrpcRequestManager(
server_args=server_args,
port_args=port_args,
bootstrap_server=bootstrap_server,
)
# Create gRPC server
server = grpc.aio.server(
futures.ThreadPoolExecutor(max_workers=10),
options=[
("grpc.max_send_message_length", 1024 * 1024 * 256),
("grpc.max_receive_message_length", 1024 * 1024 * 256),
],
)
# Create standard health service (for Kubernetes probes)
health_servicer = SGLangHealthServicer(
request_manager=request_manager,
scheduler_info=scheduler_info,
)
health_pb2_grpc.add_HealthServicer_to_server(health_servicer, server)
# Add SGLang service
servicer = SGLangSchedulerServicer(
request_manager=request_manager,
server_args=server_args,
model_info=model_info,
scheduler_info=scheduler_info,
health_servicer=health_servicer,
)
sglang_scheduler_pb2_grpc.add_SglangSchedulerServicer_to_server(servicer, server)
# Enable reflection
SERVICE_NAMES = (
sglang_scheduler_pb2.DESCRIPTOR.services_by_name["SglangScheduler"].full_name,
"grpc.health.v1.Health",
reflection.SERVICE_NAME,
)
reflection.enable_server_reflection(SERVICE_NAMES, server)
# Start server
listen_addr = f"{server_args.host}:{server_args.port}"
server.add_insecure_port(listen_addr)
await server.start()
logger.info(f"gRPC server listening on {listen_addr}")
# Start warmup in a separate thread
warmup_thread = threading.Thread(
target=_wait_and_warmup_grpc,
args=(server_args, None, health_servicer),
)
warmup_thread.start()
# Handle shutdown signals
loop = asyncio.get_running_loop()
stop_event = asyncio.Event()
def signal_handler():
logger.info("Received shutdown signal")
stop_event.set()
for sig in (signal.SIGTERM, signal.SIGINT):
loop.add_signal_handler(sig, signal_handler)
try:
await stop_event.wait()
finally:
logger.info("Shutting down gRPC server")
# Shutdown request manager first - this closes ZMQ sockets and stops background tasks
await servicer.shutdown()
# Stop the gRPC server
await server.stop(5.0)
# Wait for warmup thread to finish
if warmup_thread.is_alive():
logger.info("Waiting for warmup thread to finish...")
warmup_thread.join(timeout=5.0)
# Terminate scheduler processes before exiting to avoid atexit hang
# The scheduler processes have SIGINT ignored, so they won't get KeyboardInterrupt
for i, proc in enumerate(scheduler_procs):
if proc.is_alive():
logger.info(f"Terminating scheduler process {i}...")
proc.terminate()
proc.join(timeout=2.0)
if proc.is_alive():
logger.warning(
f"Scheduler process {i} did not terminate, killing..."
)
proc.kill()
proc.join(timeout=1.0)
logger.info("All scheduler processes terminated")
def _execute_grpc_server_warmup(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection],
):
"""Execute warmup for gRPC server by checking health and sending test request."""
try:
# Connect to the gRPC server
grpc_url = f"{server_args.host}:{server_args.port}"
channel = grpc.insecure_channel(
grpc_url,
options=[
("grpc.max_send_message_length", 1024 * 1024 * 256),
("grpc.max_receive_message_length", 1024 * 1024 * 256),
],
)
stub = sglang_scheduler_pb2_grpc.SglangSchedulerStub(channel)
# Wait until the server is launched (poll GetModelInfo)
success = False
last_error = None
for _ in range(120):
time.sleep(1)
try:
request = sglang_scheduler_pb2.GetModelInfoRequest()
response = stub.GetModelInfo(request, timeout=5)
success = True
break
except Exception as e:
last_error = str(e)
pass
if not success:
error_msg = f"gRPC server warmup failed: Could not connect to server after 120 seconds. Last error: {last_error}"
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
channel.close()
kill_process_tree(os.getpid())
return False
# Get model info to determine if it's generation or embedding
is_generation = response.is_generation
# Send a warmup request
logger.info("Sending warmup request to gRPC server...")
max_new_tokens = 8 if is_generation else 1
if is_generation:
warmup_request_kwargs = {
"request_id": f"WARMUP_{time.time()}",
"tokenized": sglang_scheduler_pb2.TokenizedInput(
input_ids=[
123,
456,
789,
234,
567,
890,
345,
], # Random-looking but safe token IDs
original_text="warmup request",
),
"sampling_params": sglang_scheduler_pb2.SamplingParams(
temperature=0.0,
max_new_tokens=max_new_tokens,
),
"stream": False,
}
# Set disaggregation params if needed
if server_args.disaggregation_mode != DisaggregationMode.NULL:
warmup_request_kwargs["disaggregated_params"] = (
sglang_scheduler_pb2.DisaggregatedParams(
bootstrap_host=FAKE_BOOTSTRAP_HOST,
bootstrap_room=0,
)
)
warmup_request = sglang_scheduler_pb2.GenerateRequest(
**warmup_request_kwargs
)
# Send the warmup request
try:
responses = list(stub.Generate(warmup_request, timeout=600))
# Check if we got a valid response
if responses and not responses[-1].HasField("error"):
logger.info("gRPC warmup request completed successfully")
success = True
else:
error_msg = (
responses[-1].error.message if responses else "No response"
)
logger.warning(f"gRPC warmup request returned error: {error_msg}")
success = False
except Exception as e:
error_msg = f"gRPC warmup request failed: {e}"
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
channel.close()
kill_process_tree(os.getpid())
return False
else:
# For embedding models
warmup_request = sglang_scheduler_pb2.EmbedRequest(
request_id=f"WARMUP_{time.time()}",
tokenized=sglang_scheduler_pb2.TokenizedInput(
input_ids=[10, 11, 12],
original_text="test embedding",
),
)
try:
response = stub.Embed(warmup_request, timeout=600)
if not response.HasField("error"):
logger.info("gRPC warmup request completed successfully")
success = True
else:
logger.warning(
f"gRPC warmup request returned error: {response.error.message}"
)
success = False
except Exception as e:
error_msg = f"gRPC warmup request failed: {e}"
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
channel.close()
kill_process_tree(os.getpid())
return False
channel.close()
return success
except Exception as e:
error_msg = (
f"gRPC warmup failed with exception: {e}\n{get_exception_traceback()}"
)
logger.error(error_msg)
if pipe_finish_writer is not None:
pipe_finish_writer.send(error_msg)
try:
channel.close()
except Exception:
pass
kill_process_tree(os.getpid())
return False
def _wait_and_warmup_grpc(
server_args: ServerArgs,
pipe_finish_writer: Optional[mp.connection.Connection],
health_servicer: Optional[SGLangHealthServicer] = None,
):
"""Wait for gRPC server to be ready and execute warmup."""
if not server_args.skip_server_warmup:
if not _execute_grpc_server_warmup(server_args, pipe_finish_writer):
return
else:
logger.info("Skipping gRPC server warmup (skip_server_warmup=True)")
# Mark health service as SERVING after warmup completes
if health_servicer:
health_servicer.set_serving()
logger.info("Health service marked as SERVING")
logger.info("The server is fired up and ready to roll!")
if pipe_finish_writer is not None:
pipe_finish_writer.send("ready")

Xet Storage Details

Size:
38.9 kB
·
Xet hash:
f7d0064ea5416f33ac7e5ae7010d927bc251d9aba1ff73133b9979a806a8af1c

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.