| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Optional, Tuple, List | |
| import math | |
| from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm | |
| from peft_ import LinearWithLoRA, AdapterLayer | |
| from moe import MixtureOfExperts | |
| class GroupedQueryAttention(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: Optional[int] = None, | |
| head_dim: Optional[int] = None, | |
| dropout: float = 0.0, | |
| attn_dropout: float = 0.0, | |
| use_flash: bool = True, | |
| qkv_bias: bool = False, | |
| use_lora: bool = False, | |
| lora_rank: int = 8, | |
| max_seq_len: int = 8192, | |
| rope_scaling_factor: float = 1.0, | |
| rope_scaling_type: str = "yarn", | |
| use_qk_norm: bool = False, | |
| sliding_window: Optional[int] = None, | |
| use_alibi: bool = False | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.n_heads = n_heads | |
| self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads | |
| assert n_heads % self.n_kv_heads == 0, \ | |
| f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})" | |
| self.n_rep = n_heads // self.n_kv_heads | |
| self.head_dim = head_dim if head_dim is not None else dim // n_heads | |
| self.scale = self.head_dim ** -0.5 | |
| self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention') | |
| self.sliding_window = sliding_window | |
| self.q_proj = LinearWithLoRA( | |
| dim, n_heads * self.head_dim, | |
| bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank | |
| ) | |
| self.k_proj = LinearWithLoRA( | |
| dim, self.n_kv_heads * self.head_dim, | |
| bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank | |
| ) | |
| self.v_proj = LinearWithLoRA( | |
| dim, self.n_kv_heads * self.head_dim, | |
| bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank | |
| ) | |
| self.o_proj = LinearWithLoRA( | |
| n_heads * self.head_dim, dim, | |
| bias=False, use_lora=use_lora, lora_rank=lora_rank | |
| ) | |
| self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity() | |
| self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() | |
| self.use_qk_norm = use_qk_norm | |
| if use_qk_norm: | |
| self.q_norm = QKNorm(self.head_dim) | |
| self.k_norm = QKNorm(self.head_dim) | |
| self.use_alibi = use_alibi | |
| if use_alibi: | |
| self.register_buffer( | |
| "alibi_slopes", | |
| self._get_alibi_slopes(n_heads), | |
| persistent=False | |
| ) | |
| else: | |
| self.rotary_emb = YARNRotaryEmbedding( | |
| self.head_dim, | |
| max_seq_len=max_seq_len, | |
| original_max_len=4096, | |
| scaling_factor=rope_scaling_factor, | |
| rope_percentage=1.0 | |
| ) | |
| def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor: | |
| """计算ALiBi斜率""" | |
| def get_slopes_power_of_2(n): | |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) | |
| ratio = start | |
| return [start * ratio ** i for i in range(n)] | |
| if math.log2(n_heads).is_integer(): | |
| slopes = get_slopes_power_of_2(n_heads) | |
| else: | |
| closest_power_of_2 = 2 ** math.floor(math.log2(n_heads)) | |
| slopes = get_slopes_power_of_2(closest_power_of_2) | |
| extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2] | |
| slopes.extend(extra_slopes[:n_heads - closest_power_of_2]) | |
| return torch.tensor(slopes).view(n_heads, 1, 1) | |
| def repeat_kv(self, x: torch.Tensor) -> torch.Tensor: | |
| """重复KV heads以匹配Q heads""" | |
| if self.n_rep == 1: | |
| return x | |
| B, n_kv_heads, seq_len, head_dim = x.shape | |
| return x[:, :, None, :, :].expand( | |
| B, n_kv_heads, self.n_rep, seq_len, head_dim | |
| ).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim) | |
| def _apply_sliding_window_mask( | |
| self, | |
| attn_scores: torch.Tensor, | |
| seq_len: int | |
| ) -> torch.Tensor: | |
| """应用滑动窗口mask""" | |
| if self.sliding_window is None or seq_len <= self.sliding_window: | |
| return attn_scores | |
| mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool) | |
| mask = torch.triu(mask, diagonal=-self.sliding_window + 1) | |
| mask = torch.tril(mask, diagonal=0) | |
| attn_scores = attn_scores.masked_fill(~mask, float('-inf')) | |
| return attn_scores | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| output_attentions: bool = False | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]: | |
| """前向传播""" | |
| B, T, C = x.shape | |
| q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2) | |
| k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2) | |
| if self.use_qk_norm: | |
| q_shape = q.shape | |
| k_shape = k.shape | |
| q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape) | |
| k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape) | |
| if not self.use_alibi: | |
| q, k = self.rotary_emb(q, k, position_ids) | |
| if past_kv is not None: | |
| past_k, past_v = past_kv | |
| k = torch.cat([past_k, k], dim=2) | |
| v = torch.cat([past_v, v], dim=2) | |
| present_kv = (k, v) if use_cache else None | |
| k = self.repeat_kv(k) | |
| v = self.repeat_kv(v) | |
| seq_len_k = k.size(2) | |
| if self.use_flash and not output_attentions and attention_mask is None: | |
| dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0 | |
| attn_output = F.scaled_dot_product_attention( | |
| q, k, v, | |
| attn_mask=attention_mask, | |
| dropout_p=dropout_p, | |
| is_causal=True if attention_mask is None else False | |
| ) | |
| attention_weights = None | |
| else: | |
| attn_scores = (q @ k.transpose(-2, -1)) * self.scale | |
| if self.use_alibi: | |
| position_bias = self.alibi_slopes.to(x.device) * torch.arange( | |
| seq_len_k, device=x.device | |
| ).view(1, 1, -1) | |
| attn_scores = attn_scores + position_bias | |
| if self.sliding_window is not None: | |
| attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k) | |
| if attention_mask is not None: | |
| if attention_mask.dim() == 2: | |
| attention_mask = attention_mask[:, None, None, :] | |
| if attention_mask.dtype != torch.float: | |
| extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min | |
| else: | |
| extended_mask = attention_mask | |
| attn_scores = attn_scores + extended_mask | |
| is_causal = seq_len_k > 1 | |
| if is_causal: | |
| causal_mask = torch.triu( | |
| torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool), | |
| diagonal=1 | |
| ) | |
| causal_mask = causal_mask[-q.shape[2]:, :] | |
| attn_scores = attn_scores.masked_fill(causal_mask, float('-inf')) | |
| attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype) | |
| attention_weights = self.attn_dropout(attention_weights) | |
| attn_output = attention_weights @ v | |
| attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1) | |
| output = self.resid_dropout(self.o_proj(attn_output)) | |
| return output, present_kv, attention_weights if output_attentions else None | |
| class OptimizedTransformerBlock(nn.Module): | |
| """优化的Transformer块""" | |
| def __init__( | |
| self, | |
| dim: int, | |
| n_heads: int, | |
| n_kv_heads: Optional[int] = None, | |
| head_dim: Optional[int] = None, | |
| dropout: float = 0.0, | |
| attn_dropout: float = 0.0, | |
| use_moe: bool = False, | |
| num_experts: int = 8, | |
| moe_top_k: int = 2, | |
| use_adapter: bool = False, | |
| adapter_dim: int = 64, | |
| use_lora: bool = False, | |
| lora_rank: int = 8, | |
| use_parallel_residual: bool = False, | |
| norm_eps: float = 1e-6, | |
| sliding_window: Optional[int] = None, | |
| ffn_dim_multiplier: Optional[float] = None, | |
| layer_idx: int = 0 | |
| ): | |
| super().__init__() | |
| self.layer_idx = layer_idx | |
| self.use_moe = use_moe | |
| self.use_adapter = use_adapter | |
| self.use_parallel_residual = use_parallel_residual | |
| self.attention = GroupedQueryAttention( | |
| dim=dim, | |
| n_heads=n_heads, | |
| n_kv_heads=n_kv_heads, | |
| head_dim=head_dim, | |
| dropout=dropout, | |
| attn_dropout=attn_dropout, | |
| use_lora=use_lora, | |
| lora_rank=lora_rank, | |
| sliding_window=sliding_window, | |
| rope_scaling_type="yarn" | |
| ) | |
| if use_moe: | |
| self.ffn = MixtureOfExperts( | |
| dim=dim, | |
| num_experts=num_experts, | |
| top_k=moe_top_k, | |
| dropout=dropout, | |
| ffn_dim_multiplier=ffn_dim_multiplier | |
| ) | |
| else: | |
| self.ffn = SwiGLU( | |
| dim=dim, | |
| dropout=dropout, | |
| ffn_dim_multiplier=ffn_dim_multiplier | |
| ) | |
| if use_adapter: | |
| self.adapter = AdapterLayer(dim, adapter_dim, dropout) | |
| self.attention_norm = RMSNorm(dim, eps=norm_eps) | |
| self.ffn_norm = RMSNorm(dim, eps=norm_eps) | |
| self.moe_aux_loss = torch.tensor(0.0) | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.Tensor] = None, | |
| use_cache: bool = False, | |
| past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, | |
| output_attentions: bool = False | |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]: | |
| """前向传播""" | |
| attn_out, present_kv, attn_weights = self.attention( | |
| self.attention_norm(x), | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| use_cache=use_cache, | |
| past_kv=past_kv, | |
| output_attentions=output_attentions | |
| ) | |
| if self.use_parallel_residual: | |
| ffn_input = self.ffn_norm(x) | |
| if self.use_moe: | |
| ffn_out, aux_loss = self.ffn(ffn_input) | |
| self.moe_aux_loss = aux_loss | |
| else: | |
| ffn_out = self.ffn(ffn_input) | |
| self.moe_aux_loss = torch.tensor(0.0, device=x.device) | |
| x = x + attn_out + ffn_out | |
| else: | |
| x = x + attn_out | |
| if self.use_adapter: | |
| x = self.adapter(x) | |
| ffn_input = self.ffn_norm(x) | |
| if self.use_moe: | |
| ffn_out, aux_loss = self.ffn(ffn_input) | |
| x = x + ffn_out | |
| self.moe_aux_loss = aux_loss | |
| else: | |
| x = x + self.ffn(ffn_input) | |
| self.moe_aux_loss = torch.tensor(0.0, device=x.device) | |
| return x, present_kv, attn_weights |