File size: 9,849 Bytes
1f5470c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 | from keras.src import backend
from keras.src import tree
from keras.src.api_export import keras_export
@keras_export(["keras.InputSpec", "keras.layers.InputSpec"])
class InputSpec:
"""Specifies the rank, dtype and shape of every input to a layer.
Layers can expose (if appropriate) an `input_spec` attribute:
an instance of `InputSpec`, or a nested structure of `InputSpec` instances
(one per input tensor). These objects enable the layer to run input
compatibility checks for input structure, input rank, input shape, and
input dtype for the first argument of `Layer.__call__`.
A `None` entry in a shape is compatible with any dimension.
Args:
dtype: Expected dtype of the input.
shape: Shape tuple, expected shape of the input
(may include `None` for dynamic axes).
Includes the batch size.
ndim: Integer, expected rank of the input.
max_ndim: Integer, maximum rank of the input.
min_ndim: Integer, minimum rank of the input.
axes: Dictionary mapping integer axes to
a specific dimension value.
allow_last_axis_squeeze: If `True`, allow inputs of rank N+1 as long
as the last axis of the input is 1, as well as inputs of rank N-1
as long as the last axis of the spec is 1.
name: Expected key corresponding to this input when passing data as
a dictionary.
optional: Boolean, whether the input is optional or not.
An optional input can accept `None` values.
Example:
```python
class MyLayer(Layer):
def __init__(self):
super().__init__()
# The layer will accept inputs with
# shape (*, 28, 28) & (*, 28, 28, 1)
# and raise an appropriate error message otherwise.
self.input_spec = InputSpec(
shape=(None, 28, 28, 1),
allow_last_axis_squeeze=True)
```
"""
def __init__(
self,
dtype=None,
shape=None,
ndim=None,
max_ndim=None,
min_ndim=None,
axes=None,
allow_last_axis_squeeze=False,
name=None,
optional=False,
):
self.dtype = (
backend.standardize_dtype(dtype) if dtype is not None else None
)
if shape is not None:
self.shape = backend.standardize_shape(shape)
self.ndim = len(shape)
else:
self.ndim = ndim
self.shape = None
self.max_ndim = max_ndim
self.min_ndim = min_ndim
self.name = name
self.optional = optional
self.allow_last_axis_squeeze = allow_last_axis_squeeze
try:
axes = axes or {}
self.axes = {int(k): axes[k] for k in axes}
except (ValueError, TypeError):
raise TypeError(
"Argument `axes` must be a dict with integer keys. "
f"Received: axes={axes}"
)
if self.axes and (self.ndim is not None or self.max_ndim is not None):
max_dim = (self.ndim if self.ndim else self.max_ndim) - 1
max_axis = max(self.axes)
if max_axis > max_dim:
raise ValueError(
"Axis {} is greater than the maximum "
"allowed value: {}".format(max_axis, max_dim)
)
def __repr__(self):
spec = [
("dtype=" + str(self.dtype)) if self.dtype else "",
("shape=" + str(self.shape)) if self.shape else "",
("ndim=" + str(self.ndim)) if self.ndim else "",
("max_ndim=" + str(self.max_ndim)) if self.max_ndim else "",
("min_ndim=" + str(self.min_ndim)) if self.min_ndim else "",
("axes=" + str(self.axes)) if self.axes else "",
]
return f"InputSpec({', '.join(x for x in spec if x)})"
def get_config(self):
return {
"dtype": self.dtype,
"shape": self.shape,
"ndim": self.ndim,
"max_ndim": self.max_ndim,
"min_ndim": self.min_ndim,
"axes": self.axes,
}
@classmethod
def from_config(cls, config):
return cls(**config)
def assert_input_compatibility(input_spec, inputs, layer_name):
"""Checks compatibility between the layer and provided inputs.
This checks that the tensor(s) `inputs` verify the input assumptions
of a layer (if any). If not, a clear and actional exception gets raised.
Args:
input_spec: An InputSpec instance, list of InputSpec instances, a nested
structure of InputSpec instances, or None.
inputs: Input tensor, list of input tensors, or a nested structure of
input tensors.
layer_name: String, name of the layer (for error message formatting).
Raises:
ValueError: in case of mismatch between
the provided inputs and the expectations of the layer.
"""
if not input_spec:
return
input_spec = tree.flatten(input_spec)
if isinstance(inputs, dict):
# Flatten `inputs` by reference order if input spec names are provided
names = [spec.name for spec in input_spec]
if all(names):
list_inputs = []
for name in names:
if name not in inputs:
raise ValueError(
f'Missing data for input "{name}". '
"You passed a data dictionary with keys "
f"{list(inputs.keys())}. "
f"Expected the following keys: {names}"
)
list_inputs.append(inputs[name])
inputs = list_inputs
inputs = tree.flatten(inputs)
if len(inputs) != len(input_spec):
raise ValueError(
f'Layer "{layer_name}" expects {len(input_spec)} input(s),'
f" but it received {len(inputs)} input tensors. "
f"Inputs received: {inputs}"
)
for input_index, (x, spec) in enumerate(zip(inputs, input_spec)):
if spec is None:
continue
if x is None and spec.optional:
continue
# Having a shape/dtype is the only commonality of the various
# tensor-like objects that may be passed. The most common kind of
# invalid type we are guarding for is a Layer instance (Functional API),
# which does not have a `shape` attribute.
if not hasattr(x, "shape"):
raise ValueError(
f"Inputs to a layer should be tensors. Got '{x}' "
f"(of type {type(x)}) as input for layer '{layer_name}'."
)
shape = backend.standardize_shape(x.shape)
ndim = len(shape)
# Check ndim.
if spec.ndim is not None and not spec.allow_last_axis_squeeze:
if ndim != spec.ndim:
raise ValueError(
f'Input {input_index} of layer "{layer_name}" '
"is incompatible with the layer: "
f"expected ndim={spec.ndim}, found ndim={ndim}. "
f"Full shape received: {shape}"
)
if spec.max_ndim is not None:
if ndim is not None and ndim > spec.max_ndim:
raise ValueError(
f'Input {input_index} of layer "{layer_name}" '
"is incompatible with the layer: "
f"expected max_ndim={spec.max_ndim}, "
f"found ndim={ndim}"
)
if spec.min_ndim is not None:
if ndim is not None and ndim < spec.min_ndim:
raise ValueError(
f'Input {input_index} of layer "{layer_name}" '
"is incompatible with the layer: "
f"expected min_ndim={spec.min_ndim}, "
f"found ndim={ndim}. "
f"Full shape received: {shape}"
)
# Check dtype.
if spec.dtype is not None:
dtype = backend.standardize_dtype(x.dtype)
if dtype != spec.dtype:
raise ValueError(
f'Input {input_index} of layer "{layer_name}" '
"is incompatible with the layer: "
f"expected dtype={spec.dtype}, "
f"found dtype={dtype}"
)
# Check specific shape axes.
if spec.axes:
for axis, value in spec.axes.items():
if value is not None and shape[axis] not in {
value,
None,
}:
raise ValueError(
f'Input {input_index} of layer "{layer_name}" is '
f"incompatible with the layer: expected axis {axis} "
f"of input shape to have value {value}, "
"but received input with "
f"shape {shape}"
)
# Check shape.
if spec.shape is not None:
spec_shape = spec.shape
if spec.allow_last_axis_squeeze:
if shape and shape[-1] == 1:
shape = shape[:-1]
if spec_shape and spec_shape[-1] == 1:
spec_shape = spec_shape[:-1]
for spec_dim, dim in zip(spec_shape, shape):
if spec_dim is not None and dim is not None:
if spec_dim != dim:
raise ValueError(
f'Input {input_index} of layer "{layer_name}" is '
"incompatible with the layer: "
f"expected shape={spec.shape}, "
f"found shape={shape}"
)
|