|
|
import collections |
|
|
|
|
|
from keras.src import tree |
|
|
from keras.src.api_export import keras_export |
|
|
from keras.src.backend import KerasTensor |
|
|
from keras.src.backend.config import backend |
|
|
from keras.src.ops.operation import Operation |
|
|
|
|
|
|
|
|
@keras_export("keras.Function") |
|
|
class Function(Operation): |
|
|
"""Class that encapsulates a computation graph of Keras operations. |
|
|
|
|
|
You can use a `Function` to capture the computation graph linking |
|
|
some input tensors to some output tensors, and reapply the same |
|
|
computation on new inputs. |
|
|
|
|
|
A `Function` is similar to a Functional Model, with the difference |
|
|
that it is stateless (it does not track state variables) |
|
|
and does not implement the `Layer` API. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
input_1 = keras.KerasTensor(shape=(None, 2, 3)) |
|
|
input_2 = keras.KerasTensor(shape=(None, 2, 3)) |
|
|
x = input_1 + input_2 |
|
|
output = keras.ops.sigmoid(x) |
|
|
fn = keras.Function(inputs=[input_1, input_2], outputs=output) |
|
|
|
|
|
input_1_val = np.random.random((4, 2, 3)) |
|
|
input_2_val = np.random.random((4, 2, 3)) |
|
|
output_val = fn([input_1_val, input_2_val]) |
|
|
``` |
|
|
|
|
|
Args: |
|
|
inputs: `KerasTensor` instance or nested structured of |
|
|
`KerasTensor` instances. |
|
|
outputs: `KerasTensor` instance or nested structured of |
|
|
`KerasTensor` instances. They should be computable |
|
|
given only the values of `inputs`. |
|
|
name: String. The name of the function. |
|
|
""" |
|
|
|
|
|
def __init__(self, inputs, outputs, name=None): |
|
|
super().__init__(name=name) |
|
|
|
|
|
if backend() == "tensorflow": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_self_setattr_tracking = getattr( |
|
|
self, "_self_setattr_tracking", True |
|
|
) |
|
|
self._self_setattr_tracking = False |
|
|
self._inputs_struct = tree.map_structure(lambda x: x, inputs) |
|
|
self._outputs_struct = tree.map_structure(lambda x: x, outputs) |
|
|
self._inputs = tree.flatten(inputs) |
|
|
self._outputs = tree.flatten(outputs) |
|
|
if not self._inputs: |
|
|
raise ValueError( |
|
|
"`inputs` argument cannot be empty. Received:\n" |
|
|
f"inputs={inputs}\n" |
|
|
f"outputs={outputs}" |
|
|
) |
|
|
if not self._outputs: |
|
|
raise ValueError( |
|
|
"`outputs` argument cannot be empty. Received:\n" |
|
|
f"inputs={inputs}\n" |
|
|
f"outputs={outputs}" |
|
|
) |
|
|
|
|
|
if backend() == "tensorflow": |
|
|
self._self_setattr_tracking = _self_setattr_tracking |
|
|
|
|
|
(nodes, nodes_by_depth, operations, operations_by_depth) = map_graph( |
|
|
self._inputs, self._outputs |
|
|
) |
|
|
self._nodes = nodes |
|
|
self._nodes_by_depth = nodes_by_depth |
|
|
self._operations = operations |
|
|
self._operations_by_depth = operations_by_depth |
|
|
for input in self._inputs: |
|
|
if ( |
|
|
input._keras_history.operation |
|
|
and not input._keras_history.operation._outbound_nodes |
|
|
): |
|
|
raise ValueError("`inputs` not connected to `outputs`") |
|
|
|
|
|
@property |
|
|
def operations(self): |
|
|
return self._operations[:] |
|
|
|
|
|
@property |
|
|
def inputs(self): |
|
|
"""Flat list of the symbolic inputs of the Function.""" |
|
|
return self._inputs |
|
|
|
|
|
@property |
|
|
def outputs(self): |
|
|
"""Flat list of the symbolic outputs of the Function.""" |
|
|
return self._outputs |
|
|
|
|
|
def compute_output_spec(self, inputs): |
|
|
self._assert_input_compatibility(inputs) |
|
|
|
|
|
|
|
|
shortcut = True |
|
|
for x, x_ref in zip(tree.flatten(inputs), self._inputs): |
|
|
if x.shape != x_ref.shape: |
|
|
shortcut = False |
|
|
break |
|
|
if shortcut: |
|
|
return tree.map_structure( |
|
|
lambda x: KerasTensor(shape=x.shape, dtype=x.dtype), |
|
|
self._outputs_struct, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self._run_through_graph( |
|
|
inputs, operation_fn=lambda op: op.compute_output_spec |
|
|
) |
|
|
|
|
|
def compute_output_shape(self, input_shape): |
|
|
|
|
|
|
|
|
input_shape_struct = tree.map_shape_structure( |
|
|
lambda x: KerasTensor(shape=x), input_shape |
|
|
) |
|
|
|
|
|
|
|
|
for x, x_ref in zip(tree.flatten(input_shape_struct), self._inputs): |
|
|
x._dtype = x_ref.dtype |
|
|
x._sparse = x_ref.sparse |
|
|
output_spec = self.compute_output_spec(input_shape_struct) |
|
|
return tree.map_structure(lambda x: x.shape, output_spec) |
|
|
|
|
|
def call(self, inputs): |
|
|
"""Computes output tensors for new inputs.""" |
|
|
self._assert_input_compatibility(inputs) |
|
|
return self._run_through_graph(inputs, operation_fn=lambda op: op) |
|
|
|
|
|
def _run_through_graph(self, inputs, operation_fn, call_fn=None): |
|
|
"""Execute the graph. |
|
|
|
|
|
At each node we compute outputs via |
|
|
`operation_fn(node.operation)(*args, **kwargs)`. |
|
|
""" |
|
|
inputs = tree.flatten(inputs) |
|
|
|
|
|
|
|
|
tensor_dict = {} |
|
|
for x, y in zip(self.inputs, inputs): |
|
|
tensor_dict[id(x)] = y |
|
|
|
|
|
nodes_by_depth = self._nodes_by_depth |
|
|
depth_keys = list(nodes_by_depth.keys()) |
|
|
depth_keys.sort(reverse=True) |
|
|
|
|
|
for depth in depth_keys: |
|
|
nodes = nodes_by_depth[depth] |
|
|
for node in nodes: |
|
|
if not node.operation or node.is_input: |
|
|
continue |
|
|
|
|
|
if any(id(x) not in tensor_dict for x in node.input_tensors): |
|
|
continue |
|
|
|
|
|
args, kwargs = node.arguments.fill_in(tensor_dict) |
|
|
op = operation_fn(node.operation) |
|
|
if call_fn is not None: |
|
|
outputs = call_fn(op, *args, **kwargs) |
|
|
else: |
|
|
outputs = op(*args, **kwargs) |
|
|
|
|
|
|
|
|
for x, y in zip(node.outputs, tree.flatten(outputs)): |
|
|
tensor_dict[id(x)] = y |
|
|
|
|
|
output_tensors = [] |
|
|
for x in self.outputs: |
|
|
output_tensors.append(tensor_dict[id(x)]) |
|
|
|
|
|
return tree.pack_sequence_as(self._outputs_struct, output_tensors) |
|
|
|
|
|
def _assert_input_compatibility(self, inputs): |
|
|
try: |
|
|
tree.assert_same_structure(inputs, self._inputs_struct) |
|
|
except ValueError: |
|
|
raise ValueError( |
|
|
"Function was called with an invalid input structure. " |
|
|
f"Expected input structure: {self._inputs_struct}\n" |
|
|
f"Received input structure: {inputs}" |
|
|
) |
|
|
for x, x_ref in zip(tree.flatten(inputs), self._inputs): |
|
|
if len(x.shape) != len(x_ref.shape): |
|
|
raise ValueError( |
|
|
f"{self.__class__.__name__} was passed " |
|
|
f"incompatible inputs. For input '{x_ref.name}', " |
|
|
f"expected shape {x_ref.shape}, but received " |
|
|
f"instead a tensor with shape {x.shape}." |
|
|
) |
|
|
for dim, ref_dim in zip(x.shape, x_ref.shape): |
|
|
if ref_dim is not None and dim is not None: |
|
|
if dim != ref_dim: |
|
|
raise ValueError( |
|
|
f"{self.__class__.__name__} was passed " |
|
|
f"incompatible inputs. For input '{x_ref.name}', " |
|
|
f"expected shape {x_ref.shape}, but received " |
|
|
f"instead a tensor with shape {x.shape}." |
|
|
) |
|
|
|
|
|
|
|
|
def make_node_key(op, node_index): |
|
|
return str(id(op)) + "_ib-" + str(node_index) |
|
|
|
|
|
|
|
|
def map_graph(inputs, outputs): |
|
|
"""Validates a graph's topology and gather its operations and nodes. |
|
|
|
|
|
Args: |
|
|
inputs: List of input tensors. |
|
|
outputs: List of outputs tensors. |
|
|
|
|
|
Returns: |
|
|
A tuple `(nodes, nodes_by_depth, operations, operations_by_depth)`. |
|
|
- nodes: set of Node instances |
|
|
- nodes_by_depth: dict mapping ints (depth) to lists of node instances. |
|
|
- operations: list of Operation instances. |
|
|
- operations_by_depth: dict mapping ints (depth) to lists of Operation |
|
|
instances. |
|
|
""" |
|
|
|
|
|
|
|
|
nodes_in_decreasing_depth, operation_indices = _build_map(inputs, outputs) |
|
|
network_nodes = { |
|
|
make_node_key(node.operation, node.operation._inbound_nodes.index(node)) |
|
|
for node in nodes_in_decreasing_depth |
|
|
} |
|
|
|
|
|
nodes_depths = {} |
|
|
operations_depths = {} |
|
|
|
|
|
for node in reversed(nodes_in_decreasing_depth): |
|
|
|
|
|
depth = nodes_depths.setdefault(node, 0) |
|
|
|
|
|
|
|
|
previous_depth = operations_depths.get(node.operation, 0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
depth = max(depth, previous_depth) |
|
|
operations_depths[node.operation] = depth |
|
|
nodes_depths[node] = depth |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for node_dep in node.parent_nodes: |
|
|
previous_depth = nodes_depths.get(node_dep, 0) |
|
|
nodes_depths[node_dep] = max(depth + 1, previous_depth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for input_t in inputs: |
|
|
input_operation = input_t._keras_history[0] |
|
|
if input_operation and input_operation not in operations_depths: |
|
|
operations_depths[input_operation] = 0 |
|
|
operation_indices[input_operation] = -1 |
|
|
nodes_depths[input_operation._inbound_nodes[0]] = 0 |
|
|
network_nodes.add(make_node_key(input_operation, 0)) |
|
|
|
|
|
|
|
|
nodes_by_depth = collections.defaultdict(list) |
|
|
for node, depth in nodes_depths.items(): |
|
|
nodes_by_depth[depth].append(node) |
|
|
|
|
|
|
|
|
operations_by_depth = collections.defaultdict(list) |
|
|
for operation, depth in operations_depths.items(): |
|
|
operations_by_depth[depth].append(operation) |
|
|
|
|
|
|
|
|
depth_keys = list(operations_by_depth.keys()) |
|
|
depth_keys.sort(reverse=True) |
|
|
|
|
|
|
|
|
operations = [] |
|
|
for depth in depth_keys: |
|
|
operations_for_depth = operations_by_depth[depth] |
|
|
|
|
|
|
|
|
operations_for_depth.sort(key=lambda x: operation_indices[x]) |
|
|
operations.extend(operations_for_depth) |
|
|
|
|
|
|
|
|
depth_keys = list(nodes_by_depth.keys()) |
|
|
depth_keys.sort(reverse=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
computable_tensors = set() |
|
|
for x in inputs: |
|
|
computable_tensors.add(x) |
|
|
|
|
|
operations_with_complete_input = [] |
|
|
for depth in depth_keys: |
|
|
for node in nodes_by_depth[depth]: |
|
|
for x in tree.flatten(node.input_tensors): |
|
|
if x not in computable_tensors: |
|
|
operation = node.operation |
|
|
raise ValueError( |
|
|
"Graph disconnected: cannot find parent for " |
|
|
f"tensor {x} at operation '{operation}'. " |
|
|
"The following previous operations were accessed " |
|
|
f"without issue: {operations_with_complete_input}" |
|
|
) |
|
|
operations_with_complete_input.append(node.operation.name) |
|
|
|
|
|
for x in tree.flatten(node.outputs): |
|
|
computable_tensors.add(x) |
|
|
|
|
|
|
|
|
|
|
|
all_names = [operation.name for operation in operations] |
|
|
for name in all_names: |
|
|
if all_names.count(name) != 1: |
|
|
raise ValueError( |
|
|
f'The name "{name}" is used {all_names.count(name)} ' |
|
|
"times in the model. All operation names should be unique." |
|
|
) |
|
|
return network_nodes, nodes_by_depth, operations, operations_by_depth |
|
|
|
|
|
|
|
|
def _build_map(inputs, outputs): |
|
|
"""Topologically sort nodes in order from inputs to outputs. |
|
|
|
|
|
It uses a depth-first search to topologically sort nodes that appear in the |
|
|
_keras_history connectivity metadata of `outputs`. |
|
|
|
|
|
Args: |
|
|
outputs: the output tensors whose _keras_history metadata should be |
|
|
walked. This may be an arbitrary nested structure. |
|
|
|
|
|
Returns: |
|
|
A tuple like (ordered_nodes, operation_to_first_traversal_index) |
|
|
ordered_nodes: list of nodes appearing in the keras history, |
|
|
topologically sorted from original inputs to the `outputs`. |
|
|
(If outputs have different sets of ancestors, the inputs to one |
|
|
output may appear after a different output). |
|
|
operation_to_first_traversal_index: |
|
|
A dict mapping operation to the traversal index in the DFS where it |
|
|
is seen. Note: if a operation is shared by several nodes, the dict |
|
|
will onlystore the index corresponding to the *first* time the |
|
|
operation seen. |
|
|
""" |
|
|
finished_nodes = set() |
|
|
nodes_in_progress = set() |
|
|
nodes_in_decreasing_depth = [] |
|
|
operation_indices = {} |
|
|
for output in tree.flatten(outputs): |
|
|
_build_map_helper( |
|
|
inputs, |
|
|
output, |
|
|
finished_nodes, |
|
|
nodes_in_progress, |
|
|
nodes_in_decreasing_depth, |
|
|
operation_indices, |
|
|
) |
|
|
return nodes_in_decreasing_depth, operation_indices |
|
|
|
|
|
|
|
|
def _build_map_helper( |
|
|
inputs, |
|
|
tensor, |
|
|
finished_nodes, |
|
|
nodes_in_progress, |
|
|
nodes_in_decreasing_depth, |
|
|
operation_indices, |
|
|
): |
|
|
"""Recursive helper for `_build_map`.""" |
|
|
( |
|
|
operation, |
|
|
node_index, |
|
|
_, |
|
|
) = tensor._keras_history |
|
|
if not operation: |
|
|
return |
|
|
|
|
|
node = operation._inbound_nodes[node_index] |
|
|
|
|
|
|
|
|
if node in finished_nodes: |
|
|
return |
|
|
|
|
|
|
|
|
if node in nodes_in_progress: |
|
|
raise ValueError( |
|
|
f"Tensor {tensor} from operation '{operation.name}' is part of a " |
|
|
"cycle." |
|
|
) |
|
|
|
|
|
|
|
|
if operation not in operation_indices: |
|
|
operation_indices[operation] = len(operation_indices) |
|
|
|
|
|
|
|
|
nodes_in_progress.add(node) |
|
|
if not node.is_input and tensor not in tree.flatten(inputs): |
|
|
for tensor in node.input_tensors: |
|
|
_build_map_helper( |
|
|
inputs, |
|
|
tensor, |
|
|
finished_nodes, |
|
|
nodes_in_progress, |
|
|
nodes_in_decreasing_depth, |
|
|
operation_indices, |
|
|
) |
|
|
|
|
|
finished_nodes.add(node) |
|
|
nodes_in_progress.remove(node) |
|
|
nodes_in_decreasing_depth.append(node) |
|
|
|