davda54 commited on
Commit
10d36d5
·
verified ·
1 Parent(s): 86bb1b0

Update modeling_norbert.py

Browse files
Files changed (1) hide show
  1. modeling_norbert.py +4 -4
modeling_norbert.py CHANGED
@@ -133,12 +133,12 @@ class Attention(nn.Module):
133
 
134
  # Recompute position_indices at the beginning or if sequence length exceeds the precomputed size
135
  if self.position_indices is None or self.position_indices.size(0) < query_len:
136
- position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
137
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
138
- position_indices = self.make_log_bucket_position(position_indices, self.config.position_bucket_size, 512)
139
- position_indices = self.config.position_bucket_size - 1 + position_indices
140
  if self.position_indices.device != hidden_states.device:
141
- self.position_indices = position_indices.to(hidden_states.device)
142
 
143
  # Pre-LN and project query/key/value.
144
  hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]
 
133
 
134
  # Recompute position_indices at the beginning or if sequence length exceeds the precomputed size
135
  if self.position_indices is None or self.position_indices.size(0) < query_len:
136
+ self.position_indices = torch.arange(query_len, dtype=torch.long).unsqueeze(1) \
137
  - torch.arange(query_len, dtype=torch.long).unsqueeze(0)
138
+ self.position_indices = self.make_log_bucket_position(self.position_indices, self.config.position_bucket_size, 512)
139
+ self.position_indices = self.config.position_bucket_size - 1 + self.position_indices
140
  if self.position_indices.device != hidden_states.device:
141
+ self.position_indices = self.position_indices.to(hidden_states.device)
142
 
143
  # Pre-LN and project query/key/value.
144
  hidden_states = self.pre_layer_norm(hidden_states) # shape: [B, T, D]