import os import math import datetime from dataclasses import dataclass from typing import Tuple, Optional, Literal import torch from torch import nn import torch.nn.functional as F import torch.distributed as dist from safetensors.torch import load_model from kernel import act_quant, weight_dequant, fp8_gemm, int64_bmm_broadcast, \ complex_int64_mul_broadcast, einsum_bshd_hdc_bshc, einsum_bshc_btc_bsht, softmax_init_q21, softmax_q21, einsum_bsht_btc_bshc, einsum_bshc_hdc_bshd, \ silu_init_q25, silu_q25, sigmoid_q25, softmax_init_q19, softmax_q19, silu_init_q23, silu_q23, sigmoid_q23, RMS_Norm_int64 world_size = 1 rank = 0 block_size = 128 gemm_impl: Literal["bf16", "fp8"] = "bf16" attn_impl: Literal["naive", "absorb"] = "absorb" snark = True zkDataDir = '../zkdata' @dataclass class ModelArgs: """ Data class for defining model arguments and hyperparameters. Attributes: max_batch_size (int): Maximum batch size. max_seq_len (int): Maximum sequence length. dtype (Literal["bf16", "fp8"]): Data type for computations. vocab_size (int): Vocabulary size. dim (int): Model dimension. inter_dim (int): Intermediate dimension for MLP layers. moe_inter_dim (int): Intermediate dimension for MoE layers. n_layers (int): Number of transformer layers. n_dense_layers (int): Number of dense layers in the model. n_heads (int): Number of attention heads. n_routed_experts (int): Number of routed experts for MoE layers. n_shared_experts (int): Number of shared experts for MoE layers. n_activated_experts (int): Number of activated experts in MoE layers. n_expert_groups (int): Number of expert groups. n_limited_groups (int): Number of limited groups for MoE routing. score_func (Literal["softmax", "sigmoid"]): Scoring function for MoE routing. route_scale (float): Scaling factor for routing scores. q_lora_rank (int): LoRA rank for query projections. kv_lora_rank (int): LoRA rank for key-value projections. qk_nope_head_dim (int): Dimension for query-key projections without positional embeddings. qk_rope_head_dim (int): Dimension for query-key projections with rotary embeddings. v_head_dim (int): Dimension for value projections. original_seq_len (int): Original sequence length. rope_theta (float): Base for rotary positional encoding. rope_factor (float): Scaling factor for extended sequence lengths. beta_fast (int): Fast beta correction factor. beta_slow (int): Slow beta correction factor. mscale (float): Scaling factor for extended attention. """ max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" vocab_size: int = 102400 dim: int = 2048 inter_dim: int = 10944 moe_inter_dim: int = 1408 n_layers: int = 27 n_dense_layers: int = 1 n_heads: int = 16 # moe n_routed_experts: int = 64 n_shared_experts: int = 2 n_activated_experts: int = 6 n_expert_groups: int = 1 n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" route_scale: float = 1. # mla q_lora_rank: int = 0 kv_lora_rank: int = 512 qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 # yarn original_seq_len: int = 4096 rope_theta: float = 10000.0 rope_factor: float = 40 beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1. def saveTensor(fileName, t): with open(fileName, "w", encoding="utf-8") as f: t = t.detach() if t.device.type != "cpu": t = t.cpu() t = t.contiguous() with open(fileName, "wb") as f: # .numpy() -> bytes(C-order) f.write(t.numpy().tobytes(order="C")) class ParallelEmbedding(nn.Module): """ Embedding layer with parallelism support across distributed processes. Args: vocab_size (int): Vocabulary size. dim (int): Embedding dimension. """ def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size self.dim = dim assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size # weight 的 shape: [129280, 7168] self.register_buffer("weight", torch.empty(self.part_vocab_size, self.dim, dtype=torch.int64)) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for parallel embedding layer. Args: x (torch.Tensor): Input tensor containing token indices. Returns: torch.Tensor: Embedded representations. Raises: ValueError: If `world_size` is not defined. """ # print('aaab ' + str(self.weight[0][0].type())) if world_size > 1: # 找出 x 中 的值不在 [vocab_start_idx, vocab_end_idx) 范围内的下标 mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) # x 中所有的值都减去 vocab_start_idx x = x - self.vocab_start_idx # 之前找出的标记为 mask 下标的值设置为0 x[mask] = 0 y = F.embedding(x, self.weight) if world_size > 1: y[mask] = 0 dist.all_reduce(y) # print(f'ParallelEmbedding x: {x}', flush=True) return y def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """ Applies a linear transformation to the incoming data: y = xA^T + b. This function supports specialized implementations based on quantization and tensor formats. Args: x (torch.Tensor): The input tensor. weight (torch.Tensor): The weight tensor. It may be quantized and requires dequantization for certain cases. bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. Returns: torch.Tensor: The result of the linear transformation, which may involve quantization-aware computations depending on the input parameters. Notes: - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version is used for computation. - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. """ element_size = weight.element_size() typ = weight.type() # print(f'linear weight element_size {element_size}, type: {typ}', flush=True) if weight.element_size() > 1: # print('linear weight.element_size > 1, element_size=' + str(weight.element_size()), flush=True) return F.linear(x, weight, bias) elif gemm_impl == "bf16": weight = weight_dequant(weight, weight.scale) return F.linear(x, weight, bias) else: # print('linear act_quant', flush=True) x, scale = act_quant(x, block_size) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: y += bias return y def linear_int(x: torch.Tensor, weight: torch.Tensor, x_rescale, weight_rescale, res_rescale, bias: Optional[torch.Tensor] = None) -> tuple[torch.Tensor]: if weight.element_size() > 1: (q, r) = int64_bmm_broadcast(x, weight, x_rescale, weight_rescale, res_rescale) return (q, r) elif gemm_impl == "bf16": weight = weight_dequant(weight, weight.scale) return (F.linear(x, weight, bias), torch.tensor(0, dtype=torch.int64)) else: print('linear act_quant', flush=True) x, scale = act_quant(x, block_size) y = fp8_gemm(x, scale, weight, weight.scale) if bias is not None: y += bias return (y, torch.tensor(0, dtype=torch.int64)) class Linear_int(nn.Module): """ Custom linear layer with support for quantized weights and optional bias. Args: in_features (int): Number of input features. out_features (int): Number of output features. bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ dtype = torch.int64 def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, res_rescale, dtype, bias: bool = False): super().__init__() self.layer_id = layer_id self.in_features = in_features self.out_features = out_features self.x_rescale = x_rescale self.weight_rescale = weight_rescale self.res_rescale = res_rescale self.register_buffer("weight", torch.empty(out_features, in_features, dtype=dtype)) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> tuple[torch.Tensor]: q, r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, self.res_rescale, self.bias) return q, r class Linear_rescale_int(nn.Module): """ Custom linear layer with support for quantized weights and optional bias. Args: in_features (int): Number of input features. out_features (int): Number of output features. bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ dtype = torch.int64 def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, dtype, bias: bool = False): super().__init__() self.layer_id = layer_id self.in_features = in_features self.out_features = out_features self.x_rescale = x_rescale self.weight_rescale = weight_rescale self.register_buffer("weight", torch.empty(out_features, in_features, dtype=dtype)) self.register_buffer("scale", torch.tensor(0, dtype=torch.int32)) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: rescale = self.scale.item() y, _r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, rescale, self.bias) return y class Linear(nn.Module): """ Custom linear layer with support for quantized weights and optional bias. Args: in_features (int): Number of input features. out_features (int): Number of output features. bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ dtype = torch.bfloat16 def __init__(self, layer_id, in_features: int, out_features: int, bias: bool = False, dtype = None): super().__init__() self.layer_id = layer_id self.in_features = in_features self.out_features = out_features self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) # print('Linear.weight.element_size: ' + str(self.weight.element_size())) # nn.Parameter.element_size() 返回的是 每个元素在内存中占用的字节数 # torch.float32 -> 4 字节 # torch.float64 -> 8 字节 # torch.int64 -> 8 字节 # torch.bfloat16 -> 2 字节 # torch.float8_e4m3fn -> 1 字节 if self.weight.element_size() == 1: scale_out_features = (out_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) else: self.register_parameter("scale", None) if bias: self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the custom linear layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Transformed tensor after linear computation. """ return linear(x, self.weight, self.bias) class ColumnParallelLinear(Linear): """ Linear layer with column parallelism, splitting output features across distributed processes. Args: in_features (int): Number of input features. out_features (int): Total number of output features. bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, layer_id, in_features: int, out_features: int, bias: bool = False, dtype = None): assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(layer_id, in_features, self.part_out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for column parallel linear layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Transformed tensor with column-parallel computation. """ y = linear(x, self.weight, self.bias) return y class ColumnParallelLinear_int(Linear_int): def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, res_rescale, dtype, bias: bool = False): assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(layer_id, in_features, self.part_out_features, x_rescale, weight_rescale, res_rescale, dtype, bias) def forward(self, x: torch.Tensor) -> torch.Tensor: y, _r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, self.res_rescale, self.bias) return y class ColumnParallelLinear_rescale_int(Linear_int): def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, dtype, bias: bool = False): assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(layer_id, in_features, self.part_out_features, x_rescale, weight_rescale, 1, dtype, bias) self.register_buffer("scale", torch.tensor(0, dtype=torch.int32)) # self.res_rescale = self.scale def forward(self, x: torch.Tensor) -> torch.Tensor: rescale = self.scale.item() y, _r = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, rescale, self.bias) return y class RowParallelLinear(Linear): """ Linear layer with row parallelism, splitting input features across distributed processes. Args: in_features (int): Total number of input features. out_features (int): Number of output features. bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, layer_id, in_features: int, out_features: int, bias: bool = False, dtype = None): assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(layer_id, self.part_in_features, out_features, bias, dtype) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for row parallel linear layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Transformed tensor with row-parallel computation. """ y = linear(x, self.weight) if world_size > 1: dist.all_reduce(y) if self.bias is not None: y += self.bias return y class RowParallelLinear_rescale_int(Linear_int): """ Linear layer with row parallelism, splitting input features across distributed processes. Args: in_features (int): Total number of input features. out_features (int): Number of output features. bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, layer_id, in_features: int, out_features: int, x_rescale, weight_rescale, res_rescale, dtype, bias: bool = False): assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(layer_id, self.part_in_features, out_features, x_rescale, weight_rescale, res_rescale, dtype, bias) self.register_buffer("scale", torch.tensor(0, dtype=torch.int32)) self.res_rescale = self.scale # useless def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for row parallel linear layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Transformed tensor with row-parallel computation. """ # rescale = 2 ** self.scale.item() rescale = self.scale.item() # print(f'RowParallelLinear_rescale_int forward scale: {self.scale} ' + str(rescale), flush=True) y, _ = linear_int(x, self.weight, self.x_rescale, self.weight_rescale, rescale, self.bias) if world_size > 1: dist.all_reduce(y) if self.bias is not None: y += self.bias return y class RMSNorm(nn.Module): """ Root Mean Square Layer Normalization (RMSNorm). Args: dim (int): Dimension of the input tensor. eps (float): Epsilon value for numerical stability. Defaults to 1e-6. """ def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor): """ Forward pass for RMSNorm. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Normalized tensor with the same shape as input. """ return F.rms_norm(x, (self.dim,), self.weight, self.eps) class RMSNorm_int(nn.Module): def __init__(self, dim: int, dtype, eps: float = 1e-6): super().__init__() self.dim = dim self.eps = eps self.register_buffer( "weight", torch.ones(dim, dtype=dtype)) def forward(self, x: torch.Tensor): # x 的 scale 为 2 ** 31 # weight的scale 为 2 ** 15, 范围为 2^7 - 2^14 # rms 的 scale 为 2 ** 28 # 返回的结果 scale 为 2 ** 16,因为中间计算的时候 除以了 (1 << 15),44 + 15 - 28 - 15 = 16 (c, rms) = RMS_Norm_int64(x[0], self.weight, 1, self.dim) return (c[None, :], rms) def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: """ Precomputes frequency-based complex exponential values for rotary positional embeddings. Args: args (ModelArgs): Model arguments containing positional embedding parameters. Returns: torch.Tensor: Precomputed complex exponential values for positional embeddings. """ # dim = 64 dim = args.qk_rope_head_dim seqlen = args.max_seq_len beta_fast = args.beta_fast beta_slow = args.beta_slow base = args.rope_theta factor = args.rope_factor def find_correction_dim(num_rotations, dim, base, max_seq_len): """ Computes the correction dimension for a given number of rotations in the rotary positional embedding. Args: num_rotations (float): Number of rotations to compute the correction for. dim (int): Dimensionality of the embedding space. base (float): Base value for the exponential computation. max_seq_len (int): Maximum sequence length. Returns: float: The correction dimension based on the input parameters. """ return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): """ Computes the range of correction dimensions for rotary positional embeddings. Args: low_rot (float): Lower bound for the number of rotations. high_rot (float): Upper bound for the number of rotations. dim (int): Dimensionality of the embedding space. base (float): Base value for the exponential computation. max_seq_len (int): Maximum sequence length. Returns: Tuple[int, int]: The range of correction dimensions (low, high), clamped to valid indices. """ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) return max(low, 0), min(high, dim-1) def linear_ramp_factor(min, max, dim): """ Computes a linear ramp function used to smooth values between a minimum and maximum range. Args: min (float): Minimum value for the ramp function. max (float): Maximum value for the ramp function. dim (int): Dimensionality of the ramp tensor. Returns: torch.Tensor: A tensor of shape (dim,) with values linearly interpolated between 0 and 1, clamped to the range [0, 1]. """ if min == max: max += 0.001 linear_func = (torch.arange(dim, dtype=torch.float32) - min) / (max - min) ramp_func = torch.clamp(linear_func, 0, 1) return ramp_func # torch.arange(0, dim, 2, dtype=torch.float32) 的作用是: 生成从 0 开始、步长为 2、到 dim 之前(不含 dim)的一维张量,数据类型为 float32 # 1/10000^(2k/d_model) # freqs shape: 一维向量,长度为 dim /2 freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) # original_seq_len=4096 if seqlen > args.original_seq_len: low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth t = torch.arange(seqlen) # torch.outer 的作用是计算两个向量的 外积 (outer product),比如: # t = torch.tensor([1, 2, 3]) # shape = [3] # freqs = torch.tensor([10, 20]) # shape = [2] # out = torch.outer(t, freqs) # tensor([[10, 20], # [20, 40], # [30, 60]]) # freqs shape为 [seqlen, dim/2] freqs = torch.outer(t, freqs) # torch.polar(abs, angle) 的作用: 把 极坐标 (r, θ) 转换成 复数 (x + iy) 的函数 # freqs_cis_0 shape为 [seqlen, dim/2] freqs_cis_0 = torch.polar(torch.ones_like(freqs), freqs) # return freqs_cis_0 # 复数转换成实数, freqs_cis_1 shape为 [seqlen, dim] freqs_cis_1 = torch.view_as_real(freqs_cis_0) # freqs_cis = torch.empty_like(freqs_cis_1, dtype=torch.int64, device='cuda') # cols 为 2 * freqs_cis_1.shape[1] 是因为 复数的实部 和 虚部 # rescale 参数为 19 = 42 - 23, ex 部分加 +19,总的rescale为 2^42 freqs_cis = (freqs_cis_1 * (2 ** 42)).round().to(torch.int64) freqs_cis_abs = freqs_cis.abs() min1 = freqs_cis_abs.min() max1 = freqs_cis_abs.max() print(f'freqs_cis min {min1}, max: {max1}', flush=True) # print(f'freqs_cis: {freqs_cis}') # freqs_cis 的 rescale 为 2^42 return freqs_cis # x(q_pe) 的维度 [batch, seqLen, 128, 64] def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: """ Applies rotary positional embeddings to the input tensor. Args: x (torch.Tensor): Input tensor with positional embeddings to be applied. freqs_cis (torch.Tensor): Precomputed complex exponential values for positional embeddings. Returns: torch.Tensor: Tensor with rotary embeddings applied. """ # if x.dtype == torch.int64: # x 的维度 变为 [batch, seqLen, 128, 32, 2] ### important!!! 调用 so lib库之前,必须确保内存连续 x = x.contiguous().view(*x.shape[:-1], -1, 2) # freqs_cis 的维度为 [1, seqLen, 1, 32, 2] freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-2), 2) # freqs_cis = freqs_cis.view(1, x.size(1), 1, x.size(-1)) # 4194304 = 1 << (64 - 42), 42是 rescale, int64 * int64 结果的高 64位 乘以 4194304 # 4398046511104 = 1 << 42 # print(x) # print(f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}') # y = complex_int64_mul_broadcast(x, freqs_cis, 4194304, 4398046511104) y = complex_int64_mul_broadcast(x, freqs_cis) y2 = y.flatten(3) return y2 def getBF16PrintStr(ele): v = int(ele.cpu().view(torch.uint16).item()) ex = v >> 7 & 0xFF r = '(1+' + str(v & 0x7F) + '/128)' rraw = v & 0x7F if v & 0x8000: vstr = '-' + r + '*2^' + str(ex - 127) else: vstr = r + '*2^' + str(ex - 127) return vstr class MLA(nn.Module): """ Multi-Headed Attention Layer (MLA). Attributes: dim (int): Dimensionality of the input features. n_heads (int): Number of attention heads. n_local_heads (int): Number of local attention heads for distributed systems. q_lora_rank (int): Rank for low-rank query projection. kv_lora_rank (int): Rank for low-rank key/value projection. qk_nope_head_dim (int): Dimensionality of non-positional query/key projections. qk_rope_head_dim (int): Dimensionality of rotary-positional query/key projections. qk_head_dim (int): Total dimensionality of query/key projections. v_head_dim (int): Dimensionality of value projections. softmax_scale (float): Scaling factor for softmax in attention computation. """ def __init__(self, layer_id, args: ModelArgs): super().__init__() # RowParallelLinear和ColumnParallelLinear是将Linear层按照行和列划分为多个子线性层并分配到各个设备上,每个设备维护一个子线性层, # 如线性层的shape为[in_features, out_features],RowParallelLinear的shape为[in_features/world_size, out_features], # ColumnParallelLinear的shape为[in_features,out_features/world_size],world_size是设备数 self.layer_id = layer_id # 7168 self.dim = args.dim # 128 self.n_heads = args.n_heads # 当前进程跑的header数目 self.n_local_heads = args.n_heads // world_size # query向下投影矩阵维度,默认为0表示不压缩,实际使用过程为 1536 self.q_lora_rank = args.q_lora_rank # key和value向下投影矩阵维度,实际使用过程为 512; self.kv_lora_rank = args.kv_lora_rank # query/key不包含位置信息的隐藏层维度, 实际使用过程为 128 self.qk_nope_head_dim = args.qk_nope_head_dim # query/key包含rope位置信息的隐藏层维度, 实际使用过程为 64 self.qk_rope_head_dim = args.qk_rope_head_dim # 192 self.qk_head_dim = args.qk_nope_head_dim + args.qk_rope_head_dim # value隐藏层维度, 实际使用过程为 128 self.v_head_dim = args.v_head_dim # query向下投影矩阵维度,默认为0表示不压缩,实际使用过程为 1536 if self.q_lora_rank == 0: self.wq = ColumnParallelLinear(layer_id, self.dim, self.n_heads * self.qk_head_dim) else: # query向下投影矩阵, shape [7168, 1536], Float8_e4m3fnTensor self.wq_a = Linear_int(layer_id, self.dim, self.q_lora_rank, 1, 1, 30, torch.int32) self.q_norm = RMSNorm_int(self.q_lora_rank, torch.int32) # query向上投影矩阵的列并行线性层, shape [1536, 24576(128 * 192)], Float8_e4m3fnTensor # self.wq_b = ColumnParallelLinear_int(layer_id, self.q_lora_rank, self.n_heads * self.qk_head_dim, 1, 1, (1 << 30), torch.int32) self.wq_b1 = ColumnParallelLinear_int(layer_id, self.q_lora_rank, self.n_heads * args.qk_nope_head_dim, 1, 1, 30, torch.int32) self.wq_b2 = ColumnParallelLinear_int(layer_id, self.q_lora_rank, self.n_heads * args.qk_rope_head_dim, 1, 1, 30, torch.int32) # key和value的向下投影矩阵, shape [576, 7168], Float8_e4m3fnTensor, kv_lora_rank=512, qk_rope_head_dim=64 # self.wkv_a = Linear_int(layer_id, self.dim, self.kv_lora_rank + self.qk_rope_head_dim, 1, 1, (1 << 29), torch.int32) self.wkv_a1 = Linear_int(layer_id, self.dim, self.kv_lora_rank, 1, 1, 29, torch.int32) self.wkv_a2 = Linear_int(layer_id, self.dim, self.qk_rope_head_dim, 1, 1, 29, torch.int32) # self.kv_norm = RMSNorm(self.kv_lora_rank) self.kv_norm = RMSNorm_int(self.kv_lora_rank, torch.int32) # key和value向上投影矩阵的列并行线性层, shape [32768, 512], Float8_e4m3fnTensor # kv_lora_rank=512, n_heads = 128, qk_nope_head_dim = 128, v_head_dim = 128 # self.wkv_b = ColumnParallelLinear(layer_id, self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wkv_b_1 = ColumnParallelLinear_rescale_int(layer_id, self.kv_lora_rank, self.n_heads * self.qk_nope_head_dim, 1, 1, torch.int32) self.wkv_b_2 = ColumnParallelLinear_rescale_int(layer_id, self.kv_lora_rank, self.n_heads * self.v_head_dim, 1, 1, torch.int32) # 输出投影行并行线性层, shape [7168, 16384], Float8_e4m3fnTensor self.wo = RowParallelLinear_rescale_int(layer_id, self.n_heads * self.v_head_dim, self.dim, 1, 1, 1, torch.int32) # softmax缩放系数, qk_head_dim = 192 # self.softmax_scale = self.qk_head_dim ** -0.5 # # max_seq_len = 4096 * 4, original_seq_len = 4096 # if args.max_seq_len > args.original_seq_len: # # mscale = 1.0, rope_factor = 40, math.log = ln 自然对数 # mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 # self.softmax_scale = self.softmax_scale * mscale * mscale self.softmax_scale1 = 94 self.softmax_scale2 = 695 if attn_impl == "naive": self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False) self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False) else: # 缓存key和value向下投影表示 # self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) # self.register_buffer("kv_cache", torch.zeros(1, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("kv_cache", torch.zeros(1, args.max_seq_len, self.kv_lora_rank, dtype=torch.int64), persistent=False) # 缓存key执行rope操作后的表示 # self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) # self.register_buffer("pe_cache", torch.zeros(1, args.max_seq_len, self.qk_rope_head_dim), persistent=False) self.register_buffer("pe_cache", torch.zeros(1, args.max_seq_len, self.qk_rope_head_dim, dtype=torch.int64), persistent=False) # x shape [1, seqLen, 7168], x 的resacle 为 2^21 def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): """ Forward pass for the Multi-Headed Attention Layer (MLA). Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). start_pos (int): Starting position in the sequence for caching. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. Returns: torch.Tensor: Output tensor with the same shape as the input. """ # 从输入获取batch size和序列长度seqlen,并根据输入序列的起始位置计算输入序列的结束位置end_pos=start_pos+seqlen; bsz, seqlen, _ = x.size() end_pos = start_pos + seqlen # 获取query的投影表示:如果对query投影矩阵进行压缩(即q_lora_rank不为0),则将输入乘以query的向下投影矩阵wq_a,然后经过归一化层q_norm, # 再乘以向上投影矩阵wq_b,否则直接乘以原始投影矩阵wq;将其维度调整为[batchsize, n_local_threads, qk_head_dim]; if self.q_lora_rank == 0: q = self.wq(x) else: # query向下投影矩阵, shape [7168, 1536], Float8_e4m3fnTensor # x(也就是 attn_normed) 的 scale 为 2^21, wq_a weight 的scale 为 2^30, q_down 的 scale 为 2^21 q_down, q_down_rem = self.wq_a(x) # q_down = self.wq_a(x) if snark: dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}' os.makedirs(dirStr, exist_ok=True) saveTensor(f'{dirStr}/wq_a_x.bin', x.cpu()) saveTensor(f'{dirStr}/wq_a_w.bin', self.wq_a.weight.view(torch.uint32).cpu()) saveTensor(f'{dirStr}/wq_a_y.bin', q_down.cpu()) saveTensor(f'{dirStr}/q_norm_r.bin', q_down_rem.cpu()) # q_down = (q_down.detach().to(torch.float32) * (2 ** -23)).to(torch.bfloat16) # q_norm 的 rescale 为 2^19 (q_normed, rms) = self.q_norm(q_down) if snark: dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}' os.makedirs(dirStr, exist_ok=True) saveTensor(f'{dirStr}/q_norm_x.bin', q_down.cpu()) saveTensor(f'{dirStr}/q_norm_weight.bin', self.q_norm.weight.view(torch.uint32).cpu()) saveTensor(f'{dirStr}/q_norm_rms.bin', rms.cpu()) saveTensor(f'{dirStr}/q_norm_y.bin', q_normed.cpu()) # q 的 rescale 为 2^19 # q = self.wq_b(q_normed) q_nope = self.wq_b1(q_normed) q_pe = self.wq_b2(q_normed) # 在pytorch中view函数的作用为重构张量的维度 # q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q_nope = q_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim) q_pe = q_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim) # 将query的投影表示按照最后一个维度拆分,前面qk_nope_head_dim维(128)作为query不包含位置信息的表示q_nope,后面qk_rope_head_dim维(64)添加rope位置信息 # (调用apply_rotary_emb函数,参考秀才经商:DeepSeek源码解析之RoPE)作为query包含位置信息的表示q_pe(即公式39); # q_nope 的维度[batch, seqLen, 128, 128], q_pe 的维度 [batch, seqLen, 128, 64] # q_nope, q_pe 的 rescale 为 2^19 # q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) # freqs_cis 的 rescale 为 2^42, 计算之后 q_pe 的 rescale 为 2^19 if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/q_pe_x.bin', q_pe.cpu()) saveTensor(f'{zkDataDir}/freqs_cis.bin', freqs_cis.cpu()) q_pe = apply_rotary_emb(q_pe, freqs_cis) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/q_pe_y.bin', self.q_norm.weight.view(torch.uint32).cpu()) # 获取key和value的联合表示kv(即公式41中的)和包含位置信息的key表示k_pe(即公式43中的):输入乘以向下投影矩阵wkv_a后,按照最后一个维度拆分, # 前面kv_lora_rank维作为key和value的联合表示,后面qk_rope_head_dim维添加rope位置信息(调用apply_rotary_emb)后得到包含rope位置信息的key表示; # x 的resacle 为 2^21, kv shape [batch, seqLen, 512], kv 的resacle 为 2^21 kv, kv_rem = self.wkv_a1(x) if snark: dirStr = f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}' os.makedirs(dirStr, exist_ok=True) saveTensor(f'{dirStr}/wkv_a1_x.bin', x.cpu()) saveTensor(f'{dirStr}/wkv_a1_w.bin', self.wkv_a1.weight.view(torch.uint32).cpu()) saveTensor(f'{dirStr}/wkv_a1_y.bin', kv.cpu()) saveTensor(f'{dirStr}/wkv_a1_r.bin', kv_rem.cpu()) k_pe, _ = self.wkv_a2(x) # print(f'k_pe 1 shape: {k_pe.shape}', flush=True) # unsqueeze()用于增加一个维度, k_pe.unsqueeze(2) 把 k_pe reshape 成 [batch, seqLen, 1, dim] # # kv, k_pe 的resacle 为 2^21 k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) # print(f'k_pe 2 shape: {k_pe.shape}', flush=True) if attn_impl == "naive": q = torch.cat([q_nope, q_pe], dim=-1) kv = self.wkv_b(self.kv_norm(kv)) kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) self.k_cache[:bsz, start_pos:end_pos] = k self.v_cache[:bsz, start_pos:end_pos] = v scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale else: # 计算query和key的注意力: # query中不包含位置信息的q_nope(乘以了key的向上投影矩阵后)与缓存kv_cache中的key表示求内积; # query中包含位置信息的q_pe与缓存pe_cache中的key表示求内积; # 两者相加后乘以softmax缩放系数softmax_scale # q_nope 的维度[batch, seqLen, 128, 128], wkv_b_1 shape: [128, 128, 512] # q_nope rescale 2^19, wkv_b_1 rescale 2 ** 32 # q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b_1) # 调用 einsum_bshd_hdc_bshc 之后, q_nope维度 [batch, seqLen, 128, 512] wkv_b_1 = self.wkv_b_1.weight.view(self.n_local_heads, -1, self.kv_lora_rank) q_nope = einsum_bshd_hdc_bshc(q_nope.contiguous(), wkv_b_1.contiguous(), self.wkv_b_1.scale.item()) # print('q_nope type: ' + str(q_nope.type())) # print('q_nope shape: ' + str(q_nope.shape)) # kv_normed 的 rescale 为 2^23 (kv_normed, rms) = self.kv_norm(kv) # kv_cache 的 rescale 为 2^23, shape [batch, seqLen, 512], self.kv_cache[:bsz, start_pos:end_pos] = kv_normed # self.kv_cache[:bsz, start_pos:end_pos] = kv2 # kv = (kv.detach().to(torch.float32) * (2 ** -23)).to(torch.bfloat16) # pe_cache 的 rescale 为 2^21 self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) # q_nope rescale: 2^19, kv_cache rescale: 2^23 # q_nope 的维度 [batch, seqLen, 128, 512], kv_cache 维度 (batch, args.max_seq_len, 512) # score1 = torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) kv_cache1 = self.kv_cache[:bsz, :end_pos] # score1 = einsum_bshc_btc_bsht(q_nope.contiguous(), kv_cache1.contiguous(), 25) # score1 的 rescale 为 2^19 score1 = einsum_bshc_btc_bsht(q_nope.contiguous(), kv_cache1.contiguous(), 23) # print(f'kv_cache1 type: {kv_cache1.type()}, shape: {kv_cache1.shape}', flush=True) # score1 = (score1.detach().to(torch.float32) * (2 ** -21)).to(torch.bfloat16) # score2 = torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) pe_cache1 = self.pe_cache[:bsz, :end_pos] # score2 = einsum_bshc_btc_bsht(q_pe.contiguous(), pe_cache1.contiguous(), 23) # q_pe 的 rescale 为 2^19, scores2 的rescale 为 2^19 score2 = einsum_bshc_btc_bsht(q_pe.contiguous(), pe_cache1.contiguous(), 21) # score2 = (score2.detach().to(torch.float32) * (2 ** -21)).to(torch.bfloat16) # scores = (score1 + score2) * self.softmax_scale # scores 的 rescale 为 2 ** 19 scores = (score1 + score2) * self.softmax_scale1 // self.softmax_scale2 # scores = torch.round(((score1 + score2) * self.softmax_scale1).to(torch.float32) / self.softmax_scale2).to(torch.int64) # mask 在 unsqueeze(1) 之后的 shape 为 [seqLen, 1, senLen], scores 的shape 为 [batch, seqLen, heads , t] if mask is not None: # print('mask type: ' + str(mask.type())) # print('mask shape: ' + str(mask.shape)) scores += mask.unsqueeze(1) # query和key的内积按照最后一个维度计算softmax值; # scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) scores_new = torch.empty_like(scores, dtype=torch.int64, device='cuda') # scores 和 scores_new 的 rescale 为 2 ** 19, shape: [bsz, seqLen, headCount, seqLen] # # softmax_q19 会破坏 scores 的原始数据,先拷贝一份数据 if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_x.bin', scores.contiguous().cpu()) softmax_q19(scores.contiguous(), scores_new) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/scores_softmax_y.bin', scores_new.cpu()) if attn_impl == "naive": x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) else: kv_cache2 = self.kv_cache[:bsz, :end_pos] # kv_cache2 = (kv_cache2.detach().to(torch.float32) * (2 ** -25)).to(torch.bfloat16) # x = (x.detach().to(torch.float32) * (2 ** -23)).to(torch.bfloat16) # 计算最终输出: # 注意力分数乘以kv缓存后,再乘以value的向上投影矩阵wkv_b(实现公式45和46); # 乘以输出投影矩阵wo(公式47); # x = torch.einsum("bsht,btc->bshc", scores_new, kv_cache2) # scores_new 的 rescale 为 2^19, kv_cache2 的 rescale 为 2^23, bshc 的 rescale 为 2^19 # scores_new shape: [1, 8, 128, 8], bshc shape: [1, 8, 128, 512] # bshc = einsum_bsht_btc_bshc(scores_new.contiguous(), kv_cache2.contiguous(), 25) bshc = einsum_bsht_btc_bshc(scores_new.contiguous(), kv_cache2.contiguous(), 23) # # v_head_dim = 128, kv_lora_rank = 512, n_local_heads = 128 # wkv_b_2 = wkv_b[:, -self.v_head_dim:] # # print('wkv_b 2 type: ' + str(wkv_b_2.type())) # # print('wkv_b 2 shape: ' + str(wkv_b_2.shape)) wkv_b_2 = self.wkv_b_2.weight wkv_b_2 = wkv_b_2.view(self.n_local_heads, -1, self.kv_lora_rank) # wkv_b_2 = (wkv_b_2.detach().to(torch.float32) * (2 ** -self.wkv_b_2.scale.item())).to(torch.bfloat16) # x = torch.einsum("bshc,hdc->bshd", x, wkv_b_2) # bshc 的 rescale 为 2^19, wkv_b_2 的 rescale 为 self.wkv_b_2.scale # x 的 rescale 为 2 ** 19 # bshc shape: [1, seqLen, 128, 512], wkv_b_2 shape: [128, 128, 512] x = einsum_bshc_hdc_bshd(bshc.contiguous(), wkv_b_2.contiguous(), self.wkv_b_2.scale.item()) # x = (x.detach().to(torch.float32) * (2 ** -21)).to(torch.bfloat16) # x 返回的的 shape [1, seqLen, 7168] x = self.wo(x.flatten(2)) return x class MLP(nn.Module): """ Multi-Layer Perceptron (MLP) used as a feed-forward layer. Attributes: w1 (nn.Module): Linear layer for input-to-hidden transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ def __init__(self, layer_id, dim: int, inter_dim: int): """ Initializes the MLP layer. Args: dim (int): Input and output dimensionality. inter_dim (int): Hidden layer dimensionality. """ super().__init__() self.w1 = ColumnParallelLinear(layer_id, dim, inter_dim) self.w2 = RowParallelLinear(layer_id, inter_dim, dim) self.w3 = ColumnParallelLinear(layer_id, dim, inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the MLP layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor after MLP computation. """ return self.w2(F.silu(self.w1(x)) * self.w3(x)) class MLP_int(nn.Module): """ Multi-Layer Perceptron (MLP) used as a feed-forward layer. Attributes: w1 (nn.Module): Linear layer for input-to-hidden transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ def __init__(self, layer_id, dim: int, inter_dim: int): """ Initializes the MLP layer. Args: dim (int): Input and output dimensionality. inter_dim (int): Hidden layer dimensionality. """ super().__init__() self.layer_id = layer_id self.w1 = ColumnParallelLinear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32) self.w2 = RowParallelLinear_rescale_int(layer_id, inter_dim, dim, 1, 1, 1, torch.int32) self.w3 = ColumnParallelLinear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32) # 输入的 x 的rescale 为 2^23, [bsz, seqLen, 7168] def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the MLP layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor after MLP computation. """ # r1 shape: [bsz, seqLen, inter_dim], r1 rescale: 2^23 r1 = self.w1(x) # s1 = F.silu(r1) # s1 shape: [bsz, seqLen, inter_dim], s1 rescale: 2^23 s1 = torch.empty_like(r1, dtype=torch.int64, device='cuda') # silu_q25(r1, s1) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_x.bin', r1.contiguous().cpu()) silu_q23(r1, s1) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/mlp_silu_y.bin', s1.cpu()) # r2 rescale: 2^23, shape: [1, seqLen, inter_dim] r2 = self.w3(x) # 返回的 shape [bsz, seqLen, dim] q = self.w2(s1 * r2 // (1 << 23)) return q class Gate(nn.Module): """ Gating mechanism for routing inputs in a mixture-of-experts (MoE) model. Attributes: dim (int): Dimensionality of input features. topk (int): Number of top experts activated for each input. n_groups (int): Number of groups for routing. topk_groups (int): Number of groups to route inputs to. score_func (str): Scoring function ('softmax' or 'sigmoid'). route_scale (float): Scaling factor for routing weights. weight (torch.nn.Parameter): Learnable weights for the gate. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. """ def __init__(self, layer_id: int, args: ModelArgs): """ Initializes the Gate module. Args: args (ModelArgs): Model arguments containing gating parameters. """ super().__init__() self.layer_id = layer_id self.dim = args.dim # n_activated_experts = 8 self.topk = args.n_activated_experts # n_expert_groups = 8 self.n_groups = args.n_expert_groups # n_limited_groups = 4 self.topk_groups = args.n_limited_groups # score_func = 'sigmoid' self.score_func = args.score_func # route_scale = 2.5 self.route_scale = args.route_scale # n_routed_experts = 256 # self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) self.register_buffer("weight", torch.empty(args.n_routed_experts, args.dim, dtype=torch.int32)) self.register_buffer("scale", torch.tensor(0, dtype=torch.int32)) # self.bias = nn.Parameter(torch.empty(args.n_routed_experts, dtype=torch.int32)) if self.dim == 7168 else None if self.dim == 7168: self.register_buffer("bias", torch.empty(args.n_routed_experts, dtype=torch.int32)) else: self.bias = None # x 的 rescale 为 2^23 def forward(self, start_pos: int, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ Forward pass for the gating mechanism. Args: x (torch.Tensor): Input tensor. Returns: Tuple[torch.Tensor, torch.Tensor]: Routing weights and selected expert indices. """ x = x.view(1, -1, self.dim) # scores = linear(x, self.weight) # self.weight shape: [256, 7168] # 当前 scores shape: [1, seqLen, 256] # rescale = 2 ** self.scale.item() rescale = self.scale.item() # scores 的 rescale 为 2^23 scores, scores_rem = linear_int(x, self.weight, 1, 1, rescale) # scores = int64_bmm_with_bias(x, self.weight, bias, 1, 1, self.scale) # x shape: [seqLen, 7168] x = x.view(-1, self.dim) if self.score_func == "softmax": scores = scores.softmax(dim=-1, dtype=torch.float32) else: # scores = scores.sigmoid() C = torch.empty_like(scores, dtype=torch.int64, device='cuda') if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_x.bin', scores.cpu()) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_r.bin', scores_rem.cpu()) sigmoid_q23(scores, C) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/sigmoid_gate_y.bin', C.cpu()) # 当前 scores shape: [seqLen, 256] scores = C.squeeze(0) # bias的rescale为2^23 original_scores = scores if self.bias is not None: # scores = scores + self.bias # 当前 scores shape: [seqLen, 256] scores = scores + self.bias if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_original_scores.bin', original_scores.contiguous().cpu()) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_bias.bin', self.bias.view(torch.uint32).cpu()) # n_groups = 8 if self.n_groups > 1: # x.size(0) = 8,当前 scores shape: [seqLen, 8, 32] scores = scores.view(x.size(0), self.n_groups, -1) # print(f'scores shape 111: {scores.shape}', flush=True) if self.bias is None: group_scores = scores.amax(dim=-1) else: # topk 返回 -1维度上 最大的 前 2 个值,同时返回值和索引,[0] 表示 取值,sum(-1) 再把最大的两个值相加. # 256维,分成8个组,每个组挑最大的两个数相加,得到 [seqLen, 8] 的结果,代表 8 个组的 最大两个值的和。 # group_scores 的 shape: [8, 8] group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) # print(group_scores[0], flush=True) # print(f'group_scores shape: {group_scores.shape}') # topk_groups = 4, 从 8 个group中选择最大的 4个,返回其索引,比如返回 [[0, 2, 4, 6], ...] # indices shape: [seqLen, 4] indices = group_scores.topk(self.topk_groups, dim=-1)[1] # print(indices[0], flush=True) # mask shape: [seqLen, 8] # scatter_: 按照给定索引,把某个源张量的值写入到目标张量对应位置。 Tensor.scatter_(dim, index, src, reduce=None) # 比如 mask 为[[False, True, False, True, False, True, False, True], ...] # mask: 每一行最大的4个值相对应的 mask 为 False mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) # print(mask[0], flush=True) # 把满足布尔 mask 的位置替换成 "-inf", mask.unsqueeze(-1) shape: [8, 8, 1] # 把 scores 中 淘汰掉的4个group中的每一个值设置为 "-inf",总共设置 128个 "-inf",占每一行中的一半 # scores shape: [seqLen, 256] # scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) scores = scores.masked_fill_(mask.unsqueeze(-1), -(1 << 42)).flatten(1) # 没有淘汰掉的group中的 128个值中,选择最大的8个值,返回其下标 # self.topk = 8, indices shape: [8, 8] indices = torch.topk(scores, self.topk, dim=-1)[1] # print(indices[0], flush=True) # gather 用来按照索引从一个张量中取值,按照8个最大值的下标,获取其值 # weights shape: [8, 8] weights = original_scores.gather(1, indices) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_indices.bin', indices.contiguous().cpu()) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/gate_weights.bin', weights.contiguous().cpu()) # print(f'weights shape: {weights.shape}') if self.score_func == "sigmoid": sum1 = weights.sum(dim=-1, keepdim=True) # weights = (weights * (2 ** 25) + sum1 // 2) // sum1 weights = (weights * (2 ** 23)) // sum1 # weights /= weights.sum(dim=-1, keepdim=True) #self.route_scale = 2.5 # weights *= self.route_scale weights = weights * 5 // 2 # weights = (weights.to(torch.float32) * (2 ** -23)).to(torch.bfloat16) # return weights.type_as(x), indices return weights, indices class Expert_int(nn.Module): """ Expert layer for Mixture-of-Experts (MoE) models. Attributes: w1 (nn.Module): Linear layer for input-to-hidden transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ def __init__(self, layer_id, idx, dim: int, inter_dim: int): """ Initializes the Expert layer. Args: dim (int): Input and output dimensionality. inter_dim (int): Hidden layer dimensionality. """ super().__init__() # # w1 shape: [2048, 7168] # self.w1 = Linear(layer_id, dim, inter_dim) # # w2 shape: [7168, 2048] # self.w2 = Linear(layer_id, inter_dim, dim) # # w3 shape: [2048, 7168] # self.w3 = Linear(layer_id, dim, inter_dim) self.layer_id = layer_id self.idx = idx self.w1 = Linear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32) self.w2 = Linear_rescale_int(layer_id, inter_dim, dim, 1, 1, torch.int32) self.w3 = Linear_rescale_int(layer_id, dim, inter_dim, 1, 1, torch.int32) # todo: add row id in the forward function def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the Expert layer. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor after expert computation. """ # 返回的 shape [bsz, seqLen, 7168] # return self.w2(F.silu(self.w1(x)) * self.w3(x)) # r1 shape: [bsz, seqLen, 18432], r1 rescale: 2^23 r1 = self.w1(x) # s1 = F.silu(r1) # s1 shape: [bsz, seqLen, 18432], s1 rescale: 2^23 s1 = torch.empty_like(r1, dtype=torch.int64, device='cuda') # silu_q25(r1, s1) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_x.bin', r1.contiguous().cpu()) silu_q23(r1, s1) if snark: saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/expert_{self.idx}_silu_y.bin', s1.cpu()) # r2 rescale: 2^23 r2 = self.w3(x) # 返回的 shape [bsz, seqLen, 7168] q = self.w2((s1 * r2) >> 23) return q class MoE(nn.Module): """ Mixture-of-Experts (MoE) module. Attributes: dim (int): Dimensionality of input features. n_routed_experts (int): Total number of experts in the model. n_local_experts (int): Number of experts handled locally in distributed systems. n_activated_experts (int): Number of experts activated for each input. gate (nn.Module): Gating mechanism to route inputs to experts. experts (nn.ModuleList): List of expert modules. shared_experts (nn.Module): Shared experts applied to all inputs. """ def __init__(self, layer_id, args: ModelArgs, ckpt_path): """ Initializes the MoE module. Args: args (ModelArgs): Model arguments containing MoE parameters. """ super().__init__() self.layer_id = layer_id self.ckpt_path = ckpt_path self.dim = args.dim self.moe_inter_dim = args.moe_inter_dim assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.gate = Gate(layer_id, args) # moe_inter_dim = 2048 # self.experts = nn.ModuleList([Expert(layer_id, args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None # for i in range(self.n_routed_experts)]) # self.experts = torch.nn.ModuleList() # dim = 7168, n_shared_experts = 1, moe_inter_dim = 2048 self.shared_experts = MLP_int(layer_id, args.dim, args.n_shared_experts * args.moe_inter_dim) # x 的 rescale 为 2^23, shape: [1, seqLen, 7168] def forward(self, start_pos: int, x: torch.Tensor) -> torch.Tensor: """ Forward pass for the MoE module. Args: x (torch.Tensor): Input tensor. Returns: torch.Tensor: Output tensor after expert routing and computation. """ # ffn_normed 的 rescale 为 2^23 # x = (x.to(torch.float32) * (2 ** -23)).to(torch.bfloat16) # z rescale: 2^23, z 的 shape [seqLen, 7168] z = self.shared_experts(start_pos, x) # x shape 之前为: [bsz, seqLen, 7168], 之后为 [8, 7168] shape = x.size() x = x.view(-1, self.dim) # weights shape: [seqLen, 8], indices shape: [seqLen, 8] # weights 的 rescale 为 2^23 weights, indices = self.gate(start_pos, x) # y shape: [seqLen, 7168] y = torch.zeros_like(x) # torch.bincount 用来统计非负整数张量中各个数值出现的次数,类似于直方图计数 # torch.bincount(input, weights=None, minlength=0) -> Tensor, weights: 可选的一维浮点张量,和 input 形状一致。若提供,就不是“次数统计”,而是“权重和” # 统计 256 个 专家 出现的次数 counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue # expert = self.experts[i] with torch.device("cuda"): expert = Expert_int(self.layer_id, i, self.dim, self.moe_inter_dim) # load_model(expert, f'/data3/DeepSeek-V3-Demo1/experts-{self.layer_id}/{i}.safetensors') expertModelPath = os.path.join(self.ckpt_path, f"experts-{self.layer_id}/{i}.safetensors") load_model(expert, expertModelPath) # 第 idx 个 token, 专家 i 出现的编号是 top # 比如 # [0, 1, 3, 2, 5, 4, 6, 9] # [7, 8, 3, 12, 5, 11, 6, 1] # [16, 10, 3, 2, 15, 4, 6, 9] # [10, 21, 3, 2, 5, 4, 1, 9] # torch.where(indices == 1) 返回的结果是 ([0, 1, 3], [1, 7, 6]) idx, top = torch.where(indices == i) # expert(x[idx]) 返回的 shape [seqLen, 2048], weights[idx, top, None] 的 shape 为 [seqLen, 1], 包含一个 weight 值 # y[idx] += expert(x[idx]) * weights[idx, top, None] x2 = x[idx].unsqueeze(0) y2 = expert(start_pos, x2) y2 = y2.view(-1, self.dim) # y[idx] += y2 * weights[idx, top, None] // (1 << 25) y[idx] += y2 * weights[idx, top, None] // (1 << 23) # z = self.shared_experts(x) if world_size > 1: dist.all_reduce(y) return (y + z).view(shape) def getBF8PrintStr(ele): v = int(ele.cpu().view(torch.uint8).item()) ex = v >> 3 & 0xF r = v & 0x7 if ex == 15 and r == 7: print(f'BF8 Nan: {ex} {r} !!!', flush=True) elif ex == 0: print(f'BF8 subnormal: {ex} {r} !!!', flush=True) if v & 0x80: vstr = f'-{ex} {r}' else: vstr = f'{ex} {r}' return vstr class Block(nn.Module): """ Transformer block combining attention and feed-forward layers. Attributes: attn (nn.Module): Attention layer (MLA). ffn (nn.Module): Feed-forward network (MLP or MoE). attn_norm (nn.Module): Layer normalization for attention. ffn_norm (nn.Module): Layer normalization for feed-forward network. """ def __init__(self, layer_id: int, args: ModelArgs, ckpt_path): """ Initializes the Transformer block. Args: layer_id (int): Layer index in the transformer. args (ModelArgs): Model arguments containing block parameters. """ super().__init__() self.layer_id = layer_id self.ckpt_path = ckpt_path self.attn = MLA(layer_id, args) self.ffn = MLP_int(layer_id, args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(layer_id, args, ckpt_path) # print('args.dim: ' + str(args.dim)) # args.dim = 7168 self.attn_norm = RMSNorm_int(args.dim, torch.int32) self.ffn_norm = RMSNorm_int(args.dim, torch.int32) # self.ffn_norm = RMSNorm(args.dim) def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: """ Forward pass for the Transformer block. Args: x (torch.Tensor): Input tensor. start_pos (int): Starting position in the sequence. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. mask (Optional[torch.Tensor]): Mask tensor to exclude certain positions from attention. Returns: torch.Tensor: Output tensor after block computation. """ x_abs = x.abs() x_abs_min = x_abs.min().item() x_abs_max = x_abs.max().item() print(f'x abs min: {x_abs_min}, max: {x_abs_max}', flush=True) # self.attn_norm(x): 在进行attention之前,先将7168维的embeding 进行 归一化 # attn_norm 的 scale 为 2^21, x 的 scale 为 2^31 (atten_normed, rms) = self.attn_norm(x) if snark: os.makedirs(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}', exist_ok=True) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_x.bin', x.cpu()) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_weight.bin', self.attn_norm.weight.view(torch.uint32).cpu()) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_y.bin', atten_normed.cpu()) saveTensor(f'{zkDataDir}/pos_{start_pos}/layer_{self.layer_id}/attn_norm_rms.bin', rms.cpu()) # attned 的 rescale 是 2^19, shape: [1, seqLen, 7168] attned = self.attn(atten_normed, start_pos, freqs_cis, mask) # 调整 rescale,因为 x 的 rescale 是 2^31, attned 的 rescale 是 2^19,因此要乘以 2^12 # x = x + attned * (2 ** 10) x = x + attned * (2 ** 12) # ffn_normed 的 rescale 为 2^23 (ffn_normed, rms) = self.ffn_norm(x) ffned = self.ffn(start_pos, ffn_normed) # x = x + ffned * (2 ** 6) x = x + ffned * (2 ** 8) # 返回的 x 的rescale 为 2^31 return x # Transformer 类在初始化中就已经明确好了自己的进程(rank),并且可以发现它是由比较经典的transformer组件构成的: # embedding层(self.embed)、堆叠的decoding block(self.layers),标准的RMSnorm层(self.norm)与最后将隐藏状态投射到词表分布的output层(self.head) # 根据前面提及的初始化的参数来看,词表大小为129280,模型的hidden dim为7168,堆叠的decode block一共有61个。维度变换会在下面举例说明。 # Transformer 由61个Block组成,每个Block有 attn 和 ffd # Transformer类在初始化中就已经明确好了自己的进程(rank),并且可以发现它是由比较经典的transformer组件构成的 # embedding层(self.embed)、堆叠的decoding block(self.layers),标准的RMSnorm层(self.norm)与最后将隐藏状态投射到词表分布的output层(self.head)。 class Transformer(nn.Module): """ Transformer model with positional embeddings, multiple layers, and output projection. Attributes: max_seq_len (int): Maximum sequence length for the transformer. embed (nn.Module): Embedding layer for input tokens. layers (torch.nn.ModuleList): List of transformer blocks. norm (nn.Module): Layer normalization applied after all blocks. head (nn.Module): Output projection layer mapping to vocabulary size. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary(旋转的) embeddings. """ def __init__(self, args: ModelArgs): """ Initializes the Transformer model. Args: args (ModelArgs): Model arguments containing transformer parameters. """ global world_size, rank world_size = dist.get_world_size() if dist.is_initialized() else 1 rank = dist.get_rank() if dist.is_initialized() else 0 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.args = args self.max_seq_len = args.max_seq_len self.embed = ParallelEmbedding(args.vocab_size, args.dim) self.layers = torch.nn.ModuleList() for layer_id in range(args.n_layers): # self.layers.append(Block(layer_id, args)) self.layers.append(nn.Module()) self.norm = RMSNorm_int(args.dim, torch.int64) # self.head = ColumnParallelLinear(-1, args.dim, args.vocab_size, dtype=torch.get_default_dtype()) # 模型中的 head 的 rescale 为 2^43, 使用的过程中的rescale为 2^35, head 输入的 rescale为 2^15, 输出的 rescale为 2^21 # self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, 1, (1 << 8), (1 << 29), torch.int64) self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, 1, (1 << 8), 29, torch.int64) # self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, 1, (1 << 8), (1 << 31), torch.int64) # self.head = ColumnParallelLinear_int(-1, args.dim, args.vocab_size, (1 << 5), (1 << 11), (1 << 21), torch.int64) # register_buffer()注册了名为 "freqs_cis" 的缓冲区,缓冲区的值由 precompute_freqs_cis(args) 提供,并且由于设置了 persistent=False, # 该缓冲区不会被保存到模型的状态字典中。缓冲区注册的张量是该Transformer类的位置编码。 # register_buffer 用于注册一个非参数张量(tensor),这个张量虽然不是模型的可学习参数,但仍然是模型状态的一部分。 # 与参数不同,缓冲区不会在反向传播中计算梯度,也不会被优化器更新,但它会随模型一起移动到相应的设备(如 GPU)上。 # persistent=False表示这个参数表示该缓冲区不属于持久状态(persistent state)。也就是说,当你调用 model.state_dict() 保存模型时, # 这个缓冲区不会被包含进去。位置编码可以在模型加载后重新计算,不需要存储。 self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) @torch.inference_mode() def prep_inference(self, tokens: torch.Tensor, start_pos: int = 0): # softmax_init() softmax_init_q19() softmax_init_q21() silu_init_q23() seqlen = tokens.size(1) # h 是经过embed之后的结果,embed将文本表达转化为词嵌入,h的形状为 (batch_size, seq_len, 7168) h = self.embed(tokens) # h = h.to(torch.bfloat16) * (1.0 / (1 << 44)) return (h, start_pos, seqlen) @torch.inference_mode() def layer_inference(self, layer_id, h, start_pos, seqlen): freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] mask = None # triu = triangle up # 返回上三角矩阵 # 参数 k=0 代表主对角线,k 为正数则从主对角线开始向上数第 k 条,k 为负数则从主对角线开始向下数第 k 条 if seqlen > 1: # mask = torch.full((seqlen, seqlen), float("-inf"), device="cuda").triu_(1) mask = torch.full((seqlen, seqlen), -(64 << 36), dtype=torch.int64, device="cuda").triu_(1) h = self.layers[layer_id](h, start_pos, freqs_cis, mask) h_abs = (h.to(torch.float32) * (2 ** -31)).to(torch.bfloat16).abs() h_abs_max = h_abs.max() h_abs[h_abs < (2 ** -125)] = h_abs_max h_abs_min = h_abs.min() h_abs_min_str = getBF16PrintStr(h_abs_min) h_abs_max_str = getBF16PrintStr(h_abs_max) print(f'h_abs min: {h_abs_min_str}, max: {h_abs_max_str}') # 返回的 h 的rescale 为 2^31 return h @torch.inference_mode() def finish_inference(self, h): # norm的结果的scale = 2^15, h 的 scale = 2^15 h = self.norm(h)[0][:, -1] # logits 的rescale 为 2^21 logits = self.head(h[None, :]) if world_size > 1: all_logits = [torch.empty_like(logits) for _ in range(world_size)] dist.all_gather(all_logits, logits) logits = torch.cat(all_logits, dim=-1) # logits 的 scale = 2^21 return logits # # 这里开始推理了,torch.inference_mode 这句话 关闭梯度计算 并 禁止 autograd 构建计算图,同时比 torch.no_grad() 还高效,专门为推理场景优化 # @torch.inference_mode() # def forward(self, tokens: torch.Tensor, start_pos: int = 0): # """ # Forward pass for the Transformer model. # Args: # tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). # start_pos (int, optional): Starting position in the sequence for rotary(旋转的) embeddings. Defaults to 0. # Returns: # torch.Tensor: Logits tensor of shape (batch_size, vocab_size). # """ # seqlen = tokens.size(1) # # h 是经过embed之后的结果,embed将文本表达转化为词嵌入,h的形状为 (batch_size, seq_len, 7168) # h = self.embed(tokens) # freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] # print('freqs_cis: ' + str(freqs_cis.tolist())) # mask = None # # triu = triangle up # # 返回上三角矩阵 # # 参数 k=0 代表主对角线,k 为正数则从主对角线开始向上数第 k 条,k 为负数则从主对角线开始向下数第 k 条 # if seqlen > 1: # mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) # for layer in self.layers: # h = layer(h, start_pos, freqs_cis, mask) # # 只取最后一个 token # h = self.norm(h)[:, -1] # logits = self.head(h) # if world_size > 1: # all_logits = [torch.empty_like(logits) for _ in range(world_size)] # dist.all_gather(all_logits, logits) # logits = torch.cat(all_logits, dim=-1) # return logits if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) torch.set_default_device("cuda") torch.manual_seed(0) args = ModelArgs() x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(0, args) print(model(x).size())