| # Copyright 2023-2025 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| from dataclasses import dataclass | |
| from typing import Literal, Optional | |
| import torch | |
| from sglang.srt.eplb.expert_location import get_global_expert_location_metadata | |
| from sglang.srt.server_args import get_global_server_args | |
| class ExpertLocationDispatchInfo: | |
| ep_dispatch_algorithm: Literal["static", "random"] | |
| # (num_logical_experts,) | |
| partial_logical_to_rank_dispatch_physical_map: Optional[torch.Tensor] | |
| # (num_logical_experts, X) | |
| partial_logical_to_all_physical_map: torch.Tensor | |
| # (num_logical_experts,) | |
| partial_logical_to_all_physical_map_num_valid: torch.Tensor | |
| num_physical_experts: int | |
| def init_new(cls, layer_id: int): | |
| ep_dispatch_algorithm = get_global_server_args().ep_dispatch_algorithm | |
| expert_location_metadata = get_global_expert_location_metadata() | |
| assert expert_location_metadata is not None | |
| if ep_dispatch_algorithm is None: | |
| return None | |
| return cls( | |
| ep_dispatch_algorithm=ep_dispatch_algorithm, | |
| partial_logical_to_rank_dispatch_physical_map=( | |
| expert_location_metadata.logical_to_rank_dispatch_physical_map[ | |
| layer_id, : | |
| ] | |
| if expert_location_metadata.logical_to_rank_dispatch_physical_map | |
| is not None | |
| else None | |
| ), | |
| partial_logical_to_all_physical_map=expert_location_metadata.logical_to_all_physical_map[ | |
| layer_id, : | |
| ], | |
| partial_logical_to_all_physical_map_num_valid=expert_location_metadata.logical_to_all_physical_map_num_valid[ | |
| layer_id, : | |
| ], | |
| num_physical_experts=expert_location_metadata.num_physical_experts, | |
| ) | |
| def transform_select_experts_inputs( | |
| router_logits: torch.Tensor, | |
| correction_bias: Optional[torch.Tensor], | |
| info: Optional[ExpertLocationDispatchInfo], | |
| ): | |
| if (info is not None) and (info.ep_dispatch_algorithm == "fake"): | |
| router_logits.uniform_(5, 10) | |
| if correction_bias is not None: | |
| correction_bias = torch.zeros_like(correction_bias) | |
| return router_logits, correction_bias | |
| def topk_ids_logical_to_physical( | |
| topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] | |
| ) -> torch.Tensor: | |
| if info is None: | |
| return topk_ids | |
| if info.ep_dispatch_algorithm == "static": | |
| return _topk_ids_logical_to_physical_static(topk_ids, info) | |
| if info.ep_dispatch_algorithm in ["dynamic", "fake"]: | |
| return _topk_ids_logical_to_physical_dynamic(topk_ids, info) | |
| raise NotImplementedError(f"Unknown algorithm {info.ep_dispatch_algorithm}") | |
| def _topk_ids_logical_to_physical_static( | |
| topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] | |
| ) -> torch.Tensor: | |
| return info.partial_logical_to_rank_dispatch_physical_map[topk_ids] | |
| def _topk_ids_logical_to_physical_dynamic( | |
| topk_ids: torch.Tensor, info: Optional[ExpertLocationDispatchInfo] | |
| ) -> torch.Tensor: | |
| topk_ids_original_shape = topk_ids.shape | |
| device = topk_ids.device | |
| topk_ids = topk_ids.flatten() | |
| chosen_dispatch_index = ( | |
| torch.randint(0, 65536, topk_ids.shape, dtype=torch.int32, device=device) | |
| % info.partial_logical_to_all_physical_map_num_valid[topk_ids] | |
| ) | |
| topk_ids = info.partial_logical_to_all_physical_map[topk_ids, chosen_dispatch_index] | |
| topk_ids = topk_ids.view(topk_ids_original_shape) | |
| return topk_ids | |
Xet Storage Details
- Size:
- 4.16 kB
- Xet hash:
- dd447f0014c3eaba493ac8657092d9269cb9e9e6c39c26cdbc5302dd0984b267
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.