MTP-120M / fla /ops /nsa /utils.py
Erland's picture
Add files using upload-large-folder tool
7fdd671 verified
# -*- coding: utf-8 -*-
# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang
# Implements argsort based on bitonic sort.
# [What is bitonic sort?](https://en.wikipedia.org/wiki/Bitonic_sorter)
# Code adapted from https://github.com/triton-lang/triton/issues/3698#issuecomment-2067681396
import triton
import triton.language as tl
from fla.ops.utils.op import log2
@triton.jit
def _compare_and_swap(
x,
ids,
flip,
i: tl.constexpr,
n_dims: tl.constexpr,
):
n_outer: tl.constexpr = x.numel >> n_dims
shape: tl.constexpr = [n_outer * 2**i, 2, 2**(n_dims - i - 1)]
y = tl.reshape(x, shape)
# slice left/right with 'stride' 2**(n_dims - i - 1)
mask = tl.arange(0, 2)[None, :, None]
left = tl.broadcast_to(tl.sum(y * (1 - mask), 1)[:, None, :], shape).to(y.dtype)
right = tl.broadcast_to(tl.sum(y * mask, 1)[:, None, :], shape).to(y.dtype)
left = tl.reshape(left, x.shape)
right = tl.reshape(right, x.shape)
# idx
y_idx = tl.reshape(ids, shape)
left_idx = tl.broadcast_to(tl.sum(y_idx * (1 - mask), 1)[:, None, :], shape)
right_idx = tl.broadcast_to(tl.sum(y_idx * mask, 1)[:, None, :], shape)
left_idx = tl.reshape(left_idx, x.shape).to(y_idx.dtype)
right_idx = tl.reshape(right_idx, x.shape).to(y_idx.dtype)
# actual compare-and-swap
idtype = tl.core.get_int_dtype(bitwidth=x.dtype.primitive_bitwidth, signed=True)
ileft = left.to(idtype, bitcast=True)
iright = right.to(idtype, bitcast=True)
ix = x.to(idtype, bitcast=True)
cond = (left > right) != flip
ret = ix ^ tl.where(cond, ileft ^ iright, tl.zeros_like(ix))
new_ids = ids ^ tl.where(cond, left_idx ^ right_idx, tl.zeros_like(ids))
return ret.to(x.dtype, bitcast=True), new_ids
@triton.jit
def _bitonic_merge(
x,
ids,
stage: tl.constexpr,
order: tl.constexpr,
n_dims: tl.constexpr,
):
n_outer: tl.constexpr = x.numel >> n_dims
tl.static_assert(stage <= n_dims)
# flip denotes whether to re-arrange sub-sequences of elements in ascending or
# descending order.
# if flip = 00000000... then all elements will be re-arranged ascendingly at this stage
# if flip = 00110011... then all the elements will be re-arranged alternatingly (with
# a stride of 2) at this stage
if order == 2:
shape: tl.constexpr = [n_outer * 2**(n_dims - 1 - stage), 2, 2**stage]
flip = tl.reshape(tl.broadcast_to(tl.arange(0, 2)[None, :, None], shape), x.shape)
else:
flip = order
# perform `stage` rounds of `compare-and-swap`
for i in tl.static_range(stage):
x, ids = _compare_and_swap(x, ids, flip, i + (n_dims - stage), n_dims)
return x, ids
@triton.jit
def argsort(
x,
ids,
dim: tl.constexpr = None,
descending: tl.constexpr = tl.core.CONSTEXPR_0,
):
# handle default dimension or check that it is the most minor dim
_dim: tl.constexpr = len(x.shape) - 1 if dim is None else dim
tl.static_assert(_dim == len(x.shape) - 1, "only minor dimension is currently supported")
# iteratively run bitonic merge-sort steps
n_dims: tl.constexpr = log2(x.shape[_dim])
for i in tl.static_range(1, n_dims + 1):
x, ids = _bitonic_merge(x, ids, i, 2 if i < n_dims else descending, n_dims)
return x, ids