import torch import torch.nn as nn import torch.nn.functional as F import math from torch.nn.parallel import parallel_apply from typing import Tuple, List, Optional, Union import torch.utils.checkpoint as checkpoint class MultiHeadAttention(nn.Module): """高效实现的多头注意力机制""" def __init__(self, model_dim: int, n_heads: int): super().__init__() assert model_dim % n_heads == 0, "model_dim must be divisible by n_heads" self.model_dim = model_dim self.d_k = model_dim // n_heads self.n_heads = n_heads # 使用单个线性层同时计算Q, K, V投影,减少计算开销 self.qkv_linear = nn.Linear(model_dim, 3 * model_dim, bias=False) self.out_linear = nn.Linear(model_dim, model_dim, bias=False) # 初始化参数,提高训练稳定性 nn.init.xavier_uniform_(self.qkv_linear.weight) nn.init.xavier_uniform_(self.out_linear.weight) self.scale = 1.0 / math.sqrt(self.d_k) def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: batch_size = q.size(0) # 如果输入相同,使用更高效的自注意力计算 is_self_attention = q.data_ptr() == k.data_ptr() == v.data_ptr() if is_self_attention: # [batch, seq, 3*dim] -> 3 x [batch, seq, dim] qkv = self.qkv_linear(q).chunk(3, dim=-1) q, k, v = map( lambda x: x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2), qkv ) else: # 使用单独的线性变换进行异源注意力计算 q = self.qkv_linear(q)[:, :, :self.model_dim] k = self.qkv_linear(k)[:, :, self.model_dim:2*self.model_dim] v = self.qkv_linear(v)[:, :, 2*self.model_dim:] q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2) # 缩放点积注意力计算 scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale # 掩码处理 (提高数值稳定性) if mask is not None: scores = scores.masked_fill(mask == 0, -6.0e4) if key_padding_mask is not None: scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -6.0e4) attn_weights = F.softmax(scores, dim=-1) # 使用注意力权重聚合值 context = torch.matmul(attn_weights, v) # 重新组织维度并线性投影输出 context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.model_dim) output = self.out_linear(context) return output class MoE(nn.Module): """优化的混合专家模块,支持并行计算和更高效的专家选择 Args: d_model (int): 模型隐藏层维度 num_experts (int): 专家数量 d_ff (int): 前馈层维度 dropout (float): Dropout概率 top_k (int): 每个token选择的专家数量 """ def __init__(self, d_model: int, num_experts: int, d_ff: int, dropout: float, top_k: int): super().__init__() # 参数初始化 self.num_experts = num_experts self.top_k = min(top_k, num_experts) # 确保top_k不超过专家数量,形状无变化 self.d_model = d_model # 门控网络:将输入映射到专家分数 [d_model -> num_experts] self.gate = nn.Linear(d_model, num_experts, bias=False) # 专家网络:并行专家模块列表 self.experts = nn.ModuleList([ nn.Sequential( # 每个专家结构: nn.Linear(d_model, d_ff, bias=False), # [d_model -> d_ff] nn.GELU(), # 激活函数无形状变化 nn.Dropout(dropout), # 无形状变化 nn.Linear(d_ff, d_model, bias=False) # [d_ff -> d_model] ) for _ in range(num_experts) ]) # 参数初始化 for expert in self.experts: nn.init.kaiming_uniform_(expert[0].weight, a=math.sqrt(5)) # 第一层线性权重初始化 nn.init.zeros_(expert[3].weight) # 输出层零初始化,形状保持 [d_ff, d_model] nn.init.zeros_(self.gate.weight) # 门控网络零初始化,形状 [d_model, num_experts] def orthogonal_loss(self) -> torch.Tensor: """计算专家网络之间的正交损失,提高专家多样性 Returns: torch.Tensor: 正交损失标量值 """ total_loss = 0.0 num_pairs = 0 # 获取所有专家的第一层和最后一层权重 # expert_weights_1形状: [num_experts, d_ff, d_model] expert_weights_1 = torch.stack([expert[0].weight for expert in self.experts]) # expert_weights_2形状: [num_experts, d_model, d_ff] expert_weights_2 = torch.stack([expert[3].weight for expert in self.experts]) # 计算所有专家对之间的正交损失 for i in range(self.num_experts): w1_i = expert_weights_1[i] # [d_ff, d_model] w2_i = expert_weights_2[i] # [d_model, d_ff] for j in range(i+1, self.num_experts): w1_j = expert_weights_1[j] # [d_ff, d_model] w2_j = expert_weights_2[j] # [d_model, d_ff] # 计算第一层权重的相似度 w1_sim = torch.sum((w1_i @ w1_j.T)**2) / (w1_i.size(0) * w1_j.size(0)) # 标量 # 计算第二层权重的相似度 w2_sim = torch.sum((w2_i.T @ w2_j)**2) / (w2_i.size(1) * w2_j.size(1)) # 标量 total_loss += (w1_sim + w2_sim) / 2 # 平均相似度 num_pairs += 1 return total_loss / max(num_pairs, 1) # 平均正交损失 def entropy_regularization_loss(self, routing_probs: torch.Tensor) -> torch.Tensor: """计算熵正则化损失,鼓励更均匀的路由分布 Args: routing_probs (torch.Tensor): 路由概率分布,形状 [batch*seq_len, num_experts] Returns: torch.Tensor: 熵损失标量值 """ # 使用数值稳定的log计算 log_probs = torch.log(torch.clamp(routing_probs, min=1e-6)) # 保持形状 [batch*seq, num_experts] # 逐元素计算熵,保持维度 entropy = -torch.sum(routing_probs * log_probs, dim=-1) # 形状 [batch*seq] return entropy.mean() # 标量 def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """MoE前向传播,高效实现专家选择和组合 Args: hidden_states (torch.Tensor): 输入张量,形状 [batch_size, seq_len, d_model] Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - 输出张量 [batch_size, seq_len, d_model] - 路由逻辑分数 [batch_size*seq_len, num_experts] - 熵正则化损失标量值 """ batch_size, seq_len, d_model = hidden_states.shape combined_batch_size = batch_size * seq_len # 展平输入用于并行处理 flat_hidden = hidden_states.reshape(combined_batch_size, d_model) # [batch*seq, d_model] # 路由计算 router_logits = self.gate(flat_hidden) # [batch*seq, num_experts] routing_probs = F.softmax(router_logits, dim=-1) # [batch*seq, num_experts] # 选择top-k专家 routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1) # 均为 [batch*seq, top_k] # 归一化权重 routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # [batch*seq, top_k] # 并行计算所有专家输出 flat_expert_inputs = [flat_hidden] * self.num_experts # 列表包含num_experts个[batch*seq, d_model] expert_outputs = parallel_apply(self.experts, flat_expert_inputs) # 列表包含num_experts个[batch*seq, d_model] expert_outputs = torch.stack(expert_outputs, dim=1) # [batch*seq, num_experts, d_model] # 构建专家权重矩阵 expert_weights_matrix = torch.zeros( combined_batch_size, self.num_experts, device=hidden_states.device ) # [batch*seq, num_experts] # 使用scatter_add高效聚合权重 for k in range(self.top_k): k_indices = selected_experts[:, k] # [batch*seq] k_weights = routing_weights[:, k].unsqueeze(1) # [batch*seq, 1] # 将权重累加到对应位置 expert_weights_matrix.scatter_add_( 1, k_indices.unsqueeze(1), # [batch*seq, 1] k_weights # [batch*seq, 1] ) # 更新expert_weights_matrix # 矩阵乘法组合专家输出 combined_output = torch.bmm( expert_weights_matrix.unsqueeze(1), # [batch*seq, 1, num_experts] expert_outputs # [batch*seq, num_experts, d_model] ).squeeze(1) # [batch*seq, d_model] # 恢复原始形状 output = combined_output.reshape(batch_size, seq_len, d_model) # [batch_size, seq_len, d_model] # 计算熵正则化损失 entropy_loss = self.entropy_regularization_loss(routing_probs) return output, router_logits, entropy_loss class EncoderLayer(nn.Module): """优化的编码器层,支持梯度检查点和残差连接预归一化""" def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int, dropout: float, num_experts: int, top_k: int): super().__init__() self.model_dim = model_dim # 使用预归一化(Pre-LN)结构,提高训练稳定性 self.norm1 = nn.LayerNorm(model_dim) self.norm2 = nn.LayerNorm(model_dim) self.self_attn = MultiHeadAttention(model_dim, n_heads) self.moe = MoE(model_dim, num_experts, ff_hidden_dim, dropout, top_k) self.dropout = nn.Dropout(dropout) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) # 可选的投影层,处理残差连接尺寸不匹配的情况 self.use_projection = False if not self.use_projection: self.residual_scale = nn.Parameter(torch.ones(1)) def _sa_block(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: """封装自注意力计算,便于梯度检查点使用""" x = self.self_attn(x, x, x, mask=mask, key_padding_mask=key_padding_mask) return self.dropout1(x) def _moe_block(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """封装MoE计算,便于梯度检查点使用""" return self.moe(x) def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None, use_checkpoint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ 编码器层前向传播 Args: x: 输入张量 [batch_size, seq_len, model_dim] src_mask: 源序列掩码 src_key_padding_mask: 填充掩码 use_checkpoint: 是否使用梯度检查点以节省内存 """ # 预归一化结构 (Pre-LN) normalized_x = self.norm1(x) # 自注意力块 (可选梯度检查点) if use_checkpoint and self.training: attn_output = checkpoint.checkpoint( self._sa_block, normalized_x, src_mask, src_key_padding_mask ) else: attn_output = self._sa_block(normalized_x, src_mask, src_key_padding_mask) # 第一个残差连接 x = x + attn_output * self.residual_scale # 预归一化 normalized_x = self.norm2(x) # MoE块 (可选梯度检查点) if use_checkpoint and self.training: moe_output, router_logits, entropy_loss = checkpoint.checkpoint( self._moe_block, normalized_x ) else: moe_output, router_logits, entropy_loss = self._moe_block(normalized_x) # 第二个残差连接 x = x + self.dropout2(moe_output) * self.residual_scale return x, router_logits, entropy_loss class PositionwiseFeedForward(nn.Module): def __init__(self, d_model, d_ff, dropout=0.1): super(PositionwiseFeedForward, self).__init__() self.linear1 = nn.Linear(d_model, d_ff) self.linear2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) self.relu = nn.ReLU() def forward(self, x): return self.linear2(self.dropout(self.relu(self.linear1(x)))) class EncoderLayer_nomoe(nn.Module): def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int, dropout: float): super().__init__() # 使用预归一化(Pre-LN)结构,提高训练稳定性 self.norm1 = nn.LayerNorm(model_dim) self.norm2 = nn.LayerNorm(model_dim) self.self_attn = MultiHeadAttention(model_dim, n_heads) self.feed_forward = PositionwiseFeedForward(model_dim, ff_hidden_dim, dropout) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward(self, x: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # 预归一化结构 (Pre-LN) normalized_x = self.norm1(x) attn_output = self.self_attn(normalized_x, normalized_x, normalized_x, src_mask,src_key_padding_mask) # 第一个残差连接 x = x + self.dropout1(attn_output) # 预归一化 normalized_x = self.norm2(x) ff_output = self.feed_forward(normalized_x) # 第二个残差连接 x = x + self.dropout2(ff_output) return x class PositionalEncoding(nn.Module): """高效实现的位置编码""" def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000): super().__init__() self.dropout = nn.Dropout(p=dropout) # 一次性计算并缓存位置编码 pe = torch.zeros(1, max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # 更高效的位置编码计算 pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) # 注册缓冲区而不是参数,节省内存 self.register_buffer('pe', pe) def forward(self, x: torch.Tensor) -> torch.Tensor: """ 添加位置编码到输入 Args: x: 输入张量 [batch_size, seq_len, model_dim] """ pos_encoding = self.pe[:, :x.size(1)] x = x + pos_encoding return self.dropout(x) class Encoder(nn.Module): """优化的Encoder架构""" def __init__(self, input_dim: int, model_dim: int, n_heads: int, num_layers: int, ff_hidden_dim: int, dropout: float, num_experts: int, top_k: int, if_embedding: bool = True, if_pos_encoding: bool = True, use_checkpointing: bool = False): super().__init__() self.model_dim = model_dim self.num_layers = num_layers self.if_embedding = if_embedding self.if_pos_encoding = if_pos_encoding self.use_checkpointing = use_checkpointing # 嵌入层 if if_embedding: self.embedding = nn.Embedding(input_dim, model_dim) # 改善嵌入初始化 nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5) # 位置编码 if if_pos_encoding: self.pos_encoding = PositionalEncoding(model_dim, dropout) # 编码器层 self.layers = nn.ModuleList([ EncoderLayer( model_dim, n_heads, ff_hidden_dim, dropout, num_experts, top_k ) for _ in range(num_layers) ]) # 输出归一化 self.final_norm = nn.LayerNorm(model_dim) def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]: """ 编码器前向传播 Args: src: 输入序列 [batch_size, seq_len] 或 [batch_size, seq_len, model_dim] src_mask: 源序列掩码 src_key_padding_mask: 填充掩码 Returns: tuple: (输出张量, 路由逻辑列表, 熵损失) """ # 嵌入处理 if self.if_embedding: x = self.embedding(src) * math.sqrt(self.model_dim) else: x = src # 位置编码 if self.if_pos_encoding: x = self.pos_encoding(x) # 跟踪熵损失和路由逻辑 total_entropy_loss = 0.0 router_logits_list = [] # 通过编码器层 for layer in self.layers: x, router_logits, entropy_loss = layer( x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask, use_checkpoint=self.use_checkpointing ) total_entropy_loss += entropy_loss # 只保存CPU版本的路由逻辑,降低内存使用 if not self.training: # 仅在推理时保存路由逻辑 router_logits_list.append(router_logits.detach().cpu().tolist()) # 应用最终层归一化 x = self.final_norm(x) # 计算平均熵损失 avg_entropy_loss = total_entropy_loss / self.num_layers return x, router_logits_list, avg_entropy_loss class Encoder_nomoe(nn.Module): """优化的Encoder架构""" def __init__(self, input_dim: int, model_dim: int, n_heads: int, num_layers: int, ff_hidden_dim: int, dropout: float, if_embedding: bool = True, if_pos_encoding: bool = True): super().__init__() self.model_dim = model_dim self.num_layers = num_layers self.if_embedding = if_embedding self.if_pos_encoding = if_pos_encoding # 嵌入层 if if_embedding: self.embedding = nn.Embedding(input_dim, model_dim) # 改善嵌入初始化 nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5) # 位置编码 if if_pos_encoding: self.pos_encoding = PositionalEncoding(model_dim, dropout) # 编码器层 self.layers = nn.ModuleList([ EncoderLayer_nomoe( model_dim, n_heads, ff_hidden_dim, dropout ) for _ in range(num_layers) ]) # 输出归一化 self.final_norm = nn.LayerNorm(model_dim) def forward(self, src: torch.Tensor, src_mask: Optional[torch.Tensor] = None, src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]: # 嵌入处理 if self.if_embedding: x = self.embedding(src) * math.sqrt(self.model_dim) else: x = src # 位置编码 if self.if_pos_encoding: x = self.pos_encoding(x) # 通过编码器层 for layer in self.layers: x = layer( x, src_mask=src_mask, src_key_padding_mask=src_key_padding_mask ) # 应用最终层归一化 x = self.final_norm(x) return x