import torch import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple, List import math from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm from peft_ import LinearWithLoRA, AdapterLayer from moe import MixtureOfExperts class GroupedQueryAttention(nn.Module): def __init__( self, dim: int, n_heads: int, n_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, dropout: float = 0.0, attn_dropout: float = 0.0, use_flash: bool = True, qkv_bias: bool = False, use_lora: bool = False, lora_rank: int = 8, max_seq_len: int = 8192, rope_scaling_factor: float = 1.0, rope_scaling_type: str = "yarn", use_qk_norm: bool = False, sliding_window: Optional[int] = None, use_alibi: bool = False ): super().__init__() self.dim = dim self.n_heads = n_heads self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads assert n_heads % self.n_kv_heads == 0, \ f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})" self.n_rep = n_heads // self.n_kv_heads self.head_dim = head_dim if head_dim is not None else dim // n_heads self.scale = self.head_dim ** -0.5 self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention') self.sliding_window = sliding_window self.q_proj = LinearWithLoRA( dim, n_heads * self.head_dim, bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank ) self.k_proj = LinearWithLoRA( dim, self.n_kv_heads * self.head_dim, bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank ) self.v_proj = LinearWithLoRA( dim, self.n_kv_heads * self.head_dim, bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank ) self.o_proj = LinearWithLoRA( n_heads * self.head_dim, dim, bias=False, use_lora=use_lora, lora_rank=lora_rank ) self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity() self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() self.use_qk_norm = use_qk_norm if use_qk_norm: self.q_norm = QKNorm(self.head_dim) self.k_norm = QKNorm(self.head_dim) self.use_alibi = use_alibi if use_alibi: self.register_buffer( "alibi_slopes", self._get_alibi_slopes(n_heads), persistent=False ) else: self.rotary_emb = YARNRotaryEmbedding( self.head_dim, max_seq_len=max_seq_len, original_max_len=4096, scaling_factor=rope_scaling_factor, rope_percentage=1.0 ) def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor: """计算ALiBi斜率""" def get_slopes_power_of_2(n): start = 2 ** (-(2 ** -(math.log2(n) - 3))) ratio = start return [start * ratio ** i for i in range(n)] if math.log2(n_heads).is_integer(): slopes = get_slopes_power_of_2(n_heads) else: closest_power_of_2 = 2 ** math.floor(math.log2(n_heads)) slopes = get_slopes_power_of_2(closest_power_of_2) extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2] slopes.extend(extra_slopes[:n_heads - closest_power_of_2]) return torch.tensor(slopes).view(n_heads, 1, 1) def repeat_kv(self, x: torch.Tensor) -> torch.Tensor: """重复KV heads以匹配Q heads""" if self.n_rep == 1: return x B, n_kv_heads, seq_len, head_dim = x.shape return x[:, :, None, :, :].expand( B, n_kv_heads, self.n_rep, seq_len, head_dim ).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim) def _apply_sliding_window_mask( self, attn_scores: torch.Tensor, seq_len: int ) -> torch.Tensor: """应用滑动窗口mask""" if self.sliding_window is None or seq_len <= self.sliding_window: return attn_scores mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool) mask = torch.triu(mask, diagonal=-self.sliding_window + 1) mask = torch.tril(mask, diagonal=0) attn_scores = attn_scores.masked_fill(~mask, float('-inf')) return attn_scores def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, use_cache: bool = False, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]: """前向传播""" B, T, C = x.shape q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) if self.use_qk_norm: q_shape = q.shape k_shape = k.shape q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape) k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape) if not self.use_alibi: q, k = self.rotary_emb(q, k, position_ids) if past_kv is not None: past_k, past_v = past_kv k = torch.cat([past_k, k], dim=2) v = torch.cat([past_v, v], dim=2) present_kv = (k, v) if use_cache else None k = self.repeat_kv(k) v = self.repeat_kv(v) seq_len_k = k.size(2) if self.use_flash and not output_attentions and attention_mask is None: dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0 attn_output = F.scaled_dot_product_attention( q, k, v, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=True if attention_mask is None else False ) attention_weights = None else: attn_scores = (q @ k.transpose(-2, -1)) * self.scale if self.use_alibi: position_bias = self.alibi_slopes.to(x.device) * torch.arange( seq_len_k, device=x.device ).view(1, 1, -1) attn_scores = attn_scores + position_bias if self.sliding_window is not None: attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k) if attention_mask is not None: if attention_mask.dim() == 2: attention_mask = attention_mask[:, None, None, :] if attention_mask.dtype != torch.float: extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min else: extended_mask = attention_mask attn_scores = attn_scores + extended_mask is_causal = seq_len_k > 1 if is_causal: causal_mask = torch.triu( torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool), diagonal=1 ) causal_mask = causal_mask[-q.shape[2]:, :] attn_scores = attn_scores.masked_fill(causal_mask, float('-inf')) attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) attention_weights = self.attn_dropout(attention_weights) attn_output = attention_weights @ v attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) output = self.resid_dropout(self.o_proj(attn_output)) return output, present_kv, attention_weights if output_attentions else None class OptimizedTransformerBlock(nn.Module): """优化的Transformer块""" def __init__( self, dim: int, n_heads: int, n_kv_heads: Optional[int] = None, head_dim: Optional[int] = None, dropout: float = 0.0, attn_dropout: float = 0.0, use_moe: bool = False, num_experts: int = 8, moe_top_k: int = 2, use_adapter: bool = False, adapter_dim: int = 64, use_lora: bool = False, lora_rank: int = 8, use_parallel_residual: bool = False, norm_eps: float = 1e-6, sliding_window: Optional[int] = None, ffn_dim_multiplier: Optional[float] = None, layer_idx: int = 0 ): super().__init__() self.layer_idx = layer_idx self.use_moe = use_moe self.use_adapter = use_adapter self.use_parallel_residual = use_parallel_residual self.attention = GroupedQueryAttention( dim=dim, n_heads=n_heads, n_kv_heads=n_kv_heads, head_dim=head_dim, dropout=dropout, attn_dropout=attn_dropout, use_lora=use_lora, lora_rank=lora_rank, sliding_window=sliding_window, rope_scaling_type="yarn" ) if use_moe: self.ffn = MixtureOfExperts( dim=dim, num_experts=num_experts, top_k=moe_top_k, dropout=dropout, ffn_dim_multiplier=ffn_dim_multiplier ) else: self.ffn = SwiGLU( dim=dim, dropout=dropout, ffn_dim_multiplier=ffn_dim_multiplier ) if use_adapter: self.adapter = AdapterLayer(dim, adapter_dim, dropout) self.attention_norm = RMSNorm(dim, eps=norm_eps) self.ffn_norm = RMSNorm(dim, eps=norm_eps) self.moe_aux_loss = torch.tensor(0.0) def forward( self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.Tensor] = None, use_cache: bool = False, past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, output_attentions: bool = False ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]: """前向传播""" attn_out, present_kv, attn_weights = self.attention( self.attention_norm(x), attention_mask=attention_mask, position_ids=position_ids, use_cache=use_cache, past_kv=past_kv, output_attentions=output_attentions ) if self.use_parallel_residual: ffn_input = self.ffn_norm(x) if self.use_moe: ffn_out, aux_loss = self.ffn(ffn_input) self.moe_aux_loss = aux_loss else: ffn_out = self.ffn(ffn_input) self.moe_aux_loss = torch.tensor(0.0, device=x.device) x = x + attn_out + ffn_out else: x = x + attn_out if self.use_adapter: x = self.adapter(x) ffn_input = self.ffn_norm(x) if self.use_moe: ffn_out, aux_loss = self.ffn(ffn_input) x = x + ffn_out self.moe_aux_loss = aux_loss else: x = x + self.ffn(ffn_input) self.moe_aux_loss = torch.tensor(0.0, device=x.device) return x, present_kv, attn_weights