Update bert_layers.py
Browse files- bert_layers.py +114 -0
bert_layers.py
CHANGED
|
@@ -870,3 +870,117 @@ class BertForSequenceClassification(BertPreTrainedModel):
|
|
| 870 |
hidden_states=None,
|
| 871 |
attentions=None,
|
| 872 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 870 |
hidden_states=None,
|
| 871 |
attentions=None,
|
| 872 |
)
|
| 873 |
+
|
| 874 |
+
class BertForTextEncoding(BertPreTrainedModel):
|
| 875 |
+
|
| 876 |
+
def __init__(self, config):
|
| 877 |
+
super().__init__(config)
|
| 878 |
+
|
| 879 |
+
if config.is_decoder:
|
| 880 |
+
warnings.warn(
|
| 881 |
+
'If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for '
|
| 882 |
+
'bi-directional self-attention.')
|
| 883 |
+
|
| 884 |
+
self.bert = BertModel(config, add_pooling_layer=False)
|
| 885 |
+
|
| 886 |
+
# Initialize weights and apply final processing
|
| 887 |
+
self.post_init()
|
| 888 |
+
|
| 889 |
+
@classmethod
|
| 890 |
+
def from_composer(cls,
|
| 891 |
+
pretrained_checkpoint,
|
| 892 |
+
state_dict=None,
|
| 893 |
+
cache_dir=None,
|
| 894 |
+
from_tf=False,
|
| 895 |
+
config=None,
|
| 896 |
+
*inputs,
|
| 897 |
+
**kwargs):
|
| 898 |
+
"""Load from pre-trained."""
|
| 899 |
+
model = cls(config, *inputs, **kwargs)
|
| 900 |
+
if from_tf:
|
| 901 |
+
raise ValueError(
|
| 902 |
+
'TensorFlow is not supported.')
|
| 903 |
+
|
| 904 |
+
state_dict = torch.load(pretrained_checkpoint)
|
| 905 |
+
# If the state_dict was saved after wrapping with `composer.HuggingFaceModel`, it takes on the `model` prefix
|
| 906 |
+
consume_prefix_in_state_dict_if_present(state_dict, prefix='model.')
|
| 907 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict,
|
| 908 |
+
strict=False)
|
| 909 |
+
|
| 910 |
+
if len(missing_keys) > 0:
|
| 911 |
+
logger.warning(
|
| 912 |
+
f"Found these missing keys in the checkpoint: {', '.join(missing_keys)}"
|
| 913 |
+
)
|
| 914 |
+
if len(unexpected_keys) > 0:
|
| 915 |
+
logger.warning(
|
| 916 |
+
f"Found these unexpected keys in the checkpoint: {', '.join(unexpected_keys)}"
|
| 917 |
+
)
|
| 918 |
+
|
| 919 |
+
return model
|
| 920 |
+
|
| 921 |
+
def forward(
|
| 922 |
+
self,
|
| 923 |
+
input_ids: Optional[torch.Tensor] = None,
|
| 924 |
+
attention_mask: Optional[torch.Tensor] = None,
|
| 925 |
+
token_type_ids: Optional[torch.Tensor] = None,
|
| 926 |
+
position_ids: Optional[torch.Tensor] = None,
|
| 927 |
+
head_mask: Optional[torch.Tensor] = None,
|
| 928 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
| 929 |
+
encoder_hidden_states: Optional[torch.Tensor] = None,
|
| 930 |
+
encoder_attention_mask: Optional[torch.Tensor] = None,
|
| 931 |
+
labels: Optional[torch.Tensor] = None,
|
| 932 |
+
output_attentions: Optional[bool] = None,
|
| 933 |
+
output_hidden_states: Optional[bool] = None,
|
| 934 |
+
return_dict: Optional[bool] = None,
|
| 935 |
+
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
|
| 936 |
+
|
| 937 |
+
if (input_ids is not None) == (inputs_embeds is not None):
|
| 938 |
+
raise ValueError('Must specify either input_ids or input_embeds!')
|
| 939 |
+
|
| 940 |
+
if labels is None:
|
| 941 |
+
masked_tokens_mask = None
|
| 942 |
+
else:
|
| 943 |
+
masked_tokens_mask = labels > 0
|
| 944 |
+
|
| 945 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 946 |
+
|
| 947 |
+
outputs = self.bert(
|
| 948 |
+
input_ids,
|
| 949 |
+
attention_mask=attention_mask,
|
| 950 |
+
token_type_ids=token_type_ids,
|
| 951 |
+
position_ids=position_ids,
|
| 952 |
+
head_mask=head_mask,
|
| 953 |
+
inputs_embeds=inputs_embeds,
|
| 954 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 955 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 956 |
+
output_attentions=output_attentions,
|
| 957 |
+
output_hidden_states=output_hidden_states,
|
| 958 |
+
return_dict=return_dict,
|
| 959 |
+
masked_tokens_mask=masked_tokens_mask,
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
pooled_output = outputs[1]
|
| 963 |
+
|
| 964 |
+
return {"sentence_embedding": pooled_output}
|
| 965 |
+
|
| 966 |
+
def prepare_inputs_for_generation(self, input_ids: torch.Tensor,
|
| 967 |
+
attention_mask: torch.Tensor,
|
| 968 |
+
**model_kwargs):
|
| 969 |
+
input_shape = input_ids.shape
|
| 970 |
+
effective_batch_size = input_shape[0]
|
| 971 |
+
|
| 972 |
+
# add a dummy token
|
| 973 |
+
if self.config.pad_token_id is None:
|
| 974 |
+
raise ValueError('The PAD token should be defined for generation')
|
| 975 |
+
|
| 976 |
+
attention_mask = torch.cat([
|
| 977 |
+
attention_mask,
|
| 978 |
+
attention_mask.new_zeros((attention_mask.shape[0], 1))
|
| 979 |
+
], dim=-1)
|
| 980 |
+
dummy_token = torch.full((effective_batch_size, 1),
|
| 981 |
+
self.config.pad_token_id,
|
| 982 |
+
dtype=torch.long,
|
| 983 |
+
device=input_ids.device)
|
| 984 |
+
input_ids = torch.cat([input_ids, dummy_token], dim=1)
|
| 985 |
+
|
| 986 |
+
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|