|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
import torch_npu |
|
|
from vllm.distributed import GroupCoordinator |
|
|
|
|
|
import vllm_ascend.envs as envs |
|
|
from vllm_ascend.ascend_config import get_ascend_config |
|
|
from vllm_ascend.distributed.parallel_state import get_ep_group |
|
|
from vllm_ascend.ops.fused_moe import select_experts |
|
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState, |
|
|
dispose_tensor, get_fused_moe_state, |
|
|
npu_stream_switch, npu_wait_tensor) |
|
|
|
|
|
|
|
|
def apply_mlp(hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
group_list: torch.Tensor, |
|
|
dynamic_scale: torch.Tensor = None, |
|
|
group_list_type: int = 1) -> torch.Tensor: |
|
|
""" |
|
|
apply MLP: gate_up_proj -> swiglu -> down_proj |
|
|
|
|
|
Args: |
|
|
hidden_states: input hidden states with shape (num_tokens, hidden_size). |
|
|
w1: expert weights1 with shape |
|
|
(num_experts, hidden_size, intermediate_size * 2) |
|
|
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2) |
|
|
w2: expert weights2 with shape |
|
|
(num_experts, intermediate_size, hidden_size) |
|
|
w2_scale: weights2 scale with shape (num_experts, hidden_size) |
|
|
group_list: number of tokens for each expert, follow cumsum mode, and |
|
|
with shape (num_experts). |
|
|
transpose_weight: |
|
|
w1: (num_experts, intermediate_size * 2, hidden_size) -> |
|
|
(num_experts, hidden_size, intermediate_size * 2) |
|
|
w2: (num_experts, hidden_size, intermediate_size) -> |
|
|
(num_experts, intermediate_size, hidden_size) |
|
|
|
|
|
Returns: |
|
|
hidden_states: output hidden states after MLP. |
|
|
""" |
|
|
|
|
|
if dynamic_scale is None: |
|
|
unquantized_hidden_states = hidden_states |
|
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( |
|
|
hidden_states) |
|
|
|
|
|
|
|
|
dispose_tensor(unquantized_hidden_states) |
|
|
else: |
|
|
pertoken_scale = dynamic_scale |
|
|
|
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul( |
|
|
x=[hidden_states], |
|
|
weight=[w1], |
|
|
scale=[w1_scale], |
|
|
per_token_scale=[pertoken_scale], |
|
|
split_item=2, |
|
|
group_list_type=group_list_type, |
|
|
group_type=0, |
|
|
group_list=group_list, |
|
|
output_dtype=w2_scale.dtype)[0] |
|
|
|
|
|
|
|
|
hidden_states = torch_npu.npu_swiglu(hidden_states) |
|
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( |
|
|
hidden_states) |
|
|
|
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul( |
|
|
x=[hidden_states], |
|
|
weight=[w2], |
|
|
scale=[w2_scale], |
|
|
per_token_scale=[swiglu_out_scale], |
|
|
split_item=2, |
|
|
group_list_type=group_list_type, |
|
|
group_type=0, |
|
|
group_list=group_list, |
|
|
output_dtype=w2_scale.dtype)[0] |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def fused_experts_with_mc2( |
|
|
hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
topk_weights: torch.Tensor, |
|
|
topk_ids: torch.Tensor, |
|
|
top_k: int, |
|
|
expert_map: torch.Tensor = None, |
|
|
moe_all_to_all_group_name: str = "", |
|
|
log2phy: torch.Tensor = None, |
|
|
global_redundant_expert_num: int = 0, |
|
|
shared_experts: Optional[Any] = None, |
|
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: |
|
|
if log2phy is not None: |
|
|
topk_ids = log2phy[topk_ids] |
|
|
global_bs = 0 |
|
|
moe_expert_num = len(expert_map) + global_redundant_expert_num |
|
|
|
|
|
kwargs_mc2 = { |
|
|
"x": hidden_states, |
|
|
"expert_ids": topk_ids, |
|
|
"expert_shard_type": 0, |
|
|
"shared_expert_rank_num": 0, |
|
|
"moe_expert_num": moe_expert_num, |
|
|
"global_bs": global_bs, |
|
|
"expert_scales": topk_weights.to(torch.float32), |
|
|
} |
|
|
|
|
|
rank = torch.distributed.get_rank() |
|
|
|
|
|
quant_mode = 2 |
|
|
ep_group = get_ep_group().device_group |
|
|
local_rank = torch.distributed.get_rank(group=ep_group) |
|
|
all_to_all_group_size = torch.distributed.get_world_size(ep_group) |
|
|
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
tp_size = world_size // all_to_all_group_size |
|
|
tp_rank = rank % tp_size |
|
|
|
|
|
stage1_kwargs = { |
|
|
"scales": None, |
|
|
"quant_mode": quant_mode, |
|
|
"group_ep": moe_all_to_all_group_name, |
|
|
"ep_world_size": all_to_all_group_size, |
|
|
"ep_rank_id": local_rank, |
|
|
|
|
|
"group_tp": moe_all_to_all_group_name, |
|
|
"tp_world_size": tp_size, |
|
|
"tp_rank_id": tp_rank, |
|
|
} |
|
|
kwargs_mc2.update(stage1_kwargs) |
|
|
|
|
|
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2) |
|
|
|
|
|
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[ |
|
|
0:7] |
|
|
|
|
|
if shared_experts is not None: |
|
|
with npu_stream_switch("moe_secondary", 0): |
|
|
npu_wait_tensor(hidden_states, topk_weights) |
|
|
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states) |
|
|
npu_wait_tensor(shared_gate_up[0], expand_x) |
|
|
shared_act = shared_experts.act_fn(shared_gate_up) |
|
|
|
|
|
|
|
|
down_out_list = apply_mlp(expand_x, |
|
|
w1, |
|
|
w1_scale, |
|
|
w2, |
|
|
w2_scale, |
|
|
expert_token_nums, |
|
|
dynamic_scale=dynamic_scale) |
|
|
|
|
|
|
|
|
kwargs_mc2 = { |
|
|
"expand_x": down_out_list, |
|
|
"expert_ids": topk_ids, |
|
|
"expand_idx": expand_idx, |
|
|
"expert_scales": topk_weights.to(torch.float32), |
|
|
"expert_shard_type": 0, |
|
|
"shared_expert_rank_num": 0, |
|
|
"moe_expert_num": moe_expert_num, |
|
|
"global_bs": 0, |
|
|
"expand_scales": expand_scales, |
|
|
} |
|
|
tp_recv_counts = torch.empty(1, |
|
|
dtype=torch.int32, |
|
|
device=hidden_states.device) |
|
|
stage3_kwargs = { |
|
|
"ep_send_counts": ep_recv_counts, |
|
|
"group_ep": moe_all_to_all_group_name, |
|
|
"ep_world_size": all_to_all_group_size, |
|
|
"ep_rank_id": local_rank, |
|
|
"tp_send_counts": tp_recv_counts, |
|
|
|
|
|
"group_tp": moe_all_to_all_group_name, |
|
|
"tp_world_size": tp_size, |
|
|
"tp_rank_id": tp_rank, |
|
|
} |
|
|
kwargs_mc2.update(stage3_kwargs) |
|
|
|
|
|
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2) |
|
|
|
|
|
if shared_experts is None: |
|
|
return hidden_states |
|
|
else: |
|
|
with npu_stream_switch("moe_secondary", 0): |
|
|
npu_wait_tensor(shared_act[0], down_out_list) |
|
|
shared_output, _ = shared_experts.down_proj(shared_act) |
|
|
return hidden_states, shared_output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def fused_experts_with_all2all( |
|
|
hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
topk_weights: torch.Tensor, |
|
|
topk_ids: torch.Tensor, |
|
|
top_k: int, |
|
|
expert_map: torch.Tensor = None, |
|
|
ep_group: GroupCoordinator = None, |
|
|
log2phy: torch.Tensor = None, |
|
|
global_redundant_expert_num: int = 0, |
|
|
): |
|
|
if log2phy is not None: |
|
|
topk_ids = log2phy[topk_ids] |
|
|
original_shape = hidden_states.shape |
|
|
if len(original_shape) == 3: |
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
|
|
|
|
num_tokens, _ = hidden_states.shape |
|
|
num_experts = w1.shape[0] |
|
|
device = hidden_states.device |
|
|
|
|
|
if expert_map is not None: |
|
|
global_num_experts = len(expert_map) + global_redundant_expert_num |
|
|
local_num_experts = global_num_experts // ep_group.world_size |
|
|
row_idx_len = num_tokens * top_k |
|
|
row_idx = (torch.arange(0, |
|
|
row_idx_len, |
|
|
dtype=torch.int32, |
|
|
device=device).view(top_k, -1).permute( |
|
|
1, 0).contiguous()) |
|
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
|
|
hidden_states, |
|
|
row_idx=row_idx, |
|
|
expert_idx=topk_ids, |
|
|
active_num=num_tokens) |
|
|
|
|
|
global_expert_tokens = torch.bincount(expanded_expert_idx, |
|
|
minlength=global_num_experts) |
|
|
scatter_sizes = global_expert_tokens.view(ep_group.world_size, |
|
|
-1).sum(-1) |
|
|
|
|
|
gather_sizes = torch.empty_like(scatter_sizes) |
|
|
dist.all_to_all_single(gather_sizes, |
|
|
scatter_sizes, |
|
|
group=ep_group.device_group) |
|
|
scatter_size_list = scatter_sizes.cpu().tolist() |
|
|
gather_size_list = gather_sizes.cpu().tolist() |
|
|
|
|
|
expanded_expert_idx = expanded_expert_idx % local_num_experts |
|
|
hidden_states = ep_group.all_to_all(hidden_states, 0, 0, |
|
|
scatter_size_list, |
|
|
gather_size_list) |
|
|
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0, |
|
|
scatter_size_list, |
|
|
gather_size_list) |
|
|
|
|
|
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx) |
|
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
|
|
sorted_local_expert_idx, local_num_experts).to(torch.int64) |
|
|
|
|
|
hidden_states = hidden_states[sorted_idx] |
|
|
group_list_type = 0 |
|
|
else: |
|
|
row_idx_len = num_tokens * top_k |
|
|
row_idx = torch.arange(0, |
|
|
row_idx_len, |
|
|
dtype=torch.int32, |
|
|
device=topk_weights.device).view( |
|
|
top_k, -1).permute(1, 0).contiguous() |
|
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
|
|
hidden_states, |
|
|
row_idx=row_idx, |
|
|
expert_idx=topk_ids, |
|
|
active_num=num_tokens) |
|
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
|
|
expanded_expert_idx, num_experts) |
|
|
expert_tokens = expert_tokens.to(torch.int64) |
|
|
group_list_type = 0 |
|
|
|
|
|
|
|
|
hidden_states = apply_mlp( |
|
|
hidden_states, |
|
|
w1, |
|
|
w1_scale, |
|
|
w2, |
|
|
w2_scale, |
|
|
expert_tokens, |
|
|
group_list_type=group_list_type) |
|
|
|
|
|
if expert_map is not None: |
|
|
resorted_idx = torch.argsort(sorted_idx) |
|
|
hidden_states = hidden_states[resorted_idx] |
|
|
hidden_states = ep_group.all_to_all(hidden_states, 0, 0, |
|
|
gather_size_list, |
|
|
scatter_size_list) |
|
|
|
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing( |
|
|
hidden_states, |
|
|
skip1=None, |
|
|
skip2=None, |
|
|
bias=None, |
|
|
scales=topk_weights, |
|
|
expanded_src_to_dst_row=expanded_row_idx, |
|
|
export_for_source_row=topk_ids, |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing( |
|
|
hidden_states, |
|
|
skip1=None, |
|
|
skip2=None, |
|
|
bias=None, |
|
|
scales=topk_weights, |
|
|
expanded_src_to_dst_row=expanded_row_idx, |
|
|
export_for_source_row=topk_ids, |
|
|
) |
|
|
if len(original_shape) == 3: |
|
|
final_hidden_states = final_hidden_states.view(original_shape) |
|
|
return final_hidden_states |
|
|
|
|
|
|
|
|
def fused_experts_with_allgather(hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
topk_weights: torch.Tensor, |
|
|
topk_ids: torch.Tensor, |
|
|
top_k: int, |
|
|
expert_map: torch.Tensor = None): |
|
|
original_shape = hidden_states.shape |
|
|
if len(original_shape) == 3: |
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
|
num_tokens = hidden_states.shape[0] |
|
|
batch_size, hidden_size = hidden_states.shape |
|
|
|
|
|
ep_group = get_ep_group().device_group |
|
|
ep_rank = torch.distributed.get_rank(group=ep_group) |
|
|
ep_size = torch.distributed.get_world_size(ep_group) |
|
|
|
|
|
global_num_experts = len(expert_map) |
|
|
local_num_experts = global_num_experts // ep_size |
|
|
|
|
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) |
|
|
|
|
|
hidden_states, expanded_x_idx, expert_tokens, pertoken_scale = torch_npu.npu_moe_init_routing_v2( |
|
|
hidden_states, |
|
|
topk_ids, |
|
|
scale=pertoken_scale, |
|
|
offset=None, |
|
|
active_num=num_tokens * top_k, |
|
|
expert_num=global_num_experts, |
|
|
expert_tokens_num_type=1, |
|
|
expert_tokens_num_flag=True, |
|
|
active_expert_range=[ |
|
|
ep_rank * local_num_experts, (ep_rank + 1) * local_num_experts |
|
|
], |
|
|
quant_mode=-1, |
|
|
row_idx_type=0) |
|
|
group_list_type = 1 |
|
|
|
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul( |
|
|
x=[hidden_states], |
|
|
weight=[w1], |
|
|
split_item=3, |
|
|
group_list_type=group_list_type, |
|
|
group_type=0, |
|
|
group_list=expert_tokens, |
|
|
output_dtype=torch.int32)[0] |
|
|
|
|
|
|
|
|
hidden_states, pertoken_scale = torch_npu.npu_dequant_swiglu_quant( |
|
|
x=hidden_states, |
|
|
weight_scale=w1_scale.to(torch.float32), |
|
|
activation_scale=pertoken_scale, |
|
|
bias=None, |
|
|
quant_scale=None, |
|
|
quant_offset=None, |
|
|
group_index=expert_tokens, |
|
|
activate_left=True, |
|
|
quant_mode=1, |
|
|
) |
|
|
|
|
|
hidden_states = torch_npu.npu_grouped_matmul( |
|
|
x=[hidden_states], |
|
|
weight=[w2], |
|
|
scale=[w2_scale.to(torch.bfloat16)], |
|
|
per_token_scale=[pertoken_scale.view(-1)], |
|
|
split_item=3, |
|
|
group_list_type=group_list_type, |
|
|
group_type=0, |
|
|
group_list=expert_tokens, |
|
|
output_dtype=torch.bfloat16)[0] |
|
|
|
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing( |
|
|
expanded_permuted_rows=hidden_states.unsqueeze(1), |
|
|
skip1=None, |
|
|
skip2=None, |
|
|
bias=None, |
|
|
scales=topk_weights.to(torch.bfloat16), |
|
|
expanded_src_to_dst_row=expanded_x_idx.to(torch.int32), |
|
|
export_for_source_row=topk_ids, |
|
|
drop_pad_mode=3 |
|
|
).to(torch.bfloat16) |
|
|
|
|
|
if len(original_shape) == 3: |
|
|
final_hidden_states = final_hidden_states.view(original_shape) |
|
|
|
|
|
return final_hidden_states |
|
|
|
|
|
|
|
|
def fused_experts(hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
topk_weights: torch.Tensor, |
|
|
topk_ids: torch.Tensor, |
|
|
top_k: int, |
|
|
expert_map: torch.Tensor = None): |
|
|
original_shape = hidden_states.shape |
|
|
if len(original_shape) == 3: |
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) |
|
|
|
|
|
num_tokens, _ = hidden_states.shape |
|
|
num_experts = w1.shape[0] |
|
|
dtype = hidden_states.dtype |
|
|
device = hidden_states.device |
|
|
|
|
|
if expert_map is not None: |
|
|
|
|
|
token_indices = (torch.arange(num_tokens, |
|
|
device=device, |
|
|
dtype=torch.int64).unsqueeze(1).expand( |
|
|
-1, top_k).reshape(-1)) |
|
|
|
|
|
|
|
|
weights_flat = topk_weights.view(-1) |
|
|
experts_flat = topk_ids.view(-1) |
|
|
local_experts_flat = expert_map[experts_flat] |
|
|
|
|
|
|
|
|
mask = local_experts_flat != -1 |
|
|
filtered_weights = torch.where( |
|
|
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype) |
|
|
filtered_experts = torch.where( |
|
|
mask, local_experts_flat, |
|
|
torch.full_like(local_experts_flat, |
|
|
num_experts)).to(topk_ids.dtype) |
|
|
|
|
|
|
|
|
sort_indices = torch.argsort(filtered_experts) |
|
|
sorted_token_indices = token_indices[sort_indices] |
|
|
sorted_weights = filtered_weights[sort_indices] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_counts = torch.zeros(num_experts + 1, |
|
|
device=device, |
|
|
dtype=torch.int64) |
|
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64) |
|
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones) |
|
|
expert_tokens = token_counts[:num_experts] |
|
|
|
|
|
hidden_states = hidden_states[sorted_token_indices] |
|
|
group_list_type = 1 |
|
|
else: |
|
|
row_idx_len = num_tokens * top_k |
|
|
row_idx = torch.arange(0, |
|
|
row_idx_len, |
|
|
dtype=torch.int32, |
|
|
device=topk_weights.device).view( |
|
|
top_k, -1).permute(1, 0).contiguous() |
|
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing( |
|
|
hidden_states, |
|
|
row_idx=row_idx, |
|
|
expert_idx=topk_ids, |
|
|
active_num=num_tokens) |
|
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens( |
|
|
expanded_expert_idx, num_experts) |
|
|
expert_tokens = expert_tokens.to(torch.int64) |
|
|
group_list_type = 0 |
|
|
|
|
|
|
|
|
hidden_states = apply_mlp(hidden_states, |
|
|
w1, |
|
|
w1_scale, |
|
|
w2, |
|
|
w2_scale, |
|
|
expert_tokens, |
|
|
group_list_type=group_list_type) |
|
|
|
|
|
if expert_map is not None: |
|
|
hidden_states.mul_(sorted_weights.unsqueeze(1)) |
|
|
final_hidden_states = torch.zeros(*original_shape, |
|
|
device=device, |
|
|
dtype=dtype) |
|
|
|
|
|
num_valid_tokens = mask.sum() |
|
|
valid_token_mask = torch.arange( |
|
|
0, sorted_token_indices.shape[0], |
|
|
device=device).unsqueeze(1) < num_valid_tokens |
|
|
hidden_states = hidden_states.masked_fill_(~valid_token_mask, |
|
|
0).to(dtype) |
|
|
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states) |
|
|
else: |
|
|
|
|
|
|
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing( |
|
|
hidden_states, |
|
|
skip1=None, |
|
|
skip2=None, |
|
|
bias=None, |
|
|
scales=topk_weights, |
|
|
expanded_src_to_dst_row=expanded_row_idx, |
|
|
export_for_source_row=topk_ids, |
|
|
) |
|
|
|
|
|
if len(original_shape) == 3: |
|
|
final_hidden_states = final_hidden_states.view(original_shape) |
|
|
return final_hidden_states |
|
|
|
|
|
|
|
|
class AscendW8A8DynamicLinearMethod: |
|
|
"""Linear method for Ascend W8A8_DYNAMIC. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.transpose_weight = True |
|
|
|
|
|
@staticmethod |
|
|
def get_weight(input_size: int, output_size: int, |
|
|
params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
params_dict = { |
|
|
"weight": torch.empty(output_size, input_size, dtype=torch.int8) |
|
|
} |
|
|
return params_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
return {} |
|
|
|
|
|
@staticmethod |
|
|
def get_perchannel_param( |
|
|
output_size: int, |
|
|
params_dtype: torch.dtype, |
|
|
) -> Dict[str, Any]: |
|
|
params_dict = {} |
|
|
params_dict["weight_scale"] = torch.empty(output_size, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
params_dict["weight_offset"] = torch.empty(output_size, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
return params_dict |
|
|
|
|
|
@staticmethod |
|
|
def apply( |
|
|
layer: torch.nn.Module, |
|
|
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]], |
|
|
bias: Optional[torch.Tensor] = None, |
|
|
tp_rank: Optional[int] = 0, |
|
|
) -> torch.Tensor: |
|
|
config = getattr(layer, "_ascend_quant_config", {}) |
|
|
if not isinstance(x, tuple): |
|
|
output_dtype = config.get("output_dtype", x.dtype) |
|
|
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x) |
|
|
else: |
|
|
assert "output_dtype" in config.keys(), ( |
|
|
f"DynamicLinearMethod needs explicitly specified `output_dtype`" |
|
|
f"for pre-quantized input, got config [{config}]") |
|
|
output_dtype = config["output_dtype"] |
|
|
quantized_x, dynamic_scale = x |
|
|
pertoken_scale = (dynamic_scale |
|
|
if config.get("pertoken_scale", True) else None) |
|
|
|
|
|
output = torch_npu.npu_quant_matmul( |
|
|
quantized_x, |
|
|
layer.weight, |
|
|
layer.weight_scale, |
|
|
pertoken_scale=pertoken_scale, |
|
|
bias=bias, |
|
|
output_dtype=output_dtype, |
|
|
) |
|
|
return ((output, dynamic_scale) |
|
|
if config.get("return_scale", False) else output) |
|
|
|
|
|
def process_weights_after_loading(self, layer): |
|
|
if self.transpose_weight: |
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() |
|
|
|
|
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29) |
|
|
layer.weight_scale.data = layer.weight_scale.data.flatten() |
|
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) |
|
|
layer.weight_offset.data = layer.weight_offset.data.flatten() |
|
|
|
|
|
|
|
|
class AscendW8A8DynamicFusedMoEMethod: |
|
|
"""FusedMoe method for Ascend W8A8_DYNAMIC. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.transpose_weight = True |
|
|
|
|
|
self.ep_group = get_ep_group() |
|
|
|
|
|
ascend_config = get_ascend_config() |
|
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled |
|
|
|
|
|
try: |
|
|
device_group = self.ep_group.device_group |
|
|
|
|
|
local_rank = torch.distributed.get_rank(group=device_group) |
|
|
backend = device_group._get_backend(torch.device("npu")) |
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name( |
|
|
local_rank) |
|
|
except AttributeError: |
|
|
self.moe_all_to_all_group_name = "" |
|
|
|
|
|
@staticmethod |
|
|
def get_weight(num_experts: int, intermediate_size_per_partition: int, |
|
|
hidden_sizes: int, |
|
|
params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
param_dict = {} |
|
|
param_dict["w13_weight"] = torch.empty(num_experts, |
|
|
2 * |
|
|
intermediate_size_per_partition, |
|
|
hidden_sizes, |
|
|
dtype=torch.int8) |
|
|
param_dict["w2_weight"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
intermediate_size_per_partition, |
|
|
dtype=torch.int8) |
|
|
return param_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_dynamic_quant_param(num_experts: int, |
|
|
intermediate_size_per_partition: int, |
|
|
hidden_sizes: int, |
|
|
params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
param_dict = {} |
|
|
param_dict["w13_weight_scale"] = torch.empty( |
|
|
num_experts, |
|
|
2 * intermediate_size_per_partition, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
param_dict["w13_weight_offset"] = torch.empty( |
|
|
num_experts, |
|
|
2 * intermediate_size_per_partition, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
return param_dict |
|
|
|
|
|
def apply( |
|
|
self, |
|
|
layer: torch.nn.Module, |
|
|
x: torch.Tensor, |
|
|
router_logits: torch.Tensor, |
|
|
top_k: int, |
|
|
renormalize: bool, |
|
|
use_grouped_topk: bool = False, |
|
|
global_num_experts: int = -1, |
|
|
expert_map: Optional[torch.Tensor] = None, |
|
|
topk_group: Optional[int] = None, |
|
|
num_expert_group: Optional[int] = None, |
|
|
custom_routing_function: Optional[Callable] = None, |
|
|
scoring_func: str = "softmax", |
|
|
e_score_correction_bias: Optional[torch.Tensor] = None, |
|
|
is_prefill: bool = True, |
|
|
enable_force_load_balance: bool = True, |
|
|
log2phy: torch.Tensor = None, |
|
|
global_redundant_expert_num: int = 0, |
|
|
shared_experts: Optional[Any] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
assert router_logits.shape[ |
|
|
1] == global_num_experts, "Number of global experts mismatch" |
|
|
|
|
|
is_deepseek_v3_r1 = global_num_experts == 256 |
|
|
use_grouped_topk = (topk_group > 1 or num_expert_group > 1) |
|
|
|
|
|
|
|
|
if use_grouped_topk and is_deepseek_v3_r1: |
|
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k( |
|
|
router_logits, |
|
|
k=top_k, |
|
|
bias=e_score_correction_bias, |
|
|
k_group=topk_group, |
|
|
group_count=num_expert_group, |
|
|
group_select_mode=1, |
|
|
renorm=0, |
|
|
norm_type=1, |
|
|
|
|
|
|
|
|
routed_scaling_factor=1, |
|
|
eps=float(1e-20)) |
|
|
else: |
|
|
topk_weights, topk_ids = select_experts( |
|
|
hidden_states=x, |
|
|
router_logits=router_logits, |
|
|
top_k=top_k, |
|
|
use_grouped_topk=use_grouped_topk, |
|
|
renormalize=renormalize, |
|
|
topk_group=topk_group, |
|
|
num_expert_group=num_expert_group, |
|
|
custom_routing_function=custom_routing_function, |
|
|
scoring_func=scoring_func, |
|
|
e_score_correction_bias=e_score_correction_bias, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if enable_force_load_balance: |
|
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) |
|
|
|
|
|
topk_weights = topk_weights.to(x.dtype) |
|
|
|
|
|
fused_moe_state = get_fused_moe_state(self.ep_group.world_size, |
|
|
is_prefill, is_deepseek_v3_r1) |
|
|
if fused_moe_state == FusedMoEState.AllGatherEP: |
|
|
return fused_experts_with_allgather( |
|
|
hidden_states=x, |
|
|
w1=layer.w13_weight, |
|
|
w1_scale=layer.w13_weight_scale, |
|
|
w2=layer.w2_weight, |
|
|
w2_scale=layer.w2_weight_scale, |
|
|
topk_weights=topk_weights, |
|
|
topk_ids=topk_ids, |
|
|
top_k=top_k, |
|
|
expert_map=expert_map) |
|
|
elif fused_moe_state == FusedMoEState.MC2: |
|
|
return fused_experts_with_mc2( |
|
|
hidden_states=x, |
|
|
w1=layer.w13_weight, |
|
|
w2=layer.w2_weight, |
|
|
w1_scale=layer.w13_weight_scale, |
|
|
w2_scale=layer.w2_weight_scale, |
|
|
topk_weights=topk_weights, |
|
|
topk_ids=topk_ids, |
|
|
top_k=top_k, |
|
|
expert_map=expert_map, |
|
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name, |
|
|
log2phy=log2phy, |
|
|
global_redundant_expert_num=global_redundant_expert_num, |
|
|
shared_experts=shared_experts) |
|
|
elif fused_moe_state in [ |
|
|
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast |
|
|
]: |
|
|
return fused_experts(hidden_states=x, |
|
|
w1=layer.w13_weight, |
|
|
w1_scale=layer.w13_weight_scale, |
|
|
w2=layer.w2_weight, |
|
|
w2_scale=layer.w2_weight_scale, |
|
|
topk_weights=topk_weights, |
|
|
topk_ids=topk_ids, |
|
|
top_k=top_k, |
|
|
expert_map=expert_map) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return fused_experts_with_all2all( |
|
|
hidden_states=x, |
|
|
w1=layer.w13_weight, |
|
|
w1_scale=layer.w13_weight_scale, |
|
|
w2=layer.w2_weight, |
|
|
w2_scale=layer.w2_weight_scale, |
|
|
topk_weights=topk_weights, |
|
|
topk_ids=topk_ids, |
|
|
top_k=top_k, |
|
|
expert_map=expert_map, |
|
|
ep_group=self.ep_group, |
|
|
log2phy=log2phy, |
|
|
global_redundant_expert_num=global_redundant_expert_num, |
|
|
) |
|
|
|
|
|
def process_weights_after_loading(self, layer): |
|
|
if self.transpose_weight: |
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose( |
|
|
1, 2).contiguous() |
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose( |
|
|
1, 2).contiguous() |
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( |
|
|
layer.w13_weight_scale.data.shape[0], -1) |
|
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( |
|
|
layer.w13_weight_offset.data.shape[0], -1) |
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( |
|
|
layer.w2_weight_scale.data.shape[0], -1) |
|
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( |
|
|
layer.w2_weight_offset.data.shape[0], -1) |
|
|
|