|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Callable, Dict, Optional |
|
|
|
|
|
import torch |
|
|
import torch_npu |
|
|
from vllm.attention.backends.abstract import AttentionType |
|
|
|
|
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState |
|
|
from vllm_ascend.distributed.parallel_state import get_ep_group |
|
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p |
|
|
|
|
|
|
|
|
def quant_per_tensor(in_tensor: torch.Tensor, |
|
|
input_scale: torch.Tensor, |
|
|
input_offset: torch.Tensor, |
|
|
function=False): |
|
|
return torch_npu.npu_quantize(in_tensor, input_scale, input_offset, |
|
|
torch.qint8, -1, function) |
|
|
|
|
|
|
|
|
class AscendW8A8LinearMethod: |
|
|
"""Linear method for Ascend W8A8. |
|
|
|
|
|
Args: |
|
|
w_sym: whether the linear weight is symmetrically quantized. |
|
|
""" |
|
|
|
|
|
def __init__(self) -> None: |
|
|
|
|
|
self.transpose_weight = not is_310p() |
|
|
|
|
|
@staticmethod |
|
|
def get_weight( |
|
|
input_size: int, |
|
|
output_size: int, |
|
|
params_dtype: torch.dtype = torch.bfloat16, |
|
|
) -> Dict[str, Any]: |
|
|
params_dict = { |
|
|
"weight": torch.empty(output_size, input_size, dtype=torch.int8) |
|
|
} |
|
|
return params_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
params_dict = {} |
|
|
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) |
|
|
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) |
|
|
return params_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_perchannel_param( |
|
|
output_size: int, |
|
|
params_dtype: torch.dtype, |
|
|
) -> Dict[str, Any]: |
|
|
params_dict = {} |
|
|
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) |
|
|
if params_dtype == torch.bfloat16: |
|
|
params_dict["deq_scale"] = torch.empty(output_size, |
|
|
dtype=torch.float32) |
|
|
elif params_dtype == torch.float16: |
|
|
params_dict["deq_scale"] = torch.empty(output_size, |
|
|
dtype=torch.int64) |
|
|
params_dict["weight_scale"] = torch.empty(output_size, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
params_dict["weight_offset"] = torch.empty(output_size, |
|
|
1, |
|
|
dtype=params_dtype) |
|
|
return params_dict |
|
|
|
|
|
@staticmethod |
|
|
def apply( |
|
|
layer: torch.nn.Module, |
|
|
x: torch.Tensor, |
|
|
bias: Optional[torch.Tensor] = None, |
|
|
tp_rank: Optional[int] = 0, |
|
|
) -> torch.Tensor: |
|
|
original_dtype = x.dtype |
|
|
if original_dtype != torch.int8: |
|
|
x = quant_per_tensor(x, layer.aclnn_input_scale, |
|
|
layer.aclnn_input_offset) |
|
|
quant_bias = layer.quant_bias if tp_rank == 0 else None |
|
|
if is_310p(): |
|
|
|
|
|
|
|
|
output = torch_npu.npu_quant_matmul( |
|
|
x, |
|
|
layer.weight.data.transpose(1, 0), |
|
|
layer.deq_scale, |
|
|
bias=quant_bias, |
|
|
output_dtype=original_dtype, |
|
|
) |
|
|
else: |
|
|
output = torch_npu.npu_quant_matmul( |
|
|
x, |
|
|
layer.weight, |
|
|
layer.deq_scale, |
|
|
bias=quant_bias, |
|
|
output_dtype=original_dtype, |
|
|
) |
|
|
return output |
|
|
|
|
|
def process_weights_after_loading(self, layer): |
|
|
expanding_factor = layer.weight.data.shape[1] |
|
|
layer.aclnn_input_scale = 1 / torch.nn.Parameter( |
|
|
layer.input_scale.data.repeat(expanding_factor), |
|
|
requires_grad=False) |
|
|
layer.aclnn_input_offset = torch.nn.Parameter( |
|
|
layer.input_offset.data.repeat(expanding_factor), |
|
|
requires_grad=False).to(layer.aclnn_input_scale.dtype) |
|
|
if self.transpose_weight: |
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() |
|
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, |
|
|
ACL_FORMAT_FRACTAL_NZ) |
|
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data) |
|
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data) |
|
|
|
|
|
|
|
|
class AscendW8A8FusedMoEMethod: |
|
|
"""FusedMoe method for Ascend W8A8. |
|
|
""" |
|
|
|
|
|
def __init__(self): |
|
|
self.transpose_weight = True |
|
|
|
|
|
@staticmethod |
|
|
def get_weight(num_experts: int, intermediate_size_per_partition: int, |
|
|
hidden_sizes: int, |
|
|
params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
param_dict = {} |
|
|
param_dict["w13_weight"] = torch.empty(num_experts, |
|
|
2 * |
|
|
intermediate_size_per_partition, |
|
|
hidden_sizes, |
|
|
dtype=torch.int8, |
|
|
requires_grad=False) |
|
|
param_dict["w2_weight"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
intermediate_size_per_partition, |
|
|
dtype=torch.int8, |
|
|
requires_grad=False) |
|
|
return param_dict |
|
|
|
|
|
@staticmethod |
|
|
def get_dynamic_quant_param(num_experts: int, |
|
|
intermediate_size_per_partition: int, |
|
|
hidden_sizes: int, |
|
|
params_dtype: torch.dtype) -> Dict[str, Any]: |
|
|
param_dict = {} |
|
|
param_dict["w13_weight_scale"] = torch.empty( |
|
|
num_experts, |
|
|
2 * intermediate_size_per_partition, |
|
|
1, |
|
|
dtype=torch.float32) |
|
|
param_dict["w13_weight_offset"] = torch.empty( |
|
|
num_experts, |
|
|
2 * intermediate_size_per_partition, |
|
|
1, |
|
|
dtype=torch.float16) |
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
1, |
|
|
dtype=torch.float32) |
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
1, |
|
|
dtype=torch.float16) |
|
|
param_dict["w2_deq_scale"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
dtype=torch.float32) |
|
|
param_dict["w13_deq_scale"] = torch.empty( |
|
|
num_experts, |
|
|
2 * intermediate_size_per_partition, |
|
|
dtype=torch.float32) |
|
|
param_dict["w2_input_scale"] = torch.empty(num_experts, |
|
|
1, |
|
|
dtype=torch.float32) |
|
|
param_dict["w13_input_scale"] = torch.empty(num_experts, |
|
|
1, |
|
|
dtype=torch.float32) |
|
|
param_dict["w2_input_offset"] = torch.empty(num_experts, |
|
|
1, |
|
|
dtype=torch.int8) |
|
|
param_dict["w13_input_offset"] = torch.empty(num_experts, |
|
|
1, |
|
|
dtype=torch.int8) |
|
|
param_dict["quant_bias"] = torch.empty(num_experts, |
|
|
hidden_sizes, |
|
|
dtype=torch.int32) |
|
|
|
|
|
return param_dict |
|
|
|
|
|
def apply( |
|
|
self, |
|
|
layer: torch.nn.Module, |
|
|
x: torch.Tensor, |
|
|
router_logits: torch.Tensor, |
|
|
top_k: int, |
|
|
renormalize: bool, |
|
|
use_grouped_topk: bool = False, |
|
|
global_num_experts: int = -1, |
|
|
expert_map: Optional[torch.Tensor] = None, |
|
|
topk_group: Optional[int] = None, |
|
|
num_expert_group: Optional[int] = None, |
|
|
custom_routing_function: Optional[Callable] = None, |
|
|
scoring_func: str = "softmax", |
|
|
e_score_correction_bias: Optional[torch.Tensor] = None, |
|
|
is_prefill: bool = True, |
|
|
enable_force_load_balance: bool = False, |
|
|
log2phy: torch.Tensor = None, |
|
|
global_redundant_expert_num: int = 0, |
|
|
shared_experts: Optional[Any] = None, |
|
|
**kwargs, |
|
|
) -> torch.Tensor: |
|
|
assert router_logits.shape[ |
|
|
1] == global_num_experts, "Number of global experts mismatch" |
|
|
|
|
|
topk_weights, topk_ids = select_experts( |
|
|
hidden_states=x, |
|
|
router_logits=router_logits, |
|
|
top_k=top_k, |
|
|
use_grouped_topk=use_grouped_topk, |
|
|
renormalize=renormalize, |
|
|
topk_group=topk_group, |
|
|
num_expert_group=num_expert_group, |
|
|
custom_routing_function=custom_routing_function, |
|
|
scoring_func=scoring_func, |
|
|
e_score_correction_bias=e_score_correction_bias, |
|
|
global_num_experts=global_num_experts, |
|
|
) |
|
|
|
|
|
if is_310p(): |
|
|
return fused_experts_310p(hidden_states=x, |
|
|
w1=layer.w13_weight, |
|
|
w1_scale=layer.w13_weight_scale, |
|
|
w1_input_scale=layer.w13_input_scale, |
|
|
w2=layer.w2_weight, |
|
|
w2_scale=layer.w2_weight_scale, |
|
|
w2_input_scale=layer.w2_input_scale, |
|
|
topk_weights=topk_weights, |
|
|
topk_ids=topk_ids, |
|
|
top_k=top_k, |
|
|
global_num_experts=global_num_experts, |
|
|
expert_map=expert_map) |
|
|
return fused_experts(hidden_states=x, |
|
|
w1=layer.w13_weight, |
|
|
w1_scale=layer.w13_weight_scale, |
|
|
w1_input_scale=layer.w13_input_scale, |
|
|
w1_input_offset=layer.w13_input_offset, |
|
|
w2=layer.w2_weight, |
|
|
w2_scale=layer.w2_weight_scale, |
|
|
w2_input_scale=layer.w2_input_scale, |
|
|
w2_input_offset=layer.w2_input_offset, |
|
|
topk_weights=topk_weights, |
|
|
topk_ids=topk_ids, |
|
|
top_k=top_k, |
|
|
global_num_experts=global_num_experts, |
|
|
expert_map=expert_map) |
|
|
|
|
|
def process_weights_after_loading(self, layer): |
|
|
if not is_310p(): |
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose( |
|
|
1, 2).contiguous() |
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose( |
|
|
1, 2).contiguous() |
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view( |
|
|
layer.w13_weight_scale.data.shape[0], -1) |
|
|
|
|
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view( |
|
|
layer.w13_weight_offset.data.shape[0], -1) |
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view( |
|
|
layer.w2_weight_scale.data.shape[0], -1) |
|
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view( |
|
|
layer.w2_weight_offset.data.shape[0], -1) |
|
|
expanding_factor_w13 = layer.w13_weight.data.shape[1] |
|
|
expanding_factor_w2 = layer.w2_weight.data.shape[1] |
|
|
|
|
|
if is_310p(): |
|
|
layer.w13_input_scale.data = torch.nn.Parameter( |
|
|
layer.w13_input_scale.data.max()) |
|
|
layer.w2_input_scale.data = torch.nn.Parameter( |
|
|
layer.w2_input_scale.data.max()) |
|
|
else: |
|
|
layer.w13_input_scale.data = torch.nn.Parameter( |
|
|
layer.w13_input_scale.data.repeat(1, |
|
|
expanding_factor_w13)[0:1]) |
|
|
layer.w2_input_scale.data = torch.nn.Parameter( |
|
|
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]) |
|
|
|
|
|
layer.w13_input_offset.data = torch.nn.Parameter( |
|
|
layer.w13_input_scale.data.repeat(1, expanding_factor_w13)[0:1]) |
|
|
layer.w2_input_offset.data = torch.nn.Parameter( |
|
|
layer.w2_input_scale.data.repeat(1, expanding_factor_w2)[0:1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not is_310p(): |
|
|
layer.w13_weight.data = torch_npu.npu_format_cast( |
|
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() |
|
|
layer.w2_weight.data = torch_npu.npu_format_cast( |
|
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ).contiguous() |
|
|
|
|
|
|
|
|
class AscendC8KVCacheMethod: |
|
|
|
|
|
def __init__(self) -> None: |
|
|
self.antiquant_scale_comb = None |
|
|
|
|
|
@staticmethod |
|
|
def create_weights(layer) -> None: |
|
|
param_dict = {} |
|
|
param_dict["key_antiquant_scale"] = torch.empty(layer.num_kv_heads * |
|
|
layer.head_size, |
|
|
dtype=torch.float16, |
|
|
requires_grad=False) |
|
|
param_dict["value_antiquant_scale"] = torch.empty(layer.num_kv_heads * |
|
|
layer.head_size, |
|
|
dtype=torch.float16, |
|
|
requires_grad=False) |
|
|
for weight_name, weight_param in param_dict.items(): |
|
|
param = torch.nn.Parameter(weight_param, requires_grad=False) |
|
|
layer.register_parameter(weight_name, param) |
|
|
|
|
|
def process_weights_after_loading(self, layer): |
|
|
self.antiquant_scale_comb = torch.cat( |
|
|
(layer.key_antiquant_scale.data.unsqueeze(0), |
|
|
layer.value_antiquant_scale.data.unsqueeze(0)), |
|
|
dim=0).to(torch.float16).contiguous() |
|
|
|
|
|
def apply(self, layer, query, key, value, kv_cache, attn_metadata, |
|
|
attn_type, scale, output) -> torch.Tensor: |
|
|
num_tokens = query.shape[0] |
|
|
if attn_metadata is None: |
|
|
return output.view(num_tokens, layer.num_heads * layer.head_size) |
|
|
assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0 |
|
|
if attn_type != AttentionType.DECODER: |
|
|
raise NotImplementedError("Encoder self-attention and " |
|
|
"encoder/decoder cross-attention " |
|
|
"are not implemented for " |
|
|
"PallasAttentionBackendImpl") |
|
|
|
|
|
|
|
|
quant_key = quant_per_tensor( |
|
|
key.view(-1, layer.num_kv_heads * layer.head_size), |
|
|
layer.key_antiquant_scale.data.view(-1), None, True) |
|
|
quant_value = quant_per_tensor( |
|
|
value.view(-1, layer.num_kv_heads * layer.head_size), |
|
|
layer.value_antiquant_scale.data.view(-1), None, True) |
|
|
|
|
|
|
|
|
query = query.view(-1, layer.num_heads, layer.head_size) |
|
|
key = key.view(-1, layer.num_kv_heads, layer.head_size) |
|
|
value = value.view(-1, layer.num_kv_heads, layer.head_size) |
|
|
|
|
|
value = value.contiguous() |
|
|
|
|
|
if kv_cache[0].numel() > 0: |
|
|
|
|
|
key_cache, value_cache = kv_cache[0], kv_cache[1] |
|
|
slots = attn_metadata.slot_mapping |
|
|
|
|
|
block_size = key_cache.shape[1] |
|
|
slots_indices = slots.reshape(-1, 1) |
|
|
block_indices = slots_indices // block_size |
|
|
slots_indices = slots_indices % block_size |
|
|
indices = torch.cat((block_indices, slots_indices), dim=1) |
|
|
|
|
|
|
|
|
torch_npu.npu_scatter_nd_update_(key_cache, indices, quant_key) |
|
|
torch_npu.npu_scatter_nd_update_(value_cache, indices, quant_value) |
|
|
|
|
|
|
|
|
if attn_metadata.attn_state == AscendAttentionState.PrefillNoCache: |
|
|
assert attn_metadata is not None |
|
|
assert attn_metadata.attn_mask is not None |
|
|
mask = attn_metadata.attn_mask |
|
|
torch_npu._npu_flash_attention(query=query, |
|
|
key=key, |
|
|
value=value, |
|
|
mask=mask, |
|
|
seq_len=attn_metadata.seq_lens, |
|
|
scale_value=scale, |
|
|
num_heads=layer.num_heads, |
|
|
num_kv_heads=layer.num_kv_heads, |
|
|
out=output.reshape(query.shape)) |
|
|
|
|
|
elif attn_metadata.attn_state == AscendAttentionState.PrefillCacheHit: |
|
|
raise NotImplementedError("kv cache int8 are not " |
|
|
"implemented for " |
|
|
"PrefillCacheHit") |
|
|
elif attn_metadata.attn_state == AscendAttentionState.DecodeOnly: |
|
|
if hasattr(attn_metadata, "decode"): |
|
|
|
|
|
decode_meta = attn_metadata.decode |
|
|
seq_lens = decode_meta.seq_lens_list |
|
|
else: |
|
|
seq_lens = attn_metadata.seq_lens |
|
|
block_size = key_cache.shape[1] |
|
|
query = query.view(num_tokens, 1, layer.num_heads * |
|
|
layer.head_size).contiguous() |
|
|
|
|
|
|
|
|
key = key_cache |
|
|
value = value_cache |
|
|
|
|
|
output = torch_npu.npu_incre_flash_attention( |
|
|
query, |
|
|
key, |
|
|
value, |
|
|
num_key_value_heads=layer.num_kv_heads, |
|
|
num_heads=layer.num_heads, |
|
|
actual_seq_lengths=seq_lens, |
|
|
scale_value=scale, |
|
|
input_layout='BSH', |
|
|
block_size=block_size, |
|
|
block_table=attn_metadata.block_tables, |
|
|
antiquant_scale=self.antiquant_scale_comb, |
|
|
) |
|
|
|
|
|
|
|
|
else: |
|
|
raise NotImplementedError("kv cache int8 are not " |
|
|
"implemented for " |
|
|
"other case") |
|
|
return output |
|
|
|
|
|
|
|
|
def fused_experts_310p( |
|
|
hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w1_input_scale: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
w2_input_scale: torch.Tensor, |
|
|
topk_weights: torch.Tensor, |
|
|
topk_ids: torch.Tensor, |
|
|
top_k: int, |
|
|
global_num_experts: int, |
|
|
expert_map: torch.Tensor = None, |
|
|
) -> torch.Tensor: |
|
|
ep_size = get_ep_group().world_size |
|
|
local_num_experts = global_num_experts // ep_size |
|
|
local_num_group = top_k // ep_size |
|
|
|
|
|
bsz, _ = hidden_states.shape |
|
|
flatten_topk_ids = topk_ids.view(-1) |
|
|
sorted_topk_ids = torch.argsort(flatten_topk_ids.float()) |
|
|
sorted_topk_ids = sorted_topk_ids.to(torch.int32) |
|
|
sorted_hidden_states = hidden_states.index_select( |
|
|
0, sorted_topk_ids // local_num_group) |
|
|
|
|
|
experts_id = torch.arange(0, |
|
|
local_num_experts, |
|
|
dtype=topk_ids.dtype, |
|
|
device=topk_ids.device) |
|
|
num_tokens_per_expert = (flatten_topk_ids.unsqueeze(-1) == experts_id).to( |
|
|
torch.float32).sum(0) |
|
|
topk_scales = topk_weights.view(-1).index_select( |
|
|
0, sorted_topk_ids).unsqueeze(-1) |
|
|
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64) |
|
|
|
|
|
gate_up_out = torch_npu.npu_quant_grouped_matmul_dequant( |
|
|
x=sorted_hidden_states, |
|
|
quantized_weight=w1, |
|
|
weight_scale=w1_scale, |
|
|
group_list=group_list, |
|
|
x_scale=w1_input_scale, |
|
|
quant_mode="pertensor") |
|
|
|
|
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out.to(torch.float32)).to( |
|
|
torch.float16) |
|
|
gate_up_out *= topk_scales |
|
|
|
|
|
down_out = torch_npu.npu_quant_grouped_matmul_dequant( |
|
|
x=gate_up_out, |
|
|
quantized_weight=w2, |
|
|
weight_scale=w2_scale, |
|
|
group_list=group_list, |
|
|
x_scale=w2_input_scale, |
|
|
quant_mode="pertensor") |
|
|
|
|
|
unsorted_topk_ids = torch.argsort(sorted_topk_ids.float()).to(torch.int32) |
|
|
unsorted_hidden_states = down_out.index_select(0, unsorted_topk_ids) |
|
|
final_hidden_states = unsorted_hidden_states.reshape( |
|
|
bsz, top_k // ep_size, -1).sum(1) |
|
|
|
|
|
return final_hidden_states |
|
|
|
|
|
|
|
|
def fused_experts( |
|
|
hidden_states: torch.Tensor, |
|
|
w1: torch.Tensor, |
|
|
w1_scale: torch.Tensor, |
|
|
w1_input_scale: torch.Tensor, |
|
|
w1_input_offset: torch.Tensor, |
|
|
w2: torch.Tensor, |
|
|
w2_scale: torch.Tensor, |
|
|
w2_input_scale: torch.Tensor, |
|
|
w2_input_offset: torch.Tensor, |
|
|
topk_weights: torch.Tensor, |
|
|
topk_ids: torch.Tensor, |
|
|
top_k: int, |
|
|
global_num_experts: int, |
|
|
expert_map: torch.Tensor = None, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Fused experts with top-k routing. |
|
|
|
|
|
Args: |
|
|
hidden_states: Hidden states of shape (num_tokens, hidden_size). |
|
|
w1: Expert weights1 of shape (num_experts, intermediate_size * 2, hidden_size). |
|
|
w2: Expert weights2 of shape (num_experts, hidden_size, intermediate_size). |
|
|
topk_weights: Routing weights of shape (num_tokens, top_k). |
|
|
topk_ids: Selected expert IDs of shape (num_tokens, top_k). |
|
|
top_k: Number of experts to select. |
|
|
expert_map: Expert mapping of shape (num_experts,). |
|
|
|
|
|
Returns: |
|
|
hidden_states: Hidden states after routing. |
|
|
""" |
|
|
""" |
|
|
# Check constraints. |
|
|
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" |
|
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" |
|
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" |
|
|
assert w1.is_contiguous(), "Expert weights1 must be contiguous" |
|
|
assert w2.is_contiguous(), "Expert weights2 must be contiguous" |
|
|
""" |
|
|
|
|
|
original_dtype = hidden_states.dtype |
|
|
ep_size = get_ep_group().world_size |
|
|
local_num_experts = global_num_experts // ep_size |
|
|
w1_input_scale, _ = w1_input_scale.max(0) |
|
|
quant_sorted_hidden_states = quant_per_tensor( |
|
|
hidden_states, |
|
|
w1_input_scale, |
|
|
None, |
|
|
True, |
|
|
) |
|
|
if expert_map is not None: |
|
|
expanded_x, expanded_row_idx, expert_token_count, expanded_scale = torch_npu.npu_moe_init_routing_v2( |
|
|
quant_sorted_hidden_states, |
|
|
topk_ids, |
|
|
scale=None, |
|
|
active_num=topk_ids.numel(), |
|
|
expert_capacity=-1, |
|
|
expert_num=local_num_experts, |
|
|
drop_pad_mode=0, |
|
|
expert_tokens_num_type=1, |
|
|
expert_tokens_num_flag=True, |
|
|
quant_mode=-1, |
|
|
active_expert_range=[0, local_num_experts], |
|
|
row_idx_type=0, |
|
|
) |
|
|
|
|
|
else: |
|
|
raise NotImplementedError( |
|
|
"The quantified version of MOE class models " |
|
|
"currently does not support tensor parallelism") |
|
|
if expanded_x.dtype != w1.dtype: |
|
|
w1_input_scale, _ = w1_input_scale.max(0) |
|
|
quant_sorted_hidden_states = quant_per_tensor( |
|
|
expanded_x, |
|
|
w1_input_scale, |
|
|
None, |
|
|
True, |
|
|
) |
|
|
else: |
|
|
quant_sorted_hidden_states = expanded_x |
|
|
gate_up_out = torch_npu.npu_grouped_matmul( |
|
|
x=[quant_sorted_hidden_states], |
|
|
weight=[w1], |
|
|
scale=[w1_scale * w1_input_scale[0]], |
|
|
split_item=2, |
|
|
group_list_type=1, |
|
|
group_type=0, |
|
|
group_list=expert_token_count, |
|
|
output_dtype=original_dtype, |
|
|
)[0] |
|
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out) |
|
|
|
|
|
if gate_up_out.dtype != w2.dtype: |
|
|
w2_input_scale, _ = w2_input_scale.max(0) |
|
|
quant_gate_up_out = quant_per_tensor( |
|
|
gate_up_out, |
|
|
w2_input_scale, |
|
|
None, |
|
|
True, |
|
|
) |
|
|
else: |
|
|
quant_gate_up_out = gate_up_out |
|
|
|
|
|
down_out = torch_npu.npu_grouped_matmul( |
|
|
x=[quant_gate_up_out], |
|
|
weight=[w2], |
|
|
scale=[w2_scale * w2_input_scale[0]], |
|
|
split_item=2, |
|
|
group_list_type=1, |
|
|
group_type=0, |
|
|
group_list=expert_token_count, |
|
|
output_dtype=original_dtype, |
|
|
)[0] |
|
|
|
|
|
if expert_map is not None: |
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing( |
|
|
down_out, |
|
|
skip1=None, |
|
|
skip2=None, |
|
|
bias=None, |
|
|
scales=topk_weights.to(down_out.dtype), |
|
|
expanded_src_to_dst_row=expanded_row_idx, |
|
|
export_for_source_row=topk_ids, |
|
|
drop_pad_mode=2, |
|
|
) |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
"The quantified version of MOE class models " |
|
|
"currently does not support tensor parallelism") |
|
|
|
|
|
return final_hidden_states |
|
|
|
|
|
|
|
|
def select_experts( |
|
|
hidden_states: torch.Tensor, |
|
|
router_logits: torch.Tensor, |
|
|
top_k: int, |
|
|
use_grouped_topk: bool, |
|
|
renormalize: bool, |
|
|
topk_group: Optional[int] = None, |
|
|
num_expert_group: Optional[int] = None, |
|
|
custom_routing_function: Optional[Callable] = None, |
|
|
scoring_func: str = "softmax", |
|
|
e_score_correction_bias: Optional[torch.Tensor] = None, |
|
|
global_num_experts=-1, |
|
|
) -> tuple[torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Select top-k experts based on router logits. |
|
|
|
|
|
Args: |
|
|
hidden_states: Hidden states of shape (num_tokens, hidden_size). |
|
|
router_logits: Router logits of shape (num_tokens, num_experts). |
|
|
top_k: Number of experts to select. |
|
|
use_grouped_topk: Whether to group experts before selecting top-k. |
|
|
renormalize: Whether to renormalize the routing weights. |
|
|
topk_group: Number of expert groups to select from. |
|
|
num_expert_group: Number of experts in each group. |
|
|
custom_routing_function: Custom routing function. |
|
|
scoring_func: Scoring function to use. |
|
|
e_score_correction_bias: Correction bias to apply to expert scores. |
|
|
|
|
|
Returns: |
|
|
topk_weights: Routing weights of shape (num_tokens, top_k). |
|
|
topk_ids: Selected expert IDs of shape (num_tokens, top_k). |
|
|
|
|
|
Raises: |
|
|
ValueError: If an unsupported scoring function is provided. |
|
|
""" |
|
|
|
|
|
if scoring_func == "softmax": |
|
|
|
|
|
topk_weights = router_logits.softmax(dim=-1) |
|
|
elif scoring_func == "sigmoid": |
|
|
topk_weights = router_logits.sigmoid() |
|
|
else: |
|
|
raise ValueError(f"Unsupported scoring function: {scoring_func}") |
|
|
|
|
|
if use_grouped_topk: |
|
|
assert topk_group is not None |
|
|
assert num_expert_group is not None |
|
|
|
|
|
if e_score_correction_bias is not None: |
|
|
|
|
|
|
|
|
original_weights = topk_weights |
|
|
topk_weights = topk_weights + e_score_correction_bias.unsqueeze(0) |
|
|
|
|
|
|
|
|
|
|
|
topk_weights = native_grouped_topk(topk_weights, num_expert_group, |
|
|
topk_group) |
|
|
|
|
|
if e_score_correction_bias is not None: |
|
|
topk_ids = torch.topk(topk_weights.to(torch.float32), |
|
|
k=top_k, |
|
|
dim=-1, |
|
|
sorted=False)[1] |
|
|
|
|
|
topk_weights = original_weights.gather(1, topk_ids) |
|
|
else: |
|
|
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), |
|
|
k=top_k, |
|
|
dim=-1, |
|
|
sorted=False) |
|
|
elif custom_routing_function is None: |
|
|
topk_weights, topk_ids = topk_weights.topk(top_k, dim=-1) |
|
|
else: |
|
|
topk_weights, topk_ids = custom_routing_function( |
|
|
hidden_states=hidden_states, |
|
|
gating_output=router_logits, |
|
|
topk=top_k, |
|
|
renormalize=renormalize, |
|
|
global_num_experts=global_num_experts, |
|
|
) |
|
|
|
|
|
topk_ids = topk_ids.to(torch.int32) |
|
|
return topk_weights, topk_ids |
|
|
|
|
|
|
|
|
topk_ids = topk_ids.to(torch.int32) |
|
|
|
|
|
if renormalize: |
|
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) |
|
|
|
|
|
return topk_weights, topk_ids |
|
|
|
|
|
|
|
|
def native_grouped_topk( |
|
|
topk_weights: torch.Tensor, |
|
|
num_expert_group: Optional[int], |
|
|
topk_group: Optional[int], |
|
|
): |
|
|
topk_group = 0 if topk_group is None else topk_group |
|
|
num_expert_group = 0 if num_expert_group is None else num_expert_group |
|
|
|
|
|
num_token = topk_weights.shape[0] |
|
|
grouped_weights = topk_weights.view(num_token, num_expert_group, |
|
|
-1).max(dim=-1).values |
|
|
topk_group_indices = torch.topk(grouped_weights.to(torch.float32), |
|
|
k=topk_group, |
|
|
dim=-1, |
|
|
sorted=False)[1] |
|
|
topk_group_mask = torch.zeros_like(grouped_weights) |
|
|
topk_group_mask.scatter_(1, topk_group_indices, 1) |
|
|
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( |
|
|
num_token, num_expert_group, |
|
|
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) |
|
|
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) |
|
|
|
|
|
return topk_weights |
|
|
|