Kirim1 commited on
Commit
9fbc5c2
·
verified ·
1 Parent(s): 30b7c91

Create modeling_bert.py

Browse files
Files changed (1) hide show
  1. modeling_bert.py +140 -0
modeling_bert.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertPreTrainedModel, BertModel
2
+ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions
3
+ import torch
4
+ import torch.nn as nn
5
+ from typing import Optional, Tuple, Union
6
+
7
+
8
+ class BertForCausalLM(BertPreTrainedModel):
9
+ """
10
+ BERT model with a language modeling head for instruction following and text generation.
11
+ Supports 100+ languages with primary focus on English.
12
+ """
13
+
14
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
15
+
16
+ def __init__(self, config):
17
+ super().__init__(config)
18
+
19
+ self.bert = BertModel(config, add_pooling_layer=False)
20
+ self.cls = BertOnlyMLMHead(config)
21
+
22
+ self.post_init()
23
+
24
+ def get_output_embeddings(self):
25
+ return self.cls.predictions.decoder
26
+
27
+ def set_output_embeddings(self, new_embeddings):
28
+ self.cls.predictions.decoder = new_embeddings
29
+
30
+ def forward(
31
+ self,
32
+ input_ids: Optional[torch.LongTensor] = None,
33
+ attention_mask: Optional[torch.FloatTensor] = None,
34
+ token_type_ids: Optional[torch.LongTensor] = None,
35
+ position_ids: Optional[torch.LongTensor] = None,
36
+ head_mask: Optional[torch.FloatTensor] = None,
37
+ inputs_embeds: Optional[torch.FloatTensor] = None,
38
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
39
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
40
+ labels: Optional[torch.LongTensor] = None,
41
+ past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
42
+ use_cache: Optional[bool] = None,
43
+ output_attentions: Optional[bool] = None,
44
+ output_hidden_states: Optional[bool] = None,
45
+ return_dict: Optional[bool] = None,
46
+ ) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
47
+
48
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
49
+
50
+ outputs = self.bert(
51
+ input_ids,
52
+ attention_mask=attention_mask,
53
+ token_type_ids=token_type_ids,
54
+ position_ids=position_ids,
55
+ head_mask=head_mask,
56
+ inputs_embeds=inputs_embeds,
57
+ encoder_hidden_states=encoder_hidden_states,
58
+ encoder_attention_mask=encoder_attention_mask,
59
+ past_key_values=past_key_values,
60
+ use_cache=use_cache,
61
+ output_attentions=output_attentions,
62
+ output_hidden_states=output_hidden_states,
63
+ return_dict=return_dict,
64
+ )
65
+
66
+ sequence_output = outputs[0]
67
+ prediction_scores = self.cls(sequence_output)
68
+
69
+ lm_loss = None
70
+ if labels is not None:
71
+ loss_fct = nn.CrossEntropyLoss()
72
+ lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
73
+
74
+ if not return_dict:
75
+ output = (prediction_scores,) + outputs[2:]
76
+ return ((lm_loss,) + output) if lm_loss is not None else output
77
+
78
+ return CausalLMOutputWithCrossAttentions(
79
+ loss=lm_loss,
80
+ logits=prediction_scores,
81
+ past_key_values=outputs.past_key_values,
82
+ hidden_states=outputs.hidden_states,
83
+ attentions=outputs.attentions,
84
+ cross_attentions=outputs.cross_attentions,
85
+ )
86
+
87
+ def prepare_inputs_for_generation(
88
+ self, input_ids, past_key_values=None, attention_mask=None, **model_kwargs
89
+ ):
90
+ input_shape = input_ids.shape
91
+
92
+ if attention_mask is None:
93
+ attention_mask = input_ids.new_ones(input_shape)
94
+
95
+ if past_key_values is not None:
96
+ input_ids = input_ids[:, -1:]
97
+
98
+ return {
99
+ "input_ids": input_ids,
100
+ "attention_mask": attention_mask,
101
+ "past_key_values": past_key_values,
102
+ }
103
+
104
+
105
+ class BertOnlyMLMHead(nn.Module):
106
+ def __init__(self, config):
107
+ super().__init__()
108
+ self.predictions = BertLMPredictionHead(config)
109
+
110
+ def forward(self, sequence_output):
111
+ prediction_scores = self.predictions(sequence_output)
112
+ return prediction_scores
113
+
114
+
115
+ class BertLMPredictionHead(nn.Module):
116
+ def __init__(self, config):
117
+ super().__init__()
118
+ self.transform = BertPredictionHeadTransform(config)
119
+ self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
120
+ self.bias = nn.Parameter(torch.zeros(config.vocab_size))
121
+ self.decoder.bias = self.bias
122
+
123
+ def forward(self, hidden_states):
124
+ hidden_states = self.transform(hidden_states)
125
+ hidden_states = self.decoder(hidden_states)
126
+ return hidden_states
127
+
128
+
129
+ class BertPredictionHeadTransform(nn.Module):
130
+ def __init__(self, config):
131
+ super().__init__()
132
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
133
+ self.transform_act_fn = nn.GELU()
134
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
135
+
136
+ def forward(self, hidden_states):
137
+ hidden_states = self.dense(hidden_states)
138
+ hidden_states = self.transform_act_fn(hidden_states)
139
+ hidden_states = self.LayerNorm(hidden_states)
140
+ return hidden_states