Taykhoom commited on
Commit
b4fb6f9
·
verified ·
1 Parent(s): ac7f7ab

Upload modeling_bert.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_bert.py +14 -1
modeling_bert.py CHANGED
@@ -256,8 +256,19 @@ class BertPooler(nn.Module):
256
  return self.activation(self.dense(hidden_states[:, 0]))
257
 
258
 
 
 
 
 
 
 
 
 
 
 
259
  class BertModel(PreTrainedModel):
260
  config_class = BertUpdatedConfig
 
261
  _supports_sdpa = True
262
  _supports_flash_attn_2 = True
263
 
@@ -314,12 +325,14 @@ class BertModel(PreTrainedModel):
314
 
315
  class BertForMaskedLM(PreTrainedModel):
316
  config_class = BertUpdatedConfig
 
317
  _supports_sdpa = True
318
  _supports_flash_attn_2 = True
319
 
320
  def __init__(self, config):
321
  super().__init__(config)
322
  self.bert = BertModel(config)
 
323
  self.cls = nn.Linear(config.hidden_size, config.vocab_size)
324
  self.post_init()
325
 
@@ -343,7 +356,7 @@ class BertForMaskedLM(PreTrainedModel):
343
  output_hidden_states=output_hidden_states, output_attentions=output_attentions,
344
  return_dict=True,
345
  )
346
- logits = self.cls(outputs.last_hidden_state)
347
 
348
  loss = None
349
  if labels is not None:
 
256
  return self.activation(self.dense(hidden_states[:, 0]))
257
 
258
 
259
+ class BertPredictionHeadTransform(nn.Module):
260
+ def __init__(self, config):
261
+ super().__init__()
262
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
263
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
264
+
265
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
266
+ return self.LayerNorm(F.gelu(self.dense(hidden_states)))
267
+
268
+
269
  class BertModel(PreTrainedModel):
270
  config_class = BertUpdatedConfig
271
+ base_model_prefix = "bert"
272
  _supports_sdpa = True
273
  _supports_flash_attn_2 = True
274
 
 
325
 
326
  class BertForMaskedLM(PreTrainedModel):
327
  config_class = BertUpdatedConfig
328
+ base_model_prefix = "bert"
329
  _supports_sdpa = True
330
  _supports_flash_attn_2 = True
331
 
332
  def __init__(self, config):
333
  super().__init__(config)
334
  self.bert = BertModel(config)
335
+ self.transform = BertPredictionHeadTransform(config)
336
  self.cls = nn.Linear(config.hidden_size, config.vocab_size)
337
  self.post_init()
338
 
 
356
  output_hidden_states=output_hidden_states, output_attentions=output_attentions,
357
  return_dict=True,
358
  )
359
+ logits = self.cls(self.transform(outputs.last_hidden_state))
360
 
361
  loss = None
362
  if labels is not None: