MultiModal / transformer.py
szxllm's picture
Update transformer.py
1ef8665 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple, List
import math
from components import RMSNorm, SwiGLU, YARNRotaryEmbedding, QKNorm
from peft_ import LinearWithLoRA, AdapterLayer
from moe import MixtureOfExperts
class GroupedQueryAttention(nn.Module):
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
dropout: float = 0.0,
attn_dropout: float = 0.0,
use_flash: bool = True,
qkv_bias: bool = False,
use_lora: bool = False,
lora_rank: int = 8,
max_seq_len: int = 8192,
rope_scaling_factor: float = 1.0,
rope_scaling_type: str = "yarn",
use_qk_norm: bool = False,
sliding_window: Optional[int] = None,
use_alibi: bool = False
):
super().__init__()
self.dim = dim
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
assert n_heads % self.n_kv_heads == 0, \
f"n_heads ({n_heads}) must be divisible by n_kv_heads ({self.n_kv_heads})"
self.n_rep = n_heads // self.n_kv_heads
self.head_dim = head_dim if head_dim is not None else dim // n_heads
self.scale = self.head_dim ** -0.5
self.use_flash = use_flash and hasattr(F, 'scaled_dot_product_attention')
self.sliding_window = sliding_window
self.q_proj = LinearWithLoRA(
dim, n_heads * self.head_dim,
bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
)
self.k_proj = LinearWithLoRA(
dim, self.n_kv_heads * self.head_dim,
bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
)
self.v_proj = LinearWithLoRA(
dim, self.n_kv_heads * self.head_dim,
bias=qkv_bias, use_lora=use_lora, lora_rank=lora_rank
)
self.o_proj = LinearWithLoRA(
n_heads * self.head_dim, dim,
bias=False, use_lora=use_lora, lora_rank=lora_rank
)
self.attn_dropout = nn.Dropout(attn_dropout) if attn_dropout > 0 else nn.Identity()
self.resid_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
self.use_qk_norm = use_qk_norm
if use_qk_norm:
self.q_norm = QKNorm(self.head_dim)
self.k_norm = QKNorm(self.head_dim)
self.use_alibi = use_alibi
if use_alibi:
self.register_buffer(
"alibi_slopes",
self._get_alibi_slopes(n_heads),
persistent=False
)
else:
self.rotary_emb = YARNRotaryEmbedding(
self.head_dim,
max_seq_len=max_seq_len,
original_max_len=4096,
scaling_factor=rope_scaling_factor,
rope_percentage=1.0
)
def _get_alibi_slopes(self, n_heads: int) -> torch.Tensor:
"""计算ALiBi斜率"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio ** i for i in range(n)]
if math.log2(n_heads).is_integer():
slopes = get_slopes_power_of_2(n_heads)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n_heads))
slopes = get_slopes_power_of_2(closest_power_of_2)
extra_slopes = get_slopes_power_of_2(2 * closest_power_of_2)[::2]
slopes.extend(extra_slopes[:n_heads - closest_power_of_2])
return torch.tensor(slopes).view(n_heads, 1, 1)
def repeat_kv(self, x: torch.Tensor) -> torch.Tensor:
"""重复KV heads以匹配Q heads"""
if self.n_rep == 1:
return x
B, n_kv_heads, seq_len, head_dim = x.shape
return x[:, :, None, :, :].expand(
B, n_kv_heads, self.n_rep, seq_len, head_dim
).reshape(B, n_kv_heads * self.n_rep, seq_len, head_dim)
def _apply_sliding_window_mask(
self,
attn_scores: torch.Tensor,
seq_len: int
) -> torch.Tensor:
"""应用滑动窗口mask"""
if self.sliding_window is None or seq_len <= self.sliding_window:
return attn_scores
mask = torch.ones(seq_len, seq_len, device=attn_scores.device, dtype=torch.bool)
mask = torch.triu(mask, diagonal=-self.sliding_window + 1)
mask = torch.tril(mask, diagonal=0)
attn_scores = attn_scores.masked_fill(~mask, float('-inf'))
return attn_scores
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
"""前向传播"""
B, T, C = x.shape
q = self.q_proj(x).view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
k = self.k_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
v = self.v_proj(x).view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)
if self.use_qk_norm:
q_shape = q.shape
k_shape = k.shape
q = self.q_norm.query_norm(q.view(-1, self.head_dim)).view(q_shape)
k = self.k_norm.key_norm(k.view(-1, self.head_dim)).view(k_shape)
if not self.use_alibi:
q, k = self.rotary_emb(q, k, position_ids)
if past_kv is not None:
past_k, past_v = past_kv
k = torch.cat([past_k, k], dim=2)
v = torch.cat([past_v, v], dim=2)
present_kv = (k, v) if use_cache else None
k = self.repeat_kv(k)
v = self.repeat_kv(v)
seq_len_k = k.size(2)
if self.use_flash and not output_attentions and attention_mask is None:
dropout_p = self.attn_dropout.p if isinstance(self.attn_dropout, nn.Dropout) and self.training else 0.0
attn_output = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attention_mask,
dropout_p=dropout_p,
is_causal=True if attention_mask is None else False
)
attention_weights = None
else:
attn_scores = (q @ k.transpose(-2, -1)) * self.scale
if self.use_alibi:
position_bias = self.alibi_slopes.to(x.device) * torch.arange(
seq_len_k, device=x.device
).view(1, 1, -1)
attn_scores = attn_scores + position_bias
if self.sliding_window is not None:
attn_scores = self._apply_sliding_window_mask(attn_scores, seq_len_k)
if attention_mask is not None:
if attention_mask.dim() == 2:
attention_mask = attention_mask[:, None, None, :]
if attention_mask.dtype != torch.float:
extended_mask = (1.0 - attention_mask) * torch.finfo(attn_scores.dtype).min
else:
extended_mask = attention_mask
attn_scores = attn_scores + extended_mask
is_causal = seq_len_k > 1
if is_causal:
causal_mask = torch.triu(
torch.ones(seq_len_k, seq_len_k, device=x.device, dtype=torch.bool),
diagonal=1
)
causal_mask = causal_mask[-q.shape[2]:, :]
attn_scores = attn_scores.masked_fill(causal_mask, float('-inf'))
attention_weights = F.softmax(attn_scores, dim=-1, dtype=torch.float32).to(q.dtype)
attention_weights = self.attn_dropout(attention_weights)
attn_output = attention_weights @ v
attn_output = attn_output.transpose(1, 2).contiguous().view(B, T, -1)
output = self.resid_dropout(self.o_proj(attn_output))
return output, present_kv, attention_weights if output_attentions else None
class OptimizedTransformerBlock(nn.Module):
"""优化的Transformer块"""
def __init__(
self,
dim: int,
n_heads: int,
n_kv_heads: Optional[int] = None,
head_dim: Optional[int] = None,
dropout: float = 0.0,
attn_dropout: float = 0.0,
use_moe: bool = False,
num_experts: int = 8,
moe_top_k: int = 2,
use_adapter: bool = False,
adapter_dim: int = 64,
use_lora: bool = False,
lora_rank: int = 8,
use_parallel_residual: bool = False,
norm_eps: float = 1e-6,
sliding_window: Optional[int] = None,
ffn_dim_multiplier: Optional[float] = None,
layer_idx: int = 0
):
super().__init__()
self.layer_idx = layer_idx
self.use_moe = use_moe
self.use_adapter = use_adapter
self.use_parallel_residual = use_parallel_residual
self.attention = GroupedQueryAttention(
dim=dim,
n_heads=n_heads,
n_kv_heads=n_kv_heads,
head_dim=head_dim,
dropout=dropout,
attn_dropout=attn_dropout,
use_lora=use_lora,
lora_rank=lora_rank,
sliding_window=sliding_window,
rope_scaling_type="yarn"
)
if use_moe:
self.ffn = MixtureOfExperts(
dim=dim,
num_experts=num_experts,
top_k=moe_top_k,
dropout=dropout,
ffn_dim_multiplier=ffn_dim_multiplier
)
else:
self.ffn = SwiGLU(
dim=dim,
dropout=dropout,
ffn_dim_multiplier=ffn_dim_multiplier
)
if use_adapter:
self.adapter = AdapterLayer(dim, adapter_dim, dropout)
self.attention_norm = RMSNorm(dim, eps=norm_eps)
self.ffn_norm = RMSNorm(dim, eps=norm_eps)
self.moe_aux_loss = torch.tensor(0.0)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
use_cache: bool = False,
past_kv: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
output_attentions: bool = False
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]], Optional[torch.Tensor]]:
"""前向传播"""
attn_out, present_kv, attn_weights = self.attention(
self.attention_norm(x),
attention_mask=attention_mask,
position_ids=position_ids,
use_cache=use_cache,
past_kv=past_kv,
output_attentions=output_attentions
)
if self.use_parallel_residual:
ffn_input = self.ffn_norm(x)
if self.use_moe:
ffn_out, aux_loss = self.ffn(ffn_input)
self.moe_aux_loss = aux_loss
else:
ffn_out = self.ffn(ffn_input)
self.moe_aux_loss = torch.tensor(0.0, device=x.device)
x = x + attn_out + ffn_out
else:
x = x + attn_out
if self.use_adapter:
x = self.adapter(x)
ffn_input = self.ffn_norm(x)
if self.use_moe:
ffn_out, aux_loss = self.ffn(ffn_input)
x = x + ffn_out
self.moe_aux_loss = aux_loss
else:
x = x + self.ffn(ffn_input)
self.moe_aux_loss = torch.tensor(0.0, device=x.device)
return x, present_kv, attn_weights