Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Callable | |
| import torch | |
| from megatron.core import ModelParallelConfig, parallel_state | |
| from megatron.core.tensor_parallel import ColumnParallelLinear as McoreColumnParallelLinear | |
| from megatron.core.tensor_parallel import RowParallelLinear as McoreRowParallelLinear | |
| from megatron.core.tensor_parallel import VocabParallelEmbedding as McoreVocabParallelEmbedding | |
| from megatron.core.tensor_parallel.mappings import ( | |
| reduce_from_tensor_model_parallel_region, | |
| reduce_scatter_to_sequence_parallel_region, | |
| ) | |
| from megatron.core.tensor_parallel.utils import VocabUtility | |
| from torch.distributed import _functional_collectives as funcol | |
| from torch.distributed._functional_collectives import all_reduce | |
| class VocabParallelEmbedding(torch.nn.Module): | |
| """ | |
| Embedding parallelized in the vocabulary dimension. | |
| This is mainly adapted from torch.nn.Embedding and all the default | |
| values are kept. | |
| Args: | |
| num_embeddings (int): vocabulary size. | |
| embedding_dim (int): size of hidden state. | |
| precision (str): precision of the embedding. | |
| """ | |
| def __init__( | |
| self, | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| precision: str = "bfloat16", | |
| ): | |
| super().__init__() | |
| self.num_embeddings = num_embeddings | |
| self.embedding_dim = embedding_dim | |
| self.tensor_model_parallel_size = parallel_state.get_tensor_model_parallel_world_size() | |
| # Divide the weight matrix along the vocaburaly dimension. | |
| (self.vocab_start_index, self.vocab_end_index) = VocabUtility.vocab_range_from_global_vocab_size( | |
| self.num_embeddings, | |
| parallel_state.get_tensor_model_parallel_rank(), | |
| self.tensor_model_parallel_size, | |
| ) | |
| self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index | |
| self.weight = torch.nn.Parameter( | |
| torch.empty( | |
| self.num_embeddings_per_partition, | |
| self.embedding_dim, | |
| device=torch.cuda.current_device(), | |
| dtype=getattr(torch, precision), | |
| ) | |
| ) | |
| def forward(self, input_): | |
| """Forward. | |
| Args: | |
| input_ (torch.Tensor): Input tensor. | |
| """ | |
| if self.tensor_model_parallel_size > 1: | |
| # Build the mask. | |
| input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) | |
| # Mask the input. | |
| masked_input = input_.clone() - self.vocab_start_index | |
| masked_input[input_mask] = 0 | |
| else: | |
| masked_input = input_ | |
| # Get the embeddings. | |
| output = self.weight[masked_input] | |
| # Mask the output embedding. | |
| if self.tensor_model_parallel_size > 1: | |
| output[input_mask, :] = 0.0 | |
| output = all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) | |
| return output | |
| class ColumnParallelLinear(McoreColumnParallelLinear): | |
| """ | |
| A modified version of Mcore's ColumnParallelLinear that only returns the output tensor. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input_: torch.Tensor): | |
| """ | |
| Performs the forward pass of the column parallel linear layer. | |
| Args: | |
| input_ (torch.Tensor): The input tensor. | |
| weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. | |
| Returns: | |
| torch.Tensor: The output tensor after the linear transformation. | |
| """ | |
| output, _ = super().forward(input_) | |
| return output | |
| class RowParallelLinear(McoreRowParallelLinear): | |
| """ | |
| A modified version of Mcore's RowParallelLinear that only returns the output tensor. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| def forward(self, input_: torch.Tensor): | |
| """ | |
| Performs the forward pass of the Row Parallel linear layer. | |
| Args: | |
| input_ (torch.Tensor): The input tensor. | |
| weight (Optional[torch.Tensor], optional): The weight tensor. If None, uses the layer's own weight. | |
| Returns: | |
| torch.Tensor: The output tensor after the linear transformation. | |
| """ | |
| output, _ = super().forward(input_) | |
| return output | |
| class TrainingVocabParallelEmbedding(McoreVocabParallelEmbedding): | |
| """ | |
| Embedding parallelized in the vocabulary dimension. | |
| This is mainly adapted from torch.nn.Embedding and all the default | |
| values are kept. | |
| Args: | |
| num_embeddings (int): vocabulary size. | |
| embedding_dim (int): size of hidden state. | |
| Keyword Args: | |
| sequence_parallel (bool): Decides whether to perform ReduceScatter after embedding lookup | |
| batch_first (bool): If True, then output tensor shape is [batch, seq, feature]. If False, then shape becomes | |
| [seq, batch, feature]. Note: We assume the input tensor is always in the shape of [seq, batch]. | |
| config: A megatron.core.ModelParallelConfig object | |
| use_inference_allreduce (bool): If True, then Megatron's allreduce in the forward pass is disabled, and the pytorch's | |
| allreduce is used instead (inference mode only). | |
| """ | |
| def __init__( | |
| self, | |
| num_embeddings: int, | |
| embedding_dim: int, | |
| *, | |
| init_method: Callable, | |
| sequence_parallel: bool = False, | |
| batch_first: bool = False, | |
| config: ModelParallelConfig, | |
| use_inference_allreduce: bool = False, | |
| ): | |
| super(TrainingVocabParallelEmbedding, self).__init__( | |
| num_embeddings=num_embeddings, | |
| embedding_dim=embedding_dim, | |
| init_method=init_method, | |
| config=config, | |
| ) | |
| self.sequence_parallel = sequence_parallel | |
| if sequence_parallel: | |
| # If sequence parallel, then the output tensor should be in the shape of [seq, batch, feature] | |
| batch_first = False | |
| self.batch_first = batch_first | |
| self.use_inference_allreduce = use_inference_allreduce | |
| def forward(self, input_): | |
| """Forward. | |
| Args: | |
| input_ (torch.Tensor): Input tensor. | |
| """ | |
| if self.tensor_model_parallel_size > 1: | |
| # Build the mask. | |
| input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) | |
| # Mask the input. | |
| masked_input = input_.clone() - self.vocab_start_index | |
| masked_input[input_mask] = 0 | |
| else: | |
| masked_input = input_ | |
| # Get the embeddings. | |
| output = self.weight[masked_input] | |
| # Mask the output embedding. | |
| if self.tensor_model_parallel_size > 1: | |
| output[input_mask, :] = 0.0 | |
| if self.sequence_parallel: | |
| assert not self.batch_first | |
| # Data format change to avoid explicit tranposes : [b s h] --> [s b h]. | |
| output = output.transpose(0, 1).contiguous() | |
| if not self.use_inference_allreduce: | |
| output = reduce_scatter_to_sequence_parallel_region(output) | |
| else: | |
| # Reduce across all the model parallel GPUs. | |
| if not self.use_inference_allreduce: | |
| output = reduce_from_tensor_model_parallel_region(output) | |
| if not self.batch_first: | |
| # Shape: [b, s, h] --> [s, b, h] | |
| output = output.transpose(0, 1).contiguous() | |
| if self.use_inference_allreduce: | |
| output = funcol.all_reduce(output, "sum", group=parallel_state.get_tensor_model_parallel_group()) | |
| return output | |