dynamic_alibi_pile_2layer / modeling_dynamic_alibi.py
Lanni-ni's picture
add remote code + model files
c67d7da verified
# -*- coding: utf-8 -*-
from __future__ import annotations
import math
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from fla.modules import FusedCrossEntropyLoss, RMSNorm, RotaryEmbedding
from torch.nn import functional as F
from fla.modules.activations import swiglu_linear
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_outputs import (BaseModelOutputWithPast,
CausalLMOutputWithPast)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from einops import rearrange
# 动态导入配置类
try:
from .configuration_dynamic_alibi import DynamicAlibiConfig
except (ImportError, ValueError):
try:
from configuration_dynamic_alibi import DynamicAlibiConfig
except ImportError:
from forgetting_transformer.model.dynamic_alibi.configuration_dynamic_alibi import DynamicAlibiConfig
from functools import partial
logger = logging.get_logger(__name__)
class DynamicAttention(nn.Module):
"""
Attention module with Dynamic ALiBi support
参照GPT2的动态ALiBi实现:m_t = m_0 * r^t
"""
def __init__(
self,
hidden_size: int = 2048,
num_heads: int = 32,
num_kv_heads: Optional[int] = None,
window_size: Optional[int] = None,
max_position_embeddings: Optional[int] = None,
rope_base: float = 500000.0,
use_rope: bool = False,
use_alibi: bool = True,
layer_idx: int = None,
# 🆕 动态ALiBi参数
use_dynamic_alibi: bool = False,
alibi_initial_slope: float = 1.0,
alibi_decay_rate: float = 0.6,
):
super().__init__()
self.num_heads = num_heads
if num_kv_heads is None:
self.num_kv_heads = self.num_heads
else:
self.num_kv_heads = num_kv_heads
self.num_kv_groups = num_heads // self.num_kv_heads
self.hidden_size = hidden_size
self.head_dim = self.hidden_size // self.num_heads
self.kv_dim = self.num_kv_heads * self.head_dim
self.window_size = window_size
self.max_position_embeddings = max_position_embeddings
self.layer_idx = layer_idx
self.q_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.kv_dim, bias=False)
self.o_proj = nn.Linear(self.hidden_size, self.hidden_size, bias=False)
if use_rope:
self.rotary = RotaryEmbedding(self.head_dim, base=rope_base)
else:
self.rotary = None
if use_alibi:
# 基础slopes(每个head一个slope)
slopes = torch.tensor(self._get_slopes(self.num_heads), dtype=torch.float32)
self.register_buffer("alibi_base_slopes", slopes, persistent=False)
# 🆕 动态ALiBi配置
self.use_dynamic_alibi = use_dynamic_alibi
self.alibi_initial_slope = alibi_initial_slope
self.alibi_decay_rate = alibi_decay_rate
self.current_epoch = 0 # 当前epoch,训练时更新
self.apply(self._initialize_weights)
def _initialize_weights(self, module: nn.Module):
pass
def update_epoch(self, epoch):
"""
更新当前epoch(参照GPT2实现)
Args:
epoch: 当前epoch数 (0-based)
"""
self.current_epoch = epoch
def _get_dynamic_scale(self):
"""
计算动态slope缩放因子
公式:m_t = m_0 * r^t
Returns:
float: 当前epoch的slope缩放因子
"""
if not self.use_dynamic_alibi:
return 1.0
# m_t = m_0 * r^t
scale = self.alibi_initial_slope * (self.alibi_decay_rate ** self.current_epoch)
return scale
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
output_attentions: bool = False,
use_cache: bool = False,
**kwargs,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
B, T, _ = hidden_states.size()
q = rearrange(self.q_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_heads)
k = rearrange(self.k_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_kv_heads)
v = rearrange(self.v_proj(hidden_states), 'b t (h d) -> b t h d', h=self.num_kv_heads)
seqlen_offset = 0
max_seqlen = q.shape[1]
if past_key_values is not None:
seqlen_offset = past_key_values.get_seq_length(self.layer_idx)
max_seqlen = q.shape[1] + seqlen_offset
if self.max_position_embeddings is not None:
max_seqlen = max(max_seqlen, self.max_position_embeddings)
if self.rotary is not None:
q, k = self.rotary(q, k, seqlen_offset, max_seqlen)
q = rearrange(q, 'b t h d -> b h t d')
k = rearrange(k, 'b t h d -> b h t d')
v = rearrange(v, 'b t h d -> b h t d')
if past_key_values is not None:
k, v = past_key_values.update(k, v, self.layer_idx)
if self.num_kv_groups > 1:
k = k.repeat_interleave(self.num_kv_groups, dim=1)
v = v.repeat_interleave(self.num_kv_groups, dim=1)
B, H, Tq, Dh = q.shape
Tk = k.size(2)
scale = 1.0 / math.sqrt(Dh)
scores = torch.matmul(q, k.transpose(-2, -1)) * scale
# 🆕 动态ALiBi bias计算
if hasattr(self, "alibi_base_slopes"):
positions = torch.arange(Tk, device=scores.device, dtype=torch.float32)
# 根据是否使用动态ALiBi选择slopes
if self.use_dynamic_alibi and self.training:
# 动态模式:slopes随epoch变化
dynamic_scale = self._get_dynamic_scale()
current_slopes = self.alibi_base_slopes * dynamic_scale
else:
# 静态模式:slopes固定
current_slopes = self.alibi_base_slopes
# 计算ALiBi bias(GPTNeoX方式)
alibi_slopes = current_slopes.view(H, 1).to(scores.device) # [H, 1]
alibi_bias = torch.matmul(alibi_slopes, positions.unsqueeze(0)) # [H, Tk]
alibi_bias = alibi_bias.view(1, H, 1, Tk).expand(B, -1, Tq, -1) # [B, H, Tq, Tk]
scores = scores + alibi_bias.to(scores.dtype)
# Causal mask:基于绝对位置
pos_q = seqlen_offset + torch.arange(Tq, device=scores.device)
pos_k = torch.arange(Tk, device=scores.device)
causal_mask = (pos_k.unsqueeze(0) > pos_q.unsqueeze(1))
scores = scores.masked_fill(causal_mask.view(1, 1, Tq, Tk), float('-inf'))
# Padding mask
if attention_mask is not None and attention_mask.shape[-1] == Tk:
pad_mask = (attention_mask == 0).view(B, 1, 1, Tk)
scores = scores.masked_fill(pad_mask, float('-inf'))
# Window mask
if self.window_size is not None:
past_too_far = (pos_k.view(1, Tk) < (pos_q.view(Tq, 1) - (self.window_size - 1)))
scores = scores.masked_fill(past_too_far.view(1, 1, Tq, Tk), float('-inf'))
attn = torch.softmax(scores, dim=-1)
o = torch.matmul(attn, v)
o = rearrange(o, 'b h t d -> b t (h d)')
o = self.o_proj(o)
attentions = attn if output_attentions else None
return o, attentions, past_key_values
def _get_slopes(self, n):
"""
Get slopes for ALiBi positional embedding
Based on the original ALiBi paper and GPTNeoX implementation
Returns negative slopes that will be multiplied by position indices
"""
def get_slopes_power_of_2(n):
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
ratio = start
return [start * ratio**i for i in range(n)]
if math.log2(n).is_integer():
slopes = get_slopes_power_of_2(n)
else:
closest_power_of_2 = 2 ** math.floor(math.log2(n))
slopes = (
get_slopes_power_of_2(closest_power_of_2)
+ self._get_slopes(2 * closest_power_of_2)[0::2][: n - closest_power_of_2]
)
# 返回负的slopes(与GPTNeoX一致)
return [-x for x in slopes]
class TransformerMLP(nn.Module):
def __init__(
self,
hidden_size: int,
hidden_ratio: Optional[int] = None,
intermediate_size: Optional[int] = None,
hidden_act: str = 'swish'
) -> 'TransformerMLP':
super().__init__()
self.hidden_size = hidden_size
if hidden_ratio is None:
hidden_ratio = 4
if intermediate_size is None:
intermediate_size = int(hidden_size * hidden_ratio * 2 / 3)
intermediate_size = 256 * ((intermediate_size + 256 - 1) // 256)
self.hidden_ratio = hidden_ratio
self.intermediate_size = intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size * 2, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[hidden_act]
def forward(self, x):
y = self.gate_proj(x)
gate, y = y.chunk(2, -1)
return swiglu_linear(
gate, y,
self.down_proj.weight.to(y.dtype),
self.down_proj.bias.to(y.dtype) if self.down_proj.bias is not None else self.down_proj.bias
)
class DynamicTransformerBlock(nn.Module):
def __init__(self, config, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.attn = DynamicAttention(
hidden_size=config.hidden_size,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
window_size=config.window_size,
use_alibi=config.use_alibi,
max_position_embeddings=config.max_position_embeddings,
rope_base=config.rope_base,
use_rope=config.use_rope,
layer_idx=layer_idx,
# 🆕 传递动态ALiBi参数
use_dynamic_alibi=config.use_dynamic_alibi,
alibi_initial_slope=config.alibi_initial_slope,
alibi_decay_rate=config.alibi_decay_rate,
)
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps)
self.mlp = TransformerMLP(
hidden_size=config.hidden_size,
hidden_ratio=config.hidden_ratio,
intermediate_size=config.intermediate_size,
hidden_act=config.hidden_act
)
def forward_attn(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
**kwargs,
):
hidden_states = self.attn_norm(hidden_states)
hidden_states, attentions, past_key_values = self.attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
return hidden_states, attentions, past_key_values
def forward_mlp(
self,
hidden_states: torch.Tensor,
residual: torch.Tensor,
):
hidden_states, residual = self.mlp_norm(hidden_states, residual, True)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
return hidden_states
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
gradient_checkpointing: bool = False
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
residual = hidden_states
if gradient_checkpointing:
forward_attn = partial(torch.utils.checkpoint.checkpoint, self.forward_attn, use_reentrant=False)
forward_mlp = partial(torch.utils.checkpoint.checkpoint, self.forward_mlp, use_reentrant=False)
else:
forward_attn = self.forward_attn
forward_mlp = self.forward_mlp
hidden_states, attentions, past_key_values = forward_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions
)
hidden_states = forward_mlp(
hidden_states,
residual,
)
outputs = (hidden_states,)
if output_attentions:
outputs += (attentions,)
if use_cache:
outputs += (past_key_values,)
return outputs
class DynamicTransformerPreTrainedModel(PreTrainedModel):
config_class = DynamicAlibiConfig
supports_gradient_checkpointing = True
_no_split_modules = ['DynamicTransformerBlock']
def __init__(self, config, *inputs, **kwargs):
# 动态修复 config_class 以支持远程代码加载
if hasattr(config, '__class__'):
config_module = config.__class__.__module__
if 'transformers_modules' in config_module or config_module == 'configuration_dynamic_alibi':
self.__class__.config_class = config.__class__
super().__init__(config, *inputs, **kwargs)
def _init_weights(
self,
module: nn.Module,
):
if isinstance(module, (nn.Linear, nn.Conv1d)):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Embedding):
nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
class DynamicAlibiModel(DynamicTransformerPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([DynamicTransformerBlock(config, layer_idx) for layer_idx in range(config.num_hidden_layers)])
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps)
self.gradient_checkpointing = False
self.post_init()
def get_input_embeddings(self):
return self.embeddings
def set_input_embeddings(self, value):
self.embeddings = value
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
) -> Union[Tuple, BaseModelOutputWithPast]:
if output_attentions:
warnings.warn(
"`DynamicAlibiModel` does not support output attention weights now, so `output_attentions` is set to `False`."
)
output_attentions = False
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is None and inputs_embeds is None:
raise ValueError("You have to specify either input_ids or inputs_embeds")
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
if past_key_values is None:
past_key_values = DynamicCache()
else:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
if inputs_embeds is None:
inputs_embeds = self.embeddings(input_ids)
hidden_states = inputs_embeds
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
all_hidden_states = () if output_hidden_states else None
all_attns = () if output_attentions else None
next_decoder_cache = None
for layer in self.layers:
if output_hidden_states:
all_hidden_states += (hidden_states,)
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
past_key_values=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
gradient_checkpointing=self.gradient_checkpointing and self.training
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions:
all_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states)
if output_hidden_states:
all_hidden_states += (hidden_states,)
next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_attns
)
class DynamicAlibiForCausalLM(DynamicTransformerPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = DynamicAlibiModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.model.embeddings
def set_input_embeddings(self, value):
self.model.embeddings = value
def get_output_embeddings(self):
return self.lm_head
def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings
def set_decoder(self, decoder):
self.model = decoder
def get_decoder(self):
return self.model
def update_alibi_epoch(self, current_epoch: int):
"""
更新所有attention层的当前epoch
参照GPT2实现,在训练循环中每个epoch开始时调用
Args:
current_epoch: 当前训练的epoch数 (0-based)
"""
for layer in self.model.layers:
if hasattr(layer.attn, 'update_epoch'):
layer.attn.update_epoch(current_epoch)
def get_working_memory_capacity(self):
"""
获取当前工作记忆容量
公式:w_t = 1 - m_t
Returns:
float: 当前的工作记忆容量 [0, 1]
"""
if not self.config.use_dynamic_alibi:
return 1.0 # 静态模式下,容量固定为1
# 从第一个attention层获取当前scale
first_attn = self.model.layers[0].attn
if hasattr(first_attn, '_get_dynamic_scale'):
m_t = first_attn._get_dynamic_scale()
w_t = 1.0 - m_t
return w_t
return 1.0
def prepare_inputs_for_generation(
self,
input_ids: torch.LongTensor = None,
past_key_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
**kwargs
):
if past_key_values is not None:
input_ids = input_ids[:, -1:]
if inputs_embeds is not None and past_key_values is None:
model_inputs = {'inputs_embeds': inputs_embeds}
else:
model_inputs = {'input_ids': input_ids.contiguous()}
model_inputs.update({
'past_key_values': past_key_values,
'use_cache': kwargs.get('use_cache'),
'attention_mask': attention_mask,
})
return model_inputs
def forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
labels: Optional[torch.LongTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, CausalLMOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict
)
hidden_states = outputs[0]
loss = None
if labels is not None:
if self.config.fuse_cross_entropy:
loss_fct = FusedCrossEntropyLoss(inplace_backward=True, reduction='none')
else:
loss_fct = nn.CrossEntropyLoss(reduction='none')
logits = self.lm_head(hidden_states)
labels = labels.to(logits.device)
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
loss = loss.view(*labels.size())
del logits
logits = None
else:
logits = self.lm_head(hidden_states)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)