| |
| from __future__ import annotations |
| import math |
| from typing import Optional, Tuple, List |
| import torch |
| import torch.nn as nn |
| import warnings |
|
|
| from fla.modules import RMSNorm, RotaryEmbedding |
| from fla.modules.activations import swiglu_linear |
| from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast |
| from transformers.modeling_utils import PreTrainedModel |
| from einops import rearrange |
| from flash_attn.flash_attn_interface import flash_attn_func |
|
|
| from forgetting_transformer.model.alibi.configuration_alibi import AlibiConfig |
|
|
|
|
| class Attention(nn.Module): |
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| num_kv_heads: Optional[int], |
| layer_idx: int, |
| use_alibi: bool, |
| use_rope: bool, |
| rope_base: float, |
| ): |
| super().__init__() |
|
|
| self.num_heads = num_heads |
| self.num_kv_heads = num_kv_heads or num_heads |
| self.num_kv_groups = num_heads // self.num_kv_heads |
| self.hidden_size = hidden_size |
| self.head_dim = hidden_size // num_heads |
| self.layer_idx = layer_idx |
|
|
| self.q_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.k_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.v_proj = nn.Linear(hidden_size, self.num_kv_heads * self.head_dim, bias=False) |
| self.o_proj = nn.Linear(hidden_size, hidden_size, bias=False) |
|
|
| self.use_rope = use_rope |
| self.rotary = RotaryEmbedding(self.head_dim, base=rope_base) if use_rope else None |
|
|
| self.use_alibi = use_alibi |
| if use_alibi: |
| slopes = torch.tensor(self._get_slopes(num_heads), dtype=torch.float32) |
| self.register_buffer("alibi_slopes", slopes.view(num_heads), persistent=False) |
| |
| |
| if use_alibi and use_rope: |
| warnings.warn( |
| "Both use_alibi and use_rope are enabled. This is an unusual configuration.", |
| UserWarning |
| ) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| |
| B, T, _ = x.shape |
|
|
| |
| q = rearrange(self.q_proj(x), "b t (h d) -> b t h d", h=self.num_heads) |
| |
| k = rearrange(self.k_proj(x), "b t (h d) -> b t h d", h=self.num_kv_heads) |
| v = rearrange(self.v_proj(x), "b t (h d) -> b t h d", h=self.num_kv_heads) |
|
|
| |
| if self.use_rope: |
| seqlen_offset = 0 |
| if past_key_value is not None: |
| |
| seqlen_offset = past_key_value[0].shape[1] |
| |
| |
| q, k = self.rotary(q, k, seqlen_offset=seqlen_offset) |
|
|
| |
| if past_key_value is not None: |
| past_k, past_v = past_key_value |
| k = torch.cat([past_k, k], dim=1) |
| v = torch.cat([past_v, v], dim=1) |
|
|
| |
| present_key_value = (k, v) if use_cache else None |
|
|
| |
| original_dtype = q.dtype |
| compute_dtype = original_dtype |
| |
| if original_dtype not in [torch.float16, torch.bfloat16]: |
| compute_dtype = torch.bfloat16 |
| q = q.to(compute_dtype) |
| k = k.to(compute_dtype) |
| v = v.to(compute_dtype) |
| warnings.warn( |
| f"Flash Attention requires fp16/bf16 input, converting from {original_dtype} to {compute_dtype}", |
| UserWarning, |
| stacklevel=2 |
| ) |
|
|
| |
| alibi = None |
| if self.use_alibi: |
| |
| alibi = self.alibi_slopes.to(dtype=compute_dtype, device=x.device) |
|
|
| |
| |
| |
| |
| |
| try: |
| out = flash_attn_func( |
| q, k, v, |
| dropout_p=0.0, |
| causal=True, |
| alibi_slopes=alibi, |
| ) |
| except Exception as e: |
| |
| if self.num_kv_groups > 1: |
| warnings.warn( |
| f"Flash Attention native GQA failed, falling back to manual repeat. Error: {e}", |
| UserWarning |
| ) |
| k = k.repeat_interleave(self.num_kv_groups, dim=2) |
| v = v.repeat_interleave(self.num_kv_groups, dim=2) |
| out = flash_attn_func( |
| q, k, v, |
| dropout_p=0.0, |
| causal=True, |
| alibi_slopes=alibi, |
| ) |
| else: |
| raise |
|
|
| if compute_dtype != original_dtype: |
| out = out.to(original_dtype) |
|
|
| out = self.o_proj(out.reshape(B, T, self.hidden_size)) |
| |
| return out, present_key_value |
|
|
| def _get_slopes(self, n): |
| """生成 ALiBi slopes""" |
| def get_slopes_power_of_2(n): |
| start = 2 ** (-(2 ** -(math.log2(n) - 3))) |
| ratio = start |
| return [start * (ratio ** i) for i in range(n)] |
|
|
| if math.log2(n).is_integer(): |
| return get_slopes_power_of_2(n) |
|
|
| closest = 2 ** math.floor(math.log2(n)) |
| return get_slopes_power_of_2(closest) + \ |
| self._get_slopes(2 * closest)[0::2][: n - closest] |
|
|
|
|
| class TransformerMLP(nn.Module): |
| def __init__(self, hidden_size, hidden_ratio): |
| super().__init__() |
| inter = 256 * (((hidden_size * hidden_ratio * 2 // 3) + 255) // 256) |
| self.gate_proj = nn.Linear(hidden_size, inter * 2, bias=False) |
| self.down_proj = nn.Linear(inter, hidden_size, bias=False) |
|
|
| def forward(self, x): |
| y = self.gate_proj(x) |
| gate, y = y.chunk(2, dim=-1) |
| return swiglu_linear(gate, y, self.down_proj.weight.to(y.dtype), None) |
|
|
|
|
| class TransformerBlock(nn.Module): |
| def __init__(self, cfg: AlibiConfig, idx: int): |
| super().__init__() |
| self.attn_norm = RMSNorm(cfg.hidden_size, eps=cfg.norm_eps) |
| self.attn = Attention( |
| cfg.hidden_size, cfg.num_heads, cfg.num_kv_heads, |
| idx, cfg.use_alibi, cfg.use_rope, cfg.rope_base |
| ) |
| self.mlp_norm = RMSNorm(cfg.hidden_size, eps=cfg.norm_eps) |
| self.mlp = TransformerMLP(cfg.hidden_size, cfg.hidden_ratio) |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| use_cache: bool = False, |
| position_ids: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]: |
| |
| |
| |
| attn_input = self.attn_norm(x) |
| attn_output, present_key_value = self.attn( |
| attn_input, |
| attention_mask=attention_mask, |
| past_key_value=past_key_value, |
| use_cache=use_cache, |
| position_ids=position_ids, |
| ) |
| x = x + attn_output |
|
|
| |
| mlp_input = self.mlp_norm(x) |
| mlp_output = self.mlp(mlp_input) |
| x = x + mlp_output |
| |
| return x, present_key_value |
|
|
|
|
| class AlibiModel(PreTrainedModel): |
| config_class = AlibiConfig |
|
|
| def __init__(self, config: AlibiConfig): |
| super().__init__(config) |
| self.config = config |
| self.emb = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([ |
| TransformerBlock(config, i) for i in range(config.num_hidden_layers) |
| ]) |
| self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) |
| |
| self.gradient_checkpointing = False |
| self.post_init() |
|
|
| def _init_weights(self, module): |
| """初始化权重""" |
| std = self.config.initializer_range if hasattr(self.config, 'initializer_range') else 0.02 |
| |
| if isinstance(module, nn.Linear): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.bias is not None: |
| module.bias.data.zero_() |
| elif isinstance(module, nn.Embedding): |
| module.weight.data.normal_(mean=0.0, std=std) |
| if module.padding_idx is not None: |
| module.weight.data[module.padding_idx].zero_() |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
| use_cache: Optional[bool] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| **kwargs |
| ) -> BaseModelOutputWithPast: |
| |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| |
| x = self.emb(input_ids) |
| |
| if past_key_values is None: |
| |
| past_key_values = [None] * len(self.layers) |
| |
| new_past_key_values = () if use_cache else None |
| |
| for i, layer in enumerate(self.layers): |
| layer_past = past_key_values[i] |
| |
| if self.gradient_checkpointing and self.training: |
| def create_custom_forward(module): |
| def custom_forward(*inputs): |
| return module(*inputs) |
| return custom_forward |
| |
| x, layer_present = torch.utils.checkpoint.checkpoint( |
| create_custom_forward(layer), |
| x, |
| attention_mask, |
| layer_past, |
| use_cache, |
| position_ids, |
| use_reentrant=False |
| ) |
| else: |
| x, layer_present = layer( |
| x, |
| attention_mask=attention_mask, |
| past_key_value=layer_past, |
| use_cache=use_cache, |
| position_ids=position_ids, |
| ) |
| |
| if use_cache: |
| new_past_key_values = new_past_key_values + (layer_present,) |
| |
| x = self.norm(x) |
| |
| return BaseModelOutputWithPast( |
| last_hidden_state=x, |
| past_key_values=new_past_key_values if use_cache else None, |
| ) |
|
|
|
|
| class AlibiForCausalLM(AlibiModel): |
| _no_split_modules = ["TransformerBlock"] |
| _tied_weights_keys = ["lm_head.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| |
| if config.tie_word_embeddings: |
| self.lm_head.weight = self.emb.weight |
| |
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.emb |
|
|
| def set_input_embeddings(self, value): |
| self.emb = value |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.lm_head = new_embeddings |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| past_key_values: Optional[Tuple[Tuple[torch.Tensor, torch.Tensor], ...]] = None, |
| use_cache: Optional[bool] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs |
| ) -> CausalLMOutputWithPast: |
| |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| out = super().forward( |
| input_ids, |
| attention_mask=attention_mask, |
| past_key_values=past_key_values, |
| use_cache=use_cache, |
| position_ids=position_ids, |
| **kwargs |
| ) |
| |
| logits = self.lm_head(out.last_hidden_state) |
|
|
| |
| loss = None |
| if labels is not None: |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| |
| |
| shift_labels = labels[..., 1:].contiguous() |
| |
| |
| loss_fct = nn.CrossEntropyLoss(reduction='none') |
| loss = loss_fct( |
| shift_logits.view(-1, shift_logits.size(-1)), |
| shift_labels.view(-1) |
| ) |
| |
| |
| loss = loss.view(shift_labels.size(0), shift_labels.size(1)) |
| |
| |
| |
| loss = torch.cat([loss, torch.zeros_like(loss[:, :1])], dim=1) |
|
|
| if not return_dict: |
| output = (logits,) + out[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return CausalLMOutputWithPast( |
| loss=loss, |
| logits=logits, |
| past_key_values=out.past_key_values, |
| hidden_states=out.hidden_states if hasattr(out, 'hidden_states') else None, |
| attentions=out.attentions if hasattr(out, 'attentions') else None, |
| ) |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids, |
| past_key_values=None, |
| attention_mask=None, |
| **kwargs |
| ): |
| """为生成准备输入""" |
| if past_key_values is not None: |
| |
| input_ids = input_ids[:, -1:] |
|
|
| |
| position_ids = kwargs.get("position_ids", None) |
| if attention_mask is not None and position_ids is None: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| if past_key_values: |
| |
| position_ids = position_ids[:, -1].unsqueeze(-1) |
|
|
| return { |
| "input_ids": input_ids, |
| "past_key_values": past_key_values, |
| "use_cache": kwargs.get("use_cache"), |
| "position_ids": position_ids, |
| "attention_mask": attention_mask, |
| } |
|
|
| @staticmethod |
| def _reorder_cache(past_key_values, beam_idx): |
| """为 beam search 重排序缓存""" |
| reordered_past = () |
| for layer_past in past_key_values: |
| reordered_past += ( |
| tuple( |
| past_state.index_select(0, beam_idx.to(past_state.device)) |
| for past_state in layer_past |
| ), |
| ) |
| return reordered_past |