enable activation checkpointing
#10
by
smangrul
- opened
- modeling_phi.py +1 -2
modeling_phi.py
CHANGED
|
@@ -525,7 +525,6 @@ class MHA(nn.Module):
|
|
| 525 |
softmax_scale: Optional[float] = None,
|
| 526 |
layer_idx: Optional[int] = None,
|
| 527 |
return_residual: bool = False,
|
| 528 |
-
checkpointing: bool = False,
|
| 529 |
) -> None:
|
| 530 |
super().__init__()
|
| 531 |
|
|
@@ -585,7 +584,7 @@ class MHA(nn.Module):
|
|
| 585 |
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
| 586 |
self.layer_idx = layer_idx
|
| 587 |
self.return_residual = return_residual
|
| 588 |
-
self.checkpointing = checkpointing
|
| 589 |
|
| 590 |
def _forward_self_attn(
|
| 591 |
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|
|
|
|
| 525 |
softmax_scale: Optional[float] = None,
|
| 526 |
layer_idx: Optional[int] = None,
|
| 527 |
return_residual: bool = False,
|
|
|
|
| 528 |
) -> None:
|
| 529 |
super().__init__()
|
| 530 |
|
|
|
|
| 584 |
self.flash_attn = config.flash_attn and attn_cls is FlashSelfAttention
|
| 585 |
self.layer_idx = layer_idx
|
| 586 |
self.return_residual = return_residual
|
| 587 |
+
self.checkpointing = getattr(config, "checkpointing", False)
|
| 588 |
|
| 589 |
def _forward_self_attn(
|
| 590 |
self, x: torch.FloatTensor, key_padding_mask: Optional[torch.BoolTensor]
|