leideng/QCFuse / srt /eplb /expert_location_dispatch.py
leideng's picture
download
raw
4.16 kB
# Copyright 2023-2025 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from dataclasses import dataclass
from typing import Literal, Optional
import torch
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.server_args import get_global_server_args
@dataclass
class ExpertLocationDispatchInfo:
ep_dispatch_algorithm: Literal["static", "random"]
# (num_logical_experts,)
partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor]
# (num_logical_experts, X)
partial_logical_to_all_physical_map: torch.Tensor
# (num_logical_experts,)
partial_logical_to_all_physical_map_num_valid: torch.Tensor
num_physical_experts: int
@classmethod
def init_new(cls, layer_id: int):
ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm
expert_location_metadata = get_global_expert_location_metadata()
assert expert_location_metadata is not None
if ep_dispatch_algorithm is None:
return None
return cls(
ep_dispatch_algorithm=ep_dispatch_algorithm,
partial_logical_to_rank_dispatch_physical_map=(
expert_location_metadata.logical_to_rank_dispatch_physical_map[
layer_id, :
]
if expert_location_metadata.logical_to_rank_dispatch_physical_map
is not None
else None
),
partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[
layer_id, :
],
partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[
layer_id, :
],
num_physical_experts=expert_location_metadata.num_physical_experts,
)
def transform_select_experts_inputs(
router_logits: torch.Tensor,
correction_bias: Optional[torch.Tensor],
info: Optional[ExpertLocationDispatchInfo],
):
if (info is not None) and (info.ep_dispatch_algorithm == "fake"):
router_logits.uniform_(5, 10)
if correction_bias is not None:
correction_bias = torch.zeros_like(correction_bias)
return router_logits, correction_bias
def topk_ids_logical_to_physical(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
if info is None:
return topk_ids
if info.ep_dispatch_algorithm == "static":
return _topk_ids_logical_to_physical_static(topk_ids, info)
if info.ep_dispatch_algorithm in ["dynamic", "fake"]:
return _topk_ids_logical_to_physical_dynamic(topk_ids, info)
raise NotImplementedError(f"Unknown algorithm {info.ep_dispatch_algorithm}")
def _topk_ids_logical_to_physical_static(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
return info.partial_logical_to_rank_dispatch_physical_map[topk_ids]
def _topk_ids_logical_to_physical_dynamic(
topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo]
) -> torch.Tensor:
topk_ids_original_shape = topk_ids.shape
device = topk_ids.device
topk_ids = topk_ids.flatten()
chosen_dispatch_index = (
torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device)
% info.partial_logical_to_all_physical_map_num_valid[topk_ids]
)
topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index]
topk_ids = topk_ids.view(topk_ids_original_shape)
return topk_ids

Xet Storage Details

Size:
4.16 kB
·
Xet hash:
dd447f0014c3eaba493ac8657092d9269cb9e9e6c39c26cdbc5302dd0984b267

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.