fix
Browse files- modeling_gptbert.py +1 -2
modeling_gptbert.py
CHANGED
|
@@ -345,7 +345,6 @@ class SelfAttention(nn.Module):
|
|
| 345 |
def set_window_length(self, window_length: int):
|
| 346 |
self.window_length = window_length
|
| 347 |
|
| 348 |
-
@lru_cache(maxsize=32)
|
| 349 |
def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
|
| 350 |
"""Create and cache window attention mask."""
|
| 351 |
if self.is_causal:
|
|
@@ -532,7 +531,7 @@ class Encoder(nn.Module):
|
|
| 532 |
|
| 533 |
for layer in self.layers:
|
| 534 |
if checkpoint_activations:
|
| 535 |
-
hidden_layer, v1 = torch.utils.checkpoint.checkpoint(
|
| 536 |
else:
|
| 537 |
hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
|
| 538 |
|
|
|
|
| 345 |
def set_window_length(self, window_length: int):
|
| 346 |
self.window_length = window_length
|
| 347 |
|
|
|
|
| 348 |
def _get_window_mask(self, query_length: int, key_length: int, device: torch.device):
|
| 349 |
"""Create and cache window attention mask."""
|
| 350 |
if self.is_causal:
|
|
|
|
| 531 |
|
| 532 |
for layer in self.layers:
|
| 533 |
if checkpoint_activations:
|
| 534 |
+
hidden_layer, v1 = torch.utils.checkpoint.checkpoint(layer, hidden_layer, embeddings, v1, padding_info, use_reentrant=True)
|
| 535 |
else:
|
| 536 |
hidden_layer, v1 = layer(hidden_layer, embeddings, v1, padding_info)
|
| 537 |
|