Implement long former sliding window
#9
by
alaeddine-13
- opened
- modeling_bert.py +240 -214
modeling_bert.py
CHANGED
|
@@ -16,7 +16,6 @@
|
|
| 16 |
# limitations under the License.
|
| 17 |
"""PyTorch BERT model."""
|
| 18 |
|
| 19 |
-
|
| 20 |
import math
|
| 21 |
import os
|
| 22 |
import warnings
|
|
@@ -96,6 +95,15 @@ _SEQ_CLASS_EXPECTED_OUTPUT = "'LABEL_1'"
|
|
| 96 |
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
| 97 |
|
| 98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
| 100 |
"""Load tf checkpoints in a pytorch model."""
|
| 101 |
try:
|
|
@@ -126,15 +134,15 @@ def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
|
| 126 |
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 127 |
# which are not required for using pretrained model
|
| 128 |
if any(
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
):
|
| 139 |
logger.info(f"Skipping {'/'.join(name)}")
|
| 140 |
continue
|
|
@@ -214,12 +222,12 @@ class JinaBertEmbeddings(nn.Module):
|
|
| 214 |
)
|
| 215 |
|
| 216 |
def forward(
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
) -> torch.Tensor:
|
| 224 |
if input_ids is not None:
|
| 225 |
input_shape = input_ids.size()
|
|
@@ -230,8 +238,8 @@ class JinaBertEmbeddings(nn.Module):
|
|
| 230 |
|
| 231 |
if position_ids is None:
|
| 232 |
position_ids = self.position_ids[
|
| 233 |
-
|
| 234 |
-
|
| 235 |
|
| 236 |
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 237 |
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
|
@@ -265,13 +273,13 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 265 |
def __init__(self, config: JinaBertConfig, position_embedding_type=None):
|
| 266 |
super().__init__()
|
| 267 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 268 |
-
|
| 269 |
):
|
| 270 |
raise ValueError(
|
| 271 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 272 |
f"heads ({config.num_attention_heads})"
|
| 273 |
)
|
| 274 |
-
|
| 275 |
self.attn_implementation = config.attn_implementation
|
| 276 |
self.num_attention_heads = config.num_attention_heads
|
| 277 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
@@ -286,8 +294,8 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 286 |
config, "position_embedding_type", "absolute"
|
| 287 |
)
|
| 288 |
if (
|
| 289 |
-
|
| 290 |
-
|
| 291 |
):
|
| 292 |
self.max_position_embeddings = config.max_position_embeddings
|
| 293 |
self.distance_embedding = nn.Embedding(
|
|
@@ -305,15 +313,16 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 305 |
return x.permute(0, 2, 1, 3)
|
| 306 |
|
| 307 |
def forward(
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
|
|
|
| 317 |
) -> Tuple[torch.Tensor]:
|
| 318 |
mixed_query_layer = self.query(hidden_states)
|
| 319 |
|
|
@@ -364,8 +373,8 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 364 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 365 |
|
| 366 |
if (
|
| 367 |
-
|
| 368 |
-
|
| 369 |
):
|
| 370 |
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 371 |
if use_cache:
|
|
@@ -401,9 +410,9 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 401 |
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
| 402 |
)
|
| 403 |
attention_scores = (
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
)
|
| 408 |
|
| 409 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
@@ -414,6 +423,10 @@ class JinaBertSelfAttention(nn.Module):
|
|
| 414 |
# Normalize the attention scores to probabilities.
|
| 415 |
attention_probs = nn.functional.softmax(attention_scores + bias, dim=-1)
|
| 416 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 417 |
# This is actually dropping out entire tokens to attend to, which might
|
| 418 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 419 |
attention_probs = self.dropout(attention_probs)
|
|
@@ -445,7 +458,7 @@ class JinaBertSelfOutput(nn.Module):
|
|
| 445 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 446 |
|
| 447 |
def forward(
|
| 448 |
-
|
| 449 |
) -> torch.Tensor:
|
| 450 |
hidden_states = self.dense(hidden_states)
|
| 451 |
hidden_states = self.dropout(hidden_states)
|
|
@@ -481,20 +494,21 @@ class JinaBertAttention(nn.Module):
|
|
| 481 |
# Update hyper params and store pruned heads
|
| 482 |
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 483 |
self.self.all_head_size = (
|
| 484 |
-
|
| 485 |
)
|
| 486 |
self.pruned_heads = self.pruned_heads.union(heads)
|
| 487 |
|
| 488 |
def forward(
|
| 489 |
-
|
| 490 |
-
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
|
|
|
| 498 |
) -> Tuple[torch.Tensor]:
|
| 499 |
self_outputs = self.self(
|
| 500 |
hidden_states,
|
|
@@ -505,11 +519,12 @@ class JinaBertAttention(nn.Module):
|
|
| 505 |
past_key_value,
|
| 506 |
output_attentions,
|
| 507 |
bias,
|
|
|
|
| 508 |
)
|
| 509 |
attention_output = self.output(self_outputs[0], hidden_states)
|
| 510 |
outputs = (attention_output,) + self_outputs[
|
| 511 |
-
|
| 512 |
-
|
| 513 |
return outputs
|
| 514 |
|
| 515 |
|
|
@@ -536,7 +551,7 @@ class JinaBertOutput(nn.Module):
|
|
| 536 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 537 |
|
| 538 |
def forward(
|
| 539 |
-
|
| 540 |
) -> torch.Tensor:
|
| 541 |
hidden_states = self.dense(hidden_states)
|
| 542 |
hidden_states = self.dropout(hidden_states)
|
|
@@ -568,7 +583,7 @@ class JinaBertGLUMLP(nn.Module):
|
|
| 568 |
# compute the activation
|
| 569 |
hidden_states = self.gated_layers(hidden_states)
|
| 570 |
gated = hidden_states[:, :, : self.config.intermediate_size]
|
| 571 |
-
non_gated = hidden_states[:, :, self.config.intermediate_size
|
| 572 |
hidden_states = self.act(gated) * non_gated
|
| 573 |
hidden_states = self.dropout(hidden_states)
|
| 574 |
# multiply by the second matrix
|
|
@@ -602,15 +617,16 @@ class JinaBertLayer(nn.Module):
|
|
| 602 |
self.output = JinaBertOutput(config)
|
| 603 |
|
| 604 |
def forward(
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
|
| 608 |
-
|
| 609 |
-
|
| 610 |
-
|
| 611 |
-
|
| 612 |
-
|
| 613 |
-
|
|
|
|
| 614 |
) -> Tuple[torch.Tensor]:
|
| 615 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 616 |
self_attn_past_key_value = (
|
|
@@ -623,6 +639,7 @@ class JinaBertLayer(nn.Module):
|
|
| 623 |
output_attentions=output_attentions,
|
| 624 |
past_key_value=self_attn_past_key_value,
|
| 625 |
bias=bias,
|
|
|
|
| 626 |
)
|
| 627 |
attention_output = self_attention_outputs[0]
|
| 628 |
|
|
@@ -632,8 +649,8 @@ class JinaBertLayer(nn.Module):
|
|
| 632 |
present_key_value = self_attention_outputs[-1]
|
| 633 |
else:
|
| 634 |
outputs = self_attention_outputs[
|
| 635 |
-
|
| 636 |
-
|
| 637 |
|
| 638 |
cross_attn_present_key_value = None
|
| 639 |
if self.is_decoder and encoder_hidden_states is not None:
|
|
@@ -658,7 +675,7 @@ class JinaBertLayer(nn.Module):
|
|
| 658 |
)
|
| 659 |
attention_output = cross_attention_outputs[0]
|
| 660 |
outputs = (
|
| 661 |
-
|
| 662 |
) # add cross attentions if we output attention weights
|
| 663 |
|
| 664 |
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
|
@@ -704,7 +721,7 @@ class JinaBertEncoder(nn.Module):
|
|
| 704 |
)
|
| 705 |
|
| 706 |
def rebuild_alibi_tensor(
|
| 707 |
-
|
| 708 |
):
|
| 709 |
# Alibi
|
| 710 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
|
@@ -717,7 +734,7 @@ class JinaBertEncoder(nn.Module):
|
|
| 717 |
def get_slopes_power_of_2(n):
|
| 718 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
| 719 |
ratio = start
|
| 720 |
-
return [start * ratio**i for i in range(n)]
|
| 721 |
|
| 722 |
if math.log2(n_heads).is_integer():
|
| 723 |
return get_slopes_power_of_2(
|
|
@@ -728,10 +745,10 @@ class JinaBertEncoder(nn.Module):
|
|
| 728 |
math.log2(n_heads)
|
| 729 |
) # when the number of heads is not a power of 2, we use this workaround.
|
| 730 |
return (
|
| 731 |
-
|
| 732 |
-
|
| 733 |
-
|
| 734 |
-
|
| 735 |
)
|
| 736 |
|
| 737 |
context_position = torch.arange(size, device=device)[:, None]
|
|
@@ -749,17 +766,18 @@ class JinaBertEncoder(nn.Module):
|
|
| 749 |
return alibi
|
| 750 |
|
| 751 |
def forward(
|
| 752 |
-
|
| 753 |
-
|
| 754 |
-
|
| 755 |
-
|
| 756 |
-
|
| 757 |
-
|
| 758 |
-
|
| 759 |
-
|
| 760 |
-
|
| 761 |
-
|
| 762 |
-
|
|
|
|
| 763 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 764 |
all_hidden_states = () if output_hidden_states else None
|
| 765 |
all_self_attentions = () if output_attentions else None
|
|
@@ -828,6 +846,7 @@ class JinaBertEncoder(nn.Module):
|
|
| 828 |
alibi_bias,
|
| 829 |
past_key_value,
|
| 830 |
output_attentions,
|
|
|
|
| 831 |
)
|
| 832 |
|
| 833 |
hidden_states = layer_outputs[0]
|
|
@@ -1117,16 +1136,17 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1117 |
|
| 1118 |
@torch.inference_mode()
|
| 1119 |
def encode(
|
| 1120 |
-
|
| 1121 |
-
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
-
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
|
|
|
|
| 1130 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 1131 |
"""
|
| 1132 |
Computes sentence embeddings
|
|
@@ -1172,8 +1192,8 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1172 |
|
| 1173 |
if show_progress_bar is None:
|
| 1174 |
show_progress_bar = (
|
| 1175 |
-
|
| 1176 |
-
|
| 1177 |
)
|
| 1178 |
|
| 1179 |
if convert_to_tensor:
|
|
@@ -1215,11 +1235,11 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1215 |
|
| 1216 |
for i in range_iter:
|
| 1217 |
encoded_input = self.tokenizer(
|
| 1218 |
-
sentences[i
|
| 1219 |
return_tensors='pt',
|
| 1220 |
**tokenizer_kwargs,
|
| 1221 |
).to(self.device)
|
| 1222 |
-
token_embs = self.forward(**encoded_input)[0]
|
| 1223 |
|
| 1224 |
# Accumulate in fp32 to avoid overflow
|
| 1225 |
token_embs = token_embs.float()
|
|
@@ -1254,7 +1274,7 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1254 |
return all_embeddings
|
| 1255 |
|
| 1256 |
def mean_pooling(
|
| 1257 |
-
|
| 1258 |
):
|
| 1259 |
input_mask_expanded = (
|
| 1260 |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
@@ -1286,20 +1306,21 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1286 |
config_class=_CONFIG_FOR_DOC,
|
| 1287 |
)
|
| 1288 |
def forward(
|
| 1289 |
-
|
| 1290 |
-
|
| 1291 |
-
|
| 1292 |
-
|
| 1293 |
-
|
| 1294 |
-
|
| 1295 |
-
|
| 1296 |
-
|
| 1297 |
-
|
| 1298 |
-
|
| 1299 |
-
|
| 1300 |
-
|
| 1301 |
-
|
| 1302 |
-
|
|
|
|
| 1303 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 1304 |
r"""
|
| 1305 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
@@ -1425,6 +1446,7 @@ class JinaBertModel(JinaBertPreTrainedModel):
|
|
| 1425 |
output_attentions=output_attentions,
|
| 1426 |
output_hidden_states=output_hidden_states,
|
| 1427 |
return_dict=return_dict,
|
|
|
|
| 1428 |
)
|
| 1429 |
sequence_output = encoder_outputs[0]
|
| 1430 |
pooled_output = (
|
|
@@ -1476,18 +1498,19 @@ class JinaBertForPreTraining(JinaBertPreTrainedModel):
|
|
| 1476 |
output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
| 1477 |
)
|
| 1478 |
def forward(
|
| 1479 |
-
|
| 1480 |
-
|
| 1481 |
-
|
| 1482 |
-
|
| 1483 |
-
|
| 1484 |
-
|
| 1485 |
-
|
| 1486 |
-
|
| 1487 |
-
|
| 1488 |
-
|
| 1489 |
-
|
| 1490 |
-
|
|
|
|
| 1491 |
) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
|
| 1492 |
r"""
|
| 1493 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1519,6 +1542,7 @@ class JinaBertForPreTraining(JinaBertPreTrainedModel):
|
|
| 1519 |
output_attentions=output_attentions,
|
| 1520 |
output_hidden_states=output_hidden_states,
|
| 1521 |
return_dict=return_dict,
|
|
|
|
| 1522 |
)
|
| 1523 |
|
| 1524 |
sequence_output, pooled_output = outputs[:2]
|
|
@@ -1586,21 +1610,21 @@ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
|
|
| 1586 |
config_class=_CONFIG_FOR_DOC,
|
| 1587 |
)
|
| 1588 |
def forward(
|
| 1589 |
-
|
| 1590 |
-
|
| 1591 |
-
|
| 1592 |
-
|
| 1593 |
-
|
| 1594 |
-
|
| 1595 |
-
|
| 1596 |
-
|
| 1597 |
-
|
| 1598 |
-
|
| 1599 |
-
|
| 1600 |
-
|
| 1601 |
-
|
| 1602 |
-
|
| 1603 |
-
|
| 1604 |
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 1605 |
r"""
|
| 1606 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
@@ -1676,12 +1700,12 @@ class JinaBertLMHeadModel(JinaBertPreTrainedModel):
|
|
| 1676 |
)
|
| 1677 |
|
| 1678 |
def prepare_inputs_for_generation(
|
| 1679 |
-
|
| 1680 |
-
|
| 1681 |
-
|
| 1682 |
-
|
| 1683 |
-
|
| 1684 |
-
|
| 1685 |
):
|
| 1686 |
input_shape = input_ids.shape
|
| 1687 |
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
@@ -1748,19 +1772,20 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
|
| 1748 |
expected_loss=0.88,
|
| 1749 |
)
|
| 1750 |
def forward(
|
| 1751 |
-
|
| 1752 |
-
|
| 1753 |
-
|
| 1754 |
-
|
| 1755 |
-
|
| 1756 |
-
|
| 1757 |
-
|
| 1758 |
-
|
| 1759 |
-
|
| 1760 |
-
|
| 1761 |
-
|
| 1762 |
-
|
| 1763 |
-
|
|
|
|
| 1764 |
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 1765 |
r"""
|
| 1766 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -1785,6 +1810,7 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
|
| 1785 |
output_attentions=output_attentions,
|
| 1786 |
output_hidden_states=output_hidden_states,
|
| 1787 |
return_dict=return_dict,
|
|
|
|
| 1788 |
)
|
| 1789 |
|
| 1790 |
sequence_output = outputs[0]
|
|
@@ -1811,7 +1837,7 @@ class JinaBertForMaskedLM(JinaBertPreTrainedModel):
|
|
| 1811 |
)
|
| 1812 |
|
| 1813 |
def prepare_inputs_for_generation(
|
| 1814 |
-
|
| 1815 |
):
|
| 1816 |
input_shape = input_ids.shape
|
| 1817 |
effective_batch_size = input_shape[0]
|
|
@@ -1856,18 +1882,18 @@ class JinaBertForNextSentencePrediction(JinaBertPreTrainedModel):
|
|
| 1856 |
output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
| 1857 |
)
|
| 1858 |
def forward(
|
| 1859 |
-
|
| 1860 |
-
|
| 1861 |
-
|
| 1862 |
-
|
| 1863 |
-
|
| 1864 |
-
|
| 1865 |
-
|
| 1866 |
-
|
| 1867 |
-
|
| 1868 |
-
|
| 1869 |
-
|
| 1870 |
-
|
| 1871 |
) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
| 1872 |
r"""
|
| 1873 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -1967,17 +1993,17 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
|
|
| 1967 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
| 1968 |
)
|
| 1969 |
def forward(
|
| 1970 |
-
|
| 1971 |
-
|
| 1972 |
-
|
| 1973 |
-
|
| 1974 |
-
|
| 1975 |
-
|
| 1976 |
-
|
| 1977 |
-
|
| 1978 |
-
|
| 1979 |
-
|
| 1980 |
-
|
| 1981 |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 1982 |
r"""
|
| 1983 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -2012,7 +2038,7 @@ class JinaBertForSequenceClassification(JinaBertPreTrainedModel):
|
|
| 2012 |
if self.num_labels == 1:
|
| 2013 |
self.config.problem_type = "regression"
|
| 2014 |
elif self.num_labels > 1 and (
|
| 2015 |
-
|
| 2016 |
):
|
| 2017 |
self.config.problem_type = "single_label_classification"
|
| 2018 |
else:
|
|
@@ -2074,17 +2100,17 @@ class JinaBertForMultipleChoice(JinaBertPreTrainedModel):
|
|
| 2074 |
config_class=_CONFIG_FOR_DOC,
|
| 2075 |
)
|
| 2076 |
def forward(
|
| 2077 |
-
|
| 2078 |
-
|
| 2079 |
-
|
| 2080 |
-
|
| 2081 |
-
|
| 2082 |
-
|
| 2083 |
-
|
| 2084 |
-
|
| 2085 |
-
|
| 2086 |
-
|
| 2087 |
-
|
| 2088 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 2089 |
r"""
|
| 2090 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
@@ -2193,17 +2219,17 @@ class JinaBertForTokenClassification(JinaBertPreTrainedModel):
|
|
| 2193 |
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
|
| 2194 |
)
|
| 2195 |
def forward(
|
| 2196 |
-
|
| 2197 |
-
|
| 2198 |
-
|
| 2199 |
-
|
| 2200 |
-
|
| 2201 |
-
|
| 2202 |
-
|
| 2203 |
-
|
| 2204 |
-
|
| 2205 |
-
|
| 2206 |
-
|
| 2207 |
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 2208 |
r"""
|
| 2209 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
@@ -2278,18 +2304,18 @@ class JinaBertForQuestionAnswering(JinaBertPreTrainedModel):
|
|
| 2278 |
expected_loss=_QA_EXPECTED_LOSS,
|
| 2279 |
)
|
| 2280 |
def forward(
|
| 2281 |
-
|
| 2282 |
-
|
| 2283 |
-
|
| 2284 |
-
|
| 2285 |
-
|
| 2286 |
-
|
| 2287 |
-
|
| 2288 |
-
|
| 2289 |
-
|
| 2290 |
-
|
| 2291 |
-
|
| 2292 |
-
|
| 2293 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 2294 |
r"""
|
| 2295 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
| 16 |
# limitations under the License.
|
| 17 |
"""PyTorch BERT model."""
|
| 18 |
|
|
|
|
| 19 |
import math
|
| 20 |
import os
|
| 21 |
import warnings
|
|
|
|
| 95 |
_SEQ_CLASS_EXPECTED_LOSS = 0.01
|
| 96 |
|
| 97 |
|
| 98 |
+
def create_k_diag_mask(k, n):
|
| 99 |
+
mask = torch.zeros(n, n, dtype=bool)
|
| 100 |
+
for i in range(n):
|
| 101 |
+
for j in range(n):
|
| 102 |
+
if not math.fabs(i - j) < k:
|
| 103 |
+
mask[i, j] = True
|
| 104 |
+
return mask
|
| 105 |
+
|
| 106 |
+
|
| 107 |
def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
|
| 108 |
"""Load tf checkpoints in a pytorch model."""
|
| 109 |
try:
|
|
|
|
| 134 |
# adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
|
| 135 |
# which are not required for using pretrained model
|
| 136 |
if any(
|
| 137 |
+
n
|
| 138 |
+
in [
|
| 139 |
+
"adam_v",
|
| 140 |
+
"adam_m",
|
| 141 |
+
"AdamWeightDecayOptimizer",
|
| 142 |
+
"AdamWeightDecayOptimizer_1",
|
| 143 |
+
"global_step",
|
| 144 |
+
]
|
| 145 |
+
for n in name
|
| 146 |
):
|
| 147 |
logger.info(f"Skipping {'/'.join(name)}")
|
| 148 |
continue
|
|
|
|
| 222 |
)
|
| 223 |
|
| 224 |
def forward(
|
| 225 |
+
self,
|
| 226 |
+
input_ids: Optional[torch.LongTensor] = None,
|
| 227 |
+
token_type_ids: Optional[torch.LongTensor] = None,
|
| 228 |
+
position_ids: Optional[torch.LongTensor] = None,
|
| 229 |
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
| 230 |
+
past_key_values_length: int = 0,
|
| 231 |
) -> torch.Tensor:
|
| 232 |
if input_ids is not None:
|
| 233 |
input_shape = input_ids.size()
|
|
|
|
| 238 |
|
| 239 |
if position_ids is None:
|
| 240 |
position_ids = self.position_ids[
|
| 241 |
+
:, past_key_values_length: seq_length + past_key_values_length
|
| 242 |
+
]
|
| 243 |
|
| 244 |
# Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
|
| 245 |
# when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
|
|
|
|
| 273 |
def __init__(self, config: JinaBertConfig, position_embedding_type=None):
|
| 274 |
super().__init__()
|
| 275 |
if config.hidden_size % config.num_attention_heads != 0 and not hasattr(
|
| 276 |
+
config, "embedding_size"
|
| 277 |
):
|
| 278 |
raise ValueError(
|
| 279 |
f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
|
| 280 |
f"heads ({config.num_attention_heads})"
|
| 281 |
)
|
| 282 |
+
|
| 283 |
self.attn_implementation = config.attn_implementation
|
| 284 |
self.num_attention_heads = config.num_attention_heads
|
| 285 |
self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
|
|
|
|
| 294 |
config, "position_embedding_type", "absolute"
|
| 295 |
)
|
| 296 |
if (
|
| 297 |
+
self.position_embedding_type == "relative_key"
|
| 298 |
+
or self.position_embedding_type == "relative_key_query"
|
| 299 |
):
|
| 300 |
self.max_position_embeddings = config.max_position_embeddings
|
| 301 |
self.distance_embedding = nn.Embedding(
|
|
|
|
| 313 |
return x.permute(0, 2, 1, 3)
|
| 314 |
|
| 315 |
def forward(
|
| 316 |
+
self,
|
| 317 |
+
hidden_states: torch.Tensor,
|
| 318 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 319 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 320 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 321 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 322 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 323 |
+
output_attentions: Optional[bool] = False,
|
| 324 |
+
bias: Optional[torch.FloatTensor] = None,
|
| 325 |
+
sliding_window: Optional[int] = None,
|
| 326 |
) -> Tuple[torch.Tensor]:
|
| 327 |
mixed_query_layer = self.query(hidden_states)
|
| 328 |
|
|
|
|
| 373 |
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 374 |
|
| 375 |
if (
|
| 376 |
+
self.position_embedding_type == "relative_key"
|
| 377 |
+
or self.position_embedding_type == "relative_key_query"
|
| 378 |
):
|
| 379 |
query_length, key_length = query_layer.shape[2], key_layer.shape[2]
|
| 380 |
if use_cache:
|
|
|
|
| 410 |
"bhrd,lrd->bhlr", key_layer, positional_embedding
|
| 411 |
)
|
| 412 |
attention_scores = (
|
| 413 |
+
attention_scores
|
| 414 |
+
+ relative_position_scores_query
|
| 415 |
+
+ relative_position_scores_key
|
| 416 |
)
|
| 417 |
|
| 418 |
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
|
|
|
|
| 423 |
# Normalize the attention scores to probabilities.
|
| 424 |
attention_probs = nn.functional.softmax(attention_scores + bias, dim=-1)
|
| 425 |
|
| 426 |
+
if sliding_window is not None:
|
| 427 |
+
mask = create_k_diag_mask(sliding_window, int(attention_scores.size(dim=2)))
|
| 428 |
+
attention_probs.masked_fill_(mask, 0)
|
| 429 |
+
|
| 430 |
# This is actually dropping out entire tokens to attend to, which might
|
| 431 |
# seem a bit unusual, but is taken from the original Transformer paper.
|
| 432 |
attention_probs = self.dropout(attention_probs)
|
|
|
|
| 458 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 459 |
|
| 460 |
def forward(
|
| 461 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
| 462 |
) -> torch.Tensor:
|
| 463 |
hidden_states = self.dense(hidden_states)
|
| 464 |
hidden_states = self.dropout(hidden_states)
|
|
|
|
| 494 |
# Update hyper params and store pruned heads
|
| 495 |
self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
|
| 496 |
self.self.all_head_size = (
|
| 497 |
+
self.self.attention_head_size * self.self.num_attention_heads
|
| 498 |
)
|
| 499 |
self.pruned_heads = self.pruned_heads.union(heads)
|
| 500 |
|
| 501 |
def forward(
|
| 502 |
+
self,
|
| 503 |
+
hidden_states: torch.Tensor,
|
| 504 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 505 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 506 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 507 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 508 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 509 |
+
output_attentions: Optional[bool] = False,
|
| 510 |
+
bias: Optional[torch.FloatTensor] = None,
|
| 511 |
+
sliding_window: Optional[int] = None,
|
| 512 |
) -> Tuple[torch.Tensor]:
|
| 513 |
self_outputs = self.self(
|
| 514 |
hidden_states,
|
|
|
|
| 519 |
past_key_value,
|
| 520 |
output_attentions,
|
| 521 |
bias,
|
| 522 |
+
sliding_window=sliding_window
|
| 523 |
)
|
| 524 |
attention_output = self.output(self_outputs[0], hidden_states)
|
| 525 |
outputs = (attention_output,) + self_outputs[
|
| 526 |
+
1:
|
| 527 |
+
] # add attentions if we output them
|
| 528 |
return outputs
|
| 529 |
|
| 530 |
|
|
|
|
| 551 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 552 |
|
| 553 |
def forward(
|
| 554 |
+
self, hidden_states: torch.Tensor, input_tensor: torch.Tensor
|
| 555 |
) -> torch.Tensor:
|
| 556 |
hidden_states = self.dense(hidden_states)
|
| 557 |
hidden_states = self.dropout(hidden_states)
|
|
|
|
| 583 |
# compute the activation
|
| 584 |
hidden_states = self.gated_layers(hidden_states)
|
| 585 |
gated = hidden_states[:, :, : self.config.intermediate_size]
|
| 586 |
+
non_gated = hidden_states[:, :, self.config.intermediate_size:]
|
| 587 |
hidden_states = self.act(gated) * non_gated
|
| 588 |
hidden_states = self.dropout(hidden_states)
|
| 589 |
# multiply by the second matrix
|
|
|
|
| 617 |
self.output = JinaBertOutput(config)
|
| 618 |
|
| 619 |
def forward(
|
| 620 |
+
self,
|
| 621 |
+
hidden_states: torch.Tensor,
|
| 622 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 623 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 624 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 625 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 626 |
+
bias: Optional[torch.FloatTensor] = None,
|
| 627 |
+
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 628 |
+
output_attentions: Optional[bool] = False,
|
| 629 |
+
sliding_window: Optional[int] = None,
|
| 630 |
) -> Tuple[torch.Tensor]:
|
| 631 |
# decoder uni-directional self-attention cached key/values tuple is at positions 1,2
|
| 632 |
self_attn_past_key_value = (
|
|
|
|
| 639 |
output_attentions=output_attentions,
|
| 640 |
past_key_value=self_attn_past_key_value,
|
| 641 |
bias=bias,
|
| 642 |
+
sliding_window=sliding_window
|
| 643 |
)
|
| 644 |
attention_output = self_attention_outputs[0]
|
| 645 |
|
|
|
|
| 649 |
present_key_value = self_attention_outputs[-1]
|
| 650 |
else:
|
| 651 |
outputs = self_attention_outputs[
|
| 652 |
+
1:
|
| 653 |
+
] # add self attentions if we output attention weights
|
| 654 |
|
| 655 |
cross_attn_present_key_value = None
|
| 656 |
if self.is_decoder and encoder_hidden_states is not None:
|
|
|
|
| 675 |
)
|
| 676 |
attention_output = cross_attention_outputs[0]
|
| 677 |
outputs = (
|
| 678 |
+
outputs + cross_attention_outputs[1:-1]
|
| 679 |
) # add cross attentions if we output attention weights
|
| 680 |
|
| 681 |
# add cross-attn cache to positions 3,4 of present_key_value tuple
|
|
|
|
| 721 |
)
|
| 722 |
|
| 723 |
def rebuild_alibi_tensor(
|
| 724 |
+
self, size: int, device: Optional[Union[torch.device, str]] = None
|
| 725 |
):
|
| 726 |
# Alibi
|
| 727 |
# Following https://github.com/ofirpress/attention_with_linear_biases/issues/5 (Implementation 1)
|
|
|
|
| 734 |
def get_slopes_power_of_2(n):
|
| 735 |
start = 2 ** (-(2 ** -(math.log2(n) - 3)))
|
| 736 |
ratio = start
|
| 737 |
+
return [start * ratio ** i for i in range(n)]
|
| 738 |
|
| 739 |
if math.log2(n_heads).is_integer():
|
| 740 |
return get_slopes_power_of_2(
|
|
|
|
| 745 |
math.log2(n_heads)
|
| 746 |
) # when the number of heads is not a power of 2, we use this workaround.
|
| 747 |
return (
|
| 748 |
+
get_slopes_power_of_2(closest_power_of_2)
|
| 749 |
+
+ _get_alibi_head_slopes(2 * closest_power_of_2)[0::2][
|
| 750 |
+
: n_heads - closest_power_of_2
|
| 751 |
+
]
|
| 752 |
)
|
| 753 |
|
| 754 |
context_position = torch.arange(size, device=device)[:, None]
|
|
|
|
| 766 |
return alibi
|
| 767 |
|
| 768 |
def forward(
|
| 769 |
+
self,
|
| 770 |
+
hidden_states: torch.Tensor,
|
| 771 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 772 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
| 773 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 774 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 775 |
+
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 776 |
+
use_cache: Optional[bool] = None,
|
| 777 |
+
output_attentions: Optional[bool] = False,
|
| 778 |
+
output_hidden_states: Optional[bool] = False,
|
| 779 |
+
return_dict: Optional[bool] = True,
|
| 780 |
+
sliding_window: Optional[int] = None,
|
| 781 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
|
| 782 |
all_hidden_states = () if output_hidden_states else None
|
| 783 |
all_self_attentions = () if output_attentions else None
|
|
|
|
| 846 |
alibi_bias,
|
| 847 |
past_key_value,
|
| 848 |
output_attentions,
|
| 849 |
+
sliding_window
|
| 850 |
)
|
| 851 |
|
| 852 |
hidden_states = layer_outputs[0]
|
|
|
|
| 1136 |
|
| 1137 |
@torch.inference_mode()
|
| 1138 |
def encode(
|
| 1139 |
+
self: 'JinaBertModel',
|
| 1140 |
+
sentences: Union[str, List[str]],
|
| 1141 |
+
batch_size: int = 32,
|
| 1142 |
+
show_progress_bar: Optional[bool] = None,
|
| 1143 |
+
output_value: str = 'sentence_embedding',
|
| 1144 |
+
convert_to_numpy: bool = True,
|
| 1145 |
+
convert_to_tensor: bool = False,
|
| 1146 |
+
device: Optional[torch.device] = None,
|
| 1147 |
+
normalize_embeddings: bool = False,
|
| 1148 |
+
sliding_window: Optional[int] = None,
|
| 1149 |
+
**tokenizer_kwargs,
|
| 1150 |
) -> Union[List[torch.Tensor], np.ndarray, torch.Tensor]:
|
| 1151 |
"""
|
| 1152 |
Computes sentence embeddings
|
|
|
|
| 1192 |
|
| 1193 |
if show_progress_bar is None:
|
| 1194 |
show_progress_bar = (
|
| 1195 |
+
logger.getEffectiveLevel() == logging.INFO
|
| 1196 |
+
or logger.getEffectiveLevel() == logging.DEBUG
|
| 1197 |
)
|
| 1198 |
|
| 1199 |
if convert_to_tensor:
|
|
|
|
| 1235 |
|
| 1236 |
for i in range_iter:
|
| 1237 |
encoded_input = self.tokenizer(
|
| 1238 |
+
sentences[i: i + batch_size],
|
| 1239 |
return_tensors='pt',
|
| 1240 |
**tokenizer_kwargs,
|
| 1241 |
).to(self.device)
|
| 1242 |
+
token_embs = self.forward(sliding_window=sliding_window, **encoded_input)[0]
|
| 1243 |
|
| 1244 |
# Accumulate in fp32 to avoid overflow
|
| 1245 |
token_embs = token_embs.float()
|
|
|
|
| 1274 |
return all_embeddings
|
| 1275 |
|
| 1276 |
def mean_pooling(
|
| 1277 |
+
self, token_embeddings: torch.Tensor, attention_mask: torch.Tensor
|
| 1278 |
):
|
| 1279 |
input_mask_expanded = (
|
| 1280 |
attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
|
|
|
| 1306 |
config_class=_CONFIG_FOR_DOC,
|
| 1307 |
)
|
| 1308 |
def forward(
|
| 1309 |
+
self,
|
| 1310 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1311 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1312 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1313 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1314 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1315 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1316 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1317 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1318 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1319 |
+
use_cache: Optional[bool] = None,
|
| 1320 |
+
output_attentions: Optional[bool] = None,
|
| 1321 |
+
output_hidden_states: Optional[bool] = None,
|
| 1322 |
+
return_dict: Optional[bool] = None,
|
| 1323 |
+
sliding_window: Optional[int] = None,
|
| 1324 |
) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
|
| 1325 |
r"""
|
| 1326 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
|
|
| 1446 |
output_attentions=output_attentions,
|
| 1447 |
output_hidden_states=output_hidden_states,
|
| 1448 |
return_dict=return_dict,
|
| 1449 |
+
sliding_window=sliding_window
|
| 1450 |
)
|
| 1451 |
sequence_output = encoder_outputs[0]
|
| 1452 |
pooled_output = (
|
|
|
|
| 1498 |
output_type=JinaBertForPreTrainingOutput, config_class=_CONFIG_FOR_DOC
|
| 1499 |
)
|
| 1500 |
def forward(
|
| 1501 |
+
self,
|
| 1502 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1503 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1504 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1505 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1506 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1507 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1508 |
+
labels: Optional[torch.Tensor] = None,
|
| 1509 |
+
next_sentence_label: Optional[torch.Tensor] = None,
|
| 1510 |
+
output_attentions: Optional[bool] = None,
|
| 1511 |
+
output_hidden_states: Optional[bool] = None,
|
| 1512 |
+
return_dict: Optional[bool] = None,
|
| 1513 |
+
sliding_window: Optional[int] = None,
|
| 1514 |
) -> Union[Tuple[torch.Tensor], JinaBertForPreTrainingOutput]:
|
| 1515 |
r"""
|
| 1516 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 1542 |
output_attentions=output_attentions,
|
| 1543 |
output_hidden_states=output_hidden_states,
|
| 1544 |
return_dict=return_dict,
|
| 1545 |
+
sliding_window=sliding_window
|
| 1546 |
)
|
| 1547 |
|
| 1548 |
sequence_output, pooled_output = outputs[:2]
|
|
|
|
| 1610 |
config_class=_CONFIG_FOR_DOC,
|
| 1611 |
)
|
| 1612 |
def forward(
|
| 1613 |
+
self,
|
| 1614 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1615 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1616 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1617 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1618 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1619 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1620 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1621 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1622 |
+
labels: Optional[torch.Tensor] = None,
|
| 1623 |
+
past_key_values: Optional[List[torch.Tensor]] = None,
|
| 1624 |
+
use_cache: Optional[bool] = None,
|
| 1625 |
+
output_attentions: Optional[bool] = None,
|
| 1626 |
+
output_hidden_states: Optional[bool] = None,
|
| 1627 |
+
return_dict: Optional[bool] = None,
|
| 1628 |
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
|
| 1629 |
r"""
|
| 1630 |
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
|
|
|
| 1700 |
)
|
| 1701 |
|
| 1702 |
def prepare_inputs_for_generation(
|
| 1703 |
+
self,
|
| 1704 |
+
input_ids,
|
| 1705 |
+
past_key_values=None,
|
| 1706 |
+
attention_mask=None,
|
| 1707 |
+
use_cache=True,
|
| 1708 |
+
**model_kwargs,
|
| 1709 |
):
|
| 1710 |
input_shape = input_ids.shape
|
| 1711 |
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
|
|
|
| 1772 |
expected_loss=0.88,
|
| 1773 |
)
|
| 1774 |
def forward(
|
| 1775 |
+
self,
|
| 1776 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1777 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1778 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1779 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1780 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1781 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1782 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 1783 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 1784 |
+
labels: Optional[torch.Tensor] = None,
|
| 1785 |
+
output_attentions: Optional[bool] = None,
|
| 1786 |
+
output_hidden_states: Optional[bool] = None,
|
| 1787 |
+
return_dict: Optional[bool] = None,
|
| 1788 |
+
sliding_window: Optional[int] = None,
|
| 1789 |
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 1790 |
r"""
|
| 1791 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 1810 |
output_attentions=output_attentions,
|
| 1811 |
output_hidden_states=output_hidden_states,
|
| 1812 |
return_dict=return_dict,
|
| 1813 |
+
sliding_window=sliding_window
|
| 1814 |
)
|
| 1815 |
|
| 1816 |
sequence_output = outputs[0]
|
|
|
|
| 1837 |
)
|
| 1838 |
|
| 1839 |
def prepare_inputs_for_generation(
|
| 1840 |
+
self, input_ids, attention_mask=None, **model_kwargs
|
| 1841 |
):
|
| 1842 |
input_shape = input_ids.shape
|
| 1843 |
effective_batch_size = input_shape[0]
|
|
|
|
| 1882 |
output_type=NextSentencePredictorOutput, config_class=_CONFIG_FOR_DOC
|
| 1883 |
)
|
| 1884 |
def forward(
|
| 1885 |
+
self,
|
| 1886 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1887 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1888 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 1889 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 1890 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 1891 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 1892 |
+
labels: Optional[torch.Tensor] = None,
|
| 1893 |
+
output_attentions: Optional[bool] = None,
|
| 1894 |
+
output_hidden_states: Optional[bool] = None,
|
| 1895 |
+
return_dict: Optional[bool] = None,
|
| 1896 |
+
**kwargs,
|
| 1897 |
) -> Union[Tuple[torch.Tensor], NextSentencePredictorOutput]:
|
| 1898 |
r"""
|
| 1899 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
| 1993 |
expected_loss=_SEQ_CLASS_EXPECTED_LOSS,
|
| 1994 |
)
|
| 1995 |
def forward(
|
| 1996 |
+
self,
|
| 1997 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 1998 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 1999 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 2000 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 2001 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 2002 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 2003 |
+
labels: Optional[torch.Tensor] = None,
|
| 2004 |
+
output_attentions: Optional[bool] = None,
|
| 2005 |
+
output_hidden_states: Optional[bool] = None,
|
| 2006 |
+
return_dict: Optional[bool] = None,
|
| 2007 |
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]:
|
| 2008 |
r"""
|
| 2009 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
| 2038 |
if self.num_labels == 1:
|
| 2039 |
self.config.problem_type = "regression"
|
| 2040 |
elif self.num_labels > 1 and (
|
| 2041 |
+
labels.dtype == torch.long or labels.dtype == torch.int
|
| 2042 |
):
|
| 2043 |
self.config.problem_type = "single_label_classification"
|
| 2044 |
else:
|
|
|
|
| 2100 |
config_class=_CONFIG_FOR_DOC,
|
| 2101 |
)
|
| 2102 |
def forward(
|
| 2103 |
+
self,
|
| 2104 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 2105 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2106 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 2107 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 2108 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 2109 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 2110 |
+
labels: Optional[torch.Tensor] = None,
|
| 2111 |
+
output_attentions: Optional[bool] = None,
|
| 2112 |
+
output_hidden_states: Optional[bool] = None,
|
| 2113 |
+
return_dict: Optional[bool] = None,
|
| 2114 |
) -> Union[Tuple[torch.Tensor], MultipleChoiceModelOutput]:
|
| 2115 |
r"""
|
| 2116 |
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
|
|
| 2219 |
expected_loss=_TOKEN_CLASS_EXPECTED_LOSS,
|
| 2220 |
)
|
| 2221 |
def forward(
|
| 2222 |
+
self,
|
| 2223 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 2224 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2225 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 2226 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 2227 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 2228 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 2229 |
+
labels: Optional[torch.Tensor] = None,
|
| 2230 |
+
output_attentions: Optional[bool] = None,
|
| 2231 |
+
output_hidden_states: Optional[bool] = None,
|
| 2232 |
+
return_dict: Optional[bool] = None,
|
| 2233 |
) -> Union[Tuple[torch.Tensor], TokenClassifierOutput]:
|
| 2234 |
r"""
|
| 2235 |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
|
|
| 2304 |
expected_loss=_QA_EXPECTED_LOSS,
|
| 2305 |
)
|
| 2306 |
def forward(
|
| 2307 |
+
self,
|
| 2308 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 2309 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 2310 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 2311 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 2312 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 2313 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 2314 |
+
start_positions: Optional[torch.Tensor] = None,
|
| 2315 |
+
end_positions: Optional[torch.Tensor] = None,
|
| 2316 |
+
output_attentions: Optional[bool] = None,
|
| 2317 |
+
output_hidden_states: Optional[bool] = None,
|
| 2318 |
+
return_dict: Optional[bool] = None,
|
| 2319 |
) -> Union[Tuple[torch.Tensor], QuestionAnsweringModelOutput]:
|
| 2320 |
r"""
|
| 2321 |
start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|