| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from functools import partial |
|
|
| import torch.nn as nn |
|
|
|
|
| class GradientCheckpointingLayer(nn.Module): |
| """Base class for layers with gradient checkpointing. |
| |
| This class enables gradient checkpointing functionality for a layer. By default, gradient checkpointing is disabled |
| (`gradient_checkpointing = False`). When `model.set_gradient_checkpointing()` is called, gradient checkpointing is |
| enabled by setting `gradient_checkpointing = True` and assigning a checkpointing function to `_gradient_checkpointing_func`. |
| |
| Important: |
| |
| When using gradient checkpointing with `use_reentrant=True`, inputs that require gradients (e.g. hidden states) |
| must be passed as positional arguments (`*args`) rather than keyword arguments to properly propagate gradients. |
| |
| Example: |
| |
| ```python |
| >>> # Correct - hidden_states passed as positional arg |
| >>> out = self.layer(hidden_states, attention_mask=attention_mask) |
| |
| >>> # Incorrect - hidden_states passed as keyword arg |
| >>> out = self.layer(hidden_states=hidden_states, attention_mask=attention_mask) |
| ``` |
| """ |
|
|
| gradient_checkpointing = False |
|
|
| def __call__(self, *args, **kwargs): |
| if self.gradient_checkpointing and self.training: |
| return self._gradient_checkpointing_func(partial(super().__call__, **kwargs), *args) |
| return super().__call__(*args, **kwargs) |
|
|