| import torch |
|
|
| from .grouped_gemm.interface import grouped_gemm |
|
|
|
|
| |
| def _moe_fused_linear_naive_fwd( |
| input: torch.Tensor, weight: torch.Tensor, selected_experts: torch.Tensor |
| ) -> torch.Tensor: |
| """ |
| Computes a MoE linear operation. |
| |
| The operation is defined as: |
| `output[b, o] = sum_i weight[selected_experts[b], o, i] * input[b, i]` |
| |
| Args: |
| input (`torch.FloatTensor`): input tensor of shape `(batch_size, in_features)`. |
| weight (`torch.FloatTensor`): weight tensor of shape `(num_experts, out_features, in_features)`. |
| selected_experts (`torch.LongTensor`): tensor of selected expert indices in shape `(batch_size,)`. |
| Each element is in the range `[0, num_experts)`. |
| |
| Returns: |
| output (`torch.FloatTensor`): output tensor of shape `(batch_size, out_features)`. |
| """ |
| batch_size, in_features = input.shape |
| num_experts, out_features, _ = weight.shape |
|
|
| output = torch.empty(batch_size, out_features, device=input.device, dtype=input.dtype) |
| for b in range(batch_size): |
| _weight = weight[selected_experts[b], :, :] |
| _input = input[b, :] |
| output[b, :] = _weight @ _input |
| return output |
|
|
|
|
| |
| def _moe_fused_linear_naive_bwd_input( |
| grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, selected_experts: torch.Tensor |
| ) -> torch.Tensor: |
| batch_size, in_features = input.shape |
| num_experts, out_features, _ = weight.shape |
|
|
| grad_input = torch.empty_like(input) |
| for b in range(batch_size): |
| _weight = weight[selected_experts[b], :, :] |
| _grad_output = grad_output[b, :] |
| grad_input[b, :] = _grad_output @ _weight |
| return grad_input |
|
|
|
|
| |
| def _moe_fused_linear_naive_bwd_weight( |
| grad_output: torch.Tensor, input: torch.Tensor, weight: torch.Tensor, selected_experts: torch.Tensor |
| ) -> torch.Tensor: |
| batch_size, in_features = input.shape |
| num_experts, out_features, _ = weight.shape |
|
|
| grad_weight = torch.zeros_like(weight) |
| for b in range(batch_size): |
| grad_weight[selected_experts[b], :, :] += grad_output[b, :, None] * input[b, None, :] |
| return grad_weight |
|
|
|
|
| def moe_fused_linear(input: torch.Tensor, weight: torch.Tensor, m_sizes: torch.Tensor) -> torch.Tensor: |
| """ |
| Computes a MoE linear operation using grouped GEMM. |
| |
| The operation is defined as: |
| `output[b, o] = sum_i weight[selected_experts[b], o, i] * input[b, i]` |
| |
| Args: |
| input (`torch.FloatTensor`): input tensor of shape `(batch_size, in_features)`. |
| weight (`torch.FloatTensor`): weight tensor of shape `(num_experts, out_features, in_features)`. |
| m_sizes (`torch.LongTensor`): counts of selected experts in shape `(num_experts)`. Should sum to `batch_size`. |
| |
| Returns: |
| output (`torch.FloatTensor`): output tensor of shape `(batch_size, out_features)`. |
| """ |
| return grouped_gemm(input, weight, m_sizes) |
|
|