|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
|
|
|
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn |
|
|
from transformers.models.bart.modeling_bart import BartSdpaAttention |
|
|
from transformers.activations import ACT2FN |
|
|
|
|
|
|
|
|
class BertEmbeddings(nn.Module): |
|
|
"""Construct the embeddings from word, position and token_type embeddings.""" |
|
|
|
|
|
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size=2, pad_token_id=2, layer_norm_eps=1e-12, hidden_dropout_prob=0.1, device=None, dtype=None): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, **factory_kwargs) |
|
|
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size, **factory_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs) |
|
|
self.dropout = nn.Dropout(hidden_dropout_prob) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
past_key_values_length: int = 0, |
|
|
) -> torch.Tensor: |
|
|
if input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
else: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
|
|
seq_length = input_shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
|
embeddings = self.LayerNorm(embeddings) |
|
|
embeddings = self.dropout(embeddings) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
class BertPooler(nn.Module): |
|
|
def __init__(self, hidden_size, device=None, dtype=None): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs) |
|
|
self.activation = nn.Tanh() |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] |
|
|
pooled_output = self.dense(first_token_tensor) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
return pooled_output |
|
|
|
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module): |
|
|
def __init__(self, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs) |
|
|
if isinstance(hidden_act, str): |
|
|
self.transform_act_fn = ACT2FN[hidden_act] |
|
|
else: |
|
|
self.transform_act_fn = hidden_act |
|
|
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.transform_act_fn(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module): |
|
|
def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.transform = BertPredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps, **factory_kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False, **factory_kwargs) |
|
|
|
|
|
self.bias = nn.Parameter(torch.zeros(vocab_size, **factory_kwargs)) |
|
|
|
|
|
|
|
|
self.decoder.bias = self.bias |
|
|
|
|
|
def _tie_weights(self): |
|
|
self.decoder.bias = self.bias |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
hidden_states = self.transform(hidden_states) |
|
|
hidden_states = self.decoder(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BertPreTrainingHeads(nn.Module): |
|
|
def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None): |
|
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
|
super().__init__() |
|
|
self.predictions = BertLMPredictionHead(vocab_size, hidden_size, hidden_act, layer_norm_eps, **factory_kwargs) |
|
|
self.seq_relationship = nn.Linear(hidden_size, 2, **factory_kwargs) |
|
|
|
|
|
def forward(self, sequence_output, pooled_output): |
|
|
prediction_scores = self.predictions(sequence_output) |
|
|
seq_relationship_score = self.seq_relationship(pooled_output) |
|
|
return prediction_scores, seq_relationship_score |
|
|
|
|
|
|
|
|
|
|
|
class BlockCrossAttention(nn.Module): |
|
|
def __init__( |
|
|
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False |
|
|
): |
|
|
""" |
|
|
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" |
|
|
|
|
|
This Block has a slightly different structure compared to a regular |
|
|
prenorm Transformer block. |
|
|
The standard block is: LN -> MHA/MLP -> Add. |
|
|
[Ref: https://arxiv.org/abs/2002.04745] |
|
|
Here we have: Add -> LN -> Mixer, returning both |
|
|
the hidden_states (output of the mixer) and the residual. |
|
|
This is purely for performance reasons, as we can fuse add and LayerNorm. |
|
|
The residual needs to be provided (except for the very first block). |
|
|
""" |
|
|
super().__init__() |
|
|
self.residual_in_fp32 = residual_in_fp32 |
|
|
self.fused_add_norm = fused_add_norm |
|
|
self.norm = norm_cls(dim) |
|
|
self.mixer = mixer_cls(dim) |
|
|
self.encoder_attn = BartSdpaAttention(embed_dim=dim, num_heads=1) |
|
|
if mlp_cls is not nn.Identity: |
|
|
self.norm2 = norm_cls(dim) |
|
|
self.mlp = mlp_cls(dim) |
|
|
else: |
|
|
self.mlp = None |
|
|
if self.fused_add_norm: |
|
|
assert RMSNorm is not None, "RMSNorm import fails" |
|
|
assert isinstance( |
|
|
self.norm, (nn.LayerNorm, RMSNorm) |
|
|
), "Only LayerNorm and RMSNorm are supported for fused_add_norm" |
|
|
|
|
|
def forward( |
|
|
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs |
|
|
): |
|
|
r"""Pass the input through the encoder layer. |
|
|
|
|
|
Args: |
|
|
hidden_states: the sequence to the encoder layer (required). |
|
|
residual: hidden_states = Mixer(LN(residual)) |
|
|
""" |
|
|
if not self.fused_add_norm: |
|
|
residual = (hidden_states + residual) if residual is not None else hidden_states |
|
|
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) |
|
|
if self.residual_in_fp32: |
|
|
residual = residual.to(torch.float32) |
|
|
else: |
|
|
hidden_states, residual = layer_norm_fn( |
|
|
hidden_states, |
|
|
self.norm.weight, |
|
|
self.norm.bias, |
|
|
residual=residual, |
|
|
prenorm=True, |
|
|
residual_in_fp32=self.residual_in_fp32, |
|
|
eps=self.norm.eps, |
|
|
is_rms_norm=isinstance(self.norm, RMSNorm) |
|
|
) |
|
|
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs) |
|
|
|
|
|
|
|
|
hidden_states, _, _ = self.encoder_attn(hidden_states, encoder_hidden_states, attention_mask=attention_mask) |
|
|
|
|
|
if self.mlp is not None: |
|
|
if not self.fused_add_norm: |
|
|
residual = hidden_states + residual |
|
|
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) |
|
|
if self.residual_in_fp32: |
|
|
residual = residual.to(torch.float32) |
|
|
else: |
|
|
hidden_states, residual = layer_norm_fn( |
|
|
hidden_states, |
|
|
self.norm2.weight, |
|
|
self.norm2.bias, |
|
|
residual=residual, |
|
|
prenorm=True, |
|
|
residual_in_fp32=self.residual_in_fp32, |
|
|
eps=self.norm2.eps, |
|
|
is_rms_norm=isinstance(self.norm2, RMSNorm) |
|
|
) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
|
|
|
return hidden_states, residual |
|
|
|
|
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
|
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) |
|
|
|