feat: implement task type embeddings
#1
by
Markus28
- opened
- configuration_bert.py +4 -0
- modeling_bert.py +14 -2
configuration_bert.py
CHANGED
|
@@ -81,6 +81,8 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 81 |
fused_dropout_add_ln=False,
|
| 82 |
fused_bias_fc=False,
|
| 83 |
pad_vocab_size_multiple=1,
|
|
|
|
|
|
|
| 84 |
**kwargs,
|
| 85 |
):
|
| 86 |
assert 'position_embedding_type' not in kwargs
|
|
@@ -106,3 +108,5 @@ class JinaBertConfig(PretrainedConfig):
|
|
| 106 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 107 |
self.fused_bias_fc = fused_bias_fc
|
| 108 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
|
|
|
|
|
|
|
|
| 81 |
fused_dropout_add_ln=False,
|
| 82 |
fused_bias_fc=False,
|
| 83 |
pad_vocab_size_multiple=1,
|
| 84 |
+
num_tasks=0,
|
| 85 |
+
use_flash_attn=True,
|
| 86 |
**kwargs,
|
| 87 |
):
|
| 88 |
assert 'position_embedding_type' not in kwargs
|
|
|
|
| 108 |
self.fused_dropout_add_ln = fused_dropout_add_ln
|
| 109 |
self.fused_bias_fc = fused_bias_fc
|
| 110 |
self.pad_vocab_size_multiple = pad_vocab_size_multiple
|
| 111 |
+
self.num_tasks = num_tasks
|
| 112 |
+
self.use_flash_attn = use_flash_attn
|
modeling_bert.py
CHANGED
|
@@ -59,6 +59,7 @@ logger = logging.getLogger(__name__)
|
|
| 59 |
|
| 60 |
|
| 61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
|
|
| 62 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 63 |
window_size = getattr(config, "window_size", (-1, -1))
|
| 64 |
mixer_cls = partial(
|
|
@@ -68,7 +69,7 @@ def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
|
| 68 |
dropout=config.attention_probs_dropout_prob,
|
| 69 |
causal=False,
|
| 70 |
fused_bias_fc=fused_bias_fc,
|
| 71 |
-
use_flash_attn=
|
| 72 |
return_residual=return_residual,
|
| 73 |
use_alibi=True,
|
| 74 |
window_size=window_size,
|
|
@@ -151,6 +152,7 @@ def _init_weights(module, initializer_range=0.02):
|
|
| 151 |
class BertEncoder(nn.Module):
|
| 152 |
def __init__(self, config: JinaBertConfig):
|
| 153 |
super().__init__()
|
|
|
|
| 154 |
self.layers = nn.ModuleList(
|
| 155 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 156 |
)
|
|
@@ -171,7 +173,7 @@ class BertEncoder(nn.Module):
|
|
| 171 |
This means that we only compute the last layer output for these tokens.
|
| 172 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 173 |
"""
|
| 174 |
-
if key_padding_mask is None:
|
| 175 |
mixer_kwargs = (
|
| 176 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
| 177 |
)
|
|
@@ -340,14 +342,21 @@ class BertModel(BertPreTrainedModel):
|
|
| 340 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 341 |
self.encoder = BertEncoder(config)
|
| 342 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
|
|
|
| 343 |
|
| 344 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 345 |
|
| 346 |
def forward(
|
| 347 |
self,
|
| 348 |
input_ids,
|
| 349 |
position_ids=None,
|
| 350 |
token_type_ids=None,
|
|
|
|
| 351 |
attention_mask=None,
|
| 352 |
masked_tokens_mask=None,
|
| 353 |
):
|
|
@@ -359,6 +368,9 @@ class BertModel(BertPreTrainedModel):
|
|
| 359 |
hidden_states = self.embeddings(
|
| 360 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 361 |
)
|
|
|
|
|
|
|
|
|
|
| 362 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 363 |
# BERT puts embedding LayerNorm before embedding dropout.
|
| 364 |
if not self.fused_dropout_add_ln:
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
def create_mixer_cls(config, cross_attn=False, return_residual=False):
|
| 62 |
+
use_flash_attn = getattr(config, "use_flash_attn", False)
|
| 63 |
fused_bias_fc = getattr(config, "fused_bias_fc", False)
|
| 64 |
window_size = getattr(config, "window_size", (-1, -1))
|
| 65 |
mixer_cls = partial(
|
|
|
|
| 69 |
dropout=config.attention_probs_dropout_prob,
|
| 70 |
causal=False,
|
| 71 |
fused_bias_fc=fused_bias_fc,
|
| 72 |
+
use_flash_attn=use_flash_attn,
|
| 73 |
return_residual=return_residual,
|
| 74 |
use_alibi=True,
|
| 75 |
window_size=window_size,
|
|
|
|
| 152 |
class BertEncoder(nn.Module):
|
| 153 |
def __init__(self, config: JinaBertConfig):
|
| 154 |
super().__init__()
|
| 155 |
+
self.use_flash_attn = getattr(config, "use_flash_attn", False)
|
| 156 |
self.layers = nn.ModuleList(
|
| 157 |
[create_block(config, layer_idx=i) for i in range(config.num_hidden_layers)]
|
| 158 |
)
|
|
|
|
| 173 |
This means that we only compute the last layer output for these tokens.
|
| 174 |
subset_mask: (batch, seqlen), dtype=torch.bool
|
| 175 |
"""
|
| 176 |
+
if key_padding_mask is None or not self.use_flash_attn:
|
| 177 |
mixer_kwargs = (
|
| 178 |
{"key_padding_mask": key_padding_mask} if key_padding_mask is not None else None
|
| 179 |
)
|
|
|
|
| 342 |
self.emb_ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
|
| 343 |
self.encoder = BertEncoder(config)
|
| 344 |
self.pooler = BertPooler(config) if add_pooling_layer else None
|
| 345 |
+
self.task_type_embeddings = nn.Embedding(config.num_tasks, config.hidden_size)
|
| 346 |
|
| 347 |
self.apply(partial(_init_weights, initializer_range=config.initializer_range))
|
| 348 |
+
# We now initialize the task embeddings to 0; We do not use task types during
|
| 349 |
+
# pretraining. When we start using task types during embedding training,
|
| 350 |
+
# we want the model to behave exactly as in pretraining (i.e. task types
|
| 351 |
+
# have no effect).
|
| 352 |
+
nn.init.zeros_(self.task_type_embeddings.weight)
|
| 353 |
|
| 354 |
def forward(
|
| 355 |
self,
|
| 356 |
input_ids,
|
| 357 |
position_ids=None,
|
| 358 |
token_type_ids=None,
|
| 359 |
+
task_type_ids=None,
|
| 360 |
attention_mask=None,
|
| 361 |
masked_tokens_mask=None,
|
| 362 |
):
|
|
|
|
| 368 |
hidden_states = self.embeddings(
|
| 369 |
input_ids, position_ids=position_ids, token_type_ids=token_type_ids
|
| 370 |
)
|
| 371 |
+
if task_type_ids is not None:
|
| 372 |
+
hidden_states = hidden_states + self.task_type_embeddings(task_type_ids)
|
| 373 |
+
|
| 374 |
# TD [2022-12:18]: Don't need to force residual in fp32
|
| 375 |
# BERT puts embedding LayerNorm before embedding dropout.
|
| 376 |
if not self.fused_dropout_add_ln:
|