|
|
import torch |
|
|
import torch.nn as nn |
|
|
from . import kernels |
|
|
from typing import Optional |
|
|
|
|
|
@torch.library.custom_op("scattermoe::bincount", mutates_args={}) |
|
|
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor: |
|
|
return x.bincount(minlength=minlength) |
|
|
|
|
|
@compileable_bincount.register_fake |
|
|
def _(x: torch.Tensor, minlength: int) -> torch.Tensor: |
|
|
return torch.empty(minlength, dtype=torch.long, device=x.device) |
|
|
|
|
|
@torch.compile |
|
|
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int): |
|
|
with torch.no_grad(): |
|
|
flattened_expert_idxs = expert_idxs.flatten() |
|
|
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs) |
|
|
expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts) |
|
|
expert_offsets = expert_counts.cumsum(-1) |
|
|
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets |
|
|
|
|
|
|
|
|
|
|
|
class ParallelLinear(torch.autograd.Function): |
|
|
@staticmethod |
|
|
def forward( |
|
|
ctx, |
|
|
x: torch.Tensor, expert_weights: torch.Tensor, k: int, |
|
|
sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor, |
|
|
expert_offsets: torch.Tensor, |
|
|
expert_biases: Optional[torch.Tensor]=None, |
|
|
gates: Optional[torch.Tensor]=None, |
|
|
grouped_in: bool =False, grouped_out: bool=False, |
|
|
): |
|
|
with torch.device(x.device): |
|
|
output = kernels.ops.scatter2scatter( |
|
|
X=x, W=expert_weights, |
|
|
b=expert_biases, k=k, |
|
|
sorted_expert_idxs=sorted_expert_idxs, |
|
|
sorted_scattered_idxs=sorted_scattered_idxs, |
|
|
x_grouped=grouped_in, y_grouped=grouped_out |
|
|
) |
|
|
if gates is not None: |
|
|
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1)) |
|
|
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1) |
|
|
else: |
|
|
output_expanded = None |
|
|
|
|
|
ctx.save_for_backward( |
|
|
x, expert_weights, |
|
|
expert_biases, |
|
|
sorted_expert_idxs, |
|
|
sorted_scattered_idxs, |
|
|
expert_offsets, |
|
|
gates, |
|
|
output_expanded |
|
|
) |
|
|
ctx.grouped_in = grouped_in |
|
|
ctx.grouped_out = grouped_out |
|
|
ctx.k = k |
|
|
return output |
|
|
@staticmethod |
|
|
def backward(ctx, grad_out: torch.Tensor): |
|
|
with torch.device(grad_out.device): |
|
|
(x, expert_weights, expert_biases, |
|
|
sorted_expert_idxs, |
|
|
sorted_scattered_idxs, |
|
|
expert_offsets, |
|
|
gates, output_expanded) = ctx.saved_tensors |
|
|
k = ctx.k |
|
|
grouped_in = ctx.grouped_in |
|
|
grouped_out = ctx.grouped_out |
|
|
|
|
|
|
|
|
if gates is not None: |
|
|
|
|
|
|
|
|
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1) |
|
|
gates_flat = gates.flatten() |
|
|
gate_fan = gates.size(1) |
|
|
grouped_grad_out = output_expanded.flatten(0, 1) |
|
|
else: |
|
|
d_gates = None |
|
|
gates_flat = None |
|
|
gate_fan = 1 |
|
|
grouped_grad_out = None |
|
|
|
|
|
if grouped_out: |
|
|
grouped_grad_out = grad_out |
|
|
else: |
|
|
grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs, |
|
|
fan_out=gate_fan, coeff=gates_flat, |
|
|
out=grouped_grad_out) |
|
|
if grouped_in: |
|
|
grouped_x = x |
|
|
d_expanded_input = None |
|
|
else: |
|
|
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k) |
|
|
d_expanded_input = grouped_x |
|
|
|
|
|
d_weights, d_biases = kernels.ops.group_bwd_W( |
|
|
DY=grouped_grad_out, X=grouped_x, |
|
|
expert_offsets=expert_offsets, |
|
|
E=expert_weights.size(0), |
|
|
has_bias=expert_biases is not None |
|
|
) |
|
|
|
|
|
|
|
|
d_expanded_input = kernels.ops.scatter2scatter( |
|
|
X=grouped_grad_out, x_grouped=True, |
|
|
W=expert_weights.permute(0, 2, 1), |
|
|
sorted_expert_idxs=sorted_expert_idxs, |
|
|
sorted_scattered_idxs=sorted_scattered_idxs, |
|
|
k=1, |
|
|
y_grouped=grouped_in, |
|
|
out=d_expanded_input |
|
|
) |
|
|
|
|
|
if k == 1: |
|
|
d_input = d_expanded_input |
|
|
else: |
|
|
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2) |
|
|
|
|
|
return ( |
|
|
|
|
|
d_input, d_weights, |
|
|
|
|
|
None, None, None, None, |
|
|
|
|
|
d_biases, d_gates, |
|
|
|
|
|
None, None |
|
|
) |
|
|
|
|
|
def parallel_linear(inputs, expert_weights, k, |
|
|
sorted_expert_idxs, sorted_scattered_idxs, |
|
|
expert_offsets, |
|
|
expert_biases=None, |
|
|
gates=None, grouped_in=False, grouped_out=False): |
|
|
results = ParallelLinear.apply(inputs, expert_weights, k, |
|
|
sorted_expert_idxs, sorted_scattered_idxs, |
|
|
expert_offsets, |
|
|
expert_biases, |
|
|
gates, grouped_in, grouped_out) |
|
|
return results |
|
|
|
|
|
class ParallelExperts(nn.Module): |
|
|
def __init__(self, num_experts, input_size, output_size, bias=False) -> None: |
|
|
super().__init__() |
|
|
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size)) |
|
|
|
|
|
if bias: |
|
|
self.bias = nn.Parameter(torch.empty(num_experts, output_size)) |
|
|
else: |
|
|
self.bias = None |
|
|
|
|
|
self.num_experts = num_experts |
|
|
self.input_size = input_size |
|
|
self.output_size = output_size |
|
|
self.reset_parameters() |
|
|
|
|
|
def extra_repr(self): |
|
|
return 'num_experts={}, input_size={}, output_size={}'.format( |
|
|
self.num_experts, self.input_size, self.output_size) |
|
|
|
|
|
def reset_parameters(self) -> None: |
|
|
nn.init.normal_(self.weight, std=0.02) |
|
|
if self.bias is not None: |
|
|
nn.init.zeros_(self.bias) |
|
|
|
|
|
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs, |
|
|
expert_offsets, |
|
|
gates=None, grouped_in=False, grouped_out=False): |
|
|
|
|
|
results = parallel_linear( |
|
|
inputs, self.weight.permute(0, 2, 1), k, |
|
|
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets, |
|
|
expert_biases=self.bias, |
|
|
gates=gates, grouped_in=grouped_in, grouped_out=grouped_out |
|
|
) |
|
|
return results |
|
|
|