|
|
import collections |
|
|
|
|
|
from keras.src import tree |
|
|
from keras.src.backend import KerasTensor |
|
|
from keras.src.ops.symbolic_arguments import SymbolicArguments |
|
|
|
|
|
|
|
|
class Node: |
|
|
"""A `Node` describes an operation `__call__()` event. |
|
|
|
|
|
A Keras Function is a DAG with `Node` instances as nodes, and |
|
|
`KerasTensor` instances as edges. Nodes aren't `Operation` instances, |
|
|
because a single operation could be called multiple times, which would |
|
|
result in graph cycles. |
|
|
|
|
|
A `__call__()` event involves input tensors (and other input arguments), |
|
|
the operation that was called, and the resulting output tensors. |
|
|
A `Node` will include all this information. |
|
|
|
|
|
Since a single `Operation` could be called multiple times, |
|
|
the `Node` instances are stored on operations as a list. |
|
|
Each time an operation is called, a node is added to `op._inbound_nodes`. |
|
|
Each time the output of an operation is used by another operation, |
|
|
a node is added to `op._outbound_nodes`. |
|
|
|
|
|
Every `KerasTensor` instance has a `KerasHistory` object attached, |
|
|
which tracks the `Node` that records the `__call__()` event that created |
|
|
the tensor. By recursively walking through `Node` instances |
|
|
via the `KerasHistory` metadata of `KerasTensor` instances, once can |
|
|
retrieve the entire DAG of a Keras Function. |
|
|
|
|
|
Args: |
|
|
operation: The Operation that was called in the `op.__call__()` |
|
|
event that this node represents. |
|
|
call_args: The positional arguments the operation was called with. |
|
|
call_kwargs: The keyword arguments the operation was called with. |
|
|
outputs: The output tensors of the `op.__call__()` call. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, operation, call_args=None, call_kwargs=None, outputs=None |
|
|
): |
|
|
self.operation = operation |
|
|
self.arguments = SymbolicArguments(*call_args, **call_kwargs) |
|
|
self.outputs = [] if outputs is None else tree.flatten(outputs) |
|
|
for x in self.outputs: |
|
|
if not isinstance(x, KerasTensor): |
|
|
raise ValueError( |
|
|
"All operation outputs must be tensors. " |
|
|
f"Operation {operation} returned a non-tensor. " |
|
|
f"Non-tensor received: {x}" |
|
|
) |
|
|
|
|
|
zero_history = any( |
|
|
not x.record_history for x in self.arguments.keras_tensors |
|
|
) |
|
|
|
|
|
|
|
|
if not zero_history: |
|
|
for tensor in self.arguments.keras_tensors: |
|
|
if not hasattr(tensor, "_keras_history"): |
|
|
tensor._keras_history = KerasHistory( |
|
|
operation=None, node_index=0, tensor_index=0 |
|
|
) |
|
|
|
|
|
|
|
|
self.operation._inbound_nodes.append(self) |
|
|
for kt in self.arguments.keras_tensors: |
|
|
inbound_op = kt._keras_history.operation |
|
|
if inbound_op is not None: |
|
|
inbound_op._outbound_nodes.append(self) |
|
|
|
|
|
|
|
|
if not zero_history: |
|
|
node_index = len(self.operation._inbound_nodes) - 1 |
|
|
for i, tensor in enumerate(self.outputs): |
|
|
tensor._keras_history = KerasHistory( |
|
|
operation=operation, node_index=node_index, tensor_index=i |
|
|
) |
|
|
|
|
|
|
|
|
self.is_input = not self.arguments.keras_tensors |
|
|
|
|
|
def __repr__(self): |
|
|
return f"<Node operation={self.operation.name}, id={id(self)}>" |
|
|
|
|
|
@property |
|
|
def input_tensors(self): |
|
|
return self.arguments.keras_tensors |
|
|
|
|
|
@property |
|
|
def output_tensors(self): |
|
|
return self.outputs |
|
|
|
|
|
@property |
|
|
def parent_nodes(self): |
|
|
"""The parent `Node`s. |
|
|
|
|
|
Returns: |
|
|
all the `Node`s whose output this node immediately depends on. |
|
|
""" |
|
|
node_deps = [] |
|
|
for kt in self.arguments.keras_tensors: |
|
|
op = kt._keras_history.operation |
|
|
node_index = kt._keras_history.node_index |
|
|
if op is not None: |
|
|
node_deps.append(op._inbound_nodes[node_index]) |
|
|
return node_deps |
|
|
|
|
|
|
|
|
class KerasHistory( |
|
|
collections.namedtuple( |
|
|
"KerasHistory", ["operation", "node_index", "tensor_index"] |
|
|
) |
|
|
): |
|
|
"""Tracks the Operation call that created a Tensor. |
|
|
|
|
|
During construction of Keras Functions, this metadata is added to |
|
|
each Tensor produced as the output of an Operation. |
|
|
This allows Keras to track how each Tensor was produced, and |
|
|
this information is later retraced by the `Function` class to |
|
|
reconstruct the Operations graph. |
|
|
|
|
|
Attributes: |
|
|
operation: The Operation instance that produced the Tensor. |
|
|
node_index: The specific call to the Operation that produced this Tensor. |
|
|
Operations can be called multiple times in order to share weights. A new |
|
|
node is created every time an Operation is called. The corresponding |
|
|
node that represents the call event that produced the Tensor can be |
|
|
found at `op._inbound_nodes[node_index]`. |
|
|
tensor_index: The output index for this Tensor. |
|
|
Always zero if the Operation that produced this Tensor |
|
|
only has one output. Nested structures of |
|
|
Tensors are deterministically assigned an index via `nest.flatten`. |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
__slots__ = () |
|
|
|
|
|
|
|
|
def is_keras_tensor(obj): |
|
|
return hasattr(obj, "_keras_history") |
|
|
|