File size: 7,088 Bytes
d4e05f6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 | import torch
import torch.nn as nn
from . import kernels
from typing import Optional
@torch.library.custom_op("scattermoe::bincount", mutates_args={})
def compileable_bincount(x: torch.Tensor, minlength: int) -> torch.Tensor:
return x.bincount(minlength=minlength)
@compileable_bincount.register_fake
def _(x: torch.Tensor, minlength: int) -> torch.Tensor:
return torch.empty(minlength, dtype=torch.long, device=x.device)
@torch.compile
def flatten_sort_count(expert_idxs: torch.Tensor, num_experts: int):
with torch.no_grad():
flattened_expert_idxs = expert_idxs.flatten()
sorted_expert_idxs, sorted_scattered_idxs = torch.sort(flattened_expert_idxs)
expert_counts = compileable_bincount(flattened_expert_idxs, minlength=num_experts)
expert_offsets = expert_counts.cumsum(-1)
return sorted_expert_idxs, sorted_scattered_idxs, expert_offsets
class ParallelLinear(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: torch.Tensor, expert_weights: torch.Tensor, k: int,
sorted_expert_idxs: torch.Tensor, sorted_scattered_idxs: torch.Tensor,
expert_offsets: torch.Tensor,
expert_biases: Optional[torch.Tensor]=None,
gates: Optional[torch.Tensor]=None,
grouped_in: bool =False, grouped_out: bool=False,
):
with torch.device(x.device):
output = kernels.ops.scatter2scatter(
X=x, W=expert_weights,
b=expert_biases, k=k,
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
x_grouped=grouped_in, y_grouped=grouped_out
)
if gates is not None:
output_expanded = output.view(gates.size(0), gates.size(1), output.size(-1))
output = (gates.unsqueeze(1) @ output_expanded).squeeze(1)
else:
output_expanded = None
ctx.save_for_backward(
x, expert_weights,
expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates,
output_expanded
)
ctx.grouped_in = grouped_in
ctx.grouped_out = grouped_out
ctx.k = k
return output
@staticmethod
def backward(ctx, grad_out: torch.Tensor):
with torch.device(grad_out.device):
(x, expert_weights, expert_biases,
sorted_expert_idxs,
sorted_scattered_idxs,
expert_offsets,
gates, output_expanded) = ctx.saved_tensors
k = ctx.k
grouped_in = ctx.grouped_in
grouped_out = ctx.grouped_out
# print("backward")
if gates is not None:
# calculate gates gradient
# d_gates = torch.bmm(output_expanded, grad_out[:, :, None]).squeeze(-1)
d_gates = (output_expanded @ grad_out.unsqueeze(-1)).squeeze(-1)
gates_flat = gates.flatten()
gate_fan = gates.size(1)
grouped_grad_out = output_expanded.flatten(0, 1) # reuse expanded buffer later
else:
d_gates = None
gates_flat = None
gate_fan = 1
grouped_grad_out = None
if grouped_out:
grouped_grad_out = grad_out
else:
grouped_grad_out = kernels.ops.group(grad_out, sorted_scattered_idxs,
fan_out=gate_fan, coeff=gates_flat,
out=grouped_grad_out)
if grouped_in:
grouped_x = x
d_expanded_input = None
else:
grouped_x = kernels.ops.group(x, sorted_scattered_idxs, fan_out=k)
d_expanded_input = grouped_x
d_weights, d_biases = kernels.ops.group_bwd_W(
DY=grouped_grad_out, X=grouped_x,
expert_offsets=expert_offsets,
E=expert_weights.size(0),
has_bias=expert_biases is not None
)
d_expanded_input = kernels.ops.scatter2scatter(
X=grouped_grad_out, x_grouped=True,
W=expert_weights.permute(0, 2, 1),
sorted_expert_idxs=sorted_expert_idxs,
sorted_scattered_idxs=sorted_scattered_idxs,
k=1,
y_grouped=grouped_in,
out=d_expanded_input # Reuse grouped_x buffer
)
if k == 1:
d_input = d_expanded_input
else:
d_input = d_expanded_input.view(x.size(0), k, d_expanded_input.size(-1)).sum(-2)
# print("backward end.")
return (
# x, expert_weights,
d_input, d_weights,
# k, sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
None, None, None, None,
# bias, gates
d_biases, d_gates,
# grouped_in, grouped_out,
None, None
)
def parallel_linear(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
expert_offsets,
expert_biases=None,
gates=None, grouped_in=False, grouped_out=False):
results = ParallelLinear.apply(inputs, expert_weights, k,
sorted_expert_idxs, sorted_scattered_idxs,
expert_offsets,
expert_biases,
gates, grouped_in, grouped_out)
return results
class ParallelExperts(nn.Module):
def __init__(self, num_experts, input_size, output_size, bias=False) -> None:
super().__init__()
self.weight = nn.Parameter(torch.empty(num_experts, output_size, input_size))
if bias:
self.bias = nn.Parameter(torch.empty(num_experts, output_size))
else:
self.bias = None
self.num_experts = num_experts
self.input_size = input_size
self.output_size = output_size
self.reset_parameters()
def extra_repr(self):
return 'num_experts={}, input_size={}, output_size={}'.format(
self.num_experts, self.input_size, self.output_size)
def reset_parameters(self) -> None:
nn.init.normal_(self.weight, std=0.02)
if self.bias is not None:
nn.init.zeros_(self.bias)
def forward(self, inputs, k, sorted_expert_idxs, sorted_scattered_idxs,
expert_offsets,
gates=None, grouped_in=False, grouped_out=False):
results = parallel_linear(
inputs, self.weight.permute(0, 2, 1), k,
sorted_expert_idxs, sorted_scattered_idxs, expert_offsets,
expert_biases=self.bias,
gates=gates, grouped_in=grouped_in, grouped_out=grouped_out
)
return results
|