davda54 commited on
Commit
86bb1b0
·
verified ·
1 Parent(s): 52fc641

Fix compatibility with transformers v5

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +3 -7
modeling_norbert.py CHANGED
@@ -22,11 +22,6 @@ class Encoder(nn.Module):
22
  def __init__(self, config, activation_checkpointing=False):
23
  super().__init__()
24
  self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
25
-
26
- for i, layer in enumerate(self.layers):
27
- layer.mlp.mlp[1].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
28
- layer.mlp.mlp[-2].weight.data *= math.sqrt(1.0 / (2.0 * (1 + i)))
29
-
30
  self.activation_checkpointing = activation_checkpointing
31
 
32
  def forward(self, hidden_states, attention_mask, relative_embedding):
@@ -142,6 +137,7 @@ class Attention(nn.Module):
142
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
143
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
144
  position_indices = self.config.position_bucket_size - 1 + position_indices
 
145
  self.position_indices = position_indices.to(hidden_states.device)
146
 
147
  # Pre-LN and project query/key/value.
@@ -320,10 +316,10 @@ class NorbertForMaskedLM(NorbertModel):
320
  self.post_init()
321
 
322
  def get_output_embeddings(self):
323
- return self.classifier.nonlinearity[-1].weight
324
 
325
  def set_output_embeddings(self, new_embeddings):
326
- self.classifier.nonlinearity[-1].weight = new_embeddings
327
 
328
  def forward(
329
  self,
 
22
  def __init__(self, config, activation_checkpointing=False):
23
  super().__init__()
24
  self.layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])
 
 
 
 
 
25
  self.activation_checkpointing = activation_checkpointing
26
 
27
  def forward(self, hidden_states, attention_mask, relative_embedding):
 
137
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
138
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
139
  position_indices = self.config.position_bucket_size - 1 + position_indices
140
+ if self.position_indices.device != hidden_states.device:
141
  self.position_indices = position_indices.to(hidden_states.device)
142
 
143
  # Pre-LN and project query/key/value.
 
316
  self.post_init()
317
 
318
  def get_output_embeddings(self):
319
+ return self.classifier.nonlinearity[-1]
320
 
321
  def set_output_embeddings(self, new_embeddings):
322
+ self.classifier.nonlinearity[-1] = new_embeddings
323
 
324
  def forward(
325
  self,