joebruce1313's picture
Upload 38004 files
1f5470c verified
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.layers.preprocessing.tf_data_layer import TFDataLayer
from keras.src.saving import serialization_lib
@keras_export("keras.layers.Rescaling")
class Rescaling(TFDataLayer):
"""A preprocessing layer which rescales input values to a new range.
This layer rescales every value of an input (often an image) by multiplying
by `scale` and adding `offset`.
For instance:
1. To rescale an input in the `[0, 255]` range
to be in the `[0, 1]` range, you would pass `scale=1./255`.
2. To rescale an input in the `[0, 255]` range to be in the `[-1, 1]` range,
you would pass `scale=1./127.5, offset=-1`.
The rescaling is applied both during training and inference. Inputs can be
of integer or floating point dtype, and by default the layer will output
floats.
**Note:** This layer is safe to use inside a `tf.data` pipeline
(independently of which backend you're using).
Args:
scale: Float, the scale to apply to the inputs.
offset: Float, the offset to apply to the inputs.
**kwargs: Base layer keyword arguments, such as `name` and `dtype`.
"""
def __init__(self, scale, offset=0.0, **kwargs):
super().__init__(**kwargs)
self.scale = scale
self.offset = offset
self.supports_masking = True
def call(self, inputs):
dtype = self.compute_dtype
scale = self.backend.cast(self.scale, dtype)
offset = self.backend.cast(self.offset, dtype)
scale_shape = self.backend.core.shape(scale)
if (
len(scale_shape) > 0
and backend.image_data_format() == "channels_first"
):
scale = self.backend.numpy.reshape(
scale, scale_shape + (1,) * (3 - len(scale_shape))
)
return self.backend.cast(inputs, dtype) * scale + offset
def compute_output_shape(self, input_shape):
return input_shape
def get_config(self):
config = super().get_config()
config.update(
{
# `scale` and `offset` might be numpy array.
"scale": serialization_lib.serialize_keras_object(self.scale),
"offset": serialization_lib.serialize_keras_object(self.offset),
}
)
return config
@classmethod
def from_config(cls, config, custom_objects=None):
config = config.copy()
config["scale"] = serialization_lib.deserialize_keras_object(
config["scale"], custom_objects=custom_objects
)
config["offset"] = serialization_lib.deserialize_keras_object(
config["offset"], custom_objects=custom_objects
)
return cls(**config)