Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any, Dict, Optional | |
| import torch | |
| import torch.nn as nn | |
| import transformer_engine as te | |
| from megatron.core import InferenceParams, ModelParallelConfig, parallel_state | |
| from megatron.core.tensor_parallel import gather_from_tensor_model_parallel_region | |
| from torch.distributed import ProcessGroup | |
| from torch.distributed import _functional_collectives as funcol | |
| from torch.distributed import broadcast, get_process_group_ranks | |
| from torch.nn.modules.module import _IncompatibleKeys | |
| from transformer_engine.pytorch.module.linear import Linear as LinearTE | |
| from transformer_engine.pytorch.module.rmsnorm import RMSNorm as RMSNormTE | |
| from cosmos_predict1.utils import log | |
| _ACTION_DIM = 8 | |
| from cosmos_predict1.autoregressive.modules.embedding import ( | |
| RotaryPositionEmbeddingPytorch, | |
| RotaryPositionEmbeddingPytorchV2, | |
| RotaryPositionEmbeddingTE, | |
| SinCosPosEmbAxisTE, | |
| get_pos_emb_on_this_cp_rank, | |
| get_pos_emb_on_this_sptp_rank, | |
| ) | |
| from cosmos_predict1.autoregressive.modules.linear import ColumnParallelLinear, TrainingVocabParallelEmbedding | |
| from cosmos_predict1.autoregressive.modules.mlp import TrainingMLP, compute_llama3_ffn_hidden_dim | |
| from cosmos_predict1.autoregressive.modules.normalization import create_norm | |
| from cosmos_predict1.autoregressive.training.modules.attention import ( | |
| GQA, | |
| create_group_causal_attn_mask, | |
| enable_different_context_dim_in_te_ca, | |
| enable_qk_normalization_in_te_mha, | |
| ) | |
| from cosmos_predict1.autoregressive.utils.checkpoint import process_state_dict, substrings_to_ignore | |
| from cosmos_predict1.autoregressive.utils.misc import maybe_convert_to_namespace | |
| from cosmos_predict1.autoregressive.utils.parallel import ( | |
| AllReduceBWDRMSNormTE, | |
| allreduce_layernorm_grads, | |
| sync_1d_parameters, | |
| ) | |
| _MLP_HIDDEN_DIM_DIVISOR = ( | |
| 4 # hidden dim of the action embedding layer is action_embedding_dim // _MLP_HIDDEN_DIM_DIVISOR | |
| ) | |
| _T5_NUM_TOKENS = 512 | |
| class TransformerBlock(nn.Module): | |
| """ | |
| A single transformer block consisting of an attention layer and a feed-forward layer. | |
| """ | |
| def __init__(self, layer_id: int, model_parallel: Optional[ModelParallelConfig] = None, args=None): | |
| """ | |
| Initializes the TransformerBlock module. | |
| Args: | |
| layer_id: The ID of the transformer block. | |
| args: The model arguments containing hyperparameters. | |
| """ | |
| super().__init__() | |
| args = maybe_convert_to_namespace(args) | |
| attention_args = { | |
| "n_heads": args["n_heads"], | |
| "n_kv_heads": args["n_kv_heads"], | |
| "dim": args["dim"], | |
| "context_dim": None, | |
| "max_batch_size": args["max_batch_size"], | |
| "max_seq_len": args["max_seq_len"], | |
| "inference": args["inference"], | |
| "flash_attn": args["flash_attn"], | |
| "use_qk_normalization": args["use_qk_normalization"], | |
| "attention_dropout": getattr(args, "attention_dropout", 0.0), | |
| "set_parallel_mode": args["set_parallel_mode"], | |
| "model_parallel": model_parallel, | |
| "attention_tp": args["attention_tp"], | |
| "causal_mask": args["causal_mask"], | |
| "head_dim": args["head_dim"], | |
| "fuse_qkv": getattr(args, "fuse_qkv", False), | |
| "precision": getattr(args, "precision", "bfloat16"), | |
| "attention_type": getattr(args, "attention_type", "self"), | |
| } | |
| self.attention = GQA(**attention_args) | |
| self.has_cross_attention = False | |
| self.cross_attention, self.cross_attention_norm = None, None | |
| if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: | |
| self.has_cross_attention = True | |
| cross_attention_args = attention_args.copy() | |
| cross_attention_args.update( | |
| {"context_dim": args["context_dim"], "fuse_qkv": False, "attention_type": "cross"} | |
| ) | |
| self.cross_attention = GQA(**cross_attention_args) | |
| self.cross_attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) | |
| self.feed_forward = TrainingMLP( | |
| dim=args["dim"], | |
| hidden_dim=( | |
| compute_llama3_ffn_hidden_dim( | |
| dim=args["dim"], multiple_of=args["multiple_of"], ffn_dim_multiplier=args["ffn_dim_multiplier"] | |
| ) | |
| if args["ffn_hidden_size"] is None | |
| else args["ffn_hidden_size"] | |
| ), | |
| hidden_dropout=getattr(args, "hidden_dropout", 0.0), | |
| set_parallel_mode=args["set_parallel_mode"], | |
| model_parallel=model_parallel, | |
| inference=args["inference"], | |
| ) | |
| self.layer_id = layer_id | |
| self.num_layers = args["n_layers"] | |
| self.attention_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) | |
| self.ffn_norm = create_norm(args["norm_type"], dim=args["dim"], eps=args["norm_eps"]) | |
| # If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the | |
| # total number of transformer blocks. Default is `True` (following the TorchTitan implementation of Llama3). | |
| if getattr(args, "depth_init", True): | |
| self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 | |
| else: | |
| self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| rope: RotaryPositionEmbeddingPytorch, | |
| input_pos: Optional[torch.Tensor] = None, | |
| mask: Optional[torch.Tensor] = None, | |
| context: Optional[torch.Tensor] = None, | |
| context_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Performs the forward pass of the TransformerBlock module. | |
| Args: | |
| x: The input tensor. | |
| input_pos: The position of the current sequence. Used in inference (with KV cache) only. | |
| freqs_cis: The precomputed frequency values for rotary position embeddings. | |
| mask: The attention mask tensor. | |
| context (Optional[torch.Tensor]): The context tensor added via cross-attn. | |
| context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. | |
| Returns: | |
| The output tensor after applying the transformer block. | |
| """ | |
| # Apply attention and residual connection | |
| h = x + self.attention(self.attention_norm(x), rope=rope, input_pos=input_pos, mask=mask) | |
| # If insert cross-attention, apply CA and residual connection | |
| if self.has_cross_attention: | |
| h = h + self.cross_attention( | |
| self.cross_attention_norm(h), rope=rope, input_pos=input_pos, mask=context_mask, context=context | |
| ) | |
| # Apply feed-forward network and residual connection | |
| out = h + self.feed_forward(self.ffn_norm(h)) | |
| return out | |
| def init_weights(self): | |
| """ | |
| Initializes the weights of the transformer block. | |
| """ | |
| for norm in (self.attention_norm, self.ffn_norm): | |
| norm.reset_parameters() | |
| self.attention.init_weights(self.weight_init_std) | |
| self.feed_forward.init_weights(self.weight_init_std) | |
| if self.has_cross_attention: | |
| self.cross_attention_norm.reset_parameters() | |
| self.cross_attention.init_weights(self.weight_init_std) | |
| # zero-init the final output layer of cross-attention | |
| # nn.init.zeros_(self.cross_attention.wo.weight) | |
| class TransformerBlockTE(te.pytorch.TransformerLayer): | |
| """ | |
| Wrapper class over TE's `TransformerLayer`. | |
| Args: | |
| layer_id (int): The ID of the transformer block. | |
| args: The model arguments containing hyperparameters. | |
| """ | |
| def __init__( | |
| self, | |
| layer_id: int, | |
| args, | |
| tp_group: Optional[ProcessGroup] = None, | |
| set_parallel_mode: bool = False, | |
| attn_input_format: str = "bshd", | |
| ): | |
| attention_args = { | |
| "hidden_size": args["dim"], | |
| "ffn_hidden_size": ( | |
| compute_llama3_ffn_hidden_dim( | |
| dim=args["dim"], multiple_of=args["multiple_of"], ffn_dim_multiplier=args["ffn_dim_multiplier"] | |
| ) | |
| if args["ffn_hidden_size"] is None | |
| else args["ffn_hidden_size"] | |
| ), | |
| "num_attention_heads": args["n_heads"], | |
| "bias": False, | |
| "layernorm_epsilon": args["norm_eps"], | |
| "hidden_dropout": getattr(args, "hidden_dropout", 0.0), | |
| "attention_dropout": getattr(args, "attention_dropout", 0.0), | |
| "normalization": "RMSNorm", | |
| "activation": "swiglu", | |
| "attn_input_format": attn_input_format, | |
| "num_gqa_groups": args["n_kv_heads"], | |
| "fuse_wgrad_accumulation": False, | |
| "fuse_qkv_params": False, | |
| "tp_group": tp_group, | |
| "sequence_parallel": args["sequence_parallel"], | |
| "set_parallel_mode": set_parallel_mode, | |
| "layer_number": layer_id + 1, | |
| "self_attn_mask_type": "causal" if args["causal_mask"] else "no_mask", | |
| "kv_channels": args["head_dim"], # If None, te.pytorch.TransformerLayer defaults it to dim // n_heads | |
| "layer_type": "encoder", | |
| } | |
| self.has_cross_attention = False | |
| if args["insert_cross_attn"] and layer_id % args["insert_cross_attn_every_k_layers"] == 0: | |
| self.has_cross_attention = True | |
| attention_args["layer_type"] = "decoder" | |
| super().__init__(**attention_args) | |
| if args["use_qk_normalization"]: | |
| # Add QK normalization layers and replace the forward function of original Multi-Head Attention module with | |
| # our custom one to add QK normalization operations. | |
| enable_qk_normalization_in_te_mha(self.self_attention, norm_eps=args["norm_eps"], is_self_attn=True) | |
| if self.has_cross_attention: | |
| enable_qk_normalization_in_te_mha(self.inter_attention, norm_eps=args["norm_eps"], is_self_attn=False) | |
| if self.has_cross_attention: | |
| enable_different_context_dim_in_te_ca( | |
| self.inter_attention, context_dim=args["context_dim"], args=attention_args | |
| ) | |
| self.layer_id = layer_id | |
| self.num_layers = args["n_layers"] | |
| # If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the | |
| # total number of transformer blocks. Default is `True` (following the TorchTitan implementation of Llama3). | |
| if getattr(args, "depth_init", True): | |
| self.weight_init_std = 0.02 / (2 * (self.layer_id + 1)) ** 0.5 | |
| else: | |
| self.weight_init_std = 0.02 / (2 * self.num_layers) ** 0.5 | |
| self.args = args | |
| self.inference = args["inference"] | |
| def set_inference_flag(self, flag: bool): | |
| """ | |
| Set the inference flag for the transformer layers. | |
| """ | |
| self.inference = flag | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| rotary_pos_emb: torch.Tensor, | |
| mask: Optional[torch.Tensor], | |
| inference_params: Optional[InferenceParams] = None, | |
| context: Optional[torch.Tensor] = None, | |
| context_mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Custom forward to make sure we only pass relevant arguments to the | |
| forward pass of the `TransformerLayer`. | |
| Args: | |
| x (torch.Tensor): The input tensor. | |
| mask (Optional[torch.Tensor]): The attention mask tensor. | |
| inference_params (Optional[InferenceParams]): Inference parameters used for caching key-value pairs in the TE backend. | |
| It is not applicable for the PyTorch backend and should be set to None in that case. | |
| context (Optional[torch.Tensor]): The context tensor added via cross-attn. | |
| context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. | |
| Returns: | |
| torch.Tensor: The output tensor after applying the transformer block | |
| """ | |
| inference_params = None if not self.inference else inference_params | |
| output = super().forward( | |
| x, | |
| attention_mask=mask, | |
| rotary_pos_emb=rotary_pos_emb.to(x.device), | |
| inference_params=inference_params, | |
| encoder_output=context, | |
| enc_dec_attn_mask=context_mask, | |
| ) | |
| return output | |
| def init_weights(self): | |
| """ | |
| Initializes the weights of the transformer block. | |
| """ | |
| # Self Attention | |
| attn_layer = self.self_attention.layernorm_qkv | |
| for linear_weight in [attn_layer.query_weight, attn_layer.key_weight, attn_layer.value_weight]: | |
| nn.init.trunc_normal_(linear_weight, mean=0.0, std=0.02) | |
| nn.init.trunc_normal_(self.self_attention.proj.weight, mean=0.0, std=self.weight_init_std) | |
| # Cross Attention | |
| if self.has_cross_attention: | |
| nn.init.trunc_normal_(self.inter_attention.layernorm_query.query_weight, mean=0.0, std=0.02) | |
| nn.init.trunc_normal_(self.inter_attention.key_value.key_weight, mean=0.0, std=0.02) | |
| nn.init.trunc_normal_(self.inter_attention.key_value.value_weight, mean=0.0, std=0.02) | |
| # zero-init the final output layer of cross-attention | |
| if self.args["zero_init_cross_attn_proj"]: | |
| nn.init.zeros_(self.inter_attention.proj.weight) | |
| else: | |
| nn.init.trunc_normal_(self.inter_attention.proj.weight, mean=0.0, std=self.weight_init_std) | |
| # RMS Normalization | |
| for norm_weight in (self.layernorm_mlp.layer_norm_weight, self.self_attention.layernorm_qkv.layer_norm_weight): | |
| torch.nn.init.ones_(norm_weight) | |
| # In the case of QK Normalization, we also reset the parameters of the QK normalization layers. | |
| if self.args["use_qk_normalization"]: | |
| for norm_weight in [self.self_attention.q_norm.weight, self.self_attention.k_norm.weight]: | |
| torch.nn.init.ones_(norm_weight) | |
| # MLP | |
| for linear_weight in (self.layernorm_mlp.fc1_weight, self.layernorm_mlp.fc2_weight): | |
| nn.init.trunc_normal_(linear_weight, mean=0.0, std=self.weight_init_std) | |
| # The fc1_weight is a fused weight of w1 and w2 in the MLP of the PyTorch backend, where w1 is initialized with | |
| # a different std (0.02 by TorchTitan). So we re-initialize the w1 part of the fused weight below. | |
| split_point = self.layernorm_mlp.fc1_weight.shape[0] // 2 | |
| nn.init.trunc_normal_(self.layernorm_mlp.fc1_weight[:split_point], mean=0.0, std=0.02) | |
| class Transformer(nn.Module): | |
| """ | |
| The Transformer network consisting of transformer blocks. | |
| """ | |
| def __init__(self, params, model_parallel=None, tokenizer_config=None, init_weights: bool = True): | |
| """ | |
| Initializes the Transformer module. | |
| Args: | |
| params: The model parameters containing hyperparameters. | |
| model_parallel: The model parallel configuration. | |
| tokenizer_config: The model tokenizer configuration. | |
| init_weights (bool): Whether to initialize the weights of the transformer following | |
| TorchTitan's Llama3 initialization scheme. | |
| """ | |
| super().__init__() | |
| # Check if self.params is an OmegaConf DictConfig instance | |
| self.params = maybe_convert_to_namespace(params) | |
| self.vocab_size = params["vocab_size"] | |
| self.n_layers = params["n_layers"] | |
| self.precision = getattr(torch, params["precision"]) | |
| self.inference = params["inference"] | |
| self.backend = params["backend"] | |
| self.tokenizer_config = tokenizer_config | |
| self.model_parallel = model_parallel | |
| self.num_video_frames = params["num_video_frames"] | |
| self.token_emb_dropout = nn.Dropout(getattr(params, "embedding_dropout", 0.0)) | |
| tp_group = self._get_tp_group() | |
| # Sequence parallelism requires the first dimension to be the sequence dimension. When sequence parallelism | |
| # is enabled, we transpose the first two dimensions of the input tensor, and specify the format as "sbhd", | |
| # (sequence, batch, head, dim). Otherwise, the input format is "bshd" (batch, sequence, head, dim). | |
| self.attn_input_format = "bshd" if not params["sequence_parallel"] else "sbhd" | |
| # Token embeddings | |
| self.tok_embeddings = self._create_token_embeddings(self.model_parallel) | |
| self.rope_config = self._create_rope_config() | |
| if self.backend == "pytorch": | |
| self._initialize_pytorch_backend(model_parallel) | |
| elif self.backend == "transformer_engine": | |
| self._initialize_transformer_engine_backend(tp_group) | |
| else: | |
| raise ValueError(f"Unknown backend: {self.backend}") | |
| self.output = self._create_output_projection(model_parallel) | |
| # Action conditioning | |
| self.use_action_condition = getattr(params, "use_action_condition", False) | |
| if self.use_action_condition: | |
| self.action_dim = getattr( | |
| params, "action_dim", _ACTION_DIM | |
| ) # e.g., [Δx, Δy, Δz, rx, ry, rz, gripper_open, zero_pad] | |
| self.action_embedding_dim = self.params["action_embedding_dim"] # 1024 | |
| self.action_embedding_mode = getattr(params, "action_embedding_mode", "mlp") # Default to mlp mode | |
| self.group_causal_mask_mode = getattr( | |
| params, "group_causal_mask_mode", None | |
| ) # Default to None, 'causal' or 'group_diagonal' | |
| self.action_embedding_layers = self._create_action_projection() | |
| if params["sequence_parallel"]: | |
| if model_parallel is None: | |
| setattr(params, "sequence_parallel", False) | |
| log.critical("model_parallel is None. Disabling sequence parallelism.") | |
| self.sequence_parallel_enabled = False | |
| else: | |
| assert self.backend == "transformer_engine", f"Invalid backend: {self.backend} for sequence parallelism" | |
| assert ( | |
| params["tensor_model_parallel_size"] > 1 | |
| ), f"Invalid tensor_model_parallel_size: {params['tensor_model_parallel_size']}" | |
| self.sequence_parallel_enabled = True | |
| else: | |
| self.sequence_parallel_enabled = False | |
| if init_weights: | |
| self.init_weights() | |
| # Set default value for peft_last_n_layers and peft_every_n_layers | |
| self.peft_last_n_layers = getattr(params, "peft_last_n_layers", 0) | |
| self.peft_every_n_layers = getattr(params, "peft_every_n_layers", 0) | |
| if self.peft_last_n_layers > 0 or self.peft_every_n_layers > 0: | |
| self._setup_peft() | |
| # Freeze network parameters for finetuning w/ cross-attention | |
| self.has_cross_attention = getattr(params, "insert_cross_attn", False) | |
| if self.has_cross_attention: | |
| self.ca_every_k_layers = getattr(params, "insert_cross_attn_every_k_layers", 1) | |
| self.finetune_layers_with_cross_attn = getattr(params, "finetune_layers_with_cross_attn", False) | |
| self.finetune_layers_without_cross_attn = getattr(params, "finetune_layers_without_cross_attn", False) | |
| self._setup_cross_attn_ft() | |
| if self.params["apply_abs_pos_emb"]: | |
| self.pos_emb_config = self._create_abs_pos_emb_config() | |
| self.pos_emb, self.abs_pos_emb = self._initialize_abs_pos_emb() | |
| if self.attn_input_format == "sbhd": | |
| self.abs_pos_emb = self.abs_pos_emb.transpose(0, 1).contiguous() | |
| self._broadcast_pos_emb(self.abs_pos_emb, tp_group) | |
| def _initialize_pytorch_backend(self, model_parallel): | |
| self.layers = nn.ModuleList( | |
| [ | |
| TransformerBlock(layer_id, model_parallel, self.params).to(self.precision) | |
| for layer_id in range(self.n_layers) | |
| ] | |
| ) | |
| self.norm = create_norm(self.params["norm_type"], dim=self.params["dim"], eps=self.params["norm_eps"]).to( | |
| self.precision | |
| ) | |
| pytorch_rope_version = getattr(self.params, "pytorch_rope_version", "v2") | |
| if pytorch_rope_version == "v1": | |
| self.rope = RotaryPositionEmbeddingPytorch(**self.rope_config) | |
| elif pytorch_rope_version == "v2": | |
| training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None | |
| self.rope = RotaryPositionEmbeddingPytorchV2( | |
| seq_len=self.params["max_seq_len"], training_type=training_type, **self.rope_config | |
| ) | |
| self._broadcast_pos_emb(self.rope.cos_cached, tp_group=self._get_tp_group()) | |
| self._broadcast_pos_emb(self.rope.sin_cached, tp_group=self._get_tp_group()) | |
| else: | |
| raise ValueError(f"Unknown pytorch_rope_version: {pytorch_rope_version}") | |
| self.causal_mask = torch.tril( | |
| torch.ones(self.params["max_seq_len"], self.params["max_seq_len"], dtype=torch.bool) | |
| ).cuda() | |
| def _initialize_transformer_engine_backend(self, tp_group): | |
| self.layers = self._create_transformer_layers(tp_group) | |
| if self.params["sequence_parallel"]: | |
| tp_group = parallel_state.get_tensor_model_parallel_group() | |
| self.norm = AllReduceBWDRMSNormTE( | |
| self.params["dim"], | |
| process_group=tp_group, | |
| eps=self.params["norm_eps"], | |
| sequence_parallel=True, | |
| ).to(self.precision) | |
| else: | |
| self.norm = RMSNormTE(self.params["dim"], eps=self.params["norm_eps"]).to(self.precision) | |
| self.rope, self.rotary_pos_emb = self._initialize_rope() | |
| self._broadcast_pos_emb(self.rotary_pos_emb, tp_group) | |
| def _create_rope_config(self) -> Dict: | |
| shape_map = { | |
| "3D": self.params["video_latent_shape"], | |
| "2D": self.params["image_latent_shape"], | |
| "1D": None, | |
| } | |
| latent_shape = shape_map.get(self.params["rope_dim"], None) | |
| head_dim = self.params["head_dim"] | |
| if head_dim is None: | |
| head_dim = self.params["dim"] // self.params["n_heads"] | |
| return { | |
| "dim": head_dim, | |
| "max_position_embeddings": self.params["max_seq_len"], | |
| "original_max_position_embeddings": self.params["original_seq_len"], | |
| "rope_theta": self.params["rope_theta"], | |
| "apply_yarn": self.params["apply_yarn"], | |
| "scale": self.params["yarn_scale"], | |
| "beta_fast": self.params["yarn_beta_fast"], | |
| "beta_slow": self.params["yarn_beta_slow"], | |
| "rope_dim": self.params["rope_dim"], | |
| "latent_shape": latent_shape, | |
| "original_latent_shape": self.params["original_latent_shape"], | |
| "pad_to_multiple_of": self.params["pad_to_multiple_of"], | |
| } | |
| def _create_abs_pos_emb_config(self): | |
| shape_map = { | |
| "3D": self.params["video_latent_shape"], | |
| "2D": self.params["image_latent_shape"], | |
| "1D": None, | |
| } | |
| latent_shape = shape_map.get(self.params["rope_dim"], None) | |
| return { | |
| "dim": self.params["dim"], | |
| "latent_shape": latent_shape, | |
| "pad_to_multiple_of": self.params["pad_to_multiple_of"], | |
| } | |
| def _create_token_embeddings(self, model_parallel=None, vocab_size: int = None): | |
| """ | |
| Create token embeddings. | |
| Args: | |
| model_parallel: The model parallel configuration. | |
| Returns: | |
| nn.Module: Token embeddings module. | |
| """ | |
| if vocab_size is None: | |
| vocab_size = self.params["vocab_size"] | |
| tp_size = self.params["tensor_model_parallel_size"] | |
| if tp_size > 1: | |
| # For inference in the PyTorch backend, we use PyTorch's allreduce (tracable) in the forward pass to enable torch.compile. | |
| use_inference_allreduce = self.inference and self.params["backend"] == "pytorch" | |
| emb = TrainingVocabParallelEmbedding( | |
| vocab_size, | |
| self.params["dim"], | |
| init_method=lambda x: x, | |
| config=model_parallel, | |
| sequence_parallel=self.params["sequence_parallel"], | |
| batch_first=not self.params["sequence_parallel"], | |
| use_inference_allreduce=use_inference_allreduce, | |
| ).to(self.precision) | |
| return emb | |
| else: | |
| return nn.Embedding(vocab_size, self.params["dim"]).to(self.precision) | |
| def _create_action_projection(self): | |
| """ | |
| Create the action projection layer. | |
| Returns: | |
| nn.Module: Action projection layer. | |
| """ | |
| assert self.action_embedding_mode == "mlp", f"Invalid action embedding mode: {self.action_embedding_mode}" | |
| # This method is not working well. (option 1. default) exp102e | |
| hidden_dim = self.action_embedding_dim // _MLP_HIDDEN_DIM_DIVISOR | |
| action_embedding_layers = nn.Sequential( | |
| nn.Linear(self.action_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, self.action_embedding_dim), | |
| ) | |
| return action_embedding_layers | |
| def _get_tp_group( | |
| self, | |
| ): | |
| """ | |
| Get tensor parallel process group if applicable. | |
| Returns: | |
| torch.distributed.ProcessGroup or None: Tensor parallel process group if tensor parallelism is enabled, else None. | |
| """ | |
| if self.params["tensor_model_parallel_size"] > 1: | |
| tp_group = parallel_state.get_tensor_model_parallel_group() | |
| log.info(f"Using tensor model parallel group: {tp_group}") | |
| return tp_group | |
| return None | |
| def _create_transformer_layers(self, tp_group): | |
| """ | |
| Create the transformer layers. | |
| Args: | |
| tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. | |
| Returns: | |
| nn.ModuleList: List of transformer layers. | |
| """ | |
| return nn.ModuleList( | |
| [ | |
| TransformerBlockTE( | |
| layer_id, | |
| self.params, | |
| tp_group, | |
| set_parallel_mode=self.params["set_parallel_mode"], | |
| attn_input_format=self.attn_input_format, | |
| ).to(self.precision) | |
| for layer_id in range(self.params["n_layers"]) | |
| ] | |
| ) | |
| def _create_output_projection(self, model_parallel=None, vocab_size: int = None): | |
| """ | |
| Create the output projection layer. | |
| Args: | |
| model_parallel: The model parallel configuration. | |
| vocab_size (int): Vocabulary size (to override the default vocab size). | |
| Returns: | |
| LinearTE: Output projection layer. | |
| """ | |
| if vocab_size is None: | |
| vocab_size = self.params["vocab_size"] | |
| if self.params["tensor_model_parallel_size"] > 1: | |
| if self.params["backend"] == "pytorch" and self.inference: | |
| tp_size = self.params["tensor_model_parallel_size"] | |
| layer = nn.Linear(self.params["dim"], vocab_size // tp_size, bias=False).to(self.precision) | |
| return layer | |
| else: | |
| layer = ColumnParallelLinear( | |
| self.params["dim"], | |
| vocab_size, | |
| bias=False, | |
| gather_output=False, | |
| init_method=lambda x: x, | |
| config=model_parallel, | |
| ).to(self.precision) | |
| return layer | |
| else: | |
| # No Tensor Parallelism | |
| if self.params["backend"] == "pytorch": | |
| return nn.Linear(self.params["dim"], vocab_size, bias=False).to(self.precision) | |
| elif self.params["backend"] == "transformer_engine": | |
| return LinearTE(self.params["dim"], vocab_size, bias=False).to(self.precision) | |
| else: | |
| raise ValueError("Unknown backend: " + self.params["backend"]) | |
| def _initialize_rope( | |
| self, | |
| ): | |
| """ | |
| Initialize the rotary position embedding. | |
| Returns: | |
| tuple: (RotaryPositionEmbeddingTE, torch.Tensor) The RoPE module and the rotary position embeddings. | |
| """ | |
| rope = RotaryPositionEmbeddingTE(**self.rope_config) | |
| training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None | |
| rotary_pos_emb = rope.forward(seq_len=self.params["max_seq_len"], training_type=training_type) | |
| return rope, rotary_pos_emb | |
| def _initialize_abs_pos_emb(self): | |
| pos_emb = SinCosPosEmbAxisTE(**self.pos_emb_config) | |
| training_type = self.tokenizer_config.training_type if self.tokenizer_config is not None else None | |
| abs_pos_emb = pos_emb.forward(training_type=training_type) | |
| return pos_emb, abs_pos_emb | |
| def _broadcast_pos_emb(self, pos_emb, tp_group): | |
| """ | |
| Broadcast the position embeddings across the tensor parallel group. | |
| Args: | |
| pos_emb (torch.Tensor): Position embeddings to broadcast. | |
| tp_group (torch.distributed.ProcessGroup or None): Tensor parallel process group. | |
| """ | |
| if self.params["tensor_model_parallel_size"] > 1: | |
| broadcast(pos_emb, min(get_process_group_ranks(tp_group)), group=tp_group) | |
| def _setup_peft(self): | |
| """ | |
| Set up Parameter Efficient Fine-Tuning (PEFT) by selectively freezing and unfreezing layers. | |
| This method configures the model for fine-tuning by: | |
| 1. Freezing all parameters in the model. | |
| 2. Unfreezing the embedding, normalization and output layers. | |
| 3. Unfreezing the first and last (peft_last_n_layers - 1) transformer layers if peft_last_n_layers is set, | |
| or unfreezing every n layers (flamingo style) if peft_every_n_layers is set. | |
| """ | |
| # Ensure only one of peft_last_n_layers and peft_every_n_layers is set | |
| assert ( | |
| self.peft_last_n_layers == 0 or self.peft_every_n_layers == 0 | |
| ), "Only one of peft_last_n_layers and peft_every_n_layers can be set." | |
| # First, freeze all parameters | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze embedding, normalization and output layers | |
| for param in self.tok_embeddings.parameters(): | |
| param.requires_grad = True | |
| for param in self.norm.parameters(): | |
| param.requires_grad = True | |
| for param in self.output.parameters(): | |
| param.requires_grad = True | |
| # PEFT last n layers | |
| if self.peft_last_n_layers > 0: | |
| # Ensure peft_last_n_layers is at least 2 | |
| assert self.peft_last_n_layers >= 2, "peft_last_n_layers must be at least 2" | |
| # Unfreeze specific transformer layers | |
| total_layers = len(self.layers) | |
| for i, layer in enumerate(self.layers): | |
| if i == 0 or i >= total_layers - self.peft_last_n_layers + 1: | |
| # Unfreeze the first layer and the last (peft_last_n_layers - 1) layers | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| log.info( | |
| f"PEFT setup complete. Trainable components: embeddings, un-embedding, normalization layer, " | |
| f"first transformer layer, last {self.peft_last_n_layers - 1} transformer layers." | |
| ) | |
| # PEFT every n layers (flamingo style, e.g. every 4 layers = layer 0,1,2,4,5,6,... frozen, layer 3,7,11,... is trainable) | |
| else: | |
| trainable_layers = [] | |
| for i, layer in enumerate(self.layers, 1): | |
| if i % self.peft_every_n_layers == 0: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| trainable_layers.append(i - 1) | |
| log.info( | |
| f"PEFT setup complete. Trainable components: embeddings, un-embedding, normalization layer, " | |
| f"every {self.peft_every_n_layers} transformer layers (layer idx {trainable_layers}; total {len(trainable_layers)} layers)." | |
| ) | |
| def _setup_cross_attn_ft(self): | |
| """ | |
| Set up Cross Attention Fine-Tuning by selectively freezing and unfreezing layers. | |
| This method configures the model for fine-tuning by: | |
| 1. Freezing all parameters in the model. | |
| 2. Unfreezing the embedding, normalization and output layers. | |
| 3. Unfreezing all the added cross-attention layers. | |
| 4. If `finetune_layers_with_cross_attn` is True, unfreeze the transformer layers for layers with cross attention. | |
| 5. If `finetune_layers_without_cross_attn` is True, unfreeze the transformer layers for layers without cross attention. | |
| 6. If 'use_action_condition' is True, unfreeze the action embedding layers. | |
| """ | |
| assert self.has_cross_attention, "Must insert cross-attention layers for finetuning." | |
| finetune_layer_num = 0 | |
| # First, freeze all parameters | |
| for param in self.parameters(): | |
| param.requires_grad = False | |
| # Unfreeze embedding, normalization and output layers | |
| for param in self.tok_embeddings.parameters(): | |
| param.requires_grad = True | |
| for param in self.norm.parameters(): | |
| param.requires_grad = True | |
| for param in self.output.parameters(): | |
| param.requires_grad = True | |
| # Unfreeze all the added cross-attention layers | |
| total_layers = len(self.layers) | |
| for i, layer in enumerate(self.layers): | |
| if i % self.ca_every_k_layers == 0: | |
| if self.params["backend"] == "pytorch": | |
| for param in layer.cross_attention.parameters(): | |
| param.requires_grad = True | |
| elif self.params["backend"] == "transformer_engine": | |
| for param in layer.inter_attention.parameters(): | |
| param.requires_grad = True | |
| else: | |
| raise ValueError("Unknown backend: " + self.params["backend"]) | |
| # Unfreeze the transformer layers for layers with cross attention | |
| if self.finetune_layers_with_cross_attn: | |
| for i, layer in enumerate(self.layers): | |
| if i % self.ca_every_k_layers == 0: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| finetune_layer_num += 1 | |
| # Unfreeze the transformer layers for layers without cross attention | |
| if self.finetune_layers_without_cross_attn: | |
| for i, layer in enumerate(self.layers): | |
| if i % self.ca_every_k_layers != 0: | |
| for param in layer.parameters(): | |
| param.requires_grad = True | |
| finetune_layer_num += 1 | |
| # Unfreeze the action embedding layers | |
| if self.use_action_condition: | |
| for param in self.action_embedding_layers.parameters(): | |
| param.requires_grad = True | |
| log.info( | |
| f"cross attention finetune setup complete. Trainable components: cross-attention layer, " | |
| f"fully trainable transformer layer number is {finetune_layer_num}." | |
| ) | |
| def enable_context_parallel(self, cp_group: ProcessGroup): | |
| """ | |
| Enable context parallelism for the transformer model. | |
| This method sets up context parallelism by configuring the context parallel group | |
| and updating each transformer layer to support context parallelism. | |
| Args: | |
| cp_group (ProcessGroup): The process group for context parallelism. | |
| Notes: | |
| - Updates the model's context parallel group and size. | |
| - Configures each transformer layer for context parallelism. | |
| - Enables context parallelism for the rotary position embedding if using the transformer engine backend. | |
| """ | |
| cp_ranks = get_process_group_ranks(cp_group) | |
| cp_size = len(cp_ranks) | |
| # Set these attributes for spliting the data after embedding. | |
| self.cp_group = cp_group | |
| # Set these attributes for computing the loss. | |
| self.cp_size = cp_size | |
| for layer_idx, layer in enumerate(self.layers): | |
| if isinstance(layer, TransformerBlockTE): | |
| layer.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) | |
| elif hasattr(layer, "module") and isinstance(layer.module, TransformerBlockTE): | |
| layer.module.set_context_parallel_group(cp_group, cp_ranks, torch.cuda.Stream()) | |
| else: | |
| log.warning(f"Layer {layer_idx} does not support context parallelism") | |
| def set_inference_flag(self, flag: bool): | |
| """ | |
| Set the inference flag for the transformer layers. | |
| """ | |
| log.info(f"Setting inference flag to {flag}") | |
| self.inference = flag | |
| if self.inference: | |
| self.eval() | |
| if self.params["backend"] == "pytorch": | |
| for layer in self.layers: | |
| layer.attention.set_inference_flag(flag) | |
| elif self.params["backend"] == "transformer_engine": | |
| for layer in self.layers: | |
| layer.set_inference_flag(flag) | |
| self._maybe_change_sequence_parallel_status(enable=False) | |
| def _maybe_change_sequence_parallel_status(self, enable: bool): | |
| """ | |
| Change the sequence parallel status of the transformer layers. | |
| """ | |
| if enable and not self.sequence_parallel_enabled: | |
| for name, module in self.named_modules(): | |
| if hasattr(module, "sequence_parallel"): | |
| assert isinstance( | |
| module.sequence_parallel, bool | |
| ), f"Invalid type of {name}: {type(module.sequence_parallel)}" | |
| setattr(module, "sequence_parallel", True) | |
| self.sequence_parallel_enabled = True | |
| elif not enable and self.sequence_parallel_enabled: | |
| for name, module in self.named_modules(): | |
| if hasattr(module, "sequence_parallel"): | |
| assert isinstance( | |
| module.sequence_parallel, bool | |
| ), f"Invalid type of {name}: {type(module.sequence_parallel)}" | |
| setattr(module, "sequence_parallel", False) | |
| self.sequence_parallel_enabled = False | |
| def forward( | |
| self, | |
| tokens: Optional[torch.Tensor] = None, | |
| input_pos: Optional[torch.Tensor] = None, | |
| inference_params: Optional[InferenceParams] = None, | |
| token_embeddings: Optional[torch.Tensor] = None, | |
| context: Optional[torch.Tensor] = None, | |
| context_mask: Optional[torch.Tensor] = None, | |
| action: Optional[torch.Tensor] = None, | |
| total_seq_len: Optional[int] = None, | |
| return_hidden_states: bool = False, | |
| mask: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| """ | |
| Performs the forward pass of the Transformer module. | |
| Args: | |
| tokens (torch.Tensor, optional): The input tensor of token IDs. | |
| input_pos (Optional[torch.Tensor]): The position of the current sequence. Used in inference with KV cache. PyTorch backend only. | |
| inference_params (InferenceParams, optional): Parameters for inference. | |
| token_embeddings (torch.Tensor, optional): Precomputed token embeddings. If provided, tokens should be None. | |
| context (Optional[torch.Tensor]): The context tensor added via cross-attn. | |
| context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. | |
| action (Optional[torch.Tensor]): The robot action tensor for conditioning. | |
| total_seq_len (Optional[int]): The total sequence length (before applying context parallelism). | |
| return_hidden_states (bool): Whether to return hidden states. | |
| Returns: | |
| The output tensor after applying the transformer layers. | |
| """ | |
| # Turn on/off sequence parallelism based on the training status | |
| self._maybe_change_sequence_parallel_status(enable=self.training and self.params["sequence_parallel"]) | |
| # Token embeddings | |
| assert ( | |
| tokens is None or token_embeddings is None | |
| ), "Either tokens or token_embeddings should be provided, not both." | |
| if token_embeddings is None: | |
| seq_len = tokens.shape[1] | |
| h = self.token_emb_dropout(self.tok_embeddings(tokens)) | |
| else: | |
| seq_len = token_embeddings.shape[1] | |
| h = self.token_emb_dropout(token_embeddings) | |
| if mask is None: | |
| # Create attention mask | |
| mask = self._create_attention_mask(input_pos=input_pos) | |
| # Action embedding | |
| if self.use_action_condition and action is not None: | |
| assert self.action_embedding_mode == "mlp", f"Invalid action embedding mode: {self.action_embedding_mode}" | |
| # change action type to bfloat16, of shape [batch_size, action_dim] | |
| action = action.to(torch.bfloat16) | |
| # action_emb shape: [batch_size, action_dim, action_embedding_dim] | |
| action_emb = self.action_embedding_layers(action).unsqueeze(1).repeat(1, self.action_dim, 1) | |
| # Use action_emb as context | |
| if self.params["concat_action_to_context"]: | |
| context = torch.zeros( | |
| (action_emb.shape[0], _T5_NUM_TOKENS, self.action_embedding_dim), device=h.device, dtype=h.dtype | |
| ) | |
| # context[:, -1, :] = action_emb[:, 0, :] # overwrite the last token with action_emb | |
| context = torch.cat([context, action_emb[:, 0:1, :]], dim=1) | |
| else: | |
| context = action_emb # [batch_size, action_dim, action_embedding_dim] | |
| # Create context mask | |
| if self.group_causal_mask_mode is not None: | |
| num_temporal_groups = self.num_video_frames - 1 # number of latent frames | |
| num_query_per_group = seq_len // num_temporal_groups # number of latent tokens per frame | |
| num_key_per_group = self.action_dim // num_temporal_groups | |
| context_mask = create_group_causal_attn_mask( | |
| num_temporal_groups=num_temporal_groups, | |
| num_query_per_group=num_query_per_group, | |
| num_key_per_group=num_key_per_group, | |
| mode=self.group_causal_mask_mode, | |
| ) # [L (query), S (key)] | |
| context_mask = context_mask.unsqueeze(0) # [1, L (query), S (key)] | |
| context_mask = context_mask.repeat(context.shape[0], 1, 1) # [batch_size, L (query), S (key)] | |
| context_mask = context_mask.to(context.device) | |
| else: | |
| context_mask = torch.ones( | |
| (context.shape[0], context.shape[1]), device=context.device, dtype=torch.bool | |
| ) # [batch_size, action_dim] | |
| # Prepare layer arguments | |
| layer_kwargs = self._prepare_layer_kwargs( | |
| total_seq_len=total_seq_len, | |
| input_pos=input_pos, | |
| mask=mask, | |
| inference_params=inference_params, | |
| context=context, | |
| context_mask=context_mask, | |
| ) | |
| # Apply transformer layers | |
| for layer in self.layers: | |
| if self.params["apply_abs_pos_emb"]: | |
| h = self.apply_abs_pos_emb(h, input_pos=input_pos, total_seq_len=total_seq_len) | |
| h = layer(h, **layer_kwargs) | |
| # Apply final layer normalization | |
| h = self.norm(h) | |
| if return_hidden_states: | |
| return h | |
| # Output linear projection | |
| output = self.output(h) | |
| output = self.process_output(output) | |
| return output | |
| def process_output(self, output: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Adjusts the shape and layout of tensor based on tensor parallelism and attention input format. | |
| The function performs two operations: | |
| 1. If the tensor model parallelism is enabled (`tensor_model_parallel_size > 1`), it gathers the tensor from | |
| the tensor-parallel regions and reshapes it accordingly. | |
| 2. If the attention input format is `"sbhd"` (Sequence, Batch, Hidden Dimension), it transposes the tensor | |
| to the format `(Batch, Sequence, Hidden Dimension)` for further processing. | |
| Args: | |
| output [torch.Tensor]: The tensor before modification. | |
| Returns: | |
| output [torch.Tensor]: The tensor after modification. | |
| """ | |
| if self.params["tensor_model_parallel_size"] > 1: | |
| if self.params["backend"] == "pytorch" and self.inference: | |
| # Use PyTorch all gather | |
| output = funcol.all_gather_tensor( | |
| output, gather_dim=-1, group=parallel_state.get_tensor_model_parallel_group() | |
| ) | |
| else: | |
| # [*, *, hidden_dim // tp_size] --> [*, *, hidden_dim] | |
| output = gather_from_tensor_model_parallel_region(output) | |
| if self.attn_input_format == "sbhd": | |
| # [seq_len, batch_size, hidden_dim] --> [batch_size, seq_len, hidden_dim] | |
| output = output.transpose(0, 1).contiguous() | |
| return output | |
| def _create_attention_mask(self, input_pos: Optional[torch.Tensor]) -> Optional[torch.Tensor]: | |
| """ | |
| Creates an attention mask for the transformer layers. | |
| Args: | |
| input_pos[torch.Tensor]: The position of input sequence (used for inference only). | |
| Returns: | |
| Optional[torch.Tensor]: The attention mask, or None for causal mask. | |
| """ | |
| if self.backend == "pytorch" and self.inference: | |
| assert input_pos is not None, "input_pos must be provided for inference" | |
| mask = self.causal_mask[input_pos] | |
| return mask | |
| else: | |
| return None # None means causal mask | |
| def _prepare_layer_kwargs( | |
| self, | |
| total_seq_len: Optional[int], | |
| input_pos: Optional[torch.Tensor], | |
| mask: Optional[torch.Tensor], | |
| inference_params: Optional[InferenceParams], | |
| context: Optional[torch.Tensor], | |
| context_mask: Optional[torch.Tensor], | |
| ) -> Dict[str, Any]: | |
| """ | |
| Prepares the keyword arguments for transformer layers. | |
| Args: | |
| total_seq_len (Optional[int]): The total sequence length (before applying context parallelism). | |
| seq_len (Optional[int]): The length of the input sequence. | |
| input_pos (Optional[torch.Tensor]): The position of the current sequence. | |
| mask (Optional[torch.Tensor]): The attention mask. | |
| inference_params (Optional[InferenceParams]): Parameters for inference. | |
| context (Optional[torch.Tensor]): The context tensor added via cross-attn. | |
| context_mask (Optional[torch.Tensor]): The context cross-attn mask tensor. | |
| Returns: | |
| Dict[str, Any]: A dictionary of keyword arguments for the transformer layers. | |
| """ | |
| if context is not None: | |
| context = context.to(self.precision) | |
| if self.attn_input_format == "sbhd": | |
| context = context.transpose(0, 1).contiguous() | |
| if self.backend == "pytorch": | |
| if isinstance(mask, torch.Tensor) and mask.ndim == 2: | |
| mask = mask[None, None, :, :] | |
| if isinstance(context_mask, torch.Tensor) and context_mask.ndim == 2: | |
| context_mask = context_mask[None, None, :, :] | |
| layer_kwargs = { | |
| "mask": mask, | |
| "context": context, | |
| "context_mask": context_mask, | |
| } | |
| if self.backend == "pytorch": | |
| layer_kwargs["input_pos"] = input_pos | |
| layer_kwargs["rope"] = self.rope | |
| elif self.backend == "transformer_engine": | |
| rotary_pos_emb = self.rotary_pos_emb | |
| try: | |
| cp_size = parallel_state.get_context_parallel_world_size() | |
| except (AssertionError, RuntimeError): | |
| # Fallback if context parallel group isn't initialized | |
| cp_size = 1 | |
| log.warning("Context parallel group not initialized, falling back to size 1") | |
| else: | |
| cp_size = 1 | |
| if cp_size > 1: | |
| assert input_pos is None, "input_pos must be None for context parallelism" | |
| rotary_pos_emb = rotary_pos_emb[:total_seq_len] | |
| rotary_pos_emb = get_pos_emb_on_this_cp_rank(rotary_pos_emb, 0) | |
| layer_kwargs["rotary_pos_emb"] = rotary_pos_emb | |
| layer_kwargs["inference_params"] = inference_params | |
| return layer_kwargs | |
| def apply_abs_pos_emb( | |
| self, x: torch.Tensor, input_pos: int = None, total_seq_len: Optional[int] = None | |
| ) -> torch.Tensor: | |
| """ | |
| Applies the absolute position embeddings to the input tensor. | |
| """ | |
| abs_pos_emb = self.abs_pos_emb | |
| if total_seq_len is not None: | |
| # Truncate the absolute position embeddings to the total sequence length | |
| abs_pos_emb = ( | |
| abs_pos_emb[:total_seq_len, :, :] | |
| if self.attn_input_format == "sbhd" | |
| else abs_pos_emb[:, :total_seq_len, :] | |
| ) | |
| cp_size = parallel_state.get_context_parallel_world_size() if self.training else 1 | |
| if cp_size > 1: | |
| assert input_pos is None | |
| seq_dim = 0 if self.attn_input_format == "sbhd" else 1 | |
| abs_pos_emb = get_pos_emb_on_this_cp_rank(abs_pos_emb, seq_dim=seq_dim) | |
| if self.attn_input_format == "sbhd": | |
| if self.sequence_parallel_enabled: | |
| # Training | |
| assert input_pos is None, "input_pos must be None when training with sequence parallelism" | |
| abs_pos_emb = get_pos_emb_on_this_sptp_rank(abs_pos_emb, seq_dim=0) | |
| else: | |
| # Inference or Evaluation | |
| abs_pos_emb = abs_pos_emb[input_pos, :, :] if input_pos is not None else abs_pos_emb | |
| else: | |
| abs_pos_emb = abs_pos_emb[:, input_pos, :] if input_pos is not None else abs_pos_emb | |
| return x + abs_pos_emb | |
| def expand_vocab( | |
| self, new_vocab_size: int, init_method: str = "gaussian", multiple_of=64, expand_output_layer=True | |
| ): | |
| """ | |
| Expands the vocabulary of the model to the new size. | |
| Args: | |
| new_vocab_size (int): The new vocabulary size. | |
| init_method (str): The initialization method for new embeddings. | |
| Can be "zero" or "gaussian". Default is "gaussian". | |
| multiple_of (int): The new vocabulary size must be a multiple of this value. Defaults to 64 to fully | |
| leverage the power of NVIDIA TensorCore (source 1: https://x.com/karpathy/status/1621578354024677377, | |
| source 2: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc) | |
| expand_output_layer (bool): Whether to also expand the output layer. Defaults to True. | |
| Returns: | |
| None | |
| """ | |
| tp_size = self.params["tensor_model_parallel_size"] | |
| if new_vocab_size <= self.vocab_size: | |
| raise ValueError( | |
| f"New vocabulary size ({new_vocab_size}) must be " f"larger than current size ({self.vocab_size})" | |
| ) | |
| if new_vocab_size % multiple_of != 0: | |
| log.critical(f"New vocabulary size must be a multiple of {multiple_of}. Obtained {new_vocab_size}.") | |
| new_vocab_size = (new_vocab_size // multiple_of + 1) * multiple_of | |
| log.critical(f"Rounded vocabulary size to {new_vocab_size}.") | |
| # Resize token embeddings | |
| old_embeddings = self.tok_embeddings | |
| old_embeddings_requires_grad = old_embeddings.weight.requires_grad | |
| tensor_kwargs = {"device": old_embeddings.weight.device, "dtype": old_embeddings.weight.dtype} | |
| self.tok_embeddings = self._create_token_embeddings( | |
| model_parallel=self.model_parallel, vocab_size=new_vocab_size | |
| ).to(**tensor_kwargs) | |
| # Initialize new embeddings | |
| if init_method not in ["zero", "gaussian"]: | |
| raise ValueError(f"Unknown initialization method: {init_method}") | |
| # The default initialization of nn.Embedding is Gaussian, so we don't need to do anything | |
| # if init_method == "gaussian". Only if init_method == "zero", we need to zero out the new embeddings. | |
| if init_method == "zero": | |
| self.tok_embeddings.weight.data[self.vocab_size // tp_size :].zero_() | |
| # Copy old embeddings | |
| log.info( | |
| f"old_embeddings: {old_embeddings.weight.data.shape}, new_embeddings: {self.tok_embeddings.weight.data.shape}, vocab_size: {self.vocab_size}" | |
| ) | |
| self.tok_embeddings.weight.data[: self.vocab_size // tp_size] = old_embeddings.weight.data | |
| self.tok_embeddings.weight.requires_grad = old_embeddings_requires_grad | |
| # Resize output layer | |
| old_output = self.output | |
| old_output_requires_grad = old_output.weight.requires_grad | |
| self.output = self._create_output_projection( | |
| self.model_parallel, vocab_size=new_vocab_size if expand_output_layer else None | |
| ) | |
| # Initialize new output weights | |
| if init_method == "zero": | |
| self.output.weight.data[self.vocab_size // tp_size :].zero_() | |
| elif init_method == "gaussian": | |
| # Follows the parameter initialization in TorchTitan: | |
| # https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama/model.py | |
| final_out_std = self.params["dim"] ** -0.5 | |
| cutoff_factor = 3 | |
| nn.init.trunc_normal_( | |
| self.output.weight, | |
| mean=0.0, | |
| std=final_out_std, | |
| a=-cutoff_factor * final_out_std, | |
| b=cutoff_factor * final_out_std, | |
| ) | |
| # Copy old output weights | |
| self.output.weight.data[: self.vocab_size // tp_size] = old_output.weight.data | |
| self.output.weight.requires_grad = old_output_requires_grad | |
| # Update vocab size | |
| self.vocab_size = new_vocab_size | |
| log.critical(f"Expanded vocabulary size to {new_vocab_size}") | |
| def init_weights(self): | |
| """ | |
| [Note: On ``init_weights`` vs. ``reset_parameters`` (copied from github.com/pytorch/torchtitan)] | |
| Modules may define ``reset_parameters`` to initialize parameter values. ``reset_parameters`` is meant to only | |
| initialize directly owned parameters/buffers, not those of their child modules, and it can be used to give the | |
| initial values for these tensors. Separately, users may want custom initialization for their modules, different | |
| from that in ``reset_parameters``. For this, we define ``init_weights``. We only call it in the constructor of | |
| this ``Transformer`` root module to avoid reinitializing tensors. | |
| """ | |
| nn.init.normal_(self.tok_embeddings.weight) | |
| for layer in self.layers: | |
| layer.init_weights() | |
| if self.backend == "pytorch": | |
| self.norm.reset_parameters() | |
| elif self.backend == "transformer_engine": | |
| nn.init.ones_(self.norm.weight) | |
| else: | |
| raise ValueError(f"Unknown backend: {self.backend}") | |
| final_out_std = self.params["dim"] ** -0.5 | |
| cutoff_factor = 3 | |
| nn.init.trunc_normal_( | |
| self.output.weight, | |
| mean=0.0, | |
| std=final_out_std, | |
| a=-cutoff_factor * final_out_std, | |
| b=cutoff_factor * final_out_std, | |
| ) | |
| if self.use_action_condition: | |
| for layer in self.action_embedding_layers: | |
| if isinstance(layer, nn.Linear): | |
| nn.init.xavier_uniform_(layer.weight) | |
| nn.init.zeros_(layer.bias) | |
| def state_dict(self, *args, **kwargs): | |
| """ | |
| Process the state dict (e.g., remove "_extra_state" keys imposed by TransformerEngine for FP8). | |
| """ | |
| state_dict = super().state_dict(*args, **kwargs) | |
| return process_state_dict(state_dict) | |
| def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True, assign: bool = False): | |
| """ | |
| Ignore the missing keys with substrings matching `substring_to_ignore` (e.g., "_extra_state" keys imposed by | |
| TransformerEngine for FP8). | |
| """ | |
| state_dict = process_state_dict(state_dict) | |
| missing_keys, unexpected_keys = super().load_state_dict(state_dict, strict=False, assign=assign) | |
| if strict: | |
| actual_missing_keys = [] | |
| for key in missing_keys: | |
| if not any(substring in key for substring in substrings_to_ignore): | |
| actual_missing_keys.append(key) | |
| if len(actual_missing_keys) > 0 or len(unexpected_keys) > 0: | |
| raise ValueError(f"Missing keys: {actual_missing_keys}\n\nUnexpected keys: {unexpected_keys}") | |
| missing_keys = actual_missing_keys | |
| return _IncompatibleKeys(missing_keys, unexpected_keys) | |
| def on_after_backward(self, *args, **kwargs): | |
| """ | |
| All-reduce layernorm grads for tensor/sequence parallelism. | |
| Reference implementation: https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/distributed/finalize_model_grads.py | |
| """ | |
| allreduce_layernorm_grads( | |
| [self], | |
| tensor_model_parallel_size=self.params["tensor_model_parallel_size"], | |
| sequence_parallel=self.params["sequence_parallel"], | |
| ) | |
| def on_before_zero_grad( | |
| self, optimizer: torch.optim.Optimizer, scheduler: torch.optim.lr_scheduler.LRScheduler, iteration: int | |
| ) -> None: | |
| """Hook before zero_grad() is called. | |
| Args: | |
| optimizer (torch.optim.Optimizer): The model optimizer. | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
| iteration (int): Current iteration number. | |
| """ | |
| if self.params["sync_1d_parameters"]: | |
| if self.params["tensor_model_parallel_size"] > 1: | |
| sync_1d_parameters(self, process_group=parallel_state.get_tensor_model_parallel_group()) | |
| if self.params["context_parallel_size"] > 1: | |
| sync_1d_parameters(self, process_group=parallel_state.get_context_parallel_group()) | |