| import torch |
| from yunchang.kernels import AttnType, select_flash_attn_impl |
|
|
| from .utils import RingComm, update_out_and_lse |
|
|
|
|
| def ring_flash_attn_forward( |
| process_group, |
| q: torch.Tensor, |
| k: torch.Tensor, |
| v: torch.Tensor, |
| softmax_scale, |
| dropout_p=0, |
| causal=True, |
| window_size=(-1, -1), |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| attn_type: AttnType = AttnType.FA, |
| attn_processor=None, |
| ): |
| comm = RingComm(process_group) |
|
|
| out = None |
| lse = None |
|
|
| next_k, next_v = None, None |
|
|
| for step in range(comm.world_size): |
| if step + 1 != comm.world_size: |
| next_k: torch.Tensor = comm.send_recv(k) |
| next_v: torch.Tensor = comm.send_recv(v) |
| comm.commit() |
|
|
| if not causal or step <= comm.rank: |
| fn = select_flash_attn_impl( |
| attn_type, stage="fwd-only", attn_processor=attn_processor |
| ) |
| block_out, block_lse = fn( |
| q, |
| k, |
| v, |
| dropout_p=dropout_p, |
| softmax_scale=softmax_scale, |
| causal=causal and step == 0, |
| window_size=window_size, |
| softcap=softcap, |
| alibi_slopes=alibi_slopes, |
| return_softmax=True and dropout_p > 0, |
| ) |
| if attn_type == AttnType.SPARSE_SAGE: |
| out, lse = block_out, block_lse |
| else: |
| out, lse = update_out_and_lse(out, lse, block_out, block_lse) |
|
|
| if step + 1 != comm.world_size: |
| comm.wait() |
| k = next_k |
| v = next_v |
|
|
| out = out.to(q.dtype) |
| if attn_type != AttnType.SPARSE_SAGE: |
| lse = lse.squeeze(dim=-1).transpose(1, 2) |
| return out, lse |
|
|
|
|
| def ring_flash_attn_backward( |
| process_group, |
| dout, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| softmax_scale, |
| dropout_p=0, |
| causal=True, |
| window_size=(-1, -1), |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| attn_type: AttnType = AttnType.FA, |
| ): |
| kv_comm = RingComm(process_group) |
| d_kv_comm = RingComm(process_group) |
| dq, dk, dv = None, None, None |
| next_dk, next_dv = None, None |
|
|
| block_dq_buffer = torch.empty(q.shape, dtype=q.dtype, device=q.device) |
| block_dk_buffer = torch.empty(k.shape, dtype=k.dtype, device=k.device) |
| block_dv_buffer = torch.empty(v.shape, dtype=v.dtype, device=v.device) |
|
|
| next_dk, next_dv = None, None |
| next_k, next_v = None, None |
|
|
| for step in range(kv_comm.world_size): |
| if step + 1 != kv_comm.world_size: |
| next_k = kv_comm.send_recv(k) |
| next_v = kv_comm.send_recv(v) |
| kv_comm.commit() |
| if step <= kv_comm.rank or not causal: |
| bwd_causal = causal and step == 0 |
| fn = select_flash_attn_impl(attn_type, stage="bwd-only") |
| fn( |
| dout, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| block_dq_buffer, |
| block_dk_buffer, |
| block_dv_buffer, |
| dropout_p, |
| softmax_scale, |
| bwd_causal, |
| window_size, |
| softcap, |
| alibi_slopes, |
| deterministic, |
| rng_state=None, |
| ) |
|
|
| if dq is None: |
| dq = block_dq_buffer.to(torch.float32) |
| dk = block_dk_buffer.to(torch.float32) |
| dv = block_dv_buffer.to(torch.float32) |
| else: |
| dq += block_dq_buffer |
| d_kv_comm.wait() |
| dk = block_dk_buffer + next_dk |
| dv = block_dv_buffer + next_dv |
| elif step != 0: |
| d_kv_comm.wait() |
| dk = next_dk |
| dv = next_dv |
|
|
| if step + 1 != kv_comm.world_size: |
| kv_comm.wait() |
| k = next_k |
| v = next_v |
|
|
| next_dk = d_kv_comm.send_recv(dk) |
| next_dv = d_kv_comm.send_recv(dv) |
| d_kv_comm.commit() |
|
|
| d_kv_comm.wait() |
|
|
| return dq.to(torch.bfloat16), next_dk.to(q.dtype), next_dv.to(q.dtype) |
|
|
|
|
| class RingFlashAttnFunc(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| q, |
| k, |
| v, |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| softcap, |
| alibi_slopes, |
| deterministic, |
| return_softmax, |
| group, |
| attn_type, |
| attn_processor, |
| ): |
| if softmax_scale is None: |
| softmax_scale = q.shape[-1] ** (-0.5) |
|
|
| assert alibi_slopes is None |
| k = k.contiguous() |
| v = v.contiguous() |
| out, softmax_lse = ring_flash_attn_forward( |
| group, |
| q, |
| k, |
| v, |
| softmax_scale=softmax_scale, |
| dropout_p=dropout_p, |
| causal=causal, |
| window_size=window_size, |
| softcap=softcap, |
| alibi_slopes=alibi_slopes, |
| deterministic=False, |
| attn_type=attn_type, |
| attn_processor=attn_processor, |
| ) |
| |
| ctx.save_for_backward(q, k, v, out, softmax_lse) |
| ctx.dropout_p = dropout_p |
| ctx.softmax_scale = softmax_scale |
| ctx.causal = causal |
| ctx.window_size = window_size |
| ctx.softcap = softcap |
| ctx.alibi_slopes = alibi_slopes |
| ctx.deterministic = deterministic |
| ctx.group = group |
| ctx.attn_type = attn_type |
| ctx.attn_processor = attn_processor |
| return out if not return_softmax else (out, softmax_lse, None) |
|
|
| @staticmethod |
| def backward(ctx, dout, *args): |
| q, k, v, out, softmax_lse = ctx.saved_tensors |
| dq, dk, dv = ring_flash_attn_backward( |
| ctx.group, |
| dout, |
| q, |
| k, |
| v, |
| out, |
| softmax_lse, |
| softmax_scale=ctx.softmax_scale, |
| dropout_p=ctx.dropout_p, |
| causal=ctx.causal, |
| window_size=ctx.window_size, |
| softcap=ctx.softcap, |
| alibi_slopes=ctx.alibi_slopes, |
| deterministic=ctx.deterministic, |
| attn_type=ctx.attn_type, |
| ) |
| return ( |
| dq, |
| dk, |
| dv, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| def ring_flash_attn_qkvpacked_func( |
| qkv, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| window_size=(-1, -1), |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| return_attn_probs=False, |
| group=None, |
| attn_type: AttnType = AttnType.FA, |
| ): |
| return RingFlashAttnFunc.apply( |
| qkv[:, :, 0], |
| qkv[:, :, 1], |
| qkv[:, :, 2], |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| softcap, |
| alibi_slopes, |
| deterministic, |
| return_attn_probs, |
| group, |
| attn_type, |
| ) |
|
|
|
|
| def ring_flash_attn_kvpacked_func( |
| q, |
| kv, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| window_size=(-1, -1), |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| return_attn_probs=False, |
| group=None, |
| attn_type: AttnType = AttnType.FA, |
| ): |
| return RingFlashAttnFunc.apply( |
| q, |
| kv[:, :, 0], |
| kv[:, :, 1], |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| softcap, |
| alibi_slopes, |
| deterministic, |
| return_attn_probs, |
| group, |
| attn_type, |
| ) |
|
|
|
|
| def ring_flash_attn_func( |
| q, |
| k, |
| v, |
| dropout_p=0.0, |
| softmax_scale=None, |
| causal=False, |
| window_size=(-1, -1), |
| softcap=0.0, |
| alibi_slopes=None, |
| deterministic=False, |
| return_attn_probs=False, |
| group=None, |
| attn_type: AttnType = AttnType.FA, |
| attn_processor=None, |
| ): |
| return RingFlashAttnFunc.apply( |
| q, |
| k, |
| v, |
| dropout_p, |
| softmax_scale, |
| causal, |
| window_size, |
| softcap, |
| alibi_slopes, |
| deterministic, |
| return_attn_probs, |
| group, |
| attn_type, |
| attn_processor, |
| ) |
|
|