joebruce1313's picture
Upload 38004 files
1f5470c verified
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.layers.layer import Layer
from keras.src.saving import serialization_lib
@keras_export("keras.layers.Pipeline")
class Pipeline(Layer):
"""Applies a series of layers to an input.
This class is useful to build a preprocessing pipeline,
in particular an image data augmentation pipeline.
Compared to a `Sequential` model, `Pipeline` features
a few important differences:
- It's not a `Model`, just a plain layer.
- When the layers in the pipeline are compatible
with `tf.data`, the pipeline will also
remain `tf.data` compatible. That is to say,
the pipeline will not attempt to convert
its inputs to backend-native tensors
when in a tf.data context (unlike a `Sequential`
model).
Example:
```python
from keras import layers
preprocessing_pipeline = layers.Pipeline([
layers.AutoContrast(),
layers.RandomZoom(0.2),
layers.RandomRotation(0.2),
])
# `ds` is a tf.data.Dataset
preprocessed_ds = ds.map(
preprocessing_pipeline,
num_parallel_calls=4,
)
```
"""
def __init__(self, layers, name=None):
super().__init__(name=name)
self._pipeline_layers = layers
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
@property
def layers(self):
return self._pipeline_layers
def call(self, inputs, training=True, mask=None):
for layer in self._pipeline_layers:
kwargs = {}
if layer._call_has_mask_arg:
kwargs["mask"] = mask
if layer._call_has_training_arg and training is not None:
kwargs["training"] = training
outputs = layer(inputs, **kwargs)
inputs = outputs
def _get_mask_from_keras_tensor(kt):
return getattr(kt, "_keras_mask", None)
mask = tree.map_structure(_get_mask_from_keras_tensor, outputs)
return outputs
@classmethod
def from_config(cls, config):
config["layers"] = [
serialization_lib.deserialize_keras_object(x)
for x in config["layers"]
]
return cls(**config)
def get_config(self):
config = {
"layers": serialization_lib.serialize_keras_object(
self._pipeline_layers
),
"name": self.name,
}
return config