Adding _set_gradient_checkpointing for compatibility (#22)
Browse files- Adding _set_gradient_checkpointing for compatibility (a30a931294ac0f344a0c1547877c692ceb17123c)
Co-authored-by: Vicente Rivera <vriveras@users.noreply.huggingface.co>
modeling_mixformer_sequential.py
CHANGED
|
@@ -711,6 +711,10 @@ class MixFormerSequentialPreTrainedModel(PreTrainedModel):
|
|
| 711 |
"past_key_values": past_key_values,
|
| 712 |
"attention_mask": attention_mask,
|
| 713 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 714 |
|
| 715 |
|
| 716 |
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|
|
|
|
| 711 |
"past_key_values": past_key_values,
|
| 712 |
"attention_mask": attention_mask,
|
| 713 |
}
|
| 714 |
+
|
| 715 |
+
def _set_gradient_checkpointing(self, module, value=False):
|
| 716 |
+
if isinstance(module, MixFormerSequentialPreTrainedModel):
|
| 717 |
+
module.gradient_checkpointing = value
|
| 718 |
|
| 719 |
|
| 720 |
class MixFormerSequentialForCausalLM(MixFormerSequentialPreTrainedModel):
|