File size: 16,542 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 |
import collections
from keras.src import tree
from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.backend.config import backend
from keras.src.ops.operation import Operation
@keras_export("keras.Function")
class Function(Operation):
"""Class that encapsulates a computation graph of Keras operations.
You can use a `Function` to capture the computation graph linking
some input tensors to some output tensors, and reapply the same
computation on new inputs.
A `Function` is similar to a Functional Model, with the difference
that it is stateless (it does not track state variables)
and does not implement the `Layer` API.
Example:
```python
input_1 = keras.KerasTensor(shape=(None, 2, 3))
input_2 = keras.KerasTensor(shape=(None, 2, 3))
x = input_1 + input_2
output = keras.ops.sigmoid(x)
fn = keras.Function(inputs=[input_1, input_2], outputs=output)
input_1_val = np.random.random((4, 2, 3))
input_2_val = np.random.random((4, 2, 3))
output_val = fn([input_1_val, input_2_val])
```
Args:
inputs: `KerasTensor` instance or nested structured of
`KerasTensor` instances.
outputs: `KerasTensor` instance or nested structured of
`KerasTensor` instances. They should be computable
given only the values of `inputs`.
name: String. The name of the function.
"""
def __init__(self, inputs, outputs, name=None):
super().__init__(name=name)
if backend() == "tensorflow":
# Temporary work around for
# https://github.com/keras-team/keras/issues/931
# This stop tensorflow from wrapping tf.function output in a
# _DictWrapper object.
_self_setattr_tracking = getattr(
self, "_self_setattr_tracking", True
)
self._self_setattr_tracking = False
self._inputs_struct = tree.map_structure(lambda x: x, inputs)
self._outputs_struct = tree.map_structure(lambda x: x, outputs)
self._inputs = tree.flatten(inputs)
self._outputs = tree.flatten(outputs)
if not self._inputs:
raise ValueError(
"`inputs` argument cannot be empty. Received:\n"
f"inputs={inputs}\n"
f"outputs={outputs}"
)
if not self._outputs:
raise ValueError(
"`outputs` argument cannot be empty. Received:\n"
f"inputs={inputs}\n"
f"outputs={outputs}"
)
if backend() == "tensorflow":
self._self_setattr_tracking = _self_setattr_tracking
(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph(
self._inputs, self._outputs
)
self._nodes = nodes
self._nodes_by_depth = nodes_by_depth
self._operations = operations
self._operations_by_depth = operations_by_depth
for input in self._inputs:
if (
input._keras_history.operation
and not input._keras_history.operation._outbound_nodes
):
raise ValueError("`inputs` not connected to `outputs`")
@property
def operations(self):
return self._operations[:]
@property
def inputs(self):
"""Flat list of the symbolic inputs of the Function."""
return self._inputs
@property
def outputs(self):
"""Flat list of the symbolic outputs of the Function."""
return self._outputs
def compute_output_spec(self, inputs):
self._assert_input_compatibility(inputs)
# Check if input shapes are identical to ref input shapes,
# if so take a shortcut.
shortcut = True
for x, x_ref in zip(tree.flatten(inputs), self._inputs):
if x.shape != x_ref.shape:
shortcut = False
break
if shortcut:
return tree.map_structure(
lambda x: KerasTensor(shape=x.shape, dtype=x.dtype),
self._outputs_struct,
)
# No luck; take the long road through the graph.
# Original Keras used a cache to avoid recomputing all this
# when known input shapes where seen again. Perhaps a good
# idea to bring that back.
return self._run_through_graph(
inputs, operation_fn=lambda op: op.compute_output_spec
)
def compute_output_shape(self, input_shape):
# Wrap `input_shape` into the structure of KerasTensor to utilize
# `compute_output_spec`.
input_shape_struct = tree.map_shape_structure(
lambda x: KerasTensor(shape=x), input_shape
)
# Ensure that dtype and sparse settings are the same as self._inputs,
# because we only care about the shape in this function.
for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs):
x._dtype = x_ref.dtype
x._sparse = x_ref.sparse
output_spec = self.compute_output_spec(input_shape_struct)
return tree.map_structure(lambda x: x.shape, output_spec)
def call(self, inputs):
"""Computes output tensors for new inputs."""
self._assert_input_compatibility(inputs)
return self._run_through_graph(inputs, operation_fn=lambda op: op)
def _run_through_graph(self, inputs, operation_fn, call_fn=None):
"""Execute the graph.
At each node we compute outputs via
`operation_fn(node.operation)(*args, **kwargs)`.
"""
inputs = tree.flatten(inputs)
# Dictionary mapping reference tensors to computed tensors.
tensor_dict = {}
for x, y in zip(self.inputs, inputs):
tensor_dict[id(x)] = y
nodes_by_depth = self._nodes_by_depth
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
for depth in depth_keys:
nodes = nodes_by_depth[depth]
for node in nodes:
if not node.operation or node.is_input:
continue # Input tensors already exist.
if any(id(x) not in tensor_dict for x in node.input_tensors):
continue # Node is not computable, try skipping.
args, kwargs = node.arguments.fill_in(tensor_dict)
op = operation_fn(node.operation)
if call_fn is not None:
outputs = call_fn(op, *args, **kwargs)
else:
outputs = op(*args, **kwargs)
# Update tensor_dict.
for x, y in zip(node.outputs, tree.flatten(outputs)):
tensor_dict[id(x)] = y
output_tensors = []
for x in self.outputs:
output_tensors.append(tensor_dict[id(x)])
return tree.pack_sequence_as(self._outputs_struct, output_tensors)
def _assert_input_compatibility(self, inputs):
try:
tree.assert_same_structure(inputs, self._inputs_struct)
except ValueError:
raise ValueError(
"Function was called with an invalid input structure. "
f"Expected input structure: {self._inputs_struct}\n"
f"Received input structure: {inputs}"
)
for x, x_ref in zip(tree.flatten(inputs), self._inputs):
if len(x.shape) != len(x_ref.shape):
raise ValueError(
f"{self.__class__.__name__} was passed "
f"incompatible inputs. For input '{x_ref.name}', "
f"expected shape {x_ref.shape}, but received "
f"instead a tensor with shape {x.shape}."
)
for dim, ref_dim in zip(x.shape, x_ref.shape):
if ref_dim is not None and dim is not None:
if dim != ref_dim:
raise ValueError(
f"{self.__class__.__name__} was passed "
f"incompatible inputs. For input '{x_ref.name}', "
f"expected shape {x_ref.shape}, but received "
f"instead a tensor with shape {x.shape}."
)
def make_node_key(op, node_index):
return str(id(op)) + "_ib-" + str(node_index)
def map_graph(inputs, outputs):
"""Validates a graph's topology and gather its operations and nodes.
Args:
inputs: List of input tensors.
outputs: List of outputs tensors.
Returns:
A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`.
- nodes: set of Node instances
- nodes_by_depth: dict mapping ints (depth) to lists of node instances.
- operations: list of Operation instances.
- operations_by_depth: dict mapping ints (depth) to lists of Operation
instances.
"""
# "depth" is number of operations between output Node and the Node.
# Nodes are ordered from inputs -> outputs.
nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs)
network_nodes = {
make_node_key(node.operation, node.operation._inbound_nodes.index(node))
for node in nodes_in_decreasing_depth
}
nodes_depths = {} # dict {node: depth value}
operations_depths = {} # dict {operation: depth value}
for node in reversed(nodes_in_decreasing_depth):
# If the depth is not set, the node has no outbound nodes (depth 0).
depth = nodes_depths.setdefault(node, 0)
# Update the depth of the corresponding operation
previous_depth = operations_depths.get(node.operation, 0)
# If we've seen this operation before at a higher depth,
# we should use that depth instead of the node depth.
# This is necessary for shared operations that have inputs at different
# depth levels in the graph.
depth = max(depth, previous_depth)
operations_depths[node.operation] = depth
nodes_depths[node] = depth
# Update the depth of inbound nodes.
# The "depth" of a node is the max of the depths
# of all nodes it is connected to + 1.
for node_dep in node.parent_nodes:
previous_depth = nodes_depths.get(node_dep, 0)
nodes_depths[node_dep] = max(depth + 1, previous_depth)
# Handle inputs that are not connected to outputs.
# We do not error out here because the inputs may be used to compute losses
# and metrics.
for input_t in inputs:
input_operation = input_t._keras_history[0]
if input_operation and input_operation not in operations_depths:
operations_depths[input_operation] = 0
operation_indices[input_operation] = -1
nodes_depths[input_operation._inbound_nodes[0]] = 0
network_nodes.add(make_node_key(input_operation, 0))
# Build a dict {depth: list of nodes with this depth}
nodes_by_depth = collections.defaultdict(list)
for node, depth in nodes_depths.items():
nodes_by_depth[depth].append(node)
# Build a dict {depth: list of operations with this depth}
operations_by_depth = collections.defaultdict(list)
for operation, depth in operations_depths.items():
operations_by_depth[depth].append(operation)
# Get sorted list of operation depths.
depth_keys = list(operations_by_depth.keys())
depth_keys.sort(reverse=True)
# Set self.operations ordered by depth.
operations = []
for depth in depth_keys:
operations_for_depth = operations_by_depth[depth]
# Network.operations needs to have a deterministic order:
# here we order them by traversal order.
operations_for_depth.sort(key=lambda x: operation_indices[x])
operations.extend(operations_for_depth)
# Get sorted list of node depths.
depth_keys = list(nodes_by_depth.keys())
depth_keys.sort(reverse=True)
# Check that all tensors required are computable.
# computable_tensors: all tensors in the graph
# that can be computed from the inputs provided.
computable_tensors = set()
for x in inputs:
computable_tensors.add(x)
operations_with_complete_input = [] # To provide a better error msg.
for depth in depth_keys:
for node in nodes_by_depth[depth]:
for x in tree.flatten(node.input_tensors):
if x not in computable_tensors:
operation = node.operation
raise ValueError(
"Graph disconnected: cannot find parent for "
f"tensor {x} at operation '{operation}'. "
"The following previous operations were accessed "
f"without issue: {operations_with_complete_input}"
)
operations_with_complete_input.append(node.operation.name)
for x in tree.flatten(node.outputs):
computable_tensors.add(x)
# Ensure name unicity, which will be crucial for serialization
# (since serialized nodes refer to operations by their name).
all_names = [operation.name for operation in operations]
for name in all_names:
if all_names.count(name) != 1:
raise ValueError(
f'The name "{name}" is used {all_names.count(name)} '
"times in the model. All operation names should be unique."
)
return network_nodes, nodes_by_depth, operations, operations_by_depth
def _build_map(inputs, outputs):
"""Topologically sort nodes in order from inputs to outputs.
It uses a depth-first search to topologically sort nodes that appear in the
_keras_history connectivity metadata of `outputs`.
Args:
outputs: the output tensors whose _keras_history metadata should be
walked. This may be an arbitrary nested structure.
Returns:
A tuple like (ordered_nodes, operation_to_first_traversal_index)
ordered_nodes: list of nodes appearing in the keras history,
topologically sorted from original inputs to the `outputs`.
(If outputs have different sets of ancestors, the inputs to one
output may appear after a different output).
operation_to_first_traversal_index:
A dict mapping operation to the traversal index in the DFS where it
is seen. Note: if a operation is shared by several nodes, the dict
will onlystore the index corresponding to the *first* time the
operation seen.
"""
finished_nodes = set()
nodes_in_progress = set()
nodes_in_decreasing_depth = [] # nodes from inputs -> outputs.
operation_indices = {} # operation -> in traversal order.
for output in tree.flatten(outputs):
_build_map_helper(
inputs,
output,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
operation_indices,
)
return nodes_in_decreasing_depth, operation_indices
def _build_map_helper(
inputs,
tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
operation_indices,
):
"""Recursive helper for `_build_map`."""
(
operation,
node_index,
_,
) = tensor._keras_history
if not operation:
return
node = operation._inbound_nodes[node_index]
# Don't repeat work for shared subgraphs
if node in finished_nodes:
return
# Prevent cycles.
if node in nodes_in_progress:
raise ValueError(
f"Tensor {tensor} from operation '{operation.name}' is part of a "
"cycle."
)
# Store the traversal order for operation sorting.
if operation not in operation_indices:
operation_indices[operation] = len(operation_indices)
# Propagate to all previous tensors connected to this node.
nodes_in_progress.add(node)
if not node.is_input and tensor not in tree.flatten(inputs):
for tensor in node.input_tensors:
_build_map_helper(
inputs,
tensor,
finished_nodes,
nodes_in_progress,
nodes_in_decreasing_depth,
operation_indices,
)
finished_nodes.add(node)
nodes_in_progress.remove(node)
nodes_in_decreasing_depth.append(node)
|