Spaces:
Sleeping
Sleeping
| # Copyright (c) 2024, Tri Dao. | |
| # The TensorParallel linear modules are inspired by https://github.com/NVIDIA/apex/blob/master/apex/transformer/tensor_parallel/layers.py | |
| from typing import Optional | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch import Tensor | |
| from torch.cuda.amp import custom_bwd, custom_fwd | |
| from torch.distributed import ProcessGroup | |
| from einops import rearrange | |
| from mamba_ssm.distributed.distributed_utils import ( | |
| all_gather_raw, | |
| all_reduce, | |
| all_reduce_raw, | |
| reduce_scatter, | |
| reduce_scatter_raw, | |
| ) | |
| class ParallelLinearFunc(torch.autograd.Function): | |
| def forward(ctx, x, weight, bias, process_group=None, sequence_parallel=True): | |
| """ | |
| If process_group is not None and sequence_parallel=True, we're doing Tensor Parallel | |
| with sequence parallelism: we do an all_gather_raw of x before doing the matmul. | |
| """ | |
| ctx.compute_weight_gradient = weight.requires_grad | |
| ctx.process_group = process_group | |
| ctx.sequence_parallel = sequence_parallel | |
| if torch.is_autocast_enabled(): | |
| x = x.to(dtype=torch.get_autocast_gpu_dtype()) | |
| x = x.contiguous() | |
| if process_group is not None and sequence_parallel: | |
| # We want to kick off the all_gather early, before weight dtype conversion | |
| total_x, handle_x = all_gather_raw(x, process_group, async_op=True) | |
| else: | |
| total_x = x | |
| if torch.is_autocast_enabled(): | |
| weight = weight.to(dtype=torch.get_autocast_gpu_dtype()) | |
| bias = bias.to(dtype=torch.get_autocast_gpu_dtype()) if bias is not None else None | |
| weight = weight.contiguous() | |
| if process_group is not None and sequence_parallel: | |
| handle_x.wait() | |
| batch_shape, n = total_x.shape[:-1], total_x.shape[-1] | |
| batch_dim = batch_shape.numel() | |
| # https://github.com/pytorch/pytorch/blob/5b51849b48a7dbccd297286cc0110def4706f9e7/aten/src/ATen/native/cuda/Blas.cpp#L174 | |
| output = F.linear(total_x, weight, bias) | |
| if ctx.compute_weight_gradient: | |
| ctx.save_for_backward(x, weight) | |
| else: | |
| ctx.save_for_backward(weight) | |
| return output | |
| def backward(ctx, grad_output): | |
| grad_output = grad_output.contiguous() | |
| process_group = ctx.process_group | |
| sequence_parallel = ctx.sequence_parallel | |
| if ctx.compute_weight_gradient: | |
| x, weight = ctx.saved_tensors | |
| if process_group is not None and sequence_parallel: | |
| total_x, handle_x = all_gather_raw(x, process_group, async_op=True) | |
| else: | |
| total_x = x | |
| else: | |
| (weight,) = ctx.saved_tensors | |
| total_x = None | |
| batch_shape = grad_output.shape[:-1] | |
| batch_dim = batch_shape.numel() | |
| grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) | |
| if ctx.needs_input_grad[0]: | |
| grad_input = F.linear(grad_output, weight.t()) | |
| grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) | |
| if process_group is not None: | |
| reduce_fn = reduce_scatter_raw if sequence_parallel else all_reduce_raw | |
| grad_input, handle_grad_input = reduce_fn(grad_input, process_group, async_op=True) | |
| else: | |
| grad_input = None | |
| if ctx.needs_input_grad[1]: | |
| assert ctx.compute_weight_gradient | |
| if process_group is not None and sequence_parallel: | |
| handle_x.wait() | |
| grad_weight = torch.einsum( | |
| "bo,bi->oi", grad_output, total_x.reshape(batch_dim, total_x.shape[-1]) | |
| ) | |
| else: | |
| grad_weight = None | |
| grad_bias = grad_output.sum(dim=0) if ctx.needs_input_grad[2] else None | |
| if process_group is not None and ctx.needs_input_grad[0]: | |
| handle_grad_input.wait() | |
| return grad_input, grad_weight, grad_bias, None, None | |
| def parallel_linear_func( | |
| x: Tensor, | |
| weight: Tensor, | |
| bias: Optional[Tensor] = None, | |
| process_group: Optional[ProcessGroup] = None, | |
| sequence_parallel: bool = True, | |
| ): | |
| return ParallelLinearFunc.apply(x, weight, bias, process_group, sequence_parallel) | |
| class ColumnParallelLinear(nn.Linear): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| process_group: ProcessGroup, | |
| bias: bool = True, | |
| sequence_parallel=True, | |
| multiple_of=1, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| world_size = torch.distributed.get_world_size(process_group) | |
| if out_features % multiple_of: | |
| raise ValueError(f"out_features ({out_features}) must be a multiple of {multiple_of}") | |
| multiple = out_features // multiple_of | |
| # We want to split @multiple across world_size, but it could be an uneven split | |
| div = multiple // world_size | |
| mod = multiple % world_size | |
| # The first @mod ranks get @div + 1 copies, the rest get @div copies | |
| local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) | |
| super().__init__( | |
| in_features, local_multiple * multiple_of, bias=bias, device=device, dtype=dtype | |
| ) | |
| self.process_group = process_group | |
| self.sequence_parallel = sequence_parallel | |
| def forward(self, x): | |
| # If self.sequence_parallel is True, we're doing Tensor Parallel with sequence parallelism: | |
| # we do an all_gather of x before doing the matmul. | |
| # If not, then the input is already gathered. | |
| return parallel_linear_func( | |
| x, | |
| self.weight, | |
| self.bias, | |
| process_group=self.process_group, | |
| sequence_parallel=self.sequence_parallel, | |
| ) | |
| class RowParallelLinear(nn.Linear): | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| process_group: ProcessGroup, | |
| bias: bool = True, | |
| sequence_parallel=True, | |
| multiple_of=1, | |
| device=None, | |
| dtype=None, | |
| ) -> None: | |
| world_size = torch.distributed.get_world_size(process_group) | |
| rank = torch.distributed.get_rank(process_group) | |
| if in_features % multiple_of: | |
| raise ValueError(f"in_features ({in_features}) must be a multiple of {multiple_of}") | |
| multiple = in_features // multiple_of | |
| # We want to split @multiple across world_size, but it could be an uneven split | |
| div = multiple // world_size | |
| mod = multiple % world_size | |
| # The first @mod ranks get @div + 1 copies, the rest get @div copies | |
| local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) | |
| # Only rank 0 will have bias | |
| super().__init__( | |
| local_multiple * multiple_of, | |
| out_features, | |
| bias=bias and rank == 0, | |
| device=device, | |
| dtype=dtype, | |
| ) | |
| self.process_group = process_group | |
| self.sequence_parallel = sequence_parallel | |
| def forward(self, x): | |
| """ | |
| We're doing Tensor Parallel with sequence parallelism: we do the matmul and then | |
| a reduce_scatter of the result. | |
| """ | |
| out = parallel_linear_func(x, self.weight, self.bias) | |
| reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce | |
| return reduce_fn(out, self.process_group) | |
| class VocabParallelEmbedding(nn.Embedding): | |
| def __init__(self, num_embeddings, *args, process_group=None, padding_idx=None, **kwargs): | |
| self.process_group = process_group | |
| if process_group is not None: | |
| world_size = torch.distributed.get_world_size(process_group) | |
| if num_embeddings % world_size != 0: | |
| raise ValueError( | |
| f"num_embeddings ({num_embeddings}) must be divisible by " | |
| f"world_size ({world_size})" | |
| ) | |
| if world_size > 1 and padding_idx is not None: | |
| raise RuntimeError("ParallelEmbedding does not support padding_idx") | |
| else: | |
| world_size = 1 | |
| super().__init__(num_embeddings // world_size, *args, padding_idx=padding_idx, **kwargs) | |
| def forward(self, input: Tensor) -> Tensor: | |
| if self.process_group is None: | |
| return super().forward(input) | |
| else: | |
| rank = torch.distributed.get_rank(self.process_group) | |
| vocab_size = self.num_embeddings | |
| vocab_start_index, vocab_end_index = rank * vocab_size, (rank + 1) * vocab_size | |
| # Create a mask of valid vocab ids (1 means it needs to be masked). | |
| input_ids_mask = (input < vocab_start_index) | (input >= vocab_end_index) | |
| input = input - vocab_start_index | |
| input[input_ids_mask] = 0 | |
| embeddings = super().forward(input) | |
| embeddings[input_ids_mask] = 0.0 | |
| return embeddings | |
| class ColumnParallelEmbedding(nn.Embedding): | |
| def __init__(self, num_embeddings, embedding_dim, *args, process_group=None, **kwargs): | |
| self.process_group = process_group | |
| if process_group is not None: | |
| world_size = torch.distributed.get_world_size(process_group) | |
| if embedding_dim % world_size != 0: | |
| raise ValueError( | |
| f"embedding_dim ({embedding_dim}) must be divisible by " | |
| f"world_size ({world_size})" | |
| ) | |
| else: | |
| world_size = 1 | |
| super().__init__(num_embeddings, embedding_dim // world_size, *args, **kwargs) | |
| class ParallelEmbeddings(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim, | |
| vocab_size, | |
| max_position_embeddings, | |
| process_group, | |
| padding_idx=None, | |
| sequence_parallel=True, | |
| device=None, | |
| dtype=None, | |
| ): | |
| """ | |
| If max_position_embeddings <= 0, there's no position embeddings | |
| """ | |
| factory_kwargs = {"device": device, "dtype": dtype} | |
| super().__init__() | |
| self.process_group = process_group | |
| self.sequence_parallel = sequence_parallel | |
| self.word_embeddings = VocabParallelEmbedding( | |
| vocab_size, | |
| embed_dim, | |
| padding_idx=padding_idx, | |
| process_group=process_group, | |
| **factory_kwargs, | |
| ) | |
| self.max_position_embeddings = max_position_embeddings | |
| if self.max_position_embeddings > 0: | |
| self.position_embeddings = ColumnParallelEmbedding( | |
| max_position_embeddings, embed_dim, process_group=process_group, **factory_kwargs | |
| ) | |
| def forward(self, input_ids, position_ids=None, combine_batch_seqlen_dim=False): | |
| """ | |
| input_ids: (batch, seqlen) | |
| position_ids: (batch, seqlen) | |
| """ | |
| batch_size, seqlen = input_ids.shape | |
| world_size = torch.distributed.get_world_size(self.process_group) | |
| embeddings = self.word_embeddings(input_ids) | |
| if self.max_position_embeddings > 0: | |
| if position_ids is None: | |
| position_ids = torch.arange(seqlen, dtype=torch.long, device=input_ids.device) | |
| position_embeddings = self.position_embeddings(position_ids) | |
| if world_size <= 1: | |
| embeddings = embeddings + position_embeddings | |
| else: | |
| partition_dim = self.position_embeddings.embedding_dim | |
| rank = torch.distributed.get_rank(self.process_group) | |
| embeddings[ | |
| ..., rank * partition_dim : (rank + 1) * partition_dim | |
| ] += position_embeddings | |
| if combine_batch_seqlen_dim: | |
| embeddings = rearrange(embeddings, "b s d -> (b s) d") | |
| reduce_fn = reduce_scatter if self.sequence_parallel else all_reduce | |
| return embeddings if world_size <= 1 else reduce_fn(embeddings, self.process_group) | |