from typing import Optional import torch from torch import nn, Tensor from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn from transformers.models.bart.modeling_bart import BartSdpaAttention from transformers.activations import ACT2FN class BertEmbeddings(nn.Module): """Construct the embeddings from word, position and token_type embeddings.""" def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size=2, pad_token_id=2, layer_norm_eps=1e-12, hidden_dropout_prob=0.1, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, **factory_kwargs) self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size, **factory_kwargs) # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load # any TensorFlow checkpoint file self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs) self.dropout = nn.Dropout(hidden_dropout_prob) # self.position_embedding_type = "rotary" # self.register_buffer( # "position_ids", torch.arange(max_position_embeddings).expand((1, -1)), persistent=False # ) # self.register_buffer( # "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False # ) def forward( self, input_ids: Optional[torch.LongTensor] = None, token_type_ids: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, past_key_values_length: int = 0, ) -> torch.Tensor: if input_ids is not None: input_shape = input_ids.size() else: input_shape = inputs_embeds.size()[:-1] seq_length = input_shape[1] # if position_ids is None: # position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves # issue #5664 if token_type_ids is None: # if hasattr(self, "token_type_ids"): # import ipdb; ipdb.set_trace() # buffered_token_type_ids = self.token_type_ids[:, :seq_length] # buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) # token_type_ids = buffered_token_type_ids_expanded # else: token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device) if inputs_embeds is None: inputs_embeds = self.word_embeddings(input_ids) token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class BertPooler(nn.Module): def __init__(self, hidden_size, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs) self.activation = nn.Tanh() def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: # We "pool" the model by simply taking the hidden state corresponding # to the first token. first_token_tensor = hidden_states[:, 0] pooled_output = self.dense(first_token_tensor) pooled_output = self.activation(pooled_output) return pooled_output class BertPredictionHeadTransform(nn.Module): def __init__(self, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs) if isinstance(hidden_act, str): self.transform_act_fn = ACT2FN[hidden_act] else: self.transform_act_fn = hidden_act self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: hidden_states = self.dense(hidden_states) hidden_states = self.transform_act_fn(hidden_states) hidden_states = self.LayerNorm(hidden_states) return hidden_states class BertLMPredictionHead(nn.Module): def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.transform = BertPredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps, **factory_kwargs) # The output weights are the same as the input embeddings, but there is # an output-only bias for each token. self.decoder = nn.Linear(hidden_size, vocab_size, bias=False, **factory_kwargs) self.bias = nn.Parameter(torch.zeros(vocab_size, **factory_kwargs)) # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings` self.decoder.bias = self.bias def _tie_weights(self): self.decoder.bias = self.bias def forward(self, hidden_states): hidden_states = self.transform(hidden_states) hidden_states = self.decoder(hidden_states) return hidden_states class BertPreTrainingHeads(nn.Module): def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None): factory_kwargs = {"device": device, "dtype": dtype} super().__init__() self.predictions = BertLMPredictionHead(vocab_size, hidden_size, hidden_act, layer_norm_eps, **factory_kwargs) self.seq_relationship = nn.Linear(hidden_size, 2, **factory_kwargs) def forward(self, sequence_output, pooled_output): prediction_scores = self.predictions(sequence_output) seq_relationship_score = self.seq_relationship(pooled_output) return prediction_scores, seq_relationship_score class BlockCrossAttention(nn.Module): def __init__( self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False ): """ Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection" This Block has a slightly different structure compared to a regular prenorm Transformer block. The standard block is: LN -> MHA/MLP -> Add. [Ref: https://arxiv.org/abs/2002.04745] Here we have: Add -> LN -> Mixer, returning both the hidden_states (output of the mixer) and the residual. This is purely for performance reasons, as we can fuse add and LayerNorm. The residual needs to be provided (except for the very first block). """ super().__init__() self.residual_in_fp32 = residual_in_fp32 self.fused_add_norm = fused_add_norm self.norm = norm_cls(dim) self.mixer = mixer_cls(dim) self.encoder_attn = BartSdpaAttention(embed_dim=dim, num_heads=1) if mlp_cls is not nn.Identity: self.norm2 = norm_cls(dim) self.mlp = mlp_cls(dim) else: self.mlp = None if self.fused_add_norm: assert RMSNorm is not None, "RMSNorm import fails" assert isinstance( self.norm, (nn.LayerNorm, RMSNorm) ), "Only LayerNorm and RMSNorm are supported for fused_add_norm" def forward( self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs ): r"""Pass the input through the encoder layer. Args: hidden_states: the sequence to the encoder layer (required). residual: hidden_states = Mixer(LN(residual)) """ if not self.fused_add_norm: residual = (hidden_states + residual) if residual is not None else hidden_states hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: hidden_states, residual = layer_norm_fn( hidden_states, self.norm.weight, self.norm.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm.eps, is_rms_norm=isinstance(self.norm, RMSNorm) ) hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs) # cross-attention hidden_states, _, _ = self.encoder_attn(hidden_states, encoder_hidden_states, attention_mask=attention_mask) if self.mlp is not None: if not self.fused_add_norm: residual = hidden_states + residual hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype)) if self.residual_in_fp32: residual = residual.to(torch.float32) else: hidden_states, residual = layer_norm_fn( hidden_states, self.norm2.weight, self.norm2.bias, residual=residual, prenorm=True, residual_in_fp32=self.residual_in_fp32, eps=self.norm2.eps, is_rms_norm=isinstance(self.norm2, RMSNorm) ) hidden_states = self.mlp(hidden_states) return hidden_states, residual def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)