Commit ·
4311911
1
Parent(s): f784b6e
Make compatible with recent versions of triton
Browse files- flash_attn_triton.py +4 -4
flash_attn_triton.py
CHANGED
|
@@ -188,7 +188,7 @@ def _fwd_kernel(
|
|
| 188 |
(offs_d[None, :] < headdim),
|
| 189 |
other=0.0)
|
| 190 |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 191 |
-
qk += tl.dot(q, k
|
| 192 |
# Trying to combine the two masks seem to make the result wrong
|
| 193 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 194 |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
|
@@ -431,7 +431,7 @@ def _bwd_kernel_one_col_block(
|
|
| 431 |
(offs_d[None, :] < headdim),
|
| 432 |
other=0.0)
|
| 433 |
# recompute p = softmax(qk, dim=-1).T
|
| 434 |
-
qk = tl.dot(q, k
|
| 435 |
# Trying to combine the two masks seem to make the result wrong
|
| 436 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 437 |
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|
|
@@ -491,7 +491,7 @@ def _bwd_kernel_one_col_block(
|
|
| 491 |
# else:
|
| 492 |
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
| 493 |
# & (offs_d[None, :] < headdim), other=0.0)
|
| 494 |
-
dv += tl.dot(p.to(do.dtype), do
|
| 495 |
# compute dp = dot(v, do)
|
| 496 |
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
| 497 |
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
|
@@ -509,7 +509,7 @@ def _bwd_kernel_one_col_block(
|
|
| 509 |
# for BLOCK_HEADDIM=128
|
| 510 |
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
| 511 |
# compute dk = dot(ds.T, q)
|
| 512 |
-
dk += tl.dot(ds, q
|
| 513 |
# compute dq
|
| 514 |
if not ATOMIC_ADD:
|
| 515 |
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|
|
|
|
| 188 |
(offs_d[None, :] < headdim),
|
| 189 |
other=0.0)
|
| 190 |
qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
|
| 191 |
+
qk += tl.dot(q, tl.trans(k))
|
| 192 |
# Trying to combine the two masks seem to make the result wrong
|
| 193 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 194 |
qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0,
|
|
|
|
| 431 |
(offs_d[None, :] < headdim),
|
| 432 |
other=0.0)
|
| 433 |
# recompute p = softmax(qk, dim=-1).T
|
| 434 |
+
qk = tl.dot(q, tl.trans(k))
|
| 435 |
# Trying to combine the two masks seem to make the result wrong
|
| 436 |
if not EVEN_N: # Need to mask out otherwise the softmax is wrong
|
| 437 |
qk = tl.where(offs_n[None, :] < seqlen_k, qk, float('-inf'))
|
|
|
|
| 491 |
# else:
|
| 492 |
# do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q)
|
| 493 |
# & (offs_d[None, :] < headdim), other=0.0)
|
| 494 |
+
dv += tl.dot(tl.trans(p).to(do.dtype), do)
|
| 495 |
# compute dp = dot(v, do)
|
| 496 |
# There seems to be a race condition when headdim=48/96, and dq, dk are wrong.
|
| 497 |
# Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True
|
|
|
|
| 509 |
# for BLOCK_HEADDIM=128
|
| 510 |
ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype)
|
| 511 |
# compute dk = dot(ds.T, q)
|
| 512 |
+
dk += tl.dot(tl.trans(ds), q)
|
| 513 |
# compute dq
|
| 514 |
if not ATOMIC_ADD:
|
| 515 |
if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M
|