joebruce1313's picture
Upload 38004 files
1f5470c verified
import copy
import inspect
import typing
import warnings
from keras.src import backend
from keras.src import ops
from keras.src import tree
from keras.src.backend.common import global_state
from keras.src.layers.core.input_layer import Input
from keras.src.layers.core.input_layer import InputLayer
from keras.src.layers.input_spec import InputSpec
from keras.src.layers.layer import Layer
from keras.src.legacy.saving import saving_utils
from keras.src.legacy.saving import serialization as legacy_serialization
from keras.src.models.model import Model
from keras.src.ops.function import Function
from keras.src.ops.function import _build_map
from keras.src.ops.function import make_node_key
from keras.src.ops.node import KerasHistory
from keras.src.ops.node import Node
from keras.src.ops.operation import Operation
from keras.src.saving import serialization_lib
from keras.src.utils import tracking
class Functional(Function, Model):
"""A `Functional` model is a `Model` defined as a directed graph of layers.
Three types of `Model` exist: subclassed `Model`, `Functional` model,
and `Sequential` (a special case of `Functional`).
A `Functional` model can be instantiated by passing two arguments to
`__init__()`. The first argument is the `keras.Input` objects
that represent the inputs to the model.
The second argument specifies the output tensors that represent
the outputs of this model. Both arguments can be a nested structure
of tensors.
Example:
```
inputs = {'x1': keras.Input(shape=(10,), name='x1'),
'x2': keras.Input(shape=(1,), name='x2')}
t = keras.layers.Dense(1, activation='relu')(inputs['x1'])
outputs = keras.layers.Add()([t, inputs['x2']])
model = keras.Model(inputs, outputs)
```
A `Functional` model constructed using the Functional API can also
include raw Keras 3 ops.
Example:
```python
inputs = keras.Input(shape=(10,))
x = keras.layers.Dense(1)(inputs)
outputs = ops.nn.relu(x)
model = keras.Model(inputs, outputs)
```
A new `Functional` model can also be created by using the
intermediate tensors. This enables you to quickly extract sub-components
of the model.
Example:
```python
inputs = keras.Input(shape=(None, None, 3))
processed = keras.layers.RandomCrop(width=32, height=32)(inputs)
conv = keras.layers.Conv2D(filters=2, kernel_size=3)(processed)
pooling = keras.layers.GlobalAveragePooling2D()(conv)
feature = keras.layers.Dense(10)(pooling)
full_model = keras.Model(inputs, feature)
backbone = keras.Model(processed, conv)
activations = keras.Model(conv, feature)
```
Note that the `backbone` and `activations` models are not
created with `keras.Input` objects, but with the tensors
that are originated from `keras.Input` objects.
Under the hood, the layers and weights will
be shared across these models, so that user can train the `full_model`, and
use `backbone` or `activations` to do feature extraction.
The inputs and outputs of the model can be nested structures of tensors as
well, and the created models are standard `Functional` model that support
all the existing API.
Args:
inputs: List of input tensors (must be created via `keras.Input()`
or originated from `keras.Input()`).
outputs: List of output tensors.
name: String, optional. Name of the model.
trainable: Boolean, optional. If the model's variables should be
trainable.
"""
def __new__(cls, *args, **kwargs):
return typing.cast(cls, super().__new__(cls))
@tracking.no_automatic_dependency_tracking
def __init__(self, inputs, outputs, name=None, **kwargs):
if isinstance(inputs, dict):
for k, v in inputs.items():
if isinstance(v, backend.KerasTensor) and k != v.name:
warnings.warn(
"When providing `inputs` as a dict, all keys in the "
"dict must match the names of the corresponding "
f"tensors. Received key '{k}' mapping to value {v} "
f"which has name '{v.name}'. Change the tensor name to "
f"'{k}' (via `Input(..., name='{k}')`)"
)
trainable = kwargs.pop("trainable", None)
flat_inputs = tree.flatten(inputs)
flat_outputs = tree.flatten(outputs)
for x in flat_inputs:
if not isinstance(x, backend.KerasTensor):
raise ValueError(
"All `inputs` values must be KerasTensors. Received: "
f"inputs={inputs} including invalid value {x} of "
f"type {type(x)}"
)
for x in flat_outputs:
if not isinstance(x, backend.KerasTensor):
raise ValueError(
"All `outputs` values must be KerasTensors. Received: "
f"outputs={outputs} including invalid value {x} of "
f"type {type(x)}"
)
if not all(is_input_keras_tensor(t) for t in flat_inputs):
inputs, outputs = clone_graph_nodes(inputs, outputs)
Function.__init__(self, inputs, outputs, name=name)
if trainable is not None:
self.trainable = trainable
self._layers = self.layers
self.build(None)
# We will convert directly (to the correct dtype per input).
self._convert_input_args = False
self._allow_non_tensor_positional_args = True
output_layers = [x._keras_history[0] for x in self.outputs]
self.output_names = [x.name for x in output_layers]
def _lock_state(self):
# Unlike other layers, we allow Functional state to be mutable after
# build. E.g. to attach a layer to a model that is not part of the
# functional DAG.
pass
def _obj_type(self):
return "Functional"
@property
def layers(self):
layers = []
for operation in self._operations:
if isinstance(operation, Layer):
layers.append(operation)
return layers
@layers.setter
def layers(self, _):
raise AttributeError(
"`Model.layers` attribute is reserved and should not be used. "
"Please use another name."
)
def call(self, inputs, training=None, mask=None, **kwargs):
# Add support for training, masking
inputs = self._standardize_inputs(inputs)
if mask is None:
masks = [None] * len(inputs)
else:
masks = tree.flatten(mask)
for x, mask in zip(inputs, masks):
if mask is not None:
backend.set_keras_mask(x, mask)
outputs = self._run_through_graph(
inputs,
operation_fn=lambda op: operation_fn(
op, training=training, **kwargs
),
)
return unpack_singleton(outputs)
def compute_output_spec(self, inputs, training=None, mask=None):
# From Function
return super().compute_output_spec(inputs)
def compute_output_shape(self, input_shape):
# From Function
return super().compute_output_shape(input_shape)
def build(self, input_shape):
self.built = True
@property
def input_shape(self):
input_shapes = tree.map_structure(lambda x: x.shape, self.inputs)
if isinstance(input_shapes, list) and len(input_shapes) == 1:
return input_shapes[0]
return input_shapes
@property
def output_shape(self):
output_shapes = tree.map_structure(lambda x: x.shape, self.outputs)
if isinstance(output_shapes, list) and len(output_shapes) == 1:
return output_shapes[0]
return output_shapes
def _assert_input_compatibility(self, *args):
return super(Model, self)._assert_input_compatibility(*args)
def _maybe_warn_inputs_struct_mismatch(self, inputs, raise_exception=False):
try:
# We first normalize to tuples before performing the check to
# suppress warnings when encountering mismatched tuples and lists.
tree.assert_same_structure(
tree.lists_to_tuples(inputs),
tree.lists_to_tuples(self._inputs_struct),
)
except:
model_inputs_struct = tree.map_structure(
lambda x: x.name, self._inputs_struct
)
inputs_struct = tree.map_structure(
lambda x: f"Tensor(shape={x.shape})", inputs
)
msg = (
"The structure of `inputs` doesn't match the expected "
f"structure.\nExpected: {model_inputs_struct}\n"
f"Received: inputs={inputs_struct}"
)
if raise_exception:
raise ValueError(msg)
warnings.warn(msg)
def _convert_inputs_to_tensors(self, flat_inputs):
converted = []
for x, input in zip(flat_inputs, self._inputs):
if x is None: # TODO: check if optional
converted.append(x)
else:
converted.append(
ops.convert_to_tensor(
x, dtype=input.dtype, sparse=input.sparse
)
)
return converted
def _adjust_input_rank(self, flat_inputs):
flat_ref_shapes = [x.shape for x in self._inputs]
adjusted = []
for x, ref_shape in zip(flat_inputs, flat_ref_shapes):
if x is None:
adjusted.append(x)
continue
x_rank = len(x.shape)
ref_rank = len(ref_shape)
if x_rank == ref_rank:
adjusted.append(x)
continue
if x_rank == ref_rank + 1:
if x.shape[-1] == 1:
adjusted.append(ops.squeeze(x, axis=-1))
continue
if x_rank == ref_rank - 1:
if ref_shape[-1] == 1:
adjusted.append(ops.expand_dims(x, axis=-1))
continue
raise ValueError(
f"Invalid input shape for input {x}. Expected shape "
f"{ref_shape}, but input has incompatible shape {x.shape}"
)
# Add back metadata.
for i in range(len(flat_inputs)):
if hasattr(flat_inputs[i], "_keras_history"):
adjusted[i]._keras_history = flat_inputs[i]._keras_history
mask = backend.get_keras_mask(flat_inputs[i])
if mask is not None:
backend.set_keras_mask(adjusted[i], mask)
return adjusted
def _standardize_inputs(self, inputs):
raise_exception = False
if (
isinstance(self._inputs_struct, list)
and len(self._inputs_struct) == 1
and ops.is_tensor(inputs)
):
inputs = [inputs]
elif isinstance(inputs, dict) and not isinstance(
self._inputs_struct, dict
):
# This is to avoid warning
# when we have reconciable dict/list structs
if hasattr(self._inputs_struct, "__len__") and all(
isinstance(i, backend.KerasTensor) for i in self._inputs_struct
):
expected_keys = set(i.name for i in self._inputs_struct)
keys = set(inputs.keys())
if expected_keys.issubset(keys):
inputs = [inputs[i.name] for i in self._inputs_struct]
else:
raise_exception = True
elif isinstance(self._inputs_struct, backend.KerasTensor):
if self._inputs_struct.name in inputs:
inputs = [inputs[self._inputs_struct.name]]
else:
raise_exception = True
else:
raise_exception = True
if (
isinstance(self._inputs_struct, dict)
and not isinstance(inputs, dict)
and list(self._inputs_struct.keys())
!= sorted(self._inputs_struct.keys())
):
raise_exception = True
self._maybe_warn_inputs_struct_mismatch(
inputs, raise_exception=raise_exception
)
flat_inputs = tree.flatten(inputs)
flat_inputs = self._convert_inputs_to_tensors(flat_inputs)
return self._adjust_input_rank(flat_inputs)
@property
def input(self):
# For backwards compatibility,
# override `input` to retrieve the used-provided
# constructor inputs
return self._inputs_struct
@property
def output(self):
return self._outputs_struct
def add_loss(self, loss):
# Symbolic only. TODO
raise NotImplementedError
@property
def input_spec(self):
if hasattr(self, "_manual_input_spec"):
return self._manual_input_spec
def shape_with_no_batch_size(x):
x = list(x)
if x:
x[0] = None
return tuple(x)
def make_spec_for_tensor(x, name=None):
optional = False
if isinstance(x._keras_history[0], InputLayer):
if x._keras_history[0].optional:
optional = True
return InputSpec(
shape=shape_with_no_batch_size(x.shape),
allow_last_axis_squeeze=True,
name=x._keras_history[0].name if name is None else name,
optional=optional,
)
if isinstance(self._inputs_struct, dict):
if all(
isinstance(x, backend.KerasTensor)
for x in self._inputs_struct.values()
):
# Case where `_nested_inputs` is a plain dict of Inputs.
names = sorted(self._inputs_struct.keys())
return [
make_spec_for_tensor(self._inputs_struct[name], name=name)
for name in names
]
return None # Deeply nested dict: skip checks.
return [make_spec_for_tensor(x) for x in self.inputs]
@input_spec.setter
def input_spec(self, value):
self._manual_input_spec = value
def get_config(self):
if not functional_like_constructor(self.__class__):
# Subclassed networks are not serializable
# (unless serialization is implemented by
# the author of the subclassed network).
return Model.get_config(self)
config = {
"name": self.name,
"trainable": self.trainable,
}
# Build a map from a layer unique name (make_node_key)
# to the index of the nodes that are saved in the config.
# Only nodes in network_nodes are saved.
node_reindexing_map = {}
for operation in self.operations:
if issubclass(operation.__class__, Functional):
# Functional models start with a pre-existing node
# linking their input to output.
kept_nodes = 1
else:
kept_nodes = 0
for original_node_index, node in enumerate(
operation._inbound_nodes
):
node_key = make_node_key(operation, original_node_index)
if node_key in self._nodes:
# i.e. we mark it to be saved
node_reindexing_map[node_key] = kept_nodes
kept_nodes += 1
# serialize and save the layers in layer_configs
layer_configs = []
for operation in self.operations: # From the earliest layers on.
filtered_inbound_nodes = []
for original_node_index, node in enumerate(
operation._inbound_nodes
):
node_key = make_node_key(operation, original_node_index)
if node_key in self._nodes:
# The node is relevant to the model:
# add to filtered_inbound_nodes.
node_data = serialize_node(node, own_nodes=self._nodes)
if node_data is not None:
filtered_inbound_nodes.append(node_data)
serialize_obj_fn = serialization_lib.serialize_keras_object
if global_state.get_global_attribute("use_legacy_config", False):
# Legacy format serialization used for H5 and SavedModel
serialize_obj_fn = legacy_serialization.serialize_keras_object
layer_config = serialize_obj_fn(operation)
layer_config["name"] = operation.name
layer_config["inbound_nodes"] = filtered_inbound_nodes
layer_configs.append(layer_config)
config["layers"] = layer_configs
# Gather info about inputs and outputs.
def get_tensor_config(tensor):
operation = tensor._keras_history[0]
node_index = tensor._keras_history[1]
tensor_index = tensor._keras_history[2]
node_key = make_node_key(operation, node_index)
assert node_key in self._nodes
new_node_index = node_reindexing_map[node_key]
return [operation.name, new_node_index, tensor_index]
def map_tensors(tensors):
if isinstance(tensors, backend.KerasTensor):
return [get_tensor_config(tensors)]
return tree.map_structure(get_tensor_config, tensors)
config["input_layers"] = map_tensors(self._inputs_struct)
config["output_layers"] = map_tensors(self._outputs_struct)
return copy.deepcopy(config)
def functional_from_config(cls, config, custom_objects=None):
"""Instantiates a Functional model from its config (from `get_config()`).
Args:
cls: Class of the model, e.g. a custom subclass of `Model`.
config: Output of `get_config()` for the original model instance.
custom_objects: Optional dict of custom objects.
Returns:
An instance of `cls`.
"""
# Layer instances created during
# the graph reconstruction process
created_layers = {}
# Dictionary mapping layer instances to
# node data that specifies a layer call.
# It acts as a queue that maintains any unprocessed
# layer call until it becomes possible to process it
# (i.e. until the input tensors to the call all exist).
unprocessed_nodes = {}
def add_unprocessed_node(layer, node_data):
"""Add node to layer list
Arg:
layer: layer object
node_data: Node data specifying layer call
"""
if layer not in unprocessed_nodes:
unprocessed_nodes[layer] = [node_data]
else:
unprocessed_nodes[layer].append(node_data)
def process_node(layer, node_data):
"""Reconstruct node by linking to inbound layers
Args:
layer: Layer to process
node_data: List of layer configs
"""
args, kwargs = deserialize_node(node_data, created_layers)
# Call layer on its inputs, thus creating the node
# and building the layer if needed.
layer(*args, **kwargs)
def process_layer(layer_data):
"""Deserializes a layer and index its inbound nodes.
Args:
layer_data: layer config dict.
"""
layer_name = layer_data["name"]
# Instantiate layer.
if "module" not in layer_data:
# Legacy format deserialization (no "module" key)
# used for H5 and SavedModel formats
layer = saving_utils.model_from_config(
layer_data, custom_objects=custom_objects
)
else:
layer = serialization_lib.deserialize_keras_object(
layer_data, custom_objects=custom_objects
)
if not isinstance(layer, Operation):
raise ValueError(
"Unexpected object from deserialization, expected a layer or "
f"operation, got a {type(layer)}"
)
created_layers[layer_name] = layer
# Gather layer inputs.
inbound_nodes_data = layer_data["inbound_nodes"]
for node_data in inbound_nodes_data:
# We don't process nodes (i.e. make layer calls)
# on the fly because the inbound node may not yet exist,
# in case of layer shared at different topological depths
# (e.g. a model such as A(B(A(B(x)))))
add_unprocessed_node(layer, node_data)
# Extract config used to instantiate Functional model from the config. The
# remaining config will be passed as keyword arguments to the Model
# constructor.
functional_config = {}
for key in ["layers", "input_layers", "output_layers"]:
functional_config[key] = config.pop(key)
for key in ["name", "trainable"]:
if key in config:
functional_config[key] = config.pop(key)
else:
functional_config[key] = None
# First, we create all layers and enqueue nodes to be processed
for layer_data in functional_config["layers"]:
process_layer(layer_data)
# Then we process nodes in order of layer depth.
# Nodes that cannot yet be processed (if the inbound node
# does not yet exist) are re-enqueued, and the process
# is repeated until all nodes are processed.
while unprocessed_nodes:
for layer_data in functional_config["layers"]:
layer = created_layers[layer_data["name"]]
# Process all nodes in layer, if not yet processed
if layer in unprocessed_nodes:
node_data_list = unprocessed_nodes[layer]
# Process nodes in order
node_index = 0
while node_index < len(node_data_list):
node_data = node_data_list[node_index]
try:
process_node(layer, node_data)
# If the node does not have all inbound layers
# available, stop processing and continue later
except IndexError:
break
node_index += 1
# If not all nodes processed then store unprocessed nodes
if node_index < len(node_data_list):
unprocessed_nodes[layer] = node_data_list[node_index:]
# If all nodes processed remove the layer
else:
del unprocessed_nodes[layer]
# Create list of input and output tensors and return new class
name = functional_config["name"]
trainable = functional_config["trainable"]
def get_tensor(layer_name, node_index, tensor_index):
assert layer_name in created_layers
layer = created_layers[layer_name]
if isinstance(layer, Functional):
# Functional models start out with a built-in node.
node_index -= 1
layer_output_tensors = layer._inbound_nodes[node_index].output_tensors
return layer_output_tensors[tensor_index]
def map_tensors(tensors):
if (
isinstance(tensors, list)
and len(tensors) == 3
and isinstance(tensors[0], str)
):
# Leaf
return get_tensor(*tensors)
if isinstance(tensors, dict):
return {k: map_tensors(v) for k, v in tensors.items()}
if isinstance(tensors, tuple):
return tuple([map_tensors(v) for v in tensors])
return [map_tensors(v) for v in tensors]
input_tensors = map_tensors(functional_config["input_layers"])
output_tensors = map_tensors(functional_config["output_layers"])
if isinstance(output_tensors, list) and len(output_tensors) == 1:
output_tensors = output_tensors[0]
return cls(
inputs=input_tensors,
outputs=output_tensors,
name=name,
trainable=trainable,
**config,
)
def operation_fn(operation, **call_context_args):
"""Wraps each op to inject the call-context args."""
def call(*args, **kwargs):
# Propagate all registered call-context args
for name, value in call_context_args.items():
if (
name in getattr(operation, "_call_context_args", {})
and value is not None
):
kwargs[name] = value
return operation(*args, **kwargs)
return call
def functional_like_constructor(cls):
init_args = inspect.getfullargspec(cls.__init__).args[1:]
functional_init_args = inspect.getfullargspec(Functional.__init__).args[1:]
if init_args == functional_init_args:
return True
return False
def unpack_singleton(x):
if isinstance(x, (list, tuple)) and len(x) == 1:
return x[0]
return x
def serialize_node(node, own_nodes=()):
if not node.input_tensors:
# Does not need to be serialized.
return
def serialize_keras_tensor(x):
# Serialize KerasTensor while converting
# node indices to only include nodes relevant to `own_nodes`.
if isinstance(x, backend.KerasTensor):
operation, node_index, tensor_index = x._keras_history
irrelevant_node_count = 0
for i, node in enumerate(operation._inbound_nodes[:node_index]):
node_key = make_node_key(operation, i)
if node_key not in own_nodes:
irrelevant_node_count += 1
x._keras_history = KerasHistory(
operation, node_index - irrelevant_node_count, tensor_index
)
serialized = serialization_lib.serialize_keras_object(x)
x._keras_history = KerasHistory(operation, node_index, tensor_index)
return serialized
return x
args = node.arguments.args
kwargs = node.arguments.kwargs
args = tree.map_structure(serialize_keras_tensor, args)
kwargs = tree.map_structure(serialize_keras_tensor, kwargs)
return {
"args": serialization_lib.serialize_keras_object(args),
"kwargs": serialization_lib.serialize_keras_object(kwargs),
}
def deserialize_node(node_data, created_layers):
"""Return (args, kwargs) for calling the node layer."""
if not node_data:
return [], {}
if isinstance(node_data, list):
# Legacy case.
input_tensors = []
for input_data in node_data:
inbound_layer_name = input_data[0]
inbound_node_index = input_data[1]
inbound_tensor_index = input_data[2]
if len(input_data) == 3:
kwargs = {}
elif len(input_data) == 4:
kwargs = input_data[3]
else:
raise ValueError(
"Cannot deserialize the model (invalid config data?)"
)
inbound_layer = created_layers[inbound_layer_name]
# Raise an error if the corresponding layer node
# has not yet been created
if len(inbound_layer._inbound_nodes) <= inbound_node_index:
raise IndexError(
"Layer node index out of bounds.\n"
f"inbound_layer = {inbound_layer}\n"
"inbound_layer._inbound_nodes = "
f"{inbound_layer._inbound_nodes}\n"
f"inbound_node_index = {inbound_node_index}"
)
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
input_tensors.append(
inbound_node.output_tensors[inbound_tensor_index]
)
return [unpack_singleton(input_tensors)], kwargs
args = serialization_lib.deserialize_keras_object(node_data["args"])
kwargs = serialization_lib.deserialize_keras_object(node_data["kwargs"])
def convert_revived_tensor(x):
if isinstance(x, backend.KerasTensor):
history = x._pre_serialization_keras_history
if history is None:
return x
layer = created_layers.get(history[0], None)
if layer is None:
raise ValueError(f"Unknown layer: {history[0]}")
inbound_node_index = history[1]
inbound_tensor_index = history[2]
if len(layer._inbound_nodes) <= inbound_node_index:
raise IndexError(
"Layer node index out of bounds.\n"
f"inbound_layer = {layer}\n"
f"inbound_layer._inbound_nodes = {layer._inbound_nodes}\n"
f"inbound_node_index = {inbound_node_index}"
)
inbound_node = layer._inbound_nodes[inbound_node_index]
return inbound_node.output_tensors[inbound_tensor_index]
return x
args = tree.map_structure(convert_revived_tensor, args)
kwargs = tree.map_structure(convert_revived_tensor, kwargs)
return args, kwargs
def is_input_keras_tensor(x):
(
operation,
node_index,
_,
) = x._keras_history
node = operation._inbound_nodes[node_index]
return node.is_input
def clone_single_keras_tensor(x):
return backend.KerasTensor(
shape=x.shape, dtype=x.dtype, sparse=x.sparse, name=x.name + "_clone"
)
def clone_keras_tensors(tensors, kt_id_mapping):
def swap(x):
if not isinstance(x, backend.KerasTensor):
return x
if id(x) in kt_id_mapping:
return kt_id_mapping[id(x)]
new_x = clone_single_keras_tensor(x)
kt_id_mapping[id(x)] = new_x
return new_x
return tree.map_structure(swap, tensors)
def find_nodes_by_inputs_and_outputs(inputs, outputs):
nodes, _ = _build_map(inputs, outputs)
return nodes
def clone_graph_nodes(inputs, outputs):
"""Clone the `Node` between the inputs and output tensors.
This function is used to create a new functional model from any intermediate
Keras tensors. The clone of the nodes mimic the behavior of reconstructing
the functional graph network by re-executing all the `__call__()` methods.
The cloned nodes will be appended to the layers.
Note that a new `keras.Input` will be created for any items in the
`inputs`
Args:
inputs: A nested structure of `KerasTensor` instances.
outputs: A nested structure of `KerasTensor` instances.
Returns:
A pair of inputs and outputs, with cloned `KerasTensor` instances.
They can be used to create a new functional model.
"""
nodes_to_clone = find_nodes_by_inputs_and_outputs(inputs, outputs)
cloned_inputs = []
cloned_outputs = []
# We not only need to create copies of Nodes (mimic the calls), also need to
# clone Keras tensors to avoid the override of _keras_history attached on
# the Keras tensor. The following dict is used to track any keras tensor we
# cloned The key is the string ID of the original keras tensor, and value is
# the cloned Keras tensor instance.
kt_id_mapping = {}
op_id_mapping = {}
for kt_input in tree.flatten(inputs):
if is_input_keras_tensor(kt_input):
# For any existing Keras tensor from keras.Input, leave them as is.
cloned_inputs.append(kt_input)
kt_id_mapping[id(kt_input)] = kt_input
else:
# We need to create a new Keras tensor for any intermediate tensor
cloned_input = Input(
batch_shape=kt_input.shape,
dtype=kt_input.dtype,
sparse=kt_input.sparse,
name=kt_input.name + "CLONE",
)
cloned_inputs.append(cloned_input)
kt_id_mapping[id(kt_input)] = cloned_input
op_id_mapping[id(kt_input._keras_history[0])] = (
cloned_input._keras_history[0]
)
cloned_inputs = tree.pack_sequence_as(inputs, cloned_inputs)
for kt_output in tree.flatten(outputs):
cpy = clone_single_keras_tensor(kt_output)
# We reuse the _keras_history here, which contains the old information.
cpy._keras_history = kt_output._keras_history
cloned_outputs.append(cpy)
kt_id_mapping[id(kt_output)] = cpy
cloned_outputs = tree.pack_sequence_as(outputs, cloned_outputs)
for node in nodes_to_clone:
if id(node.operation) in op_id_mapping:
operation = op_id_mapping[id(node.operation)]
else:
operation = node.operation
# Clone any Keras tensor to avoid override of _keras_history
# Or reuse an existing Keras tensor if it has already been cloned.
output_copy = clone_keras_tensors(node.output_tensors, kt_id_mapping)
if not isinstance(operation, InputLayer):
call_args_copy = clone_keras_tensors(
node.arguments.args, kt_id_mapping
)
call_kwargs_copy = clone_keras_tensors(
node.arguments.kwargs, kt_id_mapping
)
else:
call_args_copy = ()
call_kwargs_copy = {}
# Creating new nodes based on the existing node information. Node wires
# itself to inbound and outbound layers. The Node constructor actually
# updates this layer's self._inbound_nodes, sets _keras_history on the
# outputs, and adds itself to the `_outbound_nodes` of the layers that
# produced the inputs to this layer call.
Node(
operation,
call_args=call_args_copy,
call_kwargs=call_kwargs_copy,
outputs=output_copy,
)
return cloned_inputs, cloned_outputs