scattermoe / build /torch-rocm /parallel_experts.py
danieldk's picture
danieldk HF Staff
Build uploaded using `kernels`.
d4e05f6 verified
import torch
import torch.nn as nn
from . import kernels
from typing import Optional
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength)
@compileable_bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, dtype=torch.long, device=x.device)
@torch.compile
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
with torch.no_grad():
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
expert_offsets = expert_counts.cumsum(-1)
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor, expert_weights: torch.Tensor, k: int,
sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
expert_biases: Optional[torch.Tensor]=None,
gates: Optional[torch.Tensor]=None,
grouped_in: bool =False, grouped_out: bool=False,
):
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x, W=expert_weights,
b=expert_biases, k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
x_grouped=grouped_in, y_grouped=grouped_out
)
if gates is not None:
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x, expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(x, expert_weights, expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates, output_expanded) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
# print("backward")
if gates is not None:
# calculate gates gradient
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
fan_out=gate_fan, coeff=gates_flat,
out=grouped_grad_out)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights, d_biases = kernels.ops.group_bwd_W(
DY=grouped_grad_out, X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0),
has_bias=expert_biases is not None
)
d_expanded_input = kernels.ops.scatter2scatter(
X=grouped_grad_out, x_grouped=True,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
# print("backward end.")
return (
# x, expert_weights,
d_input, d_weights,
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
None, None, None, None,
# bias, gates
d_biases, d_gates,
# grouped_in, grouped_out,
None, None
)
def parallel_linear(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
expert_offsets,
expert_biases=None,
gates=None, grouped_in=False, grouped_out=False):
results = ParallelLinear.apply(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates, grouped_in, grouped_out)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reset_parameters()
def extra_repr(self):
return 'num_experts={}, input_size={}, output_size={}'.format(
self.num_experts, self.input_size, self.output_size)
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
expert_offsets,
gates=None, grouped_in=False, grouped_out=False):
results = parallel_linear(
inputs, self.weight.permute(0, 2, 1), k,
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
expert_biases=self.bias,
gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
)
return results