alibi_2_4_256_fla / modeling_transformer.py
Lanni-ni's picture
Update modeling_transformer.py
2075401 verified
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
from typing import Optional, Tuple, List
import torch
import torch.nn as nn
import warnings
from fla.modules import RMSNorm, RotaryEmbedding
from fla.modules.activations import swiglu_linear
from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
from transformers.modeling_utils import PreTrainedModel
from einops import rearrange
from flash_attn.flash_attn_interface import flash_attn_func
from forgetting_transformer.model.alibi.configuration_alibi import AlibiConfig
class Attention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: Optional[int],
layer_idx: int,
use_alibi: bool,
use_rope: bool,
rope_base: float,
):
super().__init__()
self.num_heads = num_heads
self.num_kv_heads = num_kv_heads or num_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = hidden_size // num_heads
self.layer_idx = layer_idx
self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False)
self.use_rope = use_rope
self.rotary = RotaryEmbedding(self.head_dim, base=rope_base) if use_rope else None
self.use_alibi = use_alibi
if use_alibi:
slopes = torch.tensor(self._get_slopes(num_heads), dtype=torch.float32)
self.register_buffer("alibi_slopes", slopes.view(num_heads), persistent=False)
# 警告:ALiBi 和 RoPE 通常是互斥的
if use_alibi and use_rope:
warnings.warn(
"Both use_alibi and use_rope are enabled. This is an unusual configuration.",
UserWarning
)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
B, T, _ = x.shape
# [B, T, H, D]
q = rearrange(self.q_proj(x), "b t (h d) -> b t h d", h=self.num_heads)
# [B, T, H_kv, D]
k = rearrange(self.k_proj(x), "b t (h d) -> b t h d", h=self.num_kv_heads)
v = rearrange(self.v_proj(x), "b t (h d) -> b t h d", h=self.num_kv_heads)
# --- RoPE 应用(仅对新的 q, k) ---
if self.use_rope:
seqlen_offset = 0
if past_key_value is not None:
# 已有缓存时,offset 是缓存的长度
seqlen_offset = past_key_value[0].shape[1]
# 只对新生成的 q, k 应用 RoPE
q, k = self.rotary(q, k, seqlen_offset=seqlen_offset)
# --- KV 缓存拼接 ---
if past_key_value is not None:
past_k, past_v = past_key_value
k = torch.cat([past_k, k], dim=1)
v = torch.cat([past_v, v], dim=1)
# 保存当前状态(未 repeat 的 KV)
present_key_value = (k, v) if use_cache else None
# --- 类型转换 ---
original_dtype = q.dtype
compute_dtype = original_dtype
if original_dtype not in [torch.float16, torch.bfloat16]:
compute_dtype = torch.bfloat16
q = q.to(compute_dtype)
k = k.to(compute_dtype)
v = v.to(compute_dtype)
warnings.warn(
f"Flash Attention requires fp16/bf16 input, converting from {original_dtype} to {compute_dtype}",
UserWarning,
stacklevel=2
)
# --- ALiBi slopes ---
alibi = None
if self.use_alibi:
# CRITICAL: ALiBi slopes 必须与 q/k/v 的 dtype 一致
alibi = self.alibi_slopes.to(dtype=compute_dtype, device=x.device)
# --- Flash Attention ---
# 注意:Flash Attention 2.3+ 支持原生 GQA,不需要手动 repeat
# q: [B, T_q, H, D]
# k: [B, T_k, H_kv, D] (GQA: H_kv < H)
# v: [B, T_k, H_kv, D]
try:
out = flash_attn_func(
q, k, v,
dropout_p=0.0,
causal=True,
alibi_slopes=alibi,
) # -> [B, T_q, H, D]
except Exception as e:
# 如果原生 GQA 失败,回退到手动 repeat
if self.num_kv_groups > 1:
warnings.warn(
f"Flash Attention native GQA failed, falling back to manual repeat. Error: {e}",
UserWarning
)
k = k.repeat_interleave(self.num_kv_groups, dim=2)
v = v.repeat_interleave(self.num_kv_groups, dim=2)
out = flash_attn_func(
q, k, v,
dropout_p=0.0,
causal=True,
alibi_slopes=alibi,
)
else:
raise
if compute_dtype != original_dtype:
out = out.to(original_dtype)
out = self.o_proj(out.reshape(B, T, self.hidden_size))
return out, present_key_value
def _get_slopes(self, n):
"""生成 ALiBi slopes"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * (ratio ** i) for i in range(n)]
if math.log2(n).is_integer():
return get_slopes_power_of_2(n)
closest = 2 ** math.floor(math.log2(n))
return get_slopes_power_of_2(closest) + \
self._get_slopes(2 * closest)[0::2][: n - closest]
class TransformerMLP(nn.Module):
def __init__(self, hidden_size, hidden_ratio):
super().__init__()
inter = 256 * (((hidden_size * hidden_ratio * 2 // 3) + 255) // 256)
self.gate_proj = nn.Linear(hidden_size, inter * 2, bias=False)
self.down_proj = nn.Linear(inter, hidden_size, bias=False)
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, dim=-1)
return swiglu_linear(gate, y, self.down_proj.weight.to(y.dtype), None)
class TransformerBlock(nn.Module):
def __init__(self, cfg: AlibiConfig, idx: int):
super().__init__()
self.attn_norm = RMSNorm(cfg.hidden_size, eps=cfg.norm_eps)
self.attn = Attention(
cfg.hidden_size, cfg.num_heads, cfg.num_kv_heads,
idx, cfg.use_alibi, cfg.use_rope, cfg.rope_base
)
self.mlp_norm = RMSNorm(cfg.hidden_size, eps=cfg.norm_eps)
self.mlp = TransformerMLP(cfg.hidden_size, cfg.hidden_ratio)
def forward(
self,
x: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
use_cache: bool = False,
position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
# Pre-Norm Transformer
# 1. 注意力块
attn_input = self.attn_norm(x)
attn_output, present_key_value = self.attn(
attn_input,
attention_mask=attention_mask,
past_key_value=past_key_value,
use_cache=use_cache,
position_ids=position_ids,
)
x = x + attn_output
# 2. MLP 块
mlp_input = self.mlp_norm(x)
mlp_output = self.mlp(mlp_input)
x = x + mlp_output
return x, present_key_value
class AlibiModel(PreTrainedModel):
config_class = AlibiConfig
def __init__(self, config: AlibiConfig):
super().__init__(config)
self.config = config
self.emb = nn.Embedding(config.vocab_size, config.hidden_size)
self.layers = nn.ModuleList([
TransformerBlock(config, i) for i in range(config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def _init_weights(self, module):
"""初始化权重"""
std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
use_cache: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
**kwargs
) -> BaseModelOutputWithPast:
use_cache = use_cache if use_cache is not None else self.config.use_cache
x = self.emb(input_ids)
if past_key_values is None:
# KV 缓存结构: ((L0_k, L0_v), (L1_k, L1_v), ...)
past_key_values = [None] * len(self.layers)
new_past_key_values = () if use_cache else None
for i, layer in enumerate(self.layers):
layer_past = past_key_values[i]
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
x, layer_present = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer),
x,
attention_mask,
layer_past,
use_cache,
position_ids,
use_reentrant=False
)
else:
x, layer_present = layer(
x,
attention_mask=attention_mask,
past_key_value=layer_past,
use_cache=use_cache,
position_ids=position_ids,
)
if use_cache:
new_past_key_values = new_past_key_values + (layer_present,)
x = self.norm(x)
return BaseModelOutputWithPast(
last_hidden_state=x,
past_key_values=new_past_key_values if use_cache else None,
)
class AlibiForCausalLM(AlibiModel):
_no_split_modules = ["TransformerBlock"]
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
if config.tie_word_embeddings:
self.lm_head.weight = self.emb.weight
self.post_init()
def get_input_embeddings(self):
return self.emb
def set_input_embeddings(self, value):
self.emb = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None,
use_cache: Optional[bool] = None,
position_ids: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs
) -> CausalLMOutputWithPast:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
out = super().forward(
input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
position_ids=position_ids,
**kwargs
)
logits = self.lm_head(out.last_hidden_state)
# Causal LM 损失计算(返回每个位置的损失)
loss = None
if labels is not None:
# 1. 将 Logits 向左移动一位
shift_logits = logits[..., :-1, :].contiguous()
# 2. 将 Labels 向右移动一位
shift_labels = labels[..., 1:].contiguous()
# 3. 计算交叉熵损失(使用 reduction='none' 返回每个 token 的损失)
loss_fct = nn.CrossEntropyLoss(reduction='none')
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)),
shift_labels.view(-1)
)
# 4. 重塑为 [batch_size, seq_len-1]
loss = loss.view(shift_labels.size(0), shift_labels.size(1))
# 5. 在最后补一个 0,使形状变为 [batch_size, seq_len]
# 因为最后一个 token 没有对应的预测目标
loss = torch.cat([loss, torch.zeros_like(loss[:, :1])], dim=1)
if not return_dict:
output = (logits,) + out[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=out.past_key_values,
hidden_states=out.hidden_states if hasattr(out, 'hidden_states') else None,
attentions=out.attentions if hasattr(out, 'attentions') else None,
)
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
**kwargs
):
"""为生成准备输入"""
if past_key_values is not None:
# 只需要最后一个 token
input_ids = input_ids[:, -1:]
# 计算 position_ids
position_ids = kwargs.get("position_ids", None)
if attention_mask is not None and position_ids is None:
# [B, T]
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_key_values:
# [B, 1]
position_ids = position_ids[:, -1].unsqueeze(-1)
return {
"input_ids": input_ids,
"past_key_values": past_key_values,
"use_cache": kwargs.get("use_cache"),
"position_ids": position_ids,
"attention_mask": attention_mask,
}
@staticmethod
def _reorder_cache(past_key_values, beam_idx):
"""为 beam search 重排序缓存"""
reordered_past = ()
for layer_past in past_key_values:
reordered_past += (
tuple(
past_state.index_select(0, beam_idx.to(past_state.device))
for past_state in layer_past
),
)
return reordered_past