| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
|
|
|
|
| import math |
|
|
| import torch |
| import torch.nn.functional as F |
| import torch.nn.init as init |
| from torch.nn.parameter import Parameter |
|
|
| from .initialize import get_tensor_model_parallel_rank |
| from .initialize import get_tensor_model_parallel_world_size |
| from .initialize import get_tensor_model_parallel_group |
| from .mappings import copy_to_tensor_model_parallel_region |
| from .mappings import gather_from_tensor_model_parallel_region |
| from .mappings import gather_from_sequence_parallel_region |
| from .mappings import reduce_from_tensor_model_parallel_region |
| from .mappings import scatter_to_tensor_model_parallel_region |
| from .mappings import reduce_scatter_to_sequence_parallel_region |
|
|
| from .random import get_cuda_rng_tracker |
| from .utils import divide |
| from .utils import split_tensor_along_last_dim |
| from .utils import VocabUtility |
| from megatron import get_args, get_global_memory_buffer |
|
|
| _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS = {'tensor_model_parallel': False, |
| 'partition_dim': -1, |
| 'partition_stride': 1} |
|
|
| def param_is_not_tensor_parallel_duplicate(param): |
| return (hasattr(param, 'tensor_model_parallel') and |
| param.tensor_model_parallel) or ( |
| get_tensor_model_parallel_rank() == 0) |
|
|
|
|
| def set_tensor_model_parallel_attributes(tensor, is_parallel, dim, stride): |
| |
| for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: |
| assert not hasattr(tensor, attribute) |
| |
| setattr(tensor, 'tensor_model_parallel', is_parallel) |
| setattr(tensor, 'partition_dim', dim) |
| setattr(tensor, 'partition_stride', stride) |
|
|
|
|
| def set_defaults_if_not_set_tensor_model_parallel_attributes(tensor): |
| def maybe_set(attribute, value): |
| if not hasattr(tensor, attribute): |
| setattr(tensor, attribute, value) |
| for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: |
| maybe_set(attribute, _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS[attribute]) |
|
|
|
|
| def copy_tensor_model_parallel_attributes(destination_tensor, source_tensor): |
| def maybe_copy(attribute): |
| if hasattr(source_tensor, attribute): |
| setattr(destination_tensor, attribute, |
| getattr(source_tensor, attribute)) |
| for attribute in _MODEL_PARALLEL_ATTRIBUTE_DEFAULTS: |
| maybe_copy(attribute) |
|
|
|
|
| def _initialize_affine_weight_gpu(weight, init_method, |
| partition_dim, stride=1): |
| """Initialize affine weight for model parallel on GPU.""" |
|
|
| set_tensor_model_parallel_attributes(tensor=weight, |
| is_parallel=True, |
| dim=partition_dim, |
| stride=stride) |
|
|
| with get_cuda_rng_tracker().fork(): |
| init_method(weight) |
|
|
|
|
| def _initialize_affine_weight_cpu(weight, output_size, input_size, |
| per_partition_size, partition_dim, |
| init_method, stride=1, |
| return_master_weight=False): |
| """Initialize affine weight for model parallel. |
| |
| Build the master weight on all processes and scatter |
| the relevant chunk.""" |
|
|
| set_tensor_model_parallel_attributes(tensor=weight, |
| is_parallel=True, |
| dim=partition_dim, |
| stride=stride) |
|
|
| |
| master_weight = torch.empty(output_size, input_size, |
| dtype=torch.float, |
| requires_grad=False) |
| init_method(master_weight) |
| args = get_args() |
| master_weight = master_weight.to(dtype=args.params_dtype) |
|
|
| |
| per_partition_per_stride_size = divide(per_partition_size, stride) |
| weight_list = torch.split(master_weight, per_partition_per_stride_size, |
| dim=partition_dim) |
| rank = get_tensor_model_parallel_rank() |
| world_size = get_tensor_model_parallel_world_size() |
| my_weight_list = weight_list[rank::world_size] |
|
|
| with torch.no_grad(): |
| torch.cat(my_weight_list, dim=partition_dim, out=weight) |
| if return_master_weight: |
| return master_weight |
| return None |
|
|
|
|
| class VocabParallelEmbedding(torch.nn.Module): |
| """Embedding parallelized in the vocabulary dimension. |
| |
| This is mainly adapted from torch.nn.Embedding and all the default |
| values are kept. |
| Arguments: |
| num_embeddings: vocabulary size. |
| embedding_dim: size of hidden state. |
| init_method: method to initialize weights. |
| """ |
|
|
| def __init__(self, num_embeddings, embedding_dim, |
| init_method=init.xavier_normal_): |
| super(VocabParallelEmbedding, self).__init__() |
| |
| self.num_embeddings = num_embeddings |
| self.embedding_dim = embedding_dim |
| |
| self.padding_idx = None |
| self.max_norm = None |
| self.norm_type = 2. |
| self.scale_grad_by_freq = False |
| self.sparse = False |
| self._weight = None |
| self.tensor_model_parallel_size = get_tensor_model_parallel_world_size() |
| |
| self.vocab_start_index, self.vocab_end_index = \ |
| VocabUtility.vocab_range_from_global_vocab_size( |
| self.num_embeddings, get_tensor_model_parallel_rank(), |
| self.tensor_model_parallel_size) |
| self.num_embeddings_per_partition = self.vocab_end_index - \ |
| self.vocab_start_index |
|
|
| |
| args = get_args() |
| if args.use_cpu_initialization: |
| self.weight = Parameter(torch.empty( |
| self.num_embeddings_per_partition, self.embedding_dim, |
| dtype=args.params_dtype)) |
| if args.perform_initialization: |
| _initialize_affine_weight_cpu( |
| self.weight, self.num_embeddings, self.embedding_dim, |
| self.num_embeddings_per_partition, 0, init_method) |
| else: |
| self.weight = Parameter(torch.empty( |
| self.num_embeddings_per_partition, self.embedding_dim, |
| device=torch.cuda.current_device(), dtype=args.params_dtype)) |
| if args.perform_initialization: |
| _initialize_affine_weight_gpu(self.weight, init_method, |
| partition_dim=0, stride=1) |
|
|
| def forward(self, input_): |
| if self.tensor_model_parallel_size > 1: |
| |
| input_mask = (input_ < self.vocab_start_index) | \ |
| (input_ >= self.vocab_end_index) |
| |
| masked_input = input_.clone() - self.vocab_start_index |
| masked_input[input_mask] = 0 |
| else: |
| masked_input = input_ |
| |
| output_parallel = F.embedding(masked_input, self.weight, |
| self.padding_idx, self.max_norm, |
| self.norm_type, self.scale_grad_by_freq, |
| self.sparse) |
| |
| if self.tensor_model_parallel_size > 1: |
| output_parallel[input_mask, :] = 0.0 |
| |
| output = reduce_from_tensor_model_parallel_region(output_parallel) |
| return output |
|
|
|
|
| class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function): |
| """ |
| Linear layer execution with asynchronous communication and gradient accumulation |
| fusion in backprop. |
| """ |
|
|
| @staticmethod |
| def forward(ctx, input, weight, bias, gradient_accumulation_fusion, |
| async_grad_allreduce, sequence_parallel): |
| ctx.save_for_backward(input, weight) |
| ctx.use_bias = bias is not None |
| ctx.gradient_accumulation_fusion = gradient_accumulation_fusion |
| ctx.async_grad_allreduce = async_grad_allreduce |
| ctx.sequence_parallel = sequence_parallel |
| |
| if sequence_parallel: |
| world_size = get_tensor_model_parallel_world_size() |
| dim_size = list(input.size()) |
| dim_size[0] = dim_size[0] * world_size |
|
|
| all_gather_buffer = \ |
| get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") |
| torch.distributed._all_gather_base( |
| all_gather_buffer, |
| input, |
| group=get_tensor_model_parallel_group()) |
| total_input = all_gather_buffer |
| else: |
| total_input = input |
|
|
| output = torch.matmul(total_input, weight.t()) |
| if bias is not None: |
| output = output + bias |
| return output |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| input, weight = ctx.saved_tensors |
| use_bias = ctx.use_bias |
| |
| if ctx.sequence_parallel: |
| world_size = get_tensor_model_parallel_world_size() |
| dim_size = list(input.size()) |
| dim_size[0] = dim_size[0] * world_size |
|
|
| all_gather_buffer = \ |
| get_global_memory_buffer().get_tensor(dim_size, input.dtype, "mpu") |
| handle = torch.distributed._all_gather_base( |
| all_gather_buffer, |
| input, |
| group=get_tensor_model_parallel_group(), async_op=True) |
|
|
| |
| |
| _ = torch.empty(1, device=grad_output.device) + 1 |
| total_input = all_gather_buffer |
| else: |
| total_input = input |
| grad_input = grad_output.matmul(weight) |
|
|
| if ctx.sequence_parallel: |
| handle.wait() |
|
|
| |
| grad_output = grad_output.view(grad_output.shape[0] * grad_output.shape[1], |
| grad_output.shape[2]) |
| total_input = total_input.view(total_input.shape[0] * total_input.shape[1], |
| total_input.shape[2]) |
| |
| if ctx.async_grad_allreduce: |
| |
| handle = torch.distributed.all_reduce( |
| grad_input, group=get_tensor_model_parallel_group(), async_op=True) |
| |
| |
| _ = torch.empty(1, device=grad_output.device) + 1 |
| |
| if ctx.sequence_parallel: |
| assert not ctx.async_grad_allreduce |
| dim_size = list(input.size()) |
| sub_grad_input = torch.empty(dim_size, dtype=input.dtype, |
| device=torch.cuda.current_device(), |
| requires_grad=False) |
| |
| handle = torch.distributed._reduce_scatter_base(sub_grad_input, grad_input, |
| group=get_tensor_model_parallel_group(), |
| async_op=True) |
| |
| |
| _ = torch.empty(1, device=grad_output.device) + 1 |
| |
|
|
| if ctx.gradient_accumulation_fusion: |
| import fused_dense_cuda |
| fused_dense_cuda.wgrad_gemm_accum_fp32(total_input, grad_output, weight.main_grad) |
| grad_weight = None |
| else: |
| grad_weight = grad_output.t().matmul(total_input) |
| grad_bias = grad_output.sum(dim=0) if use_bias else None |
|
|
| if ctx.sequence_parallel: |
| handle.wait() |
| return sub_grad_input, grad_weight, grad_bias, None, None, None |
|
|
| if ctx.async_grad_allreduce: |
| handle.wait() |
|
|
| return grad_input, grad_weight, grad_bias, None, None, None |
|
|
|
|
| class ColumnParallelLinear(torch.nn.Module): |
| """Linear layer with column parallelism. |
| |
| The linear layer is defined as Y = XA + b. A is parallelized along |
| its second dimension as A = [A_1, ..., A_p]. |
| |
| Arguments: |
| input_size: first dimension of matrix A. |
| output_size: second dimension of matrix A. |
| bias: If true, add bias |
| gather_output: If true, call all-gather on output and make Y available |
| to all GPUs, otherwise, every GPU will have its output |
| which is Y_i = XA_i |
| init_method: method to initialize weights. Note that bias is always set |
| to zero. |
| stride: For the strided linear layers. |
| keep_master_weight_for_test: This was added for testing and should be |
| set to False. It returns the master weights |
| used for initialization. |
| skip_bias_add: This was added to enable performance optimations where bias |
| can be fused with other elementwise operations. we skip |
| adding bias but instead return it. |
| """ |
|
|
| def __init__(self, input_size, output_size, bias=True, gather_output=True, |
| init_method=init.xavier_normal_, stride=1, |
| keep_master_weight_for_test=False, |
| skip_bias_add=False): |
| super(ColumnParallelLinear, self).__init__() |
|
|
| |
| self.input_size = input_size |
| self.output_size = output_size |
| self.gather_output = gather_output |
| |
| world_size = get_tensor_model_parallel_world_size() |
| self.output_size_per_partition = divide(output_size, world_size) |
| self.skip_bias_add = skip_bias_add |
|
|
| |
| |
| |
| |
| args = get_args() |
| if args.use_cpu_initialization: |
| self.weight = Parameter(torch.empty(self.output_size_per_partition, |
| self.input_size, |
| dtype=args.params_dtype)) |
| if args.perform_initialization: |
| self.master_weight = _initialize_affine_weight_cpu( |
| self.weight, self.output_size, self.input_size, |
| self.output_size_per_partition, 0, init_method, |
| stride=stride, return_master_weight=keep_master_weight_for_test) |
| else: |
| self.weight = Parameter(torch.empty( |
| self.output_size_per_partition, self.input_size, |
| device=torch.cuda.current_device(), dtype=args.params_dtype)) |
| if args.perform_initialization: |
| _initialize_affine_weight_gpu(self.weight, init_method, |
| partition_dim=0, stride=stride) |
|
|
| if bias: |
| if args.use_cpu_initialization: |
| self.bias = Parameter(torch.empty( |
| self.output_size_per_partition, dtype=args.params_dtype)) |
| else: |
| self.bias = Parameter(torch.empty( |
| self.output_size_per_partition, |
| device=torch.cuda.current_device(), |
| dtype=args.params_dtype)) |
| set_tensor_model_parallel_attributes(self.bias, True, 0, stride) |
| |
| with torch.no_grad(): |
| self.bias.zero_() |
| else: |
| self.register_parameter('bias', None) |
| self.async_tensor_model_parallel_allreduce = ( |
| args.async_tensor_model_parallel_allreduce and |
| world_size > 1) |
| self.sequence_parallel = ( |
| args.sequence_parallel and |
| world_size > 1) |
| assert not self.async_tensor_model_parallel_allreduce or \ |
| not self.sequence_parallel |
| self.gradient_accumulation_fusion = args.gradient_accumulation_fusion |
|
|
| def forward(self, input_): |
| bias = self.bias if not self.skip_bias_add else None |
|
|
| if self.async_tensor_model_parallel_allreduce or \ |
| self.sequence_parallel: |
| input_parallel = input_ |
| else: |
| input_parallel = copy_to_tensor_model_parallel_region(input_) |
| |
| output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( |
| input_parallel, self.weight, bias, self.gradient_accumulation_fusion, |
| self.async_tensor_model_parallel_allreduce, self.sequence_parallel) |
| if self.gather_output: |
| |
| assert not self.sequence_parallel |
| output = gather_from_tensor_model_parallel_region(output_parallel) |
| else: |
| output = output_parallel |
| output_bias = self.bias if self.skip_bias_add else None |
| return output, output_bias |
|
|
|
|
| class RowParallelLinear(torch.nn.Module): |
| """Linear layer with row parallelism. |
| |
| The linear layer is defined as Y = XA + b. A is parallelized along |
| its first dimension and X along its second dimension as: |
| - - |
| | A_1 | |
| | . | |
| A = | . | X = [X_1, ..., X_p] |
| | . | |
| | A_p | |
| - - |
| Arguments: |
| input_size: first dimension of matrix A. |
| output_size: second dimension of matrix A. |
| bias: If true, add bias. Note that bias is not parallelized. |
| input_is_parallel: If true, we assume that the input is already |
| split across the GPUs and we do not split |
| again. |
| init_method: method to initialize weights. Note that bias is always set |
| to zero. |
| stride: For the strided linear layers. |
| keep_master_weight_for_test: This was added for testing and should be |
| set to False. It returns the master weights |
| used for initialization. |
| skip_bias_add: This was added to enable performance optimization where bias |
| can be fused with other elementwise operations. We skip |
| adding bias but instead return it. |
| """ |
|
|
| def __init__(self, input_size, output_size, bias=True, |
| input_is_parallel=False, |
| init_method=init.xavier_normal_, stride=1, |
| keep_master_weight_for_test=False, |
| skip_bias_add=False): |
| super(RowParallelLinear, self).__init__() |
|
|
| |
| self.input_size = input_size |
| self.output_size = output_size |
| self.input_is_parallel = input_is_parallel |
| |
| world_size = get_tensor_model_parallel_world_size() |
| self.input_size_per_partition = divide(input_size, world_size) |
| self.skip_bias_add = skip_bias_add |
|
|
| |
| |
| |
| |
| args = get_args() |
| if args.use_cpu_initialization: |
| self.weight = Parameter(torch.empty(self.output_size, |
| self.input_size_per_partition, |
| dtype=args.params_dtype)) |
| if args.perform_initialization: |
| self.master_weight = _initialize_affine_weight_cpu( |
| self.weight, self.output_size, self.input_size, |
| self.input_size_per_partition, 1, init_method, |
| stride=stride, return_master_weight=keep_master_weight_for_test) |
| else: |
| self.weight = Parameter(torch.empty( |
| self.output_size, self.input_size_per_partition, |
| device=torch.cuda.current_device(), dtype=args.params_dtype)) |
| if args.perform_initialization: |
| _initialize_affine_weight_gpu(self.weight, init_method, |
| partition_dim=1, stride=stride) |
| if bias: |
| if args.use_cpu_initialization: |
| self.bias = Parameter(torch.empty(self.output_size, |
| dtype=args.params_dtype)) |
| else: |
| self.bias = Parameter(torch.empty( |
| self.output_size, device=torch.cuda.current_device(), |
| dtype=args.params_dtype)) |
| setattr(self.bias, 'sequence_parallel', args.sequence_parallel) |
|
|
| |
| with torch.no_grad(): |
| self.bias.zero_() |
| else: |
| self.register_parameter('bias', None) |
| self.sequence_parallel = args.sequence_parallel |
| self.gradient_accumulation_fusion = args.gradient_accumulation_fusion |
|
|
|
|
|
|
| def forward(self, input_): |
| |
| if self.input_is_parallel: |
| input_parallel = input_ |
| else: |
| assert not self.sequence_parallel |
| input_parallel = scatter_to_tensor_model_parallel_region(input_) |
| |
| output_parallel = LinearWithGradAccumulationAndAsyncCommunication.apply( |
| input_parallel, self.weight, None, |
| self.gradient_accumulation_fusion, None, None) |
| |
| if self.sequence_parallel: |
| output_ = reduce_scatter_to_sequence_parallel_region(output_parallel) |
| else: |
| output_ = reduce_from_tensor_model_parallel_region(output_parallel) |
| if not self.skip_bias_add: |
| output = output_ + self.bias if self.bias is not None else output_ |
| output_bias = None |
| else: |
| output = output_ |
| output_bias = self.bias |
| return output, output_bias |
|
|