| |
| |
| |
|
|
| import os |
|
|
| import torch |
| import torch.nn.functional as F |
| from ..quack.gemm_interface import gemm, gemm_dgated, gemm_gated |
|
|
| from ..enums import ActivationType, is_glu |
| from .backward import ( |
| _down_projection_backward_act, |
| _down_projection_backward_weight, |
| _token_broadcast_backward, |
| _topk_softmax_bwd, |
| _up_projection_backward_act, |
| _up_projection_backward_weight, |
| ) |
| from .forward import _down_projection_forward, _router_forward, _topk_softmax_fwd, _up_projection_forward |
| from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton |
|
|
|
|
| class TC_Softmax_Topk_Router_Function(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, router_logits: torch.Tensor, E: int, K: int, is_softmax_over_topk: bool, norm_topk_probs: bool |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| T = router_logits.size(0) |
|
|
| topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device) |
| topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device) |
|
|
| _topk_softmax_fwd( |
| router_logits, |
| topk_router_score, |
| topk_router_indices, |
| E, |
| K, |
| is_softmax_over_topk=is_softmax_over_topk, |
| norm_topk_probs=norm_topk_probs, |
| ) |
|
|
| |
| |
| ctx.save_for_backward(topk_router_score, topk_router_indices, router_logits) |
| ctx.E = E |
| ctx.dtype = router_logits.dtype |
| ctx.is_softmax_over_topk = is_softmax_over_topk |
| ctx.norm_topk_probs = norm_topk_probs |
|
|
| return topk_router_score, topk_router_indices |
|
|
| @staticmethod |
| def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor): |
| T, K = dtopk_score.size() |
| E = ctx.E |
| topk_router_score, topk_router_indices, router_logits = ctx.saved_tensors |
| dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device) |
|
|
| _topk_softmax_bwd( |
| router_logits, |
| dlogits, |
| None, |
| dtopk_score, |
| topk_router_score, |
| topk_router_indices, |
| E, |
| K, |
| is_softmax_over_topk=ctx.is_softmax_over_topk, |
| norm_topk_probs=ctx.norm_topk_probs, |
| ) |
|
|
| return dlogits, None, None, None, None |
|
|
|
|
| class _UpProjection(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| x: torch.Tensor, |
| w1: torch.Tensor, |
| b1: torch.Tensor | None, |
| expert_frequency_offset: torch.Tensor, |
| total_expert_freq: int, |
| K: int, |
| x_gather_idx: torch.Tensor, |
| s_scatter_idx: torch.Tensor, |
| s_reverse_scatter_idx: torch.Tensor, |
| num_activated_expert_per_token_offset: torch.Tensor, |
| is_each_token_has_variable_activated_experts: bool, |
| activation_type: ActivationType, |
| is_inference_mode_enabled: bool, |
| concat_layout: bool = False, |
| ) -> torch.Tensor: |
| T, H = x.shape |
| I, H, E = w1.shape |
| is_glu_activation = is_glu(activation_type) |
| if is_glu_activation: |
| I //= 2 |
| TK = total_expert_freq |
|
|
| a = torch.empty(TK, I, dtype=x.dtype, device=x.device) |
| h = ( |
| torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device) |
| if (not is_inference_mode_enabled) |
| else None |
| ) |
|
|
| _up_projection_forward( |
| x=x, |
| w1=w1, |
| h=h, |
| a=a, |
| b1=b1, |
| expert_frequency_offset=expert_frequency_offset, |
| x_gather_idx=x_gather_idx, |
| activation_type=activation_type.value, |
| is_inference_mode_enabled=is_inference_mode_enabled, |
| concat_layout=concat_layout, |
| ) |
|
|
| ctx.T = T |
| ctx.TK = TK |
| ctx.E = E |
| ctx.K = K |
| ctx.H = H |
| ctx.I = I |
| ctx.is_each_token_has_variable_activated_experts = is_each_token_has_variable_activated_experts |
| ctx.is_glu_activation = is_glu_activation |
| ctx.concat_layout = concat_layout |
|
|
| ctx.save_for_backward( |
| x, |
| w1, |
| b1, |
| expert_frequency_offset, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset, |
| ) |
|
|
| ctx.mark_non_differentiable(a) |
| ctx.set_materialize_grads(False) |
|
|
| return a, h |
|
|
| @staticmethod |
| def backward(ctx, _: None, dh: torch.Tensor): |
| T = ctx.T |
| TK = ctx.TK |
| E = ctx.E |
| K = ctx.K |
| H = ctx.H |
| is_glu_activation = ctx.is_glu_activation |
| is_each_token_has_variable_activated_experts = ctx.is_each_token_has_variable_activated_experts |
| concat_layout = ctx.concat_layout |
|
|
| ( |
| x, |
| w1, |
| b1, |
| expert_frequency_offset, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset, |
| ) = ctx.saved_tensors |
|
|
| dx_expanded = torch.empty(TK, H, dtype=dh.dtype, device=dh.device) |
| dw1 = torch.empty_like(w1) |
| db1 = None if b1 is None else torch.empty_like(b1) |
|
|
| _up_projection_backward_act( |
| w1=w1, |
| dx_expanded=dx_expanded, |
| dh=dh, |
| db1=db1, |
| expert_frequency_offset=expert_frequency_offset, |
| is_glu_activation=is_glu_activation, |
| concat_layout=concat_layout, |
| ) |
|
|
| _up_projection_backward_weight( |
| x=x, |
| dw1=dw1, |
| dh=dh, |
| expert_frequency_offset=expert_frequency_offset, |
| x_gather_idx=x_gather_idx, |
| is_glu_activation=is_glu_activation, |
| concat_layout=concat_layout, |
| ) |
|
|
| dx_reduced = torch.empty(T, H, dtype=dh.dtype, device=dh.device) |
|
|
| _token_broadcast_backward( |
| dx_reduced=dx_reduced, |
| dx_expanded=dx_expanded, |
| s_reverse_scatter_idx=s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset=num_activated_expert_per_token_offset, |
| varlen_K_max=(E if is_each_token_has_variable_activated_experts else K), |
| H=H, |
| is_varlen_K=is_each_token_has_variable_activated_experts, |
| ) |
|
|
| return dx_reduced, dw1, db1, *[None] * 13 |
|
|
|
|
| class _DownProjection(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| a: torch.Tensor, |
| h: torch.Tensor, |
| w2: torch.Tensor, |
| b2: torch.Tensor | None, |
| topk_scores: torch.Tensor, |
| expert_frequency_offset: torch.Tensor, |
| T: int, |
| K: int, |
| x_gather_idx: torch.Tensor, |
| s_scatter_idx: torch.Tensor, |
| s_reverse_scatter_idx: torch.Tensor, |
| num_activated_expert_per_token_offset: torch.Tensor, |
| is_varlen_K: bool, |
| activation_type: ActivationType, |
| ) -> torch.Tensor: |
| TK = a.size(0) |
| H, I, E = w2.shape |
|
|
| y = torch.empty(TK, H, dtype=a.dtype, device=a.device) |
|
|
| _down_projection_forward( |
| w2=w2, |
| a=a, |
| y=y, |
| b2=b2, |
| expert_frequency_offset=expert_frequency_offset, |
| ) |
|
|
| o = torch.empty(T, H, device=a.device, dtype=a.dtype) |
| topk_scores = topk_scores.view(-1) |
|
|
| _router_forward( |
| y=y, |
| o=o, |
| topk_scores=topk_scores, |
| s_reverse_scatter_idx=s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset=num_activated_expert_per_token_offset, |
| varlen_K_max=(E if is_varlen_K else K), |
| H=H, |
| is_varlen_K=is_varlen_K, |
| ) |
|
|
| ctx.T = T |
| ctx.K = K |
| ctx.is_varlen_K = is_varlen_K |
| ctx.activation_type = activation_type |
|
|
| ctx.save_for_backward( |
| h, |
| w2, |
| b2, |
| topk_scores, |
| expert_frequency_offset, |
| x_gather_idx, |
| s_scatter_idx, |
| ) |
|
|
| return o |
|
|
| @staticmethod |
| def backward(ctx, dout: torch.Tensor): |
| T = ctx.T |
| K = ctx.K |
| is_varlen_K = ctx.is_varlen_K |
| activation_type = ctx.activation_type |
|
|
| ( |
| h, |
| w2, |
| b2, |
| topk_scores, |
| expert_frequency_offset, |
| x_gather_idx, |
| s_scatter_idx, |
| ) = ctx.saved_tensors |
|
|
| dw2 = torch.empty_like(w2) |
| db2 = None if b2 is None else torch.empty_like(b2) |
| dh = torch.empty_like(h) |
|
|
| I = w2.size(1) |
| TK = x_gather_idx.size(0) |
|
|
| a_prime = torch.empty(TK, I, dtype=h.dtype, device=h.device) |
| ds = torch.empty_like(topk_scores) |
|
|
| _down_projection_backward_act( |
| dout=dout, |
| h=h, |
| w2=w2, |
| dh=dh, |
| ds=ds, |
| b2=b2, |
| db2=db2, |
| a_prime=a_prime, |
| topk_scores=topk_scores, |
| expert_frequency_offset=expert_frequency_offset, |
| x_gather_idx=x_gather_idx, |
| s_scatter_idx=s_scatter_idx, |
| activation_type=activation_type.value, |
| ) |
|
|
| _down_projection_backward_weight( |
| dout=dout, |
| a_prime=a_prime, |
| dw2=dw2, |
| expert_frequency_offset=expert_frequency_offset, |
| x_gather_idx=x_gather_idx, |
| ) |
|
|
| |
| if not is_varlen_K: |
| ds = ds.view(T, K) |
|
|
| return None, dh, dw2, db2, ds, *[None] * 10 |
|
|
|
|
| def moe_TC_softmax_topk_layer( |
| x: torch.Tensor, |
| router_w: torch.Tensor, |
| w1: torch.Tensor, |
| b1: torch.Tensor | None, |
| w2: torch.Tensor, |
| b2: torch.Tensor | None, |
| K: int, |
| stream_id: int, |
| activation_type: ActivationType | str = ActivationType.SWIGLU, |
| is_inference_mode_enabled: bool = False, |
| is_softmax_over_topk: bool = True, |
| norm_topk_probs: bool = False, |
| concat_layout: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| assert ((b1 is None) and (b2 is None)) or ( |
| (b1 is not None) and (b2 is not None) |
| ), "b1 and b2 has to be None or not None at the same time!" |
| E = router_w.size(0) |
| router_logits = F.linear(x, router_w) |
| topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply( |
| router_logits, E, K, is_softmax_over_topk, norm_topk_probs |
| ) |
|
|
| T, K = topk_indices.size() |
| TK = T * K |
| device = topk_indices.device |
|
|
| s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) |
| s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) |
| expert_frequency = torch.empty(E, dtype=torch.int32, device=device) |
| expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device) |
| x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device) |
|
|
| TC_topk_router_metadata_triton( |
| topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx |
| ) |
|
|
| if type(activation_type) == str: |
| activation_type = ActivationType(activation_type) |
|
|
| assert not torch.compiler.is_compiling() |
| assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet" |
|
|
| a, h = _UpProjection.apply( |
| x, |
| w1, |
| b1, |
| expert_frequency_offset, |
| TK, |
| K, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| None, |
| False, |
| activation_type, |
| is_inference_mode_enabled, |
| concat_layout, |
| ) |
|
|
| o = _DownProjection.apply( |
| a, |
| h, |
| w2, |
| b2, |
| topk_scores, |
| expert_frequency_offset, |
| T, |
| K, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| None, |
| False, |
| activation_type, |
| ) |
|
|
| return o, router_logits, expert_frequency |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| def moe_general_routing_inputs( |
| x: torch.Tensor, |
| router_scores: torch.Tensor, |
| token_indices: torch.Tensor, |
| expert_indices: torch.Tensor, |
| w1: torch.Tensor, |
| b1: torch.Tensor | None, |
| w2: torch.Tensor, |
| b2: torch.Tensor | None, |
| E: int, |
| stream_id: int, |
| activation_type: ActivationType, |
| is_inference_mode_enabled: bool = False, |
| concat_layout: bool = False, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| assert ((b1 is None) and (b2 is None)) or ( |
| (b1 is not None) and (b2 is not None) |
| ), "b1 and b2 has to be None or not None at the same time!" |
|
|
| T = x.size(0) |
| TK = router_scores.size(0) |
| E = w2.size(-1) |
| device = router_scores.device |
|
|
| if router_scores.dtype != torch.float32: |
| router_scores = router_scores.float() |
|
|
| s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) |
| s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) |
| expert_frequency = torch.empty(E, dtype=torch.int32, device=device) |
| expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device) |
| x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device) |
| num_activated_expert_per_token_offset = torch.empty(T + 1, dtype=torch.int32, device=device) |
|
|
| general_routing_router_metadata_triton( |
| token_indices, |
| expert_indices, |
| T, |
| E, |
| expert_frequency, |
| expert_frequency_offset, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset, |
| ) |
|
|
| assert not torch.compiler.is_compiling() |
| assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet" |
|
|
| a, h = _UpProjection.apply( |
| x, |
| w1, |
| b1, |
| expert_frequency_offset, |
| TK, |
| None, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset, |
| True, |
| activation_type, |
| is_inference_mode_enabled, |
| concat_layout, |
| ) |
|
|
| o = _DownProjection.apply( |
| a, |
| h, |
| w2, |
| b2, |
| router_scores, |
| expert_frequency_offset, |
| T, |
| None, |
| x_gather_idx, |
| s_scatter_idx, |
| s_reverse_scatter_idx, |
| num_activated_expert_per_token_offset, |
| True, |
| activation_type, |
| ) |
|
|
| return o, expert_frequency |
|
|