|
|
|
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import math |
|
|
import warnings |
|
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.utils.checkpoint |
|
|
from fla.modules import FusedCrossEntropyLoss, RMSNorm, RotaryEmbedding |
|
|
from torch.nn import functional as F |
|
|
from transformers import PreTrainedModel |
|
|
from transformers.activations import ACT2FN |
|
|
from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
from transformers.generation.utils import GenerationConfig |
|
|
|
|
|
from .configuration_stickbreaking import StickbreakingConfig |
|
|
|
|
|
|
|
|
class StickbreakingAttention(nn.Module): |
|
|
""" |
|
|
Stick-breaking attention mechanism (ICLR 2025) |
|
|
""" |
|
|
|
|
|
def __init__(self, config: StickbreakingConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer_idx = layer_idx |
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_heads = config.num_heads |
|
|
self.num_kv_heads = config.num_kv_heads |
|
|
self.head_dim = self.hidden_size // self.num_heads |
|
|
self.num_kv_groups = self.num_heads // self.num_kv_heads |
|
|
self.scale = 1.0 / math.sqrt(self.head_dim) |
|
|
|
|
|
|
|
|
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.k_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.v_proj = nn.Linear(self.hidden_size, self.num_kv_heads * self.head_dim, bias=config.attention_bias) |
|
|
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) |
|
|
|
|
|
|
|
|
if config.use_rope: |
|
|
self.rotary = RotaryEmbedding( |
|
|
dim=self.head_dim, |
|
|
base=config.rope_base |
|
|
) |
|
|
|
|
|
|
|
|
if config.qk_norm: |
|
|
if config.qk_norm_share_param_across_head: |
|
|
self.q_norm = RMSNorm(hidden_size=self.head_dim, eps=config.norm_eps) |
|
|
self.k_norm = RMSNorm(hidden_size=self.head_dim, eps=config.norm_eps) |
|
|
else: |
|
|
self.q_norm = RMSNorm(hidden_size=self.hidden_size, eps=config.norm_eps) |
|
|
self.k_norm = RMSNorm(hidden_size=self.num_kv_heads * self.head_dim, eps=config.norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
**kwargs |
|
|
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor]]]: |
|
|
|
|
|
batch_size, seq_len, _ = hidden_states.size() |
|
|
|
|
|
|
|
|
q = self.q_proj(hidden_states) |
|
|
k = self.k_proj(hidden_states) |
|
|
v = self.v_proj(hidden_states) |
|
|
|
|
|
|
|
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.config.use_rope: |
|
|
q, k = self.rotary(q, k) |
|
|
|
|
|
|
|
|
if self.config.qk_norm: |
|
|
if self.config.qk_norm_share_param_across_head: |
|
|
q = self.q_norm(q) |
|
|
k = self.k_norm(k) |
|
|
else: |
|
|
q = self.q_norm(q.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)) |
|
|
k = self.k_norm(k.transpose(1, 2).contiguous().view(batch_size, seq_len, -1)) |
|
|
q = q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim).transpose(1, 2) |
|
|
|
|
|
|
|
|
if self.num_kv_groups > 1: |
|
|
k = k.repeat_interleave(self.num_kv_groups, dim=1) |
|
|
v = v.repeat_interleave(self.num_kv_groups, dim=1) |
|
|
|
|
|
|
|
|
from forgetting_transformer.ops.stickbreaking_attention_std import stickbreaking_attention_std |
|
|
|
|
|
o = stickbreaking_attention_std( |
|
|
q, k, v, |
|
|
head_first=True, |
|
|
sm_scale=self.scale, |
|
|
normalize=self.config.normalize_attention, |
|
|
attend_current=self.config.attend_current, |
|
|
) |
|
|
|
|
|
|
|
|
o = o.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size) |
|
|
o = self.o_proj(o) |
|
|
|
|
|
return o, None |
|
|
|
|
|
|
|
|
class StickbreakingMLP(nn.Module): |
|
|
def __init__(self, config: StickbreakingConfig): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
self.intermediate_size = config.intermediate_size or config.hidden_ratio * config.hidden_size |
|
|
|
|
|
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) |
|
|
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) |
|
|
self.act_fn = ACT2FN[config.hidden_act] |
|
|
|
|
|
def forward(self, x): |
|
|
return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) |
|
|
|
|
|
|
|
|
class StickbreakingBlock(nn.Module): |
|
|
def __init__(self, config: StickbreakingConfig, layer_idx: int): |
|
|
super().__init__() |
|
|
self.hidden_size = config.hidden_size |
|
|
|
|
|
self.attn_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) |
|
|
self.attn = StickbreakingAttention(config, layer_idx) |
|
|
|
|
|
self.mlp_norm = RMSNorm(hidden_size=config.hidden_size, eps=config.norm_eps) |
|
|
self.mlp = StickbreakingMLP(config) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_value: Optional[Tuple[torch.Tensor]] = None, |
|
|
use_cache: bool = False, |
|
|
**kwargs |
|
|
): |
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.attn_norm(hidden_states) |
|
|
hidden_states, present_key_value = self.attn( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
past_key_value=past_key_value, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
|
|
|
residual = hidden_states |
|
|
hidden_states = self.mlp_norm(hidden_states) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
hidden_states = residual + hidden_states |
|
|
|
|
|
return hidden_states, present_key_value |
|
|
|
|
|
|
|
|
class StickbreakingPreTrainedModel(PreTrainedModel): |
|
|
config_class = StickbreakingConfig |
|
|
base_model_prefix = "model" |
|
|
supports_gradient_checkpointing = True |
|
|
_no_split_modules = ["StickbreakingBlock"] |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
torch.nn.init.zeros_(module.bias) |
|
|
elif isinstance(module, nn.Embedding): |
|
|
torch.nn.init.normal_(module.weight, mean=0.0, std=self.config.initializer_range) |
|
|
|
|
|
|
|
|
class StickbreakingModel(StickbreakingPreTrainedModel): |
|
|
def __init__(self, config: StickbreakingConfig): |
|
|
super().__init__(config) |
|
|
self.padding_idx = config.pad_token_id |
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.embeddings = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) |
|
|
self.layers = nn.ModuleList([ |
|
|
StickbreakingBlock(config, layer_idx) |
|
|
for layer_idx in range(config.num_hidden_layers) |
|
|
]) |
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.norm_eps) |
|
|
|
|
|
self.gradient_checkpointing = False |
|
|
self.post_init() |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
): |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
hidden_states = self.embeddings(input_ids) |
|
|
|
|
|
for layer in self.layers: |
|
|
if self.gradient_checkpointing and self.training: |
|
|
hidden_states, _ = torch.utils.checkpoint.checkpoint( |
|
|
layer.__call__, |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
None, |
|
|
use_cache, |
|
|
) |
|
|
else: |
|
|
hidden_states, _ = layer( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
past_key_value=None, |
|
|
use_cache=use_cache, |
|
|
) |
|
|
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class StickbreakingForCausalLM(StickbreakingPreTrainedModel): |
|
|
_tied_weights_keys = ["lm_head.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.model = StickbreakingModel(config) |
|
|
self.vocab_size = config.vocab_size |
|
|
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.model.embeddings |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.model.embeddings = value |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.lm_head |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.lm_head = new_embeddings |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
past_key_values: Optional[List[torch.FloatTensor]] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[Tuple, CausalLMOutputWithPast]: |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
hidden_states = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
|
|
|
logits = self.lm_head(hidden_states) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.fuse_cross_entropy: |
|
|
loss_fct = FusedCrossEntropyLoss(inplace_backward=True, reduction='none') |
|
|
else: |
|
|
loss_fct = nn.CrossEntropyLoss(reduction='none') |
|
|
|
|
|
logits = logits.to(torch.float32) |
|
|
labels = labels.to(logits.device) |
|
|
loss = loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
loss = loss.view(*labels.size()) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return (loss,) + output if loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithPast( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
past_key_values=None, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |