| import torch |
| import triton |
| from .topk_details._topk_forward import _topk_forward |
| from .topk_details._topk_backward import _topk_backward |
| from .tensor import Tensor, Bitmatrix |
|
|
|
|
| def topk_forward(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): |
| if not isinstance(x, Tensor): |
| x_shape = [x.shape[0] if n_rows is None else n_rows, x.shape[1]] |
| x_shape_max = [x.shape[0], x.shape[1]] |
| x = Tensor(x, shape=x_shape, shape_max=x_shape_max) |
| cdiv = lambda a, b: (a + b - 1) // b |
| BLOCK_M = 32 |
| BLOCK_N = 32 |
| BLOCK_S = 128 |
| assert len(x.shape) == 2 |
| assert x.shape_max[-1] < 32768 |
| assert dim == 1 |
| assert return_bitmatrix |
| n_rows, n_cols = x.shape |
| n_rows_max, _ = x.shape_max |
| dev = x.device |
| |
| |
| y_vals = torch.empty((n_rows_max, k), dtype=x.dtype, device=dev) |
| if y_indx is not None: |
| use_provided_indx = True |
| else: |
| y_indx = torch.empty((n_rows_max, k), dtype=torch.int16, device=dev) |
| use_provided_indx = False |
| |
| n_cols_pad = cdiv(n_cols, BLOCK_N) * BLOCK_N |
| n_cols_words = n_cols_pad // 32 |
| bitmatrix = torch.empty((n_cols_words, cdiv(n_rows_max, 32) * 32), dtype=torch.uint32, device=dev) |
| bitmatrix = torch.transpose(bitmatrix, 0, 1)[:n_rows_max] |
| s_blocks = cdiv(n_cols, BLOCK_S) |
| s_cols = s_blocks * BLOCK_S |
| scratchpad = torch.empty((s_cols, ), dtype=torch.int32, device=dev) |
| pids = max(cdiv(n_rows_max, BLOCK_M), s_blocks) |
| _topk_forward[(pids, )]( |
| x, x.stride(0), |
| y_vals, y_indx, y_vals.stride(0), use_provided_indx, |
| bitmatrix, bitmatrix.stride(0), bitmatrix.stride(1), |
| n_rows, n_cols, |
| scratchpad, BLOCK_S, s_blocks, |
| BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, |
| APPLY_SOFTMAX=apply_softmax, N_EXPTS_PAD=n_cols_pad, N_EXPTS_ACT=k, |
| ) |
| bitmatrix_shape = [n_rows, n_cols_words * 32] |
| bitmatrix_shape_max = [n_rows_max, None] |
| bitmatrix = Bitmatrix(bitmatrix, shape=bitmatrix_shape, shape_max=bitmatrix_shape_max, scratchpad=scratchpad) |
| return y_vals, y_indx, bitmatrix |
|
|
|
|
| def topk_backward(x, y_indx, dy_vals, k, n_rows, apply_softmax): |
| assert dy_vals.shape[-1] == k |
| n_expts_pad = triton.next_power_of_2(x.shape[-1]) |
| dx = torch.empty_like(x) |
| _topk_backward[(dy_vals.shape[0], )]( |
| y_indx, y_indx.stride(0), dy_vals, dy_vals.stride(0), x, x.stride(0), |
| dx, |
| dx.stride(0), x.shape[0], n_rows, x.shape[-1], APPLY_SOFTMAX=apply_softmax, N_EXPTS_ACT=k, |
| N_EXPTS_PAD=n_expts_pad) |
| return dx |
|
|
|
|
| class TopK(torch.autograd.Function): |
|
|
| @staticmethod |
| def forward(ctx, x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows): |
| y_vals, y_indx, bitmatrix = topk_forward(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows) |
| ctx.save_for_backward(x, y_indx) |
| ctx.apply_softmax = apply_softmax |
| ctx.k = k |
| ctx.n_rows = n_rows |
| return y_vals, y_indx, bitmatrix |
|
|
| @staticmethod |
| def backward(ctx, dy_vals, _0, _1): |
| x, y_indx = ctx.saved_tensors |
| dx = topk_backward(x, y_indx, dy_vals, ctx.k, ctx.n_rows, ctx.apply_softmax) |
| return dx, None, None, None, None, None, None |
|
|
|
|
| def topk(x, k, apply_softmax=True, dim=1, return_bitmatrix=True, y_indx=None, n_rows=None): |
| ret = TopK.apply(x, k, apply_softmax, dim, return_bitmatrix, y_indx, n_rows) |
| return ret |
|
|
|
|
| |
| |
|
|