RNARL / transformer_encoder_MoE.py
你çsglin
Add initial model weights, utils.py, transformer_encoder_MoE.py, and README
9c396a5
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch.nn.parallel import parallel_apply
from typing import Tuple, List, Optional, Union
import torch.utils.checkpoint as checkpoint
class MultiHeadAttention(nn.Module):
"""高效实现的多头注意力机制"""
def __init__(self, model_dim: int, n_heads: int):
super().__init__()
assert model_dim % n_heads == 0, "model_dim must be divisible by n_heads"
self.model_dim = model_dim
self.d_k = model_dim // n_heads
self.n_heads = n_heads
# 使用单个线性层同时计算Q, K, V投影,减少计算开销
self.qkv_linear = nn.Linear(model_dim, 3 * model_dim, bias=False)
self.out_linear = nn.Linear(model_dim, model_dim, bias=False)
# 初始化参数,提高训练稳定性
nn.init.xavier_uniform_(self.qkv_linear.weight)
nn.init.xavier_uniform_(self.out_linear.weight)
self.scale = 1.0 / math.sqrt(self.d_k)
def forward(self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
batch_size = q.size(0)
# 如果输入相同,使用更高效的自注意力计算
is_self_attention = q.data_ptr() == k.data_ptr() == v.data_ptr()
if is_self_attention:
# [batch, seq, 3*dim] -> 3 x [batch, seq, dim]
qkv = self.qkv_linear(q).chunk(3, dim=-1)
q, k, v = map(
lambda x: x.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2),
qkv
)
else:
# 使用单独的线性变换进行异源注意力计算
q = self.qkv_linear(q)[:, :, :self.model_dim]
k = self.qkv_linear(k)[:, :, self.model_dim:2*self.model_dim]
v = self.qkv_linear(v)[:, :, 2*self.model_dim:]
q = q.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
k = k.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
v = v.view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
# 缩放点积注意力计算
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
# 掩码处理 (提高数值稳定性)
if mask is not None:
scores = scores.masked_fill(mask == 0, -6.0e4)
if key_padding_mask is not None:
scores = scores.masked_fill(key_padding_mask.unsqueeze(1).unsqueeze(2), -6.0e4)
attn_weights = F.softmax(scores, dim=-1)
# 使用注意力权重聚合值
context = torch.matmul(attn_weights, v)
# 重新组织维度并线性投影输出
context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.model_dim)
output = self.out_linear(context)
return output
class MoE(nn.Module):
"""优化的混合专家模块,支持并行计算和更高效的专家选择
Args:
d_model (int): 模型隐藏层维度
num_experts (int): 专家数量
d_ff (int): 前馈层维度
dropout (float): Dropout概率
top_k (int): 每个token选择的专家数量
"""
def __init__(self, d_model: int, num_experts: int, d_ff: int, dropout: float, top_k: int):
super().__init__()
# 参数初始化
self.num_experts = num_experts
self.top_k = min(top_k, num_experts) # 确保top_k不超过专家数量,形状无变化
self.d_model = d_model
# 门控网络:将输入映射到专家分数 [d_model -> num_experts]
self.gate = nn.Linear(d_model, num_experts, bias=False)
# 专家网络:并行专家模块列表
self.experts = nn.ModuleList([
nn.Sequential( # 每个专家结构:
nn.Linear(d_model, d_ff, bias=False), # [d_model -> d_ff]
nn.GELU(), # 激活函数无形状变化
nn.Dropout(dropout), # 无形状变化
nn.Linear(d_ff, d_model, bias=False) # [d_ff -> d_model]
) for _ in range(num_experts)
])
# 参数初始化
for expert in self.experts:
nn.init.kaiming_uniform_(expert[0].weight, a=math.sqrt(5)) # 第一层线性权重初始化
nn.init.zeros_(expert[3].weight) # 输出层零初始化,形状保持 [d_ff, d_model]
nn.init.zeros_(self.gate.weight) # 门控网络零初始化,形状 [d_model, num_experts]
def orthogonal_loss(self) -> torch.Tensor:
"""计算专家网络之间的正交损失,提高专家多样性
Returns:
torch.Tensor: 正交损失标量值
"""
total_loss = 0.0
num_pairs = 0
# 获取所有专家的第一层和最后一层权重
# expert_weights_1形状: [num_experts, d_ff, d_model]
expert_weights_1 = torch.stack([expert[0].weight for expert in self.experts])
# expert_weights_2形状: [num_experts, d_model, d_ff]
expert_weights_2 = torch.stack([expert[3].weight for expert in self.experts])
# 计算所有专家对之间的正交损失
for i in range(self.num_experts):
w1_i = expert_weights_1[i] # [d_ff, d_model]
w2_i = expert_weights_2[i] # [d_model, d_ff]
for j in range(i+1, self.num_experts):
w1_j = expert_weights_1[j] # [d_ff, d_model]
w2_j = expert_weights_2[j] # [d_model, d_ff]
# 计算第一层权重的相似度
w1_sim = torch.sum((w1_i @ w1_j.T)**2) / (w1_i.size(0) * w1_j.size(0)) # 标量
# 计算第二层权重的相似度
w2_sim = torch.sum((w2_i.T @ w2_j)**2) / (w2_i.size(1) * w2_j.size(1)) # 标量
total_loss += (w1_sim + w2_sim) / 2 # 平均相似度
num_pairs += 1
return total_loss / max(num_pairs, 1) # 平均正交损失
def entropy_regularization_loss(self, routing_probs: torch.Tensor) -> torch.Tensor:
"""计算熵正则化损失,鼓励更均匀的路由分布
Args:
routing_probs (torch.Tensor): 路由概率分布,形状 [batch*seq_len, num_experts]
Returns:
torch.Tensor: 熵损失标量值
"""
# 使用数值稳定的log计算
log_probs = torch.log(torch.clamp(routing_probs, min=1e-6)) # 保持形状 [batch*seq, num_experts]
# 逐元素计算熵,保持维度
entropy = -torch.sum(routing_probs * log_probs, dim=-1) # 形状 [batch*seq]
return entropy.mean() # 标量
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""MoE前向传播,高效实现专家选择和组合
Args:
hidden_states (torch.Tensor): 输入张量,形状 [batch_size, seq_len, d_model]
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- 输出张量 [batch_size, seq_len, d_model]
- 路由逻辑分数 [batch_size*seq_len, num_experts]
- 熵正则化损失标量值
"""
batch_size, seq_len, d_model = hidden_states.shape
combined_batch_size = batch_size * seq_len
# 展平输入用于并行处理
flat_hidden = hidden_states.reshape(combined_batch_size, d_model) # [batch*seq, d_model]
# 路由计算
router_logits = self.gate(flat_hidden) # [batch*seq, num_experts]
routing_probs = F.softmax(router_logits, dim=-1) # [batch*seq, num_experts]
# 选择top-k专家
routing_weights, selected_experts = torch.topk(routing_probs, self.top_k, dim=-1) # 均为 [batch*seq, top_k]
# 归一化权重
routing_weights = routing_weights / routing_weights.sum(dim=-1, keepdim=True) # [batch*seq, top_k]
# 并行计算所有专家输出
flat_expert_inputs = [flat_hidden] * self.num_experts # 列表包含num_experts个[batch*seq, d_model]
expert_outputs = parallel_apply(self.experts, flat_expert_inputs) # 列表包含num_experts个[batch*seq, d_model]
expert_outputs = torch.stack(expert_outputs, dim=1) # [batch*seq, num_experts, d_model]
# 构建专家权重矩阵
expert_weights_matrix = torch.zeros(
combined_batch_size, self.num_experts, device=hidden_states.device
) # [batch*seq, num_experts]
# 使用scatter_add高效聚合权重
for k in range(self.top_k):
k_indices = selected_experts[:, k] # [batch*seq]
k_weights = routing_weights[:, k].unsqueeze(1) # [batch*seq, 1]
# 将权重累加到对应位置
expert_weights_matrix.scatter_add_(
1,
k_indices.unsqueeze(1), # [batch*seq, 1]
k_weights # [batch*seq, 1]
) # 更新expert_weights_matrix
# 矩阵乘法组合专家输出
combined_output = torch.bmm(
expert_weights_matrix.unsqueeze(1), # [batch*seq, 1, num_experts]
expert_outputs # [batch*seq, num_experts, d_model]
).squeeze(1) # [batch*seq, d_model]
# 恢复原始形状
output = combined_output.reshape(batch_size, seq_len, d_model) # [batch_size, seq_len, d_model]
# 计算熵正则化损失
entropy_loss = self.entropy_regularization_loss(routing_probs)
return output, router_logits, entropy_loss
class EncoderLayer(nn.Module):
"""优化的编码器层,支持梯度检查点和残差连接预归一化"""
def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int,
dropout: float, num_experts: int, top_k: int):
super().__init__()
self.model_dim = model_dim
# 使用预归一化(Pre-LN)结构,提高训练稳定性
self.norm1 = nn.LayerNorm(model_dim)
self.norm2 = nn.LayerNorm(model_dim)
self.self_attn = MultiHeadAttention(model_dim, n_heads)
self.moe = MoE(model_dim, num_experts, ff_hidden_dim, dropout, top_k)
self.dropout = nn.Dropout(dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
# 可选的投影层,处理残差连接尺寸不匹配的情况
self.use_projection = False
if not self.use_projection:
self.residual_scale = nn.Parameter(torch.ones(1))
def _sa_block(self, x: torch.Tensor,
mask: Optional[torch.Tensor] = None,
key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""封装自注意力计算,便于梯度检查点使用"""
x = self.self_attn(x, x, x, mask=mask, key_padding_mask=key_padding_mask)
return self.dropout1(x)
def _moe_block(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""封装MoE计算,便于梯度检查点使用"""
return self.moe(x)
def forward(self,
x: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None,
use_checkpoint: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
编码器层前向传播
Args:
x: 输入张量 [batch_size, seq_len, model_dim]
src_mask: 源序列掩码
src_key_padding_mask: 填充掩码
use_checkpoint: 是否使用梯度检查点以节省内存
"""
# 预归一化结构 (Pre-LN)
normalized_x = self.norm1(x)
# 自注意力块 (可选梯度检查点)
if use_checkpoint and self.training:
attn_output = checkpoint.checkpoint(
self._sa_block, normalized_x, src_mask, src_key_padding_mask
)
else:
attn_output = self._sa_block(normalized_x, src_mask, src_key_padding_mask)
# 第一个残差连接
x = x + attn_output * self.residual_scale
# 预归一化
normalized_x = self.norm2(x)
# MoE块 (可选梯度检查点)
if use_checkpoint and self.training:
moe_output, router_logits, entropy_loss = checkpoint.checkpoint(
self._moe_block, normalized_x
)
else:
moe_output, router_logits, entropy_loss = self._moe_block(normalized_x)
# 第二个残差连接
x = x + self.dropout2(moe_output) * self.residual_scale
return x, router_logits, entropy_loss
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.linear1 = nn.Linear(d_model, d_ff)
self.linear2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(dropout)
self.relu = nn.ReLU()
def forward(self, x):
return self.linear2(self.dropout(self.relu(self.linear1(x))))
class EncoderLayer_nomoe(nn.Module):
def __init__(self, model_dim: int, n_heads: int, ff_hidden_dim: int,
dropout: float):
super().__init__()
# 使用预归一化(Pre-LN)结构,提高训练稳定性
self.norm1 = nn.LayerNorm(model_dim)
self.norm2 = nn.LayerNorm(model_dim)
self.self_attn = MultiHeadAttention(model_dim, n_heads)
self.feed_forward = PositionwiseFeedForward(model_dim, ff_hidden_dim, dropout)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
def forward(self,
x: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# 预归一化结构 (Pre-LN)
normalized_x = self.norm1(x)
attn_output = self.self_attn(normalized_x, normalized_x, normalized_x, src_mask,src_key_padding_mask)
# 第一个残差连接
x = x + self.dropout1(attn_output)
# 预归一化
normalized_x = self.norm2(x)
ff_output = self.feed_forward(normalized_x)
# 第二个残差连接
x = x + self.dropout2(ff_output)
return x
class PositionalEncoding(nn.Module):
"""高效实现的位置编码"""
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
# 一次性计算并缓存位置编码
pe = torch.zeros(1, max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
# 更高效的位置编码计算
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
# 注册缓冲区而不是参数,节省内存
self.register_buffer('pe', pe)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
添加位置编码到输入
Args:
x: 输入张量 [batch_size, seq_len, model_dim]
"""
pos_encoding = self.pe[:, :x.size(1)]
x = x + pos_encoding
return self.dropout(x)
class Encoder(nn.Module):
"""优化的Encoder架构"""
def __init__(self,
input_dim: int,
model_dim: int,
n_heads: int,
num_layers: int,
ff_hidden_dim: int,
dropout: float,
num_experts: int,
top_k: int,
if_embedding: bool = True,
if_pos_encoding: bool = True,
use_checkpointing: bool = False):
super().__init__()
self.model_dim = model_dim
self.num_layers = num_layers
self.if_embedding = if_embedding
self.if_pos_encoding = if_pos_encoding
self.use_checkpointing = use_checkpointing
# 嵌入层
if if_embedding:
self.embedding = nn.Embedding(input_dim, model_dim)
# 改善嵌入初始化
nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5)
# 位置编码
if if_pos_encoding:
self.pos_encoding = PositionalEncoding(model_dim, dropout)
# 编码器层
self.layers = nn.ModuleList([
EncoderLayer(
model_dim, n_heads, ff_hidden_dim, dropout, num_experts, top_k
) for _ in range(num_layers)
])
# 输出归一化
self.final_norm = nn.LayerNorm(model_dim)
def forward(self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]:
"""
编码器前向传播
Args:
src: 输入序列 [batch_size, seq_len] 或 [batch_size, seq_len, model_dim]
src_mask: 源序列掩码
src_key_padding_mask: 填充掩码
Returns:
tuple: (输出张量, 路由逻辑列表, 熵损失)
"""
# 嵌入处理
if self.if_embedding:
x = self.embedding(src) * math.sqrt(self.model_dim)
else:
x = src
# 位置编码
if self.if_pos_encoding:
x = self.pos_encoding(x)
# 跟踪熵损失和路由逻辑
total_entropy_loss = 0.0
router_logits_list = []
# 通过编码器层
for layer in self.layers:
x, router_logits, entropy_loss = layer(
x,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask,
use_checkpoint=self.use_checkpointing
)
total_entropy_loss += entropy_loss
# 只保存CPU版本的路由逻辑,降低内存使用
if not self.training: # 仅在推理时保存路由逻辑
router_logits_list.append(router_logits.detach().cpu().tolist())
# 应用最终层归一化
x = self.final_norm(x)
# 计算平均熵损失
avg_entropy_loss = total_entropy_loss / self.num_layers
return x, router_logits_list, avg_entropy_loss
class Encoder_nomoe(nn.Module):
"""优化的Encoder架构"""
def __init__(self,
input_dim: int,
model_dim: int,
n_heads: int,
num_layers: int,
ff_hidden_dim: int,
dropout: float,
if_embedding: bool = True,
if_pos_encoding: bool = True):
super().__init__()
self.model_dim = model_dim
self.num_layers = num_layers
self.if_embedding = if_embedding
self.if_pos_encoding = if_pos_encoding
# 嵌入层
if if_embedding:
self.embedding = nn.Embedding(input_dim, model_dim)
# 改善嵌入初始化
nn.init.normal_(self.embedding.weight, mean=0, std=model_dim**-0.5)
# 位置编码
if if_pos_encoding:
self.pos_encoding = PositionalEncoding(model_dim, dropout)
# 编码器层
self.layers = nn.ModuleList([
EncoderLayer_nomoe(
model_dim, n_heads, ff_hidden_dim, dropout
) for _ in range(num_layers)
])
# 输出归一化
self.final_norm = nn.LayerNorm(model_dim)
def forward(self,
src: torch.Tensor,
src_mask: Optional[torch.Tensor] = None,
src_key_padding_mask: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, List, float]:
# 嵌入处理
if self.if_embedding:
x = self.embedding(src) * math.sqrt(self.model_dim)
else:
x = src
# 位置编码
if self.if_pos_encoding:
x = self.pos_encoding(x)
# 通过编码器层
for layer in self.layers:
x = layer(
x,
src_mask=src_mask,
src_key_padding_mask=src_key_padding_mask
)
# 应用最终层归一化
x = self.final_norm(x)
return x