| import torch |
| import triton |
| from dataclasses import dataclass, field |
| from .routing_details._routing_compute import _combined_routing_compute |
| from .routing_details._routing_compute import _combined_routing_memset |
| from .routing_details._routing_compute import _routing_clear_bitmatrix |
| from .routing_details._expt_data import _expt_data_memset |
| from .routing_details._expt_data import _expt_data_compute |
| from .target_info import is_hip |
|
|
|
|
| @dataclass |
| class GatherIndx: |
| """ |
| Indices for an operation that performs: |
| Y = X[src_idx, :] |
| """ |
| |
| src_indx: torch.Tensor |
| dst_indx: torch.Tensor |
|
|
|
|
| @dataclass |
| class ScatterIndx: |
| """ |
| Indices for an operation that performs: |
| Y[dst_idx, :] = X |
| """ |
| |
| src_indx: torch.Tensor |
| dst_indx: torch.Tensor |
|
|
|
|
| @dataclass |
| class ExptData: |
| |
| hist: torch.Tensor |
| |
| |
| token_offs_raw: torch.Tensor |
| |
| |
| |
| token_offs_pad: dict[int, torch.Tensor] |
| |
| |
| |
| |
| |
| |
| |
| |
| block_pid_map: dict[int, torch.Tensor] |
|
|
| def __post_init__(self): |
| if self.hist is not None: |
| assert self.hist.dtype == torch.int32 |
| if self.token_offs_raw is not None: |
| assert self.token_offs_raw.dtype == torch.int32 |
| if self.token_offs_pad is not None: |
| for v in self.token_offs_pad.values(): |
| assert v.dtype == torch.int32 |
| if self.block_pid_map is not None: |
| for v in self.block_pid_map.values(): |
| assert v.dtype == torch.int32 |
|
|
|
|
| @dataclass |
| class RoutingData: |
| gate_scal: torch.Tensor = field() |
| expt_hist: torch.Tensor = field() |
| n_expts_tot: int = field() |
| n_expts_act: int = field() |
| expt_data: ExptData = None |
|
|
| |
| |
| |
| expected_tokens_per_expt: int = field(default=None) |
|
|
| def n_blocks(self, n_rows, block_m): |
| if n_rows <= self.n_expts_tot: |
| return n_rows |
| else: |
| return triton.cdiv(max(n_rows - self.n_expts_tot + 1, 0), block_m) + self.n_expts_tot - 1 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SortTokens(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, expt_scal, expt_indx, n_expts_tot, bitmatrix): |
| HIST_BLOCK_M = 32 |
| INDX_OFFS_BLOCK_M = 512 |
| MEMSET_BLOCK = 1024 |
| cdiv = triton.cdiv |
|
|
| device = expt_scal.device |
| dtype = expt_scal.dtype |
| n_tokens_raw, _ = bitmatrix.shape |
| n_tokens_pad, n_expts_act = expt_scal.shape |
| n_gates_pad = n_tokens_pad * n_expts_act |
|
|
| hist, partial_hist = bitmatrix.sum(partials_block_size=HIST_BLOCK_M) |
| hist = hist[:n_expts_tot] |
| assert hist.dtype == torch.int32 |
| |
| expt_offs = torch.empty(n_expts_tot, dtype=torch.int32, device=device) |
| combined_indx = torch.empty(n_gates_pad * 2, dtype=torch.int32, device=device) |
| |
| topk_indx = combined_indx[:n_gates_pad] |
| gate_indx = combined_indx[n_gates_pad:] |
| gate_scal = torch.empty(n_gates_pad, dtype=dtype, device=device) |
|
|
| token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1a, blocks2a, MEMSET_BLOCK_A, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal( |
| hist, n_expts_tot, n_gates_pad) |
|
|
| blocks1b = cdiv(n_gates_pad * 2, MEMSET_BLOCK) + n_expts_tot + 1 |
| blocks2b = cdiv(n_tokens_pad, HIST_BLOCK_M) |
|
|
| _combined_routing_memset[(blocks1a + blocks1b, )]( |
| combined_indx, n_gates_pad * 2, -1, MEMSET_BLOCK, hist, |
| expt_offs, hist.shape[0], n_expts_tot, partial_hist, |
| partial_hist.shape[0], partial_hist.stride(0), partial_hist.stride(1), |
| token_offs_combined, token_offs_combined.stride(0), |
| blocks1a, block_pid_map, |
| block_m_log2_start, SIZES=block_m_num, BLOCK_A=MEMSET_BLOCK_A, |
| BLOCK_N=512, BLOCK_M=INDX_OFFS_BLOCK_M, |
| ) |
|
|
| indx_offs = partial_hist |
|
|
| _combined_routing_compute[(blocks2a + blocks2b, )]( |
| topk_indx, gate_indx, gate_scal, |
| expt_scal, expt_indx, indx_offs, indx_offs.stride(0), indx_offs.stride(1), |
| expt_offs, n_tokens_raw, |
| HIST_BLOCK_M, n_expts_act, |
| hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), |
| block_m_log2_start, block_m_num, HIST2_BLOCK_M, blocks2a, |
| ) |
|
|
| ctx.n_tokens_raw = n_tokens_raw |
| ctx.n_tokens_pad = n_tokens_pad |
| ctx.n_expts_act = n_expts_act |
| ctx.save_for_backward(gate_indx) |
| return hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map |
|
|
| @staticmethod |
| def backward(ctx, _0, _1, _2, dgate_scal, _3, _4, _5): |
| (gate_indx, ) = ctx.saved_tensors |
| dgate_scal = dgate_scal[gate_indx] |
| dgate_scal = dgate_scal.reshape(ctx.n_tokens_pad, ctx.n_expts_act) |
| return dgate_scal, None, None, None |
|
|
|
|
| def sort_tokens(expt_scal, expt_indx, n_expts_tot, bitmatrix): |
| return SortTokens.apply(expt_scal, expt_indx, n_expts_tot, bitmatrix) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class PruneRouting(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): |
| from .compaction import compaction |
| n_tokens_pad = expt_scal.shape[0] |
| assert n_expts_tot % simulated_ep == 0 |
| _routing_clear_bitmatrix[(n_tokens_pad, )]( |
| bitmatrix.storage.data, |
| bitmatrix.storage.data.stride(0), |
| bitmatrix.storage.data.stride(1), |
| bitmatrix.storage.data.shape[1], |
| n_expts_tot // simulated_ep, |
| BLOCK_N=512, |
| ) |
| |
| expt_scal, expt_indx = compaction(expt_scal, expt_indx, bitmatrix) |
| n_expts_tot = n_expts_tot // simulated_ep |
| bitmatrix.shape[-1] = n_expts_tot |
| return expt_scal, expt_indx, bitmatrix |
|
|
|
|
| def prune_routing(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep): |
| return PruneRouting.apply(expt_scal, expt_indx, bitmatrix, n_expts_tot, simulated_ep) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def log2_power_of_two(x): |
| assert x > 0 and (x & (x - 1)) == 0, "x must be a power of two" |
| return x.bit_length() - 1 |
|
|
|
|
| block_m_log2_start = 4 |
|
|
|
|
| def _compute_expt_data_internal(expt_hist, n_expts_tot, n_gates): |
|
|
| MEMSET_BLOCK = 512 |
| HIST2_BLOCK_M = 512 |
| device = expt_hist.device |
| n_expts_tot = n_expts_tot |
| cdiv = triton.cdiv |
| |
| block_m_log2_end = 9 if is_hip() else 8 |
| block_m_num = block_m_log2_end - block_m_log2_start |
| if n_gates <= n_expts_tot: |
| max_n_tiles = n_gates |
| else: |
| max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // 2**block_m_log2_start) |
| |
| pad = lambda x: cdiv(x, MEMSET_BLOCK) * MEMSET_BLOCK |
| dtype = torch.int32 |
|
|
| token_offs_combined = torch.empty((block_m_num + 1, pad(n_expts_tot + 1)), dtype=dtype, device=device) |
|
|
| token_offs_raw = token_offs_combined[0][:n_expts_tot + 1] |
| token_offs_pad = token_offs_combined[1:] |
|
|
| block_pid_map = torch.empty((block_m_num, pad(max_n_tiles)), dtype=dtype, device=device) |
| memset_grid = torch.numel(block_pid_map) // MEMSET_BLOCK |
| |
| token_offs_pad = token_offs_pad[:, :n_expts_tot + 1] |
| block_pid_map = block_pid_map[:, :max_n_tiles] |
|
|
| blocks1 = memset_grid + block_m_num + 1 |
| blocks2 = n_expts_tot * block_m_num |
| return token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num |
|
|
|
|
| def _unpack_into_dict(x): |
|
|
| block_m_log2_end = block_m_log2_start + x.shape[0] |
| x = {2**j: x[i, :] for i, j in enumerate(range(block_m_log2_start, block_m_log2_end))} |
| return x |
|
|
|
|
| def compute_expt_data(expt_hist, n_expts_tot, n_gates): |
|
|
| if expt_hist is None: |
| return ExptData(None, None, None, None) |
|
|
| |
| token_offs_combined, token_offs_raw, token_offs_pad, block_pid_map, blocks1, blocks2, MEMSET_BLOCK, HIST2_BLOCK_M, block_m_log2_start, block_m_num = _compute_expt_data_internal( |
| expt_hist, n_expts_tot, n_gates) |
|
|
| _expt_data_memset[(blocks1, )]( |
| expt_hist, n_expts_tot, |
| token_offs_combined, token_offs_combined.stride(0), |
| block_pid_map, |
| block_m_log2_start, SIZES=block_m_num, BLOCK=MEMSET_BLOCK, |
| num_warps=4) |
| _expt_data_compute[(blocks2, )]( |
| expt_hist, token_offs_pad, token_offs_pad.stride(0), block_pid_map, block_pid_map.stride(0), |
| block_m_log2_start, SIZES=block_m_num, BLOCK=HIST2_BLOCK_M, |
| num_warps=4) |
|
|
| token_offs_pad = _unpack_into_dict(token_offs_pad) |
| block_pid_map = _unpack_into_dict(block_pid_map) |
| return ExptData(expt_hist, token_offs_raw, token_offs_pad, block_pid_map) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act): |
| hist, topk_indx, gate_indx, gate_scal, token_offs_raw, token_offs_pad, block_pid_map = sort_tokens( |
| expt_scal, expt_indx, n_expts_tot, bitmatrix) |
| token_offs_pad = _unpack_into_dict(token_offs_pad) |
| block_pid_map = _unpack_into_dict(block_pid_map) |
| expt_data = ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) |
|
|
| |
| gather_indx = GatherIndx(src_indx=topk_indx, dst_indx=gate_indx) |
| scatter_indx = ScatterIndx(src_indx=gate_indx, dst_indx=topk_indx) |
| return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx |
|
|
|
|
| def routing(logits, n_expts_act, sm_first=False, expt_indx=None, simulated_ep=1, n_rows=None): |
| from .topk import topk |
| if sm_first: |
| logits = torch.softmax(logits, dim=-1) |
| expt_scal, expt_indx, bitmatrix = topk(logits, n_expts_act, |
| apply_softmax=not sm_first, y_indx=expt_indx, n_rows=n_rows) |
| n_expts_tot = logits.shape[-1] // simulated_ep |
| |
| if simulated_ep > 1: |
| expt_scal, expt_indx, bitmatrix = prune_routing(expt_scal, expt_indx, bitmatrix, logits.shape[-1], simulated_ep) |
|
|
| return routing_from_bitmatrix(bitmatrix, expt_scal, expt_indx, n_expts_tot, n_expts_act) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def compute_expt_data_torch(hist, n_expts_tot, n_gates): |
| |
| device = hist.device |
| token_offs_raw = torch.cumsum(hist, dim=0) |
| token_offs_raw = torch.cat((torch.zeros(1, device=device), token_offs_raw)) |
| token_offs_raw = token_offs_raw.int() |
| |
| block_ms = [16, 32, 64, 128] |
| if is_hip(): |
| block_ms.append(256) |
| if n_gates <= n_expts_tot: |
| max_n_tiles = n_gates |
| else: |
| |
| |
| max_n_tiles = n_expts_tot - 1 - ((n_expts_tot - n_gates - 1) // min(block_ms)) |
| |
| token_offs_pad = dict() |
| block_pid_map = dict() |
| for block_m in block_ms: |
| n_tiles = (hist + block_m - 1) // block_m |
| token_offs_pad[block_m] = torch.cumsum(n_tiles, dim=0) |
| token_offs_pad[block_m] = torch.cat((torch.zeros(1, device=device), token_offs_pad[block_m])) |
| token_offs_pad[block_m] = token_offs_pad[block_m].int() |
| |
| block_pid_map[block_m] = -torch.ones(max_n_tiles, device=device) |
| for e in range(n_expts_tot): |
| offset = token_offs_pad[block_m][e] |
| for b in range(n_tiles[e]): |
| block_pid_map[block_m][offset + b] = (b << 16) + e |
| block_pid_map[block_m] = block_pid_map[block_m].int() |
| return ExptData(hist, token_offs_raw, token_offs_pad, block_pid_map) |
|
|
|
|
| def routing_torch(logits, n_expts_act, sm_first=False, expt_indx=None, n_rows=None): |
| has_user_provided_indx = expt_indx is not None |
| n_gates_pad = logits.shape[0] * n_expts_act |
|
|
| if n_rows is not None: |
| logits = logits[:n_rows, :] |
|
|
| def topk(vals, k, expt_indx): |
| |
| if has_user_provided_indx: |
| tk_indx = expt_indx |
| else: |
| tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k] |
| tk_indx = tk_indx.long() |
| tk_val = torch.take_along_dim(vals, tk_indx, dim=1) |
| tk_indx = tk_indx.int() |
| return tk_val, tk_indx |
|
|
| _, n_expts_tot = logits.shape |
| if sm_first: |
| logits = torch.softmax(logits, dim=-1) |
| expt_scal, expt_indx = topk(logits, n_expts_act, expt_indx) |
| if not sm_first: |
| expt_scal = torch.softmax(expt_scal, dim=-1) |
| |
| if not has_user_provided_indx: |
| expt_indx, sort_indices = torch.sort(expt_indx, dim=1) |
| expt_scal = torch.gather(expt_scal, 1, sort_indices) |
| |
| expt_scal = expt_scal.reshape(-1) |
| expt_indx = expt_indx.reshape(-1).to(torch.int32) |
| |
| topk_indx = torch.argsort(expt_indx, stable=True) |
| gate_indx = torch.argsort(topk_indx, stable=True) |
| gate_scal = expt_scal[topk_indx] |
| hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1).int() |
| |
| gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int()) |
| scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int()) |
| |
| expt_data = compute_expt_data_torch(hist, n_expts_tot, n_gates_pad) |
| return RoutingData(gate_scal, hist, n_expts_tot, n_expts_act, expt_data), gather_indx, scatter_indx |
|
|