davda54 commited on
Commit
39265fc
·
verified ·
1 Parent(s): 87e0acb
Files changed (1) hide show
  1. modeling_gptbert.py +1 -2
modeling_gptbert.py CHANGED
@@ -345,7 +345,6 @@ class SelfAttention(nn.Module):
345
  def set_window_length(self, window_length: int):
346
  self.window_length = window_length
347
 
348
- @lru_cache(maxsize=32)
349
  def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
350
  """Create and cache window attention mask."""
351
  if self.is_causal:
@@ -532,7 +531,7 @@ class Encoder(nn.Module):
532
 
533
  for layer in self.layers:
534
  if checkpoint_activations:
535
- hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layers, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
536
  else:
537
  hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
538
 
 
345
  def set_window_length(self, window_length: int):
346
  self.window_length = window_length
347
 
 
348
  def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
349
  """Create and cache window attention mask."""
350
  if self.is_causal:
 
531
 
532
  for layer in self.layers:
533
  if checkpoint_activations:
534
+ hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
535
  else:
536
  hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
537