joebruce1313's picture
Upload 38004 files
1f5470c verified
from collections import namedtuple
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.backend.common import global_state
@keras_export("keras.RematScope")
class RematScope:
"""A context manager for enabling rematerialization in Keras.
Rematerialization (gradient checkpointing) trades memory for computation by
recomputing intermediate activations during the backward pass. This is
particularly useful for training large models or large batch sizes within
limited memory constraints.
This should be used when initializing the layer (e.g., `layer(input)`).
Rematerialization applies at execution time, not at creation time.
Args:
mode: Rematerialization mode to apply.
Options:
- `"full"`: Apply rematerialization globally to all supported
operations.
- `"activations"`: Apply rematerialization to activations on any
layers that contain `keras.activations` (e.g., `Dense(...,
activation=relu)`).
- `"larger_than"`: Apply rematerialization to layers with output
sizes larger than `output_size_threshold`.
- `"list_of_layers"`: Apply rematerialization to a specific list of
layer names.
- `None`: Disable rematerialization.
output_size_threshold: Output size threshold for the
`"larger_than"` mode. Layers producing outputs larger than this
threshold will be rematerialized. Default is `1024`.
layer_names: List of layer names for the
`"list_of_layers"` mode. Default is an empty list.
Examples:
Using "list_of_layers" mode:
```python
from keras import RematScope
input_tensor = tf.random.normal((1, 32, 32, 3))
with RematScope(mode="list_of_layers", layer_names=["dense_1",
"conv2d_1"]):
layer1 = keras.layers.Dense(128, name="dense_1")
layer2 = keras.layers.Conv2D(64, (3, 3), name="conv2d_1")
layer3 = keras.layers.Dense(64, name="dense_2")
# Only layer1 and layer2 will apply rematerialization
output1 = layer1(input_tensor)
output2 = layer2(output1)
output3 = layer3(output2)
```
Using "larger_than" mode with a specific output size threshold:
```python
with RematScope(mode="larger_than", output_size_threshold=2048):
layer = keras.layers.Conv2D(64, (3, 3))
output = layer(input_tensor) # Conv2D outputs larger than 2048
```
Nested scopes for fine-grained control:
```python
with RematScope(mode="full"):
# Create layers
layer1 = keras.layers.Dense(128, activation='relu')
output1 = layer1(input_tensor) # layer1 is fully rematerialized
with RematScope(mode="larger_than", output_size_threshold=512):
layer2 = keras.layers.Conv2D(32, (3, 3))
output2 = layer2(output1) # layer2 is conditionally rematerialized
# if output > 512
```
"""
def __init__(
self, mode="full", output_size_threshold=1024, layer_names=None
):
if mode not in {
"full",
"activations",
"larger_than",
"list_of_layers",
None,
}:
raise ValueError(
f"Invalid mode '{mode}'. Supported modes are: "
"'full', 'activations', 'larger_than', 'list_of_layers', or "
" None."
)
self.mode = mode
self.output_size_threshold = output_size_threshold
self.layer_names = layer_names or []
self._pop_on_exit = False
def __enter__(self):
remat_scope_stack = global_state.get_global_attribute(
"remat_scope_stack", default=[], set_to_default=True
)
remat_scope_stack.append(self)
self._pop_on_exit = True
return self
def __exit__(self, *args, **kwargs):
if self._pop_on_exit:
remat_scope_stack = global_state.get_global_attribute(
"remat_scope_stack"
)
remat_scope_stack.pop()
RematMode = namedtuple(
"RematMode", ["mode", "output_size_threshold", "layer_names"]
)
def get_current_remat_mode():
"""Get the current rematerialization mode and associated settings.
Returns:
RematMode or None: The current rematerialization mode, or None if not
set.
"""
remat_scope_stack = global_state.get_global_attribute("remat_scope_stack")
if not remat_scope_stack:
return None
active_scope = remat_scope_stack[-1]
return RematMode(
active_scope.mode,
active_scope.output_size_threshold,
active_scope.layer_names,
)
@keras_export("keras.remat")
def remat(f):
"""Applies rematerialization to a function or layer for memory optimization.
Rematerialization is a memory optimization technique that trades off
computation for memory. Instead of storing intermediate results
(e.g. activations) for backpropagation, they are recomputed during the
backward pass. This reduces peak memory usage at the cost of increased
computation time, allowing the training of larger models or using larger
batch sizes within the same memory constraints.
Args:
f: A callable function, to which rematerialization is
applied. This is typically a computationally expensive operation
where intermediate states can be recomputed instead of stored.
Returns:
A wrapped function that applies rematerialization. The returned
function defines a custom gradient, ensuring that during the backward
pass, the forward computation is recomputed as needed.
Example:
```python
from keras import Model
class CustomRematLayer(layers.Layer):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.remat_function = remat(self.intermediate_function)
def intermediate_function(self, x):
for _ in range(2):
x = x + x * 0.1 # Simple scaled transformation
return x
def call(self, inputs):
return self.remat_function(inputs)
# Define a simple model using the custom layer
inputs = layers.Input(shape=(4,))
x = layers.Dense(4, activation="relu")(inputs)
x = CustomRematLayer()(x) # Custom layer with rematerialization
outputs = layers.Dense(1)(x)
# Create and compile the model
model = Model(inputs=inputs, outputs=outputs)
model.compile(optimizer="sgd", loss="mse")
```
"""
return backend.core.remat(f)