diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e235d9274d8a8f3fc355f8aa70f464d026ff6c11 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/__init__.py @@ -0,0 +1,11 @@ +__version__ = "2.6.3" + +from flash_attn.flash_attn_interface import ( + flash_attn_func, + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, +) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/bert_padding.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/bert_padding.py new file mode 100644 index 0000000000000000000000000000000000000000..1d447d3f660e1a6ddd7e7f6fb7d1ae4241bfec73 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/bert_padding.py @@ -0,0 +1,213 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py + +import torch +import torch.nn.functional as F +from einops import rearrange, repeat + + +class IndexFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + # return input[indices] + return torch.gather( + rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) + ).reshape(-1, *other_shape) + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + grad_output = rearrange(grad_output, "b ... -> b (...)") + grad_input = torch.zeros( + [ctx.first_axis_dim, grad_output.shape[1]], + device=grad_output.device, + dtype=grad_output.dtype, + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + # grad_input[indices] = grad_output + grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis = IndexFirstAxis.apply + + +class IndexPutFirstAxis(torch.autograd.Function): + @staticmethod + def forward(ctx, values, indices, first_axis_dim): + ctx.save_for_backward(indices) + assert indices.ndim == 1 + assert values.ndim >= 2 + output = torch.zeros( + first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype + ) + # TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. + output[indices] = values + # output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) + return output + + @staticmethod + def backward(ctx, grad_output): + (indices,) = ctx.saved_tensors + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + grad_values = grad_output[indices] + # grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) + return grad_values, None, None + + +index_put_first_axis = IndexPutFirstAxis.apply + + +class IndexFirstAxisResidual(torch.autograd.Function): + @staticmethod + def forward(ctx, input, indices): + ctx.save_for_backward(indices) + assert input.ndim >= 2 + ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] + second_dim = other_shape.numel() + # TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. + output = input[indices] + # We don't want to reshape input (b ... -> b (...)) since it could change the channel_last + # memory format to channel_first. In other words, input might not be contiguous. + # If we don't detach, Pytorch complains about output being a view and is being modified inplace + return output, input.detach() + + @staticmethod + def backward(ctx, grad_output, grad_residual): + (indices,) = ctx.saved_tensors + assert grad_output.ndim >= 2 + other_shape = grad_output.shape[1:] + assert grad_residual.shape[1:] == other_shape + grad_input = grad_residual + # grad_input[indices] += grad_output + indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) + indices = indices.expand_as(grad_output) + grad_input.scatter_add_(0, indices, grad_output) + return grad_input.reshape(ctx.first_axis_dim, *other_shape), None + + +index_first_axis_residual = IndexFirstAxisResidual.apply + + +def unpad_input(hidden_states, attention_mask): + """ + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): + """ + Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). + The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). + + For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: + ``` + [ + [2, 3, 0, 0, 0, 0], + [3, 2, 0, 0, 0, 0], + [6, 0, 0, 0, 0, 0] + ] + ``` + , which refers to the 3D-attention mask: + ``` + [ + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 1, 1, 0, 0], + [0, 0, 1, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 1, 0], + [0, 0, 0, 0, 0, 1] + ], + [ + [1, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0], + [1, 1, 1, 1, 0, 0], + [1, 1, 1, 1, 1, 0], + [1, 1, 1, 1, 1, 1] + ] + ] + ```. + + Arguments: + hidden_states: (batch, seqlen, ...) + attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. + Return: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. + cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. + max_seqlen_in_batch: int + """ + length = attention_mask_in_length.sum(dim=-1) + seqlen = attention_mask_in_length.size(-1) + attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), seqlen) < length.unsqueeze(1) + real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() + seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] + indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) + # TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the + # bool mask, then call nonzero to get the indices, then index with those. The indices is @dim + # times larger than it needs to be, wasting memory. It's faster and more memory-efficient to + # index with integer indices. Moreover, torch's index is a bit slower than it needs to be, + # so we write custom forward and backward to make it a bit faster. + return ( + index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +def pad_input(hidden_states, indices, batch, seqlen): + """ + Arguments: + hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. + indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. + batch: int, batch size for the padded sequence. + seqlen: int, maximum sequence length for the padded sequence. + Return: + hidden_states: (batch, seqlen, ...) + """ + dim = hidden_states.shape[-1] + # output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) + # output[indices] = hidden_states + output = index_put_first_axis(hidden_states, indices, batch * seqlen) + return rearrange(output, "(b s) ... -> b s ...", b=batch) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_interface.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..ecb3515c0fd86bf84c14cb232f484d72c7722364 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_interface.py @@ -0,0 +1,1286 @@ +# Copyright (c) 2023, Tri Dao. + +from typing import Optional, Union + +import torch +import torch.nn as nn + +# isort: off +# We need to import the CUDA kernels after importing torch +import flash_attn_2_cuda as flash_attn_cuda + +# isort: on + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +def _get_block_size_n(device, head_dim, is_dropout, is_causal): + # This should match the block sizes in the CUDA kernel + assert head_dim <= 256 + major, minor = torch.cuda.get_device_capability(device) + is_sm8x = major == 8 and minor > 0 # Only include sm86 and sm89, exclude sm80 (A100) + is_sm80 = major == 8 and minor == 0 + is_sm90 = major == 9 and minor == 0 + if head_dim <= 32: + return 128 + if head_dim <= 64: + return 128 if not is_dropout else 64 + elif head_dim <= 96: + return 64 + elif head_dim <= 128: + if is_sm8x: + return 64 if (not is_dropout and is_causal) else 32 + else: + return 64 if not is_dropout else 32 + elif head_dim <= 160: + if is_sm8x: + return 64 + else: + return 32 + elif head_dim <= 192: + return 64 + elif head_dim <= 224: + return 64 + elif head_dim <= 256: + return 64 + + +def _flash_attn_forward( + q, k, v, dropout_p, softmax_scale, causal, window_size, softcap, alibi_slopes, return_softmax +): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd( + q, + k, + v, + None, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + ) + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +def _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size=(-1, -1), + softcap=0.0, + alibi_slopes=None, + return_softmax=False, + block_table=None, + leftpad_k=None, + seqused_k=None, +): + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd( + q, + k, + v, + None, + cu_seqlens_q, + cu_seqlens_k, + seqused_k, + leftpad_k, + block_table, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + return_softmax, + None, + ) + # if out.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + return out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state + + +def _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state=None, +): + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + alibi_slopes, + dropout_p, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + deterministic, + None, + rng_state, + ) + return dq, dk, dv, softmax_d + + +def _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + rng_state=None, +): + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + ( + dq, + dk, + dv, + softmax_d, + ) = flash_attn_cuda.varlen_bwd( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + alibi_slopes, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + False, + causal, + window_size[0], + window_size[1], + softcap, + deterministic, + None, + rng_state, + ) + # if dk.isnan().any() or dk.isnan().any() or dv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dq, dk, dv, softmax_d + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + qkv[:, 0], + qkv[:, 1], + qkv[:, 2], + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen = max_seqlen + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens, rng_state = ctx.saved_tensors + qkv_shape = q.shape[:-2] + (3, *q.shape[-2:]) + dqkv = torch.empty(qkv_shape, dtype=q.dtype, device=q.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dqkv[:, 0], + dqkv[:, 1], + dqkv[:, 2], + cu_seqlens, + cu_seqlens, + ctx.max_seqlen, + ctx.max_seqlen, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dqkv = dqkv[..., : dout.shape[-1]] # We could have padded the head dimension + return dqkv, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + q, + kv[:, :, 0], + kv[:, :, 1], + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, :, 0], + dkv[:, :, 1], + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + kv[:, 0], + kv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=None, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq = torch.empty_like(q) + kv_shape = k.shape[:-2] + (2, *k.shape[-2:]) + dkv = torch.empty(kv_shape, dtype=k.dtype, device=k.device) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dkv[:, 0], + dkv[:, 1], + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dkv = dkv[..., : dout.shape[-1]] + return dq, dkv, None, None, None, None, None, None, None, None, None, None, None, None + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_forward( + q, + k, + v, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) + ctx.dropout_p = dropout_p + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, rng_state = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None + + +class FlashAttnVarlenFunc(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_softmax, + block_table, + ): + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = _flash_attn_varlen_forward( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + alibi_slopes=alibi_slopes, + return_softmax=return_softmax and dropout_p > 0, + block_table=block_table, + ) + ctx.save_for_backward( + q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state + ) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.window_size = window_size + ctx.softcap = softcap + ctx.alibi_slopes = alibi_slopes + ctx.deterministic = deterministic + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_varlen_backward( + dout, + q, + k, + v, + out, + softmax_lse, + dq, + dk, + dv, + cu_seqlens_q, + cu_seqlens_k, + ctx.max_seqlen_q, + ctx.max_seqlen_k, + ctx.dropout_p, + ctx.softmax_scale, + ctx.causal, + ctx.window_size, + ctx.softcap, + ctx.alibi_slopes, + ctx.deterministic, + rng_state=rng_state, + ) + dq = dq[..., : dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., : dout.shape[-1]] + dv = dv[..., : dout.shape[-1]] + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None, None, None + + +def flash_attn_qkvpacked_func( + qkv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # <=0.0 means deactivate + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_kvpacked_func and flash_attn_func. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + + Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) is added to + the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnQKVPackedFunc.apply( + qkv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_kvpacked_func( + q, + kv, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + kv: (batch_size, seqlen, 2, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnKVPackedFunc.apply( + q, + kv, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_func( + q, + k, + v, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnFunc.apply( + q, + k, + v, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If Q, K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_varlen_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of Q, K, V. + For multi-query and grouped-query attention (MQA/GQA), please see + flash_attn_varlen_kvpacked_func and flash_attn_varlen_func. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between [i - window_size[0], i + window_size[1]] inclusive. + + Arguments: + qkv: (total, 3, nheads, headdim), where total = total number of tokens in the batch. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of (-alibi_slope * |i - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenQKVPackedFunc.apply( + qkv, + cu_seqlens, + max_seqlen, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, +): + """dropout_p should be set to 0.0 during evaluation + If K, V are already stacked into 1 tensor, this function will be faster than + calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation + of the gradients of K, V. + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + kv: (total_k, 2, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenKVPackedFunc.apply( + q, + kv, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + ) + + +def flash_attn_varlen_func( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p=0.0, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + alibi_slopes=None, + deterministic=False, + return_attn_probs=False, + block_table=None, +): + """dropout_p should be set to 0.0 during evaluation + Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Arguments: + q: (total_q, nheads, headdim), where total_q = total number of query tokens in the batch. + k: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + v: (total_k, nheads_k, headdim), where total_k = total number of key tokens in the batch. + cu_seqlens_q: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_q: int. Maximum query sequence length in the batch. + max_seqlen_k: int. Maximum key sequence length in the batch. + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + deterministic: bool. Whether to use the deterministic implementation of the backward pass, + which is slightly slower and uses more memory. The forward pass is always deterministic. + return_attn_probs: bool. Whether to return the attention probabilities. This option is for + testing only. The returned probabilities are not guaranteed to be correct + (they might not have the right scaling). + Return: + out: (total, nheads, headdim). + softmax_lse [optional, if return_attn_probs=True]: (nheads, total_q_seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + S_dmask [optional, if return_attn_probs=True]: (batch_size, nheads, seqlen, seqlen). + The output of softmax (possibly with different scaling). It also encodes the dropout + pattern (negative means that location was dropped, nonnegative means it was kept). + """ + return FlashAttnVarlenFunc.apply( + q, + k, + v, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + dropout_p, + softmax_scale, + causal, + window_size, + softcap, + alibi_slopes, + deterministic, + return_attn_probs, + block_table, + ) + + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + block_table: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + alibi_slopes=None, + num_splits=0, + return_softmax_lse=False, +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no block_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a block_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim). Similar to k. + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + block_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + alibi_slopes: (nheads,) or (batch_size, nheads), fp32. A bias of + (-alibi_slope * |i + seqlen_k - seqlen_q - j|) + is added to the attention score of query i and key j. + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + cache_batch_idx = maybe_contiguous(cache_batch_idx) + block_table = maybe_contiguous(block_table) + out, softmax_lse = flash_attn_cuda.fwd_kvcache( + q, + k_cache, + v_cache, + k, + v, + cache_seqlens, + rotary_cos, + rotary_sin, + cache_batch_idx, + cache_leftpad, + block_table, + alibi_slopes, + None, + softmax_scale, + causal, + window_size[0], + window_size[1], + softcap, + rotary_interleaved, + num_splits, + ) + return (out, softmax_lse) if return_softmax_lse else out diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton.py new file mode 100644 index 0000000000000000000000000000000000000000..30420c057adf1916e16403d3f0d02d0e26c8b7a3 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton.py @@ -0,0 +1,1160 @@ +""" +*Experimental* implementation of FlashAttention in Triton. +Tested with triton==2.0.0.dev20221202. +Triton 2.0 has a new backend (MLIR) but seems like it doesn't yet work for head dimensions +other than 64: +https://github.com/openai/triton/blob/d376020f90002757eea3ea9475d4f7cfc2ec5ead/python/triton/ops/flash_attention.py#L207 +We'll update this implementation with the new Triton backend once this is fixed. + +We use the FlashAttention implementation from Phil Tillet a starting point. +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +Changes: +- Implement both causal and non-causal attention. +- Implement both self-attention and cross-attention. +- Support arbitrary seqlens (not just multiples of 128), for both forward and backward. +- Support all head dimensions up to 128 (not just 16, 32, 64, 128), for both forward and backward. +- Support attention bias. +- Speed up the forward pass a bit, and only store the LSE instead of m and l. +- Make the backward for d=128 much faster by reducing register spilling. +- Optionally parallelize the backward pass across seqlen_k, to deal with the case of +small batch size * nheads. + +Caution: +- This is an *experimental* implementation. The forward pass should be quite robust but +I'm not 100% sure that the backward pass doesn't have race conditions (due to the Triton compiler). +- This implementation has only been tested on A100. +- If you plan to use headdim other than 64 and 128, you should test for race conditions +(due to the Triton compiler), as done in tests/test_flash_attn.py +"test_flash_attn_triton_race_condition". I've tested and fixed many race conditions +for different head dimensions (40, 48, 64, 128, 80, 88, 96), but I'm still not 100% confident +that there are none left for other head dimensions. + +Differences between this Triton version and the CUDA version: +- Triton version doesn't support dropout. +- Triton forward is generally faster than CUDA forward, while Triton backward is +generally slower than CUDA backward. Overall Triton forward + backward is slightly slower +than CUDA forward + backward. +- Triton version doesn't support different sequence lengths in a batch (i.e., RaggedTensor/NestedTensor). +- Triton version supports attention bias, while CUDA version doesn't. +""" + +import math + +import torch +import triton +import triton.language as tl + + +# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128 +# @triton.autotune( +# configs=[ +# triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1), +# # This config has a race condition when EVEN_M == False, disabling it for now. +# # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1), +# ], +# key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM'] +# ) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _fwd_kernel( + Q, + K, + V, + Bias, + Out, + Lse, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_ob, + stride_oh, + stride_om, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # off_b = tl.program_id(1) + # off_h = tl.program_id(2) + # off_hb = off_b * nheads + off_h + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # Initialize pointers to Q, K, V + # Adding parenthesis around indexing might use int32 math instead of int64 math? + # https://github.com/openai/triton/issues/741 + # I'm seeing a tiny bit of difference (5-7us) + q_ptrs = ( + Q + off_b * stride_qb + off_h * stride_qh + (offs_m[:, None] * stride_qm + offs_d[None, :]) + ) + k_ptrs = ( + K + off_b * stride_kb + off_h * stride_kh + (offs_n[:, None] * stride_kn + offs_d[None, :]) + ) + v_ptrs = ( + V + off_b * stride_vb + off_h * stride_vh + (offs_n[:, None] * stride_vn + offs_d[None, :]) + ) + if BIAS_TYPE == "vector": + b_ptrs = Bias + off_b * stride_bb + off_h * stride_bh + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = ( + Bias + + off_b * stride_bb + + off_h * stride_bh + + (offs_m[:, None] * stride_bm + offs_n[None, :]) + ) + # initialize pointer to m and l + t_ptrs = TMP + off_hb * seqlen_q_rounded + offs_m + lse_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32) + # load q: it will stay in SRAM throughout + # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call + # tl.load(q_ptrs), we get the wrong output! + if EVEN_M & EVEN_N: + if EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + q = tl.load(q_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over k, v and update accumulator + end_n = seqlen_k if not IS_CAUSAL else tl.minimum((start_m + 1) * BLOCK_M, seqlen_k) + for start_n in range(0, end_n, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + k = tl.load(k_ptrs + start_n * stride_kn) + else: + k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + k = tl.load( + k_ptrs + start_n * stride_kn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf")) + if IS_CAUSAL: + qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf")) + if BIAS_TYPE != "none": + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, mask=(start_n + offs_n) < seqlen_k, other=0.0 + ).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs + start_n).to(tl.float32) + else: + bias = tl.load( + b_ptrs + start_n, + mask=(offs_m[:, None] < seqlen_q) + & ((start_n + offs_n)[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler + # can then fuse the mult and add into an fma instruction. But if we have bias we need to + # to multiply with softmax_scale here. + qk = qk * softmax_scale + bias + m_ij = tl.maximum(tl.max(qk, 1), lse_i) + p = tl.exp(qk - m_ij[:, None]) + else: + m_ij = tl.maximum(tl.max(qk, 1) * softmax_scale, lse_i) + p = tl.exp(qk * softmax_scale - m_ij[:, None]) + l_ij = tl.sum(p, 1) + + # scale acc_o + acc_o_scale = tl.exp(m_i - m_ij) + + # # -- update output accumulator -- + # BUG: have to store and immediately load + tl.store(t_ptrs, acc_o_scale) + acc_o_scale = tl.load(t_ptrs) + acc_o = acc_o * acc_o_scale[:, None] + # update acc_o + if EVEN_N & EVEN_M: # If we just do "if EVEN_N", there seems to be some race condition + if EVEN_HEADDIM: + v = tl.load(v_ptrs + start_n * stride_vn) + else: + v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=(start_n + offs_n)[:, None] < seqlen_k, + other=0.0, + ) + else: + v = tl.load( + v_ptrs + start_n * stride_vn, + mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), + other=0.0, + ) + p = p.to(v.dtype) + acc_o += tl.dot(p, v) + + # -- update statistics + m_i = m_ij + l_i_new = tl.exp(lse_i - m_ij) + l_ij + lse_i = m_ij + tl.log(l_i_new) + + o_scale = tl.exp(m_i - lse_i) + # BUG: have to store and immediately load + tl.store(t_ptrs, o_scale) + o_scale = tl.load(t_ptrs) + acc_o = acc_o * o_scale[:, None] + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + lse_ptrs = Lse + off_hb * seqlen_q_rounded + offs_m + tl.store(lse_ptrs, lse_i) + # initialize pointers to output + offs_d = tl.arange(0, BLOCK_HEADDIM) + out_ptrs = ( + Out + + off_b * stride_ob + + off_h * stride_oh + + (offs_m[:, None] * stride_om + offs_d[None, :]) + ) + if EVEN_M: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o) + else: + tl.store(out_ptrs, acc_o, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(out_ptrs, acc_o, mask=offs_m[:, None] < seqlen_q) + else: + tl.store( + out_ptrs, acc_o, mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim) + ) + + +@triton.jit +def _bwd_preprocess_do_o_dot( + Out, + DO, + Delta, + stride_ob, + stride_oh, + stride_om, + stride_dob, + stride_doh, + stride_dom, + nheads, + seqlen_q, + seqlen_q_rounded, + headdim, + BLOCK_M: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, +): + start_m = tl.program_id(0) + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # load + o = tl.load( + Out + off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + do = tl.load( + DO + + off_b * stride_dob + + off_h * stride_doh + + offs_m[:, None] * stride_dom + + offs_d[None, :], + mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ).to(tl.float32) + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(Delta + off_hb * seqlen_q_rounded + offs_m, delta) + + +@triton.jit +def _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, +): + # [2022-11-01] TD: Same bug. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.store(dv_ptrs), there's a race condition + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + else: + tl.store(dv_ptrs, dv, mask=offs_d[None, :] < headdim) + tl.store(dk_ptrs, dk, mask=offs_d[None, :] < headdim) + else: + if EVEN_HEADDIM: + tl.store(dv_ptrs, dv, mask=offs_n[:, None] < seqlen_k) + tl.store(dk_ptrs, dk, mask=offs_n[:, None] < seqlen_k) + else: + tl.store(dv_ptrs, dv, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + tl.store(dk_ptrs, dk, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim)) + + +@triton.jit +def _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD: tl.constexpr, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # We need to make sure begin_m is a multiple of BLOCK_M (not BLOCK_N) + begin_m = 0 if not IS_CAUSAL else ((start_n * BLOCK_N) // BLOCK_M) * BLOCK_M + # initialize row/col offsets + offs_qm = begin_m + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + offs_m = tl.arange(0, BLOCK_M) + offs_d = tl.arange(0, BLOCK_HEADDIM) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_d[None, :]) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_d[None, :]) + v_ptrs = V + (offs_n[:, None] * stride_vn + offs_d[None, :]) + do_ptrs = DO + (offs_qm[:, None] * stride_dom + offs_d[None, :]) + dq_ptrs = DQ + (offs_qm[:, None] * stride_dqm + offs_d[None, :]) + if BIAS_TYPE == "vector": + b_ptrs = Bias + offs_n + elif BIAS_TYPE == "matrix": + b_ptrs = Bias + (offs_qm[:, None] * stride_bm + offs_n[None, :]) + # initialize dv and dk + dv = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + dk = tl.zeros([BLOCK_N, BLOCK_HEADDIM], dtype=tl.float32) + # There seems to be some problem with Triton pipelining that makes results wrong for + # headdim=64, seqlen=(113, 255), bias_type='matrix'. In this case the for loop + # may have zero step, and pipelining with the bias matrix could screw it up. + # So we just exit early. + if begin_m >= seqlen_q: + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + return + # k and v stay in SRAM throughout + # [2022-10-30] TD: Same bug as the fwd. In the case of EVEN_N=True and EVEN_M=False, + # if we just call tl.load(k_ptrs), we get the wrong output! + if EVEN_N & EVEN_M: + if EVEN_HEADDIM: + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + else: + k = tl.load(k_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + v = tl.load(v_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + else: + if EVEN_HEADDIM: + k = tl.load(k_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + v = tl.load(v_ptrs, mask=offs_n[:, None] < seqlen_k, other=0.0) + else: + k = tl.load( + k_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + v = tl.load( + v_ptrs, mask=(offs_n[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0 + ) + # loop over rows + num_block_m = tl.cdiv(seqlen_q, BLOCK_M) + for start_m in range(begin_m, num_block_m * BLOCK_M, BLOCK_M): + start_m = tl.multiple_of(start_m, BLOCK_M) + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + # Same bug as below. Otherwise gives wrong result for headdim=40, seqlen=(128, 117) + if EVEN_M & EVEN_HEADDIM: + q = tl.load(q_ptrs) + else: + if EVEN_HEADDIM: + q = tl.load(q_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + else: + q = tl.load( + q_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # recompute p = softmax(qk, dim=-1).T + qk = tl.dot(q, k, trans_b=True) + # Trying to combine the two masks seem to make the result wrong + if not EVEN_N: # Need to mask out otherwise the softmax is wrong + qk = tl.where(offs_n[None, :] < seqlen_k, qk, float("-inf")) + if IS_CAUSAL: + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + if BIAS_TYPE != "none": + tl.debug_barrier() # Race condition otherwise + if BIAS_TYPE == "vector": + if EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load(b_ptrs, mask=offs_n < seqlen_k, other=0.0).to(tl.float32) + bias = bias[None, :] + elif BIAS_TYPE == "matrix": + if EVEN_M & EVEN_N: + bias = tl.load(b_ptrs).to(tl.float32) + else: + bias = tl.load( + b_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_n[None, :] < seqlen_k), + other=0.0, + ).to(tl.float32) + qk = qk * softmax_scale + bias + # There seems to be a race condition when headdim=48/96, and dq, dk, dv are wrong. + # Also wrong for headdim=64. + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + lse_i = tl.load(LSE + offs_m_curr) + if BIAS_TYPE == "none": + p = tl.exp(qk * softmax_scale - lse_i[:, None]) + else: + p = tl.exp(qk - lse_i[:, None]) + # compute dv + # [2022-10-30] TD: A Triton bug: if EVEN_M=True and EVEN_HEADDIM=False, if we call + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0), we get wrong outputs + # in the case of headdim=48/96, seqlen_q & seqlen_k >= 512. If headdim=40 or seqlen < 512, + # the output is correct. + if EVEN_M & EVEN_HEADDIM: + do = tl.load(do_ptrs) + else: + # [2022-11-01] TD: Triton bug, there's a race condition if we just use m_mask and not d_mask. + do = tl.load( + do_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + ) + # if EVEN_M: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs) + # else: + # do = tl.load(do_ptrs, mask=offs_d[None, :] < headdim, other=0.0) + # else: + # if EVEN_HEADDIM: + # do = tl.load(do_ptrs, mask=offs_m_curr[:, None] < seqlen_q, other=0.0) + # else: + # do = tl.load(do_ptrs, mask=(offs_m_curr[:, None] < seqlen_q) + # & (offs_d[None, :] < headdim), other=0.0) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + # There seems to be a race condition when headdim=48/96, and dq, dk are wrong. + # Also wrong for headdim=128, seqlen=(108, 256), and ATOMIC_ADD=True + # Also wrong for headdim=64, seqlen=(1023, 1024), and ATOMIC_ADD=False + if not (EVEN_M & EVEN_HEADDIM): + tl.debug_barrier() + dp = tl.dot(do, v, trans_b=True) + # There's a race condition for headdim=48 + if not EVEN_HEADDIM: + tl.debug_barrier() + # compute ds = p * (dp - delta[:, None]) + # Putting the subtraction after the dp matmul (instead of before) is slightly faster + Di = tl.load(D + offs_m_curr) + # Converting ds to q.dtype here reduces register pressure and makes it much faster + # for BLOCK_HEADDIM=128 + ds = (p * (dp - Di[:, None]) * softmax_scale).to(q.dtype) + # compute dk = dot(ds.T, q) + dk += tl.dot(ds, q, trans_a=True) + # compute dq + if not ( + EVEN_M & EVEN_HEADDIM + ): # Otherewise there's a race condition when BIAS_TYPE='matrix' + tl.debug_barrier() + if not ATOMIC_ADD: + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds, k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + else: + if EVEN_HEADDIM: + dq = tl.load( + dq_ptrs, + mask=offs_m_curr[:, None] < seqlen_q, + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=offs_m_curr[:, None] < seqlen_q, + eviction_policy="evict_last", + ) + else: + dq = tl.load( + dq_ptrs, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + other=0.0, + eviction_policy="evict_last", + ) + dq += tl.dot(ds, k) + tl.store( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + eviction_policy="evict_last", + ) + else: # If we're parallelizing across the seqlen_k dimension + dq = tl.dot(ds, k) + if EVEN_M & EVEN_HEADDIM: # Race condition if we just do EVEN_M + tl.atomic_add(dq_ptrs, dq) + else: + if EVEN_HEADDIM: + tl.atomic_add(dq_ptrs, dq, mask=offs_m_curr[:, None] < seqlen_q) + else: + tl.atomic_add( + dq_ptrs, + dq, + mask=(offs_m_curr[:, None] < seqlen_q) & (offs_d[None, :] < headdim), + ) + # increment pointers + dq_ptrs += BLOCK_M * stride_dqm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_dom + if BIAS_TYPE == "matrix": + b_ptrs += BLOCK_M * stride_bm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_dvn + offs_d[None, :]) + dk_ptrs = DK + (offs_n[:, None] * stride_dkn + offs_d[None, :]) + _bwd_store_dk_dv( + dk_ptrs, + dv_ptrs, + dk, + dv, + offs_n, + offs_d, + seqlen_k, + headdim, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + ) + + +def init_to_zero(name): + return lambda nargs: nargs[name].zero_() + + +@triton.autotune( + configs=[ + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": False}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + triton.Config( + {"BLOCK_M": 128, "BLOCK_N": 128, "SEQUENCE_PARALLEL": True}, + num_warps=8, + num_stages=1, + pre_hook=init_to_zero("DQ"), + ), + # Other configs seem to give wrong results when seqlen_q % 128 != 0, disabling them for now + # # Kernel is buggy (give wrong result) if we set BLOCK_m=128, BLOCK_n=64, num_warps=*4* + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 128, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=8, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": False}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64, "SEQUENCE_PARALLEL": True}, num_warps=4, num_stages=1, pre_hook=init_to_zero('DQ')), + ], + key=["CACHE_KEY_SEQLEN_Q", "CACHE_KEY_SEQLEN_K", "BIAS_TYPE", "IS_CAUSAL", "BLOCK_HEADDIM"], +) +@triton.heuristics( + { + "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0, + "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0, + "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"], + } +) +@triton.jit +def _bwd_kernel( + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qb, + stride_qh, + stride_qm, + stride_kb, + stride_kh, + stride_kn, + stride_vb, + stride_vh, + stride_vn, + stride_bb, + stride_bh, + stride_bm, + stride_dob, + stride_doh, + stride_dom, + stride_dqb, + stride_dqh, + stride_dqm, + stride_dkb, + stride_dkh, + stride_dkn, + stride_dvb, + stride_dvh, + stride_dvn, + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + headdim, + CACHE_KEY_SEQLEN_Q, + CACHE_KEY_SEQLEN_K, + BIAS_TYPE: tl.constexpr, + IS_CAUSAL: tl.constexpr, + BLOCK_HEADDIM: tl.constexpr, + SEQUENCE_PARALLEL: tl.constexpr, + EVEN_M: tl.constexpr, + EVEN_N: tl.constexpr, + EVEN_HEADDIM: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hb = tl.program_id(1) + off_b = off_hb // nheads + off_h = off_hb % nheads + # offset pointers for batch/head + Q += off_b * stride_qb + off_h * stride_qh + K += off_b * stride_kb + off_h * stride_kh + V += off_b * stride_vb + off_h * stride_vh + DO += off_b * stride_dob + off_h * stride_doh + DQ += off_b * stride_dqb + off_h * stride_dqh + DK += off_b * stride_dkb + off_h * stride_dkh + DV += off_b * stride_dvb + off_h * stride_dvh + if BIAS_TYPE != "none": + Bias += off_b * stride_bb + off_h * stride_bh + # pointer to row-wise quantities in value-like data + D += off_hb * seqlen_q_rounded + LSE += off_hb * seqlen_q_rounded + if not SEQUENCE_PARALLEL: + num_block_n = tl.cdiv(seqlen_k, BLOCK_N) + for start_n in range(0, num_block_n): + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=False, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + else: + start_n = tl.program_id(0) + _bwd_kernel_one_col_block( + start_n, + Q, + K, + V, + Bias, + DO, + DQ, + DK, + DV, + LSE, + D, + softmax_scale, + stride_qm, + stride_kn, + stride_vn, + stride_bm, + stride_dom, + stride_dqm, + stride_dkn, + stride_dvn, + seqlen_q, + seqlen_k, + headdim, + ATOMIC_ADD=True, + BIAS_TYPE=BIAS_TYPE, + IS_CAUSAL=IS_CAUSAL, + BLOCK_HEADDIM=BLOCK_HEADDIM, + EVEN_M=EVEN_M, + EVEN_N=EVEN_N, + EVEN_HEADDIM=EVEN_HEADDIM, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + ) + + +def _flash_attn_forward(q, k, v, bias=None, causal=False, softmax_scale=None): + # shape constraints + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + assert k.shape == (batch, seqlen_k, nheads, d) + assert v.shape == (batch, seqlen_k, nheads, d) + assert d <= 128, "FlashAttention only support head dimensions up to 128" + assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type" + assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16" + assert q.is_cuda and k.is_cuda and v.is_cuda + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + if bias.stride(-1) != 1: + bias = bias.contiguous() + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError( + "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" + ) + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + lse = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + tmp = torch.empty((batch, nheads, seqlen_q_rounded), device=q.device, dtype=torch.float32) + o = torch.empty_like(q) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + BLOCK = 128 + num_warps = 4 if d <= 64 else 8 + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _fwd_kernel[grid]( + q, + k, + v, + bias, + o, + lse, + tmp, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + o.stride(0), + o.stride(2), + o.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return o, lse, softmax_scale # softmax_scale could have been updated + + +def _flash_attn_backward( + do, q, k, v, o, lse, dq, dk, dv, bias=None, causal=False, softmax_scale=None +): + # Make sure that the last dimension is contiguous + if do.stride(-1) != 1: + do = do.contiguous() + batch, seqlen_q, nheads, d = q.shape + _, seqlen_k, _, _ = k.shape + # assert d in {16, 32, 64, 128} + assert d <= 128 + seqlen_q_rounded = math.ceil(seqlen_q / 128) * 128 + assert lse.shape == (batch, nheads, seqlen_q_rounded) + assert q.stride(-1) == k.stride(-1) == v.stride(-1) == o.stride(-1) == 1 + assert dq.stride(-1) == dk.stride(-1) == dv.stride(-1) == 1 + softmax_scale = softmax_scale or 1.0 / math.sqrt(d) + # dq_accum = torch.zeros_like(q, dtype=torch.float32) + dq_accum = torch.empty_like(q, dtype=torch.float32) + delta = torch.empty_like(lse) + # delta = torch.zeros_like(lse) + + BLOCK_HEADDIM = max(triton.next_power_of_2(d), 16) + grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads) + _bwd_preprocess_do_o_dot[grid]( + o, + do, + delta, + o.stride(0), + o.stride(2), + o.stride(1), + do.stride(0), + do.stride(2), + do.stride(1), + nheads, + seqlen_q, + seqlen_q_rounded, + d, + BLOCK_M=128, + BLOCK_HEADDIM=BLOCK_HEADDIM, + ) + + has_bias = bias is not None + bias_type = "none" + if has_bias: + assert bias.dtype in [q.dtype, torch.float] + assert bias.is_cuda + assert bias.dim() == 4 + assert bias.stride(-1) == 1 + if bias.shape[2:] == (1, seqlen_k): + bias_type = "vector" + elif bias.shape[2:] == (seqlen_q, seqlen_k): + bias_type = "matrix" + else: + raise RuntimeError( + "Last 2 dimensions of bias must be (1, seqlen_k)" " or (seqlen_q, seqlen_k)" + ) + bias = bias.expand(batch, nheads, seqlen_q, seqlen_k) + bias_strides = (bias.stride(0), bias.stride(1), bias.stride(2)) if has_bias else (0, 0, 0) + + # BLOCK_M = 128 + # BLOCK_N = 64 + # num_warps = 4 + grid = lambda META: ( + triton.cdiv(seqlen_k, META["BLOCK_N"]) if META["SEQUENCE_PARALLEL"] else 1, + batch * nheads, + ) + _bwd_kernel[grid]( + q, + k, + v, + bias, + do, + dq_accum, + dk, + dv, + lse, + delta, + softmax_scale, + q.stride(0), + q.stride(2), + q.stride(1), + k.stride(0), + k.stride(2), + k.stride(1), + v.stride(0), + v.stride(2), + v.stride(1), + *bias_strides, + do.stride(0), + do.stride(2), + do.stride(1), + dq_accum.stride(0), + dq_accum.stride(2), + dq_accum.stride(1), + dk.stride(0), + dk.stride(2), + dk.stride(1), + dv.stride(0), + dv.stride(2), + dv.stride(1), + nheads, + seqlen_q, + seqlen_k, + seqlen_q_rounded, + d, + seqlen_q // 32, + seqlen_k // 32, # key for triton cache (limit number of compilations) + # Can't use kwargs here because triton autotune expects key to be args, not kwargs + # IS_CAUSAL=causal, BLOCK_HEADDIM=d, + bias_type, + causal, + BLOCK_HEADDIM, + # SEQUENCE_PARALLEL=False, + # BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, + # num_warps=num_warps, + # num_stages=1, + ) + dq.copy_(dq_accum) + + +class FlashAttnQKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, bias=None, causal=False, softmax_scale=None): + """ + qkv: (batch, seqlen, 3, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen, seqlen). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen). + ALiBi mask for non-causal would have shape (1, nheads, seqlen, seqlen) + """ + # Make sure that the last dimension is contiguous + if qkv.stride(-1) != 1: + qkv = qkv.contiguous() + o, lse, ctx.softmax_scale = _flash_attn_forward( + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + bias=bias, + causal=causal, + softmax_scale=softmax_scale, + ) + ctx.save_for_backward(qkv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + qkv, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[1], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dqkv = torch.empty_like(qkv) + _flash_attn_backward( + do, + qkv[:, :, 0], + qkv[:, :, 1], + qkv[:, :, 2], + o, + lse, + dqkv[:, :, 0], + dqkv[:, :, 1], + dqkv[:, :, 2], + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dqkv, None, None, None + + +flash_attn_qkvpacked_func = FlashAttnQKVPackedFunc.apply + + +class FlashAttnKVPackedFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, bias=None, causal=False, softmax_scale=None): + """ + q: (batch, seqlen_q, nheads, headdim) + kv: (batch, seqlen_k, 2, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, kv = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, kv]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, kv[:, :, 0], kv[:, :, 1], bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, kv, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, kv, o, lse, bias = ctx.saved_tensors + if len(ctx.needs_input_grad) >= 3: + assert not ctx.needs_input_grad[2], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dkv = torch.empty_like(kv) + _flash_attn_backward( + do, + q, + kv[:, :, 0], + kv[:, :, 1], + o, + lse, + dq, + dkv[:, :, 0], + dkv[:, :, 1], + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dkv, None, None, None + + +flash_attn_kvpacked_func = FlashAttnKVPackedFunc.apply + + +class FlashAttnFunc(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, bias=None, causal=False, softmax_scale=None): + """ + q: (batch_size, seqlen_q, nheads, headdim) + k, v: (batch_size, seqlen_k, nheads, headdim) + bias: optional, shape broadcastible to (batch, nheads, seqlen_q, seqlen_k). + For example, ALiBi mask for causal would have shape (1, nheads, 1, seqlen_k). + ALiBi mask for non-causal would have shape (1, nheads, seqlen_q, seqlen_k) + """ + # Make sure that the last dimension is contiguous + q, k, v = [x if x.stride(-1) == 1 else x.contiguous() for x in [q, k, v]] + o, lse, ctx.softmax_scale = _flash_attn_forward( + q, k, v, bias=bias, causal=causal, softmax_scale=softmax_scale + ) + ctx.save_for_backward(q, k, v, o, lse, bias) + ctx.causal = causal + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, lse, bias = ctx.saved_tensors + assert not ctx.needs_input_grad[3], "FlashAttention does not support bias gradient yet" + # Triton's autotune causes the Tensor._version to change, and so Pytorch autograd + # does a memcpy. To avoid this we run in inference_mode, which doesn't track the version. + with torch.inference_mode(): + dq = torch.empty_like(q) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + _flash_attn_backward( + do, + q, + k, + v, + o, + lse, + dq, + dk, + dv, + bias=bias, + causal=ctx.causal, + softmax_scale=ctx.softmax_scale, + ) + return dq, dk, dv, None, None, None + + +flash_attn_func = FlashAttnFunc.apply diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton_og.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton_og.py new file mode 100644 index 0000000000000000000000000000000000000000..f2ddb99487b4f162745e2f6dd3d1744946dc3fb2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_attn_triton_og.py @@ -0,0 +1,365 @@ +# [2022-10-23] Downloaded from https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py +# for benchmarking. +# We fixed a few dtype cast to make it work for bf16 + +""" +Fused Attention +=============== +This is a Triton implementation of the Flash Attention algorithm +(see: Dao et al., https://arxiv.org/pdf/2205.14135v2.pdf; Rabe and Staats https://arxiv.org/pdf/2112.05682v2.pdf) +""" + +import pytest +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + TMP, + L, + M, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + stride_oz, + stride_oh, + stride_om, + stride_on, + Z, + H, + N_CTX, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk + off_k = off_hz * stride_qh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk + off_v = off_hz * stride_qh + offs_n[:, None] * stride_qm + offs_d[None, :] * stride_qk + # Initialize pointers to Q, K, V + q_ptrs = Q + off_q + k_ptrs = K + off_k + v_ptrs = V + off_v + # initialize pointer to m and l + t_ptrs = TMP + off_hz * N_CTX + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # load q: it will stay in SRAM throughout + q = tl.load(q_ptrs) + # loop over k, v and update accumulator + for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + start_n * stride_kn) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k, trans_b=True) + qk *= sm_scale + qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float("-inf")) + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + start_n * stride_vk) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # rematerialize offsets to save registers + start_m = tl.program_id(0) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + m_ptrs = M + off_hz * N_CTX + offs_m + tl.store(l_ptrs, l_i) + tl.store(m_ptrs, m_i) + # initialize pointers to output + offs_n = tl.arange(0, BLOCK_DMODEL) + off_o = off_hz * stride_oh + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on + out_ptrs = Out + off_o + tl.store(out_ptrs, acc) + + +@triton.jit +def _bwd_preprocess( + Out, + DO, + L, + NewDO, + Delta, + BLOCK_M: tl.constexpr, + D_HEAD: tl.constexpr, +): + off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M) + off_n = tl.arange(0, D_HEAD) + # load + o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32) + denom = tl.load(L + off_m).to(tl.float32) + # compute + do = do / denom[:, None] + delta = tl.sum(o * do, axis=1) + # write-back + tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do) + tl.store(Delta + off_m, delta) + + +@triton.jit +def _bwd_kernel( + Q, + K, + V, + sm_scale, + Out, + DO, + DQ, + DK, + DV, + L, + M, + D, + stride_qz, + stride_qh, + stride_qm, + stride_qk, + stride_kz, + stride_kh, + stride_kn, + stride_kk, + stride_vz, + stride_vh, + stride_vk, + stride_vn, + Z, + H, + N_CTX, + num_block, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + off_hz = tl.program_id(0) + off_z = off_hz // H + off_h = off_hz % H + # offset pointers for batch/head + Q += off_z * stride_qz + off_h * stride_qh + K += off_z * stride_qz + off_h * stride_qh + V += off_z * stride_qz + off_h * stride_qh + DO += off_z * stride_qz + off_h * stride_qh + DQ += off_z * stride_qz + off_h * stride_qh + DK += off_z * stride_qz + off_h * stride_qh + DV += off_z * stride_qz + off_h * stride_qh + for start_n in range(0, num_block): + lo = start_n * BLOCK_M + # initialize row/col offsets + offs_qm = lo + tl.arange(0, BLOCK_M) + offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M) + offs_m = tl.arange(0, BLOCK_N) + offs_k = tl.arange(0, BLOCK_DMODEL) + # initialize pointers to value-like data + q_ptrs = Q + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + k_ptrs = K + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + v_ptrs = V + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + do_ptrs = DO + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk) + # pointer to row-wise quantities in value-like data + D_ptrs = D + off_hz * N_CTX + m_ptrs = M + off_hz * N_CTX + # initialize dv amd dk + dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # k and v stay in SRAM throughout + k = tl.load(k_ptrs) + v = tl.load(v_ptrs) + # loop over rows + for start_m in range(lo, num_block * BLOCK_M, BLOCK_M): + offs_m_curr = start_m + offs_m + # load q, k, v, do on-chip + q = tl.load(q_ptrs) + # recompute p = softmax(qk, dim=-1).T + # NOTE: `do` is pre-divided by `l`; no normalization here + qk = tl.dot(q, k, trans_b=True) + qk = tl.where(offs_m_curr[:, None] >= (offs_n[None, :]), qk, float("-inf")) + m = tl.load(m_ptrs + offs_m_curr) + p = tl.exp(qk * sm_scale - m[:, None]) + # compute dv + do = tl.load(do_ptrs) + dv += tl.dot(p.to(do.dtype), do, trans_a=True) + # compute dp = dot(v, do) + Di = tl.load(D_ptrs + offs_m_curr) + dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None] + dp += tl.dot(do, v, trans_b=True) + # compute ds = p * (dp - delta[:, None]) + ds = p * dp * sm_scale + # compute dk = dot(ds.T, q) + dk += tl.dot(ds.to(q.dtype), q, trans_a=True) + # # compute dq + dq = tl.load(dq_ptrs, eviction_policy="evict_last") + dq += tl.dot(ds.to(k.dtype), k) + tl.store(dq_ptrs, dq, eviction_policy="evict_last") + # # increment pointers + dq_ptrs += BLOCK_M * stride_qm + q_ptrs += BLOCK_M * stride_qm + do_ptrs += BLOCK_M * stride_qm + # write-back + dv_ptrs = DV + (offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk) + dk_ptrs = DK + (offs_n[:, None] * stride_kn + offs_k[None, :] * stride_kk) + tl.store(dv_ptrs, dv) + tl.store(dk_ptrs, dk) + + +class _attention(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, sm_scale): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1]) + tmp = torch.empty( + (q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32 + ) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + m = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + tmp, + L, + m, + o, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + o.stride(0), + o.stride(1), + o.stride(2), + o.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + BLOCK_M=BLOCK, + BLOCK_N=BLOCK, + BLOCK_DMODEL=Lk, + num_warps=num_warps, + num_stages=1, + ) + ctx.save_for_backward(q, k, v, o, L, m) + ctx.BLOCK = BLOCK + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + return o + + @staticmethod + def backward(ctx, do): + q, k, v, o, l, m = ctx.saved_tensors + do = do.contiguous() + dq = torch.zeros_like(q, dtype=torch.float32) + dk = torch.empty_like(k) + dv = torch.empty_like(v) + do_scaled = torch.empty_like(do) + delta = torch.empty_like(l) + _bwd_preprocess[(ctx.grid[0] * ctx.grid[1],)]( + o, + do, + l, + do_scaled, + delta, + BLOCK_M=ctx.BLOCK, + D_HEAD=ctx.BLOCK_DMODEL, + ) + + # NOTE: kernel currently buggy for other values of `num_warps` + num_warps = 8 + _bwd_kernel[(ctx.grid[1],)]( + q, + k, + v, + ctx.sm_scale, + o, + do_scaled, + dq, + dk, + dv, + l, + m, + delta, + q.stride(0), + q.stride(1), + q.stride(2), + q.stride(3), + k.stride(0), + k.stride(1), + k.stride(2), + k.stride(3), + v.stride(0), + v.stride(1), + v.stride(2), + v.stride(3), + q.shape[0], + q.shape[1], + q.shape[2], + ctx.grid[0], + BLOCK_M=ctx.BLOCK, + BLOCK_N=ctx.BLOCK, + BLOCK_DMODEL=ctx.BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + return dq.to(q.dtype), dk, dv, None + + +attention = _attention.apply diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attention.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..03798d16ffbb3cbf1806296d5b33f81360717315 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attention.py @@ -0,0 +1,197 @@ +import math + +import hydra +import torch +import torch.nn as nn +from einops import rearrange + +from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input +from flash_attn.flash_blocksparse_attn_interface import ( + convert_blockmask, + flash_blocksparse_attn_func, +) + + +class FlashBlocksparseAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_temp: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.1) + """ + + def __init__( + self, + sparsity_config, + softmax_temp=None, + attention_dropout=0.0, + max_seq_length=2048, + device=None, + dtype=None, + ): + super().__init__() + self.sparsity_config = hydra.utils.instantiate(sparsity_config) + self.softmax_temp = softmax_temp + self.dropout_p = attention_dropout + + # initialize sparse layout and register as buffer + max_seq_length = ((max_seq_length + 256 - 1) // 256) * 256 + layout = self.sparsity_config.make_layout(max_seq_length) + self.register_buffer("layout", layout) + blockmask_converted = convert_blockmask(self.layout, causal=False) + self.register_buffer("blockmask_converted", blockmask_converted) + # logger.info(f'Attention class {self.__class__}: saving={self.layout.float().mean()}') + + def forward( + self, + qkv, + attn_mask=None, + key_padding_mask=None, + causal=False, + cu_seqlens=None, + max_s=None, + need_weights=False, + convert_mask=True, + ): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None + attn_mask: An implementation of BaseMask that encodes where each + query can attend to + key_padding_mask: An implementation of BaseMask that encodes how + many query each sequence in the batch consists of + """ + assert not need_weights + assert attn_mask is None + assert qkv.dtype == torch.float16 + assert qkv.is_cuda + + if cu_seqlens is None: + batch_size = qkv.shape[0] + seqlen = qkv.shape[1] + # Convert mask to take a subset + seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 + assert seqlen_rounded // 16 <= self.layout.shape[0], ( + seqlen_rounded // 256 <= self.layout.shape[1] + ) + blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] + if key_padding_mask is None: + qkv = rearrange(qkv, "b s ... -> (b s) ...") + max_s = seqlen + cu_seqlens = torch.arange( + 0, (batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, device=qkv.device + ) + output = flash_blocksparse_attn_func( + qkv, + cu_seqlens, + blockmask, + self.dropout_p if self.training else 0.0, + max_s, + softmax_scale=self.softmax_temp, + causal=causal, + ) + output = rearrange(output, "(b s) ... -> b s ...", b=batch_size) + else: + key_padding_mask_bool = key_padding_mask.bool_matrix + nheads = qkv.shape[-2] + x = rearrange(qkv, "b s three h d -> b s (three h d)") + x_unpad, indices, cu_seqlens, max_s = unpad_input(x, key_padding_mask_bool) + x_unpad = rearrange(x_unpad, "nnz (three h d) -> nnz three h d", three=3, h=nheads) + output_unpad = flash_blocksparse_attn_func( + x_unpad, + cu_seqlens, + blockmask, + self.dropout_p if self.training else 0.0, + max_s, + softmax_scale=self.softmax_temp, + causal=causal, + ) + output = rearrange( + pad_input( + rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices, batch_size, seqlen + ), + "b s (h d) -> b s h d", + h=nheads, + ) + else: + assert max_s is not None + seqlen = max_s + # Convert mask to take a subset + seqlen_rounded = ((seqlen + 256 - 1) // 256) * 256 + assert seqlen_rounded // 16 <= self.layout.shape[0], ( + seqlen_rounded // 256 <= self.layout.shape[1] + ) + blockmask = self.layout[: seqlen_rounded // 16, : seqlen_rounded // 256] + if convert_mask: + output = flash_blocksparse_attn_func( + qkv, + cu_seqlens, + blockmask, + self.dropout_p if self.training else 0.0, + max_s, + softmax_scale=self.softmax_temp, + causal=causal, + ) + else: + output = flash_blocksparse_attn_func( + qkv, + cu_seqlens, + self.blockmask_converted, + self.dropout_p if self.training else 0.0, + max_s, + softmax_scale=self.softmax_temp, + causal=causal, + convert_mask=False, + ) + + return output, None + + +class FlashBlocksparseMHA(nn.Module): + def __init__( + self, + embed_dim, + num_heads, + sparsity_config, + bias=True, + batch_first=True, + attention_dropout=0.0, + causal=False, + max_seq_length=2048, + device=None, + dtype=None, + **kwargs, + ) -> None: + assert batch_first + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + + self.num_heads = num_heads + assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + assert self.head_dim in [16, 32, 64], "Only support head_dim == 16, 32, or 64" + + self.Wqkv = nn.Linear(embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs) + self.inner_attn = FlashBlocksparseAttention( + sparsity_config, + attention_dropout=attention_dropout, + max_seq_length=max_seq_length, + **factory_kwargs, + ) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) + + def forward( + self, x, x_ignored_, x_ignored_1_, attn_mask=None, key_padding_mask=None, need_weights=False + ): + qkv = self.Wqkv(x) + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads) + context, attn_weights = self.inner_attn( + qkv, key_padding_mask=key_padding_mask, need_weights=need_weights, causal=self.causal + ) + return self.out_proj(rearrange(context, "b s h d -> b s (h d)")), attn_weights diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attn_interface.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attn_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..9ce3fe8c1344dd33165c43e4cc1ef0f70feb5d04 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/flash_blocksparse_attn_interface.py @@ -0,0 +1,200 @@ +# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/fmha.py +import flash_attn_cuda +import torch +import torch.nn as nn + + +def convert_blockmask(blockmask, causal): + """Convert from the 0-1 format to the format used by the CUDA code. + 0 means the block is skipped. + nonzero means the block is not skipped. + Argument: + blockmask: (row, col): a 0-1 tensor + Return: + blockmask_converted: (col, row), dtype torch.int32: for each column, it contains the row + indices of the nonzero blocks, padded with -1 to reach length @row. + The indices are multiplied by 4, with the smallest bit used to encode whether + it is the first nonzero in its row, and the 2nd smallest bit to encode whether it is + the last nonzero in its row.. + """ + assert not causal + # TD [2022-05-13]: The indexing and sorting is very tricky + nrow, ncol = blockmask.shape + # Sort does not support bool on CUDA + blockmask = blockmask.to(dtype=torch.uint8) + nonzero_val, nonzero_sorted_rowidx = blockmask.sort(dim=0, stable=True, descending=True) + nonzero_unsorted_rowidx = nonzero_sorted_rowidx.argsort(dim=0) + last_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True).indices[:, -1] + last_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ + torch.arange(nrow, device=blockmask.device), last_nonzero_col_per_row + ] + first_nonzero_col_per_row = blockmask.sort(dim=-1, stable=True, descending=True).indices[:, 0] + first_nonzero_col_per_row_after_sort = nonzero_unsorted_rowidx[ + torch.arange(nrow, device=blockmask.device), first_nonzero_col_per_row + ] + nonzero_idx = nonzero_sorted_rowidx * 4 + nonzero_idx[last_nonzero_col_per_row_after_sort, last_nonzero_col_per_row] += 2 + nonzero_idx[first_nonzero_col_per_row_after_sort, first_nonzero_col_per_row] += 1 + nonzero_idx[nonzero_val == 0] = -1 + return nonzero_idx.T.contiguous().to(dtype=torch.int32) + + +def _flash_blocksparse_attn_forward( + qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax +): + context, softmax_lse, *rest = flash_attn_cuda.fwd_block( + qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal, return_softmax, None + ) + # if context.isnan().any() or softmax_lse.isnan().any(): + # breakpoint() + S_dmask = rest[0] if return_softmax else None + return context, softmax_lse, S_dmask + + +def _flash_blocksparse_attn_backward( + dout, + qkv, + out, + S_dmask, + softmax_lse, + cu_seqlens, + blockmask, + dropout_p, + max_s, + softmax_scale, + causal, +): + dqkv, dp, softmax_d = flash_attn_cuda.bwd_block( + dout, + qkv, + out, + S_dmask, + softmax_lse, + cu_seqlens, + blockmask, + dropout_p, + softmax_scale, + max_s, + causal, + None, + ) + # if dqkv.isnan().any() or softmax_d.isnan().any(): + # breakpoint() + return dqkv + + +class FlashBlocksparseAttnFun(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): + # Save rng_state because the backward pass will regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( + qkv, + cu_seqlens, + blockmask, + dropout_p, + max_s, + softmax_scale, + causal=causal, + return_softmax=False, + ) + ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) + ctx.dropout_p = dropout_p + ctx.max_s = max_s + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return context + + @staticmethod + def backward(ctx, dout): + qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + # S_dmask is None, temporarily use another tensor just to get it running + dqkv = _flash_blocksparse_attn_backward( + dout, + qkv, + context, + context, + softmax_lse, + cu_seqlens, + blockmask, + ctx.dropout_p, + ctx.max_s, + ctx.softmax_scale, + ctx.causal, + ) + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None, None + + +# We duplicate code to return both the output and the softmax for testing +# Returning both makes backward a bit slower, so we want to keep using the other version for speed. +class FlashBlocksparseAttnFunWithS(torch.autograd.Function): + @staticmethod + def forward(ctx, qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal): + # Save rng_state because the backward pass is gonna regenerate the dropout mask + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + if softmax_scale is None: + softmax_scale = qkv.shape[-1] ** (-0.5) + context, softmax_lse, S_dmask = _flash_blocksparse_attn_forward( + qkv, + cu_seqlens, + blockmask, + dropout_p, + max_s, + softmax_scale, + causal=causal, + return_softmax=True, + ) + ctx.save_for_backward(qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state) + ctx.dropout_p = dropout_p + ctx.max_s = max_s + ctx.softmax_scale = softmax_scale + ctx.causal = causal + return context, S_dmask, softmax_lse + + @staticmethod + def backward(ctx, dout, _dS_dmask_ignored, _dsoftmax_sum_ignored): + qkv, context, S_dmask, softmax_lse, cu_seqlens, blockmask, rng_state = ctx.saved_tensors + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dqkv = _flash_blocksparse_attn_backward( + dout, + qkv, + context, + S_dmask, + softmax_lse, + cu_seqlens, + blockmask, + ctx.dropout_p, + ctx.max_s, + ctx.softmax_scale, + ctx.causal, + ) + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + return dqkv, None, None, None, None, None, None + + +def flash_blocksparse_attn_func( + qkv, + cu_seqlens, + blockmask, + dropout_p, + max_s, + softmax_scale=None, + causal=False, + return_attn_probs=False, + convert_mask=True, +): + """dropout_p should be set to 0.0 during evaluation""" + func = FlashBlocksparseAttnFun if not return_attn_probs else FlashBlocksparseAttnFunWithS + if convert_mask: + blockmask = convert_blockmask(blockmask, causal=causal) + return func.apply(qkv, cu_seqlens, blockmask, dropout_p, max_s, softmax_scale, causal) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/fused_softmax.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/fused_softmax.py new file mode 100644 index 0000000000000000000000000000000000000000..382f94f092cd3999b2378dfc2fa165a7c08017e2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/fused_softmax.py @@ -0,0 +1,201 @@ +# [2022-10-23] Copied from https://github.com/NVIDIA/apex/blob/master/apex/transformer/functional/fused_softmax.py +# for benchmarking. +# We added support for seqlen=2k and seqlen=4k + +# coding=utf-8 +# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from apex._autocast_utils import _cast_if_autocast_enabled +from apex.transformer.enums import AttnMaskType +from fused_softmax_lib import ( + scaled_masked_softmax_backward, + scaled_masked_softmax_forward, + scaled_masked_softmax_get_batch_per_block, + scaled_upper_triang_masked_softmax_backward, + scaled_upper_triang_masked_softmax_forward, +) + + +class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function): + """ + Fused operation which performs following three operations in sequence + 1. Scale the tensor. + 2. Apply upper triangular mask (typically used in gpt models). + 3. Perform softmax. + """ + + @staticmethod + def forward(ctx, inputs, scale): + scale_t = torch.tensor([scale]) + softmax_results = scaled_upper_triang_masked_softmax_forward(inputs, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_upper_triang_masked_softmax_backward( + output_grads, softmax_results, scale_t[0] + ) + return input_grads, None + + +def scaled_upper_triang_masked_softmax(inputs, _, scale): + b, np, sq, sk = inputs.size() + assert sq == sk, "causal mask is only for self attention" + # Reshaping input to 3D tensor (attn_batches, sq, sk) + inputs = inputs.view(-1, sq, sk) + args = _cast_if_autocast_enabled(inputs, scale) + with torch.cuda.amp.autocast(enabled=False): + probs = ScaledUpperTriangMaskedSoftmax.apply(*args) + return probs.view(b, np, sq, sk) + + +# NOTE (mkozuki): `ScaledMaskedSoftmax` somehow doesn't work well with `torch.cuda.amp.custom_fwd`. +# Without `cast_inputs` kwarg, somehow inputs are not cast to dtype used in the autocast context. +# So I needed to manually write two `torch.autograd.Function` inheritances. +# Fused operation which performs following three operations in sequence +# 1. Scale the tensor. +# 2. Apply the mask. +# 3. Perform softmax. +class ScaledMaskedSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, inputs, mask, scale): + scale_t = torch.tensor([scale]) + softmax_results = scaled_masked_softmax_forward(inputs, mask, scale_t[0]) + ctx.save_for_backward(softmax_results, scale_t) + return softmax_results + + @staticmethod + def backward(ctx, output_grads): + softmax_results, scale_t = ctx.saved_tensors + input_grads = scaled_masked_softmax_backward(output_grads, softmax_results, scale_t[0]) + return input_grads, None, None + + +def scaled_masked_softmax(inputs, mask, scale): + # input is 4D tensor (b, np, sq, sk) + args = _cast_if_autocast_enabled(inputs, mask, scale) + with torch.cuda.amp.autocast(enabled=False): + return ScaledMaskedSoftmax.apply(*args) + + +class FusedScaleMaskSoftmax(torch.nn.Module): + """ + fused operation: scaling + mask + softmax + + Arguments: + input_in_fp16: flag to indicate if input in fp16 data format. + input_in_bf16: flag to indicate if input in bf16 data format. + attn_mask_type: attention mask type (pad or causal) + scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion + mask_func: mask function to be applied. + softmax_in_fp32: if true, softmax in performed at fp32 precision. + scale: scaling factor used in input tensor scaling. + """ + + def __init__( + self, + input_in_fp16, + input_in_bf16, + attn_mask_type, + scaled_masked_softmax_fusion, + mask_func, + softmax_in_fp32, + scale, + ): + super().__init__() + self.input_in_fp16 = input_in_fp16 + self.input_in_bf16 = input_in_bf16 + if self.input_in_fp16 and self.input_in_bf16: + raise RuntimeError("both fp16 and bf16 flags cannot be active at the same time.") + self.input_in_float16 = self.input_in_fp16 or self.input_in_bf16 + self.attn_mask_type = attn_mask_type + self.scaled_masked_softmax_fusion = scaled_masked_softmax_fusion + self.mask_func = mask_func + self.softmax_in_fp32 = softmax_in_fp32 + self.scale = scale + + if not (self.scale is None or softmax_in_fp32): + raise RuntimeError("softmax should be in fp32 when scaled") + + if self.scaled_masked_softmax_fusion: + if self.attn_mask_type == AttnMaskType.causal: + self.fused_softmax_func = scaled_upper_triang_masked_softmax + elif self.attn_mask_type == AttnMaskType.padding: + self.fused_softmax_func = scaled_masked_softmax + else: + raise ValueError("Invalid attn_mask_type.") + + def forward(self, input, mask): + # [b, np, sq, sk] + assert input.dim() == 4 + + if self.is_kernel_available(mask, *input.size()): + return self.forward_fused_softmax(input, mask) + else: + return self.forward_torch_softmax(input, mask) + + def is_kernel_available(self, mask, b, np, sq, sk): + attn_batches = b * np + + if ( + self.scaled_masked_softmax_fusion # user want to fuse + and self.input_in_float16 # input must be fp16 + and ( + self.attn_mask_type == AttnMaskType.causal + or (self.attn_mask_type == AttnMaskType.padding and mask is not None) + ) + and 16 < sk <= 8192 # sk must be 16 ~ 8192 + and sq % 4 == 0 # sq must be divisor of 4 + and sk % 4 == 0 # sk must be divisor of 4 + and attn_batches % 4 == 0 # np * b must be divisor of 4 + ): + if 0 <= sk <= 8192: + batch_per_block = self.get_batch_per_block(sq, sk, b, np) + + if self.attn_mask_type == AttnMaskType.causal: + if attn_batches % batch_per_block == 0: + return True + else: + if sq % batch_per_block == 0: + return True + return False + + def forward_fused_softmax(self, input, mask): + # input.shape = [b, np, sq, sk] + scale = self.scale if self.scale is not None else 1.0 + return self.fused_softmax_func(input, mask, scale) + + def forward_torch_softmax(self, input, mask): + if self.input_in_float16 and self.softmax_in_fp32: + input = input.float() + + if self.scale is not None: + input = input * self.scale + mask_output = self.mask_func(input, mask) if mask is not None else input + probs = torch.nn.Softmax(dim=-1)(mask_output) + + if self.input_in_float16 and self.softmax_in_fp32: + if self.input_in_fp16: + probs = probs.half() + else: + probs = probs.bfloat16() + + return probs + + @staticmethod + def get_batch_per_block(sq, sk, b, np): + return scaled_masked_softmax_get_batch_per_block(sq, sk, b, np) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__init__.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c539d8792dbf96986155bcb6bf44d00fd2d82e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/block.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/block.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..439794d3fc1cf27b8f6a569a57457d85c9ab616f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/block.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/embedding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/embedding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c98d0bc17f603b184d3bc3ec9c0c7355847cb289 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/embedding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mha.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mha.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f625fb4be00a25be0711c1dc933351d6b6d7dada Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mha.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mlp.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mlp.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56ba7d8c75c0550c66e9400bbcacd1cd93854110 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/__pycache__/mlp.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/block.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/block.py new file mode 100644 index 0000000000000000000000000000000000000000..be8e8b864b600220068c2ec16aba5e2f1a81c121 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/block.py @@ -0,0 +1,397 @@ +# Copyright (c) 2024, Tri Dao. + +from functools import partial +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from torchvision.ops import StochasticDepth + +from flash_attn.modules.mha import MHA +from flash_attn.modules.mlp import Mlp + +try: + from flash_attn.ops.triton.layer_norm import layer_norm_fn, RMSNorm +except ImportError: + layer_norm_fn, RMSNorm = None, None + + +class Block(nn.Module): + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + prenorm=True, + resid_dropout1=0.0, + resid_dropout2=0.0, + drop_path1=0.0, + drop_path2=0.0, + fused_dropout_add_ln=False, + return_residual=False, + residual_in_fp32=False, + sequence_parallel=False, + mark_shared_params=False, + ): + """ + For prenorm=True, this Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA -> Dropout -> Add -> LN -> MLP, returning both + the hidden_states (output of the MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + + For prenorm=False, this Block has the same structure as a regular postnorm Transformer + block: MHA -> Dropout -> Add -> LN -> MLP -> Dropout -> Add -> LN. + + return_residual: whether each of the sub-layers (mixer and mlp) will return the residual. + This is for performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + super().__init__() + self.prenorm = prenorm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.return_residual = return_residual + self.residual_in_fp32 = residual_in_fp32 + if self.residual_in_fp32: + assert self.prenorm, "residual_in_fp32 is only compatible with prenorm=True" + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.drop_path1 = StochasticDepth(drop_path1, mode="row") + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + if not isinstance(self.mlp, nn.Identity): + self.dropout2 = dropout_cls(resid_dropout2) + self.drop_path2 = StochasticDepth(drop_path2, mode="row") + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert layer_norm_fn is not None, "Triton is not installed" + assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( + self.dropout1, nn.Dropout + ) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._shared_params = True + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward( + self, + hidden_states: Tensor, + residual: Optional[Tensor] = None, + mixer_subset=None, + mixer_kwargs=None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states: the sequence to the encoder layer (required). + residual: if postnorm, residual=None, If prenorm, hidden_states = Attn/MLP(LN(residual)) + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + """ + if self.prenorm: + if not self.fused_dropout_add_ln: + dropped = self.drop_path1(self.dropout1(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm1.weight, + self.norm1.bias, + residual=residual, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, + rowscale=rowscale1, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm1, RMSNorm) + ) + if mixer_kwargs is None: + mixer_kwargs = {} + if mixer_subset is not None: + mixer_kwargs["mixer_subset"] = mixer_subset + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + if mixer_subset is not None: + residual = residual[:, mixer_subset] + if not isinstance(self.mlp, nn.Identity): + if not self.fused_dropout_add_ln: + dropped = self.drop_path2(self.dropout2(hidden_states)) + residual = (dropped + residual) if residual is not None else dropped + hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + hidden_states.shape[:-1], + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + ) + hidden_states, residual = layer_norm_fn( + hidden_states, + self.norm2.weight, + self.norm2.bias, + residual=residual, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm2, RMSNorm) + ) + hidden_states = self.mlp(hidden_states) + return hidden_states, residual + else: + assert residual is None + mixer_out = self.mixer( + hidden_states, **(mixer_kwargs if mixer_kwargs is not None else {}) + ) + if self.return_residual: # mixer out is actually a pair here + mixer_out, hidden_states = mixer_out + if not self.fused_dropout_add_ln: + hidden_states = self.norm1( + (self.drop_path1(self.dropout1(mixer_out)) + hidden_states).to( + dtype=self.norm1.weight.dtype + ) + ) + else: + if self.drop_path1.p == 0 or not self.training: + rowscale1 = None + else: + rowscale1 = self.drop_path1( + torch.ones( + mixer_out.shape[:-1], device=mixer_out.device, dtype=mixer_out.dtype + ) + ) + hidden_states = layer_norm_fn( + mixer_out, + self.norm1.weight, + self.norm1.bias, + residual=hidden_states, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, + rowscale=rowscale1, + prenorm=False, + is_rms_norm=isinstance(self.norm1, RMSNorm) + ) + if not isinstance(self.mlp, nn.Identity): + mlp_out = self.mlp(hidden_states) + if self.return_residual: # mlp out is actually a pair here + mlp_out, hidden_states = mlp_out + if not self.fused_dropout_add_ln: + hidden_states = self.norm2( + (self.drop_path2(self.dropout2(mlp_out)) + hidden_states).to( + dtype=self.norm2.weight.dtype + ) + ) + else: + if self.drop_path2.p == 0 or not self.training: + rowscale2 = None + else: + rowscale2 = self.drop_path2( + torch.ones( + mlp_out.shape[:-1], device=mlp_out.device, dtype=mlp_out.dtype + ) + ) + hidden_states = layer_norm_fn( + mlp_out, + self.norm2.weight, + self.norm2.bias, + residual=hidden_states, + eps=self.norm2.eps, + dropout_p=self.dropout2.p if self.training else 0.0, + rowscale=rowscale2, + prenorm=False, + is_rms_norm=isinstance(self.norm2, RMSNorm) + ) + return hidden_states + + +class ParallelBlock(nn.Module): + """The attention (mixer) and MLP blocks are done in parallel, similar to GPT-J, GPT-NeoX, + and PaLM. + """ + + def __init__( + self, + dim, + mixer_cls=None, + mlp_cls=None, + norm_cls=nn.LayerNorm, + dropout_cls=nn.Dropout, + resid_dropout1=0.0, + resid_dropout2=0.0, + tied_norm=False, + fused_dropout_add_ln=False, + residual_in_fp32=False, + sequence_parallel=False, + mark_shared_params=False, + ): + """ + This Block has a slightly different structure compared to a regular + prenorm Transformer block. + The standard block is: LN -> MHA / MLP -> Dropout -> Add. + [Ref: https://arxiv.org/abs/2002.04745] + Here we have: Dropout -> Add -> LN -> MHA / MLP, returning both + the hidden_states (output1 of the MHA / MLP) and the residual. + This is for performance reasons, as we can fuse the dropout, add and LayerNorm. + The residual needs to be provided (except for the very first block). + """ + super().__init__() + self.tied_norm = tied_norm + self.fused_dropout_add_ln = fused_dropout_add_ln + self.residual_in_fp32 = residual_in_fp32 + if mixer_cls is None: + mixer_cls = partial(MHA, num_heads=dim // 64) + if mlp_cls is None: + mlp_cls = partial(Mlp, hidden_features=4 * dim) + self.mixer = mixer_cls(dim) + self.dropout1 = dropout_cls(resid_dropout1) + self.norm1 = norm_cls(dim) + self.mlp = mlp_cls(dim) + self.dropout2 = dropout_cls(resid_dropout2) + if not self.tied_norm: + self.norm2 = norm_cls(dim) + + if self.fused_dropout_add_ln: + assert layer_norm_fn is not None, "Triton is not installed" + assert isinstance(self.norm1, (nn.LayerNorm, RMSNorm)) and isinstance( + self.dropout1, nn.Dropout + ) + + # TD [2023-01-07]: TODO: During training, if sequence_parallel is False and dropout != 0.0, + # then the input to each worker in the tensor parallel group will be different. + # This would produce wrong outputs? Somehow we'd need to sync the RNG state across workers. + # For now this is not an issue because we always use sequence_parallel=True during training + # and only use sequence_parallel=False during inference. + + # Mark the norm parameters as "sequence_parallel" so that we run all-reduce on their grads. + if sequence_parallel: + for p in self.norm1.parameters(): + p._sequence_parallel = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._sequence_parallel = True + # Mark the norm parameters as "shared_params" so that we sync their values at init. + if mark_shared_params: + for p in self.norm1.parameters(): + p._shared_params = True + if hasattr(self, "norm2"): + for p in self.norm2.parameters(): + p._shared_params = True + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): + return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) + + def forward( + self, + hidden_states1: Tensor, + hidden_states2: Optional[Tensor] = None, + residual: Optional[Tensor] = None, + mixer_kwargs=None, + ): + r"""Pass the input through the encoder layer. + + Args: + hidden_states1: the output of the previous attention (mixer) or embedding layer. + hidden_states2: the output of the previous MLP layer (if None, will use hidden_states1). + residual. + """ + # TODO: Ideally we should only do the allgather / allreduce once for + # the Linear to MLP & Attention + if not self.fused_dropout_add_ln: + dropped1 = self.dropout1(hidden_states1) + # For the very 1st block, we only want 1 dropout, not two different dropouts + if hidden_states2 is not None: + dropped2 = self.dropout2(hidden_states2) + residual = ( + (residual + dropped1 + dropped2) + if residual is not None + else dropped1 + dropped2 + ) + else: + residual = (residual + dropped1) if residual is not None else dropped1 + hidden_states1 = self.norm1(residual.to(dtype=self.norm1.weight.dtype)) + hidden_states2 = ( + self.norm2(residual.to(dtype=self.norm2.weight.dtype)) + if not self.tied_norm + else hidden_states1 + ) + if self.residual_in_fp32: + residual = residual.to(torch.float32) + else: + weight2, bias2 = ( + (self.norm2.weight, self.norm2.bias) if not self.tied_norm else (None, None) + ) + hidden_states1, *rest, residual = layer_norm_fn( + hidden_states1, + self.norm1.weight, + self.norm1.bias, + residual=residual, + x1=hidden_states2, + weight1=weight2, + bias1=bias2, + eps=self.norm1.eps, + dropout_p=self.dropout1.p if self.training else 0.0, + prenorm=True, + residual_in_fp32=self.residual_in_fp32, + is_rms_norm=isinstance(self.norm1, RMSNorm) + ) + if self.tied_norm: + hidden_states2 = hidden_states1 + else: + hidden_states2, = rest + if mixer_kwargs is None: + mixer_kwargs = {} + hidden_states1 = self.mixer(hidden_states1, **mixer_kwargs) + hidden_states2 = self.mlp(hidden_states2) + return hidden_states1, hidden_states2, residual diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/embedding.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..33587d09413dbab5edccfa3806fca829a6f9f9da --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/embedding.py @@ -0,0 +1,216 @@ +# Copyright (c) 2022, Tri Dao. + +import torch +import torch.nn as nn +from einops import rearrange +from torch import Tensor + +from flash_attn.utils.distributed import all_reduce, reduce_scatter + + +class GPT2Embeddings(nn.Module): + def __init__( + self, + embed_dim, + vocab_size, + max_position_embeddings, + padding_idx=None, + word_embed_proj_dim=None, + device=None, + dtype=None, + ): + """ + If max_position_embeddings <= 0, there's no position embeddings + If word_embe_proj_dim is not None (e.g., OPT-350m), we embed to that dimension + the project up to embed_dim + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + if word_embed_proj_dim is None: + self.word_embeddings = nn.Embedding( + vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs + ) + self.project_in = None + else: + self.word_embeddings = nn.Embedding( + vocab_size, word_embed_proj_dim, padding_idx=padding_idx, **factory_kwargs + ) + self.project_in = nn.Linear( + word_embed_proj_dim, embed_dim, bias=False, **factory_kwargs + ) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding( + max_position_embeddings, embed_dim, **factory_kwargs + ) + + def forward(self, input_ids, position_ids=None): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + embeddings = self.word_embeddings(input_ids) + if self.project_in is not None: + embeddings = self.project_in(embeddings) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + return embeddings + + +class BertEmbeddings(nn.Module): + def __init__( + self, + embed_dim, + vocab_size, + max_position_embeddings, + type_vocab_size, + padding_idx=None, + device=None, + dtype=None, + ): + """ + If max_position_embeddings <= 0, there's no position embeddings + If type_vocab_size <= 0, there's no token type embeddings + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.word_embeddings = nn.Embedding( + vocab_size, embed_dim, padding_idx=padding_idx, **factory_kwargs + ) + self.max_position_embeddings = max_position_embeddings + self.type_vocab_size = type_vocab_size + if self.max_position_embeddings > 0: + self.position_embeddings = nn.Embedding( + max_position_embeddings, embed_dim, **factory_kwargs + ) + if self.type_vocab_size > 0: + self.token_type_embeddings = nn.Embedding(type_vocab_size, embed_dim, **factory_kwargs) + + def forward(self, input_ids, position_ids=None, token_type_ids=None): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + token_type_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + embeddings = embeddings + position_embeddings + if self.type_vocab_size > 0: + if token_type_ids is None: + token_type_ids = torch.zeros(seqlen, dtype=torch.long, device=input_ids.device) + token_type_embeddings = self.token_type_embeddings(token_type_ids) + embeddings = embeddings + token_type_embeddings + return embeddings + + +class VocabParallelEmbedding(nn.Embedding): + def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if num_embeddings % world_size != 0: + raise ValueError( + f"num_embeddings ({num_embeddings}) must be divisible by " + f"world_size ({world_size})" + ) + if world_size > 1 and padding_idx is not None: + raise RuntimeError("ParallelEmbedding does not support padding_idx") + else: + world_size = 1 + super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) + + def forward(self, input: Tensor) -> Tensor: + if self.process_group is None: + return super().forward(input) + else: + rank = torch.distributed.get_rank(self.process_group) + vocab_size = self.num_embeddings + vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size + # Create a mask of valid vocab ids (1 means it needs to be masked). + input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) + input = input - vocab_start_index + input[input_ids_mask] = 0 + embeddings = super().forward(input) + embeddings[input_ids_mask] = 0.0 + return embeddings + + +class ColumnParallelEmbedding(nn.Embedding): + def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): + self.process_group = process_group + if process_group is not None: + world_size = torch.distributed.get_world_size(process_group) + if embedding_dim % world_size != 0: + raise ValueError( + f"embedding_dim ({embedding_dim}) must be divisible by " + f"world_size ({world_size})" + ) + else: + world_size = 1 + super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) + + +class ParallelGPT2Embeddings(nn.Module): + def __init__( + self, + embed_dim, + vocab_size, + max_position_embeddings, + process_group, + padding_idx=None, + sequence_parallel=True, + device=None, + dtype=None, + ): + """ + If max_position_embeddings <= 0, there's no position embeddings + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.process_group = process_group + self.sequence_parallel = sequence_parallel + self.word_embeddings = VocabParallelEmbedding( + vocab_size, + embed_dim, + padding_idx=padding_idx, + process_group=process_group, + **factory_kwargs, + ) + self.max_position_embeddings = max_position_embeddings + if self.max_position_embeddings > 0: + self.position_embeddings = ColumnParallelEmbedding( + max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs + ) + + def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): + """ + input_ids: (batch, seqlen) + position_ids: (batch, seqlen) + """ + batch_size, seqlen = input_ids.shape + world_size = torch.distributed.get_world_size(self.process_group) + embeddings = self.word_embeddings(input_ids) + if self.max_position_embeddings > 0: + if position_ids is None: + position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) + position_embeddings = self.position_embeddings(position_ids) + if world_size <= 1: + embeddings = embeddings + position_embeddings + else: + partition_dim = self.position_embeddings.embedding_dim + rank = torch.distributed.get_rank(self.process_group) + embeddings[ + ..., rank * partition_dim : (rank + 1) * partition_dim + ] += position_embeddings + if combine_batch_seqlen_dim: + embeddings = rearrange(embeddings, "b s d -> (b s) d") + reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce + return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mha.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mha.py new file mode 100644 index 0000000000000000000000000000000000000000..77640c2b239ac729cad79ce3b2504e0eeacb5f73 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mha.py @@ -0,0 +1,1020 @@ +# Copyright (c) 2023, Tri Dao. + +import math +from functools import partial + +import torch +import torch.nn as nn +from einops import rearrange, repeat + +from flash_attn.utils.distributed import get_dim_for_local_rank + +try: + from flash_attn import ( + flash_attn_kvpacked_func, + flash_attn_qkvpacked_func, + flash_attn_varlen_kvpacked_func, + flash_attn_varlen_qkvpacked_func, + flash_attn_with_kvcache, + ) +except ImportError: + flash_attn_varlen_qkvpacked_func, flash_attn_varlen_kvpacked_func = None, None + flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None + flash_attn_with_kvcache = None + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear, FusedDense, RowParallelLinear +except ImportError: + FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None + +try: + from flash_attn.layers.rotary import RotaryEmbedding +except ImportError: + RotaryEmbedding = None + + +# From https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 +def get_alibi_slopes(nheads): + def get_slopes_power_of_2(nheads): + start = 2 ** (-(2 ** -(math.log2(nheads) - 3))) + ratio = start + return [start * ratio**i for i in range(nheads)] + + if math.log2(nheads).is_integer(): + return get_slopes_power_of_2(nheads) + else: + closest_power_of_2 = 2 ** math.floor(math.log2(nheads)) + return ( + get_slopes_power_of_2(closest_power_of_2) + + get_alibi_slopes(2 * closest_power_of_2)[0::2][: nheads - closest_power_of_2] + ) + + +class FlashSelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + window_size=(-1, -1), + alibi_slopes=None, + deterministic=False, + ): + super().__init__() + assert flash_attn_varlen_qkvpacked_func is not None, "FlashAttention is not installed" + assert flash_attn_qkvpacked_func is not None, "FlashAttention is not installed" + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.window_size = window_size + self.deterministic = deterministic + + def forward(self, qkv, causal=None, cu_seqlens=None, max_seqlen=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. + If cu_seqlens is None and max_seqlen is None, then qkv has shape (B, S, 3, H, D). + If cu_seqlens is not None and max_seqlen is not None, then qkv has shape + (total, 3, H, D), where total is the sum of the sequence lengths in the batch. + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into qkv. + max_seqlen: int. Maximum sequence length in the batch. + Returns: + -------- + out: (total, H, D) if cu_seqlens is not None and max_seqlen is not None, + else (B, S, H, D). + """ + assert qkv.dtype in [torch.float16, torch.bfloat16] + assert qkv.is_cuda + causal = self.causal if causal is None else causal + unpadded = cu_seqlens is not None + if self.alibi_slopes is not None: + self.alibi_slopes = self.alibi_slopes.to(torch.float32) + if unpadded: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + return flash_attn_varlen_qkvpacked_func( + qkv, + cu_seqlens, + max_seqlen, + self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + alibi_slopes=self.alibi_slopes, + window_size=self.window_size, + deterministic=self.deterministic, + ) + else: + return flash_attn_qkvpacked_func( + qkv, + self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + alibi_slopes=self.alibi_slopes, + window_size=self.window_size, + deterministic=self.deterministic, + ) + + +class FlashCrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__( + self, + causal=False, + softmax_scale=None, + attention_dropout=0.0, + alibi_slopes=None, + window_size=(-1, -1), + deterministic=False, + ): + super().__init__() + assert flash_attn_varlen_kvpacked_func is not None, "FlashAttention is not installed" + assert flash_attn_kvpacked_func is not None, "FlashAttention is not installed" + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + self.register_buffer("alibi_slopes", alibi_slopes, persistent=False) + self.window_size = window_size + self.deterministic = deterministic + + def forward( + self, + q, + kv, + causal=None, + cu_seqlens=None, + max_seqlen=None, + cu_seqlens_k=None, + max_seqlen_k=None, + ): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) + causal: if passed, will override self.causal + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into q. + max_seqlen: int. Maximum sequence length in the batch of q. + cu_seqlens_k: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into kv. + max_seqlen_k: int. Maximum sequence length in the batch of k and v. + """ + assert q.dtype in [torch.float16, torch.bfloat16] + assert q.is_cuda and kv.is_cuda + causal = self.causal if causal is None else causal + unpadded = cu_seqlens is not None + if self.alibi_slopes is not None: + self.alibi_slopes = self.alibi_slopes.to(torch.float32) + if unpadded: + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + assert cu_seqlens_k is not None + assert cu_seqlens_k.dtype == torch.int32 + assert max_seqlen_k is not None + assert isinstance(max_seqlen_k, int) + return flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + cu_seqlens_k, + max_seqlen, + max_seqlen_k, + self.drop.p if self.training else 0.0, + softmax_scale=self.softmax_scale, + causal=causal, + alibi_slopes=self.alibi_slopes, + window_size=self.window_size, + deterministic=self.deterministic, + ) + else: + batch_size, seqlen_q = q.shape[0], q.shape[1] + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + return flash_attn_kvpacked_func( + q, + kv, + self.drop.p if self.training else 0.0, + causal=causal, + softmax_scale=self.softmax_scale, + alibi_slopes=self.alibi_slopes, + window_size=self.window_size, + deterministic=self.deterministic, + ) + + +class SelfAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, qkv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, S) + """ + batch_size, seqlen = qkv.shape[0], qkv.shape[1] + causal = self.causal if causal is None else causal + q, k, v = qkv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full( + (batch_size, seqlen), -10000.0, dtype=scores.dtype, device=scores.device + ) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: + # "triu_tril_cuda_template" not implemented for 'BFloat16' + # So we have to construct the mask in float + causal_mask = torch.triu( + torch.full((seqlen, seqlen), -10000.0, device=scores.device), 1 + ) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + causal_mask.to(dtype=scores.dtype) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = self.drop(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + return output + + +class CrossAttention(nn.Module): + """Implement the scaled dot product attention with softmax. + Arguments + --------- + softmax_scale: The temperature to use for the softmax attention. + (default: 1/sqrt(d_keys) where d_keys is computed at + runtime) + attention_dropout: The dropout rate to apply to the attention + (default: 0.0) + """ + + def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0): + super().__init__() + self.causal = causal + self.softmax_scale = softmax_scale + self.drop = nn.Dropout(attention_dropout) + + def forward(self, q, kv, causal=None, key_padding_mask=None): + """Implements the multihead softmax attention. + Arguments + --------- + q: The tensor containing the query. (B, Sq, H, D) + kv: The tensor containing the key and value. (B, Sk, 2, H_k, D) + causal: if passed, will override self.causal + key_padding_mask: boolean mask to apply to the attention weights. True means to keep, + False means to mask out. (B, Sk) + """ + batch_size, seqlen_q = q.shape[0], q.shape[1] + causal = self.causal if causal is None else causal + seqlen_k = kv.shape[1] + assert kv.shape[0] == batch_size and kv.shape[4] == q.shape[3] + if kv.shape[3] != q.shape[2]: # MQA/GQA + kv = repeat(kv, "... hkv d -> ... (hkv g) d", g=q.shape[2] // kv.shape[3]) + k, v = kv.unbind(dim=2) + softmax_scale = self.softmax_scale or 1.0 / math.sqrt(q.shape[-1]) + scores = torch.einsum("bthd,bshd->bhts", q, k * softmax_scale) + if key_padding_mask is not None: + padding_mask = torch.full( + (batch_size, seqlen_k), -10000.0, dtype=scores.dtype, device=scores.device + ) + padding_mask.masked_fill_(key_padding_mask, 0.0) + # TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess) + scores = scores + rearrange(padding_mask, "b s -> b 1 1 s") + if causal: + # causal mask needs to take into account the difference between seqlen_q and seqlen_k + row_idx = rearrange( + torch.arange(seqlen_q, device=q.device, dtype=torch.long), "s -> s 1" + ) + col_idx = torch.arange(seqlen_k, device=kv.device, dtype=torch.long) + sk = ( + seqlen_k + if key_padding_mask is None + else rearrange(key_padding_mask.sum(-1), "b -> b 1 1 1") + ) + causal_mask = col_idx > row_idx + sk - seqlen_q + scores = scores.masked_fill(causal_mask, -10000.0) + attention = torch.softmax(scores, dim=-1, dtype=v.dtype) + attention_drop = self.drop(attention) + output = torch.einsum("bhts,bshd->bthd", attention_drop, v) + return output + + +class LinearResidual(nn.Linear): + """Wrap nn.Linear to return the residual as well. For compatibility with FusedDense.""" + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return super().forward(input), input + + +def _update_kv_cache(kv, inference_params, layer_idx): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + # Pre-allocate memory for key-values for inference. + num_heads, head_dim = kv.shape[-2:] + if layer_idx not in inference_params.key_value_memory_dict: + kv_cache = torch.empty( + inference_params.max_batch_size, + inference_params.max_seqlen, + 2, + num_heads, + head_dim, + dtype=kv.dtype, + device=kv.device, + ) + inference_params.key_value_memory_dict[layer_idx] = kv_cache + else: + kv_cache = inference_params.key_value_memory_dict[layer_idx] + # Adjust key and value for inference + batch_start = inference_params.batch_size_offset + batch_end = batch_start + kv.shape[0] + sequence_start = inference_params.seqlen_offset + sequence_end = sequence_start + kv.shape[1] + assert batch_end <= kv_cache.shape[0] + assert sequence_end <= kv_cache.shape[1] + assert kv_cache is not None + kv_cache[batch_start:batch_end, sequence_start:sequence_end, ...] = kv + return kv_cache[batch_start:batch_end, :sequence_end, ...] + + +class MHA(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + num_heads_kv=None, + cross_attn=False, + qkv_proj_bias=True, + out_proj_bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + dwconv=False, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved=False, + use_alibi=False, + window_size=(-1, -1), + fused_bias_fc=False, + use_flash_attn=False, + return_residual=False, + checkpointing=False, + device=None, + dtype=None, + ) -> None: + """ + num_heads_kv: can be used to toggle MQA / GQA. If None, use num_heads. + return_residual: whether to return the input x along with the output. This is for + performance reason: for post-norm architecture, returning the input allows us + to fuse the backward of nn.Linear with the residual connection. + """ + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.cross_attn = cross_attn + self.causal = causal + self.layer_idx = layer_idx + self.dwconv = dwconv + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.return_residual = return_residual + self.checkpointing = checkpointing + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + alibi_slopes = torch.tensor(get_alibi_slopes(num_heads), device=device) + else: + alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" + + self.num_heads = num_heads + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads" + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + kv_dim = 2 * self.head_dim * self.num_heads_kv + + if self.rotary_emb_dim > 0: + assert not cross_attn, "MHA with rotary embedding does not support cross-attention yet" + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + if fused_bias_fc and FusedDense is None: + raise ImportError("fused_dense is not installed") + linear_cls = nn.Linear if not fused_bias_fc else FusedDense + linear_resid_cls = ( + LinearResidual if not fused_bias_fc else partial(FusedDense, return_residual=True) + ) + wqkv_cls = linear_cls if not self.return_residual else linear_resid_cls + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else SelfAttention + ) + inner_cross_attn_cls = ( + partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else CrossAttention + ) + if not self.cross_attn: + self.Wqkv = wqkv_cls(embed_dim, qkv_dim, bias=qkv_proj_bias, **factory_kwargs) + else: + self.Wq = linear_cls(embed_dim, embed_dim, bias=qkv_proj_bias, **factory_kwargs) + self.Wkv = wqkv_cls(embed_dim, kv_dim, bias=qkv_proj_bias, **factory_kwargs) + if self.dwconv: + if self.num_heads_kv == self.num_heads: + self.dwconv_qkv = nn.Conv1d( + qkv_dim, qkv_dim, kernel_size=3, padding=2, groups=qkv_dim + ) + else: + self.dwconv_q = nn.Conv1d( + embed_dim, embed_dim, kernel_size=3, padding=2, groups=embed_dim + ) + self.dwconv_kv = nn.Conv1d(kv_dim, kv_dim, kernel_size=3, padding=2, groups=kv_dim) + self.inner_attn = inner_attn_cls( + causal=causal, + softmax_scale=softmax_scale, + attention_dropout=dropout, + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.out_proj = linear_cls(embed_dim, embed_dim, bias=out_proj_bias, **factory_kwargs) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + return torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv, + self.head_dim, + dtype=dtype, + device=device, + ) + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + assert not self.dwconv, "Generation does not support dwconv yet" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): + """ + Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. + q: (batch_size, seqlen_q, nheads, head_dim) + kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) + """ + assert inference_params is not None and inference_params.seqlen_offset > 0 + assert self.use_flash_attn + if self.rotary_emb_dim > 0: + assert self.rotary_emb.scale is None, "This code path does not support xPos" + self.rotary_emb._update_cos_sin_cache( + inference_params.max_seqlen, device=q.device, dtype=q.dtype + ) + rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached + else: + rotary_cos, rotary_sin = None, None + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, + ) + return context + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention""" + if ( + inference_params.seqlen_offset == 0 + or flash_attn_with_kvcache is None + or not self.use_flash_attn + ): + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + return flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + alibi_slopes=alibi_slopes, + ) + + def forward( + self, + x, + x_kv=None, + key_padding_mask=None, + cu_seqlens=None, + max_seqlen=None, + mixer_subset=None, + inference_params=None, + **kwargs, + ): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if + cu_seqlens is None and max_seqlen is None, else (total, hidden_dim) where total + is the is the sum of the sequence lengths in the batch. + x_kv: (batch, seqlen, hidden_dim), only applicable for cross-attention. If None, use x. + cu_seqlens: (batch_size + 1,), dtype torch.int32. The cumulative sequence lengths + of the sequences in the batch, used to index into x. Only applicable when using + FlashAttention. + max_seqlen: int. Maximum sequence length in the batch. + key_padding_mask: boolean mask, True means to keep, False means to mask out. + (batch, seqlen). Only applicable when not using FlashAttention. + mixer_subset: for cross-attention only. If not None, will take a subset of x + before applying the query projection. Useful for e.g., ViT where we only care + about the CLS token in the last layer. + inference_params: for generation. Adapted from Megatron-LM (and Apex) + https://github.com/NVIDIA/apex/blob/3ff1a10f72ec07067c4e44759442329804ac5162/apex/transformer/testing/standalone_transformer_lm.py#L470 + """ + if cu_seqlens is not None: + assert max_seqlen is not None + assert key_padding_mask is None + assert self.use_flash_attn + assert not self.dwconv + assert self.rotary_emb_dim == 0 + if key_padding_mask is not None: + assert cu_seqlens is None + assert max_seqlen is None + assert not self.use_flash_attn + if inference_params is not None: + assert key_padding_mask is None + assert cu_seqlens is None and max_seqlen is None + assert not self.dwconv + + kwargs = ( + {"cu_seqlens": cu_seqlens, "max_seqlen": max_seqlen, **kwargs} + if self.use_flash_attn + else {"key_padding_mask": key_padding_mask, **kwargs} + ) + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + batch, seqlen = x.shape[:2] + if not self.cross_attn and self.num_heads_kv == self.num_heads: + assert x_kv is None and mixer_subset is None + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + if self.dwconv: + qkv = rearrange( + self.dwconv_qkv(rearrange(qkv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" + ).contiguous() + qkv = rearrange(qkv, "... (three h d) -> ... three h d", three=3, d=self.head_dim) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + else: + context = self._update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: + context = self._apply_rotary_update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: + if self.cross_attn: + if not self.return_residual: + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + kv = self.Wkv(x_kv if x_kv is not None else x) + else: + if x_kv is not None: + kv, x_kv = self.Wkv(x_kv) + else: + kv, x = self.Wkv(x) + q = self.Wq(x if mixer_subset is None else x[:, mixer_subset]) + else: + assert self.num_heads_kv != self.num_heads + if not self.return_residual: + qkv = self.Wqkv(x) + else: + qkv, x = self.Wqkv(x) + q = qkv[..., : self.num_heads * self.head_dim] + kv = qkv[..., self.num_heads * self.head_dim :] + q = rearrange(q, "... (h d) -> ... h d", d=self.head_dim) + kv = rearrange(kv, "... (two hkv d) -> ... two hkv d", two=2, d=self.head_dim) + if self.dwconv: + q = rearrange( + self.dwconv_q(rearrange(q, "b s d -> b d s"))[..., :-2], "b d s -> b s d" + ).contiguous() + kv = rearrange( + self.dwconv_kv(rearrange(kv, "b s d -> b d s"))[..., :-2], "b d s -> b s d" + ).contiguous() + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_cross_attn, q, kv, **kwargs + ) + else: + context = self._update_kvcache_attention(q, kv, inference_params) + else: + context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) + out = self.out_proj(rearrange(context, "... h d -> ... (h d)")) + return out if not self.return_residual else (out, x) + + +class ParallelMHA(nn.Module): + """Multi-head self-attention and cross-attention""" + + def __init__( + self, + embed_dim, + num_heads, + process_group, + num_heads_kv=None, + qkv_proj_bias=True, + out_proj_bias=True, + dropout=0.0, + softmax_scale=None, + causal=False, + layer_idx=None, + rotary_emb_dim=0, + rotary_emb_base=10000.0, + rotary_emb_scale_base=None, + rotary_emb_interleaved=False, + use_alibi=False, + window_size=(-1, -1), + use_flash_attn=False, + checkpointing=False, + sequence_parallel=True, + device=None, + dtype=None, + ) -> None: + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.embed_dim = embed_dim + self.causal = causal + self.layer_idx = layer_idx + self.rotary_emb_dim = rotary_emb_dim + self.use_flash_attn = use_flash_attn + self.checkpointing = checkpointing + self.process_group = process_group + self.world_size = process_group.size() + self.local_rank = torch.distributed.get_rank(process_group) + + self.num_heads = num_heads + assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads" + + self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads + assert ( + self.num_heads % self.num_heads_kv == 0 + ), "num_heads must be divisible by num_heads_kv" + + self.num_heads_per_rank = get_dim_for_local_rank( + self.num_heads, self.world_size, self.local_rank + ) + self.num_heads_kv_per_rank = get_dim_for_local_rank( + self.num_heads_kv, self.world_size, self.local_rank + ) + self.head_dim = self.embed_dim // num_heads + qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) + + if use_alibi: + assert use_flash_attn, "ALiBi code path requires flash_attn" + num_heads_local = math.ceil(self.num_heads / self.world_size) + alibi_slopes = torch.tensor( + get_alibi_slopes(num_heads)[ + self.local_rank * num_heads_local : (self.local_rank + 1) * num_heads_local + ], + device=device, + ) + else: + alibi_slopes = None + if window_size != (-1, -1): + assert use_flash_attn, "Local (sliding window) attention code path requires flash_attn" + + if self.rotary_emb_dim > 0: + assert RotaryEmbedding is not None, "rotary_emb is not installed" + self.rotary_emb = RotaryEmbedding( + self.rotary_emb_dim, + base=rotary_emb_base, + scale_base=rotary_emb_scale_base, + interleaved=rotary_emb_interleaved, + device=device, + ) + + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.Wqkv = ColumnParallelLinear( + embed_dim, + qkv_dim, + process_group, + bias=qkv_proj_bias, + sequence_parallel=sequence_parallel, + multiple_of=self.head_dim * (self.num_heads // self.num_heads_kv + 2), + **factory_kwargs, + ) + inner_attn_cls = ( + partial(FlashSelfAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else SelfAttention + ) + inner_cross_attn_cls = ( + partial(FlashCrossAttention, alibi_slopes=alibi_slopes, window_size=window_size) + if use_flash_attn + else CrossAttention + ) + self.inner_attn = inner_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.inner_cross_attn = inner_cross_attn_cls( + causal=causal, softmax_scale=softmax_scale, attention_dropout=dropout + ) + self.out_proj = RowParallelLinear( + embed_dim, + embed_dim, + process_group, + bias=out_proj_bias, + sequence_parallel=sequence_parallel, + multiple_of=self.head_dim, + **factory_kwargs, + ) + + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None): + dtype = self.out_proj.weight.dtype if dtype is None else dtype + device = self.out_proj.weight.device + return torch.empty( + batch_size, + max_seqlen, + 2, + self.num_heads_kv_per_rank, + self.head_dim, + dtype=dtype, + device=device, + ) + + def _update_kv_cache(self, kv, inference_params): + """kv: (batch_size, seqlen, 2, nheads, head_dim) or (batch_size, 1, 2, nheads, head_dim)""" + assert self.layer_idx is not None, "Generation requires layer_idx in the constructor" + return _update_kv_cache(kv, inference_params, self.layer_idx) + + def _apply_rotary_update_kvcache_attention(self, q, kv, inference_params): + """ + Fast path that combine 3 steps: apply rotary to Q and K, update kv cache, and apply attention. + q: (batch_size, seqlen_q, nheads, head_dim) + kv: (batch_size, seqlen_k, 2, nheads_kv, head_dim) + """ + assert inference_params is not None and inference_params.seqlen_offset > 0 + assert self.use_flash_attn + if self.rotary_emb_dim > 0: + assert self.rotary_emb.scale is None, "This code path does not support xPos" + self.rotary_emb._update_cos_sin_cache( + inference_params.max_seqlen, device=q.device, dtype=q.dtype + ) + rotary_cos, rotary_sin = self.rotary_emb._cos_cached, self.rotary_emb._sin_cached + else: + rotary_cos, rotary_sin = None, None + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + rotary_cos=rotary_cos, + rotary_sin=rotary_sin, + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + rotary_interleaved=self.rotary_emb.interleaved if self.rotary_emb_dim > 0 else False, + alibi_slopes=alibi_slopes, + ) + return context + + def _update_kvcache_attention(self, q, kv, inference_params): + """Write kv to inference_params, then do attention""" + if inference_params.seqlen_offset == 0 or not self.use_flash_attn: + # TODO: this only uses seqlen_offset and not lengths_per_sample. + kv = self._update_kv_cache(kv, inference_params) + return self.inner_cross_attn(q, kv) + else: + batch = q.shape[0] + kv_cache = inference_params.key_value_memory_dict[self.layer_idx][:batch] + cache_seqlens = ( + inference_params.lengths_per_sample[:batch] + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + alibi_slopes = getattr(self.inner_cross_attn, "alibi_slopes", None) + context = flash_attn_with_kvcache( + q, + kv_cache[:, :, 0], + kv_cache[:, :, 1], + kv[:, :, 0], + kv[:, :, 1], + cache_seqlens=cache_seqlens, + softmax_scale=self.inner_cross_attn.softmax_scale, + causal=self.inner_cross_attn.causal, + alibi_slopes=alibi_slopes, + ) + return context + + def forward(self, x, seqlen=None, inference_params=None, **kwargs): + """ + Arguments: + x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None. + If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we + split x during sequence parallel, we split the batch * seqlen dimension + (in case batch is small). + """ + qkv = self.Wqkv(x) + if seqlen is not None: + qkv = rearrange(qkv, "(b s) ... -> b s ...", s=seqlen) + seqlen_offset = ( + 0 + if inference_params is None + else ( + inference_params.lengths_per_sample + if inference_params.lengths_per_sample is not None + else inference_params.seqlen_offset + ) + ) + rotary_max_seqlen = inference_params.max_seqlen if inference_params is not None else None + if self.num_heads_kv == self.num_heads: + qkv = rearrange(qkv, "b s (three h d) -> b s three h d", three=3, d=self.head_dim) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + if self.rotary_emb_dim > 0: + qkv = self.rotary_emb( + qkv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_attn(qkv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs) + else: + context = self._update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: + context = self._apply_rotary_update_kvcache_attention( + qkv[:, :, 0], qkv[:, :, 1:], inference_params + ) + else: + q = rearrange( + qkv[..., : self.num_heads_per_rank * self.head_dim], + "... (h d) -> ... h d", + d=self.head_dim, + ) + kv = rearrange( + qkv[..., self.num_heads_per_rank * self.head_dim :], + "... (two hkv d) -> ... two hkv d", + two=2, + d=self.head_dim, + ) + if ( + inference_params is None + or inference_params.seqlen_offset == 0 + or (self.rotary_emb_dim == 0 or self.rotary_emb_dim % 16 != 0) + or not self.use_flash_attn + ): + if self.rotary_emb_dim > 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=seqlen_offset, max_seqlen=rotary_max_seqlen + ) + if inference_params is None: + if not self.checkpointing: + context = self.inner_cross_attn(q, kv, **kwargs) + else: + context = torch.utils.checkpoint.checkpoint( + self.inner_cross_attn, q, kv, **kwargs + ) + else: + context = self._update_kvcache_attention(q, kv, inference_params) + else: + context = self._apply_rotary_update_kvcache_attention(q, kv, inference_params) + context = rearrange(context, "b s h d -> b s (h d)") + if seqlen is not None: + context = rearrange(context, "b s d -> (b s) d") + out = self.out_proj(context) + return out diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mlp.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..23584d3098a245bcb2b09653bdf0f426eb63bde2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/modules/mlp.py @@ -0,0 +1,191 @@ +# Copyright (c) 2023, Tri Dao. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributed import ProcessGroup + + +try: + from flash_attn.ops.activations import swiglu +except ImportError: + swiglu = None + +try: + from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear +except ImportError: + ColumnParallelLinear, RowParallelLinear = None, None + +try: + from flash_attn.ops.fused_dense import FusedMLP, ParallelFusedMLP +except ImportError: + FusedMLP, ParallelFusedMLP = None, None + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.gelu, + bias1=True, + bias2=True, + return_residual=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = hidden_features if hidden_features is not None else in_features * 4 + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias1, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y if not self.return_residual else (y, x) + + +class ParallelMLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.gelu, + process_group: ProcessGroup = None, + sequence_parallel=True, + bias1=True, + bias2=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + assert ColumnParallelLinear is not None, "Need to install fused_dense" + assert RowParallelLinear is not None, "Need to install fused_dense" + out_features = out_features if out_features is not None else in_features + hidden_features = hidden_features if hidden_features is not None else in_features * 4 + self.fc1 = ColumnParallelLinear( + in_features, + hidden_features, + process_group, + bias=bias1, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.activation = activation + self.fc2 = RowParallelLinear( + hidden_features, + out_features, + process_group, + bias=bias2, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + def forward(self, x): + y = self.fc1(x) + y = self.activation(y) + y = self.fc2(y) + return y + + +class GatedMlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + activation=F.sigmoid, + bias1=True, + bias2=True, + multiple_of=128, + return_residual=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = ( + hidden_features if hidden_features is not None else int(8 * in_features / 3) + ) + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + self.return_residual = return_residual + self.fc1 = nn.Linear(in_features, 2 * hidden_features, bias=bias1, **factory_kwargs) + self.activation = activation + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias2, **factory_kwargs) + + def forward(self, x): + y = self.fc1(x) + if self.activation == F.sigmoid: # Special case for GLU + y = F.glu(y, dim=-1) + elif self.activation == F.silu and swiglu is not None: # Special case for SwiGLU + y, gate = y.chunk(2, dim=-1) + y = swiglu(gate, y) + else: + y, gate = y.chunk(2, dim=-1) + y = y * self.activation(gate) + y = self.fc2(y) + return y if not self.return_residual else (y, x) + + +class ParallelGatedMlp(nn.Module): + """Parallel GatedMlp""" + + def __init__( + self, + in_features, + process_group, + hidden_features=None, + out_features=None, + activation=F.sigmoid, + bias1=True, + bias2=True, + multiple_of=128, + sequence_parallel=True, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + out_features = out_features if out_features is not None else in_features + hidden_features = ( + hidden_features if hidden_features is not None else int(8 * in_features / 3) + ) + hidden_features = (hidden_features + multiple_of - 1) // multiple_of * multiple_of + if ColumnParallelLinear is None or RowParallelLinear is None: + raise ImportError("fused_dense is not installed") + self.fc1 = ColumnParallelLinear( + in_features, + 2 * hidden_features, + process_group, + bias=bias1, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + self.activation = activation + self.fc2 = RowParallelLinear( + hidden_features, + out_features, + process_group, + bias=bias2, + sequence_parallel=sequence_parallel, + **factory_kwargs, + ) + + def forward(self, x): + y = self.fc1(x) + if self.activation == F.sigmoid: # Special case for GLU + y = F.glu(y, dim=-1) + else: + y, gate = y.chunk(2, dim=-1) + y = y * self.activation(gate) + y = self.fc2(y) + return y diff --git a/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/layer_norm.py b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..4b6cd798fd02844ef9cd3897f8ab95e490e638bf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/_flash_attn/ops/layer_norm.py @@ -0,0 +1,800 @@ +# Copyright (c) 2022, Tri Dao. +# Adapted from https://github.com/NVIDIA/apex/blob/master/apex/contrib/layer_norm/layer_norm.py + +import dropout_layer_norm +import torch +from torch.nn import init + + +def maybe_align(x, alignment_in_bytes=16): + """Assume that x already has last dim divisible by alignment_in_bytes""" + # TD [2023-07-04] I'm not 100% sure that clone will align the memory + # https://discuss.pytorch.org/t/how-to-ensure-that-tensor-data-ptr-is-aligned-to-16-bytes/183440 + return x if x.data_ptr() % alignment_in_bytes == 0 else x.clone() + + +def _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + rowscale, + colscale, + None, + None, + dropout_p, + epsilon, + 1.0, + 0, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(xmat.shape) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + rowscale = rowscale.view(-1) if rowscale is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + None, + None, + dropout_p, + 1.0, + 0, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma.numel() + x0mat = x0.view((-1, hidden_size)) + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + zmat, xmat, dmask, mu, rsigma = dropout_layer_norm.dropout_add_ln_fwd( + x0mat, + residualmat, + gamma, + beta, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask is None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return zmat, xmat if xmat is not None else x0mat, dmask, mu, rsigma + + +def _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + x0 must not be None if we have colscale. + """ + hidden_size = gamma.numel() + xmat = x.view((-1, hidden_size)) + dzmat = dz.view(-1, hidden_size) + dxmat = dx.view(xmat.shape) if dx is not None else None + x0mat = x0.view((-1, hidden_size)) if x0 is not None else None + x0_subset = x0_subset.view(-1) if x0_subset is not None else None + out_subset = out_subset.view(-1) if out_subset is not None else None + if colscale is not None: + assert x0 is not None, "x0 is required to compute the gradient of colscale" + dx0mat, dresidualmat, dgamma, dbeta, _, _, *rest = dropout_layer_norm.dropout_add_ln_bwd( + dzmat, + dxmat, + xmat, + x0mat, + dmask, + mu, + rsigma, + gamma, + None, + colscale, + x0_subset, + out_subset, + dropout_p, + rowscale_const, + x0_numrows, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + if colscale is None: + return dx0mat, dresidualmat, dgamma, dbeta + else: + dcolscale = rest[0] + return dx0mat, dresidualmat, dgamma, dbeta, dcolscale + + +def _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes""" + hidden_size = gamma0.numel() + x0mat = x0.view((-1, hidden_size)) + x1mat = x1.view((-1, hidden_size)) if x1 is not None else None + residualmat = residual.view((-1, hidden_size)) if residual is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_fwd( + x0mat, + x1mat, + residualmat, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + None, + residual_in_fp32, + is_rms_norm, + ) + # dmask0 and dmask1 are None if dropout_p == 0.0 + # xmat is None if dropout_p == 0.0 and residual is None and residual_dtype != input_dtype + return z0mat, z1mat, xmat if xmat is not None else x0mat, dmask0, dmask1, mu, rsigma + + +def _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm=False, +): + """Assume that arguments are contiguous and aligned to 16 bytes + dx == None means that it was a post-norm architecture + (x = drop(x0) + residual was not returned in the fwd). + """ + hidden_size = gamma0.numel() + xmat = x.view((-1, hidden_size)) + dz0mat = dz0.view(xmat.shape) + dz1mat = dz1.view(xmat.shape) if dz1 is not None else None + dxmat = dx.view(xmat.shape) if dx is not None else None + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + *rest, + ) = dropout_layer_norm.dropout_add_ln_parallel_residual_bwd( + dz0mat, + dz1mat, + dxmat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + is_rms_norm, + ) + # dresidualmat is None if not has_residual + return dx0mat, dx1mat, dresidualmat, dgamma0, dbeta0, dgamma1, dbeta1 + + +class DropoutAddLayerNormFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + rowscale = maybe_align(rowscale.contiguous(), 16) if rowscale is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_forward( + x0, + residual, + gamma, + beta, + rowscale, + colscale, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + ctx.save_for_backward( + xmat.view(x0.shape), x0_saved, dmask, gamma, mu, rsigma, rowscale, colscale + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + if not return_dmask: + return ( + zmat.view(x0.shape) if not prenorm else (zmat.view(x0.shape), xmat.view(x0.shape)) + ) + else: + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return ( + (zmat.view(x0.shape), dmask) + if not prenorm + else (zmat.view(x0.shape), xmat.view(x0.shape), dmask) + ) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, rowscale, colscale = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + rowscale, + colscale, + dropout_p, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + None, + dcolscale, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormSubsetFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma = maybe_align(gamma.contiguous(), 16) + beta = maybe_align(beta.contiguous(), 16) if beta is not None else None + colscale = maybe_align(colscale.contiguous(), 16) if colscale is not None else None + zmat, xmat, dmask, mu, rsigma = _dropout_add_layer_norm_subset_forward( + x0, + residual, + gamma, + beta, + colscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + is_rms_norm, + ) + # Only need to save x0 if we need to compute gradient wrt colscale + x0_saved = x0 if colscale is not None else None + x_shape = (-1, *x0.shape[1:]) + ctx.save_for_backward( + xmat.view(x_shape), x0_saved, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset + ) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.rowscale_const = rowscale_const + ctx.x0_numrows = x0.shape[:-1].numel() + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta is not None + z_shape = (-1, *x0.shape[1:]) + if not return_dmask: + return zmat.view(z_shape) if not prenorm else (zmat.view(z_shape), xmat.view(x0.shape)) + else: + z = zmat.view(z_shape) + dmask = ( + dmask.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask) + return (z, dmask) if not prenorm else (z, xmat.view(x_shape), dmask) + + @staticmethod + def backward(ctx, dz, *args): + # assert dz.is_contiguous() + dz = maybe_align(dz.contiguous(), 16) # this happens! + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, x0, dmask, gamma, mu, rsigma, colscale, x0_subset, out_subset = ctx.saved_tensors + # x0 is None if colscale is None + dropout_p = ctx.dropout_p + has_residual = ctx.has_residual + dx0mat, dresidualmat, dgamma, dbeta, *rest = _dropout_add_layer_norm_subset_backward( + dz, + dx, + x, + x0, + dmask, + mu, + rsigma, + gamma, + colscale, + x0_subset, + out_subset, + dropout_p, + ctx.rowscale_const, + ctx.x0_numrows, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(-1, *x.shape[1:]) + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + dcolscale = rest[0] if colscale is not None else None + return ( + dx0, + dresidual, + dgamma, + dbeta if ctx.has_beta else None, + dcolscale, + None, + None, + None, + None, + None, + None, + None, + None, + None, + None, + ) + + +class DropoutAddLayerNormParallelResidualFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32=False, + prenorm=False, + is_rms_norm=False, + return_dmask=False, + ): + x0 = maybe_align(x0.contiguous(), 16) + x1 = maybe_align(x1.contiguous(), 16) if x1 is not None else None + residual = maybe_align(residual.contiguous(), 16) if residual is not None else None + gamma0 = maybe_align(gamma0.contiguous(), 16) + beta0 = maybe_align(beta0.contiguous(), 16) if beta0 is not None else None + gamma1 = maybe_align(gamma1.contiguous(), 16) if gamma1 is not None else None + beta1 = maybe_align(beta1.contiguous(), 16) if beta1 is not None else None + ( + z0mat, + z1mat, + xmat, + dmask0, + dmask1, + mu, + rsigma, + ) = _dropout_add_layer_norm_parallel_residual_forward( + x0, + x1, + residual, + gamma0, + beta0, + gamma1, + beta1, + dropout_p, + epsilon, + residual_in_fp32, + is_rms_norm, + ) + ctx.save_for_backward(xmat.view(x0.shape), dmask0, dmask1, gamma0, gamma1, mu, rsigma) + ctx.prenorm = prenorm + ctx.dropout_p = dropout_p + ctx.has_x1 = x1 is not None + ctx.has_residual = residual is not None + ctx.is_rms_norm = is_rms_norm + ctx.has_beta = beta0 is not None + z = (z0mat.view(x0.shape), z1mat.view(x0.shape) if z1mat is not None else None) + if not return_dmask: + return z if not prenorm else (*z, xmat.view(x0.shape)) + else: + dmask0 = ( + dmask0.view(x0.shape) + if dropout_p > 0.0 + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + dmask1 = ( + dmask1.view(x0.shape) + if dropout_p > 0.0 and x1 is not None + else torch.ones(x0.shape, dtype=torch.uint8, device=x0.device) + ) + ctx.mark_non_differentiable(dmask0) + ctx.mark_non_differentiable(dmask1) + return ( + (*z, dmask0, dmask1) if not prenorm else (*z, xmat.view(x0.shape), dmask0, dmask1) + ) + + @staticmethod + def backward(ctx, dz0, dz1, *args): + dz0 = maybe_align(dz0.contiguous(), 16) # this happens! + dz1 = maybe_align(dz1.contiguous(), 16) if dz1 is not None else None + dx = maybe_align(args[0].contiguous(), 16) if ctx.prenorm else None + x, dmask0, dmask1, gamma0, gamma1, mu, rsigma = ctx.saved_tensors + dropout_p = ctx.dropout_p + has_x1 = ctx.has_x1 + has_residual = ctx.has_residual + ( + dx0mat, + dx1mat, + dresidualmat, + dgamma0, + dbeta0, + dgamma1, + dbeta1, + ) = _dropout_add_layer_norm_parallel_residual_backward( + dz0, + dz1, + dx, + x, + dmask0, + dmask1, + mu, + rsigma, + gamma0, + gamma1, + dropout_p, + has_x1, + has_residual, + ctx.is_rms_norm, + ) + dx0 = dx0mat.view(x.shape) + dx1 = dx1mat.view(x.shape) if dx1mat is not None else None + dresidual = dresidualmat.view(x.shape) if dresidualmat is not None else None + return ( + dx0, + dx1, + dresidual, + dgamma0, + dbeta0 if ctx.has_beta else None, + dgamma1, + dbeta1 if ctx.has_beta else None, + None, + None, + None, + None, + None, + None, + ) + + +def layer_norm(x, weight, bias, epsilon): + return DropoutAddLayerNormFn.apply(x, None, weight, bias, None, None, 0.0, epsilon, False) + + +def dropout_add_layer_norm( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + rowscale=None, + layerscale=None, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormFn.apply( + x0, + residual, + weight, + bias, + rowscale, + layerscale, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_subset( + x0, + residual, + weight, + bias, + dropout_p, + epsilon, + layerscale=None, + x0_subset=None, + out_subset=None, + rowscale_const=1.0, + out_numrows=0, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormSubsetFn.apply( + x0, + residual, + weight, + bias, + layerscale, + x0_subset, + out_subset, + dropout_p, + epsilon, + rowscale_const, + out_numrows, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +def dropout_add_layer_norm_parallel_residual( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + prenorm=False, + residual_in_fp32=False, + return_dropout_mask=False, +): + """residual_in_fp32 only has an effect if residual is None. + Otherwise residual dtype is residual.dtype. + """ + return DropoutAddLayerNormParallelResidualFn.apply( + x0, + x1, + residual, + weight0, + bias0, + weight1, + bias1, + dropout_p, + epsilon, + residual_in_fp32, + prenorm, + False, + return_dropout_mask, + ) + + +class DropoutAddLayerNorm(torch.nn.Module): + def __init__( + self, + hidden_size, + prenorm=False, + p=0.0, + eps=1e-5, + residual_in_fp32=False, + device=None, + dtype=None, + ): + factory_kwargs = {"device": device, "dtype": dtype} + super().__init__() + self.prenorm = prenorm + self.p = p + self.eps = eps + self.residual_in_fp32 = residual_in_fp32 + self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.bias = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs)) + self.reset_parameters() + + def reset_parameters(self): + init.ones_(self.weight) + init.zeros_(self.bias) + + def forward(self, x0, residual=None): + return dropout_add_layer_norm( + x0, + residual, + self.weight, + self.bias, + self.p if self.training else 0.0, + self.eps, + prenorm=self.prenorm, + residual_in_fp32=self.residual_in_fp32, + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__init__.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6677db08c0dcd87f7570a3249f61cc3eabe0532a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..422168c4ed2524e41947063353ddf9544dfabfb2 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e35a0378ed7ec85d3fe8de27af5a63ba0cddbcf Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_fetch_results.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d4ac9856f99f66c678b01b900269f60c90dba527 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/batch_submit.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f16b05b182c49b18ef8cd3f7de6b1f2e79eed0f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_grid_search.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1fc6bd1cb2add052e09172e031865b87a34895f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_tasks.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a33a7a8307108d5cf06573d8601c4b80ed690a6e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/__pycache__/run_with_submitit.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_fetch_results.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_fetch_results.py new file mode 100644 index 0000000000000000000000000000000000000000..88227ac3312c13e8a88ae6e8ebd4cd44c265aedd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_fetch_results.py @@ -0,0 +1,96 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import json +import logging +from pathlib import Path +from typing import Any, Dict + +if __name__ == "__main__": + # Get the user requests + parser = argparse.ArgumentParser( + "Collect results from a given batch of distributed results" + ) + parser.add_argument("-ck", "--checkpoint_path", required=True) + args = parser.parse_args() + + logging.getLogger().setLevel(logging.INFO) + + # Go through all the data in the given repo, try to find the end results + root = Path(args.checkpoint_path) + + # - list all the mechanisms being benchmarked + results: Dict[str, Any] = {} + + for attention in filter(lambda x: x.is_dir(), root.iterdir()): + logging.info(f"\nFound results for {attention.stem}") + task_jsons = attention.glob("*/test_eval_summary.json") + results[attention.stem] = {} + + for task in task_jsons: + task_name = task.stem.split("__")[0] + logging.info(f"Logs found for task: {task_name}") + results[attention.stem][task_name] = -1 + found_result = False + + # - collect the individual results + with open(task, "r") as result_file: + dct = json.load(result_file) + if "test_accu_mean" in dct: + found_result = True + results[attention.stem][task_name] = dct["test_accu_mean"] + + logging.info( + f"Final result found for {task_name} at epoch {dct['train_step_idx']}: " + f"{results[attention.stem][task_name]}" + ) + else: + break + + # - report an error if no result was found + if not found_result: + ERR_TAIL = 30 + + logging.warning( + f"No result found for {task_name}, showing the error log in {task.parent}" + ) + err_log = Path(task.parent).glob("*.err") + print("*****************************************************") + with open(next(err_log), "r") as err_file: + for i, line in enumerate(reversed(err_file.readlines())): + print(line, end="") + if i > ERR_TAIL: + break + print("*****************************************************") + + logging.info(f"\nCollected results: {json.dumps(results, indent=2)}") + + # - reduction: compute the average + tasks = set(t for v in results.values() for t in v.keys()) + # -- fill in the possible gaps + for att in results.keys(): + for t in tasks: + if t not in results[att].keys(): + results[att][t] = 0.0 + + # -- add the average value + for att in results.keys(): + results[att]["AVG"] = round(sum(results[att][t] for t in tasks) / len(tasks), 2) + + # - Format as an array, markdown style + tasks_sort = sorted( + set(t for v in results.values() for t in v.keys()), reverse=True + ) + print( + "{0:<20}".format("") + "".join("{0:<20} ".format(t[:10]) for t in tasks_sort) + ) + + for att in results.keys(): + print( + "{0:<20}".format(att) + + "".join("{0:<20} ".format(results[att][t]) for t in tasks_sort) + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_submit.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_submit.py new file mode 100644 index 0000000000000000000000000000000000000000..a3077aa62a889fe28074d54ebaabe8e6a019f583 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/batch_submit.py @@ -0,0 +1,49 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import os +from pathlib import Path + +from xformers.benchmarks.LRA.run_tasks import Task +from xformers.components.attention import ATTENTION_REGISTRY + + +def get_default_shared_folder() -> str: + checkpoint_paths = ["/checkpoint", "/checkpoints"] + for checkpoint_path in checkpoint_paths: + if Path(checkpoint_path).is_dir(): + return checkpoint_path + + return "." + + +if __name__ == "__main__": + default_checkpoint_path = get_default_shared_folder() + + # Get the user requests + parser = argparse.ArgumentParser( + "Benchmark different attention mechanisms on various sequence lengths" + ) + parser.add_argument("-c", "--config_path", required=True) + parser.add_argument("-ck", "--checkpoint_path", required=True) + parser.add_argument( + "-a", "--attentions", nargs="+", default=list(ATTENTION_REGISTRY.keys()) + ) + parser.add_argument("-t", "--tasks", nargs="+", default=[t.value for t in Task]) + parser.add_argument( + "--partition", default="a100", type=str, help="Partition where to submit" + ) + args = parser.parse_args() + + for attention in args.attentions: + for task in args.tasks: + os.system( + "python3 run_with_submitit.py" + + f" --attention {attention} --task {task} --config {args.config_path}" + + f" --checkpoint_dir {args.checkpoint_path}/{attention}/{task}" + + f" --partition {args.partition}" + ) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__init__.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6677db08c0dcd87f7570a3249f61cc3eabe0532a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7f1743e0feee76c0345aae80c66b73c5868ca54 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e115a048b78459d2b97b7e4cb2a47e0263e68fe5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/dataset.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f07fdfdcd83f8d6169557614a35c77e23798e9d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/__pycache__/model_wrapper.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/dataset.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cf7e003219d73cf0123476b3ad08e0b0913d2642 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/dataset.py @@ -0,0 +1,46 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: Almost as-is from the Nystromformer repo +# https://github.com/mlpen/Nystromformer + +import logging +import pickle + +import torch +from torch.utils.data.dataset import Dataset + +logging.getLogger().setLevel(logging.INFO) + + +class LRADataset(Dataset): + def __init__(self, file_path, seq_len): + with open(file_path, "rb") as f: + self.examples = pickle.load(f) + + self.seq_len = seq_len + logging.info(f"Loaded {file_path}... size={len(self.examples)}") + + def __len__(self): + return len(self.examples) + + def __getitem__(self, i): + return self.create_inst(self.examples[i], self.seq_len) + + @staticmethod + def create_inst(inst, seq_len): + output = { + "input_ids_0": torch.tensor(inst["input_ids_0"], dtype=torch.long)[:seq_len] + } + output["mask_0"] = (output["input_ids_0"] != 0).float() + + if "input_ids_1" in inst: + output["input_ids_1"] = torch.tensor(inst["input_ids_1"], dtype=torch.long)[ + :seq_len + ] + output["mask_1"] = (output["input_ids_1"] != 0).float() + output["label"] = torch.tensor(inst["label"], dtype=torch.long) + return output diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/model_wrapper.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/model_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5eb3e2ca742614755a44740b4c18463f06e206d7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/code/model_wrapper.py @@ -0,0 +1,288 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +# CREDITS: adapted from the Nystromformer repo +# https://github.com/mlpen/Nystromformer + +from enum import Enum +from typing import Dict, Union + +import pytorch_lightning as pl +import torch +import torch.nn as nn + +from xformers.components import build_attention +from xformers.components.multi_head_dispatch import MultiHeadDispatchConfig +from xformers.factory import xFormer, xFormerConfig, xFormerEncoderConfig +from xformers.utils import generate_matching_config + +PLOutput = Dict[str, Union[float, torch.Tensor]] + + +class Pooling(str, Enum): + MEAN = "mean" + CLS = "cls" + + +def pooling(mode: Pooling): + def pool_cls(inp): + return inp[:, 0, :] + + def pool_mean(inp): + return inp.mean(dim=1) + + return {Pooling.MEAN: pool_mean, Pooling.CLS: pool_cls}[mode] + + +def append_cls(inp, mask, vocab_size): + batch_size = inp.size(0) + cls_id = ( + (vocab_size - 1) * torch.ones(batch_size, dtype=torch.long, device=inp.device) + ).long() + cls_mask = torch.ones(batch_size, dtype=torch.float, device=mask.device) + inp = torch.cat([cls_id[:, None], inp[:, :-1]], dim=-1) + mask = torch.cat([cls_mask[:, None], mask[:, :-1]], dim=-1) + return inp, mask + + +def patch_model_config(config, attention_name): + # Rebuild a specific config out of generic + extra params + commons = config["common"] + try: + extra_attention_settings = config["extra_settings"]["attention"][attention_name] + except KeyError: + extra_attention_settings = None + + for bc in config["xformer"]: + bc["dim_model"] = commons["dim_model"] + bc["position_encoding_config"].update(commons) + bc["feedforward_config"].update(commons) + bc["multi_head_config"].update(commons) + bc["multi_head_config"]["attention"].update(commons) + bc["multi_head_config"]["attention"]["name"] = attention_name + bc["multi_head_config"]["attention"]["dim_head"] = ( + commons["dim_model"] / commons["num_heads"] + ) + if extra_attention_settings is not None: + bc["multi_head_config"]["attention"].update(extra_attention_settings) + + bc["multi_head_config"] = generate_matching_config( + bc["multi_head_config"], MultiHeadDispatchConfig + ) + bc["multi_head_config"].attention = build_attention( + bc["multi_head_config"].attention + ) + bc = generate_matching_config(bc, xFormerEncoderConfig) + + return config + + +class SCHead(nn.Module): + def __init__(self, config, dim_embedding, dim_mlp): + super().__init__() + self.pooling = pooling(Pooling(config["pooling_mode"])) + + self.mlpblock = nn.Sequential( + nn.Linear(dim_embedding, dim_mlp), + nn.ReLU(), + nn.Linear(dim_mlp, config["common"]["num_classes"]), + ) + + def forward(self, inp: torch.Tensor): + seq_score = self.mlpblock(self.pooling(inp)) + return seq_score + + +class SCHeadDual(nn.Module): + def __init__(self, config, dim_embedding, dim_mlp): + super().__init__() + self.pooling = pooling(Pooling(config["pooling_mode"])) + + self.mlpblock = nn.Sequential( + nn.Linear( + dim_embedding * 4, + dim_mlp, + ), + nn.ReLU(), + nn.Linear(dim_mlp, config["common"]["num_classes"]), + ) + + def forward(self, inp_0: torch.Tensor, inp_1: torch.Tensor): + X_0 = self.pooling(inp_0) + X_1 = self.pooling(inp_1) + seq_score = self.mlpblock(torch.cat([X_0, X_1, X_0 * X_1, X_0 - X_1], dim=-1)) + return seq_score + + +class ModelTrunk(pl.LightningModule): + def __init__(self, config, model_name): + super().__init__() + + config_model = config["model"] + self.config_training = config["training"] + + self.enable_amp = config["training"]["mixed_precision"] + self.pooling_mode = Pooling(config_model["pooling_mode"]) + self.vocab_size = config_model["common"]["vocab_size"] + + # Rebuild a specific config out of generic + extra params + self.config_model = patch_model_config(config_model, model_name) + self.model = xFormer.from_config(xFormerConfig(config_model["xformer"])) + self.norm = nn.LayerNorm(self.config_model["common"]["dim_model"]) + + ff_config = self.config_model["xformer"][0]["feedforward_config"] + self.dim_mlp = ( + self.config_model["common"]["dim_model"] + * ff_config["hidden_layer_multiplier"] + ) + + def training_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + outputs = self(**batch) + self.logger.log_metrics({f"train_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("train_accu", outputs["accu"], sync_dist=True) + return outputs + + def training_epoch_end(self, outputs): + logs = self.eval_epoch_end(outputs) + self.log("train_accu_mean", logs["accu"], sync_dist=True) + + def configure_optimizers(self): + optimizer = torch.optim.AdamW( + self.parameters(), + lr=self.config_training["learning_rate"], + betas=(0.9, 0.999), + eps=1e-6, + weight_decay=self.config_training["weight_decay"], + ) + + lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( + optimizer=optimizer, + max_lr=self.config_training["learning_rate"], + pct_start=self.config_training["warmup"] + / self.config_training["num_train_steps"], + anneal_strategy=self.config_training["lr_decay"], + total_steps=self.config_training["num_train_steps"], + ) + + return [optimizer], [lr_scheduler] + + def eval_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> PLOutput: + outputs = self(**batch) + return outputs + + def eval_epoch_end(self, outputs, prefix: str = "train"): + logs = {} + counts = torch.tensor([x["count"] for x in outputs]).float() + logs["count"] = counts.sum() + for k in ("accu", "loss"): + logs[k] = (torch.tensor([x[k] for x in outputs]) * counts).sum() / logs[ + "count" + ] + self.log(f"{prefix}_{k}_mean", logs[k], sync_dist=True) + return logs + + def validation_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + outputs = self.eval_step(batch, batch_idx) + self.logger.log_metrics({f"val_{k}": v for k, v in outputs.items()}) # type: ignore + self.log("val_accu", outputs["accu"], sync_dist=True, prog_bar=True) + return outputs + + def validation_epoch_end(self, outputs): + self.eval_epoch_end(outputs, prefix="val") + + def test_step( # type: ignore + self, batch: Dict[str, torch.Tensor], batch_idx: int + ) -> PLOutput: + return self.eval_step(batch, batch_idx) + + def test_epoch_end(self, outputs): + self.eval_epoch_end(outputs, prefix="test") + + +class ModelForSC(ModelTrunk): + def __init__(self, config, model_name): + # Setup trunk + super().__init__(config, model_name) + + self.seq_classifer = SCHead( + self.config_model, + dim_embedding=self.config_model["common"]["dim_model"], + dim_mlp=self.dim_mlp, + ) + + def forward( # type: ignore + self, input_ids_0: torch.Tensor, mask_0: torch.Tensor, label: torch.Tensor + ): + + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + + token_out = self.norm( + self.model(input_ids_0, encoder_input_mask=mask_0) + ) * mask_0.unsqueeze(-1) + + seq_scores = self.seq_classifer(token_out) + + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } + + return outputs + + +class ModelForSCDual(ModelTrunk): + def __init__(self, config, model_name): + # Setup trunk + super().__init__(config, model_name) + + self.seq_classifer = SCHeadDual( + self.config_model, + dim_embedding=self.config_model["common"]["dim_model"], + dim_mlp=self.dim_mlp, + ) + + def forward( # type: ignore + self, + input_ids_0: torch.Tensor, + input_ids_1: torch.Tensor, + mask_0: torch.Tensor, + mask_1: torch.Tensor, + label: torch.Tensor, + ): + + mask_0, mask_1 = mask_0.long(), mask_1.long() + + if self.pooling_mode == Pooling.CLS: + input_ids_0, mask_0 = append_cls(input_ids_0, mask_0, self.vocab_size) + input_ids_1, mask_1 = append_cls(input_ids_1, mask_1, self.vocab_size) + + # Concatenate the two inputs into one batch + input_ids = torch.cat([input_ids_0, input_ids_1], dim=0) + masks = torch.cat([mask_0, mask_1], dim=0) + + tokens_out = self.norm( + self.model(input_ids, encoder_input_mask=masks) + ) * masks.unsqueeze(-1) + + seq_scores = self.seq_classifer(*torch.chunk(tokens_out, 2, dim=0)) + + seq_loss = torch.nn.CrossEntropyLoss(reduction="none")(seq_scores, label) + seq_accu = (seq_scores.argmax(dim=-1) == label).to(torch.float32) + outputs = { + "loss": seq_loss.mean(), + "accu": seq_accu.mean(), + "count": label.size(0), + } + + return outputs diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_grid_search.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_grid_search.py new file mode 100644 index 0000000000000000000000000000000000000000..a30b9e43e1e06ef343972fec155509175a2b9e9e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_grid_search.py @@ -0,0 +1,148 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +import os +import uuid +from datetime import date +from pathlib import Path +from typing import Dict, Iterable + +import submitit + +from xformers.benchmarks.LRA.run_with_submitit import ( + Trainer, + get_init_file, + get_shared_folder, + parse_args, +) + + +def grid_parameters(grid: Dict): + """ + Yield all combinations of parameters in the grid (as a dict) + """ + grid_copy = dict(grid) + # Turn single value in an Iterable + for k in grid_copy: + if not isinstance(grid_copy[k], Iterable): + grid_copy[k] = [grid_copy[k]] + for p in itertools.product(*grid_copy.values()): + yield dict(zip(grid.keys(), p)) + + +def grid_search(args): + if args.checkpoint_dir == "": + args.checkpoint_dir = get_shared_folder() / "%j" + + date_curr = date.today().strftime("%m-%d-%Y") + orig_check_dir = os.path.join(args.checkpoint_dir, date_curr) + + # Create the executor + # Note that the folder will depend on the job_id, to easily track experiments + executor = submitit.AutoExecutor( + folder=get_shared_folder() / "%j", slurm_max_num_timeout=30 + ) + num_gpus_per_node = args.ngpus + nodes = args.nodes + args.world_size = args.nodes * args.ngpus + partition = args.partition + + executor.update_parameters( + gpus_per_node=num_gpus_per_node, + tasks_per_node=num_gpus_per_node, # one task per GPU + cpus_per_task=10, + nodes=nodes, + timeout_min=60 * 72, + slurm_signal_delay_s=120, + slurm_partition=partition, + ) + executor.update_parameters(name="lra") + + if args.task == "text": + grid_meta = { + "training:learning_rate": ( + [1e-4, 2e-4, 3e-4, 5e-5], + lambda val: f"lr{val}", + ), + "training:warmup": ([3000, 8000], lambda val: f"warmup{val}"), + "training:seed": ([1234, 32, 1994], lambda val: f"seed{val}"), + "training:weight_decay": ([0.02, 0.05, 0.01], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0, 0.05], lambda val: f"drop{val}"), + } + elif args.task == "retrieval": + grid_meta = { + "training:learning_rate": ([1e-4, 3e-4], lambda val: f"lr{val}"), + "training:warmup": ([2000, 8000], lambda val: f"warmup{val}"), + "training:seed": ([4096, 1234, 3, 15, 5], lambda val: f"seed{val}"), + "training:weight_decay": ([0.01, 0], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0], lambda val: f"drop{val}"), + } + elif args.task == "listops": + grid_meta = { + "training:learning_rate": ( + [1e-4, 2e-4, 3e-4, 5e-5], + lambda val: f"lr{val}", + ), + "training:warmup": ([3000, 2000], lambda val: f"warmup{val}"), + "training:seed": ( + [ + 1234, + ], + lambda val: f"seed{val}", + ), + "training:weight_decay": ([0.02, 0.05, 0, 1], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0], lambda val: f"drop{val}"), + } + else: + grid_meta = { + "training:learning_rate": ([1e-4, 5e-5], lambda val: f"lr{val}"), + "training:warmup": ([8000], lambda val: f"warmup{val}"), + "training:seed": ([1234, 4321, 3], lambda val: f"seed{val}"), + "training:weight_decay": ([0.01], lambda val: f"wd{val}"), + "model:pooling_model": (["cls"], lambda val: f"pool-{val}"), + "model:common:dropout": ([0.1], lambda val: f"drop{val}"), + } + + grid = {k: v[0] for k, v in grid_meta.items()} + save_key = {k: v[1] for k, v in grid_meta.items()} + + hyper_parameters = list(grid_parameters(grid)) + jobs = [] + + for i, grid_data in enumerate(hyper_parameters): + + args.sweep_parameters = grid_data + run_name = f"{args.attention}" + # run_name = "paper_config" + for k, v in grid_data.items(): + run_name += "prenorm-" + save_key[k](v) + args.checkpoint_dir = os.path.join( + orig_check_dir, f"{args.task}", "logs", run_name + ) + Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) + args.tb_dir = os.path.join(orig_check_dir, f"{args.task}", "tb", run_name) + Path(args.tb_dir).mkdir(parents=True, exist_ok=True) + + # Chronos needs a different job name each time + executor.update_parameters(name=f"lra_{args.task}_{i:02d}_{uuid.uuid4().hex}") + + args.dist_url = get_init_file().as_uri() + args.temp_file = str(get_init_file()) + + trainer = Trainer(args) + job = executor.submit(trainer) + jobs.append(job) + print(f"Run {i:02d} submitted with train cfg: {args}") + print(f"Submitted jobs ids: {','.join([str(job.job_id) for job in jobs])}") + + +if __name__ == "__main__": + args = parse_args() + grid_search(args) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_tasks.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_tasks.py new file mode 100644 index 0000000000000000000000000000000000000000..41c5fbe55ed9bece3fc9faa7c7075ee8b7db2a37 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_tasks.py @@ -0,0 +1,302 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import argparse +import json +import logging +import os +from enum import Enum +from pathlib import Path +from typing import Dict, Tuple, cast + +import pytorch_lightning as pl +import torch +import torch.nn as nn +from fvcore.nn import FlopCountAnalysis, flop_count_str +from pytorch_lightning.callbacks import ModelCheckpoint, TQDMProgressBar +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.strategies import DDPStrategy +from torch.utils.data import DataLoader + +from xformers.benchmarks.LRA.code.dataset import LRADataset +from xformers.benchmarks.LRA.code.model_wrapper import ModelForSC, ModelForSCDual +from xformers.components.attention import ATTENTION_REGISTRY + + +class Task(str, Enum): + Retrieval = "retrieval" + ListOps = "listops" + Image = "image" + PathfinderBaseline = "pathfinder32-curv_baseline" + PathfinderContour9 = "pathfinder32-curv_contour_length_9" + PathfinderContour14 = "pathfinder32-curv_contour_length_14" + Text = "text" + + +def load_config(path: str) -> Dict: + with open(Path(path).absolute(), "r") as fileio: + config = json.load(fileio) + + # Duplicate the pathfinder configs + config["pathfinder32-curv_baseline"] = config["pathfinder32"] + config["pathfinder32-curv_contour_length_9"] = config["pathfinder32"] + config["pathfinder32-curv_contour_length_14"] = config["pathfinder32"] + return config + + +def build_model(args: argparse.Namespace, config: Dict) -> nn.Module: + task = args.task + attention_name = args.attention + + model = cast( + pl.LightningModule, + ( + ModelForSCDual(config[f"{task}"], attention_name) + if task == Task.Retrieval + else ModelForSC(config[f"{task}"], attention_name) + ), + ) + + logging.info(model) + summary = pl.utilities.model_summary.LayerSummary(model) + logging.info(f"num_parameter: {summary.num_parameters // 1e3 / 1e3}M") + + with torch.no_grad(): + # Check the flops + seq_len = config[f"{task}"]["model"]["common"]["seq_len"] + x = torch.rand(1, seq_len).long() + mask = torch.rand(1, seq_len).long() + indices = torch.rand(1, seq_len).long() + flops = FlopCountAnalysis(model.model, (x, mask, indices)) + logging.info(f"complexity: {round(flops.total()/1e9, 3)} GFlops") + logging.info(flop_count_str(flops)) + + return model + + +def get_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--attention", + type=str, + help=f"Attention mechanism to chose, among {list(ATTENTION_REGISTRY.keys())}. \ + A list can be passed to test several mechanisms in sequence", + dest="attention", + required=True, + ) + parser.add_argument( + "--task", + type=Task, + help=f"Task to chose, among {[t.value for t in Task]}.", + dest="task", + required=True, + ) + parser.add_argument( + "--skip_train", + type=bool, + help="Whether to skip training, and test an existing model", + dest="skip_train", + default=False, + ) + parser.add_argument( + "--config", + type=str, + help="Path to the config being used", + dest="config", + default="./config.json", + ) + parser.add_argument( + "--checkpoint_dir", + type=str, + help="Path to the checkpoint directory", + dest="checkpoint_dir", + default=f"/checkpoints/{os.getenv('USER')}/xformers", + ) + parser.add_argument( + "--checkpoint_path", + type=str, + help="Path to checkpoint", + ) + parser.add_argument( + "--debug", + help="Make it easier to debug a possible issue", + dest="debug", + default=False, + action="store_true", + ) + parser.add_argument( + "--world_size", + help="Number of GPUs used", + dest="world_size", + type=int, + default=1, + ) + parser.add_argument( + "--sweep_parameters", + help="Rewrite some hyperparameters in the config", + dest="sweep_parameters", + type=dict, + default=None, + ) + return parser + + +def setup_log(args, attention_name, task) -> Tuple[str, TensorBoardLogger]: + experiment_name = f"{task}__{attention_name}" + logger = TensorBoardLogger( + save_dir=args.checkpoint_dir, + name="", # remove lightning_logs subdirectory + version=experiment_name, + ) + log_dir = os.path.join(logger._save_dir, experiment_name) + return log_dir, logger + + +def rewrite_hyper(config, rewrites): + def replace(config_dict, k, v): + if len(k.split(":")) == 1: + config_dict[k] = v + return + first_key = k.split(":")[0] + assert first_key in config_dict, first_key + k = k[len(first_key) + 1 :] + replace(config_dict[first_key], k, v) + + for k, v in rewrites.items(): + replace(config, k, v) + return config + + +def build_dataloaders( + args: argparse.Namespace, + config_training: Dict, + num_workers: int = 4, +) -> Dict[str, DataLoader]: + datasets = {} + for component in ("train", "dev", "test"): + datasets[component] = LRADataset( + file_path=f"datasets/{args.task}.{component}.pickle", + seq_len=config_training["seq_len"], + ) + + # Gradient accumulation + accumu_steps = config_training["gradient_accumulation"] + logging.info(f"accumu_steps={accumu_steps}") + + # Batch size + per_gpu_batch_size = ( + config_training["batch_size"] // args.world_size // accumu_steps + ) + logging.warning( + f"Requested batch size: {config_training['batch_size']}. Given world\ + size and grad accumulation, per-gpu batch is\ + {per_gpu_batch_size}" + ) + + dataloaders = { + k: DataLoader( + v, + batch_size=per_gpu_batch_size, + shuffle=False, + pin_memory=True, + num_workers=num_workers, + ) + for k, v in datasets.items() + } + return dataloaders + + +def get_eval_summary(trainer: pl.Trainer) -> Dict[str, float]: + eval_summary: Dict[str, float] = {"train_step_idx": trainer.global_step} + for k, v in trainer.callback_metrics.items(): + eval_summary[k] = v.item() + return eval_summary + + +class BasicProgressBar(TQDMProgressBar): + def get_metrics(self, trainer, model): + items = super().get_metrics(trainer, model) + items.pop("v_num", None) + return items + + +def benchmark(args): + log_dir, logger = setup_log(args, f"{args.attention}", f"{args.task}") + args.logger = logger + + config = load_config(args.config) + + config_task = config[f"{args.task}"] + if args.sweep_parameters is not None: + logging.info("Replacing hyperparameters") + rewrite_hyper(config_task, args.sweep_parameters) + + config_training = config_task["training"] + config_training["seq_len"] = config_task["model"]["common"]["seq_len"] + logging.info(f"Learning rate: {config_training['learning_rate']}") + + pl.seed_everything(config_training.get("seed", 0)) + dataloaders = build_dataloaders(args, config_training) + + model = build_model(args, config) + + progress_bar = BasicProgressBar() + checkpoint_callback = ModelCheckpoint( + monitor="val_accu", + mode="max", + dirpath=args.checkpoint_dir, + filename="{epoch}-{val_accu:.2f}", + every_n_train_steps=config_training["eval_frequency"], + ) + + trainer = pl.Trainer( + accelerator="gpu", + strategy=( + DDPStrategy(find_unused_parameters=args.debug) + if not args.skip_train + else None + ), + accumulate_grad_batches=config_training["gradient_accumulation"], + callbacks=[progress_bar, checkpoint_callback], + detect_anomaly=args.debug, + deterministic=True, + gpus=args.world_size, + limit_val_batches=config_training["num_eval_steps"], + logger=logger, + max_steps=config_training["num_train_steps"], + num_sanity_val_steps=int(not args.skip_train), + precision=16 if config_training["mixed_precision"] else 32, + val_check_interval=config_training["eval_frequency"] + / float(len(dataloaders["train"])), + ) + + if not args.skip_train: + trainer.fit( + model, + train_dataloaders=dataloaders["train"], + val_dataloaders=dataloaders["dev"], + ) + ckpt_path = checkpoint_callback.best_model_path + else: + ckpt_path = args.checkpoint_path + + trainer.test( + model, + dataloaders=dataloaders["test"], + ckpt_path=ckpt_path, + ) + eval_summary = get_eval_summary(trainer) + with open(os.path.join(log_dir, "test_eval_summary.json"), "w") as f: + logging.info(f"Saving test results at {f.name}") + json.dump(eval_summary, f) + + +if __name__ == "__main__": + parser = get_arg_parser() + args = parser.parse_args() + if args.skip_train and args.checkpoint_path is None: + raise parser.error("Must provide --checkpoint_path if --skip_train=True") + benchmark(args) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_with_submitit.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_with_submitit.py new file mode 100644 index 0000000000000000000000000000000000000000..6fefdd1f54ff7a3c312f8010913dc3955805b539 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/LRA/run_with_submitit.py @@ -0,0 +1,153 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +""" +A script to run multinode training with submitit. +Almost copy-paste from https://github.com/facebookresearch/deit/blob/main/run_with_submitit.py +""" + +import argparse +import os +import uuid +from pathlib import Path + +import submitit + +from xformers.benchmarks.LRA.run_tasks import benchmark, get_arg_parser + + +def parse_args(): + parser = argparse.ArgumentParser( + "Submitit for LRA", parents=[get_arg_parser()], add_help=False + ) + parser.add_argument( + "--ngpus", default=1, type=int, help="Number of gpus to request on each node" + ) + parser.add_argument( + "--nodes", default=1, type=int, help="Number of nodes to request" + ) + parser.add_argument("--timeout", default=2800, type=int, help="Duration of the job") + + parser.add_argument( + "--partition", default="a100", type=str, help="Partition where to submit" + ) + parser.add_argument( + "--use_volta32", action="store_true", help="Big models? Use this" + ) + parser.add_argument( + "--enforce_host_memory", action="store_true", help="Use if the host OOMs" + ) + + parser.add_argument( + "--comment", + default="", + type=str, + help="Comment to pass to scheduler, e.g. priority message", + ) + return parser.parse_args() + + +def get_shared_folder() -> Path: + user = os.getenv("USER") + checkpoint_paths = ["/checkpoint", "/checkpoints"] + for checkpoint_path in checkpoint_paths: + if Path(checkpoint_path).is_dir(): + p = Path(f"{checkpoint_path}/{user}/xformers/submitit") + p.mkdir(exist_ok=True, parents=True) + return p + raise RuntimeError(f"No shared folder available - considering {checkpoint_paths}") + + +def get_init_file(): + # Init file must not exist, but it's parent dir must exist. + os.makedirs(str(get_shared_folder()), exist_ok=True) + init_file = get_shared_folder() / f"{uuid.uuid4().hex}_init" + if init_file.exists(): + os.remove(str(init_file)) + return init_file + + +class Trainer: + def __init__(self, args): + self.args = args + + def __call__(self): + self._setup_gpu_args() + benchmark(self.args) + + def checkpoint(self): + self.args.dist_url = get_init_file().as_uri() + print("Requeuing ", self.args) + empty_trainer = type(self)(self.args) + return submitit.helpers.DelayedSubmission(empty_trainer) + + def _setup_gpu_args(self): + job_env = submitit.JobEnvironment() + self.args.checkpoint_dir = Path( + str(self.args.checkpoint_dir).replace("%j", str(job_env.job_id)) + ) + self.args.gpu = job_env.local_rank + self.args.rank = job_env.global_rank + self.args.world_size = job_env.num_tasks + print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}") + + +def main(): + args = parse_args() + if args.checkpoint_dir == "": + args.checkpoint_dir = get_shared_folder() / "%j" + Path(args.checkpoint_dir).mkdir(parents=True, exist_ok=True) + executor = submitit.AutoExecutor( + folder=args.checkpoint_dir, slurm_max_num_timeout=30 + ) + + num_gpus_per_node = args.ngpus + nodes = args.nodes + timeout_min = args.timeout + args.world_size = args.nodes * args.ngpus + + partition = args.partition + + kwargs = { + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "timeout_min": timeout_min, # max is 60 * 72 + # Below are cluster dependent parameters + "slurm_partition": partition, + "slurm_signal_delay_s": 120, + } + + if args.enforce_host_memory: + kwargs["mem_gb"] = (40 * num_gpus_per_node,) + + if args.use_volta32: + kwargs["slurm_constraint"] = "volta32gb" + + if args.comment: + kwargs["slurm_comment"] = args.comment + + executor.update_parameters( + **kwargs, + ) + + executor.update_parameters(name="lra") + + args.dist_url = get_init_file().as_uri() + args.temp_file = str(get_init_file()) + + trainer = Trainer(args) + job = executor.submit(trainer) + + print(f"Submitted job_id: {job.job_id}") + print(f"Logs and checkpoints will be saved at: {args.checkpoint_dir}") + with open(Path(f"{args.checkpoint_dir}") / Path("jobs.txt"), "a") as jobfile: + jobfile.write(f"{job.job_id}\n") + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__init__.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6677db08c0dcd87f7570a3249f61cc3eabe0532a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f814b5d4499c80169e3d537344c7c0e5c62e6015 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c28259d99b91baf4877b28fd3ff67f6e8dd75e8 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_attn_decoding.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_core.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_core.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..77779ed5c10fc037e5258dd6780fe9a0f8d17f78 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_core.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08afb870b3ac9d2627519aae32695ff83188b5af Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_indexing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fad283160eb7842d90e759dd54ab9dc3f9a1b7d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_mem_eff_attention.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_merge_attentions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_merge_attentions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b2f5930dd3b2c2b612f3588e0f09fb94584e3d1 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_merge_attentions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46a5bec55f12362761927d5b0e79ee6b28e42636 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_multi_head_dispatch.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..098603238c834303db7e1700bac98a3563b38715 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_nystrom_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62d90ab603fed80bd431e84180b4f637c306e93f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_revnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09ecd9b573e776696536bfb89841330a7b6e2b84 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sddmm.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sequence_parallel_fused.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sequence_parallel_fused.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8536326e047d27743f1ea061fdce69446ba97839 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sequence_parallel_fused.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sp24.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sp24.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dd993ae79294c949ebb025d1013aa96b2cd1e5f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_sp24.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a12b8ff1d88b25ea750eae6247f1d38c3aa1bb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_swiglu.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_tiled_matmul.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_tiled_matmul.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2357e9ebe0296d778f52b58c1cf3c1bd8fd7b26b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/benchmark_tiled_matmul.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1785f1d276ce5484c1132ba5329f33e77eeb5082 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/benchmarks/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_attn_decoding.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_attn_decoding.py new file mode 100644 index 0000000000000000000000000000000000000000..f78fa9806c61afd1e1d4b38f277816b2d450322f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_attn_decoding.py @@ -0,0 +1,405 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +# Run with --omit-baselines to skip slow baselines. +# See other CLI arguments in benchmark_main_helper in utils.py. + +import sys +from typing import Any, Dict, Type + +import pytest +import torch + +import xformers.ops as xops +from xformers.attn_bias_utils import create_attn_bias +from xformers.benchmarks.utils import NotSupportedInputError, benchmark_main_helper2 + +min_run_time = 0.5 +device = torch.device("cuda") + + +CASES = [ + dict( + B=max(1, 2 ** (16 - i)), + Mq=1, + Mkv=2**i, + Hq=16, + Hkv=hkv, + K=128, + attn_bias_type=xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + ) + for i in range(8, 18) + for hkv in (1, 2) +] + + +def quantize_kv_int4(k: torch.Tensor, num_groups: int = 1) -> torch.Tensor: + """ + Auxiliary int4 row quantization function used for benchmarking and tests. + Matches the behaviour of torch.ops.llama_cpp.dequantize_int4_cache - + quantization parameters (scale and offset) of each row along the last + dimension of the tensor are assumed to be packed into two float16 values + at the beginning of the row. + """ + # Scale and shift are such that quantization linearly maps int4 values range [0..15] + # to input values range min(k)..max(k) individually for every row + k = k.reshape(*k.shape[:-1], num_groups, k.shape[-1] // num_groups) + # print(f"k_reshape = {k.shape}") + max_vals = torch.max(k, dim=-1, keepdim=True).values + min_vals = torch.min(k, dim=-1, keepdim=True).values + scale_k: torch.Tensor = (max_vals - min_vals) / 15 + # print(f"scale_k_shape = {scale_k.shape}") + + shift_k = torch.min(k, dim=-1, keepdim=True).values + scale_k = scale_k.to(torch.float16) + shift_k = shift_k.to(torch.float16) + in_bytes = ((k - shift_k.expand(k.shape)) / scale_k.expand(k.shape)) + 0.5 + in_bytes = in_bytes.to(torch.uint8) + in_int4 = in_bytes & 0xF + in_int4_packed = in_int4[..., ::2] + (in_int4[..., 1::2] << 4) + scale_shift = torch.concat( + [scale_k.view(torch.uint8), shift_k.view(torch.uint8)], dim=-1 + ) + k_quant = torch.concat( + [ + scale_shift.flatten(start_dim=-2), + in_int4_packed.flatten(start_dim=-2), + ], + dim=-1, + ).view(torch.int16) + return k_quant + + +class AttentionDecodingBase: + OP: Any = None + + def __init__( + self, + B: int, + Mq: int, + Mkv: int, + Hq: int, + Hkv: int, + K: int, + bw: bool, + attn_bias_type, + ) -> None: + dtype = torch.float16 + torch.manual_seed(10) + self.sub_label = ( + f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K} TotalBytes=" + f"{((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2}" + ) + self.label = "attn_decoding" + self.shapes = (B, Mq, Mkv, Hq, Hkv, K) + + assert Hkv <= Hq + assert Hq % Hkv == 0 + self.q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw + ) + self.k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ).expand(-1, -1, -1, Hq // Hkv, -1) + self.v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ).expand(-1, -1, -1, Hq // Hkv, -1) + + if Hq == Hkv: + self.q = self.q[:, :, :, 0] + self.k = self.k[:, :, :, 0] + self.v = self.v[:, :, :, 0] + if Hkv == 1: + self.q = self.q[:, :, 0] + self.k = self.k[:, :, 0] + self.v = self.v[:, :, 0] + + self.attn_bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=Hq, + num_heads_groups=Hq // Hkv, + q_len=Mq, + kv_len=Mkv, + dtype=dtype, + device=device, + requires_grad=False, + fmt="BMHK", + op=self.OP, + ) + + if isinstance( + self.attn_bias, + xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + ): + self.q = self.q.view(1, -1, *self.q.shape[2:]) + self.k = self.k.view(1, -1, *self.k.shape[2:]) + self.v = self.v.view(1, -1, *self.v.shape[2:]) + + if hasattr(self.OP, "not_supported_reasons"): + inp = xops.fmha.Inputs( + query=self.q, key=self.k, value=self.v, attn_bias=self.attn_bias + ) + not_supported_reasons = self.OP.not_supported_reasons(inp) + if not_supported_reasons: + raise NotSupportedInputError(not_supported_reasons) + + def get_inputs(self): + inp = xops.fmha.Inputs( + query=self.q, key=self.k, value=self.v, attn_bias=self.attn_bias + ) + return inp + + def fw(self) -> None: + try: + xops.memory_efficient_attention_forward( + self.q, self.k, self.v, op=self.OP, attn_bias=self.attn_bias + ) + except (RuntimeError, ValueError) as e: + print(f"Runtime error: {e}") + + +class AttentionDecodingCUTLASS(AttentionDecodingBase): + OP = xops.fmha.cutlass.FwOp + + +class AttentionDecodingCK(AttentionDecodingBase): + OP = xops.fmha.ck.FwOp + + +class AttentionDecodingCKDecoder(AttentionDecodingBase): + OP = xops.fmha.ck_decoder.FwOp + + +class AttentionDecodingSplitKV(AttentionDecodingBase): + OP = xops.fmha.triton_splitk.FwOp + + +class AttentionDecodingCKSplitKV(AttentionDecodingBase): + OP = xops.fmha.ck_splitk.FwOp + + +class AttentionDecodingSplitInt4KV(AttentionDecodingBase): + OP = xops.fmha.triton_splitk.FwOp + + def __init__( + self, + B: int, + Mq: int, + Mkv: int, + Hq: int, + Hkv: int, + K: int, + bw: bool, + attn_bias_type, + ) -> None: + # super(AttentionDecodingSplitInt4KV, self).__init__(B, Mq, Mkv, Hq, Hkv, K, bw, attn_bias_type) + dtype = torch.float16 + torch.manual_seed(10) + self.sub_label = ( + f"B={B} Mq={Mq} Mkv={Mkv} Hq={Hq} Hkv={Hkv} K={K} TotalBytes=" + f"{((B * Mkv * Hkv * K * 2) + (B * Mq * Hq * K) + (B * Mq * Hq * K)) * 2}" + ) + self.label = "attn_decoding" + self.shapes = (B, Mq, Mkv, Hq, Hkv, K) + + assert Hkv <= Hq + assert Hq % Hkv == 0 + self.q = torch.randn( + [B, Mq, Hkv, Hq // Hkv, K], device="cuda", dtype=dtype, requires_grad=bw + ) + self.k = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ) + self.v = torch.randn( + [B, Mkv, Hkv, 1, K], device="cuda", dtype=dtype, requires_grad=bw + ) + + num_groups = 1 + self.k = ( + quantize_kv_int4(self.k, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ).expand(-1, -1, -1, Hq // Hkv, -1) + self.v = ( + quantize_kv_int4(self.v, num_groups=num_groups) + .contiguous() + .view(torch.int32) + ).expand(-1, -1, -1, Hq // Hkv, -1) + + if Hq == Hkv: + self.q = self.q[:, :, :, 0] + self.k = self.k[:, :, :, 0] + self.v = self.v[:, :, :, 0] + if Hkv == 1: + self.q = self.q[:, :, 0] + self.k = self.k[:, :, 0] + self.v = self.v[:, :, 0] + + self.attn_bias = create_attn_bias( + attn_bias_type, + batch_size=B, + num_heads=Hq, + num_heads_groups=Hq // Hkv, + q_len=Mq, + kv_len=Mkv, + dtype=dtype, + device=device, + requires_grad=False, + fmt="BMHK", + op=self.OP, + ) + + if isinstance( + self.attn_bias, + xops.fmha.attn_bias.BlockDiagonalCausalWithOffsetPaddedKeysMask, + ): + self.q = self.q.view(1, -1, *self.q.shape[2:]) + self.k = self.k.view(1, -1, *self.k.shape[2:]) + self.v = self.v.view(1, -1, *self.v.shape[2:]) + + if hasattr(self.OP, "not_supported_reasons"): + inp = xops.fmha.Inputs( + query=self.q, key=self.k, value=self.v, attn_bias=self.attn_bias + ) + not_supported_reasons = self.OP.not_supported_reasons(inp) + if not_supported_reasons: + raise NotSupportedInputError(not_supported_reasons) + + +class AttentionDecodingPyTorchRepeat(AttentionDecodingBase): + def fw(self) -> None: + B, Mq, Mkv, Hq, Hkv, K = self.shapes + scale = 1 / K**0.5 + q = self.q.reshape([B, Mq, -1, K]).permute(0, 2, 1, 3) + k = self.k.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + v = self.v.reshape([B, Mkv, -1, K]).permute(0, 2, 1, 3) + attn = (q @ k.transpose(-1, -2) * scale).softmax(-1) + return attn @ v + + +BENCHMARKS: Dict[str, Type[AttentionDecodingBase]] = { + "pytorch": AttentionDecodingPyTorchRepeat, +} + +if torch.version.cuda: + BENCHMARKS["cutlass"] = AttentionDecodingCUTLASS + +if torch.version.hip: + BENCHMARKS.update( + { + "ck": AttentionDecodingCK, + "ck-decoder": AttentionDecodingCKDecoder, + "ck_splitK": AttentionDecodingCKSplitKV, + } + ) + + +if (sys.version_info.major, sys.version_info.minor) >= (3, 9): + BENCHMARKS["triton_splitK"] = AttentionDecodingSplitKV + BENCHMARKS["triton_int4KV"] = AttentionDecodingSplitInt4KV + +try: + import flash_attn + + class AttentionDecodingFlashAttention(AttentionDecodingBase): + def fw(self) -> None: + q, k, v = self.q, self.k, self.v + if q.ndim == 5: + B, Mq, H1, H2, K = q.shape + B, Mkv, H1, H2, K = k.shape + q = q.reshape([B, Mq, H1 * H2, K]) + k = k[:, :, :, 0] + v = v[:, :, :, 0] + return flash_attn.flash_attn_func(q, k, v) + + BENCHMARKS[ + f"flash-attention@{flash_attn.__version__}" + ] = AttentionDecodingFlashAttention +except ImportError: + pass + + +TEST_CASES = [ + dict( + B=max(1, 2 ** (16 - i)), + Mq=1, + Mkv=2**i, + Hq=16, + Hkv=hkv, + K=128, + attn_bias_type=None, + ) + for i in range(8, 18) + for hkv in range(1, 3) +] + [ + dict(B=i, Mq=1, Mkv=4097, Hq=8, Hkv=1, K=128, attn_bias_type=None) + for i in [2, 4, 8, 16, 32, 64, 128] +] + + +def get_benchmark_names(): + decoder_names = list(BENCHMARKS.keys()) + decoder_names.remove("pytorch") + return decoder_names + + +# tests to verify correctness of each decoder implementation +@pytest.mark.parametrize( + "name, case", + [(name, case) for name in get_benchmark_names() for case in TEST_CASES], +) +def test_flash_attention_decoder(name, case): + baseline = AttentionDecodingPyTorchRepeat( + case["B"], + case["Mq"], + case["Mkv"], + case["Hq"], + case["Hkv"], + case["K"], + False, + case["attn_bias_type"], + ) + if name == "ck-decoder" and case["Mkv"] >= 2**14: + pytest.skip("ck-decoder does not support Mkv >= 16K") + + baseline_out = baseline.fw() + inputs = baseline.get_inputs() + decoder = BENCHMARKS[name] + + assert name in ["ck-decoder", "ck_splitK", "ck", "triton_splitK", "triton_int4KV"] + decoder_output, ctx = decoder.OP.apply(inputs, False) + + q, k, v = inputs.get_qkv_in_bmghk() + B, M, G, H, Kq = q.shape + mqa_swap_seqlen_head = False + if k.shape[3] > 1 and k.stride(3) == 0 and v.stride(3) == 0: + mqa_swap_seqlen_head = True + if mqa_swap_seqlen_head: + decoder_output = ( + decoder_output.reshape(B, -1, M, Kq).transpose(1, 2).contiguous() + ) + else: + decoder_output = decoder_output.reshape(B, H * G, -1, Kq).contiguous() + + decoder_output = decoder_output.transpose(2, 1).contiguous() + torch.testing.assert_close(decoder_output, baseline_out, atol=1e-2, rtol=0) + + +def main() -> None: + """ + run performance benchmark + """ + benchmark_main_helper2( + "attn_decoding", + fw=True, + cases=CASES, + functions=BENCHMARKS, + min_run_time=min_run_time, + ) + + +if __name__ == "__main__": + main() # pragma: no cover diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_core.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_core.py new file mode 100644 index 0000000000000000000000000000000000000000..2a4d675605b1c5ee72c45d22bc55b01dc0f2b5d0 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_core.py @@ -0,0 +1,261 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import torch +from torch.utils import benchmark + +from xformers.components.attention.core import ( + SparseCS, + _create_random_sparsity, + _matmul_with_mask, + _softmax, + bmm, +) + +MIN_RUN_TIME = 1 +SHAPES = [[8, 8], [256, 1024], [128, 256]] +SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99] + + +def bench_sddmm(): + min_run_time = MIN_RUN_TIME + SPARSITIES = [0.95, 0.98, 0.99, 0.995, 0.999] + + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, M, K, device=device) + + for backend, prob in itertools.product( + ["coo_pytorch", "csr_sputnik", "csr_ge"], SPARSITIES + ): + mask = _create_random_sparsity(torch.ones(B, M, M, dtype=torch.bool), prob) + aa = a + bb = b + if "csr" in backend: + mask = SparseCS(mask, device) + aa = a + bb = b + row_indices = mask.row_indices + row_offsets = mask.row_offsets + column_indices = mask.column_indices + if "_ge" in backend: + fn = torch.ops.xformers.csr_sddmm + else: + fn = torch.ops.xformers.sddmm_sputnik + fn_str = "fn(a, b, row_indices, row_offsets, column_indices)" + else: + mask = mask.to_sparse().to(device) + _, row_offsets, column_indices = mask.indices().int().unbind() + row_offsets = row_offsets.contiguous() + column_indices = column_indices.contiguous() + row_indices = row_offsets + + bb = b.transpose(-2, -1) + fn = _matmul_with_mask + fn_str = "fn(a, b, mask)" + + results.append( + benchmark.Timer( + stmt=fn_str, + globals={ + "a": aa, + "b": bb, + "mask": mask, + "row_indices": row_indices, + "row_offsets": row_offsets, + "column_indices": column_indices, + "fn": fn, + }, + label="sddmm", + sub_label=f"sparsity {backend}: {prob:0.4f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def bench_matmul_with_mask(): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, K, M, device=device) + mask = torch.rand(B, M, M, device=device) > prob + + results.extend( + [ + benchmark.Timer( + stmt="_matmul_with_mask(a, b, mask)", + globals={ + "a": a, + "b": b, + "mask": None, + "_matmul_with_mask": _matmul_with_mask, + }, + label="matmul_with_mask", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + benchmark.Timer( + stmt="_matmul_with_mask(a, b, mask)", + globals={ + "a": a, + "b": b, + "mask": mask, + "_matmul_with_mask": _matmul_with_mask, + }, + label="matmul_with_mask", + sub_label="dense with masking", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for sputnik, prob in itertools.product([False, True], SPARSITIES): + mask = _create_random_sparsity( + torch.ones(B, M, M, dtype=torch.bool, device=device), prob + ) + aa = a + bb = b + if sputnik: + mask = SparseCS(mask, device) + aa = a + bb = b.transpose(-2, -1).contiguous().transpose(-2, -1) + else: + mask = mask.to_sparse() + results.append( + benchmark.Timer( + stmt="_matmul_with_mask(a, b, mask)", + globals={ + "a": aa, + "b": bb, + "mask": mask, + "_matmul_with_mask": _matmul_with_mask, + }, + label="matmul_with_mask", + sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def bench_softmax(): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + + results.extend( + [ + benchmark.Timer( + stmt="_softmax(a)", + globals={ + "a": a, + "_softmax": _softmax, + }, + label="softmax", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for sputnik, prob in itertools.product([False, True], SPARSITIES): + a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob) + if sputnik: + a = SparseCS(a, device) + else: + a = a.to_sparse() + results.append( + benchmark.Timer( + stmt="_softmax(a)", + globals={ + "a": a, + "_softmax": _softmax, + }, + label="softmax", + sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def bench_bmm(): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + b = torch.rand(B, M, K, device=device) + + results.extend( + [ + benchmark.Timer( + stmt="bmm(a, b)", + globals={ + "a": a, + "b": b, + "bmm": bmm, + }, + label="bmm", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for sputnik, prob in itertools.product([False, True], SPARSITIES): + a = _create_random_sparsity(torch.rand(B, M, M, device=device), prob) + bb = b + if sputnik: + a = SparseCS(a, device) + bb = b + else: + a = a.to_sparse() + results.append( + benchmark.Timer( + stmt="bmm(a, b)", + globals={ + "a": a, + "b": bb, + "bmm": bmm, + }, + label="bmm", + sub_label=f"sparsity {'sputnik' if sputnik else 'pytorch'}: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + bench_sddmm() + bench_matmul_with_mask() + bench_softmax() + bench_bmm() diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_indexing.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_indexing.py new file mode 100644 index 0000000000000000000000000000000000000000..353b9dba7dfe19223b61ee9b81e619ef9024e2aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_indexing.py @@ -0,0 +1,191 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import random + +import torch +from utils import DTYPE2STR, benchmark_main_helper2, product_dict + +import xformers.ops as xops + +min_run_time = 0.2 +device = torch.device("cuda") + +CASES_IADD = list( + product_dict( + shape=[ + (int(48 * 0.6), 48, 1, 257 * 1536), + (int(48 * 0.6), 48, 257, 1536), + ], + scaling=[False, True], + dtype=[torch.half], + ) +) + list( + product_dict( + shape=[ + # Format: [B_src, B_inp, M, D] + (int(192 * 0.6), 192, 50, 1536), + (int(48 * 257 * 0.6), 257 * 48, 1, 1536), + (int(192 * 50 * 0.6), 192 * 50, 1, 1536), + (int(16 * 257 * 0.6), 48 * 257, 1, 1536), + ], + scaling=[False], + dtype=[torch.half], + ) +) + +CASES_ISELECT = list( + product_dict( + batches=[((48, 257), (50, 192))], + D=[1536], + keep_ratio=[0.6], + dtype=[torch.half], + ) +) + + +class ScaledIndexAddBenchmark: + def __init__(self, dtype, scaling: bool, shape, bw: bool) -> None: + B_src, B_out, M, D = shape + torch.manual_seed(B_out + B_src) + dtype_str = DTYPE2STR.get(dtype, dtype) + self.sub_label = f"{dtype_str} B_src={B_src}, B_out={B_out}, M={M}, D={D} s={'Y' if scaling else 'N'}" + self.label = "scaled_index_add" + self.alpha = 0.73 + + self.inp = torch.randn( + [B_out, M, D], device="cuda", dtype=dtype, requires_grad=bw + ) + self.src = torch.randn( + [B_src, M, D], device="cuda", dtype=dtype, requires_grad=bw + ) + self.scaling = ( + torch.randn([D], device="cuda", dtype=dtype, requires_grad=bw) + if scaling + else None + ) + self.index = torch.tensor( + [i for i in range(self.src.shape[0])], dtype=torch.int64, device="cuda" + ) + self.grad = torch.randn([B_out, M, D], device="cuda", dtype=dtype) + self.out = torch.Tensor() + + def fw(self) -> None: + self.out = xops.scaled_index_add( + input=self.inp.clone(), + index=self.index, + source=self.src, + scaling=self.scaling, + alpha=self.alpha, + ) + + def bw(self): + self.inp.grad = None + self.src.grad = None + if self.scaling is not None: + self.scaling.grad = None + self.out.backward(self.grad, retain_graph=True) + + +class ScaledIndexAddBenchmarkBaseline(ScaledIndexAddBenchmark): + def fw(self) -> None: + src_scaled = self.src + if self.scaling is not None: + src_scaled * self.scaling.unsqueeze(0).unsqueeze(0) + self.out = self.inp.index_add( + dim=0, + source=src_scaled, + index=self.index, + alpha=self.alpha, + ) + + +class IndexSelectBenchmark: + def __init__(self, dtype, batches, D, keep_ratio, bw: bool) -> None: + dtype_str = DTYPE2STR.get(dtype, dtype) + self.sub_label = f"{dtype_str} D={D} batches={batches} keep={keep_ratio}" + self.label = "index_select" + + indices = [] + sources = [] + for B, seqlen in batches: + index = [i for i in range(B)] + random.Random(B).shuffle(index) + indices.append( + torch.zeros( + index[int(keep_ratio * B)], + dtype=torch.int64, + device="cuda", + ) + ) + source_i = torch.randn( + [B, seqlen * D], dtype=dtype, device="cuda", requires_grad=bw + ) + sources.append(source_i) + self.indices, self.sources = indices, sources + self.out = torch.Tensor() + + def fw(self) -> None: + self.out = xops.index_select_cat(self.sources, self.indices) + + def bw(self): + for src in self.sources: + src.grad = None + self.out.backward(self.out, retain_graph=True) + + +class IndexSelectBenchmarkBaseline(IndexSelectBenchmark): + def fw(self) -> None: + self.out = torch.cat( + [s[i].flatten() for s, i in zip(self.sources, self.indices)], dim=0 + ) + + +benchmark_main_helper2( + "scaled_index_add_fw", + fw=True, + functions={ + "xformers": ScaledIndexAddBenchmark, + "pytorch": ScaledIndexAddBenchmarkBaseline, + }, + cases=CASES_IADD, + min_run_time=min_run_time, +) + +benchmark_main_helper2( + "scaled_index_add_fwbw", + fw=True, + bw=True, + functions={ + "xformers": ScaledIndexAddBenchmark, + "pytorch": ScaledIndexAddBenchmarkBaseline, + }, + cases=CASES_IADD, + min_run_time=min_run_time, +) + +benchmark_main_helper2( + "index_select_fw", + fw=True, + functions={ + "xformers": IndexSelectBenchmark, + "pytorch": IndexSelectBenchmarkBaseline, + }, + cases=CASES_ISELECT, + min_run_time=min_run_time, +) + +benchmark_main_helper2( + "index_select_fwbw", + fw=True, + bw=True, + functions={ + "xformers": IndexSelectBenchmark, + "pytorch": IndexSelectBenchmarkBaseline, + }, + cases=CASES_ISELECT, + min_run_time=min_run_time, +) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_merge_attentions.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_merge_attentions.py new file mode 100644 index 0000000000000000000000000000000000000000..d89561ae5ea30580b654bfd8e0879e83ce4bed54 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_merge_attentions.py @@ -0,0 +1,92 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import torch + +from xformers.ops import fmha +from xformers.utils import do_bench_cudagraph + + +def _merge_attentions_varargs_ref(attn_split, lse_split): + """ + attn_split: list of [B, M, (G,) H, Kq] + lse_split: list of [B, (G,) H, M] + """ + attn_split = torch.stack(attn_split) + lse_split = torch.stack(lse_split) + + lse_split = lse_split[..., None].moveaxis(4, 2) # [split_k, B, M, G, H, 1] + + lse_max, _ = torch.max(lse_split, dim=0) # [B, M, G, H, 1] + sumexp_normalized = torch.exp(lse_split - lse_max) # [split_k, B, M, G, H, 1] + denominator = sumexp_normalized.sum(dim=0) # [B, M, G, H, 1] + numerator = (sumexp_normalized * attn_split).sum(dim=0) # [B, M, G, H, K] + + attn_out = numerator / denominator # [B, M_ceil, G, H, Kq] + lse_out = lse_max + torch.log(denominator) + lse_out = lse_out.squeeze(4).permute(0, 2, 3, 1) # [B, G, H, M] + + return attn_out, lse_out + + +def benchmark_merge_attentions_backward(split_k, B, M, G, N_H_L, D_H, dtype): + """ + Benchmark backward pass for merge_attentions. Assumes "varargs" path, + i.e. LSE and attention of chunks are provided as two lists of tensors, and not as two stacked tensors. + """ + + bench_stream = torch.cuda.Stream() + with torch.cuda.stream(bench_stream): + + attn_split = [ + torch.randn( + [B, M, G, N_H_L, D_H], dtype=dtype, device="cuda", requires_grad=True + ) + for _ in range(split_k) + ] + lse_split = [ + torch.randn( + [B, G, N_H_L, M], dtype=dtype, device="cuda", requires_grad=True + ) + for _ in range(split_k) + ] + + attn_out_ref, lse_out_ref = _merge_attentions_varargs_ref(attn_split, lse_split) + out_grad = torch.randn_like(attn_out_ref) + attn_out_ref.backward(out_grad, retain_graph=True) + t_ms_ref = do_bench_cudagraph( + lambda: attn_out_ref.backward(out_grad, retain_graph=True) + ) + + for x in attn_split + lse_split: + x.detach_() + x.requires_grad_(True) + + attn_out, lse_out = fmha.merge_attentions(attn_split, lse_split) + attn_out.backward(out_grad, retain_graph=True) + t_ms = do_bench_cudagraph( + lambda: attn_out.backward(out_grad, retain_graph=True) + ) + + print( + f"{split_k=}, {B=}, {M=}, {G=}, {N_H_L=}, {D_H=}, {dtype=}. " + f"Baseline: {t_ms_ref * 1e3:.2f}us, " + f"Triton: {t_ms * 1e3:.2f}us, {t_ms_ref/t_ms:.1f}x faster" + ) + + +def main(): + G = 2 + N_H_L = 8 + D_H = 128 + dtype = torch.float32 + for split_k in [2, 4, 8, 16]: + for B in [1, 32, 128]: + for M in [1, 32, 512]: + benchmark_merge_attentions_backward(split_k, B, M, G, N_H_L, D_H, dtype) + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_multi_head_dispatch.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_multi_head_dispatch.py new file mode 100644 index 0000000000000000000000000000000000000000..2345cf2a5fa97f1e361bb217698eba1b2a051c16 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_multi_head_dispatch.py @@ -0,0 +1,105 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import torch.nn as nn +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components import MultiHeadDispatch +from xformers.components.attention import ScaledDotProduct + +SHAPES = [ + (8, 384, 128), + (8, 784, 512), + (4, 1024, 768), + (4, 2048, 1024), + (2, 2048, 2048), + (2, 2048, 4096), + (2, 4096, 4096), + (1, 2048, 12288), +] + +N_HEADS = [4] + + +def bench_multihead_dispatch(backward: bool, self_attention: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + sa = " (self_attn)" if self_attention else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in SHAPES: + for heads in N_HEADS: + xf_multi_head = MultiHeadDispatch( + dim_model=K, + residual_dropout=0.0, + num_heads=heads, + attention=ScaledDotProduct(), + bias=(True, True, True, True), + ).to(device=device, dtype=dtype) + torch_multi_head = nn.MultiheadAttention( + embed_dim=K, num_heads=heads, batch_first=True + ).to(device=device, dtype=dtype) + + q = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + if self_attention: + k = q + v = q + else: + k = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + v = torch.randn( + (B, M, K), requires_grad=backward, device=device, dtype=dtype + ) + + def torch_mha(): + y, _ = torch_multi_head(query=q, key=k, value=v) + if backward: + torch.norm(y).backward() + return y + + def xformers_mha(): + y = xf_multi_head(query=q, key=k, value=v) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(torch_mha, f"torch - fw{bw}{sa}"), + TestCase(xformers_mha, f"xf - fw{bw}{sa}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"B={B}, M={M}, K={K}, N_HEADS={heads}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"MHA-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +for bw in [False, True]: + for self_attention in [False, True]: + bench_multihead_dispatch(bw, self_attention) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_nystrom_utils.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_nystrom_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c85b0345683a9cc6077a3fb4d25bb50cbdffeead --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_nystrom_utils.py @@ -0,0 +1,101 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable + +import torch +from torch.utils import benchmark + +from xformers.components.attention.utils import iterative_pinv + +MIN_RUN_TIME = 1 +SHAPES = [[8, 8], [256, 1024], [128, 256]] +SPARSITIES = [0.5, 0.8, 0.9, 0.95, 0.99] + + +def bench_inverse(inverse_fn: Callable[[torch.Tensor], torch.Tensor]): + min_run_time = MIN_RUN_TIME + prob = 0.9 + device = torch.device("cuda") + results = [] + + for B, M, K in zip(*SHAPES): + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + a = torch.softmax(a, dim=-1) + + results.extend( + [ + benchmark.Timer( + stmt=f"{inverse_fn.__name__}(a)", + globals={ + "a": a, + f"{inverse_fn.__name__}": inverse_fn, + }, + label=f"{inverse_fn.__name__}", + sub_label="dense", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time), + ] + ) + for prob in SPARSITIES: + a = torch.rand(B, M, M, device=device) + a[a < prob] = 0 + a = a.to_sparse() + results.append( + benchmark.Timer( + stmt=f"{inverse_fn.__name__}(a)", + globals={ + "a": a, + f"{inverse_fn.__name__}": inverse_fn, + }, + label=f"{inverse_fn.__name__}", + sub_label=f"sparsity: {prob:0.2f}", + description=f"B={B}, M={M}, K={K}", + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + + +def iterative_pinv_analysis( + identity_tolerance: float = 1e-1, + pinv_tolerance: float = 5e-1, + max_iters: int = 30, + plot: bool = True, +): + + for i in range(1, 10): + B, M = 1, 2**i + a = torch.rand(B, M, M) + a = torch.softmax(a, dim=-1) + + for n_iter in range(1, max_iters + 1): + result = iterative_pinv(a, n_iter=n_iter) + expected = torch.linalg.pinv(a) + + result_identity = torch.matmul(a, result) + identity = torch.eye(M) + + # Default is frobenius norm. + identity_error = torch.linalg.norm(identity - result_identity, dim=(-2, -1)) + inverse_error = torch.linalg.norm(expected - result, dim=(-2, -1)) + + if (identity_error < identity_tolerance).all() or n_iter == max_iters: + print( + f"Size {M}, n_iters {n_iter}: \n\t \ + Final Error from Identity: {identity_error.item()} \n\t \ + Final Error from linalg.pinv {inverse_error.item()}" + ) + break + + +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + iterative_pinv_analysis() + bench_inverse(iterative_pinv) + bench_inverse(torch.linalg.pinv) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_revnet.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_revnet.py new file mode 100644 index 0000000000000000000000000000000000000000..8561481dd4459eca180b34bd98f96aff691b7c4c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_revnet.py @@ -0,0 +1,83 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Any, Dict + +import torch +import triton + +from xformers.benchmarks.utils import TestCase, pretty_plot, pretty_print +from xformers.components.reversible import ReversibleSequence + +SHAPES = [(16384, 32), (2048, 256), (128, 4096)] + +DEPTH = [4, 32, 256] + + +def bench_revnet(backward: bool): + device = torch.device("cuda") + bw = "+bw" if backward else "" + + for dtype in [torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, K in SHAPES: + for depth in DEPTH: + f = torch.nn.Linear(K, K).to(device=device, dtype=dtype) + g = torch.nn.Linear(K, K).to(device=device, dtype=dtype) + revseq = ReversibleSequence( + torch.nn.ModuleList([torch.nn.ModuleList([f, g])] * depth) + ) + revseq = revseq.to(device=device, dtype=dtype) + + a = torch.rand( + 1, B, K, device=device, dtype=dtype, requires_grad=backward + ) + b = torch.rand( + 1, B, K * 2, device=device, dtype=dtype, requires_grad=backward + ) + + def normal_step(): + y = a + for _ in range(depth): + y = y + f(y) + y = y + g(y) + if backward: + torch.norm(y).backward() + return y + + def reversible_step(): + y = revseq(b) + if backward: + torch.norm(y).backward() + return y + + for testcase in [ + TestCase(normal_step, f"residual - fw{bw}"), + TestCase(reversible_step, f"reversible - fw{bw}"), + ]: + time = triton.testing.do_bench(testcase.function)[0] + key = f"Batch={B}, Features={K}, Depth={depth}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{time:.2f}" + + pretty_print( + results, + title=f"\n --- Type: {dtype} --- ", + units="runtime in ms, lower is better", + ) + pretty_plot( + results, + title=f"RevNet-FW{bw}-{dtype}", + units="runtime in ms, lower is better", + dash_key="torch", + ) + + +for bw in [False, True]: + bench_revnet(bw) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sddmm.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sddmm.py new file mode 100644 index 0000000000000000000000000000000000000000..536fc5ef8e207c8452a9bab6f10a286379bb80e7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sddmm.py @@ -0,0 +1,120 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import itertools + +import torch +from torch.utils import benchmark + +from xformers.components.attention._sputnik_sparse import _csr_to_coo +from xformers.components.attention.core import SparseCS, _create_random_sparsity + +MIN_RUN_TIME = 0.2 + + +def _get_fn(backend): + if backend == "csr_ge": + fn = torch.ops.xformers.csr_sddmm + elif backend == "csr_sputnik": + fn = torch.ops.xformers.sddmm_sputnik + elif backend == "coo_ge": + + def fn(a, b, row_indices, row_offsets, column_indices): + row_coo, _ = _csr_to_coo( + a.shape[-2], b.shape[-2], row_offsets, column_indices + ) + return torch.ops.xformers.coo_sddmm( + a, b, row_indices, row_coo, column_indices + ) + + elif backend == "csr_to_coo": + + def fn(a, b, row_indices, row_offsets, column_indices): + row_coo, _ = _csr_to_coo( + a.shape[-2], b.shape[-2], row_offsets, column_indices + ) + return row_coo + + return fn + + +def bench_sddmm(configs): + min_run_time = MIN_RUN_TIME + + device = torch.device("cuda") + results = [] + + for (B, M, K), prob in configs: + a = torch.rand(B, M, K, device=device) + b = torch.rand(B, M, K, device=device) + + mask = _create_random_sparsity( + torch.ones(1, M, M, dtype=torch.bool), prob, divisible_by=16 + ) + aa = a + bb = b + mask = SparseCS(mask, device) + row_indices = mask.row_indices + row_offsets = mask.row_offsets + column_indices = mask.column_indices + + for backend in ["csr_sputnik", "csr_ge", "coo_ge", "csr_to_coo"]: + + fn_str = "fn(a, b, row_indices, row_offsets, column_indices)" + fn = _get_fn(backend) + + results.append( + benchmark.Timer( + stmt=fn_str, + globals={ + "a": aa, + "b": bb, + "mask": mask, + "row_indices": row_indices, + "row_offsets": row_offsets, + "column_indices": column_indices, + "fn": fn, + }, + label="sddmm", + sub_label=f"B={B:>4d}, M={M:>4d}, K={K:>3d}, prob={prob:0.4f}", + description=backend, + ).blocked_autorange(min_run_time=min_run_time) + ) + + compare = benchmark.Compare(results) + compare.print() + return results + + +# batch size 32, for different layers +SWIN_T_SIZES = [(96, 3136, 32), (192, 784, 32), (384, 196, 32), (768, 49, 32)] +swin_t_config = list(zip(SWIN_T_SIZES, (0.9844, 0.9375, 0.75, 0.0))) + +# some random values +BASIC_SIZES = [(32, 1024, 32), (32, 1024, 128), (8, 4096, 32), (8, 4096, 128)] +SPARSITIES = [0.90, 0.93, 0.95, 0.97, 0.98, 0.99, 0.995, 0.999] +basic_config = list(itertools.product(BASIC_SIZES, SPARSITIES)) + +# batch size 32 here +vit_sizes = [ + (192, 785, 64), # deit_small_patch8_224 + (192, 197, 64), # deit_small_patch16_224 + (384, 785, 64), # deit_base_patch8_224 + (384, 197, 64), # deit_base_patch16_224 +] +SPARSITIES = [0.70, 0.80, 0.85, 0.90, 0.93, 0.95, 0.97] +vit_config = list(itertools.product(vit_sizes, SPARSITIES)) + +results = [] + +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + print("Swin Transformer") + results += bench_sddmm(swin_t_config) + print("ViT") + results += bench_sddmm(vit_config) + print("Basic cases") + results += bench_sddmm(basic_config) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sequence_parallel_fused.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sequence_parallel_fused.py new file mode 100644 index 0000000000000000000000000000000000000000..cd0d646a53917524a9855e5977b827eb1ee5420e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_sequence_parallel_fused.py @@ -0,0 +1,473 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +import dataclasses +import enum +import multiprocessing +import os +import random +from collections import deque +from statistics import mean, stdev +from typing import Callable + +import torch + +# torch._C._set_print_stack_traces_on_fatal_signal(True) + + +@dataclasses.dataclass +class Scenario: + # The number of tokens, i.e., the batch size times the sequence length + num_samples: int + # The per-sample features outside of the MHA/FFN block, and inside of it + outer_dim: int + inner_dim: int + # Simulate this many matmuls during the all-gather step + num_ag_matrices: int + + +class Step(enum.Enum): + AllGather = "ag" + ReduceScatter = "rs" + + def __str__(self): + return self.value + + +@dataclasses.dataclass +class Bench: + ag: Callable[[], None] + rs: Callable[[], None] + + def __getitem__(self, step: Step): + if step is Step.AllGather: + return self.ag + elif step is Step.ReduceScatter: + return self.rs + else: + raise KeyError(f"{step}") + + +LLAMA_07B_SLEN = 4096 +LLAMA_07B_D = 4096 + +LLAMA_70B_SLEN = 2048 +LLAMA_70B_D = 8192 + + +def round_up_to_nearest_multiple(n: int, m: int) -> int: + return m * ((n + m - 1) // m) + + +def llama_07B_MHA(world_size: int) -> Scenario: + batch_size = 8 + return Scenario( + num_samples=batch_size * LLAMA_07B_SLEN, + outer_dim=LLAMA_07B_D, + inner_dim=LLAMA_07B_D // world_size, + num_ag_matrices=3, + ) + + +def llama_07B_FFN(world_size: int) -> Scenario: + batch_size = 8 + return Scenario( + num_samples=batch_size * LLAMA_07B_SLEN, + outer_dim=LLAMA_07B_D, + inner_dim=round_up_to_nearest_multiple(2 * (4 * LLAMA_07B_D) // 3, 256) + // world_size, + num_ag_matrices=2, + ) + + +def llama_70B_MHA(world_size: int) -> Scenario: + batch_size = world_size + return Scenario( + num_samples=batch_size * LLAMA_70B_SLEN, + outer_dim=LLAMA_70B_D, + inner_dim=LLAMA_70B_D // world_size, + num_ag_matrices=3, + ) + + +def llama_70B_FFN(world_size: int) -> Scenario: + batch_size = world_size + return Scenario( + num_samples=batch_size * LLAMA_70B_SLEN, + outer_dim=LLAMA_70B_D, + inner_dim=round_up_to_nearest_multiple(2 * (4 * LLAMA_70B_D) // 3, 256) + // world_size, + num_ag_matrices=2, + ) + + +SCENARIOS = { + "llama_07B_MHA": llama_07B_MHA, + "llama_07B_FFN": llama_07B_FFN, + "llama_70B_MHA": llama_70B_MHA, + "llama_70B_FFN": llama_70B_FFN, +} + +DTYPES = { + "bfloat16": torch.bfloat16, +} + + +def run_one_rank( + my_rank, + world_size, + scenario_name, + step, + dtype_str, + num_rounds, + num_warmup_iters, + num_bench_iters, + profile, + conn_from_prev, + conn_to_next, +): + print(f"RANK {my_rank} started") + + torch.cuda.set_device(my_rank) + my_device = torch.device(f"cuda:{my_rank}") + + os.environ["RANK"] = f"{my_rank}" + os.environ["WORLD_SIZE"] = f"{world_size}" + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + torch.distributed.init_process_group(backend="nccl", init_method="env://") + + subgroup = torch.distributed.new_group() + subgroup_nowait = torch.distributed.new_group() + subgroup_nowait_nomemcpy = torch.distributed.new_group() + + scenario = SCENARIOS[scenario_name](world_size) + if step is Step.AllGather: + M = scenario.num_samples + N = scenario.inner_dim + K = scenario.outer_dim + num_matrices = scenario.num_ag_matrices + elif step is Step.ReduceScatter: + M = scenario.num_samples + N = scenario.outer_dim + K = scenario.inner_dim + num_matrices = 1 + + dtype = DTYPES[dtype_str] + + scattered_input = torch.randn((M // world_size, K), dtype=dtype, device=my_device) + gathered_input = torch.randn((M, K), dtype=dtype, device=my_device) + weights = [ + torch.randn((K, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + gathered_outputs = [ + torch.randn((M, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + scattered_outputs = [ + torch.randn((M // world_size, N), dtype=dtype, device=my_device) + for _ in range(num_matrices) + ] + + gathered_outputs_nccl_reference = [ + torch.randn((M, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + gathered_outputs_fused = [ + torch.randn((M, N), dtype=dtype, device=my_device) for _ in range(num_matrices) + ] + scattered_outputs_nccl_reference = [ + torch.randn((M // world_size, N), dtype=dtype, device=my_device) + for _ in range(num_matrices) + ] + scattered_outputs_fused = [ + torch.randn((M // world_size, N), dtype=dtype, device=my_device) + for _ in range(num_matrices) + ] + + def run_compute_lower_bound_ag(): + for w, go in zip(weights, gathered_outputs): + torch.matmul(gathered_input, w, out=go) + + def run_compute_lower_bound_rs(): + for w, go, so in zip(weights, gathered_outputs, scattered_outputs): + torch.matmul(gathered_input, w, out=go) + torch.sum(go.view((world_size, M // world_size, N)), dim=0, out=so) + + def run_comms_lower_bound_ag(): + torch.distributed.all_gather_into_tensor(gathered_input, scattered_input) + + def run_comms_lower_bound_rs(): + for so, go in zip(scattered_outputs, gathered_outputs): + torch.distributed.reduce_scatter_tensor(so, go) + + def run_nccl_reference_ag(): + torch.distributed.all_gather_into_tensor(gathered_input, scattered_input) + for w, go in zip(weights, gathered_outputs_nccl_reference): + torch.matmul(gathered_input, w, out=go) + + def run_nccl_reference_rs(): + for w, go, so in zip( + weights, gathered_outputs, scattered_outputs_nccl_reference + ): + torch.matmul(gathered_input, w, out=go) + torch.distributed.reduce_scatter_tensor(so, go) + + def run_fused_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup, + timeout_s=10, + ) + + def run_fused_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup, + timeout_s=10, + ) + + def run_fused_nowait_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup_nowait, + _wait=False, + timeout_s=10, + ) + + def run_fused_nowait_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup_nowait, + _wait=False, + timeout_s=10, + ) + + def run_fused_nowait_nomemcpy_ag(): + nonlocal gathered_outputs_fused + from xformers.ops import fused_allgather_and_linear + + gathered_outputs_fused = fused_allgather_and_linear( + scattered_input, + [w.t() for w in weights], + group=subgroup_nowait_nomemcpy, + _wait=False, + _memcpy=False, + timeout_s=10, + ) + + def run_fused_nowait_nomemcpy_rs(): + nonlocal scattered_outputs_fused + from xformers.ops import fused_linear_and_reducescatter + + scattered_outputs_fused = fused_linear_and_reducescatter( + gathered_input, + [w.t() for w in weights], + group=subgroup_nowait_nomemcpy, + _wait=False, + _memcpy=False, + timeout_s=10, + ) + + print(f"Sizes: ({world_size}x{M // world_size})x({num_matrices}x{N})x{K}") + + if step is Step.AllGather: + run_nccl_reference_ag() + run_fused_ag() + if my_rank == 0: + print("fused:") + print( + "Are equal? " + + " ".join( + str(torch.equal(ref, fus)) + for ref, fus in zip( + gathered_outputs_nccl_reference, gathered_outputs_fused + ) + ) + ) + print( + "Are allclose? " + + " ".join( + str(torch.allclose(ref, fus)) + for ref, fus in zip( + gathered_outputs_nccl_reference, gathered_outputs_fused + ) + ) + ) + + elif step is Step.ReduceScatter: + run_nccl_reference_rs() + run_fused_rs() + if my_rank == 0: + print("fused:") + print( + "Are equal? " + + " ".join( + str(torch.equal(ref, fus)) + for ref, fus in zip( + scattered_outputs_nccl_reference, scattered_outputs_fused + ) + ) + ) + print( + "Are allclose? " + + " ".join( + str(torch.allclose(ref, fus)) + for ref, fus in zip( + scattered_outputs_nccl_reference, scattered_outputs_fused + ) + ) + ) + + # The above checks might still return False for, e.g., bfloat16 because they + # have too little tolerance for its lower precision. This method, OTOH, uses + # variable tolerances based on dtype. + # for ref, fus in zip(gathered_outputs_nccl_reference, gathered_outputs_fused): + # torch.testing.assert_close(ref, fus) + # for ref, fus in zip(scattered_outputs_nccl_reference, scattered_outputs_fused): + # torch.testing.assert_close(ref, fus) + + all_benchs = { + "compute_lower_bound": Bench( + ag=run_compute_lower_bound_ag, rs=run_compute_lower_bound_rs + ), + "comms_lower_bound": Bench( + ag=run_comms_lower_bound_ag, rs=run_comms_lower_bound_rs + ), + "nccl_reference": Bench(ag=run_nccl_reference_ag, rs=run_nccl_reference_rs), + "fused": Bench(ag=run_fused_ag, rs=run_fused_rs), + "fused_nowait": Bench(ag=run_fused_nowait_ag, rs=run_fused_nowait_rs), + "fused_nowait_nomemcpy": Bench( + ag=run_fused_nowait_nomemcpy_ag, rs=run_fused_nowait_nomemcpy_rs + ), + } + + unused_events = deque( + tuple(torch.cuda.Event(enable_timing=my_rank == 0) for _ in range(2)) + for f in range(len(all_benchs)) + ) + used_events = deque() + + timings = {} + + gen = random.Random(42) + + if profile: + profiler = torch.profiler.profile() + else: + profiler = contextlib.nullcontext() + + with profiler as p: + for method in gen.sample( + list(all_benchs), + k=num_rounds * len(all_benchs), + counts=[num_rounds] * len(all_benchs), + ): + fun = all_benchs[method][step] + + if unused_events: + start_ev, end_ev = unused_events.popleft() + else: + old_method, start_ev, end_ev = used_events.popleft() + end_ev.synchronize() + if my_rank == 0: + timings.setdefault(old_method, []).append( + start_ev.elapsed_time(end_ev) / num_bench_iters + ) + + for _ in range(num_warmup_iters): + fun() + start_ev.record() + for _ in range(num_bench_iters): + fun() + end_ev.record() + + used_events.append((method, start_ev, end_ev)) + + torch.cuda.synchronize() + + if profile: + p.export_chrome_trace(f"fusion_trace_{my_rank}.json") + + if my_rank == 0: + for method, start_ev, end_ev in used_events: + timings.setdefault(method, []).append( + start_ev.elapsed_time(end_ev) / num_bench_iters + ) + + for method in all_benchs: + print( + f"{method} = {mean(timings[method]):g}ms (+/- {stdev(timings[method]):g})" + ) + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("scenario", choices=SCENARIOS.keys()) + parser.add_argument("step", choices=list(Step), type=Step) + parser.add_argument("--world-size", type=int, default=8) + parser.add_argument("--dtype", choices=DTYPES.keys(), default="bfloat16") + parser.add_argument("--num-rounds", type=int, default=20) + parser.add_argument("--num-warmup-iters", type=int, default=5) + parser.add_argument("--num-bench-iters", type=int, default=50) + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + + conns_from_prev = [None] * args.world_size + conns_to_next = [None] * args.world_size + for rank in range(args.world_size): + end1, end2 = multiprocessing.get_context("spawn").Pipe(duplex=True) + conns_to_next[rank] = end1 + conns_from_prev[(rank + 1) % args.world_size] = end2 + + processes = [] + for rank in range(args.world_size): + p = multiprocessing.get_context("spawn").Process( + target=run_one_rank, + args=( + rank, + args.world_size, + args.scenario, + args.step, + args.dtype, + args.num_rounds, + args.num_warmup_iters, + args.num_bench_iters, + args.profile, + conns_from_prev[rank], + conns_to_next[rank], + ), + daemon=True, + ) + p.start() + processes.append(p) + + print("LAUNCHED") + + for rank, p in enumerate(processes): + p.join() + print(f"Rank {rank} exited with {p.exitcode}") + + print("JOINED") + + +if __name__ == "__main__": + main() diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_swiglu.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_swiglu.py new file mode 100644 index 0000000000000000000000000000000000000000..92a543d38ca5ea271d843d65be27d034f07aa322 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_swiglu.py @@ -0,0 +1,163 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools +from contextlib import nullcontext +from functools import partial +from typing import Any + +import torch +from torch.utils import benchmark + +import xformers.ops.swiglu_op as xsw +from xformers.benchmarks.utils import benchmark_main_helper + +min_run_time = 0.5 +device = torch.device("cuda") + +SHAPES = [ + # Format: [inp.shape[0], inp.shape[1], hidden.shape[1]] + # ViT-Giant + (9456, 1536, 2736), + (4440, 1536, 2736), + (4728, 1536, 2736), + # Some smaller shapes as well + (4728, 1536, 1024), + # GPT-3 (small) + (32768, 2048, 5632), + # Chinchilla + (32768, 8192, 22016), +] + + +# OP = xsw._SwiGLUDecomposedOp +# OP = xsw.SwiGLUFusedOp +OP = xsw.SwiGLUPackedFusedOp + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + shape=SHAPES, + dtype=[torch.bfloat16, torch.half, "autocast_half"], + bias=[True, False], + ) +) + +DTYPE2STR = { + torch.bfloat16: "b16 ", + torch.half: "f16 ", + "autocast_half": "f16.ac", +} + + +def benchmark_swiglu(shape, dtype, bias: bool): + if dtype == "autocast_half": + inp_dtype, model_dtype, autocast = torch.float, torch.float, True + else: + inp_dtype, model_dtype, autocast = dtype, dtype, False + + x = torch.randn(shape[:2], device=device, dtype=inp_dtype) + module = ( + xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias) + .to(device) + .to(model_dtype) + ) + + dtype_str = DTYPE2STR.get(dtype, dtype) + bstr = "bias" if bias else "nobi" + sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}" + + params = module._ordered_params() + + PREFIX = 'with torch.autocast("cuda", dtype=torch.half):\n ' if autocast else "" + yield benchmark.Timer( + stmt=f"{PREFIX}fn(x, *args)", + globals={ + "x": x, + "args": params, + "fn": partial(xsw.swiglu, op=OP), + }, + label="swiglu_fw", + description=OP.NAME, + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt=f"{PREFIX}fn(x, *args)", + globals={ + "x": x, + "args": params, + "fn": partial(xsw.swiglu, op=xsw.SwiGLUEagerOp), + }, + label="swiglu_fw", + description="eager", + sub_label=sub_label, + ) + + +def benchmark_swiglu_bw(shape, dtype, bias: bool): + if dtype == "autocast_half": + inp_dtype, model_dtype = torch.float, torch.float + cm: Any = partial(torch.amp.autocast, "cuda", enabled=True, dtype=torch.float16) + else: + inp_dtype, model_dtype = dtype, dtype + cm = nullcontext + + x = torch.randn(shape[:2], device=device, dtype=inp_dtype) + x.requires_grad_() + module = ( + xsw.SwiGLU(in_features=shape[1], hidden_features=shape[2], bias=bias) + .to(device) + .to(model_dtype) + ) + + dtype_str = DTYPE2STR.get(dtype, dtype) + bstr = "bias" if bias else "nobi" + sub_label = f"{dtype_str} B={shape[0]}, I={shape[1]}, H={shape[2]} {bstr}" + + params = module._ordered_params() + with cm(): + out = xsw.swiglu(x, *params, op=OP) + grad = torch.zeros_like(out) + + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="swiglu_bw", + description=OP.NAME, + sub_label=sub_label, + ) + del out + + with cm(): + out = xsw.swiglu(x, *params, op=xsw.SwiGLUEagerOp) + + yield benchmark.Timer( + stmt="out.backward(grad, retain_graph=True)", + globals={ + "out": out, + "grad": grad, + }, + label="swiglu_bw", + description="eager", + sub_label=sub_label, + ) + + +if torch.version.hip: + print("This benchmark could not be done on ROCM!") +else: + benchmark_main_helper(benchmark_swiglu, CASES, min_run_time=min_run_time) + benchmark_main_helper(benchmark_swiglu_bw, CASES, min_run_time=min_run_time) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_tiled_matmul.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_tiled_matmul.py new file mode 100644 index 0000000000000000000000000000000000000000..f584d1bcbb6a166a31fb54d1017ce6e64ae4cf2c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/benchmark_tiled_matmul.py @@ -0,0 +1,135 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import itertools + +import torch +from torch.utils import benchmark +from triton.ops.matmul import matmul as triton_matmul + +from xformers.benchmarks.utils import DTYPE2STR, benchmark_main_helper +from xformers.ops.tiled_matmul import tiled_matmul + +min_run_time = 5 + + +SHAPES = { + "llama1_65b_mha_fwd": ([16384], [1024] * 3, [8192]), + "llama1_65b_mha_bwd_input": ([16384], [8192], [1024] * 3), + "llama1_65b_mha_bwd_weight": ([8192], [1024] * 3, [16384]), + "llama1_65b_ffn_fwd": ([16384], [2752] * 2, [8192]), + "llama1_65b_ffn_bwd_input": ([16384], [8192], [2752] * 2), + "llama1_65b_ffn_bwd_weight": ([8192], [2752] * 2, [16384]), + "llama2_150b_mha_fwd": ([16384], [1536, 128, 128], [12288]), + "llama2_150b_mha_bwd_input": ([16384], [12288], [1536, 128, 128]), + "llama2_150b_mha_bwd_weight": ([12288], [1536, 128, 128], [16384]), + "llama2_150b_ffn_fwd": ([16384], [4096] * 2, [12288]), + "llama2_150b_ffn_bwd_input": ([16384], [12288], [4096] * 2), + "llama2_150b_ffn_bwd_weight": ([12288], [4096] * 2, [16384]), +} + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +CASES = list( + product_dict( + shape_name=SHAPES.keys(), + dtype=[ + # torch.float32, + torch.bfloat16, + # torch.float16, + ], + ) +) + + +def matmul_per_tile(a, b): + c = [] + for n in range(len(a)): + c.append([]) + for m in range(len(b[0])): + c[-1].append( + sum([torch.matmul(a[n][k], b[k][m]) for k in range(len(a[0]))]) + ) + return c + + +def benchmark_tiled_matmul(shape_name, dtype): + ms, ns, ks = SHAPES[shape_name] + m, n, k = sum(ms), sum(ns), sum(ks) + + a = torch.randn((m, k), device="cuda", dtype=dtype) + b = torch.randn((k, n), device="cuda", dtype=dtype) + + a_tiles = [[y.clone() for y in x.split(ks, dim=1)] for x in a.split(ms, dim=0)] + b_tiles = [[y.clone() for y in x.split(ns, dim=1)] for x in b.split(ks, dim=0)] + + dtype_str = DTYPE2STR.get(dtype, dtype) + sub_label = ( + f"{dtype_str} {shape_name} " + f"M={'+'.join(f'{m}' for m in ms)} " + f"N={'+'.join(f'{n}' for n in ns)} " + f"K={'+'.join(f'{k}' for k in ks)}" + ) + + # Warmup (maybe not needed?) + torch.mm(a, b) + matmul_per_tile(a_tiles, b_tiles) + triton_matmul(a, b) + tiled_matmul(a_tiles, b_tiles) + + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a, + "b": b, + "fn": torch.mm, + }, + label="tiled_matmul", + description="pytorch_fused", + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a_tiles, + "b": b_tiles, + "fn": matmul_per_tile, + }, + label="tiled_matmul", + description="pytorch_tiled", + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a, + "b": b, + "fn": triton_matmul, + }, + label="tiled_matmul", + description="triton_fused", + sub_label=sub_label, + ) + yield benchmark.Timer( + stmt="fn(a, b)", + globals={ + "a": a_tiles, + "b": b_tiles, + "fn": tiled_matmul, + }, + label="tiled_matmul", + description="xformers_tiled", + sub_label=sub_label, + ) + + +benchmark_main_helper(benchmark_tiled_matmul, CASES, min_run_time=min_run_time) diff --git a/.venv/lib/python3.11/site-packages/xformers/benchmarks/utils.py b/.venv/lib/python3.11/site-packages/xformers/benchmarks/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..dc990b2204b03a990d242bac44150c420336505e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/benchmarks/utils.py @@ -0,0 +1,758 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import contextlib +import copy +import csv +import functools +import glob +import itertools +import logging +import math +import os +import tempfile +from collections import defaultdict, namedtuple +from dataclasses import replace +from typing import Any, Dict, Generator, Iterator, List, Set, Tuple + +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +import tqdm +from torch.utils import benchmark + +sns.set() + +TestCase = namedtuple("TestCase", ["function", "name"]) + + +class NotSupportedInputError(Exception): + pass + + +_triton_is_available = torch.cuda.is_available() +if _triton_is_available: + try: + import triton + except ImportError as e: + logging.warning(f"Triton is not available: {e}.\nbench_functions") + _triton_is_available = False + + +def get_func_name(fn): + if isinstance(fn, functools.partial): + return fn.func.__name__ + return fn.__name__ + + +def pretty_print(results, title, units) -> None: + """Printout the contents of a dict as a human-readable and Markdown compatible array""" + print(title) + header = " Units: {:<45}".format(units) + print("| " + header + "|" + "".join("{0:<20}|".format(k) for k in results.keys())) + + offset = len(header) + print( + "|-{}|".format("-" * offset) + + "".join("{}|".format("-" * 20) for _ in results.keys()) + ) + + workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()} + for v in results.values(): + for k in v.keys(): + workloads[k].append(v[k]) + + for k, w in workloads.items(): + print( + "| {0:<{offset}}|".format(k, offset=offset) + + "".join("{:<20}|".format(v) for v in w) + ) + + print("") + + +def pretty_plot( + results, title, units: str, filename=None, dash_key="", legend_loc="lower right" +): + """Graph out the contents of a dict. + Dash key means that if the result label has this key, then it will be displayed with a dash + """ + + if not filename: + filename = title + ".png" + + # Sanitize the filename + filename = ( + filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "") + ) + + # Gather all the results in "collumns" + workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()} + for v in results.values(): + for k in v.keys(): + workloads[k].append(float(v[k])) + + # Make sure that the plot is big enough + f = plt.figure() + f.set_figwidth(6) + f.set_figheight(6) + + # Display the collections + for k, v in workloads.items(): + if dash_key and dash_key in k: + plt.plot(list(results.keys()), v, "--") + else: + plt.plot(list(results.keys()), v) + + plt.title(title) + plt.legend(list(workloads.keys()), loc=legend_loc) + plt.ylabel(units) + plt.xticks(rotation=45) + + plt.savefig(filename, bbox_inches="tight") + plt.close(f) + + +if _triton_is_available: + + def bench_functions( + test_cases: List[TestCase], shapes, metric_transform, unit, title="" + ): + device = torch.device("cuda") + + for dtype in [torch.bfloat16, torch.float16, torch.float32]: + results: Dict[str, Any] = {} + + for B, M, K in shapes: + a = torch.rand(B, M, K, device=device, dtype=dtype, requires_grad=True) + + for testcase in test_cases: + time = triton.testing.do_bench(lambda: testcase.function(a))[0] + + metric = metric_transform(a, time) + + key = f"B={B}, M={M}, K={K}" + if key not in results: + results[key] = {} + + results[key][testcase.name] = f"{metric:.1f}" + + pretty_print( + results, + title=" ------------- Type: {} ------------- ".format(dtype), + units=unit, + ) + pretty_plot(results, title + str(dtype), unit, dash_key="pytorch") + + +def pretty_barplot(results, title, units: str, filename=None, dash_key=""): + """Graph out the contents of a dict. + Dash key means that if the result label has this key, then it will be displayed with a dash + """ + + if not filename: + filename = title + ".png" + + # Sanitize the filename + filename = ( + filename.replace(" ", "_").replace("/", "_").replace("-", "_").replace(":", "") + ) + + xlabels = list(results.keys()) + # Gather all the results in "collumns" + workloads: Dict[str, Any] = {k: [] for v in results.values() for k in v.keys()} + for v in results.values(): + for k in v.keys(): + workloads[k].append(float(v[k])) + + options = list(workloads.keys()) + group_len = len(options) + for key in workloads.keys(): + num_groups = len(workloads[key]) + break + group_width = group_len + 1 + + # Make sure that the plot is big enough + f = plt.figure() + f.set_figwidth(6) + f.set_figheight(6) + + for idx in range(group_len): + option = options[idx] + values = workloads[option] + xloc = np.arange(1 + idx, group_width * num_groups, group_width) + plt.bar(xloc, values, width=1, edgecolor="black") + + plt.title(title) + plt.legend(list(workloads.keys()), loc="upper right") + plt.ylabel(units) + + ax = plt.gca() + xticks_loc = np.arange( + 1 + (group_len - 1) / 2.0, group_width * num_groups, group_width + ) + ax.set_xticks(xticks_loc, xlabels) + plt.xticks(rotation=45) + + plt.setp(ax.xaxis.get_majorticklabels(), ha="right") + ax.set_axisbelow(True) + ax.yaxis.grid(color="gray", linestyle="dashed") + ax.xaxis.grid(color="gray", linestyle="dashed") + + plt.savefig(filename, bbox_inches="tight") + plt.close(f) + + +def rmf(filename: str) -> None: + """Remove a file like rm -f.""" + try: + os.remove(filename) + except FileNotFoundError: + pass + + +@contextlib.contextmanager +def temp_files_ctx(num: int) -> Generator: + """A context to get tempfiles and ensure they are cleaned up.""" + files = [tempfile.mkstemp()[1] for _ in range(num)] + + yield tuple(files) + + # temp files could have been removed, so we use rmf. + for name in files: + rmf(name) + + +META_ALGORITHM = "algorithm" +BASELINE_DESCRIPTIONS = ["eager", "vanilla", "pytorch"] + + +# Serialize/unserialize to CSV +# We could use pkl, but resort to CSV for readability +def _benchmark_results_from_csv(filename: str) -> List[Tuple[Dict[str, Any], Any]]: + parts = os.path.basename(filename).split(".") + env = "" + description = "" + if len(parts) == 3: + env = parts[1] + description = parts[0] + + data = [] + with open(filename, "r") as csvfile: + reader = csv.DictReader(csvfile) + for row in reader: + if description != "" and row["description"] not in BASELINE_DESCRIPTIONS: + row["description"] = description + task_spec = benchmark.utils.common.TaskSpec( + stmt="", + setup="", + global_setup="", + label=row["label"], + sub_label=row["sub_label"], + description=row["description"], + env=env, + num_threads=int(row["num_threads"]), + ) + measurement = benchmark.utils.common.Measurement( + number_per_run=1, + raw_times=[float(row["runtime_us"]) / (1000.0 * 1000)], + task_spec=task_spec, + ) + measurement.mem_use = float(row["mem_use_mb"]) # type: ignore + data.append( + ( + { + META_ALGORITHM: ( + row["algorithm"] if row["algorithm"] != "" else None + ), + }, + measurement, + ) + ) + return data + + +def _benchmark_results_to_csv( + filename: str, results: List[Tuple[Dict[str, Any], Any]] +) -> None: + data = [ + { + "sub_label": r.task_spec.sub_label, + "label": r.task_spec.label, + "num_threads": r.task_spec.num_threads, + "algorithm": metadata.get(META_ALGORITHM, ""), + "description": ( + r.task_spec.description + if r.task_spec.description in BASELINE_DESCRIPTIONS + else "" + ), + "runtime_us": int(1000 * 1000 * r.mean), + "mem_use_mb": r.mem_use, + } + for metadata, r in results + ] + with open(filename, "w+", newline="") as csvfile: + writer = csv.DictWriter(csvfile, fieldnames=list(data[0].keys())) + writer.writeheader() + for d in data: + writer.writerow(d) + + +def _finalize_results(results: List[Tuple[Dict[str, Any], Any]]) -> List[Any]: + """ + Returns a `benchmark.Compare` object, except that if we have runs + with different algorithms, we also add the algorithm name + in the column titles + """ + all_algorithms: Set[str] = set() + all_description: Set[str] = set() + for metadata, r in results: + algo = metadata.get(META_ALGORITHM, None) + if algo is not None: + all_algorithms.add(algo) + all_description.add(r.task_spec.description) + display_algo = len(all_algorithms) > 1 + display_descr = len(all_description) > 1 + + display_results = [] + for metadata, r in results: + algo = metadata.get(META_ALGORITHM, None) + if algo is None: + display_results.append(r) + else: + r = copy.copy(r) + description = "" + if display_descr: + description = r.task_spec.description + if display_algo: + if display_descr: + description += "[" + description += algo + if display_descr: + description += "]" + r.task_spec = replace(r.task_spec, description=description) + display_results.append(r) + return display_results + + +def _render_bar_plot(results: List[Any], store_results_folder: str) -> None: + if not results: + return + runtime: Dict[str, Dict[str, float]] = defaultdict(dict) + memory_usage: Dict[str, Dict[str, float]] = defaultdict(dict) + all_descriptions: List[str] = [] + for r in results: + # Hacky: use a list to preserve order + if r.task_spec.description not in all_descriptions: + if r.task_spec.description in BASELINE_DESCRIPTIONS: + all_descriptions.insert(0, r.task_spec.description) + else: + all_descriptions.append(r.task_spec.description) + runtime[r.task_spec.sub_label][r.task_spec.description] = r.mean + memory_usage[r.task_spec.sub_label][r.task_spec.description] = r.mem_use + all_data_mem: List[Any] = [] + all_data_run: List[Any] = [] + for key, runtime_values in runtime.items(): + memory_values = memory_usage[key] + denom = memory_values.get(all_descriptions[0], math.inf) + if denom == 0: + all_data_mem.append([key] + [0] * len(all_descriptions)) + else: + all_data_mem.append( + [key] + [memory_values.get(d, 0) / denom for d in all_descriptions] + ) + all_data_run.append( + [key] + + [ + runtime_values.get(all_descriptions[0], 0) + / runtime_values.get(d, math.inf) + for d in all_descriptions + ] + ) + if all_descriptions[0] == "": + all_descriptions[0] = "baseline" + else: + all_descriptions[0] = f"{all_descriptions[0]} (baseline)" + + for data, filename, title in [ + (all_data_mem, "mem.png", "Memory usage (vs baseline, lower is better)"), + ( + all_data_run, + "runtime.png", + "Runtime speedup (vs baseline, higher is better)", + ), + ]: + df = pd.DataFrame(data, columns=["Configuration"] + all_descriptions) + df.plot( + x="Configuration", + kind="bar", + stacked=False, + title=title, + ) + plt.tight_layout() + filename_full = os.path.join(store_results_folder, filename) + plt.savefig(filename_full) + print(f"Saved plot: {filename_full}") + + +def create_argparser() -> argparse.ArgumentParser: + """ + Create CLI argument parser. + """ + parser = argparse.ArgumentParser() + parser.add_argument( + "--fn", default=None, type=str, help="Only benchmark this function" + ) + parser.add_argument( + "--label", default=None, type=str, help="Store results to a file" + ) + parser.add_argument( + "--fail_if_regression", + action="store_true", + help="Enabled in CI to check against performance regressions", + ) + parser.add_argument( + "--compare", + default=None, + type=str, + help="Compare to previously stored benchmarks (coma separated)", + ) + parser.add_argument( + "--omit-baselines", + action="store_true", + help="Do not run the (potentially slow) baselines", + ) + parser.add_argument( + "--quiet", + action="store_true", + help="Skip intermediate results and progress bar", + ) + return parser + + +def benchmark_main_helper( + benchmark_fn, cases: List[Dict[str, Any]], arg_parser=None, **kwargs +) -> None: + """ + Helper function to run benchmarks. + Supports loading previous results for comparison, and saving current results to file. + """ + arg_parser = arg_parser or create_argparser() + args = arg_parser.parse_args() + + if args.fn is not None and args.fn != get_func_name(benchmark_fn): + print(f'Skipping benchmark "{get_func_name(benchmark_fn)}"') + return + benchmark_run_and_compare( + benchmark_fn=benchmark_fn, + cases=cases, + optimized_label="optimized" if args.label is None else args.label, + fail_if_regression=args.fail_if_regression, + compare=args.compare.split(",") if args.compare is not None else [], + quiet=args.quiet, + omit_baselines=args.omit_baselines, + **kwargs, + ) + + +def benchmark_run_and_compare( + benchmark_fn, + cases: List[Dict[str, Any]], + compare: List[str], + omit_baselines: bool = False, + fail_if_regression: bool = False, + quiet: bool = False, + optimized_label: str = "optimized", + *, + min_run_time: float = 2.0, + atol_s: float = 30e-6, + rtol: float = 0.05, +) -> None: + SKIP_VANILLA_TASKS_IF_ALREADY_DONE = True + results_compare_to = [] + results = [] + + store_results_folder = os.path.expanduser( + os.path.join( + os.environ.get( + "XFORMERS_BENCHMARKS_CACHE", + os.path.join("~", ".cache", "xformers", "benchmarks"), + ), + get_func_name(benchmark_fn), + ) + ) + + try: + env = ( + torch.cuda.get_device_name(torch.cuda.current_device()) + .replace(" ", "_") + .replace("-", "_") + .replace(".", "_") + .replace("/", "_") + ) + except (RuntimeError, AssertionError): # No GPU + env = "cpu" + assert ( + "." not in optimized_label + ), f"label=`{optimized_label}` should not contain dots" + assert "." not in env, f"env=`{env}` should not contain dots" + + os.makedirs(store_results_folder, exist_ok=True) + + # Load runs that we want to compare to + skip_vanilla_tasks = set() + for cmp_name in compare: + name_with_env = cmp_name if "." in cmp_name else f"{cmp_name}.*" + for filename in glob.glob( + os.path.join(store_results_folder, f"{name_with_env}.csv") + ): + loaded = _benchmark_results_from_csv(filename) + for m, r in loaded: + if m.get(META_ALGORITHM) is not None: + m[META_ALGORITHM] = m[META_ALGORITHM].partition("@")[0] + if r.task_spec.env == env and SKIP_VANILLA_TASKS_IF_ALREADY_DONE: + skip_vanilla_tasks.add( + (r.task_spec.sub_label, r.task_spec.num_threads) + ) + results_compare_to += loaded + + if not quiet: + pbar = tqdm.tqdm(cases, leave=False) + cases = pbar + for case in cases: + if quiet: + print(str(case)) + else: + pbar.write(f"====== {str(case)} ======") + try: + benchmarks_generator = benchmark_fn(**case) + except NotImplementedError: + # pbar.write(f"Skipped (NotImplementedError)") + continue + except RuntimeError as e: + if not _is_oom_error(e): + raise + if not quiet: + pbar.write("Skipped (OOM)") + continue + + name = None + try: + for benchmark_object in benchmarks_generator: + is_optimized = ( + benchmark_object._task_spec.description not in BASELINE_DESCRIPTIONS + ) + metadata = {} + if is_optimized: + metadata[META_ALGORITHM] = benchmark_object._task_spec.description + benchmark_object._task_spec = replace( + benchmark_object._task_spec, description=optimized_label + ) + elif ( + omit_baselines + or ( + benchmark_object._task_spec.sub_label, + benchmark_object._task_spec.num_threads, + ) + in skip_vanilla_tasks + ): + continue + + memory = math.inf + try: + torch.cuda.synchronize() + torch.cuda.reset_peak_memory_stats() + mem_begin = torch.cuda.max_memory_allocated() / 2**20 + benchmark_object._task_spec = replace( + benchmark_object._task_spec, env=env + ) + measurement = benchmark_object.blocked_autorange( + min_run_time=min_run_time + ) + torch.cuda.synchronize() + results.append((metadata, measurement)) + name = measurement.task_spec.description + memory = torch.cuda.max_memory_allocated() / 2**20 - mem_begin + measurement.mem_use = memory + except RuntimeError as e: + if not _is_oom_error(e): + raise + if not quiet: + pbar.write("Skipped (OOM)") + finally: + del benchmark_object + if not quiet: + pbar.write(f"{name}: memory used: {memory} MB") + except RuntimeError as e: + if not _is_oom_error(e): + raise + if not quiet: + pbar.write("Skipped (OOM)") + # Display results for benchmarks we just calculated + if name is not None and not quiet: + + def matches_current(r): + return ( + r[1].task_spec.sub_label == results[-1][1].task_spec.sub_label + and r[1].task_spec.label == results[-1][1].task_spec.label + ) + + pbar.write( + str( + benchmark.Compare( + _finalize_results( + list(filter(matches_current, results)) + + list(filter(matches_current, results_compare_to)) + ) + ) + ) + ) + + results_for_print = _finalize_results(results + results_compare_to) + benchmark.Compare(results_for_print).print() + _render_bar_plot(results_for_print, store_results_folder) + + # Save runs to a file + if results and optimized_label is not None: + write_to_path = os.path.join( + store_results_folder, f"{optimized_label}.{env}.csv" + ) + _benchmark_results_to_csv(write_to_path, results) + print(f"Saved results to {write_to_path}") + + if fail_if_regression: + _fail_if_regressions( + results, reference=results_compare_to, atol_s=atol_s, rtol=rtol + ) + + +def _is_oom_error(e): + return isinstance( + e, (torch.cuda.OutOfMemoryError, triton.runtime.autotuner.OutOfResources) + ) + + +def _fail_if_regressions( + results: List[Any], reference: List[Any], atol_s: float, rtol: float +) -> None: + def get_measurement_id(r): + return ( + r[0].get(META_ALGORITHM, "").partition("@")[0], + r[1].task_spec.label, + r[1].task_spec.sub_label, + r[1].task_spec.env, + ) + + id_to_result = {} + for r in results: + id_to_result[get_measurement_id(r)] = r[1] + + num_better = 0 + num_worse = 0 + num_nochange = 0 + num_unk = 0 + reference_set = set() + for ref in reference: + if ref[1].task_spec.description in BASELINE_DESCRIPTIONS: + continue + benchmark_id = get_measurement_id(ref) + if benchmark_id in reference_set: + raise ValueError(f"Duplicate benchmark in reference for {benchmark_id}") + reference_set.add(benchmark_id) + if benchmark_id not in id_to_result: + num_unk += 1 + continue + res = id_to_result[benchmark_id] + # If significative change + if abs(ref[1].mean - res.mean) - rtol * ref[1].mean > atol_s: + is_now_better = res.mean < ref[1].mean + if is_now_better: + num_better += 1 + else: + num_worse += 1 + cmp = "IMPROVED" if is_now_better else "REGRESS " + print(cmp, benchmark_id, f"ref={ref[1].mean}", f"now={res.mean}") + else: + num_nochange += 1 + + print("Regression test summary:") + print(f" Better : {num_better}") + print(f" No change: {num_nochange}") + print(f" Worse : {num_worse}") + if num_unk > 0: + print(f" (no ref) : {num_unk}") + benchmarks_run = num_better + num_nochange + num_worse + if num_worse > 1: + raise RuntimeError("At least one benchmark regressed!") + elif num_unk == benchmarks_run: + raise RuntimeError("No reference found") + elif benchmarks_run == 0: + raise RuntimeError("No benchmark was run") + + +def benchmark_main_helper2( + name: str, + functions, + fw: bool = False, + bw: bool = False, + cuda_graph: bool = True, + **kwargs, +) -> None: + assert fw or bw + + def handle_case(**case) -> Iterator[benchmark.Timer]: + for k, benchmark_cls in functions.items(): + try: + benchmark_object = benchmark_cls(**case, bw=bw) + except NotSupportedInputError: + continue + label = benchmark_object.label + label += "fw" if fw else "" + label += "bw" if bw else "" + + def run_one(): + if fw: + benchmark_object.fw() + if bw: + benchmark_object.bw() + + if cuda_graph: + run_one() + g = torch.cuda.CUDAGraph() + with torch.cuda.graph(g): + run_one() + + def run_one(): + g.replay() + + yield benchmark.Timer( + stmt="fn()", + globals={ + "fn": run_one, + }, + label=label, + description=k, + sub_label=benchmark_object.sub_label, + ) + + handle_case.__name__ = name + benchmark_main_helper(handle_case, **kwargs) + + +def product_dict(**kwargs): + keys = kwargs.keys() + vals = kwargs.values() + for instance in itertools.product(*vals): + yield dict(zip(keys, instance)) + + +DTYPE2STR = { + torch.bfloat16: "b16", + torch.half: "f16", + torch.float32: "f32", +} diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/__init__.py b/.venv/lib/python3.11/site-packages/xformers/sparse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7238f5be0142089cbcf25a72f26559c6a9ee0da2 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/sparse/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +from .blocksparse_tensor import BlockSparseTensor # noqa: F401 +from .csr_tensor import SparseCSRTensor # noqa: F401 diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8e15ad12059203390669468101e0de6f81a26480 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/_csr_ops.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/_csr_ops.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b072a985e8d4c0455bd0f127647f4a7e42720e6d Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/_csr_ops.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/blocksparse_tensor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/blocksparse_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e3d465f906b7f46c08d84665a44a57872231a9f6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/blocksparse_tensor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/csr_tensor.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/csr_tensor.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..207adb3c25061253cadc9982214205f315db597b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/csr_tensor.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0eb50dcab1838b9fa89c44bdc431c1b96e4a2a8f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/xformers/sparse/__pycache__/utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/_csr_ops.py b/.venv/lib/python3.11/site-packages/xformers/sparse/_csr_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..3b3115ee1ca58e238b6bb3c68082d453f7ec7b9c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/sparse/_csr_ops.py @@ -0,0 +1,166 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from .utils import _csr_to_coo, _transpose_with_info + + +def _should_use_coo(a, sparsity): + if not a.is_cuda: + return False + B, M, K = a.shape + # amortize overhead of converting from csr to coo + if B < 32 and M < 4096: + return False + if sparsity > 0.995: + return False + if sparsity < 0.9: + return False + if K > 64: + return False + # let's be overly cautious here for now + return sparsity > 0.97 + + +def _should_use_csr_ge(a, sparsity): + if not a.is_cuda: + return False + return sparsity > 0.99 + + +def _sddmm_func(a, b, row_indices, row_offsets, column_indices): + sparsity = 1 - column_indices.shape[0] / (a.shape[1] * b.shape[1]) + if _should_use_coo(a, sparsity): + m = a.shape[-2] + n = b.shape[-2] + # converting from csr to coo has a constant overhead of ~150us + # so only dispatch to it for reasonably large problem sizes + ro, ci = _csr_to_coo(m, n, row_offsets, column_indices) + return torch.ops.xformers.coo_sddmm(a, b, row_indices, ro, ci) + elif _should_use_csr_ge(a, sparsity): + return torch.ops.xformers.csr_sddmm( + a, b, row_indices, row_offsets, column_indices + ) + return torch.ops.xformers.sddmm_sputnik( + a, b, row_indices, row_offsets, column_indices + ) + + +class _SparseSoftmax(torch.autograd.Function): + @staticmethod + def forward(ctx, m, n, row_indices, values, row_offsets, column_indices): + out = torch.ops.xformers.sparse_softmax_sputnik( + m, n, row_indices, values, row_offsets, column_indices + ) + # note: save out and not values, as an optimization step + ctx.save_for_backward(row_indices, out, row_offsets, column_indices) + ctx.size = (m, n) + return out + + @staticmethod + def backward(ctx, grad): + row_indices, out, row_offsets, column_indices = ctx.saved_tensors + m, n = ctx.size + + # gradients w.r.t. values + grad = grad.contiguous() + ga = torch.ops.xformers.sparse_softmax_backward_sputnik( + m, n, row_indices, out, grad, row_offsets, column_indices + ) + + return None, None, None, ga, None, None + + +class _sddmm(torch.autograd.Function): + @staticmethod + def forward(ctx, a, b, row_indices, row_offsets, column_indices, _transp_info): + out = _sddmm_func(a, b, row_indices, row_offsets, column_indices) + + ctx.save_for_backward( + a, b, row_indices, row_offsets, column_indices, *_transp_info + ) + return out + + @staticmethod + def backward(ctx, grad): + ( + a, + b, + row_indices, + row_offsets, + column_indices, + *_transp_info, + ) = ctx.saved_tensors + m, n = a.shape[1], b.shape[1] + + # gradients w.r.t. values + grad = grad.contiguous() + a = a.contiguous() + b = b.contiguous() + + a_grad = torch.ops.xformers.spmm_sputnik( + b, row_indices, grad, row_offsets, column_indices, m + ) + + ( + row_indices_t, + grad_t, + row_offsets_t, + column_indices_t, + ) = _transpose_with_info(grad, _transp_info) + + b_grad = torch.ops.xformers.spmm_sputnik( + a, row_indices_t, grad_t, row_offsets_t, column_indices_t, n + ) + + return a_grad, b_grad, None, None, None, None + + +class _spmm(torch.autograd.Function): + @staticmethod + def forward( + ctx, b, row_indices, values, row_offsets, column_indices, m, _transp_info + ): + b = b.contiguous() + out = torch.ops.xformers.spmm_sputnik( + b, row_indices, values, row_offsets, column_indices, m + ) + + ctx.save_for_backward( + b, row_indices, values, row_offsets, column_indices, *_transp_info + ) + return out + + @staticmethod + def backward(ctx, grad): + ( + b, + row_indices, + values, + row_offsets, + column_indices, + *_transp_info, + ) = ctx.saved_tensors + k = b.shape[1] + + # gradients w.r.t. values + grad = grad.contiguous() + + grad_sparse = _sddmm_func(grad, b, row_indices, row_offsets, column_indices) + + ( + row_indices_t, + values_t, + row_offsets_t, + column_indices_t, + ) = _transpose_with_info(values, _transp_info) + + grad_dense = torch.ops.xformers.spmm_sputnik( + grad, row_indices_t, values_t, row_offsets_t, column_indices_t, k + ) + + return grad_dense, None, grad_sparse, None, None, None, None diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/blocksparse_tensor.py b/.venv/lib/python3.11/site-packages/xformers/sparse/blocksparse_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..a3b9db039b25a4b494591aed9bd425b0d4eeac3c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/sparse/blocksparse_tensor.py @@ -0,0 +1,278 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import torch + +from xformers.ops import masked_matmul + +logger = logging.getLogger("xformers") + + +# TODO: This is all now deprecated because PyTorch has its own blocksparse ops +def _spmm(b, layout, values): + N, nnz, _, block_size = values.shape + br = b.reshape( + b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3] + ) + # perform matmul on blocks + h, r, c = layout.nonzero(as_tuple=True) + temp = values @ br[:, h, c, :] + + linear_idx = h * (b.shape[2] // block_size) + r + out = torch.zeros( + N, + b.shape[1] * layout.shape[-2], + block_size, + b.shape[3], + dtype=b.dtype, + device=b.device, + ) + # now aggregate the results of the different blocks + out.index_add_(1, linear_idx.to(b.device), temp) + out = out.reshape(N, b.shape[1], -1, b.shape[3]) + return out + + +def _softmax(layout, values): + h, r, c = layout.nonzero(as_tuple=True) + norms = torch.logsumexp(values, dim=-1, keepdim=True) + linear_idx = h * layout.shape[1] + r + + out_t = torch.zeros( + norms.shape[0], + layout.shape[0] * layout.shape[1], + norms.shape[2], + norms.shape[3], + dtype=norms.dtype, + device=norms.device, + ) + max_val = norms.max() + out_t.index_add_( + 1, linear_idx.to(values.device), (norms - max_val).exp() + ).clamp_min_(1e-24).log_().add_(max_val) + out = torch.exp(values - out_t[:, linear_idx]) + return out + + +def _sddmm(a, b, layout): + block_size = a.shape[-2] // layout.shape[-2] + a = a.reshape( + a.shape[0], a.shape[1], a.shape[2] // block_size, block_size, a.shape[3] + ) + b = b.reshape( + b.shape[0], b.shape[1], b.shape[2] // block_size, block_size, b.shape[3] + ) + + h, r, c = layout.nonzero(as_tuple=True) + + out = torch.einsum("nhik,nhjk->nhij", a[:, h, r, :, :], b[:, h, c, :, :]) + return out + + +class BlockSparseTensor(torch.Tensor): + @staticmethod + def __new__(cls, values, layout): + kwargs = {} + kwargs["device"] = values.device + kwargs["dtype"] = values.dtype + kwargs["layout"] = values.layout + kwargs["requires_grad"] = values.requires_grad + assert values.ndim == 4 + B, _, block_size, _ = values.shape + C, h, w = layout.shape + # TODO validate shape of layout vs values + shape = (B, C, block_size * h, block_size * w) + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, values, layout): + assert values.shape[-2] == values.shape[-1] + assert ( + values.device == layout.device + ), "Both values and layout need to reside on the same device" + block_size = values.shape[-1] + # TODO: make this check conditioned on the use of Triton + assert block_size >= 16, "Minimum block size is 16, for now at least" + + # Pure blocksparse data + self.__values = values + self.__layout = layout + + def __repr__(self): + return f"block_sparse_tensor(shape={self.shape}, values={self.__values})" + + def values(self): + return self.__values + + @classmethod + def _raw_wrap(cls, values, layout): + matrix = cls.__new__(cls, values, layout) + matrix.__values = values + matrix.__layout = layout + return matrix + + @classmethod + def _wrap(cls, values, bmat): + matrix = cls.__new__(cls, values, bmat.__layout) + matrix.__values = values + matrix.__layout = bmat.__layout + return matrix + + @classmethod + def _bmm(cls, arg0, arg1): + if not (isinstance(arg0, cls) and type(arg1) is torch.Tensor): + return NotImplemented + res = _spmm(arg1, arg0.__layout, arg0.__values) + return res + + @classmethod + def _masked_matmul(cls, a, b, mask): + if not (type(a) is torch.Tensor and type(b) is torch.Tensor): + return NotImplemented + b = b.transpose(-2, -1) + assert b.is_contiguous() + res = _sddmm(a, b, mask.__layout) + return cls._wrap(res, mask) + + @classmethod + def _softmax(cls, arg0, dim): + if not (dim == -1 or dim == 2): + return NotImplemented + res = _softmax(arg0.__layout, arg0.__values) + return cls._wrap(res, arg0) + + @classmethod + def _to(cls, arg0, device): + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + return cls( + arg0.__values.to(device=device), + arg0.__layout, + ) + + @classmethod + def _copy(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + assert arg0.shape == arg1.shape + av0, av1 = arg0.__values, arg1.__values + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__layout, arg1.__layout + av0.resize_as_(av1).copy_(av1) + return arg0 + + @classmethod + def _equal(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + if arg0.shape != arg1.shape: + return False + if not torch.equal(arg0.__values, arg1.__values): + return False + if not torch.equal(arg0.__layout, arg1.__layout): + return False + return True + + @classmethod + def _to_dense(cls, arg0): + # out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device, requires_grad=arg0.requires_grad) + out = torch.zeros(arg0.shape, dtype=arg0.dtype, device=arg0.device) + values = arg0.__values + layout = arg0.__layout + block_size = values.shape[-1] + blocks_i = layout.shape[-2] + blocks_j = layout.shape[-1] + + out_r = out.reshape( + arg0.shape[0], arg0.shape[1], blocks_i, block_size, blocks_j, block_size + ) + + for idx, (h, i, j) in enumerate(zip(*layout.nonzero(as_tuple=True))): + out_r[:, h, i, :, j, :] = values[:, idx, :, :] + + return out + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in [ + torch.Tensor.bmm, + torch.bmm, + torch.Tensor.__matmul__, + torch.matmul, + torch.Tensor.matmul, + ]: + assert len(args) == 2 + return cls._bmm(args[0], args[1]) + + if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]: + return cls._softmax(args[0], kwargs["dim"]) + + if func == masked_matmul: + assert len(args) == 3 + return cls._masked_matmul(args[0], args[1], args[2]) + + if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]: + x = args[0] + values = x.__values.clone() + values = func(values, *args[1:], **kwargs) + return cls._wrap(values, x) + + if func == torch.Tensor.to: + # print(args, kwargs) + assert len(args) >= 2 + return cls._to(args[0], args[1]) + # return cls._to(args[0], kwargs["device"]) + + if func in [torch.Tensor.copy_]: + assert len(args) == 2 + return cls._copy(args[0], args[1]) + + if func in [torch.Tensor.equal, torch.equal]: + assert len(args) == 2 + return cls._equal(args[0], args[1]) + + if func == torch.Tensor.to_dense: + assert len(args) == 1 + return cls._to_dense(args[0]) + + if func == torch.Tensor.detach: + x = args[0] + values = x.__values.clone() + values = func(values, *args[1:], **kwargs) + return cls._wrap(values, x) + + if func == torch.Tensor.__deepcopy__: + x = args[0] + memo = args[1] + return cls._raw_wrap( + x.__values.__deepcopy__(memo), + x.__layout.__deepcopy__(memo), + ) + + if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]: + assert len(args) == 1 + assert len(kwargs) == 0 + x = args[0] + return cls._wrap(x.__values.grad, x) + + if func == torch.Tensor.requires_grad_: + func(args[0].__values) + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + # TODO: check this + if func in torch.overrides.get_default_nowrap_functions(): + return ret + return torch._tensor._convert(ret, cls) + + return NotImplemented + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + return NotImplemented diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/csr_tensor.py b/.venv/lib/python3.11/site-packages/xformers/sparse/csr_tensor.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec9846c39f204b42f057409fd2dc97787af33b6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/sparse/csr_tensor.py @@ -0,0 +1,437 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + +from xformers.ops import masked_matmul +from xformers.sparse import _csr_ops +from xformers.sparse.utils import ( + _csr_to_coo, + _dense3d_to_sparse, + _diffsort, + _get_transpose_info, + _transpose_with_info, +) + + +class SparseCSRTensor(torch.Tensor): + @staticmethod + def __new__(cls, row_offsets, column_indices, values, shape): + kwargs = {} + kwargs["device"] = values.device + kwargs["dtype"] = values.dtype + kwargs["layout"] = values.layout + kwargs["requires_grad"] = values.requires_grad + assert len(shape) == 3 + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) + + def __init__(self, row_offsets, column_indices, values, shape): + assert row_offsets.ndim == 1 + assert column_indices.ndim == 1 + assert values.ndim == 2 + + self.__row_offsets = row_offsets.contiguous() + self.__row_indices = _diffsort(row_offsets).to(row_offsets.dtype) + self.__column_indices = column_indices.contiguous() + self.__values = values.contiguous() + + self.__transp_info = _get_transpose_info( + self.shape[1], + self.shape[2], + self.__row_indices, + self.__row_offsets, + self.__column_indices, + ) + + def __repr__(self): + return f"sparse_csr_tensor(shape={self.shape}, values={self.__values})" + + @classmethod + def from_dense(cls, matrix): + values, row_indices, row_offsets, column_indices = _dense3d_to_sparse( + matrix, matrix.device + ) + return cls(row_offsets, column_indices, values, matrix.shape) + + @classmethod + def from_sparse_coo(cls, arg0): + """ + assert arg0.is_sparse + x = arg0.coalesce() + rows, cols = x.indices().unbind(0) + vals = x.values() + _coo_to_csr() + """ + pass + + @classmethod + def _wrap( + cls, shape, values, row_indices, row_offsets, column_indices, _transp_info + ): + matrix = cls.__new__(cls, row_offsets, column_indices, values, shape) + matrix.__values = values + matrix.__row_indices = row_indices + matrix.__row_offsets = row_offsets + matrix.__column_indices = column_indices + matrix.__transp_info = _transp_info + return matrix + + def values(self): + return self.__values + + @property + def _csr_row_indices(self): + return self.__row_indices + + @property + def _csr_row_offsets(self): + return self.__row_offsets + + @property + def _csr_column_indices(self): + return self.__column_indices + + @property + def _csr_transp_info(self): + return self.__transp_info + + @classmethod + def _bmm(cls, arg0, arg1): + if not (isinstance(arg0, cls) and type(arg1) is torch.Tensor): + return NotImplemented + + assert arg0.ndim == 3 + assert arg1.ndim == 3 + + self = arg0 + b = arg1 + + _, m, n = self.shape + row_indices = self.__row_indices + values = self.__values + row_offsets = self.__row_offsets + column_indices = self.__column_indices + + out = _csr_ops._spmm.apply( + b, row_indices, values, row_offsets, column_indices, m, self.__transp_info + ) + return out + + @classmethod + def _softmax(cls, arg0, dim): + if not (dim == -1 or dim == 2): + return NotImplemented + + self = arg0 + _, m, n = self.shape + row_indices = self.__row_indices + values = self.__values + row_offsets = self.__row_offsets + column_indices = self.__column_indices + out = _csr_ops._SparseSoftmax.apply( + m, n, row_indices, values, row_offsets, column_indices + ) + return cls._wrap( + self.shape, + out, + row_indices, + row_offsets, + column_indices, + self.__transp_info, + ) + + @classmethod + def _transpose(cls, arg0, dim0, dim1): + # TODO: check if need to return this or not + if not (dim0 == 1 or dim0 == -2): + return NotImplemented + if not (dim1 == 2 or dim1 == -1): + return NotImplemented + + B, m, n = arg0.shape + values = arg0.__values + + ( + output_row_indices, + output_values, + output_row_offsets, + output_column_indices, + ) = _transpose_with_info(values, arg0.__transp_info) + new_transp_info = _get_transpose_info( + n, m, output_row_indices, output_row_offsets, output_column_indices + ) + + return cls._wrap( + (B, n, m), + output_values, + output_row_indices, + output_row_offsets, + output_column_indices, + new_transp_info, + ) + + @classmethod + def _masked_matmul(cls, a, b, mask): + if not (type(a) is torch.Tensor and type(b) is torch.Tensor): + return NotImplemented + assert mask.shape[1] == a.shape[1] + assert mask.shape[2] == b.shape[2] + row_indices = mask.__row_indices + row_offsets = mask.__row_offsets + column_indices = mask.__column_indices + a = a.contiguous() + out = _csr_ops._sddmm.apply( + a, + b.transpose(-2, -1).contiguous(), + row_indices, + row_offsets, + column_indices, + mask.__transp_info, + ) + # TODO add bias here + return cls._wrap( + mask.shape, + out, + row_indices, + row_offsets, + column_indices, + mask.__transp_info, + ) + + @classmethod + def _to(cls, arg0, device): + if isinstance(device, str): + device = torch.device(device) + assert isinstance(device, torch.device) + return cls._wrap( + arg0.shape, + arg0.__values.to(device=device), + arg0.__row_indices.to(device=device), + arg0.__row_offsets.to(device=device), + arg0.__column_indices.to(device=device), + tuple(t.to(device=device) for t in arg0.__transp_info), + ) + + @classmethod + def _copy(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + assert arg0.shape == arg1.shape + av0, av1 = arg0.__values, arg1.__values + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__row_indices, arg1.__row_indices + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__row_offsets, arg1.__row_offsets + av0.resize_as_(av1).copy_(av1) + av0, av1 = arg0.__column_indices, arg1.__column_indices + av0.resize_as_(av1).copy_(av1) + for v0, v1 in zip(arg0.__transp_info, arg1.__transp_info): + v0.resize_as_(v1).copy_(v1) + return arg0 + + @classmethod + def _equal(cls, arg0, arg1): + if not (isinstance(arg0, cls) and isinstance(arg1, cls)): + return NotImplemented + if arg0.shape != arg1.shape: + return False + if not torch.equal(arg0.__values, arg1.__values): + return False + if not torch.equal(arg0.__row_offsets, arg1.__row_offsets): + return False + if not torch.equal(arg0.__column_indices, arg1.__column_indices): + return False + return True + + @classmethod + def _to_dense(cls, arg0): + _, m, n = arg0.shape + shape = arg0.shape + matrix = torch.zeros(shape, dtype=arg0.dtype, device=arg0.device) + row_offsets = arg0.__row_offsets.long() + column_indices = arg0.__column_indices.long() + row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices) + b_idxs = torch.arange(len(arg0.__values), device=arg0.device)[:, None] + matrix[b_idxs, row_coo, column_indices] = arg0.__values + return matrix + + @classmethod + def _binary_op(cls, func, arg0, arg1): + if not ( + isinstance(arg0, (cls, int, float)) and isinstance(arg1, (cls, int, float)) + ): + return NotImplemented + v0, v1 = arg0, arg1 + if isinstance(arg0, cls): + v0 = arg0.__values + if isinstance(arg1, cls): + v1 = arg1.__values + # assert arg0.shape == arg1.shape + if isinstance(arg0, cls) and isinstance(arg1, cls): + msg = f"arg0 and arg1 need to have the same sparsity pattern in {func} (for now)" + if not arg0.__row_offsets.shape == arg1.__row_offsets.shape: + raise NotImplementedError(msg) + if not arg0.__column_indices.shape == arg1.__column_indices.shape: + raise NotImplementedError(msg) + if not arg0.__values.shape == arg1.__values.shape: + raise NotImplementedError(msg) + # TODO this is not always true, but is a fast approximation for now + if arg0.__row_offsets is not arg1.__row_offsets: + raise NotImplementedError(msg) + if arg0.__column_indices is not arg1.__column_indices: + raise NotImplementedError(msg) + out = func(v0, v1) + return cls._wrap( + arg0.shape, + out, + arg0.__row_indices, + arg0.__row_offsets, + arg0.__column_indices, + arg0.__transp_info, + ) + + @classmethod + def _binary_op_slow(cls, func, arg0, arg1): + # assert arg0.shape == arg1.shape + v0, v1 = arg0, arg1 + if isinstance(arg0, cls): + v0 = arg0.to_dense() + if isinstance(arg1, cls): + v1 = arg1.to_dense() + out = func(v0, v1) + return cls.from_dense(out) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + if func in [ + torch.Tensor.bmm, + torch.bmm, + torch.Tensor.__matmul__, + torch.matmul, + torch.Tensor.matmul, + ]: + assert len(args) == 2 + return cls._bmm(args[0], args[1]) + + if func in [torch.Tensor.softmax, torch.nn.functional.softmax, torch.softmax]: + return cls._softmax(args[0], kwargs["dim"]) + + if func in [torch.Tensor.transpose, torch.transpose]: + assert len(kwargs) == 0 + return cls._transpose(args[0], args[1], args[2]) + + if func == masked_matmul: + assert len(args) == 3 + return cls._masked_matmul(args[0], args[1], args[2]) + + if func in [ + torch.Tensor.add, + torch.add, + torch.Tensor.__add__, + ]: + assert len(args) == 2 + if not (isinstance(args[0], cls) and isinstance(args[1], cls)): + raise NotImplementedError( + f"{func} with {type(args[0])} and {type(args[1])} not implemented" + ) + return cls._binary_op(func, args[0], args[1]) + + if func in [ + torch.Tensor.mul, + torch.mul, + torch.Tensor.__mul__, + ]: + assert len(args) == 2 + return cls._binary_op(func, args[0], args[1]) + + if func in [torch.Tensor.logical_and, torch.logical_and, torch.Tensor.__and__]: + assert len(args) == 2 + return cls._binary_op_slow(func, args[0], args[1]) + + if func in [torch.nn.functional.dropout, torch.dropout, torch.dropout_]: + x = args[0] + values = x.__values.clone() + values = func(values, *args[1:], **kwargs) + return cls._wrap( + x.shape, + values, + x.__row_indices, + x.__row_offsets, + x.__column_indices, + x.__transp_info, + ) + + if func == torch.Tensor.to: + # print(args, kwargs) + assert len(args) >= 2 + return cls._to(args[0], args[1]) + # return cls._to(args[0], kwargs["device"]) + + if func in [torch.Tensor.copy_]: + assert len(args) == 2 + return cls._copy(args[0], args[1]) + + if func in [torch.Tensor.equal, torch.equal]: + assert len(args) == 2 + return cls._equal(args[0], args[1]) + + if func == torch.Tensor.to_dense: + assert len(args) == 1 + return cls._to_dense(args[0]) + + if func == torch.Tensor.detach: + x = args[0] + return cls._wrap( + x.shape, + x.__values.detach(), + x.__row_indices, + x.__row_offsets, + x.__column_indices, + x.__transp_info, + ) + + if func == torch.Tensor.__deepcopy__: + x = args[0] + memo = args[1] + return cls._wrap( + x.shape, + x.__values.__deepcopy__(memo), + x.__row_indices.__deepcopy__(memo), + x.__row_offsets.__deepcopy__(memo), + x.__column_indices.__deepcopy__(memo), + tuple(v.__deepcopy__(memo) for v in x.__transp_info), + ) + + if func in [torch.Tensor.grad.__get__, torch.Tensor._grad.__get__]: + assert len(args) == 1 + assert len(kwargs) == 0 + x = args[0] + return cls._wrap( + x.shape, + x.__values.grad, + x.__row_indices, + x.__row_offsets, + x.__column_indices, + x.__transp_info, + ) + + if func == torch.Tensor.requires_grad_: + func(args[0].__values) + + with torch._C.DisableTorchFunction(): + ret = func(*args, **kwargs) + # TODO: check this + if func in torch.overrides.get_default_nowrap_functions(): + return ret + return torch._tensor._convert(ret, cls) + + return NotImplemented + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + return NotImplemented diff --git a/.venv/lib/python3.11/site-packages/xformers/sparse/utils.py b/.venv/lib/python3.11/site-packages/xformers/sparse/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0031e4b0997fb283c6b97dbd6de5cd237dafe55b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/xformers/sparse/utils.py @@ -0,0 +1,123 @@ +# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. +# +# This source code is licensed under the BSD license found in the +# LICENSE file in the root directory of this source tree. + + +import torch + + +def _coo_to_csr(m, n, row_indices, column_indices): + # assumes coalesced coo + row_offsets = row_indices.bincount(minlength=n).cumsum(0, dtype=row_indices.dtype) + row_offsets = torch.nn.functional.pad(row_offsets, (1, 0)) + return row_offsets, column_indices + + +def _csr_to_coo(m, n, row_offsets, column_indices): + # convert from compressed rows to uncompressed + indices = torch.arange(m, dtype=row_offsets.dtype, device=row_offsets.device) + row_sizes = torch.diff(row_offsets) + row_coo = torch.repeat_interleave(indices, row_sizes.long()) + return row_coo, column_indices + + +def _diffsort(a): + return torch.argsort(torch.diff(a), dim=0, descending=True) + + +def _get_transpose_info(m, n, row_indices, row_offsets, column_indices): + # strategy: + # - uncompress the rows to have data in COO format + # - get permutation for stable sort of the columns to get the rows for the transposed matrix + # - compress the new rows and return the permutation to be applied on the values + + # convert from compressed rows to uncompressed + row_coo, _ = _csr_to_coo(m, n, row_offsets, column_indices) + + # get the permutation for the stable sort + row_offsets_t, perm = column_indices.sort(dim=0, stable=True) + column_indices_t = row_coo[perm] + + row_offsets_t, _ = _coo_to_csr(m, n, row_offsets_t, column_indices) + row_indices_t = _diffsort(row_offsets_t).int() + + return row_indices_t, row_offsets_t, column_indices_t, perm + + +def _transpose_with_info(values, _transpose_info): + row_indices_t, row_offsets_t, column_indices_t, perm = _transpose_info + values_t = values[:, perm] + return row_indices_t, values_t, row_offsets_t, column_indices_t + + +def _transpose(m, n, row_indices, values, row_offsets, column_indices): + _transpose_info = _get_transpose_info( + m, n, row_indices, row_offsets, column_indices + ) + return _transpose_with_info(values, _transpose_info) + + +def _nonzero_mask_to_sparse_csr_indices(mask, device): + """Converts dense 2d matrix to a csr sparse matrix.""" + + assert len(mask.shape) == 2 + index_dtype = torch.int32 + + # Calculate the offset of each row. + row_offsets = mask.sum(dim=-1, dtype=index_dtype).cumsum(dim=-1, dtype=index_dtype) + row_offsets = torch.nn.functional.pad(row_offsets, (1, 0)) + + # Create the row indices and sort them. + row_indices = _diffsort(row_offsets).to(index_dtype) + + # Extract the column indices for the nonzero values. + column_indices = torch.where(mask)[1].to(index_dtype).contiguous() + + row_indices = row_indices.to(device) + row_offsets = row_offsets.to(device) + column_indices = column_indices.to(device) + return row_indices, row_offsets, column_indices + + +def _dense_to_sparse(matrix, device): + """Converts dense 2d matrix to a csr sparse matrix.""" + + assert len(matrix.shape) == 2 + value_dtype = torch.float32 + + # Extract the nonzero values. + mask = matrix != 0 + values = matrix[mask].to(dtype=value_dtype, device=device) + + row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices( + mask, device + ) + return values, row_indices, row_offsets, column_indices + + +def _round_nnz(mask, divisible_by=4): + nonzero = torch.where(mask) + nnz = nonzero[0].shape[0] + nonzero = tuple(n[: (nnz - nnz % divisible_by)] for n in nonzero) + nm = torch.zeros_like(mask) + nm[nonzero] = True + return nm + + +def _dense3d_to_sparse(matrix, device): + assert len(matrix.shape) == 3 + mask = matrix != 0 + if not torch.all(mask == mask[0]): + raise ValueError("Expected the same sparsity pattern over the batch dimension") + + # for now, our kernels assume that we have the number of + # nnz to be divisible by 4 + mask = _round_nnz(mask[0], divisible_by=4) + mask = mask[None].expand(matrix.shape) + + values = matrix[mask].reshape(matrix.shape[0], -1).to(device) + row_indices, row_offsets, column_indices = _nonzero_mask_to_sparse_csr_indices( + mask[0], device + ) + return values, row_indices, row_offsets, column_indices