del gradient_checkpointing_enable()
#11
by
chandler88
- opened
- modeling_chatglm.py +0 -4
modeling_chatglm.py
CHANGED
|
@@ -797,10 +797,6 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 798 |
return position_ids
|
| 799 |
|
| 800 |
-
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
|
| 801 |
-
if not self.supports_gradient_checkpointing:
|
| 802 |
-
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
| 803 |
-
|
| 804 |
|
| 805 |
class Embedding(torch.nn.Module):
|
| 806 |
"""Language model embeddings."""
|
|
|
|
| 797 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
| 798 |
return position_ids
|
| 799 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 800 |
|
| 801 |
class Embedding(torch.nn.Module):
|
| 802 |
"""Language model embeddings."""
|