Spaces:
Sleeping
Sleeping
| // Downloaded from from FasterTransformer v5.2.1 | |
| // https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h | |
| /* | |
| * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. | |
| * | |
| * Licensed under the Apache License, Version 2.0 (the "License"); | |
| * you may not use this file except in compliance with the License. | |
| * You may obtain a copy of the License at | |
| * | |
| * http://www.apache.org/licenses/LICENSE-2.0 | |
| * | |
| * Unless required by applicable law or agreed to in writing, software | |
| * distributed under the License is distributed on an "AS IS" BASIS, | |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| * See the License for the specific language governing permissions and | |
| * limitations under the License. | |
| */ | |
| //////////////////////////////////////////////////////////////////////////////////////////////////// | |
| do { \ | |
| cudaError_t status_ = call; \ | |
| if (status_ != cudaSuccess) { \ | |
| fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ | |
| exit(1); \ | |
| } \ | |
| } while (0) | |
| //////////////////////////////////////////////////////////////////////////////////////////////////// | |
| // The structure of parameters for the masked multihead attention kernel. | |
| // | |
| // We use the following terminology to describe the different dimensions. | |
| // | |
| // B: Batch size (number of sequences), | |
| // L: Sequence length, | |
| // D: Hidden dimension, | |
| // H: Number of heads, | |
| // Dh: Hidden dimension per head - Dh = D / H. | |
| template<typename T> | |
| struct Multihead_attention_params_base { | |
| // The output buffer. Dimensions B x D. | |
| T* out = nullptr; | |
| // The input Qs and the associated bias. Dimensions B x D and D, resp. | |
| const T *q = nullptr, *q_bias = nullptr; | |
| // The input Ks and the associated bias. Dimensions B x D and D, resp. | |
| const T *k = nullptr, *k_bias = nullptr; | |
| // The input Vs and the associated bias. Dimensions B x D and D, resp. | |
| const T *v = nullptr, *v_bias = nullptr; | |
| // The cache for the Ks. The size must be at least B x L x D. | |
| T* k_cache = nullptr; | |
| // The cache for the Vs. The size must be at least B x L x D. | |
| T* v_cache = nullptr; | |
| // The indirections to use for cache when beam sampling. | |
| const int* cache_indir = nullptr; | |
| // Stride to handle the case when KQV is a single buffer | |
| int stride_q = 0; | |
| int stride_k = 0; | |
| int stride_v = 0; | |
| // The batch size. | |
| int batch_size = 0; | |
| // The beam width | |
| int beam_width = 0; | |
| // The sequence length. | |
| int memory_max_len = 0; | |
| // The number of heads (H). | |
| int num_heads = 0; | |
| int num_heads_kv = 0; | |
| int num_heads_q_kv_ratio = 0; | |
| // The hidden dimension per head (Dh). | |
| int hidden_size_per_head = 0; | |
| // The per-head latent space reserved for rotary embeddings. | |
| int rotary_embedding_dim = 0; | |
| bool neox_rotary_style = false; | |
| float rotary_base = 0.0f; | |
| // The maximum length of input sentences. | |
| int max_input_length = 0; | |
| // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? | |
| int timestep = 0; | |
| // The current timestep of each sentences (support different timestep for different sentences) | |
| // The 1.f / sqrt(Dh). Computed on the host. | |
| float inv_sqrt_dh = 0.0f; | |
| // Used when we have some input context like gpt | |
| const int* total_padding_tokens = nullptr; | |
| const bool* masked_tokens = nullptr; | |
| const int* prefix_prompt_lengths = nullptr; | |
| int max_prefix_prompt_length = 0; | |
| const T* relative_attention_bias = nullptr; | |
| int relative_attention_bias_stride = 0; | |
| // The slope per head of linear position bias to attention score (H). | |
| const T* linear_bias_slopes = nullptr; | |
| const T* ia3_key_weights = nullptr; | |
| const T* ia3_value_weights = nullptr; | |
| const int* ia3_tasks = nullptr; | |
| const float* qkv_scale_out = nullptr; | |
| const float* attention_out_scale = nullptr; | |
| int int8_mode = 0; | |
| const T *rotary_cos = nullptr; | |
| const T *rotary_sin = nullptr; | |
| const int *nnz_head_idx = nullptr; | |
| int nnz_heads = 0; | |
| }; | |
| template<typename T, bool CROSS_ATTENTION> | |
| struct Multihead_attention_params: public Multihead_attention_params_base<T> { | |
| // output cross attentions | |
| float* cross_attention_out = nullptr; | |
| int max_decoder_seq_len = 0; | |
| bool is_return_cross_attentions = false; | |
| // allows to exist attention eary | |
| bool* finished = nullptr; | |
| // required in case of cross attention | |
| // will need it here till if constexpr in c++17 | |
| int* memory_length_per_sample = nullptr; | |
| // required in case of masked attention with different length | |
| const int* length_per_sample = nullptr; | |
| }; | |
| template<typename T> | |
| struct Multihead_attention_params<T, true>: public Multihead_attention_params_base<T> { | |
| // output cross attentions | |
| float* cross_attention_out = nullptr; | |
| int max_decoder_seq_len = 0; | |
| bool is_return_cross_attentions = false; | |
| // allows to exist attention eary | |
| bool* finished = nullptr; | |
| // required in case of cross attention | |
| int* memory_length_per_sample = nullptr; | |
| // required in case of masked attention with different length | |
| const int* length_per_sample = nullptr; | |
| }; | |
| template<class T> | |
| using Masked_multihead_attention_params = Multihead_attention_params<T, false>; | |
| template<class T> | |
| using Cross_multihead_attention_params = Multihead_attention_params<T, true>; | |
| template<typename T> | |
| struct outputCrossAttentionParam { | |
| // max decoder output length | |
| int max_decoder_seq_len = 0; | |
| T* cross_attention_out = nullptr; | |
| bool is_return_cross_attentions = false; | |
| }; | |
| //////////////////////////////////////////////////////////////////////////////////////////////////// | |
| void masked_multihead_attention(const Masked_multihead_attention_params<float>& params, const cudaStream_t& stream); | |
| void masked_multihead_attention(const Masked_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream); | |
| void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, | |
| const cudaStream_t& stream); | |
| void cross_multihead_attention(const Cross_multihead_attention_params<float>& params, const cudaStream_t& stream); | |
| void cross_multihead_attention(const Cross_multihead_attention_params<uint16_t>& params, const cudaStream_t& stream); | |
| void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, | |
| const cudaStream_t& stream); | |
| //////////////////////////////////////////////////////////////////////////////////////////////////// | |