| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
| import torch.nn.functional as F |
|
|
| from nemo.collections.nlp.modules.common.megatron.adapters.parallel_adapters import ( |
| AdapterName, |
| MLPInfusedAdapterConfig, |
| ) |
| from nemo.collections.nlp.modules.common.megatron.fused_bias_geglu import fused_bias_geglu |
| from nemo.collections.nlp.modules.common.megatron.fused_bias_gelu import fused_bias_gelu |
| from nemo.collections.nlp.modules.common.megatron.fused_layer_norm import get_layer_norm |
| from nemo.collections.nlp.modules.common.megatron.layer_norm_1p import LayerNorm1P |
| from nemo.collections.nlp.modules.common.megatron.module import MegatronModule |
| from nemo.collections.nlp.modules.common.megatron.utils import ApexGuardDefaults, erf_gelu |
| from nemo.collections.nlp.modules.common.megatron.utils import openai_gelu as openai_gelu_func |
| from nemo.collections.nlp.modules.common.megatron.utils import squared_relu |
| from nemo.core import adapter_mixins |
|
|
| try: |
| from apex.normalization import MixedFusedRMSNorm |
| from apex.transformer import parallel_state, tensor_parallel |
| from apex.transformer.parallel_state import get_tensor_model_parallel_world_size |
|
|
| HAVE_APEX = True |
|
|
| except (ImportError, ModuleNotFoundError): |
|
|
| HAVE_APEX = False |
|
|
| |
| ModelType = AttnMaskType = AttnType = LayerType = ApexGuardDefaults() |
|
|
|
|
| class ParallelMLP(MegatronModule, adapter_mixins.AdapterModuleMixin): |
| """MLP. |
| |
| MLP will take the input with h hidden state, project it to 4*h |
| hidden dimension, perform nonlinear transformation, and project the |
| state back into h hidden dimension. |
| """ |
|
|
| def __init__( |
| self, |
| init_method, |
| output_layer_init_method, |
| hidden_size, |
| ffn_hidden_size, |
| use_cpu_initialization=False, |
| bias_activation_fusion=True, |
| openai_gelu=False, |
| onnx_safe=False, |
| activation='gelu', |
| bias=True, |
| transformer_block_type='pre_ln', |
| normalization='layernorm', |
| layernorm_epsilon=1e-5, |
| persist_layer_norm=False, |
| sequence_parallel=False, |
| gradient_accumulation_fusion=False, |
| dropout=0.0, |
| ): |
| super(ParallelMLP, self).__init__() |
| self.activation = activation |
| self.bias = bias |
| self.transformer_block_type = transformer_block_type |
| self.normalization = normalization |
| self.layernorm_epsilon = layernorm_epsilon |
| self.persist_layer_norm = persist_layer_norm |
| self.activation = activation |
| self.dropout = dropout |
| self.set_accepted_adapter_types([MLPInfusedAdapterConfig._target_]) |
|
|
| supported_activations = [ |
| 'gelu', |
| 'geglu', |
| 'reglu', |
| 'swiglu', |
| 'squared-relu', |
| 'fast-geglu', |
| 'fast-swiglu', |
| 'fast-reglu', |
| ] |
|
|
| if activation not in supported_activations: |
| raise ValueError( |
| f"Activation {activation} not supported. Supported activations are {supported_activations}" |
| ) |
|
|
| self.fast_glu_activation = activation in ['fast-geglu', 'fast-swiglu', 'fast-reglu'] |
| no_async_tensor_model_parallel_allreduce = ( |
| parallel_state.get_tensor_model_parallel_world_size() == 1 or sequence_parallel |
| ) |
| |
| self.dense_h_to_4h = tensor_parallel.ColumnParallelLinear( |
| hidden_size, |
| ffn_hidden_size * 2 |
| if self.fast_glu_activation |
| else ffn_hidden_size, |
| gather_output=False, |
| init_method=init_method, |
| skip_bias_add=True, |
| use_cpu_initialization=use_cpu_initialization, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| if activation in ['geglu', 'reglu', 'swiglu']: |
| |
| |
| self.dense_h_to_4h_2 = tensor_parallel.ColumnParallelLinear( |
| hidden_size, |
| ffn_hidden_size, |
| gather_output=False, |
| init_method=init_method, |
| skip_bias_add=True, |
| use_cpu_initialization=use_cpu_initialization, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| no_async_tensor_model_parallel_allreduce=no_async_tensor_model_parallel_allreduce, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| self.glu_activation_family = activation in [ |
| 'geglu', |
| 'reglu', |
| 'swiglu', |
| 'fast-geglu', |
| 'fast-reglu', |
| 'fast-swiglu', |
| ] |
| bias_activation_fusion_unavailable = activation in ['reglu', 'swiglu'] |
|
|
| if bias_activation_fusion_unavailable and bias_activation_fusion: |
| raise ValueError( |
| f"Cannot use bias_activation_fusion with {activation} activation. Please turn bias gelu fusion off." |
| ) |
|
|
| if self.glu_activation_family and onnx_safe and self.bias_activation_fusion: |
| raise ValueError( |
| f"Cannot use onnx_safe with specificed activation function and bias_activation_fusion : {activation} Please turn onnx safe off." |
| ) |
|
|
| if bias_activation_fusion and not bias: |
| raise ValueError( |
| f"Cannot use bias_activation_fusion without bias terms. Please set bias=True or bias_activation_fusion=False." |
| ) |
|
|
| self.bias_activation_fusion = bias_activation_fusion |
|
|
| |
| if openai_gelu: |
| self.activation_func = openai_gelu_func |
| elif activation in ["gelu", "geglu", "fast-geglu"]: |
| self.activation_func = F.gelu |
| elif onnx_safe: |
| self.activation_func = erf_gelu |
| elif activation in ["reglu", "fast-reglu"]: |
| self.activation_func = F.relu |
| elif activation in ["swiglu", "fast-swiglu"]: |
| |
| self.activation_func = F.silu |
| elif activation == 'squared-relu': |
| self.activation_func = squared_relu |
|
|
| |
| self.dense_4h_to_h = tensor_parallel.RowParallelLinear( |
| ffn_hidden_size, |
| hidden_size, |
| input_is_parallel=True, |
| init_method=output_layer_init_method, |
| skip_bias_add=True, |
| use_cpu_initialization=use_cpu_initialization, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| |
| if transformer_block_type == 'normformer': |
| if normalization == 'layernorm': |
| self.normalization = get_layer_norm( |
| ffn_hidden_size // get_tensor_model_parallel_world_size(), layernorm_epsilon, persist_layer_norm |
| ) |
| elif normalization == 'layernorm1p': |
| self.normalization = LayerNorm1P( |
| ffn_hidden_size // get_tensor_model_parallel_world_size(), |
| layernorm_epsilon, |
| sequence_parallel_enabled=sequence_parallel, |
| ) |
| else: |
| self.normalization = MixedFusedRMSNorm( |
| ffn_hidden_size // get_tensor_model_parallel_world_size(), layernorm_epsilon |
| ) |
|
|
| def forward(self, hidden_states): |
|
|
| |
| intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) |
|
|
| if self.fast_glu_activation: |
| intermediate_parallel, intermediate_parallel_2 = torch.chunk(intermediate_parallel, 2, dim=-1) |
| if bias_parallel is not None: |
| bias_parallel, bias_parallel_2 = torch.chunk(bias_parallel, 2, dim=-1) |
| elif self.glu_activation_family and not self.fast_glu_activation: |
| intermediate_parallel_2, bias_parallel_2 = self.dense_h_to_4h_2(hidden_states) |
|
|
| if self.bias_activation_fusion: |
| if self.activation == 'gelu': |
| intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) |
| elif self.activation in ['geglu', 'fast-geglu']: |
| intermediate_parallel = fused_bias_geglu( |
| intermediate_parallel, bias_parallel, intermediate_parallel_2, bias_parallel_2 |
| ) |
|
|
| elif self.glu_activation_family and not self.bias_activation_fusion: |
| if bias_parallel is not None: |
| intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) * ( |
| intermediate_parallel_2 + bias_parallel_2 |
| ) |
| else: |
| intermediate_parallel = self.activation_func(intermediate_parallel) * intermediate_parallel_2 |
|
|
| else: |
| if bias_parallel is not None: |
| intermediate_parallel = self.activation_func(intermediate_parallel + bias_parallel) |
| else: |
| intermediate_parallel = self.activation_func(intermediate_parallel) |
|
|
| if self.dropout > 0: |
| intermediate_parallel = F.dropout(intermediate_parallel, p=self.dropout, training=self.training) |
|
|
| infused_adapter = self.get_adapter_module(AdapterName.MLP_INFUSED) |
| if infused_adapter: |
| intermediate_parallel = infused_adapter(intermediate_parallel) |
|
|
| |
| if self.transformer_block_type == 'normformer': |
| intermediate_parallel = self.normalization(intermediate_parallel) |
|
|
| |
| output, output_bias = self.dense_4h_to_h(intermediate_parallel) |
| return output, output_bias |
|
|
|
|
| class SwitchMLP(MegatronModule): |
| """Top-1 MoE |
| |
| Curently supports Sinkhorn based expert routing.""" |
|
|
| def __init__( |
| self, |
| num_experts, |
| init_method, |
| output_layer_init_method, |
| hidden_size, |
| ffn_hidden_size, |
| use_cpu_initialization=False, |
| bias_activation_fusion=True, |
| openai_gelu=False, |
| onnx_safe=False, |
| activation='gelu', |
| bias=True, |
| transformer_block_type='pre_ln', |
| normalization='layernorm', |
| layernorm_epsilon=1e-5, |
| persist_layer_norm=False, |
| sequence_parallel=False, |
| gradient_accumulation_fusion=False, |
| dropout=0.0, |
| ): |
| super(SwitchMLP, self).__init__() |
|
|
| self.num_experts = num_experts |
| self.route_algo = SwitchMLP.sinkhorn |
| self.router = tensor_parallel.RowParallelLinear( |
| hidden_size, |
| num_experts, |
| input_is_parallel=False, |
| init_method=init_method, |
| skip_bias_add=False, |
| use_cpu_initialization=use_cpu_initialization, |
| bias=bias, |
| sequence_parallel_enabled=sequence_parallel, |
| gradient_accumulation_fusion=gradient_accumulation_fusion, |
| ) |
|
|
| mlp_args = { |
| 'init_method': init_method, |
| 'output_layer_init_method': output_layer_init_method, |
| 'hidden_size': hidden_size, |
| 'ffn_hidden_size': ffn_hidden_size, |
| 'use_cpu_initialization': use_cpu_initialization, |
| 'bias_activation_fusion': bias_activation_fusion, |
| 'openai_gelu': openai_gelu, |
| 'onnx_safe': onnx_safe, |
| 'activation': activation, |
| 'bias': bias, |
| 'transformer_block_type': transformer_block_type, |
| 'normalization': normalization, |
| 'layernorm_epsilon': layernorm_epsilon, |
| 'persist_layer_norm': persist_layer_norm, |
| 'sequence_parallel': sequence_parallel, |
| 'gradient_accumulation_fusion': gradient_accumulation_fusion, |
| 'dropout': dropout, |
| } |
| self.experts = torch.nn.ModuleList([ParallelMLP(**mlp_args) for _ in range(num_experts)]) |
|
|
| def forward(self, hidden_states): |
| hidden_shape = hidden_states.shape |
| route, _ = self.router(hidden_states) |
| route = route.view(-1, self.num_experts) |
| if self.training: |
| with torch.no_grad(): |
| norm_route = self.route_algo( |
| route.detach().to(dtype=torch.float32) |
| ) |
| _, max_ind = torch.max(norm_route, dim=1) |
| route = torch.sigmoid(route) |
| max_prob = route[torch.arange(route.size(0)), max_ind] |
| else: |
| route = torch.sigmoid(route) |
| max_prob, max_ind = torch.max(route, dim=1) |
| max_prob = torch.unsqueeze(max_prob, 1) |
|
|
| hidden_states = hidden_states.view(-1, hidden_shape[-1]) |
|
|
| local_indices = (max_ind == 0).nonzero() |
| hidden = hidden_states[local_indices, :] |
| output, output_bias = self.experts[0](hidden) |
| output_bias = output_bias.expand_as(output) |
|
|
| output_total = torch.empty_like(hidden_states, dtype=output.dtype) |
| output_bias_total = torch.empty_like(hidden_states, dtype=output_bias.dtype) |
|
|
| output_total[local_indices, :] = output |
| output_bias_total[local_indices, :] = output_bias |
|
|
| for expert_num, expert in enumerate(self.experts): |
| if expert_num == 0: |
| continue |
| local_indices = (max_ind == expert_num).nonzero() |
| hidden = hidden_states[local_indices, :] |
| output, output_bias = expert(hidden) |
| output_bias = output_bias.expand_as(output) |
| output_total[local_indices, :] = output |
| output_bias_total[local_indices, :] = output_bias |
|
|
| output_total = output_total * max_prob |
| output_bias_total = output_bias_total * max_prob |
| output_total = output_total.view(hidden_shape) |
| output_bias_total = output_bias_total.view(hidden_shape) |
|
|
| return output_total, output_bias_total |
|
|
| @classmethod |
| def sinkhorn(cls, cost, tol=0.0001): |
| "Megatron-LMs sinkhorn implementation" |
|
|
| cost = torch.exp(cost) |
| d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) |
| d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) |
|
|
| eps = 0.00000001 |
| error = 1e9 |
| d1_old = d1 |
| while error > tol: |
| d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) |
| d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) |
| error = torch.mean(torch.abs(d1_old - d1)) |
| d1_old = d1 |
| return d1 * cost * d0.unsqueeze(1) |
|
|