| |
| |
| |
| |
| @@ -86,7 +86,7 @@ class TorchCommonLoadStrategy(LoadCommonStrategy): |
| msc = MultiStorageClientFeature.import_package() |
| return msc.torch.load(load_path, map_location='cpu') |
| else: |
| - return torch.load(load_path, map_location='cpu') |
| + return torch.load(load_path, map_location='cpu', weights_only=False) |
| except FileNotFoundError as e: |
| err_msg = f'Common file {load_path} does not exist' |
| if MultiStorageClientFeature.is_enabled(): |
| |
| |
| |
| |
| @@ -597,10 +597,12 @@ class MCoreLoadPlanner(DefaultLoadPlanner): |
| def _validate_global_shapes(self, metadata, sharded_tensors): |
| for sh_ten in sharded_tensors: |
| if sh_ten.key not in metadata.state_dict_metadata: |
| - raise KeyError( |
| - f"{sh_ten.key} from model not in state dict:" |
| - f" {sorted(metadata.state_dict_metadata.keys())}" |
| - ) |
| + # raise KeyError( |
| + # f"{sh_ten.key} from model not in state dict:" |
| + # f" {sorted(metadata.state_dict_metadata.keys())}" |
| + # ) |
| + print(f"{sh_ten.key} from model not in state dict, will skip") |
| + continue |
| loaded_shape = metadata.state_dict_metadata[sh_ten.key].size |
| expected_shape = self._expected_shape(sh_ten) |
| if loaded_shape != expected_shape: |
| @@ -630,7 +632,7 @@ class MCoreLoadPlanner(DefaultLoadPlanner): |
| tensor_metadata = self.metadata.state_dict_metadata |
| metadata_with_sizes = [ |
| (tensor_metadata[key], tensor_metadata[key].size, sharded_tensor) |
| - for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() |
| + for key, sharded_tensor in self.allow_shape_mismatch_sharded_tensors.items() if key in tensor_metadata |
| ] |
| try: |
| # Temporarily set sizes to expected shapes |
| @@ -959,6 +961,7 @@ class TorchDistLoadShardedStrategy(LoadShardedStrategy): |
| planner=MCoreLoadPlanner( |
| shapes_validation_sharded_tensors=flexible_shape_sharded_tensors, |
| allow_shape_mismatch_sharded_tensors=allow_shape_mismatch_sharded_tensors, |
| + allow_partial_load=True, |
| ), |
| ) |
| |
| |
| |
| |
| |
| @@ -408,6 +408,7 @@ class TELinear(te.pytorch.Linear): |
| ) |
| |
| for param in self.parameters(): |
| + setattr(param, "parallel_mode", parallel_mode) |
| if is_expert: |
| # Reduce the gradient on the expert_data_parallel group for expert linear layers |
| setattr(param, "allreduce", not self.expert_parallel) |
| @@ -1161,6 +1162,61 @@ class TEDotProductAttention(te.pytorch.DotProductAttention): |
| |
| |
| if HAVE_TE and is_te_min_version("1.9.0.dev0"): |
| + def ceil_div(x: int, y: int) -> int: |
| + return (x + y - 1) // y |
| + |
| + class _FakeInt4QuantizationSTE(torch.autograd.Function): |
| + @staticmethod |
| + def forward(ctx, x, group_size): |
| + m, n = x.shape |
| + block_size_m, block_size_n = 1, group_size |
| + |
| + |
| + m_padded = ceil_div(m, block_size_m) * block_size_m |
| + n_padded = ceil_div(n, block_size_n) * block_size_n |
| + |
| + x_padded = torch.zeros( |
| + (m_padded, n_padded), |
| + dtype=x.dtype, device=x.device |
| + ) |
| + x_padded[:m, :n] = x |
| + |
| + x_view = x_padded.view( |
| + m_padded // block_size_m, |
| + block_size_m, |
| + n_padded // block_size_n, |
| + block_size_n |
| + ) |
| + |
| + x_max = x_view.abs().float().amax(dim=(1, 3), keepdim=True) |
| + q_max = 7 |
| + x_scale = x_max / q_max |
| + |
| + x_scale = x_scale.clamp(min=1e-5) |
| + |
| + x_div = x_view / x_scale |
| + x_round = torch.round(x_div) |
| + |
| + x_q_clamped = x_round.clamp(-q_max, q_max) |
| + |
| + x_dequant_view = x_q_clamped * x_scale |
| + |
| + x_dequant_full = x_dequant_view.view_as(x_padded) |
| + x_out = x_dequant_full[:m, :n].contiguous().to(x.dtype) |
| + |
| + return x_out |
| + |
| + @staticmethod |
| + def backward(ctx, grad_output): |
| + return grad_output, None |
| + |
| + def fake_int4_quantization_ste(x, group_size): |
| + x_out = _FakeInt4QuantizationSTE.apply(x, group_size) |
| + |
| + if hasattr(x, 'main_grad'): |
| + x_out.main_grad = x.main_grad |
| + |
| + return x_out |
| |
| class TEGroupedLinear(te.pytorch.GroupedLinear): |
| """ |
| @@ -1351,6 +1407,7 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): |
| _is_first_microbatch = ( |
| None if self.disable_parameter_transpose_cache else self.is_first_microbatch |
| ) |
| + |
| out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) |
| self.is_first_microbatch = False |
| |
| @@ -1361,6 +1418,20 @@ if HAVE_TE and is_te_min_version("1.9.0.dev0"): |
| return out |
| return out, None |
| |
| + def _get_weight_tensors(self): |
| + """Get the weight tensors of the module.""" |
| + weight_tensors = super()._get_weight_tensors() |
| + |
| + if os.getenv("OPEN_TRAINING_INT4_FAKE_QAT_FLAG", "0") == "1": |
| + group_size = int(os.getenv("OPEN_TRAINING_INT4_GROUP_SIZE", "128")) |
| + |
| + weight_tensors = [ |
| + fake_int4_quantization_ste(w, group_size) |
| + for w in weight_tensors |
| + ] |
| + |
| + return weight_tensors |
| + |
| def _encode_extra_state(self, state): |
| # TE 2.0 changed the format of extra_state to be a byte tensor |
| if is_te_min_version("2.0.0"): |
| |
| |
| |
| |
| @@ -385,6 +385,7 @@ def rotary_fwd_kv_kernel( |
| SIN, |
| emb_dim: tl.constexpr, |
| k_dim: tl.constexpr, |
| + k_dim_ceil: tl.constexpr, |
| v_dim: tl.constexpr, |
| head_num: tl.constexpr, |
| batch_size, |
| @@ -434,21 +435,27 @@ def rotary_fwd_kv_kernel( |
| cos_right = tl.load(COS + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) |
| sin_right = tl.load(SIN + token_idx * emb_dim + emb_dim // 2 + tl.arange(0, emb_dim // 2)) |
| |
| - KV_ptr = KV + pid_m * stride_kv_seq + pid_head * BLOCK_H * stride_kv_nheads |
| - kv_off = tl.arange(0, BLOCK_H)[:, None] * stride_kv_nheads |
| - mask = kv_off < head_num * stride_kv_nheads |
| - k_in_off = kv_off + tl.arange(0, k_dim)[None, :] |
| - v_in_off = kv_off + k_dim + tl.arange(0, v_dim)[None, :] |
| - k = tl.load(KV_ptr + k_in_off, mask=mask) |
| - v = tl.load(KV_ptr + v_in_off, mask=mask) |
| + KV_ptr = KV + pid_m * stride_kv_seq # + pid_head * BLOCK_H * stride_kv_nheads |
| + ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H |
| + kj_range = tl.arange(0, k_dim_ceil)[None, :] |
| + mask_k = (ki_range < head_num) & (kj_range < k_dim) |
| + mask_v = ki_range < head_num |
| + k_off = ki_range * stride_kv_nheads + kj_range |
| + if v_dim > 0: |
| + v_off = ki_range * stride_kv_nheads + k_dim + tl.arange(0, v_dim)[None, :] |
| + v = tl.load(KV_ptr + v_off, mask=mask_v) |
| + else: |
| + v = tl.zeros((BLOCK_H, 1), dtype=KV.dtype.element_ty) |
| + k = tl.load(KV_ptr + k_off, mask=mask_k) |
| |
| - K_ptr = O_KEY + pid_m * stride_k_seq + pid_head * BLOCK_H * stride_k_nheads |
| - V_ptr = O_VALUE + pid_m * stride_v_seq + pid_head * BLOCK_H * stride_v_nheads |
| + K_ptr = O_KEY + pid_m * stride_k_seq # + pid_head * BLOCK_H * stride_k_nheads |
| + V_ptr = O_VALUE + pid_m * stride_v_seq # + pid_head * BLOCK_H * stride_v_nheads |
| |
| - k_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads + tl.arange(0, k_dim)[None, :] |
| - v_out_off = tl.arange(0, BLOCK_H)[:, None] * stride_v_nheads + tl.arange(0, v_dim)[None, :] |
| - tl.store(K_ptr + k_out_off, k, mask=mask) |
| - tl.store(V_ptr + v_out_off, v, mask=mask) |
| + k_out_off = ki_range * stride_k_nheads + kj_range |
| + tl.store(K_ptr + k_out_off, k, mask=mask_k) |
| + if v_dim > 0: |
| + v_out_off = ki_range * stride_v_nheads + tl.arange(0, v_dim)[None, :] |
| + tl.store(V_ptr + v_out_off, v, mask=mask_v) |
| |
| EMB = K_POS_EMB + pid_m * stride_emb_seq |
| # x1 = t[..., 0::2], x2 = t[..., 1::2] |
| @@ -460,14 +467,16 @@ def rotary_fwd_kv_kernel( |
| x_left = x_left.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) |
| x_right = x_right.expand_dims(0).broadcast_to(BLOCK_H, emb_dim // 2) |
| |
| + x_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H |
| + mask_x = x_range < head_num |
| x_left_off = ( |
| - tl.arange(0, BLOCK_H)[:, None] * stride_k_nheads |
| + x_range * stride_k_nheads |
| + k_dim |
| + tl.arange(0, emb_dim // 2)[None, :] |
| ) |
| x_right_off = x_left_off + emb_dim // 2 |
| - tl.store(K_ptr + x_left_off, x_left, mask=mask) |
| - tl.store(K_ptr + x_right_off, x_right, mask=mask) |
| + tl.store(K_ptr + x_left_off, x_left, mask=mask_x) |
| + tl.store(K_ptr + x_right_off, x_right, mask=mask_x) |
| |
| |
| @triton.autotune( |
| @@ -493,6 +502,7 @@ def rotary_bwd_kv_kernel( |
| SIN, |
| emb_dim: tl.constexpr, |
| k_dim: tl.constexpr, |
| + k_dim_ceil: tl.constexpr, |
| v_dim: tl.constexpr, |
| head_num: tl.constexpr, |
| batch_size, |
| @@ -533,27 +543,32 @@ def rotary_bwd_kv_kernel( |
| else: |
| token_idx = _get_thd_token_idx(cu_seqlens_kv, pid_m, seq_num, cp_rank, cp_size) |
| |
| - dKV_ptr = dKV + pid_m * stride_dkv_seq + pid_head * BLOCK_H * stride_dkv_nheads |
| - dkv_off = tl.arange(0, BLOCK_H)[:, None] * stride_dkv_nheads |
| - mask = dkv_off < head_num * stride_dkv_nheads |
| - dk_out_off = dkv_off + tl.arange(0, k_dim)[None, :] |
| - dv_out_off = dkv_off + k_dim + tl.arange(0, v_dim)[None, :] |
| - |
| - dK_ptr = dK + pid_m * stride_dk_seq + pid_head * BLOCK_H * stride_dk_nheads |
| - dV_ptr = dV + pid_m * stride_dv_seq + pid_head * BLOCK_H * stride_dv_nheads |
| - dk_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + tl.arange(0, k_dim)[None, :] |
| - dv_in_off = tl.arange(0, BLOCK_H)[:, None] * stride_dv_nheads + tl.arange(0, v_dim)[None, :] |
| - dk = tl.load(dK_ptr + dk_in_off, mask=mask) |
| - dv = tl.load(dV_ptr + dv_in_off, mask=mask) |
| - tl.store(dKV_ptr + dk_out_off, dk, mask=mask) |
| - tl.store(dKV_ptr + dv_out_off, dv, mask=mask) |
| + dKV_ptr = dKV + pid_m * stride_dkv_seq # + pid_head * BLOCK_H * stride_dkv_nheads |
| + ki_range = tl.arange(0, BLOCK_H)[:, None] + pid_head * BLOCK_H |
| + kj_range = tl.arange(0, k_dim_ceil)[None, :] |
| + mask_k = (ki_range < head_num) & (kj_range < k_dim) |
| + mask_v = ki_range < head_num |
| + dk_out_off = ki_range * stride_dkv_nheads + kj_range |
| + |
| + dK_ptr = dK + pid_m * stride_dk_seq # + pid_head * BLOCK_H * stride_dk_nheads |
| + dV_ptr = dV + pid_m * stride_dv_seq # + pid_head * BLOCK_H * stride_dv_nheads |
| + dk_in_off = ki_range * stride_dk_nheads + kj_range |
| + |
| + dk = tl.load(dK_ptr + dk_in_off, mask=mask_k) |
| + tl.store(dKV_ptr + dk_out_off, dk, mask=mask_k) |
| + |
| + if v_dim > 0: |
| + dv_out_off = ki_range * stride_dkv_nheads + k_dim + tl.arange(0, v_dim)[None, :] |
| + dv_in_off = ki_range * stride_dv_nheads + tl.arange(0, v_dim)[None, :] |
| + dv = tl.load(dV_ptr + dv_in_off, mask=mask_v) |
| + tl.store(dKV_ptr + dv_out_off, dv, mask=mask_v) |
| |
| if pid_head == 0: |
| x_left_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) |
| x_right_accum = tl.zeros((BLOCK_H, emb_dim // 2), dtype=tl.float32) |
| for i in tl.static_range(triton.cdiv(head_num, BLOCK_H)): |
| - dK_ptr = dK + pid_m * stride_dk_seq + i * BLOCK_H * stride_dk_nheads |
| - x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim |
| + dK_ptr = dK + pid_m * stride_dk_seq # + i * BLOCK_H * stride_dk_nheads |
| + x_off = tl.arange(0, BLOCK_H)[:, None] * stride_dk_nheads + k_dim + i * BLOCK_H * stride_dk_nheads |
| mask = x_off < head_num * stride_dk_nheads |
| x_left_off = x_off + tl.arange(0, emb_dim // 2)[None, :] |
| x_right_off = x_left_off + emb_dim // 2 |
| @@ -632,6 +647,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): |
| |
| o_key = kv.new_empty(total_seqlen, nheads, emb_dim + k_dim) |
| o_value = kv.new_empty(total_seqlen, nheads, v_dim) |
| + k_dim_ceil = triton.next_power_of_2(k_dim) |
| |
| grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) |
| rotary_fwd_kv_kernel[grid]( |
| @@ -643,6 +659,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): |
| sin, |
| emb_dim, |
| k_dim, |
| + k_dim_ceil, |
| v_dim, |
| nheads, |
| batch_size, |
| @@ -700,6 +717,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): |
| |
| d_kv = dk.new_empty(total_seqlen, nheads, ctx.k_dim + ctx.v_dim) |
| d_emb = dk.new_empty(total_seqlen, 1, ctx.emb_dim) |
| + k_dim_ceil = triton.next_power_of_2(ctx.k_dim) |
| |
| grid = lambda META: (total_seqlen, triton.cdiv(nheads, META["BLOCK_H"])) |
| rotary_bwd_kv_kernel[grid]( |
| @@ -711,6 +729,7 @@ class ApplyMLARotaryEmbKV(torch.autograd.Function): |
| sin, |
| ctx.emb_dim, |
| ctx.k_dim, |
| + k_dim_ceil, |
| ctx.v_dim, |
| nheads, |
| batch_size, |
| |
| |
| |
| |
| @@ -184,7 +184,15 @@ class LanguageModule(MegatronModule): |
| assert ( |
| column_parallel_linear is not None |
| ), "column_parallel_linear cannot be None when not using fused linear cross entropy." |
| - logits, _ = column_parallel_linear(hidden, **col_linear_kwargs) |
| + # output |
| + output_layer_params = {k: v.detach() for k, v in column_parallel_linear.named_parameters()} |
| + output_layer_buffers = dict(column_parallel_linear.named_buffers()) |
| + logits, _ = torch.func.functional_call( |
| + column_parallel_linear, |
| + {**output_layer_params, **output_layer_buffers}, |
| + (hidden,), |
| + col_linear_kwargs, |
| + ) |
| |
| return self.compute_language_model_loss(labels, logits) |
| |
| |
| |
| |
| |
| @@ -188,6 +188,8 @@ def get_gpt_layer_with_transformer_engine_spec( |
| use_kitchen: bool = False, |
| use_te_activation_func: bool = False, |
| fallback_to_eager_attn: bool = False, |
| + post_self_attn_layernorm: bool = False, |
| + post_mlp_layernorm: bool = False, |
| ) -> ModuleSpec: |
| """Use this spec to use lower-level Transformer Engine modules (required for fp8 training). |
| |
| @@ -260,6 +262,8 @@ def get_gpt_layer_with_transformer_engine_spec( |
| mlp=mlp, |
| sharded_state_dict_keys_map=sharded_state_dict_keys_map, |
| normalization=normalization, |
| + post_self_attn_layernorm=post_self_attn_layernorm, |
| + post_mlp_layernorm=post_mlp_layernorm, |
| ) |
| |
| |
| @@ -349,6 +353,8 @@ def get_transformer_layer_spec_for_backend( |
| mlp: ModuleSpec, |
| sharded_state_dict_keys_map: Optional[dict] = None, |
| normalization: Optional[str] = None, |
| + post_self_attn_layernorm: bool = False, |
| + post_mlp_layernorm: bool = False, |
| ) -> ModuleSpec: |
| """Helper function to get module spec for TransformerLayer""" |
| |
| @@ -371,9 +377,11 @@ def get_transformer_layer_spec_for_backend( |
| input_layernorm=input_layernorm, |
| self_attention=attention, |
| self_attn_bda=get_bias_dropout_add, |
| + post_self_attn_layernorm=TENorm if post_self_attn_layernorm else IdentityOp, |
| pre_mlp_layernorm=pre_mlp_layernorm, |
| mlp=mlp, |
| mlp_bda=get_bias_dropout_add, |
| + post_mlp_layernorm=TENorm if post_mlp_layernorm else IdentityOp, |
| sharded_state_dict_keys_map=sharded_state_dict_keys_map, |
| ), |
| ) |
| |
| |
| |
| |
| @@ -446,6 +446,7 @@ class GPTModel(LanguageModule): |
| *, |
| inference_params: Optional[BaseInferenceContext] = None, |
| loss_mask: Optional[Tensor] = None, |
| + mtp_kwargs: Optional[dict] = {}, |
| ) -> Tensor: |
| """Forward function of the GPT Model This function passes the input tensors |
| through the embedding layer, and then the decoder and finally into the post |
| @@ -508,6 +509,7 @@ class GPTModel(LanguageModule): |
| runtime_gather_output=runtime_gather_output, |
| extra_block_kwargs=extra_block_kwargs, |
| inference_context=inference_context, |
| + mtp_kwargs=mtp_kwargs, |
| ) |
| |
| def _postprocess( |
| @@ -529,6 +531,7 @@ class GPTModel(LanguageModule): |
| runtime_gather_output=None, |
| extra_block_kwargs=None, |
| inference_context=None, |
| + mtp_kwargs={}, |
| ): |
| """Postprocesses decoder hidden states to generate logits or compute loss. |
| |
| @@ -543,7 +546,8 @@ class GPTModel(LanguageModule): |
| output_weight = None |
| if self.share_embeddings_and_output_weights: |
| output_weight = self.shared_embedding_or_output_weight() |
| - if mtp_in_postprocess: |
| + |
| + if mtp_in_postprocess and mtp_kwargs.get('mtp_labels', None) is not None: |
| hidden_states = self.mtp( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| @@ -563,13 +567,18 @@ class GPTModel(LanguageModule): |
| return hidden_states |
| |
| # Skip when mtp_num_layers is None or 0 |
| - if self.config.mtp_num_layers: |
| - mtp_labels = labels.clone() |
| + if self.config.mtp_num_layers and mtp_kwargs.get('mtp_labels', None) is not None: |
| + mtp_labels = mtp_kwargs['mtp_labels'].clone() |
| + mtp_labels, _ = roll_tensor(mtp_labels, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) |
| + |
| hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) |
| hidden_states = hidden_states_list[0] |
| if loss_mask is None: |
| # if loss_mask is not provided, use all ones as loss_mask |
| loss_mask = torch.ones_like(mtp_labels) |
| + else: |
| + # Otherwise, roll the loss_mask to keep up with the mtp_labels |
| + loss_mask, _ = roll_tensor(loss_mask, shifts=-1, dims=-1, cp_group=self.cp_group, packed_seq_params=packed_seq_params) |
| for mtp_layer_number in range(self.config.mtp_num_layers): |
| # Calc loss for the current Multi-Token Prediction (MTP) layers. |
| mtp_labels, _ = roll_tensor( |
| @@ -595,7 +604,7 @@ class GPTModel(LanguageModule): |
| sequence_parallel_enabled=self.output_layer.sequence_parallel, |
| column_parallel_linear=self.output_layer, |
| col_linear_kwargs={ |
| - 'weight': output_weight, |
| + 'weight': output_weight.detach() if output_weight else None, |
| 'runtime_gather_output': runtime_gather_output, |
| }, |
| ) |
| |
| |
| |
| |
| @@ -677,6 +677,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): |
| # TE FusedAdam will not accumulate step for empty param groups, so we need to |
| # align the step across param groups. |
| param_group["step"] = int(step) |
| + if "step" in param_group and param_group["step"] is None: |
| + del param_group["step"] |
| |
| # Grad scaler state. |
| if self.grad_scaler: |
| @@ -1646,6 +1648,8 @@ class DistributedOptimizer(MixedPrecisionOptimizer): |
| if key == 'padding': |
| tensors[key] = LocalNonpersistentObject(tensors[key]) |
| continue |
| + if key == 'step': |
| + continue |
| assert tensors[key].shape == (gbuf_local_end - gbuf_local_start,), ( |
| tensors[key].shape, |
| gbuf_local_start, |
| |
| |
| |
| |
| @@ -11,6 +11,7 @@ from typing import Callable, List, Optional |
| |
| import numpy as np |
| import torch |
| +import torch.distributed as dist |
| |
| from .utils import GlobalMemoryBuffer, is_torch_min_version |
| |
| |
| |
| |
| |
| @@ -26,22 +26,22 @@ def _batched_p2p_ops( |
| ops = [] |
| if tensor_send_prev is not None: |
| send_prev_op = torch.distributed.P2POp( |
| - torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, group |
| + torch.distributed.isend, tensor_send_prev, prev_pipeline_rank, |
| ) |
| ops.append(send_prev_op) |
| if tensor_recv_prev is not None: |
| recv_prev_op = torch.distributed.P2POp( |
| - torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, group |
| + torch.distributed.irecv, tensor_recv_prev, prev_pipeline_rank, |
| ) |
| ops.append(recv_prev_op) |
| if tensor_send_next is not None: |
| send_next_op = torch.distributed.P2POp( |
| - torch.distributed.isend, tensor_send_next, next_pipeline_rank, group |
| + torch.distributed.isend, tensor_send_next, next_pipeline_rank, |
| ) |
| ops.append(send_next_op) |
| if tensor_recv_next is not None: |
| recv_next_op = torch.distributed.P2POp( |
| - torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, group |
| + torch.distributed.irecv, tensor_recv_next, next_pipeline_rank, |
| ) |
| ops.append(recv_next_op) |
| if len(ops) > 0: |
| |
| |
| |
| |
| @@ -587,6 +587,9 @@ def topk_routing_with_score_function( |
| else: |
| return torch.topk(scores, k=topk, dim=1) |
| |
| + from slime.utils.routing_replay import get_routing_replay_compute_topk |
| + compute_topk = get_routing_replay_compute_topk(compute_topk) |
| + |
| if score_function == "softmax": |
| if use_pre_softmax: |
| scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) |
| |
| |
| |
| |
| @@ -201,6 +201,9 @@ class TopKRouter(Router): |
| self.global_tokens_per_expert = None |
| self.ga_steps = None |
| |
| + from slime.utils.routing_replay import register_routing_replay |
| + register_routing_replay(self) |
| + |
| def _maintain_float32_expert_bias(self): |
| """ |
| Maintain the expert bias in float32. |
| |
| |
| |
| |
| @@ -6,6 +6,7 @@ from typing import Callable, List, Optional, Union |
| |
| import torch |
| from torch import Tensor |
| +import warnings |
| |
| from megatron.core import InferenceParams, parallel_state, tensor_parallel |
| from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
| @@ -714,17 +715,19 @@ class MultiTokenPredictionLayer(MegatronModule): |
| cp_group=self.cp_group, |
| packed_seq_params=packed_seq_params, |
| ) |
| - position_ids, _ = roll_tensor( |
| - position_ids, |
| - shifts=-1, |
| - dims=-1, |
| - cp_group=self.cp_group, |
| - packed_seq_params=packed_seq_params, |
| - ) |
| + if position_ids is not None: |
| + position_ids, _ = roll_tensor( |
| + position_ids, |
| + shifts=-1, |
| + dims=-1, |
| + cp_group=self.cp_group, |
| + packed_seq_params=packed_seq_params, |
| + ) |
| # embedding |
| decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) |
| + decoder_input = decoder_input.detach() |
| |
| - hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) |
| + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=False) |
| |
| return input_ids, position_ids, decoder_input, hidden_states |
| |
| @@ -826,6 +829,51 @@ class MultiTokenPredictionLayer(MegatronModule): |
| return hidden_states |
| |
| def _checkpointed_forward(self, forward_func, *args, **kwargs): |
| + """Wrap `forward_func` with activation checkpointing while only passing tensors. |
| + |
| + Non-tensor arguments (e.g., configuration objects, None) are captured via closure so |
| + that checkpoint implementations never receive them directly, avoiding save_for_backward |
| + issues with non-tensor inputs. |
| + """ |
| + |
| + # TODO(jiajun): Is there any better implementation here? |
| + positional_specs = [] |
| + kw_specs = [] |
| + tensor_args: List[torch.Tensor] = [] |
| + |
| + for arg in args: |
| + if torch.is_tensor(arg): |
| + positional_specs.append(('tensor', len(tensor_args))) |
| + tensor_args.append(arg) |
| + else: |
| + positional_specs.append(('const', arg)) |
| + |
| + for key, value in kwargs.items(): |
| + if torch.is_tensor(value): |
| + kw_specs.append((key, ('tensor', len(tensor_args)))) |
| + tensor_args.append(value) |
| + else: |
| + kw_specs.append((key, ('const', value))) |
| + |
| + def run(*flat_tensor_args): |
| + rebuilt_args = [] |
| + for spec_type, payload in positional_specs: |
| + if spec_type == 'tensor': |
| + rebuilt_args.append(flat_tensor_args[payload]) |
| + else: |
| + rebuilt_args.append(payload) |
| + |
| + rebuilt_kwargs = {} |
| + for key, (spec_type, payload) in kw_specs: |
| + if spec_type == 'tensor': |
| + rebuilt_kwargs[key] = flat_tensor_args[payload] |
| + else: |
| + rebuilt_kwargs[key] = payload |
| + |
| + return forward_func(*rebuilt_args, **rebuilt_kwargs) |
| + |
| + tensor_args_tuple = tuple(tensor_args) |
| + |
| def checkpoint_handler(): |
| """Determines whether to use the `te_checkpoint` or `tensor_parallel.checkpoint`""" |
| if self.config.fp8: |
| @@ -836,12 +884,11 @@ class MultiTokenPredictionLayer(MegatronModule): |
| self.config.distribute_saved_activations, |
| tensor_parallel.random.get_cuda_rng_tracker, |
| parallel_state.get_tensor_model_parallel_group(), |
| - *args, |
| - **kwargs, |
| + *tensor_args_tuple, |
| ) |
| else: |
| return tensor_parallel.checkpoint( |
| - forward_func, self.config.distribute_saved_activations, *args, *kwargs.values() |
| + run, self.config.distribute_saved_activations, *tensor_args_tuple |
| ) |
| |
| if self.config.recompute_method == 'uniform': |
| |
| |
| |
| |
| @@ -210,6 +210,9 @@ class TransformerConfig(ModelParallelConfig): |
| attention_output_gate: bool = False |
| """Whether to apply output gate to the attention layers.""" |
| |
| + post_self_attn_layernorm: bool = False |
| + post_mlp_layernorm: bool = False |
| + |
| test_mode: bool = False |
| """Whether to run real-time tests.""" |
| |
| |
| |
| |
| |
| @@ -223,6 +223,7 @@ class TransformerLayerSubmodules: |
| input_layernorm: Union[ModuleSpec, type] = IdentityOp |
| self_attention: Union[ModuleSpec, type] = IdentityOp |
| self_attn_bda: Union[ModuleSpec, type] = IdentityFuncOp |
| + post_self_attn_layernorm: Union[ModuleSpec, type] = IdentityOp |
| |
| pre_cross_attn_layernorm: Union[ModuleSpec, type] = IdentityOp |
| cross_attention: Union[ModuleSpec, type] = IdentityOp |
| @@ -231,6 +232,7 @@ class TransformerLayerSubmodules: |
| pre_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp |
| mlp: Union[ModuleSpec, type] = IdentityOp |
| mlp_bda: Union[ModuleSpec, type] = IdentityFuncOp |
| + post_mlp_layernorm: Union[ModuleSpec, type] = IdentityOp |
| |
| # Mapping for sharded tensor keys to be applied in `sharded_state_dict` method |
| sharded_state_dict_keys_map: Dict[str, str] = field(default_factory=dict) |
| @@ -310,6 +312,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): |
| # [Module 3: BiasDropoutFusion] |
| self.self_attn_bda = build_module(submodules.self_attn_bda) |
| |
| + self.post_self_attn_layernorm = build_module( |
| + submodules.post_self_attn_layernorm, |
| + config=self.config, |
| + hidden_size=self.config.hidden_size, |
| + eps=self.config.layernorm_epsilon, |
| + ) |
| + |
| # [Module 4: Post SelfAttention] Optional Layernorm after self-attn |
| self.pre_cross_attn_layernorm = build_module( |
| submodules.pre_cross_attn_layernorm, |
| @@ -375,6 +384,13 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): |
| |
| self.is_moe_layer = isinstance(self.mlp, MoELayer) |
| |
| + self.post_mlp_layernorm = build_module( |
| + submodules.post_mlp_layernorm, |
| + config=self.config, |
| + hidden_size=self.config.hidden_size, |
| + eps=self.config.layernorm_epsilon |
| + ) |
| + |
| self.recompute_input_layernorm = False |
| self.recompute_pre_mlp_layernorm = False |
| self.recompute_mlp = False |
| @@ -551,6 +567,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): |
| attention_output_with_bias[0] |
| ) |
| |
| + attention_output, attention_output_bias = attention_output_with_bias |
| + attention_output = self.post_self_attn_layernorm(attention_output) |
| + attention_output_with_bias = (attention_output, attention_output_bias) |
| + |
| # TODO: could we move `bias_dropout_add_exec_handler` itself |
| # inside the module provided in the `bias_dropout_add_spec` module? |
| nvtx_range_push(suffix="self_attn_bda") |
| @@ -677,6 +697,10 @@ class TransformerLayer(GraphableMegatronModule, BaseTransformerLayer): |
| else: |
| mlp_output_with_bias = self.mlp(pre_mlp_layernorm_output) |
| |
| + mlp_output, mlp_output_bias = mlp_output_with_bias |
| + mlp_output = self.post_mlp_layernorm(mlp_output) |
| + mlp_output_with_bias = (mlp_output, mlp_output_bias) |
| + |
| if self.recompute_pre_mlp_layernorm: |
| # discard the output of the pre-mlp layernorm and register the recompute |
| # as a gradient hook of mlp_output_with_bias[0] |
| |
| |
| |
| |
| @@ -1398,6 +1398,9 @@ def core_transformer_config_from_args(args, config_class=None): |
| |
| kw_args['inference_sampling_seed'] = args.seed |
| |
| + kw_args['post_self_attn_layernorm'] = args.post_self_attn_layernorm |
| + kw_args['post_mlp_layernorm'] = args.post_mlp_layernorm |
| + |
| # handle quantization config |
| # NOTE: Kitchen arguments are only added to the namespace when |
| # Kitchen library is available. |
| @@ -1764,6 +1767,12 @@ def _add_network_size_args(parser): |
| action='store_true', |
| help='If set, use original BERT residula connection ' |
| 'ordering.') |
| + group.add_argument('--post-self-attn-layernorm', action='store_true', |
| + help='If set, use post self attention layernorm.') |
| + group.add_argument('--post-mlp-layernorm', action='store_true', |
| + help='If set, use post MLP layernorm.') |
| + group.add_argument('--use-gated-attention', action='store_true', |
| + help='If set, use gated attention as in Qwen3Next') |
| group.add_argument('--openai-gelu', action='store_true', |
| help='Use OpenAIs GeLU implementation. This option' |
| 'should not be used unless for backward compatibility' |
| |
| |
| |
| |
| @@ -136,7 +136,7 @@ class _HuggingFaceTokenizer(MegatronLegacyTokenizer): |
| # TODO(bnorick): download tokenizer once to lustre and use force offline to make sure all tasks read it from there |
| self._tokenizer = transformers.AutoTokenizer.from_pretrained( |
| pretrained_model_name_or_path=pretrained_model_name_or_path, |
| - trust_remote_code=trust_remote_code, |
| + trust_remote_code=True, |
| **kwargs, |
| ) |
| self._vocab = self._tokenizer.get_vocab() |
|
|