|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable |
|
|
|
|
|
import torch |
|
|
from megatron.core import ModelParallelConfig, parallel_state |
|
|
from megatron.core.tensor_parallel import ColumnParallelLinear as McoreColumnParallelLinear |
|
|
from megatron.core.tensor_parallel import RowParallelLinear as McoreRowParallelLinear |
|
|
from megatron.core.tensor_parallel import VocabParallelEmbedding as McoreVocabParallelEmbedding |
|
|
from megatron.core.tensor_parallel.mappings import ( |
|
|
reduce_from_tensor_model_parallel_region, |
|
|
reduce_scatter_to_sequence_parallel_region, |
|
|
) |
|
|
from megatron.core.tensor_parallel.utils import VocabUtility |
|
|
from torch.distributed import _functional_collectives as funcol |
|
|
from torch.distributed._functional_collectives import all_reduce |
|
|
|
|
|
|
|
|
class VocabParallelEmbedding(torch.nn.Module): |
|
|
""" |
|
|
Embedding parallelized in the vocabulary dimension. |
|
|
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default |
|
|
values are kept. |
|
|
|
|
|
Args: |
|
|
num_embeddings (int): vocabulary size. |
|
|
embedding_dim (int): size of hidden state. |
|
|
precision (str): precision of the embedding. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_embeddings: int, |
|
|
embedding_dim: int, |
|
|
precision: str = "bfloat16", |
|
|
): |
|
|
super().__init__() |
|
|
self.num_embeddings = num_embeddings |
|
|
self.embedding_dim = embedding_dim |
|
|
self.tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() |
|
|
|
|
|
(self.vocab_start_index, self.vocab_end_index) = VocabUtility.vocab_range_from_global_vocab_size( |
|
|
self.num_embeddings, |
|
|
parallel_state.get_tensor_model_parallel_rank(), |
|
|
self.tensor_model_parallel_size, |
|
|
) |
|
|
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index |
|
|
|
|
|
self.weight = torch.nn.Parameter( |
|
|
torch.empty( |
|
|
self.num_embeddings_per_partition, |
|
|
self.embedding_dim, |
|
|
device=torch.cuda.current_device(), |
|
|
dtype=getattr(torch, precision), |
|
|
) |
|
|
) |
|
|
|
|
|
def forward(self, input_): |
|
|
"""Forward. |
|
|
|
|
|
Args: |
|
|
input_ (torch.Tensor): Input tensor. |
|
|
""" |
|
|
if self.tensor_model_parallel_size > 1: |
|
|
|
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) |
|
|
|
|
|
masked_input = input_.clone() - self.vocab_start_index |
|
|
masked_input[input_mask] = 0 |
|
|
else: |
|
|
masked_input = input_ |
|
|
|
|
|
output = self.weight[masked_input] |
|
|
|
|
|
if self.tensor_model_parallel_size > 1: |
|
|
output[input_mask, :] = 0.0 |
|
|
|
|
|
output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) |
|
|
return output |
|
|
|
|
|
|
|
|
class ColumnParallelLinear(McoreColumnParallelLinear): |
|
|
""" |
|
|
A modified version of Mcore's ColumnParallelLinear that only returns the output tensor. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, input_: torch.Tensor): |
|
|
""" |
|
|
Performs the forward pass of the column parallel linear layer. |
|
|
|
|
|
Args: |
|
|
input_ (torch.Tensor): The input tensor. |
|
|
weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The output tensor after the linear transformation. |
|
|
""" |
|
|
output, _ = super().forward(input_) |
|
|
return output |
|
|
|
|
|
|
|
|
class RowParallelLinear(McoreRowParallelLinear): |
|
|
""" |
|
|
A modified version of Mcore's RowParallelLinear that only returns the output tensor. |
|
|
|
|
|
""" |
|
|
|
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
def forward(self, input_: torch.Tensor): |
|
|
""" |
|
|
Performs the forward pass of the Row Parallel linear layer. |
|
|
|
|
|
Args: |
|
|
input_ (torch.Tensor): The input tensor. |
|
|
weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The output tensor after the linear transformation. |
|
|
""" |
|
|
output, _ = super().forward(input_) |
|
|
return output |
|
|
|
|
|
|
|
|
class TrainingVocabParallelEmbedding(McoreVocabParallelEmbedding): |
|
|
""" |
|
|
Embedding parallelized in the vocabulary dimension. |
|
|
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default |
|
|
values are kept. |
|
|
|
|
|
Args: |
|
|
num_embeddings (int): vocabulary size. |
|
|
embedding_dim (int): size of hidden state. |
|
|
|
|
|
Keyword Args: |
|
|
sequence_parallel (bool): Decides whether to perform ReduceScatter after embedding lookup |
|
|
batch_first (bool): If True, then output tensor shape is [batch, seq, feature]. If False, then shape becomes |
|
|
[seq, batch, feature]. Note: We assume the input tensor is always in the shape of [seq, batch]. |
|
|
config: A megatron.core.ModelParallelConfig object |
|
|
use_inference_allreduce (bool): If True, then Megatron's allreduce in the forward pass is disabled, and the pytorch's |
|
|
allreduce is used instead (inference mode only). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
num_embeddings: int, |
|
|
embedding_dim: int, |
|
|
*, |
|
|
init_method: Callable, |
|
|
sequence_parallel: bool = False, |
|
|
batch_first: bool = False, |
|
|
config: ModelParallelConfig, |
|
|
use_inference_allreduce: bool = False, |
|
|
): |
|
|
super(TrainingVocabParallelEmbedding, self).__init__( |
|
|
num_embeddings=num_embeddings, |
|
|
embedding_dim=embedding_dim, |
|
|
init_method=init_method, |
|
|
config=config, |
|
|
) |
|
|
self.sequence_parallel = sequence_parallel |
|
|
if sequence_parallel: |
|
|
|
|
|
batch_first = False |
|
|
self.batch_first = batch_first |
|
|
self.use_inference_allreduce = use_inference_allreduce |
|
|
|
|
|
def forward(self, input_): |
|
|
"""Forward. |
|
|
|
|
|
Args: |
|
|
input_ (torch.Tensor): Input tensor. |
|
|
""" |
|
|
if self.tensor_model_parallel_size > 1: |
|
|
|
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) |
|
|
|
|
|
masked_input = input_.clone() - self.vocab_start_index |
|
|
masked_input[input_mask] = 0 |
|
|
else: |
|
|
masked_input = input_ |
|
|
|
|
|
output = self.weight[masked_input] |
|
|
|
|
|
if self.tensor_model_parallel_size > 1: |
|
|
output[input_mask, :] = 0.0 |
|
|
|
|
|
if self.sequence_parallel: |
|
|
assert not self.batch_first |
|
|
|
|
|
output = output.transpose(0, 1).contiguous() |
|
|
if not self.use_inference_allreduce: |
|
|
output = reduce_scatter_to_sequence_parallel_region(output) |
|
|
else: |
|
|
|
|
|
if not self.use_inference_allreduce: |
|
|
output = reduce_from_tensor_model_parallel_region(output) |
|
|
if not self.batch_first: |
|
|
|
|
|
output = output.transpose(0, 1).contiguous() |
|
|
|
|
|
if self.use_inference_allreduce: |
|
|
output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) |
|
|
return output |
|
|
|