Hanrui / progress /SpecForge /specforge /layers /ring /ring_flash_attn.py
Lekr0's picture
Add files using upload-large-folder tool
62dca4c verified
import torch
from yunchang.kernels import AttnType, select_flash_attn_impl
from .utils import RingComm, update_out_and_lse
def ring_flash_attn_forward(
process_group,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
attn_type: AttnType = AttnType.FA,
attn_processor=None,
):
comm = RingComm(process_group)
out = None
lse = None
next_k, next_v = None, None
for step in range(comm.world_size):
if step + 1 != comm.world_size:
next_k: torch.Tensor = comm.send_recv(k)
next_v: torch.Tensor = comm.send_recv(v)
comm.commit()
if not causal or step <= comm.rank:
fn = select_flash_attn_impl(
attn_type, stage="fwd-only", attn_processor=attn_processor
)
block_out, block_lse = fn(
q,
k,
v,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal and step == 0,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
return_softmax=True and dropout_p > 0,
)
if attn_type == AttnType.SPARSE_SAGE:
out, lse = block_out, block_lse
else:
out, lse = update_out_and_lse(out, lse, block_out, block_lse)
if step + 1 != comm.world_size:
comm.wait()
k = next_k
v = next_v
out = out.to(q.dtype)
if attn_type != AttnType.SPARSE_SAGE:
lse = lse.squeeze(dim=-1).transpose(1, 2)
return out, lse
def ring_flash_attn_backward(
process_group,
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale,
dropout_p=0,
causal=True,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
attn_type: AttnType = AttnType.FA,
):
kv_comm = RingComm(process_group)
d_kv_comm = RingComm(process_group)
dq, dk, dv = None, None, None
next_dk, next_dv = None, None
block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device)
block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device)
block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device)
next_dk, next_dv = None, None
next_k, next_v = None, None
for step in range(kv_comm.world_size):
if step + 1 != kv_comm.world_size:
next_k = kv_comm.send_recv(k)
next_v = kv_comm.send_recv(v)
kv_comm.commit()
if step <= kv_comm.rank or not causal:
bwd_causal = causal and step == 0
fn = select_flash_attn_impl(attn_type, stage="bwd-only")
fn(
dout,
q,
k,
v,
out,
softmax_lse,
block_dq_buffer,
block_dk_buffer,
block_dv_buffer,
dropout_p,
softmax_scale,
bwd_causal,
window_size,
softcap,
alibi_slopes,
deterministic,
rng_state=None,
)
if dq is None:
dq = block_dq_buffer.to(torch.float32)
dk = block_dk_buffer.to(torch.float32)
dv = block_dv_buffer.to(torch.float32)
else:
dq += block_dq_buffer
d_kv_comm.wait()
dk = block_dk_buffer + next_dk
dv = block_dv_buffer + next_dv
elif step != 0:
d_kv_comm.wait()
dk = next_dk
dv = next_dv
if step + 1 != kv_comm.world_size:
kv_comm.wait()
k = next_k
v = next_v
next_dk = d_kv_comm.send_recv(dk)
next_dv = d_kv_comm.send_recv(dv)
d_kv_comm.commit()
d_kv_comm.wait()
return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype)
class RingFlashAttnFunc(torch.autograd.Function):
@staticmethod
def forward(
ctx,
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_softmax,
group,
attn_type,
attn_processor,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
assert alibi_slopes is None
k = k.contiguous()
v = v.contiguous()
out, softmax_lse = ring_flash_attn_forward(
group,
q,
k,
v,
softmax_scale=softmax_scale,
dropout_p=dropout_p,
causal=causal,
window_size=window_size,
softcap=softcap,
alibi_slopes=alibi_slopes,
deterministic=False,
attn_type=attn_type,
attn_processor=attn_processor,
)
# this should be out_padded
ctx.save_for_backward(q, k, v, out, softmax_lse)
ctx.dropout_p = dropout_p
ctx.softmax_scale = softmax_scale
ctx.causal = causal
ctx.window_size = window_size
ctx.softcap = softcap
ctx.alibi_slopes = alibi_slopes
ctx.deterministic = deterministic
ctx.group = group
ctx.attn_type = attn_type
ctx.attn_processor = attn_processor
return out if not return_softmax else (out, softmax_lse, None)
@staticmethod
def backward(ctx, dout, *args):
q, k, v, out, softmax_lse = ctx.saved_tensors
dq, dk, dv = ring_flash_attn_backward(
ctx.group,
dout,
q,
k,
v,
out,
softmax_lse,
softmax_scale=ctx.softmax_scale,
dropout_p=ctx.dropout_p,
causal=ctx.causal,
window_size=ctx.window_size,
softcap=ctx.softcap,
alibi_slopes=ctx.alibi_slopes,
deterministic=ctx.deterministic,
attn_type=ctx.attn_type,
)
return (
dq,
dk,
dv,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
)
def ring_flash_attn_qkvpacked_func(
qkv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = AttnType.FA,
):
return RingFlashAttnFunc.apply(
qkv[:, :, 0],
qkv[:, :, 1],
qkv[:, :, 2],
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_type,
)
def ring_flash_attn_kvpacked_func(
q,
kv,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = AttnType.FA,
):
return RingFlashAttnFunc.apply(
q,
kv[:, :, 0],
kv[:, :, 1],
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_type,
)
def ring_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
window_size=(-1, -1),
softcap=0.0,
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
group=None,
attn_type: AttnType = AttnType.FA,
attn_processor=None,
):
return RingFlashAttnFunc.apply(
q,
k,
v,
dropout_p,
softmax_scale,
causal,
window_size,
softcap,
alibi_slopes,
deterministic,
return_attn_probs,
group,
attn_type,
attn_processor,
)