vidfom's picture
Upload folder using huggingface_hub
31112ad verified
import triton
import triton.language as tl
import os
if os.environ.get('TRITON_AUTOTUNE_ENBALE', '0') == '1':
autotune = triton.autotune
else:
def autotune(*args, **kwargs):
def decorator(func):
return func
return decorator
configs_gating_preset = {
'default': {
'BLOCK_M': 64,
'BLOCK_N': 64,
'num_stages': 3,
'num_warps': 8,
}
}
configs_gating = [
triton.Config({'BLOCK_M': BM, 'BLOCK_N': BN}, num_stages=s, num_warps=w) \
for BM in [64, 128] \
for BN in [32, 64] \
for s in [2, 3, 4, 5] \
for w in [4, 8] \
]
gating_reevaluate_keys = ["M", "N"] if os.environ.get('TRITON_REEVALUATE_KEY', '0') == '1' else []
@autotune(configs_gating, key=gating_reevaluate_keys)
@triton.jit
def _attn_fwd_gating(
Q, K, Out,
stride_qz, stride_qh, stride_qm, stride_qk,
stride_kz, stride_kh, stride_kn, stride_kk,
stride_oz, stride_oh, stride_om, stride_on,
H, M, N,
HEAD_DIM: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_N: tl.constexpr,
):
tl.static_assert(BLOCK_N <= HEAD_DIM)
start_m = tl.program_id(0)
off_hz = tl.program_id(1)
off_z = off_hz // H
off_h = off_hz % H
q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh
k_offset = off_z.to(tl.int64) * stride_kz + off_h.to(tl.int64) * stride_kh
o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh
# block pointers
Q_block_ptr = tl.make_block_ptr(
base=Q + q_offset,
shape=(M, HEAD_DIM),
strides=(stride_qm, stride_qk),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, HEAD_DIM),
order=(1, 0),
)
K_block_ptr = tl.make_block_ptr(
base=K + k_offset,
shape=(HEAD_DIM, N),
strides=(stride_kk, stride_kn),
offsets=(0, 0),
block_shape=(HEAD_DIM, BLOCK_N),
order=(0, 1),
)
O_block_ptr = tl.make_block_ptr(
base=Out + o_offset,
shape=(M, N),
strides=(stride_om, stride_on),
offsets=(start_m * BLOCK_M, 0),
block_shape=(BLOCK_M, BLOCK_N),
order=(1, 0),
)
# load q: it will stay in SRAM throughout
q = tl.load(Q_block_ptr, boundary_check=(0,))
for start_n in range(0, N, BLOCK_N):
start_n = tl.multiple_of(start_n, BLOCK_N)
# -- compute qk ----
k = tl.load(K_block_ptr, boundary_check=(1,))
qk = tl.dot(q, k)
tl.store(O_block_ptr, qk.to(Out.type.element_ty), boundary_check=(0, 1))
K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))
O_block_ptr = tl.advance(O_block_ptr, (0, BLOCK_N))
@triton.jit
def _attn_bwd_preprocess(
O, DO,
Delta, # output
N_CTX,
BLOCK_M: tl.constexpr,
HEAD_DIM: tl.constexpr
):
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
off_hz = tl.program_id(1)
off_n = tl.arange(0, HEAD_DIM)
# load
o = tl.load(O + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :])
do = tl.load(DO + off_hz * HEAD_DIM * N_CTX + off_m[:, None] * HEAD_DIM + off_n[None, :]).to(tl.float32)
delta = tl.sum(o * do, axis=1)
# write-back
tl.store(Delta + off_hz * N_CTX + off_m, delta)