|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""PyTorch BERT model.""" |
|
|
|
|
|
import math |
|
|
import os |
|
|
import warnings |
|
|
from dataclasses import dataclass |
|
|
from typing import Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
|
|
|
from ...activations import ACT2FN |
|
|
from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache |
|
|
from ...generation import GenerationMixin |
|
|
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa |
|
|
from ...modeling_layers import GradientCheckpointingLayer |
|
|
from ...modeling_outputs import ( |
|
|
BaseModelOutputWithPastAndCrossAttentions, |
|
|
BaseModelOutputWithPoolingAndCrossAttentions, |
|
|
CausalLMOutputWithCrossAttentions, |
|
|
MaskedLMOutput, |
|
|
MultipleChoiceModelOutput, |
|
|
NextSentencePredictorOutput, |
|
|
QuestionAnsweringModelOutput, |
|
|
SequenceClassifierOutput, |
|
|
TokenClassifierOutput, |
|
|
) |
|
|
from ...modeling_utils import PreTrainedModel |
|
|
from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer |
|
|
from ...utils import ModelOutput, auto_docstring, logging |
|
|
from ...utils.deprecation import deprecate_kwarg |
|
|
from .configuration_bert import BertConfig |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def load_tf_weights_in_bert(model, config, tf_checkpoint_path): |
|
|
"""Load tf checkpoints in a pytorch model.""" |
|
|
try: |
|
|
import re |
|
|
|
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
except ImportError: |
|
|
logger.error( |
|
|
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see " |
|
|
"https://www.tensorflow.org/install/ for installation instructions." |
|
|
) |
|
|
raise |
|
|
tf_path = os.path.abspath(tf_checkpoint_path) |
|
|
logger.info(f"Converting TensorFlow checkpoint from {tf_path}") |
|
|
|
|
|
init_vars = tf.train.list_variables(tf_path) |
|
|
names = [] |
|
|
arrays = [] |
|
|
for name, shape in init_vars: |
|
|
logger.info(f"Loading TF weight {name} with shape {shape}") |
|
|
array = tf.train.load_variable(tf_path, name) |
|
|
names.append(name) |
|
|
arrays.append(array) |
|
|
|
|
|
for name, array in zip(names, arrays): |
|
|
name = name.split("/") |
|
|
|
|
|
|
|
|
if any( |
|
|
n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"] |
|
|
for n in name |
|
|
): |
|
|
logger.info(f"Skipping {'/'.join(name)}") |
|
|
continue |
|
|
pointer = model |
|
|
for m_name in name: |
|
|
if re.fullmatch(r"[A-Za-z]+_\d+", m_name): |
|
|
scope_names = re.split(r"_(\d+)", m_name) |
|
|
else: |
|
|
scope_names = [m_name] |
|
|
if scope_names[0] == "kernel" or scope_names[0] == "gamma": |
|
|
pointer = getattr(pointer, "weight") |
|
|
elif scope_names[0] == "output_bias" or scope_names[0] == "beta": |
|
|
pointer = getattr(pointer, "bias") |
|
|
elif scope_names[0] == "output_weights": |
|
|
pointer = getattr(pointer, "weight") |
|
|
elif scope_names[0] == "squad": |
|
|
pointer = getattr(pointer, "classifier") |
|
|
else: |
|
|
try: |
|
|
pointer = getattr(pointer, scope_names[0]) |
|
|
except AttributeError: |
|
|
logger.info(f"Skipping {'/'.join(name)}") |
|
|
continue |
|
|
if len(scope_names) >= 2: |
|
|
num = int(scope_names[1]) |
|
|
pointer = pointer[num] |
|
|
if m_name[-11:] == "_embeddings": |
|
|
pointer = getattr(pointer, "weight") |
|
|
elif m_name == "kernel": |
|
|
array = np.transpose(array) |
|
|
try: |
|
|
if pointer.shape != array.shape: |
|
|
raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched") |
|
|
except ValueError as e: |
|
|
e.args += (pointer.shape, array.shape) |
|
|
raise |
|
|
logger.info(f"Initialize PyTorch weight {name}") |
|
|
pointer.data = torch.from_numpy(array) |
|
|
return model |
|
|
|
|
|
|
|
|
class BertEmbeddings(nn.Module): |
|
|
"""Construct the embeddings from word, position and token_type embeddings.""" |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id) |
|
|
self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size) |
|
|
self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
self.position_embedding_type = getattr(config, "position_embedding_type", "absolute") |
|
|
self.register_buffer( |
|
|
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False |
|
|
) |
|
|
self.register_buffer( |
|
|
"token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False |
|
|
) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.LongTensor] = None, |
|
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
|
position_ids: Optional[torch.LongTensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
past_key_values_length: int = 0, |
|
|
) -> torch.Tensor: |
|
|
if input_ids is not None: |
|
|
input_shape = input_ids.size() |
|
|
else: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
|
|
seq_length = input_shape[1] |
|
|
|
|
|
if position_ids is None: |
|
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self, "token_type_ids"): |
|
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
|
|
if inputs_embeds is None: |
|
|
inputs_embeds = self.word_embeddings(input_ids) |
|
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
|
if self.position_embedding_type == "absolute": |
|
|
position_embeddings = self.position_embeddings(position_ids) |
|
|
embeddings += position_embeddings |
|
|
embeddings = self.LayerNorm(embeddings) |
|
|
embeddings = self.dropout(embeddings) |
|
|
return embeddings |
|
|
|
|
|
|
|
|
class BertSelfAttention(nn.Module): |
|
|
def __init__(self, config, position_embedding_type=None, layer_idx=None): |
|
|
super().__init__() |
|
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): |
|
|
raise ValueError( |
|
|
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " |
|
|
f"heads ({config.num_attention_heads})" |
|
|
) |
|
|
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
|
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
self.position_embedding_type = position_embedding_type or getattr( |
|
|
config, "position_embedding_type", "absolute" |
|
|
) |
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
|
self.max_position_embeddings = config.max_position_embeddings |
|
|
self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size) |
|
|
|
|
|
self.is_decoder = config.is_decoder |
|
|
self.layer_idx = layer_idx |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor]: |
|
|
batch_size, seq_length, _ = hidden_states.shape |
|
|
query_layer = self.query(hidden_states) |
|
|
query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( |
|
|
1, 2 |
|
|
) |
|
|
|
|
|
is_updated = False |
|
|
is_cross_attention = encoder_hidden_states is not None |
|
|
if past_key_values is not None: |
|
|
if isinstance(past_key_values, EncoderDecoderCache): |
|
|
is_updated = past_key_values.is_updated.get(self.layer_idx) |
|
|
if is_cross_attention: |
|
|
|
|
|
curr_past_key_value = past_key_values.cross_attention_cache |
|
|
else: |
|
|
curr_past_key_value = past_key_values.self_attention_cache |
|
|
else: |
|
|
curr_past_key_value = past_key_values |
|
|
|
|
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states |
|
|
if is_cross_attention and past_key_values is not None and is_updated: |
|
|
|
|
|
key_layer = curr_past_key_value.layers[self.layer_idx].keys |
|
|
value_layer = curr_past_key_value.layers[self.layer_idx].values |
|
|
else: |
|
|
key_layer = self.key(current_states) |
|
|
key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose( |
|
|
1, 2 |
|
|
) |
|
|
value_layer = self.value(current_states) |
|
|
value_layer = value_layer.view( |
|
|
batch_size, -1, self.num_attention_heads, self.attention_head_size |
|
|
).transpose(1, 2) |
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
cache_position = cache_position if not is_cross_attention else None |
|
|
key_layer, value_layer = curr_past_key_value.update( |
|
|
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} |
|
|
) |
|
|
|
|
|
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): |
|
|
past_key_values.is_updated[self.layer_idx] = True |
|
|
|
|
|
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
|
|
|
|
if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query": |
|
|
query_length, key_length = query_layer.shape[2], key_layer.shape[2] |
|
|
if past_key_values is not None: |
|
|
position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view( |
|
|
-1, 1 |
|
|
) |
|
|
else: |
|
|
position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1) |
|
|
position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1) |
|
|
distance = position_ids_l - position_ids_r |
|
|
|
|
|
positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) |
|
|
positional_embedding = positional_embedding.to(dtype=query_layer.dtype) |
|
|
|
|
|
if self.position_embedding_type == "relative_key": |
|
|
relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
|
attention_scores = attention_scores + relative_position_scores |
|
|
elif self.position_embedding_type == "relative_key_query": |
|
|
relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding) |
|
|
relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding) |
|
|
attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key |
|
|
|
|
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
|
if attention_mask is not None: |
|
|
|
|
|
attention_scores = attention_scores + attention_mask |
|
|
|
|
|
|
|
|
attention_probs = nn.functional.softmax(attention_scores, dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
attention_probs = self.dropout(attention_probs) |
|
|
|
|
|
|
|
|
if head_mask is not None: |
|
|
attention_probs = attention_probs * head_mask |
|
|
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
|
|
|
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
|
context_layer = context_layer.view(new_context_layer_shape) |
|
|
|
|
|
return context_layer, attention_probs |
|
|
|
|
|
|
|
|
class BertSdpaSelfAttention(BertSelfAttention): |
|
|
def __init__(self, config, position_embedding_type=None, layer_idx=None): |
|
|
super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx) |
|
|
self.dropout_prob = config.attention_probs_dropout_prob |
|
|
|
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor]: |
|
|
if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None: |
|
|
|
|
|
logger.warning_once( |
|
|
"BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support " |
|
|
"non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to " |
|
|
"the manual attention implementation, but specifying the manual implementation will be required from " |
|
|
"Transformers version v5.0.0 onwards. This warning can be removed using the argument " |
|
|
'`attn_implementation="eager"` when loading the model.' |
|
|
) |
|
|
return super().forward( |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
head_mask, |
|
|
encoder_hidden_states, |
|
|
past_key_values, |
|
|
output_attentions, |
|
|
cache_position, |
|
|
) |
|
|
|
|
|
bsz, tgt_len, _ = hidden_states.size() |
|
|
|
|
|
query_layer = ( |
|
|
self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2) |
|
|
) |
|
|
|
|
|
is_updated = False |
|
|
is_cross_attention = encoder_hidden_states is not None |
|
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states |
|
|
if past_key_values is not None: |
|
|
if isinstance(past_key_values, EncoderDecoderCache): |
|
|
is_updated = past_key_values.is_updated.get(self.layer_idx) |
|
|
if is_cross_attention: |
|
|
|
|
|
curr_past_key_value = past_key_values.cross_attention_cache |
|
|
else: |
|
|
curr_past_key_value = past_key_values.self_attention_cache |
|
|
else: |
|
|
curr_past_key_value = past_key_values |
|
|
|
|
|
current_states = encoder_hidden_states if is_cross_attention else hidden_states |
|
|
if is_cross_attention and past_key_values is not None and is_updated: |
|
|
|
|
|
key_layer = curr_past_key_value.layers[self.layer_idx].keys |
|
|
value_layer = curr_past_key_value.layers[self.layer_idx].values |
|
|
else: |
|
|
key_layer = ( |
|
|
self.key(current_states) |
|
|
.view(bsz, -1, self.num_attention_heads, self.attention_head_size) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
value_layer = ( |
|
|
self.value(current_states) |
|
|
.view(bsz, -1, self.num_attention_heads, self.attention_head_size) |
|
|
.transpose(1, 2) |
|
|
) |
|
|
|
|
|
if past_key_values is not None: |
|
|
|
|
|
cache_position = cache_position if not is_cross_attention else None |
|
|
key_layer, value_layer = curr_past_key_value.update( |
|
|
key_layer, value_layer, self.layer_idx, {"cache_position": cache_position} |
|
|
) |
|
|
|
|
|
if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache): |
|
|
past_key_values.is_updated[self.layer_idx] = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1 |
|
|
|
|
|
attn_output = torch.nn.functional.scaled_dot_product_attention( |
|
|
query_layer, |
|
|
key_layer, |
|
|
value_layer, |
|
|
attn_mask=attention_mask, |
|
|
dropout_p=self.dropout_prob if self.training else 0.0, |
|
|
is_causal=is_causal, |
|
|
) |
|
|
|
|
|
attn_output = attn_output.transpose(1, 2) |
|
|
attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size) |
|
|
|
|
|
return attn_output, None |
|
|
|
|
|
|
|
|
class BertSelfOutput(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
BERT_SELF_ATTENTION_CLASSES = { |
|
|
"eager": BertSelfAttention, |
|
|
"sdpa": BertSdpaSelfAttention, |
|
|
} |
|
|
|
|
|
|
|
|
class BertAttention(nn.Module): |
|
|
def __init__(self, config, position_embedding_type=None, layer_idx=None): |
|
|
super().__init__() |
|
|
self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation]( |
|
|
config, |
|
|
position_embedding_type=position_embedding_type, |
|
|
layer_idx=layer_idx, |
|
|
) |
|
|
self.output = BertSelfOutput(config) |
|
|
self.pruned_heads = set() |
|
|
|
|
|
def prune_heads(self, heads): |
|
|
if len(heads) == 0: |
|
|
return |
|
|
heads, index = find_pruneable_heads_and_indices( |
|
|
heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads |
|
|
) |
|
|
|
|
|
|
|
|
self.self.query = prune_linear_layer(self.self.query, index) |
|
|
self.self.key = prune_linear_layer(self.self.key, index) |
|
|
self.self.value = prune_linear_layer(self.self.value, index) |
|
|
self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) |
|
|
|
|
|
|
|
|
self.self.num_attention_heads = self.self.num_attention_heads - len(heads) |
|
|
self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads |
|
|
self.pruned_heads = self.pruned_heads.union(heads) |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor]: |
|
|
self_outputs = self.self( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
attention_output = self.output(self_outputs[0], hidden_states) |
|
|
outputs = (attention_output,) + self_outputs[1:] |
|
|
return outputs |
|
|
|
|
|
|
|
|
class BertIntermediate(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) |
|
|
if isinstance(config.hidden_act, str): |
|
|
self.intermediate_act_fn = ACT2FN[config.hidden_act] |
|
|
else: |
|
|
self.intermediate_act_fn = config.hidden_act |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.intermediate_act_fn(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BertOutput(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.intermediate_size, config.hidden_size) |
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
self.dropout = nn.Dropout(config.hidden_dropout_prob) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.dropout(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states + input_tensor) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BertLayer(GradientCheckpointingLayer): |
|
|
def __init__(self, config, layer_idx=None): |
|
|
super().__init__() |
|
|
self.chunk_size_feed_forward = config.chunk_size_feed_forward |
|
|
self.seq_len_dim = 1 |
|
|
self.attention = BertAttention(config, layer_idx=layer_idx) |
|
|
self.is_decoder = config.is_decoder |
|
|
self.add_cross_attention = config.add_cross_attention |
|
|
if self.add_cross_attention: |
|
|
if not self.is_decoder: |
|
|
raise ValueError(f"{self} should be used as a decoder model if cross attention is added") |
|
|
self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx) |
|
|
self.intermediate = BertIntermediate(config) |
|
|
self.output = BertOutput(config) |
|
|
|
|
|
@deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58") |
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> tuple[torch.Tensor]: |
|
|
self_attention_outputs = self.attention( |
|
|
hidden_states, |
|
|
attention_mask=attention_mask, |
|
|
head_mask=head_mask, |
|
|
output_attentions=output_attentions, |
|
|
past_key_values=past_key_values, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
attention_output = self_attention_outputs[0] |
|
|
outputs = self_attention_outputs[1:] |
|
|
|
|
|
if self.is_decoder and encoder_hidden_states is not None: |
|
|
if not hasattr(self, "crossattention"): |
|
|
raise ValueError( |
|
|
f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers" |
|
|
" by setting `config.add_cross_attention=True`" |
|
|
) |
|
|
|
|
|
cross_attention_outputs = self.crossattention( |
|
|
attention_output, |
|
|
attention_mask=encoder_attention_mask, |
|
|
head_mask=head_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
attention_output = cross_attention_outputs[0] |
|
|
outputs = outputs + cross_attention_outputs[1:] |
|
|
|
|
|
layer_output = apply_chunking_to_forward( |
|
|
self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output |
|
|
) |
|
|
outputs = (layer_output,) + outputs |
|
|
|
|
|
return outputs |
|
|
|
|
|
def feed_forward_chunk(self, attention_output): |
|
|
intermediate_output = self.intermediate(attention_output) |
|
|
layer_output = self.output(intermediate_output, attention_output) |
|
|
return layer_output |
|
|
|
|
|
|
|
|
class BertEncoder(nn.Module): |
|
|
def __init__(self, config, layer_idx=None): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
|
|
self.gradient_checkpointing = False |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
|
head_mask: Optional[torch.FloatTensor] = None, |
|
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
|
encoder_attention_mask: Optional[torch.FloatTensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = False, |
|
|
output_hidden_states: Optional[bool] = False, |
|
|
return_dict: Optional[bool] = True, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]: |
|
|
all_hidden_states = () if output_hidden_states else None |
|
|
all_self_attentions = () if output_attentions else None |
|
|
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None |
|
|
|
|
|
if self.gradient_checkpointing and self.training: |
|
|
if use_cache: |
|
|
logger.warning_once( |
|
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." |
|
|
) |
|
|
use_cache = False |
|
|
|
|
|
if use_cache and self.config.is_decoder and past_key_values is None: |
|
|
past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) |
|
|
|
|
|
if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple): |
|
|
logger.warning_once( |
|
|
"Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. " |
|
|
"You should pass an instance of `EncoderDecoderCache` instead, e.g. " |
|
|
"`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." |
|
|
) |
|
|
past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) |
|
|
|
|
|
for i, layer_module in enumerate(self.layer): |
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
layer_head_mask = head_mask[i] if head_mask is not None else None |
|
|
|
|
|
layer_outputs = layer_module( |
|
|
hidden_states, |
|
|
attention_mask, |
|
|
layer_head_mask, |
|
|
encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
output_attentions=output_attentions, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
hidden_states = layer_outputs[0] |
|
|
if output_attentions: |
|
|
all_self_attentions = all_self_attentions + (layer_outputs[1],) |
|
|
if self.config.add_cross_attention: |
|
|
all_cross_attentions = all_cross_attentions + (layer_outputs[2],) |
|
|
|
|
|
if output_hidden_states: |
|
|
all_hidden_states = all_hidden_states + (hidden_states,) |
|
|
|
|
|
if not return_dict: |
|
|
return tuple( |
|
|
v |
|
|
for v in [ |
|
|
hidden_states, |
|
|
past_key_values, |
|
|
all_hidden_states, |
|
|
all_self_attentions, |
|
|
all_cross_attentions, |
|
|
] |
|
|
if v is not None |
|
|
) |
|
|
return BaseModelOutputWithPastAndCrossAttentions( |
|
|
last_hidden_state=hidden_states, |
|
|
past_key_values=past_key_values, |
|
|
hidden_states=all_hidden_states, |
|
|
attentions=all_self_attentions, |
|
|
cross_attentions=all_cross_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
class BertPooler(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
self.activation = nn.Tanh() |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
|
|
|
|
|
|
first_token_tensor = hidden_states[:, 0] |
|
|
pooled_output = self.dense(first_token_tensor) |
|
|
pooled_output = self.activation(pooled_output) |
|
|
return pooled_output |
|
|
|
|
|
|
|
|
class BertPredictionHeadTransform(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
|
if isinstance(config.hidden_act, str): |
|
|
self.transform_act_fn = ACT2FN[config.hidden_act] |
|
|
else: |
|
|
self.transform_act_fn = config.hidden_act |
|
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
hidden_states = self.dense(hidden_states) |
|
|
hidden_states = self.transform_act_fn(hidden_states) |
|
|
hidden_states = self.LayerNorm(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BertLMPredictionHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.transform = BertPredictionHeadTransform(config) |
|
|
|
|
|
|
|
|
|
|
|
self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
|
|
|
|
|
self.bias = nn.Parameter(torch.zeros(config.vocab_size)) |
|
|
|
|
|
|
|
|
self.decoder.bias = self.bias |
|
|
|
|
|
def _tie_weights(self): |
|
|
self.decoder.bias = self.bias |
|
|
|
|
|
def forward(self, hidden_states): |
|
|
hidden_states = self.transform(hidden_states) |
|
|
hidden_states = self.decoder(hidden_states) |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class BertOnlyMLMHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.predictions = BertLMPredictionHead(config) |
|
|
|
|
|
def forward(self, sequence_output: torch.Tensor) -> torch.Tensor: |
|
|
prediction_scores = self.predictions(sequence_output) |
|
|
return prediction_scores |
|
|
|
|
|
|
|
|
class BertOnlyNSPHead(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
|
|
|
def forward(self, pooled_output): |
|
|
seq_relationship_score = self.seq_relationship(pooled_output) |
|
|
return seq_relationship_score |
|
|
|
|
|
|
|
|
class BertPreTrainingHeads(nn.Module): |
|
|
def __init__(self, config): |
|
|
super().__init__() |
|
|
self.predictions = BertLMPredictionHead(config) |
|
|
self.seq_relationship = nn.Linear(config.hidden_size, 2) |
|
|
|
|
|
def forward(self, sequence_output, pooled_output): |
|
|
prediction_scores = self.predictions(sequence_output) |
|
|
seq_relationship_score = self.seq_relationship(pooled_output) |
|
|
return prediction_scores, seq_relationship_score |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BertPreTrainedModel(PreTrainedModel): |
|
|
config: BertConfig |
|
|
load_tf_weights = load_tf_weights_in_bert |
|
|
base_model_prefix = "bert" |
|
|
supports_gradient_checkpointing = True |
|
|
_supports_sdpa = True |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize the weights""" |
|
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.bias is not None: |
|
|
module.bias.data.zero_() |
|
|
elif isinstance(module, nn.Embedding): |
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
|
if module.padding_idx is not None: |
|
|
module.weight.data[module.padding_idx].zero_() |
|
|
elif isinstance(module, nn.LayerNorm): |
|
|
module.bias.data.zero_() |
|
|
module.weight.data.fill_(1.0) |
|
|
elif isinstance(module, BertLMPredictionHead): |
|
|
module.bias.data.zero_() |
|
|
|
|
|
|
|
|
@dataclass |
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
Output type of [`BertForPreTraining`]. |
|
|
""" |
|
|
) |
|
|
class BertForPreTrainingOutput(ModelOutput): |
|
|
r""" |
|
|
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`): |
|
|
Total loss as the sum of the masked language modeling loss and the next sequence prediction |
|
|
(classification) loss. |
|
|
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): |
|
|
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax). |
|
|
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`): |
|
|
Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation |
|
|
before SoftMax). |
|
|
""" |
|
|
|
|
|
loss: Optional[torch.FloatTensor] = None |
|
|
prediction_logits: Optional[torch.FloatTensor] = None |
|
|
seq_relationship_logits: Optional[torch.FloatTensor] = None |
|
|
hidden_states: Optional[tuple[torch.FloatTensor]] = None |
|
|
attentions: Optional[tuple[torch.FloatTensor]] = None |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of |
|
|
cross-attention is added between the self-attention layers, following the architecture described in [Attention is |
|
|
all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, |
|
|
Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin. |
|
|
|
|
|
To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set |
|
|
to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and |
|
|
`add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass. |
|
|
""" |
|
|
) |
|
|
class BertModel(BertPreTrainedModel): |
|
|
_no_split_modules = ["BertEmbeddings", "BertLayer"] |
|
|
|
|
|
def __init__(self, config, add_pooling_layer=True): |
|
|
r""" |
|
|
add_pooling_layer (bool, *optional*, defaults to `True`): |
|
|
Whether to add a pooling layer |
|
|
""" |
|
|
super().__init__(config) |
|
|
self.config = config |
|
|
|
|
|
self.embeddings = BertEmbeddings(config) |
|
|
self.encoder = BertEncoder(config) |
|
|
|
|
|
self.pooler = BertPooler(config) if add_pooling_layer else None |
|
|
|
|
|
self.attn_implementation = config._attn_implementation |
|
|
self.position_embedding_type = config.position_embedding_type |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_input_embeddings(self): |
|
|
return self.embeddings.word_embeddings |
|
|
|
|
|
def set_input_embeddings(self, value): |
|
|
self.embeddings.word_embeddings = value |
|
|
|
|
|
def _prune_heads(self, heads_to_prune): |
|
|
""" |
|
|
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base |
|
|
class PreTrainedModel |
|
|
""" |
|
|
for layer, heads in heads_to_prune.items(): |
|
|
self.encoder.layer[layer].attention.prune_heads(heads) |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]: |
|
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
output_hidden_states = ( |
|
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
|
) |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
if self.config.is_decoder: |
|
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
|
else: |
|
|
use_cache = False |
|
|
|
|
|
if input_ids is not None and inputs_embeds is not None: |
|
|
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") |
|
|
elif input_ids is not None: |
|
|
self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask) |
|
|
input_shape = input_ids.size() |
|
|
elif inputs_embeds is not None: |
|
|
input_shape = inputs_embeds.size()[:-1] |
|
|
else: |
|
|
raise ValueError("You have to specify either input_ids or inputs_embeds") |
|
|
|
|
|
batch_size, seq_length = input_shape |
|
|
device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
|
|
|
past_key_values_length = 0 |
|
|
if past_key_values is not None: |
|
|
past_key_values_length = ( |
|
|
past_key_values[0][0].shape[-2] |
|
|
if not isinstance(past_key_values, Cache) |
|
|
else past_key_values.get_seq_length() |
|
|
) |
|
|
|
|
|
if token_type_ids is None: |
|
|
if hasattr(self.embeddings, "token_type_ids"): |
|
|
buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length] |
|
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length) |
|
|
token_type_ids = buffered_token_type_ids_expanded |
|
|
else: |
|
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device) |
|
|
|
|
|
embedding_output = self.embeddings( |
|
|
input_ids=input_ids, |
|
|
position_ids=position_ids, |
|
|
token_type_ids=token_type_ids, |
|
|
inputs_embeds=inputs_embeds, |
|
|
past_key_values_length=past_key_values_length, |
|
|
) |
|
|
|
|
|
if attention_mask is None: |
|
|
attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device) |
|
|
|
|
|
use_sdpa_attention_masks = ( |
|
|
self.attn_implementation == "sdpa" |
|
|
and self.position_embedding_type == "absolute" |
|
|
and head_mask is None |
|
|
and not output_attentions |
|
|
) |
|
|
|
|
|
|
|
|
if use_sdpa_attention_masks and attention_mask.dim() == 2: |
|
|
|
|
|
|
|
|
if self.config.is_decoder: |
|
|
extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa( |
|
|
attention_mask, |
|
|
input_shape, |
|
|
embedding_output, |
|
|
past_key_values_length, |
|
|
) |
|
|
else: |
|
|
extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
|
|
attention_mask, embedding_output.dtype, tgt_len=seq_length |
|
|
) |
|
|
else: |
|
|
|
|
|
|
|
|
extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) |
|
|
|
|
|
|
|
|
|
|
|
if self.config.is_decoder and encoder_hidden_states is not None: |
|
|
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size() |
|
|
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length) |
|
|
if encoder_attention_mask is None: |
|
|
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device) |
|
|
|
|
|
if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2: |
|
|
|
|
|
|
|
|
encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa( |
|
|
encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length |
|
|
) |
|
|
else: |
|
|
encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask) |
|
|
else: |
|
|
encoder_extended_attention_mask = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) |
|
|
|
|
|
encoder_outputs = self.encoder( |
|
|
embedding_output, |
|
|
attention_mask=extended_attention_mask, |
|
|
head_mask=head_mask, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_extended_attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
sequence_output = encoder_outputs[0] |
|
|
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None |
|
|
|
|
|
if not return_dict: |
|
|
return (sequence_output, pooled_output) + encoder_outputs[1:] |
|
|
|
|
|
return BaseModelOutputWithPoolingAndCrossAttentions( |
|
|
last_hidden_state=sequence_output, |
|
|
pooler_output=pooled_output, |
|
|
past_key_values=encoder_outputs.past_key_values, |
|
|
hidden_states=encoder_outputs.hidden_states, |
|
|
attentions=encoder_outputs.attentions, |
|
|
cross_attentions=encoder_outputs.cross_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next |
|
|
sentence prediction (classification)` head. |
|
|
""" |
|
|
) |
|
|
class BertForPreTraining(BertPreTrainedModel): |
|
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.bert = BertModel(config) |
|
|
self.cls = BertPreTrainingHeads(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.cls.predictions.decoder |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.cls.predictions.decoder = new_embeddings |
|
|
self.cls.predictions.bias = new_embeddings.bias |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
next_sentence_label: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], BertForPreTrainingOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), |
|
|
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
|
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence |
|
|
pair (see `input_ids` docstring) Indices should be in `[0, 1]`: |
|
|
|
|
|
- 0 indicates sequence B is a continuation of sequence A, |
|
|
- 1 indicates sequence B is a random sequence. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, BertForPreTraining |
|
|
>>> import torch |
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
|
|
>>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased") |
|
|
|
|
|
>>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt") |
|
|
>>> outputs = model(**inputs) |
|
|
|
|
|
>>> prediction_logits = outputs.prediction_logits |
|
|
>>> seq_relationship_logits = outputs.seq_relationship_logits |
|
|
``` |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output, pooled_output = outputs[:2] |
|
|
prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output) |
|
|
|
|
|
total_loss = None |
|
|
if labels is not None and next_sentence_label is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1)) |
|
|
total_loss = masked_lm_loss + next_sentence_loss |
|
|
|
|
|
if not return_dict: |
|
|
output = (prediction_scores, seq_relationship_score) + outputs[2:] |
|
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
|
|
return BertForPreTrainingOutput( |
|
|
loss=total_loss, |
|
|
prediction_logits=prediction_scores, |
|
|
seq_relationship_logits=seq_relationship_score, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
Bert Model with a `language modeling` head on top for CLM fine-tuning. |
|
|
""" |
|
|
) |
|
|
class BertLMHeadModel(BertPreTrainedModel, GenerationMixin): |
|
|
_tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
if not config.is_decoder: |
|
|
logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`") |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.cls.predictions.decoder |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.cls.predictions.decoder = new_embeddings |
|
|
self.cls.predictions.bias = new_embeddings.bias |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
past_key_values: Optional[Cache] = None, |
|
|
use_cache: Optional[bool] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
cache_position: Optional[torch.Tensor] = None, |
|
|
**loss_kwargs, |
|
|
) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in |
|
|
`[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are |
|
|
ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]` |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
if labels is not None: |
|
|
use_cache = False |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
use_cache=use_cache, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
cache_position=cache_position, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
prediction_scores = self.cls(sequence_output) |
|
|
|
|
|
lm_loss = None |
|
|
if labels is not None: |
|
|
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs) |
|
|
|
|
|
if not return_dict: |
|
|
output = (prediction_scores,) + outputs[2:] |
|
|
return ((lm_loss,) + output) if lm_loss is not None else output |
|
|
|
|
|
return CausalLMOutputWithCrossAttentions( |
|
|
loss=lm_loss, |
|
|
logits=prediction_scores, |
|
|
past_key_values=outputs.past_key_values, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
cross_attentions=outputs.cross_attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BertForMaskedLM(BertPreTrainedModel): |
|
|
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"] |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
if config.is_decoder: |
|
|
logger.warning( |
|
|
"If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for " |
|
|
"bi-directional self-attention." |
|
|
) |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
self.cls = BertOnlyMLMHead(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def get_output_embeddings(self): |
|
|
return self.cls.predictions.decoder |
|
|
|
|
|
def set_output_embeddings(self, new_embeddings): |
|
|
self.cls.predictions.decoder = new_embeddings |
|
|
self.cls.predictions.bias = new_embeddings.bias |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
encoder_hidden_states: Optional[torch.Tensor] = None, |
|
|
encoder_attention_mask: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], MaskedLMOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., |
|
|
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the |
|
|
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` |
|
|
""" |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
encoder_hidden_states=encoder_hidden_states, |
|
|
encoder_attention_mask=encoder_attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
prediction_scores = self.cls(sequence_output) |
|
|
|
|
|
masked_lm_loss = None |
|
|
if labels is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (prediction_scores,) + outputs[2:] |
|
|
return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output |
|
|
|
|
|
return MaskedLMOutput( |
|
|
loss=masked_lm_loss, |
|
|
logits=prediction_scores, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs): |
|
|
input_shape = input_ids.shape |
|
|
effective_batch_size = input_shape[0] |
|
|
|
|
|
|
|
|
if self.config.pad_token_id is None: |
|
|
raise ValueError("The PAD token should be defined for generation") |
|
|
|
|
|
attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1) |
|
|
dummy_token = torch.full( |
|
|
(effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device |
|
|
) |
|
|
input_ids = torch.cat([input_ids, dummy_token], dim=1) |
|
|
|
|
|
return {"input_ids": input_ids, "attention_mask": attention_mask} |
|
|
|
|
|
@classmethod |
|
|
def can_generate(cls) -> bool: |
|
|
""" |
|
|
Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a |
|
|
`prepare_inputs_for_generation` method. |
|
|
""" |
|
|
return False |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
Bert Model with a `next sentence prediction (classification)` head on top. |
|
|
""" |
|
|
) |
|
|
class BertForNextSentencePrediction(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.bert = BertModel(config) |
|
|
self.cls = BertOnlyNSPHead(config) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Union[tuple[torch.Tensor], NextSentencePredictorOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair |
|
|
(see `input_ids` docstring). Indices should be in `[0, 1]`: |
|
|
|
|
|
- 0 indicates sequence B is a continuation of sequence A, |
|
|
- 1 indicates sequence B is a random sequence. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
>>> from transformers import AutoTokenizer, BertForNextSentencePrediction |
|
|
>>> import torch |
|
|
|
|
|
>>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased") |
|
|
>>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased") |
|
|
|
|
|
>>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced." |
|
|
>>> next_sentence = "The sky is blue due to the shorter wavelength of blue light." |
|
|
>>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt") |
|
|
|
|
|
>>> outputs = model(**encoding, labels=torch.LongTensor([1])) |
|
|
>>> logits = outputs.logits |
|
|
>>> assert logits[0, 0] < logits[0, 1] # next sentence was random |
|
|
``` |
|
|
""" |
|
|
|
|
|
if "next_sentence_label" in kwargs: |
|
|
warnings.warn( |
|
|
"The `next_sentence_label` argument is deprecated and will be removed in a future version, use" |
|
|
" `labels` instead.", |
|
|
FutureWarning, |
|
|
) |
|
|
labels = kwargs.pop("next_sentence_label") |
|
|
|
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
seq_relationship_scores = self.cls(pooled_output) |
|
|
|
|
|
next_sentence_loss = None |
|
|
if labels is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (seq_relationship_scores,) + outputs[2:] |
|
|
return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output |
|
|
|
|
|
return NextSentencePredictorOutput( |
|
|
loss=next_sentence_loss, |
|
|
logits=seq_relationship_scores, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring( |
|
|
custom_intro=""" |
|
|
Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled |
|
|
output) e.g. for GLUE tasks. |
|
|
""" |
|
|
) |
|
|
class BertForSequenceClassification(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
self.config = config |
|
|
|
|
|
self.bert = BertModel(config) |
|
|
classifier_dropout = ( |
|
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
|
) |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
if self.config.problem_type is None: |
|
|
if self.num_labels == 1: |
|
|
self.config.problem_type = "regression" |
|
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
|
self.config.problem_type = "single_label_classification" |
|
|
else: |
|
|
self.config.problem_type = "multi_label_classification" |
|
|
|
|
|
if self.config.problem_type == "regression": |
|
|
loss_fct = MSELoss() |
|
|
if self.num_labels == 1: |
|
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
|
else: |
|
|
loss = loss_fct(logits, labels) |
|
|
elif self.config.problem_type == "single_label_classification": |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
elif self.config.problem_type == "multi_label_classification": |
|
|
loss_fct = BCEWithLogitsLoss() |
|
|
loss = loss_fct(logits, labels) |
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BertForMultipleChoice(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.bert = BertModel(config) |
|
|
classifier_dropout = ( |
|
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
|
) |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]: |
|
|
r""" |
|
|
input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`): |
|
|
Indices of input sequence tokens in the vocabulary. |
|
|
|
|
|
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and |
|
|
[`PreTrainedTokenizer.__call__`] for details. |
|
|
|
|
|
[What are input IDs?](../glossary#input-ids) |
|
|
token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): |
|
|
Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0, |
|
|
1]`: |
|
|
|
|
|
- 0 corresponds to a *sentence A* token, |
|
|
- 1 corresponds to a *sentence B* token. |
|
|
|
|
|
[What are token type IDs?](../glossary#token-type-ids) |
|
|
position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*): |
|
|
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, |
|
|
config.max_position_embeddings - 1]`. |
|
|
|
|
|
[What are position IDs?](../glossary#position-ids) |
|
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*): |
|
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This |
|
|
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the |
|
|
model's internal embedding lookup matrix. |
|
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
|
Labels for computing the multiple choice classification loss. Indices should be in `[0, ..., |
|
|
num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See |
|
|
`input_ids` above) |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1] |
|
|
|
|
|
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None |
|
|
attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None |
|
|
token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None |
|
|
position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None |
|
|
inputs_embeds = ( |
|
|
inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1)) |
|
|
if inputs_embeds is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
logits = self.classifier(pooled_output) |
|
|
reshaped_logits = logits.view(-1, num_choices) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct(reshaped_logits, labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (reshaped_logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return MultipleChoiceModelOutput( |
|
|
loss=loss, |
|
|
logits=reshaped_logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BertForTokenClassification(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
classifier_dropout = ( |
|
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
|
) |
|
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
self.classifier = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], TokenClassifierOutput]: |
|
|
r""" |
|
|
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): |
|
|
Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`. |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
sequence_output = self.dropout(sequence_output) |
|
|
logits = self.classifier(sequence_output) |
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = CrossEntropyLoss() |
|
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) + outputs[2:] |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return TokenClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
@auto_docstring |
|
|
class BertForQuestionAnswering(BertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
self.bert = BertModel(config, add_pooling_layer=False) |
|
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
@auto_docstring |
|
|
def forward( |
|
|
self, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
token_type_ids: Optional[torch.Tensor] = None, |
|
|
position_ids: Optional[torch.Tensor] = None, |
|
|
head_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
start_positions: Optional[torch.Tensor] = None, |
|
|
end_positions: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
outputs = self.bert( |
|
|
input_ids, |
|
|
attention_mask=attention_mask, |
|
|
token_type_ids=token_type_ids, |
|
|
position_ids=position_ids, |
|
|
head_mask=head_mask, |
|
|
inputs_embeds=inputs_embeds, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
sequence_output = outputs[0] |
|
|
|
|
|
logits = self.qa_outputs(sequence_output) |
|
|
start_logits, end_logits = logits.split(1, dim=-1) |
|
|
start_logits = start_logits.squeeze(-1).contiguous() |
|
|
end_logits = end_logits.squeeze(-1).contiguous() |
|
|
|
|
|
total_loss = None |
|
|
if start_positions is not None and end_positions is not None: |
|
|
|
|
|
if len(start_positions.size()) > 1: |
|
|
start_positions = start_positions.squeeze(-1) |
|
|
if len(end_positions.size()) > 1: |
|
|
end_positions = end_positions.squeeze(-1) |
|
|
|
|
|
ignored_index = start_logits.size(1) |
|
|
start_positions = start_positions.clamp(0, ignored_index) |
|
|
end_positions = end_positions.clamp(0, ignored_index) |
|
|
|
|
|
loss_fct = CrossEntropyLoss(ignore_index=ignored_index) |
|
|
start_loss = loss_fct(start_logits, start_positions) |
|
|
end_loss = loss_fct(end_logits, end_positions) |
|
|
total_loss = (start_loss + end_loss) / 2 |
|
|
|
|
|
if not return_dict: |
|
|
output = (start_logits, end_logits) + outputs[2:] |
|
|
return ((total_loss,) + output) if total_loss is not None else output |
|
|
|
|
|
return QuestionAnsweringModelOutput( |
|
|
loss=total_loss, |
|
|
start_logits=start_logits, |
|
|
end_logits=end_logits, |
|
|
hidden_states=outputs.hidden_states, |
|
|
attentions=outputs.attentions, |
|
|
) |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"BertForMaskedLM", |
|
|
"BertForMultipleChoice", |
|
|
"BertForNextSentencePrediction", |
|
|
"BertForPreTraining", |
|
|
"BertForQuestionAnswering", |
|
|
"BertForSequenceClassification", |
|
|
"BertForTokenClassification", |
|
|
"BertLayer", |
|
|
"BertLMHeadModel", |
|
|
"BertModel", |
|
|
"BertPreTrainedModel", |
|
|
"load_tf_weights_in_bert", |
|
|
] |
|
|
|