| from collections import namedtuple |
|
|
| from keras.src import backend |
| from keras.src.api_export import keras_export |
| from keras.src.backend.common import global_state |
|
|
|
|
| @keras_export("keras.RematScope") |
| class RematScope: |
| """A context manager for enabling rematerialization in Keras. |
| |
| Rematerialization (gradient checkpointing) trades memory for computation by |
| recomputing intermediate activations during the backward pass. This is |
| particularly useful for training large models or large batch sizes within |
| limited memory constraints. |
| |
| This should be used when initializing the layer (e.g., `layer(input)`). |
| Rematerialization applies at execution time, not at creation time. |
| |
| Args: |
| mode: Rematerialization mode to apply. |
| Options: |
| - `"full"`: Apply rematerialization globally to all supported |
| operations. |
| - `"activations"`: Apply rematerialization to activations on any |
| layers that contain `keras.activations` (e.g., `Dense(..., |
| activation=relu)`). |
| - `"larger_than"`: Apply rematerialization to layers with output |
| sizes larger than `output_size_threshold`. |
| - `"list_of_layers"`: Apply rematerialization to a specific list of |
| layer names. |
| - `None`: Disable rematerialization. |
| output_size_threshold: Output size threshold for the |
| `"larger_than"` mode. Layers producing outputs larger than this |
| threshold will be rematerialized. Default is `1024`. |
| layer_names: List of layer names for the |
| `"list_of_layers"` mode. Default is an empty list. |
| |
| Examples: |
| Using "list_of_layers" mode: |
| |
| ```python |
| from keras import RematScope |
| input_tensor = tf.random.normal((1, 32, 32, 3)) |
| with RematScope(mode="list_of_layers", layer_names=["dense_1", |
| "conv2d_1"]): |
| layer1 = keras.layers.Dense(128, name="dense_1") |
| layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1") |
| layer3 = keras.layers.Dense(64, name="dense_2") |
| # Only layer1 and layer2 will apply rematerialization |
| output1 = layer1(input_tensor) |
| output2 = layer2(output1) |
| output3 = layer3(output2) |
| ``` |
| |
| Using "larger_than" mode with a specific output size threshold: |
| |
| ```python |
| with RematScope(mode="larger_than", output_size_threshold=2048): |
| layer = keras.layers.Conv2D(64, (3, 3)) |
| output = layer(input_tensor) # Conv2D outputs larger than 2048 |
| ``` |
| |
| Nested scopes for fine-grained control: |
| |
| ```python |
| with RematScope(mode="full"): |
| # Create layers |
| layer1 = keras.layers.Dense(128, activation='relu') |
| output1 = layer1(input_tensor) # layer1 is fully rematerialized |
| with RematScope(mode="larger_than", output_size_threshold=512): |
| layer2 = keras.layers.Conv2D(32, (3, 3)) |
| output2 = layer2(output1) # layer2 is conditionally rematerialized |
| # if output > 512 |
| ``` |
| """ |
|
|
| def __init__( |
| self, mode="full", output_size_threshold=1024, layer_names=None |
| ): |
| if mode not in { |
| "full", |
| "activations", |
| "larger_than", |
| "list_of_layers", |
| None, |
| }: |
| raise ValueError( |
| f"Invalid mode '{mode}'. Supported modes are: " |
| "'full', 'activations', 'larger_than', 'list_of_layers', or " |
| " None." |
| ) |
| self.mode = mode |
| self.output_size_threshold = output_size_threshold |
| self.layer_names = layer_names or [] |
| self._pop_on_exit = False |
|
|
| def __enter__(self): |
| remat_scope_stack = global_state.get_global_attribute( |
| "remat_scope_stack", default=[], set_to_default=True |
| ) |
| remat_scope_stack.append(self) |
| self._pop_on_exit = True |
| return self |
|
|
| def __exit__(self, *args, **kwargs): |
| if self._pop_on_exit: |
| remat_scope_stack = global_state.get_global_attribute( |
| "remat_scope_stack" |
| ) |
| remat_scope_stack.pop() |
|
|
|
|
| RematMode = namedtuple( |
| "RematMode", ["mode", "output_size_threshold", "layer_names"] |
| ) |
|
|
|
|
| def get_current_remat_mode(): |
| """Get the current rematerialization mode and associated settings. |
| |
| Returns: |
| RematMode or None: The current rematerialization mode, or None if not |
| set. |
| """ |
| remat_scope_stack = global_state.get_global_attribute("remat_scope_stack") |
| if not remat_scope_stack: |
| return None |
| active_scope = remat_scope_stack[-1] |
| return RematMode( |
| active_scope.mode, |
| active_scope.output_size_threshold, |
| active_scope.layer_names, |
| ) |
|
|
|
|
| @keras_export("keras.remat") |
| def remat(f): |
| """Applies rematerialization to a function or layer for memory optimization. |
| |
| Rematerialization is a memory optimization technique that trades off |
| computation for memory. Instead of storing intermediate results |
| (e.g. activations) for backpropagation, they are recomputed during the |
| backward pass. This reduces peak memory usage at the cost of increased |
| computation time, allowing the training of larger models or using larger |
| batch sizes within the same memory constraints. |
| |
| Args: |
| f: A callable function, to which rematerialization is |
| applied. This is typically a computationally expensive operation |
| where intermediate states can be recomputed instead of stored. |
| |
| Returns: |
| A wrapped function that applies rematerialization. The returned |
| function defines a custom gradient, ensuring that during the backward |
| pass, the forward computation is recomputed as needed. |
| |
| Example: |
| |
| ```python |
| from keras import Model |
| class CustomRematLayer(layers.Layer): |
| def __init__(self, **kwargs): |
| super().__init__(**kwargs) |
| self.remat_function = remat(self.intermediate_function) |
| |
| def intermediate_function(self, x): |
| for _ in range(2): |
| x = x + x * 0.1 # Simple scaled transformation |
| return x |
| |
| def call(self, inputs): |
| return self.remat_function(inputs) |
| |
| # Define a simple model using the custom layer |
| inputs = layers.Input(shape=(4,)) |
| x = layers.Dense(4, activation="relu")(inputs) |
| x = CustomRematLayer()(x) # Custom layer with rematerialization |
| outputs = layers.Dense(1)(x) |
| |
| # Create and compile the model |
| model = Model(inputs=inputs, outputs=outputs) |
| model.compile(optimizer="sgd", loss="mse") |
| ``` |
| """ |
| return backend.core.remat(f) |
|
|