davda54 commited on
Commit
1251650
·
verified ·
1 Parent(s): ef18fbe

Update modeling_norbert.py

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +4 -7
modeling_norbert.py CHANGED
@@ -119,11 +119,7 @@ class Attention(nn.Module):
119
  self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
120
  self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
121
 
122
- position_indices = torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(1) \
123
- - torch.arange(config.max_position_embeddings, dtype=torch.long).unsqueeze(0)
124
- position_indices = self.make_log_bucket_position(position_indices, config.position_bucket_size, config.max_position_embeddings)
125
- position_indices = config.position_bucket_size - 1 + position_indices
126
- self.register_buffer("position_indices", position_indices.contiguous(), persistent=False)
127
 
128
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
129
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
@@ -140,8 +136,8 @@ class Attention(nn.Module):
140
  batch_size, key_len, _ = hidden_states.size()
141
  query_len = key_len
142
 
143
- # Recompute position_indices if sequence length exceeds the precomputed size
144
- if self.position_indices.size(0) < query_len:
145
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
146
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
147
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
@@ -223,6 +219,7 @@ class NorbertPreTrainedModel(PreTrainedModel):
223
  base_model_prefix = "norbert3"
224
  supports_gradient_checkpointing = True
225
  _tied_weights_keys = {}
 
226
 
227
  def _set_gradient_checkpointing(self, module, value=False):
228
  if isinstance(module, Encoder):
 
119
  self.pre_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=False)
120
  self.post_layer_norm = nn.LayerNorm(config.hidden_size, config.layer_norm_eps, elementwise_affine=True)
121
 
122
+ self.position_indices = None
 
 
 
 
123
 
124
  self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
125
  self.scale = 1.0 / math.sqrt(3 * self.head_size)
 
136
  batch_size, key_len, _ = hidden_states.size()
137
  query_len = key_len
138
 
139
+ # Recompute position_indices at the beginning or if sequence length exceeds the precomputed size
140
+ if self.position_indices is None or self.position_indices.size(0) < query_len:
141
  position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
142
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
143
  position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
 
219
  base_model_prefix = "norbert3"
220
  supports_gradient_checkpointing = True
221
  _tied_weights_keys = {}
222
+ _keys_to_ignore_on_load_unexpected = [r".*position_indices.*"]
223
 
224
  def _set_gradient_checkpointing(self, module, value=False):
225
  if isinstance(module, Encoder):