Update modeling_gpt_refact.py
Browse files- modeling_gpt_refact.py +2 -2
modeling_gpt_refact.py
CHANGED
|
@@ -337,9 +337,9 @@ class GPTRefactPreTrainedModel(PreTrainedModel):
|
|
| 337 |
elif isinstance(module, LayerNormNoBias):
|
| 338 |
module.weight.data.fill_(1.0)
|
| 339 |
|
| 340 |
-
def _set_gradient_checkpointing(self, module,
|
| 341 |
if isinstance(module, GPTRefactModel):
|
| 342 |
-
module.gradient_checkpointing =
|
| 343 |
|
| 344 |
|
| 345 |
class GPTRefactModel(GPTRefactPreTrainedModel):
|
|
|
|
| 337 |
elif isinstance(module, LayerNormNoBias):
|
| 338 |
module.weight.data.fill_(1.0)
|
| 339 |
|
| 340 |
+
def _set_gradient_checkpointing(self, module, enable=False):
|
| 341 |
if isinstance(module, GPTRefactModel):
|
| 342 |
+
module.gradient_checkpointing = enable
|
| 343 |
|
| 344 |
|
| 345 |
class GPTRefactModel(GPTRefactPreTrainedModel):
|