File size: 10,308 Bytes
3d83373 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 | from typing import Optional
import torch
from torch import nn, Tensor
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn
from transformers.models.bart.modeling_bart import BartSdpaAttention
from transformers.activations import ACT2FN
class BertEmbeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings."""
def __init__(self, vocab_size, hidden_size, max_position_embeddings, type_vocab_size=2, pad_token_id=2, layer_norm_eps=1e-12, hidden_dropout_prob=0.1, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size, padding_idx=pad_token_id, **factory_kwargs)
self.token_type_embeddings = nn.Embedding(type_vocab_size, hidden_size, **factory_kwargs)
# self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
# any TensorFlow checkpoint file
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs)
self.dropout = nn.Dropout(hidden_dropout_prob)
# self.position_embedding_type = "rotary"
# self.register_buffer(
# "position_ids", torch.arange(max_position_embeddings).expand((1, -1)), persistent=False
# )
# self.register_buffer(
# "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
# )
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
token_type_ids: Optional[torch.LongTensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
past_key_values_length: int = 0,
) -> torch.Tensor:
if input_ids is not None:
input_shape = input_ids.size()
else:
input_shape = inputs_embeds.size()[:-1]
seq_length = input_shape[1]
# if position_ids is None:
# position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
# issue #5664
if token_type_ids is None:
# if hasattr(self, "token_type_ids"):
# import ipdb; ipdb.set_trace()
# buffered_token_type_ids = self.token_type_ids[:, :seq_length]
# buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
# token_type_ids = buffered_token_type_ids_expanded
# else:
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=input_ids.device)
if inputs_embeds is None:
inputs_embeds = self.word_embeddings(input_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
embeddings = inputs_embeds + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class BertPooler(nn.Module):
def __init__(self, hidden_size, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
self.activation = nn.Tanh()
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
# We "pool" the model by simply taking the hidden state corresponding
# to the first token.
first_token_tensor = hidden_states[:, 0]
pooled_output = self.dense(first_token_tensor)
pooled_output = self.activation(pooled_output)
return pooled_output
class BertPredictionHeadTransform(nn.Module):
def __init__(self, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.dense = nn.Linear(hidden_size, hidden_size, **factory_kwargs)
if isinstance(hidden_act, str):
self.transform_act_fn = ACT2FN[hidden_act]
else:
self.transform_act_fn = hidden_act
self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps, **factory_kwargs)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dense(hidden_states)
hidden_states = self.transform_act_fn(hidden_states)
hidden_states = self.LayerNorm(hidden_states)
return hidden_states
class BertLMPredictionHead(nn.Module):
def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.transform = BertPredictionHeadTransform(hidden_size, hidden_act, layer_norm_eps, **factory_kwargs)
# The output weights are the same as the input embeddings, but there is
# an output-only bias for each token.
self.decoder = nn.Linear(hidden_size, vocab_size, bias=False, **factory_kwargs)
self.bias = nn.Parameter(torch.zeros(vocab_size, **factory_kwargs))
# Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
self.decoder.bias = self.bias
def _tie_weights(self):
self.decoder.bias = self.bias
def forward(self, hidden_states):
hidden_states = self.transform(hidden_states)
hidden_states = self.decoder(hidden_states)
return hidden_states
class BertPreTrainingHeads(nn.Module):
def __init__(self, vocab_size, hidden_size, hidden_act="gelu", layer_norm_eps=1e-12, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
self.predictions = BertLMPredictionHead(vocab_size, hidden_size, hidden_act, layer_norm_eps, **factory_kwargs)
self.seq_relationship = nn.Linear(hidden_size, 2, **factory_kwargs)
def forward(self, sequence_output, pooled_output):
prediction_scores = self.predictions(sequence_output)
seq_relationship_score = self.seq_relationship(pooled_output)
return prediction_scores, seq_relationship_score
class BlockCrossAttention(nn.Module):
def __init__(
self, dim, mixer_cls, mlp_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False
):
"""
Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
This Block has a slightly different structure compared to a regular
prenorm Transformer block.
The standard block is: LN -> MHA/MLP -> Add.
[Ref: https://arxiv.org/abs/2002.04745]
Here we have: Add -> LN -> Mixer, returning both
the hidden_states (output of the mixer) and the residual.
This is purely for performance reasons, as we can fuse add and LayerNorm.
The residual needs to be provided (except for the very first block).
"""
super().__init__()
self.residual_in_fp32 = residual_in_fp32
self.fused_add_norm = fused_add_norm
self.norm = norm_cls(dim)
self.mixer = mixer_cls(dim)
self.encoder_attn = BartSdpaAttention(embed_dim=dim, num_heads=1)
if mlp_cls is not nn.Identity:
self.norm2 = norm_cls(dim)
self.mlp = mlp_cls(dim)
else:
self.mlp = None
if self.fused_add_norm:
assert RMSNorm is not None, "RMSNorm import fails"
assert isinstance(
self.norm, (nn.LayerNorm, RMSNorm)
), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
def forward(
self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None, encoder_hidden_states=None, attention_mask=None, **mixer_kwargs
):
r"""Pass the input through the encoder layer.
Args:
hidden_states: the sequence to the encoder layer (required).
residual: hidden_states = Mixer(LN(residual))
"""
if not self.fused_add_norm:
residual = (hidden_states + residual) if residual is not None else hidden_states
hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm.weight,
self.norm.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm.eps,
is_rms_norm=isinstance(self.norm, RMSNorm)
)
hidden_states = self.mixer(hidden_states, inference_params=inference_params, **mixer_kwargs)
# cross-attention
hidden_states, _, _ = self.encoder_attn(hidden_states, encoder_hidden_states, attention_mask=attention_mask)
if self.mlp is not None:
if not self.fused_add_norm:
residual = hidden_states + residual
hidden_states = self.norm2(residual.to(dtype=self.norm2.weight.dtype))
if self.residual_in_fp32:
residual = residual.to(torch.float32)
else:
hidden_states, residual = layer_norm_fn(
hidden_states,
self.norm2.weight,
self.norm2.bias,
residual=residual,
prenorm=True,
residual_in_fp32=self.residual_in_fp32,
eps=self.norm2.eps,
is_rms_norm=isinstance(self.norm2, RMSNorm)
)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
|