Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 NVIDIA CORPORATION. | |
| # Licensed under the MIT license. | |
| # Adapted from https://github.com/NVlabs/VILA/tree/main under the Apache 2.0 license. | |
| # LICENSE is in incl_licenses directory. | |
| # Copyright 2024 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved. | |
| # | |
| # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX | |
| # and OPT implementations in this library. It has been modified from its | |
| # original forms to accommodate minor architectural differences compared | |
| # to GPT-NeoX and OPT used by the Meta AI team that trained the model. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """PyTorch Qwen2 model.""" | |
| import math | |
| import os | |
| from dataclasses import asdict, dataclass, field | |
| from fnmatch import fnmatch | |
| from functools import lru_cache | |
| from typing import List, Optional, Tuple, Union | |
| import torch | |
| import torch.utils.checkpoint | |
| from torch import nn | |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss | |
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM, PretrainedConfig | |
| from transformers.activations import ACT2FN | |
| from transformers.cache_utils import Cache, DynamicCache, StaticCache | |
| from transformers.modeling_attn_mask_utils import AttentionMaskConverter | |
| from transformers.modeling_outputs import ( | |
| BaseModelOutputWithPast, | |
| CausalLMOutputWithPast, | |
| SequenceClassifierOutputWithPast, | |
| TokenClassifierOutput, | |
| ) | |
| from transformers.modeling_utils import PreTrainedModel | |
| from transformers.models.qwen2.configuration_qwen2 import Qwen2Config | |
| from transformers.models.qwen2.modeling_qwen2 import ( | |
| Qwen2Attention, | |
| Qwen2DecoderLayer, | |
| Qwen2FlashAttention2, | |
| Qwen2ForCausalLM, | |
| Qwen2MLP, | |
| Qwen2Model, | |
| Qwen2PreTrainedModel, | |
| Qwen2RMSNorm, | |
| Qwen2RotaryEmbedding, | |
| Qwen2SdpaAttention, | |
| apply_rotary_pos_emb, | |
| repeat_kv, | |
| rotate_half, | |
| ) | |
| from transformers.utils import ( | |
| add_start_docstrings, | |
| add_start_docstrings_to_model_forward, | |
| is_flash_attn_2_available, | |
| is_flash_attn_greater_or_equal_2_10, | |
| logging, | |
| replace_return_docstrings, | |
| ) | |
| # FP8 related | |
| from ..coat.activation.models._fp8_quantization_config import QuantizationConfig | |
| from ..coat.activation.models._fp8_weightcache import FP8CacheWeightModule | |
| from ..coat.activation.models._fp8manager import FP8Manager | |
| from ..coat.activation.real_quantization import ( | |
| Coat_quantize_bgn, | |
| Coat_quantize_end, | |
| fp8_add_Ifp_Ifp_Ofp_Og16, | |
| fp8_add_Ifp_Ifp_Ofp_Opt, | |
| fp8_division, | |
| fp8_division_transpose, | |
| fp8_gelu_backward, | |
| fp8_gelu_forward, | |
| fp8_layernorm_noparam_backward, | |
| fp8_layernorm_noparam_forward, | |
| fp8_linear_backward, | |
| fp8_linear_forward, | |
| fp8_mul_backward, | |
| fp8_mul_forward, | |
| fp8_quantize, | |
| fp8_quantize_pertensor, | |
| fp8_quantize_pertensor_transpose, | |
| fp8_rmsnorm_backward, | |
| fp8_rmsnorm_forward, | |
| fp8_silu_backward, | |
| fp8_silu_forward, | |
| fp8_transpose, | |
| ) | |
| from ..liger.cross_entropy import LigerForCausalLMLoss | |
| from ..qlinear_te import QLinearTE | |
| # from .configuration_quantize import QuantizationConfig | |
| if is_flash_attn_2_available(): | |
| from transformers.modeling_flash_attention_utils import _flash_attention_forward | |
| logger = logging.get_logger(__name__) | |
| class FP8ActivationQwen2Config(Qwen2Config): | |
| model_type = "fp8activation_qwen2" | |
| def __init__( | |
| self, | |
| coat_fp8_args=None, | |
| vocab_size=151936, | |
| hidden_size=4096, | |
| intermediate_size=22016, | |
| num_hidden_layers=32, | |
| num_attention_heads=32, | |
| num_key_value_heads=32, | |
| hidden_act="silu", | |
| max_position_embeddings=32768, | |
| initializer_range=0.02, | |
| rms_norm_eps=1e-6, | |
| use_cache=True, | |
| tie_word_embeddings=False, | |
| rope_theta=10000.0, | |
| rope_scaling=None, | |
| use_sliding_window=False, | |
| sliding_window=4096, | |
| max_window_layers=28, | |
| attention_dropout=0.0, | |
| **kwargs, | |
| ): | |
| super().__init__( | |
| vocab_size, | |
| hidden_size, | |
| intermediate_size, | |
| num_hidden_layers, | |
| num_attention_heads, | |
| num_key_value_heads, | |
| hidden_act, | |
| max_position_embeddings, | |
| initializer_range, | |
| rms_norm_eps, | |
| use_cache, | |
| tie_word_embeddings, | |
| rope_theta, | |
| rope_scaling, | |
| use_sliding_window, | |
| sliding_window, | |
| max_window_layers, | |
| attention_dropout, | |
| **kwargs, | |
| ) | |
| self.coat_fp8_args = coat_fp8_args | |
| class FP8ActivationQwen2BeforeAttentionResidual(FP8CacheWeightModule): | |
| """ | |
| This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 1 * Linear layers | |
| """ | |
| def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_idx: Optional[int] = None): | |
| super().__init__(config, qargs, layer_idx) | |
| self.qargs = qargs | |
| self.fwobits = { | |
| "fabit": self.qargs.fabit, | |
| "fwbit": self.qargs.fwbit, | |
| "fobit": self.qargs.fobit, | |
| "babit": self.qargs.babit, | |
| "bwbit": self.qargs.bwbit, | |
| "bobit": self.qargs.bobit, | |
| } | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| if layer_idx is None: | |
| logger.warning_once( | |
| f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " | |
| "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " | |
| "when creating this class." | |
| ) | |
| self.attention_dropout = config.attention_dropout | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.max_position_embeddings = config.max_position_embeddings | |
| self.rope_theta = config.rope_theta | |
| self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True) | |
| self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) | |
| self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True) | |
| def forward(self, re_x, x, s, rmsnorm_weight): | |
| if self.training: | |
| if self.qargs.weight_memory_efficient: | |
| # Prepare | |
| with torch.no_grad(): | |
| if FP8Manager.is_first_microbatch: | |
| # Directly use the corresponding weight | |
| weight1, _, weight1_s = self.prepare_weight( | |
| self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight2, _, weight2_s = self.prepare_weight( | |
| self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight3, _, weight3_s = self.prepare_weight( | |
| self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch | |
| ) | |
| else: | |
| weight1_s = self.prepare_weight(self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch) | |
| weight2_s = self.prepare_weight(self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch) | |
| weight3_s = self.prepare_weight(self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch) | |
| weight1, weight2, weight3 = None, None, None | |
| return _FP8ActivationQwen2BeforeAttentionResidual.apply( | |
| re_x, | |
| x, | |
| s, | |
| self.q_proj.weight, | |
| weight1, | |
| None, | |
| weight1_s, | |
| self.q_proj.bias, | |
| self.k_proj.weight, | |
| weight2, | |
| None, | |
| weight2_s, | |
| self.k_proj.bias, | |
| self.v_proj.weight, | |
| weight3, | |
| None, | |
| weight3_s, | |
| self.v_proj.bias, | |
| rmsnorm_weight, | |
| self.qargs.group_size, | |
| self.fwobits, | |
| self.layer_id, | |
| self.config, | |
| self.qargs, | |
| ) | |
| else: | |
| # Prepare | |
| with torch.no_grad(): | |
| weight1, weight1_t, weight1_s = self.prepare_weight( | |
| self.q_proj.weight, "q_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight2, weight2_t, weight2_s = self.prepare_weight( | |
| self.k_proj.weight, "k_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight3, weight3_t, weight3_s = self.prepare_weight( | |
| self.v_proj.weight, "v_proj", FP8Manager.is_first_microbatch | |
| ) | |
| return _FP8ActivationQwen2BeforeAttentionResidual.apply( | |
| re_x, | |
| x, | |
| s, | |
| self.q_proj.weight, | |
| weight1, | |
| weight1_t, | |
| weight1_s, | |
| self.k_proj.weight, | |
| weight2, | |
| weight2_t, | |
| weight2_s, | |
| self.v_proj.weight, | |
| weight3, | |
| weight3_t, | |
| weight3_s, | |
| rmsnorm_weight, | |
| self.qargs.group_size, | |
| self.fwobits, | |
| self.layer_id, | |
| self.config, | |
| self.qargs, | |
| ) | |
| else: | |
| raise NotImplementedError("This should be implemented in the future") | |
| return re_x, self.att_proj(self.attn_norm(re_x)) | |
| class _FP8ActivationQwen2BeforeAttentionResidual(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| re_x, | |
| in_x, | |
| in_s, | |
| weight1_origin, | |
| weight1, | |
| weight1_t, | |
| weight1_s, | |
| weight1_bias, | |
| weight2_origin, | |
| weight2, | |
| weight2_t, | |
| weight2_s, | |
| weight2_bias, | |
| weight3_origin, | |
| weight3, | |
| weight3_t, | |
| weight3_s, | |
| weight3_bias, | |
| rmsnorm_weight, | |
| group_size, | |
| fwobits, | |
| layer_id, | |
| config, | |
| qargs, | |
| eps=1e-5, | |
| ): | |
| # for autograd | |
| if fwobits["fabit"] == "E4M3": | |
| # in_x = in_x.to(torch.float8_e4m3fn) | |
| in_x = in_x.view(torch.float8_e4m3fn) | |
| else: | |
| raise ValueError("fabit should be E4M3") | |
| # LayerNorm | |
| ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( | |
| in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True | |
| ) | |
| # Linear Layer QKV Projection | |
| if qargs.weight_memory_efficient: | |
| assert ( | |
| weight1_t is None and weight2_t is None and weight3_t is None | |
| ) # we should not pass W^T to here if weight memory efficient | |
| if weight1 is None: | |
| assert weight2 is None and weight3 is None | |
| weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) | |
| weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) | |
| weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) | |
| fc1_x = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, False, group_size, bias=weight1_bias) # query states | |
| fc2_x = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, False, group_size, bias=weight2_bias) # key states | |
| fc3_x = fp8_linear_forward(ln_x, ln_s, weight3, weight3_s, False, group_size, bias=weight3_bias) # value states | |
| # ==================== save for backward ==================== | |
| ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s) | |
| if qargs.weight_memory_efficient: | |
| assert weight1_t is None and weight2_t is None and weight3_t is None | |
| ctx.weight = weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s | |
| else: | |
| ctx.weight = weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s | |
| ctx.bias = weight1_bias, weight2_bias, weight3_bias | |
| ctx.group_size = group_size | |
| ctx.ln_utils = ln_utils | |
| ctx.utils = fwobits, layer_id, config, qargs | |
| return re_x, fc1_x, fc2_x, fc3_x | |
| def backward(ctx, fp_grad, query_g, key_g, value_g): | |
| in_x, in_s, ln_x_t, ln_s = ctx.saved_tensors | |
| weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s = ctx.weight | |
| weight1_bias, weight2_bias, weight3_bias = ctx.bias | |
| group_size = ctx.group_size | |
| rms_weight, rstd, num_warps = ctx.ln_utils | |
| fwobits, layer_id, config, qargs = ctx.utils | |
| # ==================== Begin backward ==================== | |
| # Gradient of Bias TODO: make this better | |
| if weight1_bias is not None and weight2_bias is not None and weight3_bias is not None: | |
| att_q_bg = query_g.reshape(-1, query_g.shape[-1]).sum(0) | |
| att_k_bg = key_g.reshape(-1, key_g.shape[-1]).sum(0) | |
| att_v_bg = value_g.reshape(-1, value_g.shape[-1]).sum(0) | |
| else: | |
| att_q_bg = None | |
| att_k_bg = None | |
| att_v_bg = None | |
| # Quantize the RoPE and FlashAttention Output. grad_input and grad_weight requires different data layout. | |
| query_g, query_gs, query_g_t = fp8_quantize_pertensor_transpose( | |
| query_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False | |
| ) | |
| key_g, key_gs, key_g_t = fp8_quantize_pertensor_transpose( | |
| key_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False | |
| ) | |
| value_g, value_gs, value_g_t = fp8_quantize_pertensor_transpose( | |
| value_g, group_size, fwobits["babit"], transpose_output_2d=True, stochastic=False | |
| ) | |
| # Linear Layer QKV Projection | |
| if qargs.weight_memory_efficient: | |
| weight1_t, weight1_s = fp8_division_transpose( | |
| weight1_t, qargs.group_size, fwobits["fwbit"], weight1_s, only_transposed=True | |
| ) | |
| weight2_t, weight2_s = fp8_division_transpose( | |
| weight2_t, qargs.group_size, fwobits["fwbit"], weight2_s, only_transposed=True | |
| ) | |
| weight3_t, weight3_s = fp8_division_transpose( | |
| weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True | |
| ) | |
| fc1_g1, att_q_wg = fp8_linear_backward( | |
| ln_x_t, ln_s, query_g, query_gs, query_g_t, weight1_t, weight1_s, group_size | |
| ) | |
| fc1_g2, att_k_wg = fp8_linear_backward(ln_x_t, ln_s, key_g, key_gs, key_g_t, weight2_t, weight2_s, group_size) | |
| fc1_g3, att_v_wg = fp8_linear_backward( | |
| ln_x_t, ln_s, value_g, value_gs, value_g_t, weight3_t, weight3_s, group_size | |
| ) | |
| fc1_g = fc1_g1 + fc1_g2 + fc1_g3 | |
| # LayerNorm | |
| in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc1_g, rms_weight, rstd, group_size, num_warps) | |
| # Add the gradient together, and prepare the input of the next layer. | |
| re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( | |
| fp_grad, in_g, group_size, fwobits["babit"], stochastic=False | |
| ) | |
| # for autograd. forward's data type should be the same of backward tensor. this will not change the actual binary representation. | |
| in_g = in_g.view(torch.float8_e4m3fn) | |
| # Although the next operator is a linear layer in MLPResidual module, we return in_sg_g16 to make the size compatible with the forward. Otherwise it will not pass autograd. | |
| return ( | |
| re_g, | |
| in_g, | |
| in_sg_g16, | |
| att_q_wg, | |
| None, | |
| None, | |
| None, | |
| att_q_bg, | |
| att_k_wg, | |
| None, | |
| None, | |
| None, | |
| att_k_bg, | |
| att_v_wg, | |
| None, | |
| None, | |
| None, | |
| att_v_bg, | |
| rms_weight_grad, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| class FP8ActivationQwen2AfterAttentionResidual(FP8CacheWeightModule): | |
| """ | |
| This is a typical transformer attention module that contains (1) Residual (2) 1 * Linear layers | |
| """ | |
| def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id): | |
| super().__init__(config, qargs, layer_id) | |
| self.qargs = qargs | |
| self.fwobits = { | |
| "fabit": self.qargs.fabit, | |
| "fwbit": self.qargs.fwbit, | |
| "fobit": self.qargs.fobit, | |
| "babit": self.qargs.babit, | |
| "bwbit": self.qargs.bwbit, | |
| "bobit": self.qargs.bobit, | |
| } | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = getattr(config, "head_dim", self.hidden_size // self.num_heads) | |
| self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) | |
| def forward(self, re_x, in_x): | |
| if self.training: | |
| if self.qargs.weight_memory_efficient: | |
| # prepare for the weight | |
| with torch.no_grad(): | |
| if FP8Manager.is_first_microbatch: | |
| weight4, _, weight4_s = self.prepare_weight( | |
| self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch | |
| ) | |
| else: | |
| weight4_s = self.prepare_weight(self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch) | |
| weight1 = None | |
| return _FP8ActivationQwen2AfterAttentionResidual.apply( | |
| re_x, | |
| in_x, | |
| self.o_proj.weight, | |
| weight4, | |
| None, | |
| weight4_s, | |
| self.qargs.group_size, | |
| self.fwobits, | |
| self.layer_id, | |
| self.config, | |
| self.qargs, | |
| ) | |
| else: | |
| # prepare for the weight | |
| with torch.no_grad(): | |
| weight4, weight4_t, weight4_s = self.prepare_weight( | |
| self.o_proj.weight, "o_proj", FP8Manager.is_first_microbatch | |
| ) | |
| return _FP8ActivationQwen2AfterAttentionResidual.apply( | |
| re_x, | |
| in_x, | |
| self.o_proj.weight, | |
| weight4, | |
| weight4_t, | |
| weight4_s, | |
| self.qargs.group_size, | |
| self.fwobits, | |
| self.layer_id, | |
| self.config, | |
| self.qargs, | |
| ) | |
| else: | |
| return re_x + self.attn_out(in_x), None, None | |
| class _FP8ActivationQwen2AfterAttentionResidual(torch.autograd.Function): | |
| def forward( | |
| ctx, re_x, flash_x, weight4_origin, weight4, weight4_t, weight4_s, group_size, fwobits, layer_id, config, qargs | |
| ): | |
| time_bench = os.getenv("TIME_BENCH") | |
| if time_bench: | |
| start_1 = torch.cuda.Event(enable_timing=True) | |
| start_1.record() | |
| # Quantize the FlashAttention Output | |
| flash_qx, flash_s, _ = fp8_quantize_pertensor( | |
| flash_x, group_size, fwobits["fabit"] | |
| ) # Modified to make it memory efficient | |
| if time_bench: | |
| end_1 = torch.cuda.Event(enable_timing=True) | |
| end_1.record() | |
| start_2 = torch.cuda.Event(enable_timing=True) | |
| start_2.record() | |
| # # Attention Projection Linear Layer | |
| if qargs.weight_memory_efficient: | |
| assert weight4_t is None | |
| if weight4 is None: # the second batch | |
| weight4, weight4_s = fp8_division(weight4_origin, qargs.group_size, fwobits["fwbit"], weight4_s) | |
| if time_bench: | |
| end_2 = torch.cuda.Event(enable_timing=True) | |
| end_2.record() | |
| start_3 = torch.cuda.Event(enable_timing=True) | |
| start_3.record() | |
| fc4_x = fp8_linear_forward(flash_qx, flash_s, weight4, weight4_s, False, group_size) # | |
| if time_bench: | |
| end_3 = torch.cuda.Event(enable_timing=True) | |
| end_3.record() | |
| start_4 = torch.cuda.Event(enable_timing=True) | |
| start_4.record() | |
| # import IPython | |
| # IPython.embed() | |
| # Add the activations together | |
| fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc4_x, flash_qx.dtype, group_size) | |
| if time_bench: | |
| end_4 = torch.cuda.Event(enable_timing=True) | |
| end_4.record() | |
| torch.cuda.synchronize() | |
| if int(os.environ.get("LOCAL_RANK")) == 0: | |
| print( | |
| f"[AfterAt] Part 1: {start_1.elapsed_time(end_1):.6f} ms | " | |
| f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | " | |
| f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | " | |
| f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | " | |
| f" Input shape: {re_x.shape}" | |
| ) | |
| # ==================== save for backward ==================== | |
| ctx.save_for_backward(flash_x, flash_s) | |
| if qargs.weight_memory_efficient: | |
| assert weight4_t is None | |
| ctx.weight = weight4_origin, weight4_s | |
| else: | |
| ctx.weight = weight4_t, weight4_s | |
| ctx.group_size = group_size | |
| ctx.fwobits = fwobits | |
| ctx.utils = fwobits, layer_id, config, qargs | |
| # For autograd | |
| out_x = out_x.view(torch.float8_e4m3fn) | |
| return fp_x, out_x, out_s | |
| def backward(ctx, fp_grad, out_g, out_gs): | |
| flash_x, flash_s = ctx.saved_tensors | |
| weight4_t, weight4_s = ctx.weight | |
| group_size = ctx.group_size | |
| fwobits = ctx.fwobits | |
| fwobits, layer_id, config, qargs = ctx.utils | |
| # for autograd | |
| if fwobits["babit"] == "E5M2": | |
| # out_g = out_g.to(torch.float8_e5m2) | |
| out_g = out_g.view(torch.float8_e5m2) | |
| else: | |
| raise ValueError("babit should be E5M2") | |
| out_gs_max = out_gs.max() | |
| # ==================== Begin backward ==================== | |
| # Output Projection | |
| out_g_t = fp8_transpose(out_g, transpose_output_2d=True) | |
| # We do not save an extra flash_x to save the memory usage | |
| flash_x_t, flash_s = fp8_division_transpose( | |
| flash_x, group_size, fwobits["fabit"], flash_s, stochastic=False, only_transposed=True | |
| ) | |
| if qargs.weight_memory_efficient: | |
| weight4_t, weight4_s = fp8_division_transpose( | |
| weight4_t, qargs.group_size, fwobits["fwbit"], weight4_s, only_transposed=True | |
| ) | |
| fc4_g, attn_out_wg = fp8_linear_backward( | |
| flash_x_t, flash_s, out_g, out_gs_max, out_g_t, weight4_t, weight4_s, group_size | |
| ) | |
| return fp_grad, fc4_g, attn_out_wg, None, None, None, None, None, None, None, None | |
| class FP8ActivationQwen2MLPResidual(FP8CacheWeightModule): | |
| """ | |
| This is a typical transformer attention module that contains (1) Residual (2) LayerNorm / RMSNorm (3) 2 / 3 * Linear layers | |
| (4) GELU / Silu Activation | |
| """ | |
| def __init__(self, config: FP8ActivationQwen2Config, qargs: QuantizationConfig, layer_id, hidden_size: int): | |
| super().__init__(config, qargs, layer_id) | |
| self.qargs = qargs | |
| self.fwobits = { | |
| "fabit": self.qargs.fabit, | |
| "fwbit": self.qargs.fwbit, | |
| "fobit": self.qargs.fobit, | |
| "babit": self.qargs.babit, | |
| "bwbit": self.qargs.bwbit, | |
| "bobit": self.qargs.bobit, | |
| } | |
| self.config = config | |
| self.hidden_size = config.hidden_size | |
| self.intermediate_size = config.intermediate_size | |
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) | |
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) | |
| self.training = True | |
| # below is only used when training = False | |
| assert config.hidden_act == "silu", "We only support silu activation currently" | |
| self.act_fn = ACT2FN[config.hidden_act] | |
| def forward(self, re_x, x, s, rmsnorm_weight): | |
| if self.training: | |
| if self.qargs.weight_memory_efficient: # prepare for the weight | |
| with torch.no_grad(): | |
| if FP8Manager.is_first_microbatch: | |
| # Directly use the corresponding weight | |
| weight1, _, weight1_s = self.prepare_weight( | |
| self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight2, _, weight2_s = self.prepare_weight( | |
| self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight3, _, weight3_s = self.prepare_weight( | |
| self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch | |
| ) | |
| else: | |
| weight1_s = self.prepare_weight( | |
| self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight2_s = self.prepare_weight(self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch) | |
| weight3_s = self.prepare_weight( | |
| self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight1, weight2, weight3 = None, None, None | |
| return _FP8ActivationQwen2MLPResidual.apply( | |
| re_x, | |
| x, | |
| s, | |
| self.gate_proj.weight, | |
| weight1, | |
| None, | |
| weight1_s, | |
| self.up_proj.weight, | |
| weight2, | |
| None, | |
| weight2_s, | |
| self.down_proj.weight, | |
| weight3, | |
| None, | |
| weight3_s, | |
| rmsnorm_weight, | |
| self.qargs.group_size, | |
| self.fwobits, | |
| self.layer_id, | |
| self.config, | |
| self.qargs, | |
| ) | |
| else: | |
| # prepare for the weight | |
| with torch.no_grad(): | |
| weight1, weight1_t, weight1_s = self.prepare_weight( | |
| self.gate_proj.weight, "gate_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight2, weight2_t, weight2_s = self.prepare_weight( | |
| self.up_proj.weight, "up_proj", FP8Manager.is_first_microbatch | |
| ) | |
| weight3, weight3_t, weight3_s = self.prepare_weight( | |
| self.down_proj.weight, "down_proj", FP8Manager.is_first_microbatch | |
| ) | |
| return _FP8ActivationQwen2MLPResidual.apply( | |
| re_x, | |
| x, | |
| s, | |
| self.gate_proj.weight, | |
| weight1, | |
| weight1_t, | |
| weight1_s, | |
| self.up_proj.weight, | |
| weight2, | |
| weight2_t, | |
| weight2_s, | |
| self.down_proj.weight, | |
| weight3, | |
| weight3_t, | |
| weight3_s, | |
| rmsnorm_weight, | |
| self.qargs.group_size, | |
| self.fwobits, | |
| self.layer_id, | |
| self.config, | |
| self.qargs, | |
| ) | |
| else: | |
| raise NotImplementedError("Need TODO") | |
| og_x = re_x | |
| re_x = self.ff_norm(re_x) | |
| re_x = self.ff_proj(re_x) | |
| re_x = self.act(re_x) | |
| re_x = self.ff_out(re_x) | |
| re_x = og_x + re_x | |
| return re_x, None, None | |
| class _FP8ActivationQwen2MLPResidual(torch.autograd.Function): | |
| def forward( | |
| ctx, | |
| re_x, | |
| in_x, | |
| in_s, | |
| weight1_origin, | |
| weight1, | |
| weight1_t, | |
| weight1_s, | |
| weight2_origin, | |
| weight2, | |
| weight2_t, | |
| weight2_s, | |
| weight3_origin, | |
| weight3, | |
| weight3_t, | |
| weight3_s, | |
| rmsnorm_weight, | |
| group_size, | |
| fwobits, | |
| layer_id, | |
| config, | |
| qargs, | |
| eps=1e-5, | |
| ): | |
| # For autograd | |
| if fwobits["fabit"] == "E4M3": | |
| # in_x = in_x.to(torch.float8_e4m3fn) | |
| in_x = in_x.view(torch.float8_e4m3fn) | |
| else: | |
| raise ValueError("fabit should be E4M3") | |
| # LayerNorm | |
| ln_x, ln_s, ln_x_t, ln_utils = fp8_rmsnorm_forward( | |
| in_x, in_s, rmsnorm_weight, group_size, eps, transpose_output_2d=True | |
| ) | |
| # Linear Layer of Up Projection and Gate Projection. They are fused as one linear layer. | |
| if qargs.weight_memory_efficient: | |
| assert weight1_t is None and weight2_t is None and weight3_t is None # memory efficient | |
| if weight1 is None: | |
| assert weight2 is None and weight3 is None | |
| weight1, weight1_s = fp8_division(weight1_origin, qargs.group_size, fwobits["fwbit"], weight1_s) | |
| weight2, weight2_s = fp8_division(weight2_origin, qargs.group_size, fwobits["fwbit"], weight2_s) | |
| weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) | |
| gate_x, gate_s = fp8_linear_forward(ln_x, ln_s, weight1, weight1_s, True, group_size) # Gate Proj | |
| up_x, up_s = fp8_linear_forward(ln_x, ln_s, weight2, weight2_s, True, group_size) # Up Proj | |
| # silu Activation | |
| silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) | |
| # Element-wise Multiplication | |
| mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) | |
| # Output Projection | |
| if weight3 is None: # memory efficient | |
| weight3, weight3_s = fp8_division(weight3_origin, qargs.group_size, fwobits["fwbit"], weight3_s) | |
| fc3_x = fp8_linear_forward(mul_x, mul_s, weight3, weight3_s, False, group_size) | |
| # Add the activation together | |
| fp_x, (out_x, out_s) = fp8_add_Ifp_Ifp_Ofp_Og16(re_x, fc3_x, mul_x.dtype, group_size) | |
| # ==================== save for backward ==================== | |
| ctx.save_for_backward(in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s) | |
| ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s) | |
| if ( | |
| qargs.weight_memory_efficient | |
| ): # Weight_1/2_origin will not be saved twice, so it will be more memory efficient. | |
| assert weight1_t is None and weight2_t is None and weight3_t is None | |
| ctx.weight = (weight1_origin, weight1_s, weight2_origin, weight2_s, weight3_origin, weight3_s) | |
| else: # Weight1/2_t is different from the origin weight, so saving it will consumes additional memory footprint. | |
| ctx.weight = (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) | |
| ctx.group_size = group_size | |
| ctx.ln_utils = ln_utils | |
| ctx.utils = fwobits, layer_id, config, qargs | |
| out_x = out_x.view(torch.float8_e4m3fn) | |
| return fp_x, out_x, out_s | |
| def backward(ctx, fp_grad, out_g, out_gs): | |
| fwobits, layer_id, config, qargs = ctx.utils | |
| in_x, in_s, ln_x_t, ln_s, gate_x, gate_s, up_x, up_s = ctx.saved_tensors | |
| (weight1_t, weight1_s, weight2_t, weight2_s, weight3_t, weight3_s) = ctx.weight | |
| group_size = ctx.group_size | |
| rms_weight, rstd, num_warps = ctx.ln_utils | |
| fwobits, layer_id, config, qargs = ctx.utils | |
| # For autograd | |
| if fwobits["babit"] == "E5M2": | |
| # out_g = out_g.to(torch.float8_e5m2) | |
| out_g = out_g.view(torch.float8_e5m2) | |
| else: | |
| raise ValueError("babit should be E5M2") | |
| out_gs_max = out_gs.max() | |
| # ==================== Begin backward ==================== | |
| # Output Projection | |
| out_gs = out_gs.max() | |
| out_g_t = fp8_transpose(out_g, transpose_output_2d=True) | |
| # Element-wise Multiplication gradient checkpointing | |
| # silu gradient checkpointing | |
| silu_x, silu_s = fp8_silu_forward(gate_x, gate_s, group_size) | |
| mul_x, mul_s, mul_x_t = fp8_mul_forward(silu_x, silu_s, up_x, up_s, group_size, transpose_output_2d=True) | |
| if qargs.weight_memory_efficient: | |
| weight3_t, weight3_s = fp8_division_transpose( | |
| weight3_t, qargs.group_size, fwobits["fwbit"], weight3_s, only_transposed=True | |
| ) | |
| fc3_g, weight3_grad = fp8_linear_backward( | |
| mul_x_t, mul_s, out_g, out_gs_max, out_g_t, weight3_t, weight3_s, group_size | |
| ) | |
| # [MEM TEST] | |
| del out_g, out_g_t, weight3_t | |
| # Element-wise Multiplication, 1 means gate, 2 means up | |
| mul_g1, (mul_g2, mul_gs2, mul_g2_t) = fp8_mul_backward( | |
| silu_x, silu_s, up_x, up_s, fc3_g, group_size, fwobits["babit"], output_quantized_transpose=True | |
| ) | |
| # Silu activation | |
| silu_g, silu_gs, silu_g_t = fp8_silu_backward( | |
| gate_x, gate_s, mul_g1, group_size, fwobits["babit"], output_quantized_transpose=True | |
| ) | |
| # Linear Layer of Up and Gate Projection | |
| if qargs.weight_memory_efficient: | |
| weight1_t, weight1_s = fp8_division_transpose( | |
| weight1_t, group_size, fwobits["fwbit"], weight1_s, only_transposed=True | |
| ) | |
| weight2_t, weight2_s = fp8_division_transpose( | |
| weight2_t, group_size, fwobits["fwbit"], weight2_s, only_transposed=True | |
| ) | |
| # Gate Proj | |
| fc1_g, weight1_grad = fp8_linear_backward( | |
| ln_x_t, ln_s, silu_g, silu_gs, silu_g_t, weight1_t, weight1_s, group_size | |
| ) | |
| fc2_g, weight2_grad = fp8_linear_backward( | |
| ln_x_t, ln_s, mul_g2, mul_gs2, mul_g2_t, weight2_t, weight2_s, group_size | |
| ) | |
| fc_g = fc1_g + fc2_g | |
| # layerNorm | |
| in_g, rms_weight_grad = fp8_rmsnorm_backward(in_x, in_s, fc_g, rms_weight, rstd, group_size, num_warps) | |
| # Add the gradient together | |
| re_g, (in_g, in_sg, in_sg_g16) = fp8_add_Ifp_Ifp_Ofp_Opt( | |
| fp_grad, in_g, group_size, fwobits["babit"], stochastic=False | |
| ) | |
| in_g = in_g.view(torch.float8_e4m3fn) | |
| return ( | |
| re_g, | |
| in_g, | |
| in_sg_g16, | |
| weight1_grad, | |
| None, | |
| None, | |
| None, | |
| weight2_grad, | |
| None, | |
| None, | |
| None, | |
| weight3_grad, | |
| None, | |
| None, | |
| None, | |
| rms_weight_grad, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| None, | |
| ) | |
| class FP8ActivationQwen2AttentionWithoutLinear(nn.Module): | |
| """ | |
| Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer | |
| and "Generating Long Sequences with Sparse Transformers". | |
| """ | |
| def __init__(self, config: Qwen2Config, layer_idx: Optional[int] = None): | |
| super().__init__() | |
| self.config = config | |
| self.layer_idx = layer_idx | |
| if layer_idx is None: | |
| logger.warning_once( | |
| f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will " | |
| "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " | |
| "when creating this class." | |
| ) | |
| self.hidden_size = config.hidden_size | |
| self.num_heads = config.num_attention_heads | |
| self.head_dim = self.hidden_size // self.num_heads | |
| self.num_key_value_heads = config.num_key_value_heads | |
| self.num_key_value_groups = self.num_heads // self.num_key_value_heads | |
| self.max_position_embeddings = config.max_position_embeddings | |
| self.rope_theta = config.rope_theta | |
| self.is_causal = True | |
| self.attention_dropout = config.attention_dropout | |
| if (self.head_dim * self.num_heads) != self.hidden_size: | |
| raise ValueError( | |
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" | |
| f" and `num_heads`: {self.num_heads})." | |
| ) | |
| self.rotary_emb = Qwen2RotaryEmbedding(config=self.config) | |
| def forward( | |
| self, | |
| query_states: torch.Tensor, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| bsz, q_len, _ = query_states.size() | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| if position_embeddings is None: | |
| logger.warning_once( | |
| "The attention layers in this model are transitioning from computing the RoPE embeddings internally " | |
| "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " | |
| "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " | |
| "removed and `position_embeddings` will be mandatory." | |
| ) | |
| cos, sin = self.rotary_emb(value_states, position_ids) | |
| else: | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| if past_key_value is not None: | |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models | |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| # repeat k/v heads if n_kv_heads < n_heads | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) | |
| if attention_mask is not None: # no matter the length, we just slice it | |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
| attn_weights = attn_weights + causal_mask | |
| # upcast attention to fp32 | |
| attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) | |
| attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) | |
| attn_output = torch.matmul(attn_weights, value_states) | |
| if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): | |
| raise ValueError( | |
| f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" | |
| f" {attn_output.size()}" | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
| if not output_attentions: | |
| attn_weights = None | |
| return attn_output, attn_weights, past_key_value | |
| class FP8ActivationQwen2FlashAttention2WithoutLinear(FP8ActivationQwen2AttentionWithoutLinear): | |
| """ | |
| Qwen2 flash attention module, following Qwen2 attention module. This module inherits from `Qwen2Attention` | |
| as the weights of the module stays untouched. The only required change would be on the forward pass | |
| where it needs to correctly call the public API of flash attention and deal with padding tokens | |
| in case the input contains any of them. Additionally, for sliding window attention, we apply SWA only to the bottom | |
| config.max_window_layers layers. | |
| """ | |
| # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. | |
| # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. | |
| # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). | |
| self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() | |
| def forward( | |
| self, | |
| query_states: torch.Tensor, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | |
| ): | |
| bsz, q_len, _ = query_states.size() | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| if position_embeddings is None: | |
| logger.warning_once( | |
| "The attention layers in this model are transitioning from computing the RoPE embeddings internally " | |
| "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " | |
| "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " | |
| "removed and `position_embeddings` will be mandatory." | |
| ) | |
| cos, sin = self.rotary_emb(value_states, position_ids) | |
| else: | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| if past_key_value is not None: | |
| # Activate slicing cache only if the config has a value `sliding_windows` attribute | |
| cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 | |
| kv_seq_len = key_states.shape[-2] + cache_position[0] | |
| if ( | |
| getattr(self.config, "sliding_window", None) is not None | |
| and kv_seq_len > self.config.sliding_window | |
| and cache_has_contents | |
| ): | |
| slicing_tokens = 1 - self.config.sliding_window | |
| past_key = past_key_value[self.layer_idx][0] | |
| past_value = past_key_value[self.layer_idx][1] | |
| past_key = past_key[:, :, slicing_tokens:, :].contiguous() | |
| past_value = past_value[:, :, slicing_tokens:, :].contiguous() | |
| if past_key.shape[-2] != self.config.sliding_window - 1: | |
| raise ValueError( | |
| f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" | |
| f" {past_key.shape}" | |
| ) | |
| if attention_mask is not None: | |
| attention_mask = attention_mask[:, slicing_tokens:] | |
| attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) | |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models | |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| # repeat k/v heads if n_kv_heads < n_heads | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| dropout_rate = 0.0 if not self.training else self.attention_dropout | |
| # In PEFT, usually we cast the layer norms in float32 for training stability reasons | |
| # therefore the input hidden states gets silently casted in float32. Hence, we need | |
| # cast them back in float16 just to be sure everything works as expected. | |
| input_dtype = query_states.dtype | |
| if input_dtype == torch.float32: | |
| if torch.is_autocast_enabled(): | |
| target_dtype = torch.get_autocast_gpu_dtype() | |
| # Handle the case where the model is quantized | |
| elif hasattr(self.config, "_pre_quantization_dtype"): | |
| target_dtype = self.config._pre_quantization_dtype | |
| else: | |
| target_dtype = self.q_proj.weight.dtype | |
| logger.warning_once( | |
| f"The input hidden states seems to be silently casted in float32, this might be related to" | |
| f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" | |
| f" {target_dtype}." | |
| ) | |
| query_states = query_states.to(target_dtype) | |
| key_states = key_states.to(target_dtype) | |
| value_states = value_states.to(target_dtype) | |
| # Reashape to the expected shape for Flash Attention | |
| query_states = query_states.transpose(1, 2) | |
| key_states = key_states.transpose(1, 2) | |
| value_states = value_states.transpose(1, 2) | |
| if ( | |
| self.config.use_sliding_window | |
| and getattr(self.config, "sliding_window", None) is not None | |
| and self.layer_idx >= self.config.max_window_layers | |
| ): | |
| sliding_window = self.config.sliding_window | |
| else: | |
| sliding_window = None | |
| attn_output = _flash_attention_forward( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attention_mask, | |
| q_len, | |
| position_ids=position_ids, | |
| dropout=dropout_rate, | |
| sliding_window=sliding_window, | |
| is_causal=self.is_causal, | |
| use_top_left_mask=self._flash_attn_uses_top_left_mask, | |
| ) | |
| attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() | |
| if not output_attentions: | |
| attn_weights = None | |
| return attn_output, attn_weights, past_key_value | |
| class FP8ActivationQwen2SdpaAttentionWithoutLinear(FP8ActivationQwen2AttentionWithoutLinear): | |
| """ | |
| Qwen2 attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from | |
| `Qwen2Attention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to | |
| SDPA API. | |
| """ | |
| # Adapted from Qwen2Attention.forward | |
| def forward( | |
| self, | |
| query_states: torch.Tensor, | |
| key_states: torch.Tensor, | |
| value_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Cache] = None, | |
| output_attentions: bool = False, | |
| use_cache: bool = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | |
| ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
| if output_attentions: | |
| # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. | |
| logger.warning_once( | |
| "Qwen2Model is using Qwen2SdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " | |
| 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' | |
| ) | |
| return super().forward( | |
| query_states=query_states, | |
| key_states=key_states, | |
| value_states=value_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| ) | |
| bsz, q_len, _ = query_states.size() | |
| query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
| key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) | |
| if position_embeddings is None: | |
| logger.warning_once( | |
| "The attention layers in this model are transitioning from computing the RoPE embeddings internally " | |
| "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed " | |
| "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be " | |
| "removed and `position_embeddings` will be mandatory." | |
| ) | |
| cos, sin = self.rotary_emb(value_states, position_ids) | |
| else: | |
| cos, sin = position_embeddings | |
| query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) | |
| if past_key_value is not None: | |
| cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models | |
| key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) | |
| key_states = repeat_kv(key_states, self.num_key_value_groups) | |
| value_states = repeat_kv(value_states, self.num_key_value_groups) | |
| causal_mask = attention_mask | |
| if attention_mask is not None: # no matter the length, we just slice it | |
| causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] | |
| # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, | |
| # Reference: https://github.com/pytorch/pytorch/issues/112577. | |
| if query_states.device.type == "cuda" and attention_mask is not None: | |
| query_states = query_states.contiguous() | |
| key_states = key_states.contiguous() | |
| value_states = value_states.contiguous() | |
| # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment | |
| # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling. | |
| # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1. | |
| is_causal = True if causal_mask is None and q_len > 1 else False | |
| attn_output = torch.nn.functional.scaled_dot_product_attention( | |
| query_states, | |
| key_states, | |
| value_states, | |
| attn_mask=causal_mask, | |
| dropout_p=self.attention_dropout if self.training else 0.0, | |
| is_causal=is_causal, | |
| ) | |
| attn_output = attn_output.transpose(1, 2).contiguous() | |
| attn_output = attn_output.view(bsz, q_len, self.hidden_size) | |
| return attn_output, None, past_key_value | |
| FP8LINEARQWEN2_ATTENTION_CLASSES = { | |
| "eager": FP8ActivationQwen2AttentionWithoutLinear, | |
| "flash_attention_2": FP8ActivationQwen2FlashAttention2WithoutLinear, | |
| "sdpa": FP8ActivationQwen2SdpaAttentionWithoutLinear, | |
| } | |
| class FP8ActivationQwen2DecoderLayer(nn.Module): | |
| def __init__(self, config: FP8ActivationQwen2Config, layer_idx: int): | |
| super().__init__() | |
| self.hidden_size = config.hidden_size | |
| if config.sliding_window and config._attn_implementation != "flash_attention_2": | |
| logger.warning_once( | |
| f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; " | |
| "unexpected results may be encountered." | |
| ) | |
| self.self_attn = FP8LINEARQWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) | |
| self.qargs = QuantizationConfig(**config.coat_fp8_args) | |
| self.BeforeAttention = FP8ActivationQwen2BeforeAttentionResidual(config, self.qargs, layer_idx) | |
| self.AfterAttention = FP8ActivationQwen2AfterAttentionResidual(config, self.qargs, layer_idx) | |
| self.MLPResidual = FP8ActivationQwen2MLPResidual(config, self.qargs, layer_idx, self.hidden_size) | |
| self.input_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.post_attention_layernorm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| quant_hidden_states: torch.Tensor, | |
| scale_hidden_states: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46 | |
| **kwargs, | |
| ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
| `(batch, sequence_length)` where padding elements are indicated by 0. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
| cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): | |
| Indices depicting the position of the input sequence tokens in the sequence. | |
| position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*): | |
| Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`, | |
| with `head_dim` being the embedding dimension of each attention head. | |
| kwargs (`dict`, *optional*): | |
| Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code | |
| into the model | |
| """ | |
| time_bench = os.getenv("TIME_BENCH") | |
| if time_bench: | |
| start_1 = torch.cuda.Event(enable_timing=True) | |
| start_1.record() | |
| # Coat: The residual, LayerNorm, and the Q/K/V Projection Linear Layer | |
| residual, query_states, key_states, value_states = self.BeforeAttention( | |
| hidden_states, quant_hidden_states, scale_hidden_states, self.input_layernorm.weight | |
| ) | |
| if time_bench: | |
| end_1 = torch.cuda.Event(enable_timing=True) | |
| end_1.record() | |
| start_2 = torch.cuda.Event(enable_timing=True) | |
| start_2.record() | |
| # Self Attention without any linear layer | |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
| query_states=query_states, | |
| key_states=key_states, | |
| value_states=value_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| position_embeddings=position_embeddings, | |
| **kwargs, | |
| ) | |
| if time_bench: | |
| end_2 = torch.cuda.Event(enable_timing=True) | |
| end_2.record() | |
| start_3 = torch.cuda.Event(enable_timing=True) | |
| start_3.record() | |
| # Coat: The Output Projection Linear Layer and Residual | |
| hidden_states, quant_hidden_states, scale_hidden_states = self.AfterAttention(residual, hidden_states) | |
| if time_bench: | |
| end_3 = torch.cuda.Event(enable_timing=True) | |
| end_3.record() | |
| start_4 = torch.cuda.Event(enable_timing=True) | |
| start_4.record() | |
| # Residual Connection, LayerNorm, and the whole MLP module | |
| hidden_states, quant_hidden_states, scale_hidden_states = self.MLPResidual( | |
| hidden_states, quant_hidden_states, scale_hidden_states, self.post_attention_layernorm.weight | |
| ) | |
| if time_bench: | |
| end_4 = torch.cuda.Event(enable_timing=True) | |
| end_4.record() | |
| torch.cuda.synchronize() | |
| if int(os.environ.get("LOCAL_RANK")) == 0: | |
| print( | |
| f"[Forward] Part 1: {start_1.elapsed_time(end_1):.6f} ms | " | |
| f" Part 2: {start_2.elapsed_time(end_2):.6f} ms | " | |
| f" Part 3: {start_3.elapsed_time(end_3):.6f} ms | " | |
| f" Part 4: {start_4.elapsed_time(end_4):.6f} ms | " | |
| f" Input shape: {hidden_states.shape}" | |
| ) | |
| outputs = ((hidden_states, quant_hidden_states, scale_hidden_states),) | |
| # if int(os.environ.get("LOCAL_RANK")) == 0: | |
| # import IPython | |
| # IPython.embed() | |
| # else: | |
| # import time | |
| # time.sleep(1000) | |
| # if output_attentions: | |
| # outputs += (self_attn_weights,) | |
| # if use_cache: | |
| # outputs += (present_key_value,) | |
| return outputs | |
| class FP8ActivationQwen2PreTrainedModel(Qwen2PreTrainedModel): | |
| config_class = FP8ActivationQwen2Config | |
| base_model_prefix = "model" | |
| supports_gradient_checkpointing = True | |
| _no_split_modules = ["FP8ActivationQwen2DecoderLayer"] | |
| _skip_keys_device_placement = "past_key_values" | |
| _supports_flash_attn_2 = True | |
| _supports_sdpa = True | |
| _supports_cache_class = True | |
| def _init_weights(self, module): | |
| std = self.config.initializer_range | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| elif isinstance(module, nn.Embedding): | |
| module.weight.data.normal_(mean=0.0, std=std) | |
| if module.padding_idx is not None: | |
| module.weight.data[module.padding_idx].zero_() | |
| class FP8ActivationQwen2Model(FP8ActivationQwen2PreTrainedModel): | |
| """ | |
| Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2DecoderLayer`] | |
| Args: | |
| config: Qwen2Config | |
| """ | |
| def __init__(self, config: FP8ActivationQwen2Config): | |
| super().__init__(config) | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) | |
| self.layers = nn.ModuleList( | |
| [FP8ActivationQwen2DecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] | |
| ) | |
| self._attn_implementation = config._attn_implementation | |
| self.norm = Qwen2RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | |
| self.rotary_emb = Qwen2RotaryEmbedding(config=config) | |
| # Quantize | |
| self.qargs = QuantizationConfig(**config.coat_fp8_args) | |
| self.quantize_input_before_block = Coat_quantize_bgn(self.qargs) | |
| self.quantize_output_after_block = Coat_quantize_end(self.qargs) | |
| self.gradient_checkpointing = False | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| self.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False}) | |
| def get_input_embeddings(self): | |
| return self.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.embed_tokens = value | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor = None, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| cache_position: Optional[torch.LongTensor] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| # import warnings | |
| # # 将所有 UserWarning 提升为异常 | |
| # warnings.simplefilter("error", UserWarning) | |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions | |
| output_hidden_states = ( | |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError( | |
| "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" | |
| ) | |
| if self.gradient_checkpointing and self.training: | |
| if use_cache: | |
| logger.warning_once( | |
| "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." | |
| ) | |
| use_cache = False | |
| # kept for BC (non `Cache` `past_key_values` inputs) | |
| return_legacy_cache = False | |
| if use_cache and not isinstance(past_key_values, Cache): | |
| return_legacy_cache = True | |
| if past_key_values is None: | |
| past_key_values = DynamicCache() | |
| else: | |
| past_key_values = DynamicCache.from_legacy_cache(past_key_values) | |
| logger.warning_once( | |
| "We detected that you are passing `past_key_values` as a tuple of tuples. This is deprecated and " | |
| "will be removed in v4.47. Please convert your cache or use an appropriate `Cache` class " | |
| "(https://huggingface.co/docs/transformers/kv_cache#legacy-cache-format)" | |
| ) | |
| if inputs_embeds is None: | |
| inputs_embeds = self.embed_tokens(input_ids) | |
| if cache_position is None: | |
| past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 | |
| cache_position = torch.arange( | |
| past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device | |
| ) | |
| if position_ids is None: | |
| position_ids = cache_position.unsqueeze(0) | |
| causal_mask = self._update_causal_mask( | |
| attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions | |
| ) | |
| hidden_states = inputs_embeds | |
| # create position embeddings to be shared across the decoder layers | |
| position_embeddings = self.rotary_emb(hidden_states, position_ids) | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| next_decoder_cache = None | |
| # Prepare the input for Coat decoderlayer | |
| hidden_states, quant_hidden_states, scale_hidden_states = self.quantize_input_before_block(hidden_states) | |
| for layer_idx, decoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| # if int(os.environ.get("LOCAL_RANK")) == 0: | |
| # print(f"FP8 Layer Idx: {layer_idx}, Input Shape: {hidden_states.shape}, Memory Allocated: {torch.cuda.memory_allocated() // 1024 ** 2} MB | Peak: {torch.cuda.max_memory_allocated() // 1024 ** 2} MB") | |
| # NOTE: We explicitly force the LLM do not use Gradient Checkpointing | |
| # exit(0) | |
| # if self.gradient_checkpointing and self.training: | |
| # layer_outputs = self._gradient_checkpointing_func( | |
| # decoder_layer.__call__, | |
| # hidden_states, | |
| # quant_hidden_states, | |
| # scale_hidden_states, | |
| # causal_mask, | |
| # position_ids, | |
| # past_key_values, | |
| # output_attentions, | |
| # use_cache, | |
| # cache_position, | |
| # position_embeddings, | |
| # ) | |
| # else: | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| quant_hidden_states, | |
| scale_hidden_states, | |
| attention_mask=causal_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_values, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cache_position=cache_position, | |
| position_embeddings=position_embeddings, | |
| ) | |
| hidden_states, quant_hidden_states, scale_hidden_states = layer_outputs[0] | |
| if use_cache: | |
| next_decoder_cache = layer_outputs[2 if output_attentions else 1] | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| # Summarize the output of the Decoder Layer | |
| hidden_states = self.quantize_output_after_block(hidden_states, quant_hidden_states, scale_hidden_states) | |
| hidden_states = self.norm(hidden_states) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| next_cache = next_decoder_cache if use_cache else None | |
| if return_legacy_cache: | |
| next_cache = next_cache.to_legacy_cache() | |
| if not return_dict: | |
| return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) | |
| return BaseModelOutputWithPast( | |
| last_hidden_state=hidden_states, | |
| past_key_values=next_cache, | |
| hidden_states=all_hidden_states, | |
| attentions=all_self_attns, | |
| ) | |
| # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask | |
| _update_causal_mask = Qwen2Model._update_causal_mask | |
| class FP8ActivationQwen2ForCausalLM(FP8ActivationQwen2PreTrainedModel): | |
| _tied_weights_keys = ["lm_head.weight"] | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = FP8ActivationQwen2Model(config) | |
| self.vocab_size = config.vocab_size | |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) | |
| # Initialize weights and apply final processing | |
| self.post_init() | |
| def get_input_embeddings(self): | |
| return self.model.embed_tokens | |
| def set_input_embeddings(self, value): | |
| self.model.embed_tokens = value | |
| def get_output_embeddings(self): | |
| return self.lm_head | |
| def set_output_embeddings(self, new_embeddings): | |
| self.lm_head = new_embeddings | |
| def set_decoder(self, decoder): | |
| self.model = decoder | |
| def get_decoder(self): | |
| return self.model | |
| def loss_function(self): | |
| return LigerForCausalLMLoss | |
| forward = Qwen2ForCausalLM.forward | |
| AutoConfig.register("fp8activation_qwen2", FP8ActivationQwen2Config) | |
| AutoModel.register(FP8ActivationQwen2Config, FP8ActivationQwen2Model) | |
| AutoModelForCausalLM.register(FP8ActivationQwen2Config, FP8ActivationQwen2ForCausalLM) | |
| def make_state_dict_compatible(state_dict: dict[str, torch.Tensor]): | |
| compatible_state_dict = {} | |
| for key, value in state_dict.items(): | |
| if fnmatch(key, "*self_attn.q_proj*"): | |
| new_key = key.replace("self_attn.q_proj", "BeforeAttention.q_proj") | |
| elif fnmatch(key, "*self_attn.k_proj*"): | |
| new_key = key.replace("self_attn.k_proj", "BeforeAttention.k_proj") | |
| elif fnmatch(key, "*self_attn.v_proj*"): | |
| new_key = key.replace("self_attn.v_proj", "BeforeAttention.v_proj") | |
| elif fnmatch(key, "*self_attn.o_proj*"): | |
| new_key = key.replace("self_attn.o_proj", "AfterAttention.o_proj") | |
| elif fnmatch(key, "*mlp.gate_proj*"): | |
| new_key = key.replace("mlp.gate_proj", "MLPResidual.gate_proj") | |
| elif fnmatch(key, "*mlp.up_proj*"): | |
| new_key = key.replace("mlp.up_proj", "MLPResidual.up_proj") | |
| elif fnmatch(key, "*mlp.down_proj*"): | |
| new_key = key.replace("mlp.down_proj", "MLPResidual.down_proj") | |
| else: | |
| new_key = key | |
| compatible_state_dict[new_key] = value | |
| return compatible_state_dict | |