| | import torch |
| | from torch.nn import functional as F |
| |
|
| | |
| | def moe_matmul(inputs, weight_list, group_index, linear_fn=lambda x, y: torch.matmul(x, y)): |
| | """ |
| | inputs: tensor (bs, sl, dim) |
| | weight_list: MoE weights, list of [(dim, dim')] |
| | group_index: (bs, sl), max(group_index) + 1 == len(weight_list), 在sl维上表示分组信息 |
| | group_nums: 表示MoE的个数 |
| | example: |
| | 拉平后bs*sl的group index 0 0 0 1 1 1 0 0 1 1 1 0 0 0 1 1 1 (17) |
| | 按0, 1 分别正反编码index |
| | 0: |
| | cumsum: 0 1 2 2 2 2 3 4 4 4 4 5 6 7 7 7 7 |
| | offset: same |
| | mask: 0 1 2 0 0 0 3 4 0 0 0 5 6 7 0 0 0 |
| | new offset is 7 |
| | 1: |
| | cumsum: 0 0 0 1 2 3 3 3 4 5 6 6 6 6 7 8 9 |
| | offset: 7 7 7 8 9 10 10 10 11 12 13 13 13 13 14 15 16 |
| | mask: 0 0 0 8 9 10 0 0 11 12 13 0 0 0 14 15 16 |
| | new offset is 16 |
| | ... |
| | 合并encode映射码表 |
| | 0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16 |
| | 执行gather操作,之后将inputs按offset split 分别matmul 再concat |
| | decode映射码表 |
| | 0 1 2 8 9 10 3 4 11 12 13 5 6 7 14 15 16 index |
| | 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 value |
| | : |
| | 0 1 2 6 7 11 12 13 3 4 5 8 9 10 14 15 16 |
| | |
| | """ |
| | bs, sl = group_index.size() |
| | group_inputs, cur_offset, group_encode_index = [], 0, 0 |
| | for group_i in range(len(weight_list)): |
| | group_i_mask = torch.eq(group_index.to(torch.int32), group_i).view(bs * sl) |
| | group_inputs.append(linear_fn( |
| | torch.masked_select(inputs, group_i_mask.view(bs, sl, 1)).view(-1, inputs.size(-1)), |
| | weight_list[group_i])) |
| | group_i_index = torch.cumsum(group_i_mask.view(bs * sl).to(torch.int64), axis=0) |
| | group_i_index -= 1 if group_i == 0 else 0 |
| | group_i_index = (cur_offset + group_i_index) * group_i_mask |
| | cur_offset = torch.max(group_i_index) |
| | group_encode_index += group_i_index |
| | |
| | group_decode_index = torch.gather(torch.arange(0, bs * sl, step=1, dtype=torch.int64, device=inputs.device), 0, group_encode_index) |
| | group_inputs = torch.cat(group_inputs, axis=0) |
| | outputs = torch.index_select(group_inputs, 0, group_decode_index).view(bs, sl, -1) |
| | return outputs |
| |
|
| |
|
| | if __name__ == "__main__": |
| | bs, sl, d = 13, 997, 97 |
| | dtype = torch.bfloat16 |
| | inputs = torch.tensor(torch.randn([bs, sl, d], dtype=dtype).cuda(), requires_grad=True) |
| | group_num = 2 |
| | |
| | group_index = torch.remainder(torch.randint(0, 6, (bs, sl)), 1).cuda() |
| | weights = [torch.tensor(torch.eye(d).cuda().to(dtype), requires_grad=True) for _ in range(group_num)] |
| | output = moe_matmul(inputs, weights, group_index) |
| | print(inputs - output) |
| | loss = torch.sum(output * (group_index+1).to(dtype).view(bs, sl, 1)) |
| | print(loss) |
| | loss.backward() |
| | print(inputs.grad[:, :, 0] - group_index.to(dtype)) |
| | print(weights[-1].grad) |
| |
|
| |
|