| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import math |
|
|
| import torch |
| import torch.nn.functional as F |
| from einops import rearrange, repeat |
|
|
| from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import AdapterName, InfusedAdapterConfig |
| from nemo.collections.nlp.modules.common.megatron.fused_softmax import MatchedScaleMaskSoftmax |
| from nemo.collections.nlp.modules.common.megatron.module import MegatronModule |
| from nemo.collections.nlp.modules.common.megatron.rotary_pos_embedding import apply_rotary_pos_emb |
| from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, attention_mask_func |
| from nemo.core import adapter_mixins |
|
|
| try: |
| from apex.transformer import parallel_state, tensor_parallel |
| from apex.transformer.enums import AttnMaskType, AttnType |
| from apex.transformer.utils import divide as safe_divide |
|
|
| HAVE_APEX = True |
|
|
| except (ImportError, ModuleNotFoundError): |
|
|
| HAVE_APEX = False |
|
|
| |
| ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() |
|
|
| """ We use the following notation throughout this file: |
| h: hidden size |
| n: number of attention heads |
| p: number of model parallel partitions |
| np: n/p |
| hp: h/p |
| hn: h/n |
| b: batch size |
| s: sequence length |
| l: number of layers |
| Transformer takes input of size [s, b, h] and returns a |
| tensor of the same size. We use the following arguments: |
| hyperparameters: transformer hyperparameters |
| """ |
|
|
|
|
| class ParallelAttention(MegatronModule, adapter_mixins.AdapterModuleMixin): |
| """Parallel self-attention layer abstract class. |
| |
| Self-attention layer takes input with size [s, b, h] |
| and returns output of the same size. |
| """ |
|
|
| def __init__( |
| self, |
| init_method, |
| output_layer_init_method, |
| layer_number, |
| num_attention_heads, |
| hidden_size, |
| attention_type=AttnType.self_attn, |
| attn_mask_type=AttnMaskType.padding, |
| precision=16, |
| apply_query_key_layer_scaling=True, |
| kv_channels=None, |
| use_cpu_initialization=False, |
| masked_softmax_fusion=True, |
| attention_dropout=0.1, |
| layer_type=None, |
| megatron_legacy=False, |
| bias=True, |
| headscale=False, |
| position_embedding_type='learned_absolute', |
| multi_query_attention=False, |
| activations_checkpoint_granularity=None, |
| sequence_parallel=False, |
| gradient_accumulation_fusion=False, |
| normalize_attention_scores=True, |
| ): |
| super(ParallelAttention, self).__init__() |
|
|
| self.layer_number = max(1, layer_number) |
| self.attention_type = attention_type |
| self.attn_mask_type = attn_mask_type |
| self.normalize_attention_scores = normalize_attention_scores |
| self.position_embedding_type = position_embedding_type |
| self.multi_query_attention = multi_query_attention |
|
|
| self.megatron_legacy = megatron_legacy |
|
|
| self.set_accepted_adapter_types([InfusedAdapterConfig._target_]) |
|
|
| if kv_channels is None: |
| assert ( |
| hidden_size % num_attention_heads == 0 |
| ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' |
| kv_channels = hidden_size // num_attention_heads |
| projection_size = kv_channels * num_attention_heads |
|
|
| |
| world_size = parallel_state.get_tensor_model_parallel_world_size() |
| self.hidden_size_per_attention_head = safe_divide(projection_size, num_attention_heads) |
| self.num_attention_heads_per_partition = safe_divide(num_attention_heads, world_size) |
| self.num_attention_heads_partition_offset = ( |
| self.num_attention_heads_per_partition * parallel_state.get_tensor_model_parallel_rank() |
| ) |
|
|
| no_async_tensor_model_parallel_allreduce = ( |
| parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel |
| ) |
|
|
| |
| if attention_type == AttnType.self_attn: |
| self.query_key_value = tensor_parallel.ColumnParallelLinear( |
| hidden_size, |
| 3 * projection_size, |
| gather_output=False, |
| init_method=init_method, |
| use_cpu_initialization=use_cpu_initialization, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
| else: |
| assert attention_type == AttnType.cross_attn |
| self.query = tensor_parallel.ColumnParallelLinear( |
| hidden_size, |
| projection_size, |
| gather_output=False, |
| init_method=init_method, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| self.key_value = tensor_parallel.ColumnParallelLinear( |
| hidden_size, |
| 2 * projection_size, |
| gather_output=False, |
| init_method=init_method, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| self.core_attention = CoreAttention( |
| layer_number=self.layer_number, |
| num_attention_heads=num_attention_heads, |
| hidden_size=hidden_size, |
| attention_type=self.attention_type, |
| attn_mask_type=self.attn_mask_type, |
| precision=precision, |
| apply_query_key_layer_scaling=apply_query_key_layer_scaling, |
| kv_channels=kv_channels, |
| masked_softmax_fusion=masked_softmax_fusion, |
| attention_dropout=attention_dropout, |
| multi_query_attention=multi_query_attention, |
| sequence_parallel=sequence_parallel, |
| normalize_attention_scores=normalize_attention_scores, |
| ) |
|
|
| |
| self.dense = tensor_parallel.RowParallelLinear( |
| projection_size, |
| hidden_size, |
| input_is_parallel=True, |
| init_method=output_layer_init_method, |
| skip_bias_add=True, |
| use_cpu_initialization=use_cpu_initialization, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| self.headscale = headscale |
| if headscale: |
| self.head_scale_tensor = torch.nn.Parameter( |
| torch.ones(1, self.num_attention_heads_per_partition, 1, 1), requires_grad=True |
| ) |
|
|
| |
| self.inference_key_memory = None |
| self.inference_value_memory = None |
| self.inference_current_sequence_len = 0 |
|
|
| |
| self.layer_type = layer_type |
|
|
| def _checkpointed_attention_forward( |
| self, |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| rotary_pos_emb=None, |
| relative_position_bias=None, |
| headscale_tensor=None, |
| ): |
| """Forward method with activation checkpointing.""" |
|
|
| def custom_forward(*inputs): |
| if len(inputs) == 7: |
| query_layer = inputs[0] |
| key_layer = inputs[1] |
| value_layer = inputs[2] |
| attention_mask = inputs[3] |
| rotary_pos_emb = inputs[4] |
| relative_position_bias = inputs[5] |
| headscale_tensor = inputs[6] |
| elif len(inputs) == 8: |
| query_layer = inputs[0] |
| key_layer = inputs[1] |
| value_layer = inputs[2] |
| attention_mask = inputs[3] |
| rotary_pos_emb = (inputs[4], inputs[5]) |
| relative_position_bias = inputs[6] |
| headscale_tensor = inputs[7] |
| else: |
| raise ValueError('unexpected number of inputs') |
| output_ = self.core_attention( |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| rotary_pos_emb=rotary_pos_emb, |
| relative_position_bias=relative_position_bias, |
| headscale_tensor=headscale_tensor, |
| ) |
| return output_ |
|
|
| if rotary_pos_emb is None: |
| rot_tuple = (rotary_pos_emb,) |
| else: |
| rot_tuple = (rotary_pos_emb[0], rotary_pos_emb[1]) |
|
|
| hidden_states = tensor_parallel.checkpoint( |
| custom_forward, |
| False, |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| *rot_tuple, |
| relative_position_bias, |
| headscale_tensor, |
| ) |
|
|
| return hidden_states |
|
|
| def _allocate_memory(self, inference_max_sequence_len, batch_size, dtype): |
| return torch.empty( |
| inference_max_sequence_len, |
| batch_size, |
| self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head, |
| dtype=dtype, |
| device=torch.cuda.current_device(), |
| ) |
|
|
| def _transpose_last_dim(self, mixed_layer, num_splits, num_splits_first): |
| input_shape = mixed_layer.size() |
| if num_splits_first: |
| """[s, b, num_splits * np * hn] |
| -->(view) [s, b, num_splits, np, hn] |
| -->(tranpose) [s, b, np, num_splits, hn] |
| -->(view) [s, b, np * num_splits * hn] """ |
|
|
| intermediate_shape = input_shape[:-1] + ( |
| num_splits, |
| self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head, |
| ) |
|
|
| mixed_layer = mixed_layer.view(*intermediate_shape) |
| mixed_layer = mixed_layer.transpose(-2, -3).contiguous() |
| else: |
| """[s, b, np * hn * num_splits] |
| -->(view) [s, b, np, hn, num_splits] |
| -->(tranpose) [s, b, np, num_splits, hn] |
| -->(view) [s, b, np * num_splits * hn] """ |
|
|
| intermediate_shape = input_shape[:-1] + ( |
| self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head, |
| num_splits, |
| ) |
|
|
| mixed_layer = mixed_layer.view(*intermediate_shape) |
| mixed_layer = mixed_layer.transpose(-1, -2).contiguous() |
| mixed_layer = mixed_layer.view(*input_shape) |
|
|
| return mixed_layer |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask, |
| layer_past=None, |
| get_key_value=False, |
| encoder_output=None, |
| set_inference_key_value_memory=False, |
| inference_max_sequence_len=None, |
| rotary_pos_emb=None, |
| relative_position_bias=None, |
| checkpoint_core_attention=False, |
| ): |
| |
|
|
| |
| |
| |
| if set_inference_key_value_memory: |
| assert inference_max_sequence_len and inference_max_sequence_len > 0 |
| self.inference_key_memory = self._allocate_memory( |
| inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype |
| ) |
| self.inference_value_memory = self._allocate_memory( |
| inference_max_sequence_len, hidden_states.size(1), hidden_states.dtype |
| ) |
| self.inference_current_sequence_len = 0 |
|
|
| |
| if inference_max_sequence_len: |
| assert self.inference_current_sequence_len < self.inference_key_memory.size(0) |
| assert inference_max_sequence_len == self.inference_key_memory.size(0) |
| |
| |
| |
| if not inference_max_sequence_len: |
| self.inference_key_memory = None |
| self.inference_value_memory = None |
|
|
| |
| |
| |
|
|
| if self.attention_type == AttnType.self_attn: |
| |
| mixed_x_layer, _ = self.query_key_value(hidden_states) |
|
|
| |
| new_tensor_shape = mixed_x_layer.size()[:-1] + ( |
| self.num_attention_heads_per_partition, |
| 3 * self.hidden_size_per_attention_head, |
| ) |
| if self.megatron_legacy: |
| mixed_x_layer = self._transpose_last_dim(mixed_x_layer, 3, True) |
| mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) |
|
|
| |
| (query_layer, key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_x_layer, 3) |
| else: |
| |
| mixed_kv_layer, _ = self.key_value(encoder_output) |
|
|
| |
| new_tensor_shape = mixed_kv_layer.size()[:-1] + ( |
| self.num_attention_heads_per_partition, |
| 2 * self.hidden_size_per_attention_head, |
| ) |
| if self.megatron_legacy: |
| mixed_kv_layer = self._transpose_last_dim(mixed_kv_layer, 2, True) |
| mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) |
|
|
| |
| (key_layer, value_layer) = tensor_parallel.split_tensor_along_last_dim(mixed_kv_layer, 2) |
|
|
| |
| query_layer, _ = self.query(hidden_states) |
| |
| new_tensor_shape = query_layer.size()[:-1] + ( |
| self.num_attention_heads_per_partition, |
| self.hidden_size_per_attention_head, |
| ) |
| query_layer = query_layer.view(*new_tensor_shape) |
|
|
| if self.is_adapter_available(): |
| key_infused_adapter = self.get_adapter_module(AdapterName.KEY_INFUSED) |
| value_infused_adapter = self.get_adapter_module(AdapterName.VALUE_INFUSED) |
| if key_infused_adapter: |
| assert value_infused_adapter is not None, "Expected value_infused_adapter not found!" |
| kls = key_layer.shape |
| key_layer = key_infused_adapter(key_layer.reshape(kls[0], kls[1], -1)).reshape(kls) |
| if value_infused_adapter: |
| assert key_infused_adapter is not None, "Expected key_infused_adapter not found!" |
| vls = value_layer.shape |
| value_layer = value_infused_adapter(value_layer.reshape(vls[0], vls[1], -1)).reshape(vls) |
|
|
| |
| |
| |
|
|
| |
| if rotary_pos_emb is not None: |
| rotary_pos_emb = rotary_pos_emb if isinstance(rotary_pos_emb, tuple) else ((rotary_pos_emb,) * 2) |
|
|
| if inference_max_sequence_len: |
| |
| start = self.inference_current_sequence_len |
| self.inference_current_sequence_len += key_layer.size(0) |
| end = self.inference_current_sequence_len |
| |
| self.inference_key_memory[start:end, ...] = key_layer |
| self.inference_value_memory[start:end, ...] = value_layer |
| key_layer = self.inference_key_memory[:end, ...] |
| value_layer = self.inference_value_memory[:end, ...] |
| |
| attention_mask = attention_mask[..., start:end, :end] |
| |
| if rotary_pos_emb is not None: |
| q_pos_emb, k_pos_emb = rotary_pos_emb |
| if not set_inference_key_value_memory: |
| |
| |
| q_pos_emb = q_pos_emb[end - 1 : end] |
| else: |
| q_pos_emb = q_pos_emb[:end, :, :, :] |
| k_pos_emb = k_pos_emb[:end, :, :, :] |
| rotary_pos_emb = (q_pos_emb, k_pos_emb) |
|
|
| if layer_past is not None: |
| past_key, past_value = layer_past |
| key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=0) |
| value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=0) |
|
|
| if get_key_value: |
| present = (key_layer, value_layer) |
|
|
| if checkpoint_core_attention: |
| context_layer = self._checkpointed_attention_forward( |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| rotary_pos_emb=rotary_pos_emb, |
| relative_position_bias=relative_position_bias, |
| headscale_tensor=self.head_scale_tensor if self.headscale else None, |
| ) |
| else: |
| context_layer = self.core_attention( |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| layer_past=layer_past, |
| get_key_value=get_key_value, |
| rotary_pos_emb=rotary_pos_emb, |
| relative_position_bias=relative_position_bias, |
| headscale_tensor=self.head_scale_tensor if self.headscale else None, |
| ) |
|
|
| |
| |
| |
|
|
| output, bias = self.dense(context_layer) |
|
|
| if get_key_value: |
| output = [output, present] |
|
|
| return output, bias |
|
|
|
|
| class ParallelChunkedCrossAttention(MegatronModule): |
| """Parallel chunked cross-attention layer class. |
| |
| Self-attention layer takes input with size [b, s, h] |
| and returns output of the same size. |
| """ |
|
|
| def __init__( |
| self, |
| init_method, |
| output_layer_init_method, |
| layer_number, |
| num_attention_heads, |
| hidden_size, |
| precision=16, |
| apply_query_key_layer_scaling=True, |
| kv_channels=None, |
| use_cpu_initialization=False, |
| masked_softmax_fusion=True, |
| attention_dropout=0.1, |
| megatron_legacy=False, |
| chunk_size=64, |
| bias=True, |
| headscale=False, |
| gradient_accumulation_fusion=False, |
| normalize_attention_scores=True, |
| ): |
| super(ParallelChunkedCrossAttention, self).__init__() |
| self.cross_attention = ParallelAttention( |
| init_method=init_method, |
| output_layer_init_method=output_layer_init_method, |
| layer_number=layer_number, |
| num_attention_heads=num_attention_heads, |
| hidden_size=hidden_size, |
| attention_type=AttnType.cross_attn, |
| attn_mask_type=AttnMaskType.padding, |
| precision=precision, |
| apply_query_key_layer_scaling=apply_query_key_layer_scaling, |
| kv_channels=kv_channels, |
| use_cpu_initialization=use_cpu_initialization, |
| masked_softmax_fusion=masked_softmax_fusion, |
| attention_dropout=attention_dropout, |
| megatron_legacy=megatron_legacy, |
| bias=bias, |
| headscale=headscale, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| normalize_attention_scores=normalize_attention_scores, |
| ) |
| self.chunk_size = chunk_size |
|
|
| def forward( |
| self, |
| hidden_states, |
| attention_mask, |
| encoder_output=None, |
| set_inference_key_value_memory=False, |
| inference_max_sequence_len=None, |
| rotary_pos_emb=None, |
| checkpoint_core_attention=False, |
| ): |
| if checkpoint_core_attention: |
| raise ValueError( |
| 'checkpoint_core_attention during forward not implemented yet for ParallelChunkedCrossAttention' |
| ) |
|
|
| |
| |
| |
| context = encoder_output |
| |
| chunk_size = self.chunk_size |
| b, n, dim = ( |
| hidden_states.shape[1], |
| hidden_states.shape[0], |
| hidden_states.shape[2], |
| ) |
| default_bias = self.cross_attention.dense.bias |
| if set_inference_key_value_memory: |
| seq_index = (n // chunk_size) * chunk_size |
| self.current_len = n |
| elif inference_max_sequence_len is not None: |
| |
| assert n == 1 |
| self.current_len += n |
| chunk_id = self.current_len // chunk_size |
| if chunk_id <= 0: |
| |
| return torch.zeros_like(hidden_states), default_bias |
| causal_padding = chunk_size - 1 |
| |
| hidden_states = F.pad(hidden_states, (0, 0, 0, 0, causal_padding, 0), value=0.0) |
| |
| context = context[chunk_id - 1 : chunk_id, :, :, :, :] |
| attention_mask = rearrange(attention_mask, '(b k) 1 q v -> b k 1 q v', b=b) |
| |
| attention_mask = attention_mask[:, chunk_id - 1] |
| seq_index = chunk_size |
| else: |
| |
| seq_index = (n // chunk_size) * chunk_size |
|
|
| |
| if n < self.chunk_size and set_inference_key_value_memory and inference_max_sequence_len is not None: |
| return torch.zeros_like(hidden_states), default_bias |
|
|
| num_chunks, num_retrieved = ( |
| context.shape[-5], |
| context.shape[-4], |
| ) |
|
|
| |
| causal_padding = chunk_size - 1 |
|
|
| x = F.pad(hidden_states, (0, 0, 0, 0, -causal_padding, causal_padding), value=0.0) |
|
|
| |
|
|
| |
| x, x_remainder = x[:seq_index], x[seq_index:] |
|
|
| seq_remain_len = x_remainder.shape[0] |
|
|
| |
| |
|
|
| if rotary_pos_emb is not None: |
| q_pos_emb, k_pos_emb = rotary_pos_emb |
| |
| |
| |
| if inference_max_sequence_len is not None and not set_inference_key_value_memory: |
| token_pos = (self.current_len - 1) % chunk_size |
| q_pos_emb = F.pad( |
| q_pos_emb, (0, 0, 0, 0, 0, 0, -causal_padding - token_pos, -causal_padding + token_pos), value=0.0 |
| ) |
| else: |
| q_pos_emb = F.pad(q_pos_emb, (0, 0, 0, 0, 0, 0, -causal_padding, 0), value=0.0) |
|
|
| k_pos_emb = repeat(k_pos_emb, 'n b h d -> (r n) b h d', r=num_retrieved) |
| rotary_pos_emb = (q_pos_emb, k_pos_emb) |
|
|
| |
| assert x.shape[0] // chunk_size == num_chunks |
|
|
| |
| x = rearrange(x, '(k n) b d -> n (b k) d', k=num_chunks) |
| context = rearrange(context, 'k r n b d -> (r n) (b k) d') |
| |
| out, bias = self.cross_attention(x, attention_mask, encoder_output=context, rotary_pos_emb=rotary_pos_emb) |
|
|
| |
|
|
| out = rearrange(out, 'n (b k) d -> (k n) b d', b=b) |
|
|
| |
|
|
| out = F.pad(out, (0, 0, 0, 0, causal_padding, -causal_padding + seq_remain_len), value=0.0) |
| if not set_inference_key_value_memory and inference_max_sequence_len is not None: |
| out = out[-1:] |
| return out, bias |
|
|
|
|
| class CoreAttention(MegatronModule): |
| """ Region where selective activation recomputation is applied. |
| See Figure 3. in Reducing Activation Recomputation in Large Transformer Models |
| https://arxiv.org/pdf/2205.05198.pdf for more details. |
| |
| """ |
|
|
| def __init__( |
| self, |
| layer_number, |
| num_attention_heads, |
| hidden_size, |
| attention_type=AttnType.self_attn, |
| attn_mask_type=AttnMaskType.padding, |
| precision=16, |
| apply_query_key_layer_scaling=True, |
| kv_channels=None, |
| masked_softmax_fusion=True, |
| attention_dropout=0.1, |
| sequence_parallel=False, |
| normalize_attention_scores=True, |
| multi_query_attention=False, |
| ): |
|
|
| super(CoreAttention, self).__init__() |
|
|
| self.precision = precision |
| self.fp16 = precision == 16 |
| self.bf16 = precision == 'bf16' |
| self.multi_query_attention = multi_query_attention |
|
|
| self.apply_query_key_layer_scaling = apply_query_key_layer_scaling |
| self.attention_softmax_in_fp32 = False |
| if self.apply_query_key_layer_scaling: |
| self.attention_softmax_in_fp32 = True |
| self.layer_number = max(1, layer_number) |
| self.attention_type = attention_type |
| self.attn_mask_type = attn_mask_type |
| self.sequence_parallel = sequence_parallel |
| |
| |
| self.normalize_attention_scores = normalize_attention_scores |
|
|
| if kv_channels is None: |
| assert ( |
| hidden_size % num_attention_heads == 0 |
| ), 'hidden_size must be divisible by num_attention_heads if kv_channels is None' |
| kv_channels = hidden_size // num_attention_heads |
|
|
| projection_size = kv_channels * num_attention_heads |
|
|
| |
| world_size = parallel_state.get_tensor_model_parallel_world_size() |
| self.hidden_size_per_partition = safe_divide(projection_size, world_size) |
| self.hidden_size_per_attention_head = safe_divide(projection_size, num_attention_heads) |
| self.num_attention_heads_per_partition = safe_divide(num_attention_heads, world_size) |
| self.num_attention_heads_partition_offset = ( |
| self.num_attention_heads_per_partition * parallel_state.get_tensor_model_parallel_rank() |
| ) |
|
|
| coeff = None |
| self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) |
| if self.apply_query_key_layer_scaling: |
| coeff = self.layer_number |
| self.norm_factor *= coeff |
|
|
| self.scale_mask_softmax = MatchedScaleMaskSoftmax( |
| self.fp16, |
| self.bf16, |
| self.attn_mask_type, |
| masked_softmax_fusion, |
| attention_mask_func, |
| self.attention_softmax_in_fp32, |
| coeff, |
| ) |
|
|
| |
| |
| |
| self.attention_dropout = torch.nn.Dropout(attention_dropout) |
|
|
| def forward( |
| self, |
| query_layer, |
| key_layer, |
| value_layer, |
| attention_mask, |
| layer_past=None, |
| get_key_value=False, |
| rotary_pos_emb=None, |
| relative_position_bias=None, |
| headscale_tensor=None, |
| ): |
|
|
| |
| |
| |
|
|
| |
| output_size = (query_layer.size(1), query_layer.size(2), query_layer.size(0), key_layer.size(0)) |
|
|
| |
| |
| if rotary_pos_emb is not None: |
| q_pos_emb, k_pos_emb = rotary_pos_emb |
|
|
| query_layer = apply_rotary_pos_emb(query_layer, q_pos_emb) |
| key_layer = apply_rotary_pos_emb(key_layer, k_pos_emb) |
| |
| |
| |
| |
|
|
| if self.multi_query_attention: |
| |
| query_layer = query_layer.permute([1, 2, 0, 3]).reshape( |
| output_size[0], output_size[1] * output_size[2], -1 |
| ) |
|
|
| |
| key_layer = key_layer.squeeze(2).permute(1, 2, 0) |
|
|
| |
| matmul_input_buffer = torch.empty( |
| output_size[0] * output_size[1], |
| output_size[2], |
| output_size[3], |
| dtype=query_layer.dtype, |
| device=torch.cuda.current_device(), |
| ) |
|
|
| |
| matmul_result = torch.baddbmm( |
| matmul_input_buffer, |
| query_layer, |
| key_layer, |
| beta=0.0, |
| alpha=(1.0 / self.norm_factor), |
| ) |
| else: |
| |
| query_layer = query_layer.view(output_size[2], output_size[0] * output_size[1], -1) |
| |
| key_layer = key_layer.view(output_size[3], output_size[0] * output_size[1], -1) |
|
|
| |
| matmul_input_buffer = torch.empty( |
| output_size[0] * output_size[1], |
| output_size[2], |
| output_size[3], |
| dtype=query_layer.dtype, |
| device=torch.cuda.current_device(), |
| ) |
|
|
| |
| matmul_result = torch.baddbmm( |
| matmul_input_buffer, |
| query_layer.transpose(0, 1), |
| key_layer.transpose(0, 1).transpose(1, 2), |
| beta=0.0, |
| alpha=(1.0 / self.norm_factor) if self.normalize_attention_scores else 1.0, |
| ) |
|
|
| |
| attention_scores = matmul_result.view(*output_size) |
|
|
| if relative_position_bias is not None: |
| attention_scores += relative_position_bias[ |
| :, |
| self.num_attention_heads_partition_offset : self.num_attention_heads_partition_offset |
| + self.num_attention_heads_per_partition, |
| : attention_scores.size(2), |
| : attention_scores.size(3), |
| ] |
|
|
| |
| |
| |
|
|
| if get_key_value: |
| with torch.no_grad(): |
| if layer_past is not None: |
| attention_mask = attention_mask[ |
| ..., attention_scores.size(3) - 1, : attention_scores.size(3) |
| ].unsqueeze(2) |
| else: |
| attention_mask = attention_mask[..., : attention_scores.size(3), : attention_scores.size(3)] |
|
|
| |
| |
| |
|
|
| |
| attention_probs = self.scale_mask_softmax(attention_scores, attention_mask) |
|
|
| |
| |
|
|
| if not self.sequence_parallel: |
| with tensor_parallel.random.get_cuda_rng_tracker().fork(): |
| attention_probs = self.attention_dropout(attention_probs) |
| else: |
| attention_probs = self.attention_dropout(attention_probs) |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| output_size = (value_layer.size(1), value_layer.size(2), query_layer.size(0), value_layer.size(3)) |
|
|
| |
| value_layer = value_layer.view(value_layer.size(0), output_size[0] * output_size[1], -1) |
|
|
| |
| attention_probs = attention_probs.view(output_size[0] * output_size[1], output_size[2], -1) |
|
|
| |
| context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) |
|
|
| |
| context_layer = context_layer.view(*output_size) |
|
|
| if headscale_tensor is not None: |
| context_layer = context_layer * headscale_tensor |
|
|
| |
| context_layer = context_layer.permute(2, 0, 1, 3).contiguous() |
|
|
| |
| new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) |
| context_layer = context_layer.view(*new_context_layer_shape) |
|
|
| return context_layer |
|
|