Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py +279 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py +1173 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py +808 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py +5 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py +796 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py +163 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__init__.py +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py +820 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py +154 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py +372 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py +520 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py +97 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py +329 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py +87 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py +692 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py +141 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +83 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py +161 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py +1147 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py +12 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py +394 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py +187 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py +404 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py +26 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc +0 -0
- SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc +0 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py
ADDED
|
@@ -0,0 +1,279 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import zipfile
|
| 3 |
+
|
| 4 |
+
from absl import logging
|
| 5 |
+
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
from keras.src.legacy.saving import legacy_h5_format
|
| 8 |
+
from keras.src.saving import saving_lib
|
| 9 |
+
from keras.src.utils import file_utils
|
| 10 |
+
from keras.src.utils import io_utils
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
import h5py
|
| 14 |
+
except ImportError:
|
| 15 |
+
h5py = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
|
| 19 |
+
def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
|
| 20 |
+
"""Saves a model as a `.keras` file.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
model: Keras model instance to be saved.
|
| 24 |
+
filepath: `str` or `pathlib.Path` object. Path where to save the model.
|
| 25 |
+
overwrite: Whether we should overwrite any existing model at the target
|
| 26 |
+
location, or instead ask the user via an interactive prompt.
|
| 27 |
+
zipped: Whether to save the model as a zipped `.keras`
|
| 28 |
+
archive (default when saving locally), or as an unzipped directory
|
| 29 |
+
(default when saving on the Hugging Face Hub).
|
| 30 |
+
|
| 31 |
+
Example:
|
| 32 |
+
|
| 33 |
+
```python
|
| 34 |
+
model = keras.Sequential(
|
| 35 |
+
[
|
| 36 |
+
keras.layers.Dense(5, input_shape=(3,)),
|
| 37 |
+
keras.layers.Softmax(),
|
| 38 |
+
],
|
| 39 |
+
)
|
| 40 |
+
model.save("model.keras")
|
| 41 |
+
loaded_model = keras.saving.load_model("model.keras")
|
| 42 |
+
x = keras.random.uniform((10, 3))
|
| 43 |
+
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Note that `model.save()` is an alias for `keras.saving.save_model()`.
|
| 47 |
+
|
| 48 |
+
The saved `.keras` file is a `zip` archive that contains:
|
| 49 |
+
|
| 50 |
+
- The model's configuration (architecture)
|
| 51 |
+
- The model's weights
|
| 52 |
+
- The model's optimizer's state (if any)
|
| 53 |
+
|
| 54 |
+
Thus models can be reinstantiated in the exact same state.
|
| 55 |
+
"""
|
| 56 |
+
include_optimizer = kwargs.pop("include_optimizer", True)
|
| 57 |
+
save_format = kwargs.pop("save_format", False)
|
| 58 |
+
if save_format:
|
| 59 |
+
if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith(
|
| 60 |
+
".keras"
|
| 61 |
+
):
|
| 62 |
+
logging.warning(
|
| 63 |
+
"The `save_format` argument is deprecated in Keras 3. "
|
| 64 |
+
"We recommend removing this argument as it can be inferred "
|
| 65 |
+
"from the file path. "
|
| 66 |
+
f"Received: save_format={save_format}"
|
| 67 |
+
)
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(
|
| 70 |
+
"The `save_format` argument is deprecated in Keras 3. "
|
| 71 |
+
"Please remove this argument and pass a file path with "
|
| 72 |
+
"either `.keras` or `.h5` extension."
|
| 73 |
+
f"Received: save_format={save_format}"
|
| 74 |
+
)
|
| 75 |
+
if kwargs:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
"The following argument(s) are not supported: "
|
| 78 |
+
f"{list(kwargs.keys())}"
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Deprecation warnings
|
| 82 |
+
if str(filepath).endswith((".h5", ".hdf5")):
|
| 83 |
+
logging.warning(
|
| 84 |
+
"You are saving your model as an HDF5 file via "
|
| 85 |
+
"`model.save()` or `keras.saving.save_model(model)`. "
|
| 86 |
+
"This file format is considered legacy. "
|
| 87 |
+
"We recommend using instead the native Keras format, "
|
| 88 |
+
"e.g. `model.save('my_model.keras')` or "
|
| 89 |
+
"`keras.saving.save_model(model, 'my_model.keras')`. "
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
is_hf = str(filepath).startswith("hf://")
|
| 93 |
+
if zipped is None:
|
| 94 |
+
zipped = not is_hf # default behavior depends on destination
|
| 95 |
+
|
| 96 |
+
# If file exists and should not be overwritten.
|
| 97 |
+
try:
|
| 98 |
+
exists = (not is_hf) and os.path.exists(filepath)
|
| 99 |
+
except TypeError:
|
| 100 |
+
exists = False
|
| 101 |
+
if exists and not overwrite:
|
| 102 |
+
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
|
| 103 |
+
if not proceed:
|
| 104 |
+
return
|
| 105 |
+
|
| 106 |
+
if zipped and str(filepath).endswith(".keras"):
|
| 107 |
+
return saving_lib.save_model(model, filepath)
|
| 108 |
+
if not zipped:
|
| 109 |
+
return saving_lib.save_model(model, filepath, zipped=False)
|
| 110 |
+
if str(filepath).endswith((".h5", ".hdf5")):
|
| 111 |
+
return legacy_h5_format.save_model_to_hdf5(
|
| 112 |
+
model, filepath, overwrite, include_optimizer
|
| 113 |
+
)
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"Invalid filepath extension for saving. "
|
| 116 |
+
"Please add either a `.keras` extension for the native Keras "
|
| 117 |
+
f"format (recommended) or a `.h5` extension. "
|
| 118 |
+
"Use `model.export(filepath)` if you want to export a SavedModel "
|
| 119 |
+
"for use with TFLite/TFServing/etc. "
|
| 120 |
+
f"Received: filepath={filepath}."
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@keras_export(["keras.saving.load_model", "keras.models.load_model"])
|
| 125 |
+
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
| 126 |
+
"""Loads a model saved via `model.save()`.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
filepath: `str` or `pathlib.Path` object, path to the saved model file.
|
| 130 |
+
custom_objects: Optional dictionary mapping names
|
| 131 |
+
(strings) to custom classes or functions to be
|
| 132 |
+
considered during deserialization.
|
| 133 |
+
compile: Boolean, whether to compile the model after loading.
|
| 134 |
+
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
|
| 135 |
+
When `safe_mode=False`, loading an object has the potential to
|
| 136 |
+
trigger arbitrary code execution. This argument is only
|
| 137 |
+
applicable to the Keras v3 model format. Defaults to `True`.
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
A Keras model instance. If the original model was compiled,
|
| 141 |
+
and the argument `compile=True` is set, then the returned model
|
| 142 |
+
will be compiled. Otherwise, the model will be left uncompiled.
|
| 143 |
+
|
| 144 |
+
Example:
|
| 145 |
+
|
| 146 |
+
```python
|
| 147 |
+
model = keras.Sequential([
|
| 148 |
+
keras.layers.Dense(5, input_shape=(3,)),
|
| 149 |
+
keras.layers.Softmax()])
|
| 150 |
+
model.save("model.keras")
|
| 151 |
+
loaded_model = keras.saving.load_model("model.keras")
|
| 152 |
+
x = np.random.random((10, 3))
|
| 153 |
+
assert np.allclose(model.predict(x), loaded_model.predict(x))
|
| 154 |
+
```
|
| 155 |
+
|
| 156 |
+
Note that the model variables may have different name values
|
| 157 |
+
(`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
|
| 158 |
+
It is recommended that you use layer attributes to
|
| 159 |
+
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
|
| 160 |
+
"""
|
| 161 |
+
is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
|
| 162 |
+
filepath
|
| 163 |
+
)
|
| 164 |
+
is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
|
| 165 |
+
file_utils.join(filepath, "config.json")
|
| 166 |
+
)
|
| 167 |
+
is_hf = str(filepath).startswith("hf://")
|
| 168 |
+
|
| 169 |
+
# Support for remote zip files
|
| 170 |
+
if (
|
| 171 |
+
file_utils.is_remote_path(filepath)
|
| 172 |
+
and not file_utils.isdir(filepath)
|
| 173 |
+
and not is_keras_zip
|
| 174 |
+
and not is_hf
|
| 175 |
+
):
|
| 176 |
+
local_path = file_utils.join(
|
| 177 |
+
saving_lib.get_temp_dir(), os.path.basename(filepath)
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Copy from remote to temporary local directory
|
| 181 |
+
file_utils.copy(filepath, local_path)
|
| 182 |
+
|
| 183 |
+
# Switch filepath to local zipfile for loading model
|
| 184 |
+
if zipfile.is_zipfile(local_path):
|
| 185 |
+
filepath = local_path
|
| 186 |
+
is_keras_zip = True
|
| 187 |
+
|
| 188 |
+
if is_keras_zip or is_keras_dir or is_hf:
|
| 189 |
+
return saving_lib.load_model(
|
| 190 |
+
filepath,
|
| 191 |
+
custom_objects=custom_objects,
|
| 192 |
+
compile=compile,
|
| 193 |
+
safe_mode=safe_mode,
|
| 194 |
+
)
|
| 195 |
+
if str(filepath).endswith((".h5", ".hdf5")):
|
| 196 |
+
return legacy_h5_format.load_model_from_hdf5(
|
| 197 |
+
filepath, custom_objects=custom_objects, compile=compile
|
| 198 |
+
)
|
| 199 |
+
elif str(filepath).endswith(".keras"):
|
| 200 |
+
raise ValueError(
|
| 201 |
+
f"File not found: filepath={filepath}. "
|
| 202 |
+
"Please ensure the file is an accessible `.keras` "
|
| 203 |
+
"zip file."
|
| 204 |
+
)
|
| 205 |
+
else:
|
| 206 |
+
raise ValueError(
|
| 207 |
+
f"File format not supported: filepath={filepath}. "
|
| 208 |
+
"Keras 3 only supports V3 `.keras` files and "
|
| 209 |
+
"legacy H5 format files (`.h5` extension). "
|
| 210 |
+
"Note that the legacy SavedModel format is not "
|
| 211 |
+
"supported by `load_model()` in Keras 3. In "
|
| 212 |
+
"order to reload a TensorFlow SavedModel as an "
|
| 213 |
+
"inference-only layer in Keras 3, use "
|
| 214 |
+
"`keras.layers.TFSMLayer("
|
| 215 |
+
f"{filepath}, call_endpoint='serving_default')` "
|
| 216 |
+
"(note that your `call_endpoint` "
|
| 217 |
+
"might have a different name)."
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@keras_export("keras.saving.save_weights")
|
| 222 |
+
def save_weights(model, filepath, overwrite=True, **kwargs):
|
| 223 |
+
if not str(filepath).endswith(".weights.h5"):
|
| 224 |
+
raise ValueError(
|
| 225 |
+
"The filename must end in `.weights.h5`. "
|
| 226 |
+
f"Received: filepath={filepath}"
|
| 227 |
+
)
|
| 228 |
+
try:
|
| 229 |
+
exists = os.path.exists(filepath)
|
| 230 |
+
except TypeError:
|
| 231 |
+
exists = False
|
| 232 |
+
if exists and not overwrite:
|
| 233 |
+
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
|
| 234 |
+
if not proceed:
|
| 235 |
+
return
|
| 236 |
+
saving_lib.save_weights_only(model, filepath, **kwargs)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
@keras_export("keras.saving.load_weights")
|
| 240 |
+
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
|
| 241 |
+
if str(filepath).endswith(".keras"):
|
| 242 |
+
if kwargs:
|
| 243 |
+
raise ValueError(f"Invalid keyword arguments: {kwargs}")
|
| 244 |
+
saving_lib.load_weights_only(
|
| 245 |
+
model, filepath, skip_mismatch=skip_mismatch
|
| 246 |
+
)
|
| 247 |
+
elif str(filepath).endswith(".weights.h5"):
|
| 248 |
+
objects_to_skip = kwargs.pop("objects_to_skip", None)
|
| 249 |
+
if kwargs:
|
| 250 |
+
raise ValueError(f"Invalid keyword arguments: {kwargs}")
|
| 251 |
+
saving_lib.load_weights_only(
|
| 252 |
+
model,
|
| 253 |
+
filepath,
|
| 254 |
+
skip_mismatch=skip_mismatch,
|
| 255 |
+
objects_to_skip=objects_to_skip,
|
| 256 |
+
)
|
| 257 |
+
elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"):
|
| 258 |
+
by_name = kwargs.pop("by_name", False)
|
| 259 |
+
if kwargs:
|
| 260 |
+
raise ValueError(f"Invalid keyword arguments: {kwargs}")
|
| 261 |
+
if not h5py:
|
| 262 |
+
raise ImportError(
|
| 263 |
+
"Loading a H5 file requires `h5py` to be installed."
|
| 264 |
+
)
|
| 265 |
+
with h5py.File(filepath, "r") as f:
|
| 266 |
+
if "layer_names" not in f.attrs and "model_weights" in f:
|
| 267 |
+
f = f["model_weights"]
|
| 268 |
+
if by_name:
|
| 269 |
+
legacy_h5_format.load_weights_from_hdf5_group_by_name(
|
| 270 |
+
f, model, skip_mismatch
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
legacy_h5_format.load_weights_from_hdf5_group(f, model)
|
| 274 |
+
else:
|
| 275 |
+
raise ValueError(
|
| 276 |
+
f"File format not supported: filepath={filepath}. "
|
| 277 |
+
"Keras 3 only supports V3 `.keras` and `.weights.h5` "
|
| 278 |
+
"files, or legacy V1/V2 `.h5` files."
|
| 279 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py
ADDED
|
@@ -0,0 +1,1173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Python-based idempotent model-saving functionality."""
|
| 2 |
+
|
| 3 |
+
import datetime
|
| 4 |
+
import io
|
| 5 |
+
import json
|
| 6 |
+
import os
|
| 7 |
+
import pathlib
|
| 8 |
+
import shutil
|
| 9 |
+
import tempfile
|
| 10 |
+
import warnings
|
| 11 |
+
import zipfile
|
| 12 |
+
|
| 13 |
+
import ml_dtypes
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
from keras.src import backend
|
| 17 |
+
from keras.src.backend.common import global_state
|
| 18 |
+
from keras.src.layers.layer import Layer
|
| 19 |
+
from keras.src.losses.loss import Loss
|
| 20 |
+
from keras.src.metrics.metric import Metric
|
| 21 |
+
from keras.src.optimizers.optimizer import Optimizer
|
| 22 |
+
from keras.src.saving.serialization_lib import ObjectSharingScope
|
| 23 |
+
from keras.src.saving.serialization_lib import deserialize_keras_object
|
| 24 |
+
from keras.src.saving.serialization_lib import serialize_keras_object
|
| 25 |
+
from keras.src.trainers.compile_utils import CompileMetrics
|
| 26 |
+
from keras.src.utils import file_utils
|
| 27 |
+
from keras.src.utils import io_utils
|
| 28 |
+
from keras.src.utils import naming
|
| 29 |
+
from keras.src.utils import plot_model
|
| 30 |
+
from keras.src.utils.model_visualization import check_pydot
|
| 31 |
+
from keras.src.utils.summary_utils import weight_memory_size
|
| 32 |
+
from keras.src.version import __version__ as keras_version
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
import h5py
|
| 36 |
+
except ImportError:
|
| 37 |
+
h5py = None
|
| 38 |
+
try:
|
| 39 |
+
import psutil
|
| 40 |
+
except ImportError:
|
| 41 |
+
psutil = None
|
| 42 |
+
try:
|
| 43 |
+
import huggingface_hub
|
| 44 |
+
except ImportError:
|
| 45 |
+
huggingface_hub = None
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
_CONFIG_FILENAME = "config.json"
|
| 49 |
+
_METADATA_FILENAME = "metadata.json"
|
| 50 |
+
_VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5"
|
| 51 |
+
_VARS_FNAME_H5 = _VARS_FNAME + ".h5"
|
| 52 |
+
_VARS_FNAME_NPZ = _VARS_FNAME + ".npz"
|
| 53 |
+
_ASSETS_DIRNAME = "assets"
|
| 54 |
+
_MEMORY_UPPER_BOUND = 0.5 # 50%
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
_MODEL_CARD_TEMPLATE = """
|
| 58 |
+
---
|
| 59 |
+
library_name: keras
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
This model has been uploaded using the Keras library and can be used with JAX,
|
| 63 |
+
TensorFlow, and PyTorch backends.
|
| 64 |
+
|
| 65 |
+
This model card has been generated automatically and should be completed by the
|
| 66 |
+
model author.
|
| 67 |
+
See [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for
|
| 68 |
+
more information.
|
| 69 |
+
|
| 70 |
+
For more details about the model architecture, check out
|
| 71 |
+
[config.json](./config.json)."""
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def save_model(model, filepath, weights_format="h5", zipped=True):
|
| 75 |
+
"""Save a zip-archive representing a Keras model to the given file or path.
|
| 76 |
+
|
| 77 |
+
The zip-based archive contains the following structure:
|
| 78 |
+
|
| 79 |
+
- JSON-based configuration file (config.json): Records of model, layer, and
|
| 80 |
+
other saveables' configuration.
|
| 81 |
+
- H5-based saveable state files, found in respective directories, such as
|
| 82 |
+
model/states.npz, model/dense_layer/states.npz, etc.
|
| 83 |
+
- Metadata file.
|
| 84 |
+
|
| 85 |
+
The states of Keras saveables (layers, optimizers, loss, and metrics) are
|
| 86 |
+
automatically saved as long as they can be discovered through the attributes
|
| 87 |
+
returned by `dir(Model)`. Typically, the state includes the variables
|
| 88 |
+
associated with the saveable, but some specially purposed layers may
|
| 89 |
+
contain more such as the vocabularies stored in the hashmaps. The saveables
|
| 90 |
+
define how their states are saved by exposing `save_state()` and
|
| 91 |
+
`load_state()` APIs.
|
| 92 |
+
|
| 93 |
+
For the case of layer states, the variables will be visited as long as
|
| 94 |
+
they are either 1) referenced via layer attributes, or 2) referenced via a
|
| 95 |
+
container (list, tuple, or dict), and the container is referenced via a
|
| 96 |
+
layer attribute.
|
| 97 |
+
"""
|
| 98 |
+
if weights_format == "h5" and h5py is None:
|
| 99 |
+
raise ImportError("h5py must be installed in order to save a model.")
|
| 100 |
+
|
| 101 |
+
if not model.built:
|
| 102 |
+
warnings.warn(
|
| 103 |
+
"You are saving a model that has not yet been built. "
|
| 104 |
+
"It might not contain any weights yet. "
|
| 105 |
+
"Consider building the model first by calling it "
|
| 106 |
+
"on some data.",
|
| 107 |
+
stacklevel=2,
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
if isinstance(filepath, io.IOBase):
|
| 111 |
+
_save_model_to_fileobj(model, filepath, weights_format)
|
| 112 |
+
return
|
| 113 |
+
|
| 114 |
+
filepath = str(filepath)
|
| 115 |
+
is_hf = filepath.startswith("hf://")
|
| 116 |
+
if zipped and not filepath.endswith(".keras"):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
"Invalid `filepath` argument: expected a `.keras` extension. "
|
| 119 |
+
f"Received: filepath={filepath}"
|
| 120 |
+
)
|
| 121 |
+
if not zipped and filepath.endswith(".keras"):
|
| 122 |
+
raise ValueError(
|
| 123 |
+
"When using `zipped=False`, the `filepath` argument should not "
|
| 124 |
+
f"end in `.keras`. Received: filepath={filepath}"
|
| 125 |
+
)
|
| 126 |
+
if zipped and is_hf:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"When saving to the Hugging Face Hub, you should not save the "
|
| 129 |
+
f"model as zipped. Received: filepath={filepath}, zipped={zipped}"
|
| 130 |
+
)
|
| 131 |
+
if is_hf:
|
| 132 |
+
_upload_model_to_hf(model, filepath, weights_format)
|
| 133 |
+
elif not zipped:
|
| 134 |
+
_save_model_to_dir(model, filepath, weights_format)
|
| 135 |
+
else:
|
| 136 |
+
if file_utils.is_remote_path(filepath):
|
| 137 |
+
# Remote path. Zip to local memory byte io and copy to remote
|
| 138 |
+
zip_filepath = io.BytesIO()
|
| 139 |
+
_save_model_to_fileobj(model, zip_filepath, weights_format)
|
| 140 |
+
with file_utils.File(filepath, "wb") as f:
|
| 141 |
+
f.write(zip_filepath.getvalue())
|
| 142 |
+
else:
|
| 143 |
+
with open(filepath, "wb") as f:
|
| 144 |
+
_save_model_to_fileobj(model, f, weights_format)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def _serialize_model_as_json(model):
|
| 148 |
+
with ObjectSharingScope():
|
| 149 |
+
serialized_model_dict = serialize_keras_object(model)
|
| 150 |
+
config_json = json.dumps(serialized_model_dict)
|
| 151 |
+
metadata_json = json.dumps(
|
| 152 |
+
{
|
| 153 |
+
"keras_version": keras_version,
|
| 154 |
+
"date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
|
| 155 |
+
}
|
| 156 |
+
)
|
| 157 |
+
return config_json, metadata_json
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _save_model_to_dir(model, dirpath, weights_format):
|
| 161 |
+
if not file_utils.exists(dirpath):
|
| 162 |
+
file_utils.makedirs(dirpath)
|
| 163 |
+
config_json, metadata_json = _serialize_model_as_json(model)
|
| 164 |
+
with open(file_utils.join(dirpath, _METADATA_FILENAME), "w") as f:
|
| 165 |
+
f.write(metadata_json)
|
| 166 |
+
with open(file_utils.join(dirpath, _CONFIG_FILENAME), "w") as f:
|
| 167 |
+
f.write(config_json)
|
| 168 |
+
weights_filepath = file_utils.join(dirpath, _VARS_FNAME_H5)
|
| 169 |
+
assert_dirpath = file_utils.join(dirpath, _ASSETS_DIRNAME)
|
| 170 |
+
try:
|
| 171 |
+
if weights_format == "h5":
|
| 172 |
+
weights_store = H5IOStore(weights_filepath, mode="w")
|
| 173 |
+
elif weights_format == "npz":
|
| 174 |
+
weights_store = NpzIOStore(weights_filepath, mode="w")
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(
|
| 177 |
+
"Unknown `weights_format` argument. "
|
| 178 |
+
"Expected 'h5' or 'npz'. "
|
| 179 |
+
f"Received: weights_format={weights_format}"
|
| 180 |
+
)
|
| 181 |
+
asset_store = DiskIOStore(assert_dirpath, mode="w")
|
| 182 |
+
_save_state(
|
| 183 |
+
model,
|
| 184 |
+
weights_store=weights_store,
|
| 185 |
+
assets_store=asset_store,
|
| 186 |
+
inner_path="",
|
| 187 |
+
visited_saveables=set(),
|
| 188 |
+
)
|
| 189 |
+
finally:
|
| 190 |
+
weights_store.close()
|
| 191 |
+
asset_store.close()
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def _save_model_to_fileobj(model, fileobj, weights_format):
|
| 195 |
+
config_json, metadata_json = _serialize_model_as_json(model)
|
| 196 |
+
|
| 197 |
+
with zipfile.ZipFile(fileobj, "w") as zf:
|
| 198 |
+
with zf.open(_METADATA_FILENAME, "w") as f:
|
| 199 |
+
f.write(metadata_json.encode())
|
| 200 |
+
with zf.open(_CONFIG_FILENAME, "w") as f:
|
| 201 |
+
f.write(config_json.encode())
|
| 202 |
+
|
| 203 |
+
weights_file_path = None
|
| 204 |
+
weights_store = None
|
| 205 |
+
asset_store = None
|
| 206 |
+
write_zf = False
|
| 207 |
+
try:
|
| 208 |
+
if weights_format == "h5":
|
| 209 |
+
try:
|
| 210 |
+
if is_memory_sufficient(model):
|
| 211 |
+
# Load the model weights into memory before writing
|
| 212 |
+
# .keras if the system memory is sufficient.
|
| 213 |
+
weights_store = H5IOStore(
|
| 214 |
+
_VARS_FNAME_H5, archive=zf, mode="w"
|
| 215 |
+
)
|
| 216 |
+
else:
|
| 217 |
+
# Try opening the .h5 file, then writing it to `zf` at
|
| 218 |
+
# the end of the function call. This is more memory
|
| 219 |
+
# efficient than writing the weights into memory first.
|
| 220 |
+
working_dir = pathlib.Path(fileobj.name).parent
|
| 221 |
+
weights_file_path = tempfile.NamedTemporaryFile(
|
| 222 |
+
dir=working_dir
|
| 223 |
+
)
|
| 224 |
+
weights_store = H5IOStore(
|
| 225 |
+
weights_file_path.name, mode="w"
|
| 226 |
+
)
|
| 227 |
+
write_zf = True
|
| 228 |
+
except:
|
| 229 |
+
# If we can't use the local disk for any reason, write the
|
| 230 |
+
# weights into memory first, which consumes more memory.
|
| 231 |
+
weights_store = H5IOStore(
|
| 232 |
+
_VARS_FNAME_H5, archive=zf, mode="w"
|
| 233 |
+
)
|
| 234 |
+
elif weights_format == "npz":
|
| 235 |
+
weights_store = NpzIOStore(
|
| 236 |
+
_VARS_FNAME_NPZ, archive=zf, mode="w"
|
| 237 |
+
)
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(
|
| 240 |
+
"Unknown `weights_format` argument. "
|
| 241 |
+
"Expected 'h5' or 'npz'. "
|
| 242 |
+
f"Received: weights_format={weights_format}"
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w")
|
| 246 |
+
|
| 247 |
+
_save_state(
|
| 248 |
+
model,
|
| 249 |
+
weights_store=weights_store,
|
| 250 |
+
assets_store=asset_store,
|
| 251 |
+
inner_path="",
|
| 252 |
+
visited_saveables=set(),
|
| 253 |
+
)
|
| 254 |
+
except:
|
| 255 |
+
# Skip the final `zf.write` if any exception is raised
|
| 256 |
+
write_zf = False
|
| 257 |
+
if weights_store:
|
| 258 |
+
weights_store.archive = None
|
| 259 |
+
raise
|
| 260 |
+
finally:
|
| 261 |
+
if weights_store:
|
| 262 |
+
weights_store.close()
|
| 263 |
+
if asset_store:
|
| 264 |
+
asset_store.close()
|
| 265 |
+
if write_zf and weights_file_path:
|
| 266 |
+
zf.write(weights_file_path.name, _VARS_FNAME_H5)
|
| 267 |
+
if weights_file_path:
|
| 268 |
+
weights_file_path.close()
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def _upload_model_to_hf(model, hf_path, weights_format):
|
| 272 |
+
if huggingface_hub is None:
|
| 273 |
+
raise ImportError(
|
| 274 |
+
"To save models to the Hugging Face Hub, "
|
| 275 |
+
"you must install the `huggingface_hub` package."
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
original_hf_path = hf_path
|
| 279 |
+
if hf_path.startswith("hf://"):
|
| 280 |
+
hf_path = hf_path[5:]
|
| 281 |
+
if hf_path.count("/") > 1:
|
| 282 |
+
raise ValueError(
|
| 283 |
+
"Invalid `hf_path` argument: expected `namespace/model_name`"
|
| 284 |
+
f" format. Received: hf_path={original_hf_path}"
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
api = huggingface_hub.HfApi(
|
| 288 |
+
library_name="keras", library_version=keras_version
|
| 289 |
+
)
|
| 290 |
+
repo_url = api.create_repo(hf_path, exist_ok=True)
|
| 291 |
+
repo_id = repo_url.repo_id
|
| 292 |
+
|
| 293 |
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
| 294 |
+
_save_model_to_dir(model, tmp_dir, weights_format)
|
| 295 |
+
|
| 296 |
+
model_card = _MODEL_CARD_TEMPLATE
|
| 297 |
+
|
| 298 |
+
if check_pydot():
|
| 299 |
+
plot_path = file_utils.join(tmp_dir, "assets", "summary_plot.png")
|
| 300 |
+
plot_model(
|
| 301 |
+
model,
|
| 302 |
+
to_file=plot_path,
|
| 303 |
+
show_layer_names=True,
|
| 304 |
+
show_shapes=True,
|
| 305 |
+
show_dtype=True,
|
| 306 |
+
)
|
| 307 |
+
if len(model.layers) <= 10:
|
| 308 |
+
model_card += "\n\n"
|
| 309 |
+
else:
|
| 310 |
+
model_card += (
|
| 311 |
+
"A plot of the model can be found "
|
| 312 |
+
"[here](./assets/summary_plot.png)."
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
with open(file_utils.join(tmp_dir, "README.md"), "w") as f:
|
| 316 |
+
f.write(model_card)
|
| 317 |
+
|
| 318 |
+
api.upload_folder(
|
| 319 |
+
repo_id=repo_id,
|
| 320 |
+
folder_path=tmp_dir,
|
| 321 |
+
commit_message="Save model using Keras.",
|
| 322 |
+
)
|
| 323 |
+
io_utils.print_msg(
|
| 324 |
+
f"Model saved to the Hugging Face Hub: {repo_url}\n"
|
| 325 |
+
"To load back the model, use "
|
| 326 |
+
f"`keras.saving.load_model('hf://{repo_id}')`"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
|
| 331 |
+
"""Load a zip archive representing a Keras model."""
|
| 332 |
+
if isinstance(filepath, io.IOBase):
|
| 333 |
+
return _load_model_from_fileobj(
|
| 334 |
+
filepath, custom_objects, compile, safe_mode
|
| 335 |
+
)
|
| 336 |
+
elif str(filepath).startswith("hf://"):
|
| 337 |
+
if huggingface_hub is None:
|
| 338 |
+
raise ImportError(
|
| 339 |
+
"To load models from the Hugging Face Hub, "
|
| 340 |
+
"you must install the `huggingface_hub` package."
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
repo_id = filepath[5:]
|
| 344 |
+
folder_path = huggingface_hub.snapshot_download(
|
| 345 |
+
repo_id=repo_id,
|
| 346 |
+
library_name="keras",
|
| 347 |
+
library_version=keras_version,
|
| 348 |
+
)
|
| 349 |
+
return _load_model_from_dir(
|
| 350 |
+
folder_path, custom_objects, compile, safe_mode
|
| 351 |
+
)
|
| 352 |
+
else:
|
| 353 |
+
filepath = str(filepath)
|
| 354 |
+
if not filepath.endswith(".keras"):
|
| 355 |
+
is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
|
| 356 |
+
file_utils.join(filepath, "config.json")
|
| 357 |
+
)
|
| 358 |
+
if is_keras_dir:
|
| 359 |
+
return _load_model_from_dir(
|
| 360 |
+
filepath, custom_objects, compile, safe_mode
|
| 361 |
+
)
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"Invalid filename: expected a `.keras` extension. "
|
| 364 |
+
f"Received: filepath={filepath}"
|
| 365 |
+
)
|
| 366 |
+
with open(filepath, "rb") as f:
|
| 367 |
+
return _load_model_from_fileobj(
|
| 368 |
+
f, custom_objects, compile, safe_mode
|
| 369 |
+
)
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def _load_model_from_dir(dirpath, custom_objects, compile, safe_mode):
|
| 373 |
+
if not file_utils.exists(dirpath):
|
| 374 |
+
raise ValueError(f"Directory doesn't exist: {dirpath}")
|
| 375 |
+
if not file_utils.isdir(dirpath):
|
| 376 |
+
raise ValueError(f"Path isn't a directory: {dirpath}")
|
| 377 |
+
|
| 378 |
+
with open(file_utils.join(dirpath, _CONFIG_FILENAME), "r") as f:
|
| 379 |
+
config_json = f.read()
|
| 380 |
+
model = _model_from_config(config_json, custom_objects, compile, safe_mode)
|
| 381 |
+
|
| 382 |
+
all_filenames = file_utils.listdir(dirpath)
|
| 383 |
+
try:
|
| 384 |
+
if _VARS_FNAME_H5 in all_filenames:
|
| 385 |
+
weights_file_path = file_utils.join(dirpath, _VARS_FNAME_H5)
|
| 386 |
+
weights_store = H5IOStore(weights_file_path, mode="r")
|
| 387 |
+
elif _VARS_FNAME_NPZ in all_filenames:
|
| 388 |
+
weights_file_path = file_utils.join(dirpath, _VARS_FNAME_NPZ)
|
| 389 |
+
weights_store = NpzIOStore(weights_file_path, mode="r")
|
| 390 |
+
else:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file."
|
| 393 |
+
)
|
| 394 |
+
if len(all_filenames) > 3:
|
| 395 |
+
asset_store = DiskIOStore(
|
| 396 |
+
file_utils.join(dirpath, _ASSETS_DIRNAME), mode="r"
|
| 397 |
+
)
|
| 398 |
+
|
| 399 |
+
else:
|
| 400 |
+
asset_store = None
|
| 401 |
+
|
| 402 |
+
failed_saveables = set()
|
| 403 |
+
error_msgs = {}
|
| 404 |
+
_load_state(
|
| 405 |
+
model,
|
| 406 |
+
weights_store=weights_store,
|
| 407 |
+
assets_store=asset_store,
|
| 408 |
+
inner_path="",
|
| 409 |
+
visited_saveables=set(),
|
| 410 |
+
failed_saveables=failed_saveables,
|
| 411 |
+
error_msgs=error_msgs,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
finally:
|
| 415 |
+
weights_store.close()
|
| 416 |
+
if asset_store:
|
| 417 |
+
asset_store.close()
|
| 418 |
+
|
| 419 |
+
if failed_saveables:
|
| 420 |
+
_raise_loading_failure(error_msgs)
|
| 421 |
+
return model
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def _model_from_config(config_json, custom_objects, compile, safe_mode):
|
| 425 |
+
# Note: we should NOT use a custom JSON decoder. Anything that
|
| 426 |
+
# needs custom decoding must be handled in deserialize_keras_object.
|
| 427 |
+
config_dict = json.loads(config_json)
|
| 428 |
+
if not compile:
|
| 429 |
+
# Disable compilation
|
| 430 |
+
config_dict["compile_config"] = None
|
| 431 |
+
# Construct the model from the configuration file in the archive.
|
| 432 |
+
with ObjectSharingScope():
|
| 433 |
+
model = deserialize_keras_object(
|
| 434 |
+
config_dict, custom_objects, safe_mode=safe_mode
|
| 435 |
+
)
|
| 436 |
+
return model
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
|
| 440 |
+
with zipfile.ZipFile(fileobj, "r") as zf:
|
| 441 |
+
with zf.open(_CONFIG_FILENAME, "r") as f:
|
| 442 |
+
config_json = f.read()
|
| 443 |
+
|
| 444 |
+
model = _model_from_config(
|
| 445 |
+
config_json, custom_objects, compile, safe_mode
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
all_filenames = zf.namelist()
|
| 449 |
+
extract_dir = None
|
| 450 |
+
weights_store = None
|
| 451 |
+
asset_store = None
|
| 452 |
+
try:
|
| 453 |
+
if _VARS_FNAME_H5 in all_filenames:
|
| 454 |
+
try:
|
| 455 |
+
if is_memory_sufficient(model):
|
| 456 |
+
# Load the entire file into memory if the system memory
|
| 457 |
+
# is sufficient.
|
| 458 |
+
io_file = io.BytesIO(
|
| 459 |
+
zf.open(_VARS_FNAME_H5, "r").read()
|
| 460 |
+
)
|
| 461 |
+
weights_store = H5IOStore(io_file, mode="r")
|
| 462 |
+
else:
|
| 463 |
+
# Try extracting the model.weights.h5 file, and then
|
| 464 |
+
# loading it using using h5py. This is significantly
|
| 465 |
+
# faster than reading from the zip archive on the fly.
|
| 466 |
+
extract_dir = tempfile.TemporaryDirectory(
|
| 467 |
+
dir=pathlib.Path(fileobj.name).parent
|
| 468 |
+
)
|
| 469 |
+
zf.extract(_VARS_FNAME_H5, extract_dir.name)
|
| 470 |
+
weights_store = H5IOStore(
|
| 471 |
+
pathlib.Path(extract_dir.name, _VARS_FNAME_H5),
|
| 472 |
+
mode="r",
|
| 473 |
+
)
|
| 474 |
+
except:
|
| 475 |
+
# If we can't use the local disk for any reason, read the
|
| 476 |
+
# weights from the zip archive on the fly, which is less
|
| 477 |
+
# efficient.
|
| 478 |
+
weights_store = H5IOStore(_VARS_FNAME_H5, zf, mode="r")
|
| 479 |
+
elif _VARS_FNAME_NPZ in all_filenames:
|
| 480 |
+
weights_store = NpzIOStore(_VARS_FNAME_NPZ, zf, mode="r")
|
| 481 |
+
else:
|
| 482 |
+
raise ValueError(
|
| 483 |
+
f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file."
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if len(all_filenames) > 3:
|
| 487 |
+
asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r")
|
| 488 |
+
|
| 489 |
+
failed_saveables = set()
|
| 490 |
+
error_msgs = {}
|
| 491 |
+
_load_state(
|
| 492 |
+
model,
|
| 493 |
+
weights_store=weights_store,
|
| 494 |
+
assets_store=asset_store,
|
| 495 |
+
inner_path="",
|
| 496 |
+
visited_saveables=set(),
|
| 497 |
+
failed_saveables=failed_saveables,
|
| 498 |
+
error_msgs=error_msgs,
|
| 499 |
+
)
|
| 500 |
+
finally:
|
| 501 |
+
if weights_store:
|
| 502 |
+
weights_store.close()
|
| 503 |
+
if asset_store:
|
| 504 |
+
asset_store.close()
|
| 505 |
+
if extract_dir:
|
| 506 |
+
extract_dir.cleanup()
|
| 507 |
+
|
| 508 |
+
if failed_saveables:
|
| 509 |
+
_raise_loading_failure(error_msgs)
|
| 510 |
+
return model
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def save_weights_only(model, filepath, objects_to_skip=None):
|
| 514 |
+
"""Save only the weights of a model to a target filepath.
|
| 515 |
+
|
| 516 |
+
Supports both `.weights.h5` and `.keras`.
|
| 517 |
+
"""
|
| 518 |
+
if not model.built:
|
| 519 |
+
raise ValueError(
|
| 520 |
+
"You are saving a model that has not yet been built. "
|
| 521 |
+
"Try building the model first by calling it on some data or "
|
| 522 |
+
"by using `build()`."
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
filepath = str(filepath)
|
| 526 |
+
tmp_dir = None
|
| 527 |
+
remote_filepath = None
|
| 528 |
+
if not filepath.endswith(".weights.h5"):
|
| 529 |
+
raise ValueError(
|
| 530 |
+
"Invalid `filepath` argument: expected a `.weights.h5` extension. "
|
| 531 |
+
f"Received: filepath={filepath}"
|
| 532 |
+
)
|
| 533 |
+
try:
|
| 534 |
+
if file_utils.is_remote_path(filepath):
|
| 535 |
+
tmp_dir = get_temp_dir()
|
| 536 |
+
local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))
|
| 537 |
+
remote_filepath = filepath
|
| 538 |
+
filepath = local_filepath
|
| 539 |
+
|
| 540 |
+
weights_store = H5IOStore(filepath, mode="w")
|
| 541 |
+
if objects_to_skip is not None:
|
| 542 |
+
visited_saveables = set(id(o) for o in objects_to_skip)
|
| 543 |
+
else:
|
| 544 |
+
visited_saveables = set()
|
| 545 |
+
_save_state(
|
| 546 |
+
model,
|
| 547 |
+
weights_store=weights_store,
|
| 548 |
+
assets_store=None,
|
| 549 |
+
inner_path="",
|
| 550 |
+
visited_saveables=visited_saveables,
|
| 551 |
+
)
|
| 552 |
+
weights_store.close()
|
| 553 |
+
finally:
|
| 554 |
+
if tmp_dir is not None:
|
| 555 |
+
file_utils.copy(filepath, remote_filepath)
|
| 556 |
+
shutil.rmtree(tmp_dir)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
def load_weights_only(
|
| 560 |
+
model, filepath, skip_mismatch=False, objects_to_skip=None
|
| 561 |
+
):
|
| 562 |
+
"""Load the weights of a model from a filepath (.keras or .weights.h5).
|
| 563 |
+
|
| 564 |
+
Note: only supports h5 for now.
|
| 565 |
+
"""
|
| 566 |
+
if not model.built:
|
| 567 |
+
raise ValueError(
|
| 568 |
+
"You are loading weights into a model that has not yet been built. "
|
| 569 |
+
"Try building the model first by calling it on some data or "
|
| 570 |
+
"by using `build()`."
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
archive = None
|
| 574 |
+
tmp_dir = None
|
| 575 |
+
filepath = str(filepath)
|
| 576 |
+
|
| 577 |
+
try:
|
| 578 |
+
if file_utils.is_remote_path(filepath):
|
| 579 |
+
tmp_dir = get_temp_dir()
|
| 580 |
+
local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))
|
| 581 |
+
file_utils.copy(filepath, local_filepath)
|
| 582 |
+
filepath = local_filepath
|
| 583 |
+
|
| 584 |
+
if filepath.endswith(".weights.h5"):
|
| 585 |
+
weights_store = H5IOStore(filepath, mode="r")
|
| 586 |
+
elif filepath.endswith(".keras"):
|
| 587 |
+
archive = zipfile.ZipFile(filepath, "r")
|
| 588 |
+
weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode="r")
|
| 589 |
+
|
| 590 |
+
failed_saveables = set()
|
| 591 |
+
if objects_to_skip is not None:
|
| 592 |
+
visited_saveables = set(id(o) for o in objects_to_skip)
|
| 593 |
+
else:
|
| 594 |
+
visited_saveables = set()
|
| 595 |
+
error_msgs = {}
|
| 596 |
+
_load_state(
|
| 597 |
+
model,
|
| 598 |
+
weights_store=weights_store,
|
| 599 |
+
assets_store=None,
|
| 600 |
+
inner_path="",
|
| 601 |
+
skip_mismatch=skip_mismatch,
|
| 602 |
+
visited_saveables=visited_saveables,
|
| 603 |
+
failed_saveables=failed_saveables,
|
| 604 |
+
error_msgs=error_msgs,
|
| 605 |
+
)
|
| 606 |
+
weights_store.close()
|
| 607 |
+
if archive:
|
| 608 |
+
archive.close()
|
| 609 |
+
|
| 610 |
+
if failed_saveables:
|
| 611 |
+
_raise_loading_failure(error_msgs, warn_only=skip_mismatch)
|
| 612 |
+
finally:
|
| 613 |
+
if tmp_dir is not None:
|
| 614 |
+
shutil.rmtree(tmp_dir)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
def _raise_loading_failure(error_msgs, warn_only=False):
|
| 618 |
+
first_key = list(error_msgs.keys())[0]
|
| 619 |
+
ex_saveable, ex_error = error_msgs[first_key]
|
| 620 |
+
msg = (
|
| 621 |
+
f"A total of {len(error_msgs)} objects could not "
|
| 622 |
+
"be loaded. Example error message for "
|
| 623 |
+
f"object {ex_saveable}:\n\n"
|
| 624 |
+
f"{ex_error}\n\n"
|
| 625 |
+
"List of objects that could not be loaded:\n"
|
| 626 |
+
f"{[x[0] for x in error_msgs.values()]}"
|
| 627 |
+
)
|
| 628 |
+
if warn_only:
|
| 629 |
+
warnings.warn(msg)
|
| 630 |
+
else:
|
| 631 |
+
raise ValueError(msg)
|
| 632 |
+
|
| 633 |
+
|
| 634 |
+
def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):
|
| 635 |
+
if not file_utils.isdir(system_path):
|
| 636 |
+
zipfile_to_save.write(system_path, zip_path)
|
| 637 |
+
else:
|
| 638 |
+
for file_name in file_utils.listdir(system_path):
|
| 639 |
+
system_file_path = file_utils.join(system_path, file_name).replace(
|
| 640 |
+
"\\", "/"
|
| 641 |
+
)
|
| 642 |
+
zip_file_path = file_utils.join(zip_path, file_name).replace(
|
| 643 |
+
"\\", "/"
|
| 644 |
+
)
|
| 645 |
+
_write_to_zip_recursively(
|
| 646 |
+
zipfile_to_save, system_file_path, zip_file_path
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
|
| 650 |
+
def _name_key(name):
|
| 651 |
+
"""Make sure that private attributes are visited last."""
|
| 652 |
+
if name.startswith("_"):
|
| 653 |
+
return "~" + name
|
| 654 |
+
return name
|
| 655 |
+
|
| 656 |
+
|
| 657 |
+
def _walk_saveable(saveable):
|
| 658 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 659 |
+
|
| 660 |
+
if not isinstance(saveable, KerasSaveable):
|
| 661 |
+
raise ValueError(
|
| 662 |
+
"Expected object to be an "
|
| 663 |
+
"instance of `KerasSaveable`, but "
|
| 664 |
+
f"got {saveable} of type {type(saveable)}"
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
obj_type = saveable._obj_type()
|
| 668 |
+
attr_skipset = get_attr_skipset(obj_type)
|
| 669 |
+
|
| 670 |
+
# Save all layers directly tracked by Sequential and Functional first.
|
| 671 |
+
# This helps avoid ordering concerns for subclassed Sequential or Functional
|
| 672 |
+
# models with extra attributes--the internal Keras state take precedence.
|
| 673 |
+
if obj_type in ("Sequential", "Functional"):
|
| 674 |
+
yield "layers", saveable.layers
|
| 675 |
+
|
| 676 |
+
for child_attr in sorted(dir(saveable), key=lambda x: _name_key(x)):
|
| 677 |
+
if child_attr.startswith("__") or child_attr in attr_skipset:
|
| 678 |
+
continue
|
| 679 |
+
try:
|
| 680 |
+
child_obj = getattr(saveable, child_attr)
|
| 681 |
+
except Exception:
|
| 682 |
+
# Avoid raising the exception when visiting the attributes.
|
| 683 |
+
continue
|
| 684 |
+
yield child_attr, child_obj
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def _save_state(
|
| 688 |
+
saveable,
|
| 689 |
+
weights_store,
|
| 690 |
+
assets_store,
|
| 691 |
+
inner_path,
|
| 692 |
+
visited_saveables,
|
| 693 |
+
):
|
| 694 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 695 |
+
|
| 696 |
+
# If the saveable has already been saved, skip it.
|
| 697 |
+
if id(saveable) in visited_saveables:
|
| 698 |
+
return
|
| 699 |
+
|
| 700 |
+
if hasattr(saveable, "save_own_variables") and weights_store:
|
| 701 |
+
if hasattr(saveable, "name") and isinstance(saveable.name, str):
|
| 702 |
+
metadata = {"name": saveable.name}
|
| 703 |
+
else:
|
| 704 |
+
metadata = None
|
| 705 |
+
saveable.save_own_variables(
|
| 706 |
+
weights_store.make(inner_path, metadata=metadata)
|
| 707 |
+
)
|
| 708 |
+
if hasattr(saveable, "save_assets") and assets_store:
|
| 709 |
+
saveable.save_assets(assets_store.make(inner_path))
|
| 710 |
+
|
| 711 |
+
visited_saveables.add(id(saveable))
|
| 712 |
+
|
| 713 |
+
# Recursively save state of children saveables (layers, optimizers, etc.)
|
| 714 |
+
for child_attr, child_obj in _walk_saveable(saveable):
|
| 715 |
+
if isinstance(child_obj, KerasSaveable):
|
| 716 |
+
_save_state(
|
| 717 |
+
child_obj,
|
| 718 |
+
weights_store,
|
| 719 |
+
assets_store,
|
| 720 |
+
inner_path=file_utils.join(inner_path, child_attr).replace(
|
| 721 |
+
"\\", "/"
|
| 722 |
+
),
|
| 723 |
+
visited_saveables=visited_saveables,
|
| 724 |
+
)
|
| 725 |
+
elif isinstance(child_obj, (list, dict, tuple, set)):
|
| 726 |
+
_save_container_state(
|
| 727 |
+
child_obj,
|
| 728 |
+
weights_store,
|
| 729 |
+
assets_store,
|
| 730 |
+
inner_path=file_utils.join(inner_path, child_attr).replace(
|
| 731 |
+
"\\", "/"
|
| 732 |
+
),
|
| 733 |
+
visited_saveables=visited_saveables,
|
| 734 |
+
)
|
| 735 |
+
|
| 736 |
+
|
| 737 |
+
def _load_state(
|
| 738 |
+
saveable,
|
| 739 |
+
weights_store,
|
| 740 |
+
assets_store,
|
| 741 |
+
inner_path,
|
| 742 |
+
skip_mismatch=False,
|
| 743 |
+
visited_saveables=None,
|
| 744 |
+
failed_saveables=None,
|
| 745 |
+
error_msgs=None,
|
| 746 |
+
):
|
| 747 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 748 |
+
|
| 749 |
+
if visited_saveables and id(saveable) in visited_saveables:
|
| 750 |
+
return
|
| 751 |
+
|
| 752 |
+
failure = False
|
| 753 |
+
|
| 754 |
+
if hasattr(saveable, "load_own_variables") and weights_store:
|
| 755 |
+
if skip_mismatch or failed_saveables is not None:
|
| 756 |
+
try:
|
| 757 |
+
saveable.load_own_variables(weights_store.get(inner_path))
|
| 758 |
+
except Exception as e:
|
| 759 |
+
failed_saveables.add(id(saveable))
|
| 760 |
+
error_msgs[id(saveable)] = saveable, e
|
| 761 |
+
failure = True
|
| 762 |
+
else:
|
| 763 |
+
saveable.load_own_variables(weights_store.get(inner_path))
|
| 764 |
+
|
| 765 |
+
if hasattr(saveable, "load_assets") and assets_store:
|
| 766 |
+
if skip_mismatch or failed_saveables is not None:
|
| 767 |
+
try:
|
| 768 |
+
saveable.load_assets(assets_store.get(inner_path))
|
| 769 |
+
except Exception as e:
|
| 770 |
+
failed_saveables.add(id(saveable))
|
| 771 |
+
error_msgs[id(saveable)] = saveable, e
|
| 772 |
+
failure = True
|
| 773 |
+
else:
|
| 774 |
+
saveable.load_assets(assets_store.get(inner_path))
|
| 775 |
+
|
| 776 |
+
if failed_saveables is not None:
|
| 777 |
+
currently_failed = len(failed_saveables)
|
| 778 |
+
else:
|
| 779 |
+
currently_failed = 0
|
| 780 |
+
|
| 781 |
+
# Recursively load states for Keras saveables such as layers/optimizers.
|
| 782 |
+
for child_attr, child_obj in _walk_saveable(saveable):
|
| 783 |
+
if isinstance(child_obj, KerasSaveable):
|
| 784 |
+
_load_state(
|
| 785 |
+
child_obj,
|
| 786 |
+
weights_store,
|
| 787 |
+
assets_store,
|
| 788 |
+
inner_path=file_utils.join(inner_path, child_attr).replace(
|
| 789 |
+
"\\", "/"
|
| 790 |
+
),
|
| 791 |
+
skip_mismatch=skip_mismatch,
|
| 792 |
+
visited_saveables=visited_saveables,
|
| 793 |
+
failed_saveables=failed_saveables,
|
| 794 |
+
error_msgs=error_msgs,
|
| 795 |
+
)
|
| 796 |
+
elif isinstance(child_obj, (list, dict, tuple, set)):
|
| 797 |
+
_load_container_state(
|
| 798 |
+
child_obj,
|
| 799 |
+
weights_store,
|
| 800 |
+
assets_store,
|
| 801 |
+
inner_path=file_utils.join(inner_path, child_attr).replace(
|
| 802 |
+
"\\", "/"
|
| 803 |
+
),
|
| 804 |
+
skip_mismatch=skip_mismatch,
|
| 805 |
+
visited_saveables=visited_saveables,
|
| 806 |
+
failed_saveables=failed_saveables,
|
| 807 |
+
error_msgs=error_msgs,
|
| 808 |
+
)
|
| 809 |
+
|
| 810 |
+
if failed_saveables is not None:
|
| 811 |
+
newly_failed = len(failed_saveables) - currently_failed
|
| 812 |
+
else:
|
| 813 |
+
newly_failed = 0
|
| 814 |
+
|
| 815 |
+
if not failure:
|
| 816 |
+
if visited_saveables is not None and newly_failed <= 0:
|
| 817 |
+
visited_saveables.add(id(saveable))
|
| 818 |
+
if id(saveable) in failed_saveables:
|
| 819 |
+
failed_saveables.remove(id(saveable))
|
| 820 |
+
error_msgs.pop(id(saveable))
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def _save_container_state(
|
| 824 |
+
container, weights_store, assets_store, inner_path, visited_saveables
|
| 825 |
+
):
|
| 826 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 827 |
+
|
| 828 |
+
used_names = {}
|
| 829 |
+
if isinstance(container, dict):
|
| 830 |
+
container = list(container.values())
|
| 831 |
+
|
| 832 |
+
for saveable in container:
|
| 833 |
+
if isinstance(saveable, KerasSaveable):
|
| 834 |
+
# Do NOT address the saveable via `saveable.name`, since
|
| 835 |
+
# names are usually autogenerated and thus not reproducible
|
| 836 |
+
# (i.e. they may vary across two instances of the same model).
|
| 837 |
+
name = naming.to_snake_case(saveable.__class__.__name__)
|
| 838 |
+
if name in used_names:
|
| 839 |
+
used_names[name] += 1
|
| 840 |
+
name = f"{name}_{used_names[name]}"
|
| 841 |
+
else:
|
| 842 |
+
used_names[name] = 0
|
| 843 |
+
_save_state(
|
| 844 |
+
saveable,
|
| 845 |
+
weights_store,
|
| 846 |
+
assets_store,
|
| 847 |
+
inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
|
| 848 |
+
visited_saveables=visited_saveables,
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
|
| 852 |
+
def _load_container_state(
|
| 853 |
+
container,
|
| 854 |
+
weights_store,
|
| 855 |
+
assets_store,
|
| 856 |
+
inner_path,
|
| 857 |
+
skip_mismatch,
|
| 858 |
+
visited_saveables,
|
| 859 |
+
failed_saveables,
|
| 860 |
+
error_msgs,
|
| 861 |
+
):
|
| 862 |
+
from keras.src.saving.keras_saveable import KerasSaveable
|
| 863 |
+
|
| 864 |
+
used_names = {}
|
| 865 |
+
if isinstance(container, dict):
|
| 866 |
+
container = list(container.values())
|
| 867 |
+
|
| 868 |
+
for saveable in container:
|
| 869 |
+
if isinstance(saveable, KerasSaveable):
|
| 870 |
+
name = naming.to_snake_case(saveable.__class__.__name__)
|
| 871 |
+
if name in used_names:
|
| 872 |
+
used_names[name] += 1
|
| 873 |
+
name = f"{name}_{used_names[name]}"
|
| 874 |
+
else:
|
| 875 |
+
used_names[name] = 0
|
| 876 |
+
_load_state(
|
| 877 |
+
saveable,
|
| 878 |
+
weights_store,
|
| 879 |
+
assets_store,
|
| 880 |
+
inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
|
| 881 |
+
skip_mismatch=skip_mismatch,
|
| 882 |
+
visited_saveables=visited_saveables,
|
| 883 |
+
failed_saveables=failed_saveables,
|
| 884 |
+
error_msgs=error_msgs,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
|
| 888 |
+
class DiskIOStore:
|
| 889 |
+
"""Asset store backed by disk storage.
|
| 890 |
+
|
| 891 |
+
If `archive` is specified, then `root_path` refers to the filename
|
| 892 |
+
inside the archive.
|
| 893 |
+
|
| 894 |
+
If `archive` is not specified, then `root_path` refers to the full path of
|
| 895 |
+
the target directory.
|
| 896 |
+
"""
|
| 897 |
+
|
| 898 |
+
def __init__(self, root_path, archive=None, mode=None):
|
| 899 |
+
self.mode = mode
|
| 900 |
+
self.root_path = root_path
|
| 901 |
+
self.archive = archive
|
| 902 |
+
self.tmp_dir = None
|
| 903 |
+
if self.archive:
|
| 904 |
+
self.tmp_dir = get_temp_dir()
|
| 905 |
+
if self.mode == "r":
|
| 906 |
+
self.archive.extractall(path=self.tmp_dir)
|
| 907 |
+
self.working_dir = file_utils.join(
|
| 908 |
+
self.tmp_dir, self.root_path
|
| 909 |
+
).replace("\\", "/")
|
| 910 |
+
if self.mode == "w":
|
| 911 |
+
file_utils.makedirs(self.working_dir)
|
| 912 |
+
else:
|
| 913 |
+
if mode == "r":
|
| 914 |
+
self.working_dir = root_path
|
| 915 |
+
else:
|
| 916 |
+
self.tmp_dir = get_temp_dir()
|
| 917 |
+
self.working_dir = file_utils.join(
|
| 918 |
+
self.tmp_dir, self.root_path
|
| 919 |
+
).replace("\\", "/")
|
| 920 |
+
file_utils.makedirs(self.working_dir)
|
| 921 |
+
|
| 922 |
+
def make(self, path):
|
| 923 |
+
if not path:
|
| 924 |
+
return self.working_dir
|
| 925 |
+
path = file_utils.join(self.working_dir, path).replace("\\", "/")
|
| 926 |
+
if not file_utils.exists(path):
|
| 927 |
+
file_utils.makedirs(path)
|
| 928 |
+
return path
|
| 929 |
+
|
| 930 |
+
def get(self, path):
|
| 931 |
+
if not path:
|
| 932 |
+
return self.working_dir
|
| 933 |
+
path = file_utils.join(self.working_dir, path).replace("\\", "/")
|
| 934 |
+
if file_utils.exists(path):
|
| 935 |
+
return path
|
| 936 |
+
return None
|
| 937 |
+
|
| 938 |
+
def close(self):
|
| 939 |
+
if self.mode == "w" and self.archive:
|
| 940 |
+
_write_to_zip_recursively(
|
| 941 |
+
self.archive, self.working_dir, self.root_path
|
| 942 |
+
)
|
| 943 |
+
if self.tmp_dir and file_utils.exists(self.tmp_dir):
|
| 944 |
+
file_utils.rmtree(self.tmp_dir)
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
class H5IOStore:
|
| 948 |
+
def __init__(self, root_path, archive=None, mode="r"):
|
| 949 |
+
"""Numerical variable store backed by HDF5.
|
| 950 |
+
|
| 951 |
+
If `archive` is specified, then `root_path` refers to the filename
|
| 952 |
+
inside the archive.
|
| 953 |
+
|
| 954 |
+
If `archive` is not specified, then `root_path` refers to the path of
|
| 955 |
+
the h5 file on disk.
|
| 956 |
+
"""
|
| 957 |
+
self.root_path = root_path
|
| 958 |
+
self.mode = mode
|
| 959 |
+
self.archive = archive
|
| 960 |
+
self.io_file = None
|
| 961 |
+
|
| 962 |
+
if self.archive:
|
| 963 |
+
if self.mode == "w":
|
| 964 |
+
self.io_file = io.BytesIO()
|
| 965 |
+
else:
|
| 966 |
+
self.io_file = self.archive.open(self.root_path, "r")
|
| 967 |
+
self.h5_file = h5py.File(self.io_file, mode=self.mode)
|
| 968 |
+
else:
|
| 969 |
+
self.h5_file = h5py.File(root_path, mode=self.mode)
|
| 970 |
+
|
| 971 |
+
def make(self, path, metadata=None):
|
| 972 |
+
return H5Entry(self.h5_file, path, mode="w", metadata=metadata)
|
| 973 |
+
|
| 974 |
+
def get(self, path):
|
| 975 |
+
return H5Entry(self.h5_file, path, mode="r")
|
| 976 |
+
|
| 977 |
+
def close(self):
|
| 978 |
+
self.h5_file.close()
|
| 979 |
+
if self.mode == "w" and self.archive:
|
| 980 |
+
self.archive.writestr(self.root_path, self.io_file.getvalue())
|
| 981 |
+
if self.io_file:
|
| 982 |
+
self.io_file.close()
|
| 983 |
+
|
| 984 |
+
|
| 985 |
+
class H5Entry:
|
| 986 |
+
"""Leaf entry in a H5IOStore."""
|
| 987 |
+
|
| 988 |
+
def __init__(self, h5_file, path, mode, metadata=None):
|
| 989 |
+
self.h5_file = h5_file
|
| 990 |
+
self.path = path
|
| 991 |
+
self.mode = mode
|
| 992 |
+
self.metadata = metadata
|
| 993 |
+
|
| 994 |
+
if mode == "w":
|
| 995 |
+
if not path:
|
| 996 |
+
self.group = self.h5_file.create_group("vars")
|
| 997 |
+
else:
|
| 998 |
+
self.group = self.h5_file.create_group(self.path).create_group(
|
| 999 |
+
"vars"
|
| 1000 |
+
)
|
| 1001 |
+
if self.metadata:
|
| 1002 |
+
for k, v in self.metadata.items():
|
| 1003 |
+
self.group.attrs[k] = v
|
| 1004 |
+
else:
|
| 1005 |
+
found = False
|
| 1006 |
+
if not path:
|
| 1007 |
+
if "vars" in self.h5_file:
|
| 1008 |
+
self.group = self.h5_file["vars"]
|
| 1009 |
+
found = True
|
| 1010 |
+
elif path in self.h5_file and "vars" in self.h5_file[path]:
|
| 1011 |
+
self.group = self.h5_file[path]["vars"]
|
| 1012 |
+
found = True
|
| 1013 |
+
else:
|
| 1014 |
+
# No hit.
|
| 1015 |
+
# Fix for 2.13 compatibility
|
| 1016 |
+
if "_layer_checkpoint_dependencies" in self.h5_file:
|
| 1017 |
+
path = path.replace(
|
| 1018 |
+
"layers", "_layer_checkpoint_dependencies"
|
| 1019 |
+
)
|
| 1020 |
+
self.path = path
|
| 1021 |
+
if path in self.h5_file and "vars" in self.h5_file[path]:
|
| 1022 |
+
self.group = self.h5_file[path]["vars"]
|
| 1023 |
+
found = True
|
| 1024 |
+
if not found:
|
| 1025 |
+
self.group = {}
|
| 1026 |
+
|
| 1027 |
+
def __len__(self):
|
| 1028 |
+
return self.group.__len__()
|
| 1029 |
+
|
| 1030 |
+
def keys(self):
|
| 1031 |
+
return self.group.keys()
|
| 1032 |
+
|
| 1033 |
+
def items(self):
|
| 1034 |
+
return self.group.items()
|
| 1035 |
+
|
| 1036 |
+
def values(self):
|
| 1037 |
+
return self.group.values()
|
| 1038 |
+
|
| 1039 |
+
def __setitem__(self, key, value):
|
| 1040 |
+
if self.mode != "w":
|
| 1041 |
+
raise ValueError("Setting a value is only allowed in write mode.")
|
| 1042 |
+
value = backend.convert_to_numpy(value)
|
| 1043 |
+
if backend.standardize_dtype(value.dtype) == "bfloat16":
|
| 1044 |
+
ds = self.group.create_dataset(key, data=value)
|
| 1045 |
+
ds.attrs["dtype"] = "bfloat16"
|
| 1046 |
+
else:
|
| 1047 |
+
self.group[key] = value
|
| 1048 |
+
|
| 1049 |
+
def __getitem__(self, name):
|
| 1050 |
+
value = self.group[name]
|
| 1051 |
+
if "dtype" in value.attrs and value.attrs["dtype"] == "bfloat16":
|
| 1052 |
+
value = np.array(value, dtype=ml_dtypes.bfloat16)
|
| 1053 |
+
return value
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
class NpzIOStore:
|
| 1057 |
+
def __init__(self, root_path, archive=None, mode="r"):
|
| 1058 |
+
"""Numerical variable store backed by NumPy.savez/load.
|
| 1059 |
+
|
| 1060 |
+
If `archive` is specified, then `root_path` refers to the filename
|
| 1061 |
+
inside the archive.
|
| 1062 |
+
|
| 1063 |
+
If `archive` is not specified, then `root_path` refers to the path of
|
| 1064 |
+
the npz file on disk.
|
| 1065 |
+
"""
|
| 1066 |
+
self.root_path = root_path
|
| 1067 |
+
self.mode = mode
|
| 1068 |
+
self.archive = archive
|
| 1069 |
+
if mode == "w":
|
| 1070 |
+
self.contents = {}
|
| 1071 |
+
else:
|
| 1072 |
+
if self.archive:
|
| 1073 |
+
self.f = archive.open(root_path, mode="r")
|
| 1074 |
+
else:
|
| 1075 |
+
self.f = open(root_path, mode="rb")
|
| 1076 |
+
self.contents = np.load(self.f, allow_pickle=True)
|
| 1077 |
+
|
| 1078 |
+
def make(self, path, metadata=None):
|
| 1079 |
+
if not path:
|
| 1080 |
+
self.contents["__root__"] = {}
|
| 1081 |
+
return self.contents["__root__"]
|
| 1082 |
+
self.contents[path] = {}
|
| 1083 |
+
return self.contents[path]
|
| 1084 |
+
|
| 1085 |
+
def get(self, path):
|
| 1086 |
+
if not path:
|
| 1087 |
+
if "__root__" in self.contents:
|
| 1088 |
+
return dict(self.contents["__root__"])
|
| 1089 |
+
return {}
|
| 1090 |
+
if path in self.contents:
|
| 1091 |
+
return self.contents[path].tolist()
|
| 1092 |
+
return {}
|
| 1093 |
+
|
| 1094 |
+
def close(self):
|
| 1095 |
+
if self.mode == "w":
|
| 1096 |
+
if self.archive:
|
| 1097 |
+
self.f = self.archive.open(
|
| 1098 |
+
self.root_path, mode="w", force_zip64=True
|
| 1099 |
+
)
|
| 1100 |
+
else:
|
| 1101 |
+
self.f = open(self.root_path, mode="wb")
|
| 1102 |
+
np.savez(self.f, **self.contents)
|
| 1103 |
+
self.f.close()
|
| 1104 |
+
|
| 1105 |
+
|
| 1106 |
+
def get_temp_dir():
|
| 1107 |
+
temp_dir = tempfile.mkdtemp()
|
| 1108 |
+
testfile = tempfile.TemporaryFile(dir=temp_dir)
|
| 1109 |
+
testfile.close()
|
| 1110 |
+
return temp_dir
|
| 1111 |
+
|
| 1112 |
+
|
| 1113 |
+
def get_attr_skipset(obj_type):
|
| 1114 |
+
skipset = global_state.get_global_attribute(
|
| 1115 |
+
f"saving_attr_skiplist_{obj_type}", None
|
| 1116 |
+
)
|
| 1117 |
+
if skipset is not None:
|
| 1118 |
+
return skipset
|
| 1119 |
+
|
| 1120 |
+
skipset = set(
|
| 1121 |
+
[
|
| 1122 |
+
"_self_unconditional_dependency_names",
|
| 1123 |
+
]
|
| 1124 |
+
)
|
| 1125 |
+
if obj_type == "Layer":
|
| 1126 |
+
ref_obj = Layer()
|
| 1127 |
+
skipset.update(dir(ref_obj))
|
| 1128 |
+
elif obj_type == "Functional":
|
| 1129 |
+
ref_obj = Layer()
|
| 1130 |
+
skipset.update(dir(ref_obj) + ["operations", "_operations"])
|
| 1131 |
+
elif obj_type == "Sequential":
|
| 1132 |
+
ref_obj = Layer()
|
| 1133 |
+
skipset.update(dir(ref_obj) + ["_functional"])
|
| 1134 |
+
elif obj_type == "Metric":
|
| 1135 |
+
ref_obj_a = Metric()
|
| 1136 |
+
ref_obj_b = CompileMetrics([], [])
|
| 1137 |
+
skipset.update(dir(ref_obj_a) + dir(ref_obj_b))
|
| 1138 |
+
elif obj_type == "Optimizer":
|
| 1139 |
+
ref_obj = Optimizer(1.0)
|
| 1140 |
+
skipset.update(dir(ref_obj))
|
| 1141 |
+
skipset.remove("variables")
|
| 1142 |
+
elif obj_type == "Loss":
|
| 1143 |
+
ref_obj = Loss()
|
| 1144 |
+
skipset.update(dir(ref_obj))
|
| 1145 |
+
else:
|
| 1146 |
+
raise ValueError(
|
| 1147 |
+
f"get_attr_skipset got invalid {obj_type=}. "
|
| 1148 |
+
"Accepted values for `obj_type` are "
|
| 1149 |
+
"['Layer', 'Functional', 'Sequential', 'Metric', "
|
| 1150 |
+
"'Optimizer', 'Loss']"
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
global_state.set_global_attribute(
|
| 1154 |
+
f"saving_attr_skipset_{obj_type}", skipset
|
| 1155 |
+
)
|
| 1156 |
+
return skipset
|
| 1157 |
+
|
| 1158 |
+
|
| 1159 |
+
def is_memory_sufficient(model):
|
| 1160 |
+
"""Check if there is sufficient memory to load the model into memory.
|
| 1161 |
+
|
| 1162 |
+
If psutil is installed, we can use it to determine whether the memory is
|
| 1163 |
+
sufficient. Otherwise, we use a predefined value of 1 GB for available
|
| 1164 |
+
memory.
|
| 1165 |
+
"""
|
| 1166 |
+
if psutil is None:
|
| 1167 |
+
available_memory = 1024 * 1024 * 1024 # 1 GB in bytes
|
| 1168 |
+
else:
|
| 1169 |
+
available_memory = psutil.virtual_memory().available # In bytes
|
| 1170 |
+
return (
|
| 1171 |
+
weight_memory_size(model.variables)
|
| 1172 |
+
< available_memory * _MEMORY_UPPER_BOUND
|
| 1173 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Object config serialization and deserialization logic."""
|
| 2 |
+
|
| 3 |
+
import importlib
|
| 4 |
+
import inspect
|
| 5 |
+
import types
|
| 6 |
+
import warnings
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
from keras.src import api_export
|
| 11 |
+
from keras.src import backend
|
| 12 |
+
from keras.src.api_export import keras_export
|
| 13 |
+
from keras.src.backend.common import global_state
|
| 14 |
+
from keras.src.saving import object_registration
|
| 15 |
+
from keras.src.utils import python_utils
|
| 16 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 17 |
+
|
| 18 |
+
PLAIN_TYPES = (str, int, float, bool)
|
| 19 |
+
|
| 20 |
+
# List of Keras modules with built-in string representations for Keras defaults
|
| 21 |
+
BUILTIN_MODULES = (
|
| 22 |
+
"activations",
|
| 23 |
+
"constraints",
|
| 24 |
+
"initializers",
|
| 25 |
+
"losses",
|
| 26 |
+
"metrics",
|
| 27 |
+
"optimizers",
|
| 28 |
+
"regularizers",
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class SerializableDict:
|
| 33 |
+
def __init__(self, **config):
|
| 34 |
+
self.config = config
|
| 35 |
+
|
| 36 |
+
def serialize(self):
|
| 37 |
+
return serialize_keras_object(self.config)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class SafeModeScope:
|
| 41 |
+
"""Scope to propagate safe mode flag to nested deserialization calls."""
|
| 42 |
+
|
| 43 |
+
def __init__(self, safe_mode=True):
|
| 44 |
+
self.safe_mode = safe_mode
|
| 45 |
+
|
| 46 |
+
def __enter__(self):
|
| 47 |
+
self.original_value = in_safe_mode()
|
| 48 |
+
global_state.set_global_attribute("safe_mode_saving", self.safe_mode)
|
| 49 |
+
|
| 50 |
+
def __exit__(self, *args, **kwargs):
|
| 51 |
+
global_state.set_global_attribute(
|
| 52 |
+
"safe_mode_saving", self.original_value
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@keras_export("keras.config.enable_unsafe_deserialization")
|
| 57 |
+
def enable_unsafe_deserialization():
|
| 58 |
+
"""Disables safe mode globally, allowing deserialization of lambdas."""
|
| 59 |
+
global_state.set_global_attribute("safe_mode_saving", False)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def in_safe_mode():
|
| 63 |
+
return global_state.get_global_attribute("safe_mode_saving")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class ObjectSharingScope:
|
| 67 |
+
"""Scope to enable detection and reuse of previously seen objects."""
|
| 68 |
+
|
| 69 |
+
def __enter__(self):
|
| 70 |
+
global_state.set_global_attribute("shared_objects/id_to_obj_map", {})
|
| 71 |
+
global_state.set_global_attribute("shared_objects/id_to_config_map", {})
|
| 72 |
+
|
| 73 |
+
def __exit__(self, *args, **kwargs):
|
| 74 |
+
global_state.set_global_attribute("shared_objects/id_to_obj_map", None)
|
| 75 |
+
global_state.set_global_attribute(
|
| 76 |
+
"shared_objects/id_to_config_map", None
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def get_shared_object(obj_id):
|
| 81 |
+
"""Retrieve an object previously seen during deserialization."""
|
| 82 |
+
id_to_obj_map = global_state.get_global_attribute(
|
| 83 |
+
"shared_objects/id_to_obj_map"
|
| 84 |
+
)
|
| 85 |
+
if id_to_obj_map is not None:
|
| 86 |
+
return id_to_obj_map.get(obj_id, None)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def record_object_after_serialization(obj, config):
|
| 90 |
+
"""Call after serializing an object, to keep track of its config."""
|
| 91 |
+
if config["module"] == "__main__":
|
| 92 |
+
config["module"] = None # Ensures module is None when no module found
|
| 93 |
+
id_to_config_map = global_state.get_global_attribute(
|
| 94 |
+
"shared_objects/id_to_config_map"
|
| 95 |
+
)
|
| 96 |
+
if id_to_config_map is None:
|
| 97 |
+
return # Not in a sharing scope
|
| 98 |
+
obj_id = int(id(obj))
|
| 99 |
+
if obj_id not in id_to_config_map:
|
| 100 |
+
id_to_config_map[obj_id] = config
|
| 101 |
+
else:
|
| 102 |
+
config["shared_object_id"] = obj_id
|
| 103 |
+
prev_config = id_to_config_map[obj_id]
|
| 104 |
+
prev_config["shared_object_id"] = obj_id
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def record_object_after_deserialization(obj, obj_id):
|
| 108 |
+
"""Call after deserializing an object, to keep track of it in the future."""
|
| 109 |
+
id_to_obj_map = global_state.get_global_attribute(
|
| 110 |
+
"shared_objects/id_to_obj_map"
|
| 111 |
+
)
|
| 112 |
+
if id_to_obj_map is None:
|
| 113 |
+
return # Not in a sharing scope
|
| 114 |
+
id_to_obj_map[obj_id] = obj
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
@keras_export(
|
| 118 |
+
[
|
| 119 |
+
"keras.saving.serialize_keras_object",
|
| 120 |
+
"keras.utils.serialize_keras_object",
|
| 121 |
+
]
|
| 122 |
+
)
|
| 123 |
+
def serialize_keras_object(obj):
|
| 124 |
+
"""Retrieve the config dict by serializing the Keras object.
|
| 125 |
+
|
| 126 |
+
`serialize_keras_object()` serializes a Keras object to a python dictionary
|
| 127 |
+
that represents the object, and is a reciprocal function of
|
| 128 |
+
`deserialize_keras_object()`. See `deserialize_keras_object()` for more
|
| 129 |
+
information about the config format.
|
| 130 |
+
|
| 131 |
+
Args:
|
| 132 |
+
obj: the Keras object to serialize.
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
A python dict that represents the object. The python dict can be
|
| 136 |
+
deserialized via `deserialize_keras_object()`.
|
| 137 |
+
"""
|
| 138 |
+
if obj is None:
|
| 139 |
+
return obj
|
| 140 |
+
|
| 141 |
+
if isinstance(obj, PLAIN_TYPES):
|
| 142 |
+
return obj
|
| 143 |
+
|
| 144 |
+
if isinstance(obj, (list, tuple)):
|
| 145 |
+
config_arr = [serialize_keras_object(x) for x in obj]
|
| 146 |
+
return tuple(config_arr) if isinstance(obj, tuple) else config_arr
|
| 147 |
+
if isinstance(obj, dict):
|
| 148 |
+
return serialize_dict(obj)
|
| 149 |
+
|
| 150 |
+
# Special cases:
|
| 151 |
+
if isinstance(obj, bytes):
|
| 152 |
+
return {
|
| 153 |
+
"class_name": "__bytes__",
|
| 154 |
+
"config": {"value": obj.decode("utf-8")},
|
| 155 |
+
}
|
| 156 |
+
if isinstance(obj, slice):
|
| 157 |
+
return {
|
| 158 |
+
"class_name": "__slice__",
|
| 159 |
+
"config": {
|
| 160 |
+
"start": serialize_keras_object(obj.start),
|
| 161 |
+
"stop": serialize_keras_object(obj.stop),
|
| 162 |
+
"step": serialize_keras_object(obj.step),
|
| 163 |
+
},
|
| 164 |
+
}
|
| 165 |
+
# Ellipsis is an instance, and ellipsis class is not in global scope.
|
| 166 |
+
# checking equality also fails elsewhere in the library, so we have
|
| 167 |
+
# to dynamically get the type.
|
| 168 |
+
if isinstance(obj, type(Ellipsis)):
|
| 169 |
+
return {"class_name": "__ellipsis__", "config": {}}
|
| 170 |
+
if isinstance(obj, backend.KerasTensor):
|
| 171 |
+
history = getattr(obj, "_keras_history", None)
|
| 172 |
+
if history:
|
| 173 |
+
history = list(history)
|
| 174 |
+
history[0] = history[0].name
|
| 175 |
+
return {
|
| 176 |
+
"class_name": "__keras_tensor__",
|
| 177 |
+
"config": {
|
| 178 |
+
"shape": obj.shape,
|
| 179 |
+
"dtype": obj.dtype,
|
| 180 |
+
"keras_history": history,
|
| 181 |
+
},
|
| 182 |
+
}
|
| 183 |
+
if tf.available and isinstance(obj, tf.TensorShape):
|
| 184 |
+
return obj.as_list() if obj._dims is not None else None
|
| 185 |
+
if backend.is_tensor(obj):
|
| 186 |
+
return {
|
| 187 |
+
"class_name": "__tensor__",
|
| 188 |
+
"config": {
|
| 189 |
+
"value": backend.convert_to_numpy(obj).tolist(),
|
| 190 |
+
"dtype": backend.standardize_dtype(obj.dtype),
|
| 191 |
+
},
|
| 192 |
+
}
|
| 193 |
+
if type(obj).__module__ == np.__name__:
|
| 194 |
+
if isinstance(obj, np.ndarray) and obj.ndim > 0:
|
| 195 |
+
return {
|
| 196 |
+
"class_name": "__numpy__",
|
| 197 |
+
"config": {
|
| 198 |
+
"value": obj.tolist(),
|
| 199 |
+
"dtype": backend.standardize_dtype(obj.dtype),
|
| 200 |
+
},
|
| 201 |
+
}
|
| 202 |
+
else:
|
| 203 |
+
# Treat numpy floats / etc as plain types.
|
| 204 |
+
return obj.item()
|
| 205 |
+
if tf.available and isinstance(obj, tf.DType):
|
| 206 |
+
return obj.name
|
| 207 |
+
if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>":
|
| 208 |
+
warnings.warn(
|
| 209 |
+
"The object being serialized includes a `lambda`. This is unsafe. "
|
| 210 |
+
"In order to reload the object, you will have to pass "
|
| 211 |
+
"`safe_mode=False` to the loading function. "
|
| 212 |
+
"Please avoid using `lambda` in the "
|
| 213 |
+
"future, and use named Python functions instead. "
|
| 214 |
+
f"This is the `lambda` being serialized: {inspect.getsource(obj)}",
|
| 215 |
+
stacklevel=2,
|
| 216 |
+
)
|
| 217 |
+
return {
|
| 218 |
+
"class_name": "__lambda__",
|
| 219 |
+
"config": {
|
| 220 |
+
"value": python_utils.func_dump(obj),
|
| 221 |
+
},
|
| 222 |
+
}
|
| 223 |
+
if tf.available and isinstance(obj, tf.TypeSpec):
|
| 224 |
+
ts_config = obj._serialize()
|
| 225 |
+
# TensorShape and tf.DType conversion
|
| 226 |
+
ts_config = list(
|
| 227 |
+
map(
|
| 228 |
+
lambda x: (
|
| 229 |
+
x.as_list()
|
| 230 |
+
if isinstance(x, tf.TensorShape)
|
| 231 |
+
else (x.name if isinstance(x, tf.DType) else x)
|
| 232 |
+
),
|
| 233 |
+
ts_config,
|
| 234 |
+
)
|
| 235 |
+
)
|
| 236 |
+
return {
|
| 237 |
+
"class_name": "__typespec__",
|
| 238 |
+
"spec_name": obj.__class__.__name__,
|
| 239 |
+
"module": obj.__class__.__module__,
|
| 240 |
+
"config": ts_config,
|
| 241 |
+
"registered_name": None,
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
inner_config = _get_class_or_fn_config(obj)
|
| 245 |
+
config_with_public_class = serialize_with_public_class(
|
| 246 |
+
obj.__class__, inner_config
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
if config_with_public_class is not None:
|
| 250 |
+
get_build_and_compile_config(obj, config_with_public_class)
|
| 251 |
+
record_object_after_serialization(obj, config_with_public_class)
|
| 252 |
+
return config_with_public_class
|
| 253 |
+
|
| 254 |
+
# Any custom object or otherwise non-exported object
|
| 255 |
+
if isinstance(obj, types.FunctionType):
|
| 256 |
+
module = obj.__module__
|
| 257 |
+
else:
|
| 258 |
+
module = obj.__class__.__module__
|
| 259 |
+
class_name = obj.__class__.__name__
|
| 260 |
+
|
| 261 |
+
if module == "builtins":
|
| 262 |
+
registered_name = None
|
| 263 |
+
else:
|
| 264 |
+
if isinstance(obj, types.FunctionType):
|
| 265 |
+
registered_name = object_registration.get_registered_name(obj)
|
| 266 |
+
else:
|
| 267 |
+
registered_name = object_registration.get_registered_name(
|
| 268 |
+
obj.__class__
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
config = {
|
| 272 |
+
"module": module,
|
| 273 |
+
"class_name": class_name,
|
| 274 |
+
"config": inner_config,
|
| 275 |
+
"registered_name": registered_name,
|
| 276 |
+
}
|
| 277 |
+
get_build_and_compile_config(obj, config)
|
| 278 |
+
record_object_after_serialization(obj, config)
|
| 279 |
+
return config
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def get_build_and_compile_config(obj, config):
|
| 283 |
+
if hasattr(obj, "get_build_config"):
|
| 284 |
+
build_config = obj.get_build_config()
|
| 285 |
+
if build_config is not None:
|
| 286 |
+
config["build_config"] = serialize_dict(build_config)
|
| 287 |
+
if hasattr(obj, "get_compile_config"):
|
| 288 |
+
compile_config = obj.get_compile_config()
|
| 289 |
+
if compile_config is not None:
|
| 290 |
+
config["compile_config"] = serialize_dict(compile_config)
|
| 291 |
+
return
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def serialize_with_public_class(cls, inner_config=None):
|
| 295 |
+
"""Serializes classes from public Keras API or object registration.
|
| 296 |
+
|
| 297 |
+
Called to check and retrieve the config of any class that has a public
|
| 298 |
+
Keras API or has been registered as serializable via
|
| 299 |
+
`keras.saving.register_keras_serializable()`.
|
| 300 |
+
"""
|
| 301 |
+
# This gets the `keras.*` exported name, such as
|
| 302 |
+
# "keras.optimizers.Adam".
|
| 303 |
+
keras_api_name = api_export.get_name_from_symbol(cls)
|
| 304 |
+
|
| 305 |
+
# Case of custom or unknown class object
|
| 306 |
+
if keras_api_name is None:
|
| 307 |
+
registered_name = object_registration.get_registered_name(cls)
|
| 308 |
+
if registered_name is None:
|
| 309 |
+
return None
|
| 310 |
+
|
| 311 |
+
# Return custom object config with corresponding registration name
|
| 312 |
+
return {
|
| 313 |
+
"module": cls.__module__,
|
| 314 |
+
"class_name": cls.__name__,
|
| 315 |
+
"config": inner_config,
|
| 316 |
+
"registered_name": registered_name,
|
| 317 |
+
}
|
| 318 |
+
|
| 319 |
+
# Split the canonical Keras API name into a Keras module and class name.
|
| 320 |
+
parts = keras_api_name.split(".")
|
| 321 |
+
return {
|
| 322 |
+
"module": ".".join(parts[:-1]),
|
| 323 |
+
"class_name": parts[-1],
|
| 324 |
+
"config": inner_config,
|
| 325 |
+
"registered_name": None,
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def serialize_with_public_fn(fn, config, fn_module_name=None):
|
| 330 |
+
"""Serializes functions from public Keras API or object registration.
|
| 331 |
+
|
| 332 |
+
Called to check and retrieve the config of any function that has a public
|
| 333 |
+
Keras API or has been registered as serializable via
|
| 334 |
+
`keras.saving.register_keras_serializable()`. If function's module name
|
| 335 |
+
is already known, returns corresponding config.
|
| 336 |
+
"""
|
| 337 |
+
if fn_module_name:
|
| 338 |
+
return {
|
| 339 |
+
"module": fn_module_name,
|
| 340 |
+
"class_name": "function",
|
| 341 |
+
"config": config,
|
| 342 |
+
"registered_name": config,
|
| 343 |
+
}
|
| 344 |
+
keras_api_name = api_export.get_name_from_symbol(fn)
|
| 345 |
+
if keras_api_name:
|
| 346 |
+
parts = keras_api_name.split(".")
|
| 347 |
+
return {
|
| 348 |
+
"module": ".".join(parts[:-1]),
|
| 349 |
+
"class_name": "function",
|
| 350 |
+
"config": config,
|
| 351 |
+
"registered_name": config,
|
| 352 |
+
}
|
| 353 |
+
else:
|
| 354 |
+
registered_name = object_registration.get_registered_name(fn)
|
| 355 |
+
if not registered_name and not fn.__module__ == "builtins":
|
| 356 |
+
return None
|
| 357 |
+
return {
|
| 358 |
+
"module": fn.__module__,
|
| 359 |
+
"class_name": "function",
|
| 360 |
+
"config": config,
|
| 361 |
+
"registered_name": registered_name,
|
| 362 |
+
}
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def _get_class_or_fn_config(obj):
|
| 366 |
+
"""Return the object's config depending on its type."""
|
| 367 |
+
# Functions / lambdas:
|
| 368 |
+
if isinstance(obj, types.FunctionType):
|
| 369 |
+
return object_registration.get_registered_name(obj)
|
| 370 |
+
# All classes:
|
| 371 |
+
if hasattr(obj, "get_config"):
|
| 372 |
+
config = obj.get_config()
|
| 373 |
+
if not isinstance(config, dict):
|
| 374 |
+
raise TypeError(
|
| 375 |
+
f"The `get_config()` method of {obj} should return "
|
| 376 |
+
f"a dict. It returned: {config}"
|
| 377 |
+
)
|
| 378 |
+
return serialize_dict(config)
|
| 379 |
+
elif hasattr(obj, "__name__"):
|
| 380 |
+
return object_registration.get_registered_name(obj)
|
| 381 |
+
else:
|
| 382 |
+
raise TypeError(
|
| 383 |
+
f"Cannot serialize object {obj} of type {type(obj)}. "
|
| 384 |
+
"To be serializable, "
|
| 385 |
+
"a class must implement the `get_config()` method."
|
| 386 |
+
)
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def serialize_dict(obj):
|
| 390 |
+
return {key: serialize_keras_object(value) for key, value in obj.items()}
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@keras_export(
|
| 394 |
+
[
|
| 395 |
+
"keras.saving.deserialize_keras_object",
|
| 396 |
+
"keras.utils.deserialize_keras_object",
|
| 397 |
+
]
|
| 398 |
+
)
|
| 399 |
+
def deserialize_keras_object(
|
| 400 |
+
config, custom_objects=None, safe_mode=True, **kwargs
|
| 401 |
+
):
|
| 402 |
+
"""Retrieve the object by deserializing the config dict.
|
| 403 |
+
|
| 404 |
+
The config dict is a Python dictionary that consists of a set of key-value
|
| 405 |
+
pairs, and represents a Keras object, such as an `Optimizer`, `Layer`,
|
| 406 |
+
`Metrics`, etc. The saving and loading library uses the following keys to
|
| 407 |
+
record information of a Keras object:
|
| 408 |
+
|
| 409 |
+
- `class_name`: String. This is the name of the class,
|
| 410 |
+
as exactly defined in the source
|
| 411 |
+
code, such as "LossesContainer".
|
| 412 |
+
- `config`: Dict. Library-defined or user-defined key-value pairs that store
|
| 413 |
+
the configuration of the object, as obtained by `object.get_config()`.
|
| 414 |
+
- `module`: String. The path of the python module. Built-in Keras classes
|
| 415 |
+
expect to have prefix `keras`.
|
| 416 |
+
- `registered_name`: String. The key the class is registered under via
|
| 417 |
+
`keras.saving.register_keras_serializable(package, name)` API. The
|
| 418 |
+
key has the format of '{package}>{name}', where `package` and `name` are
|
| 419 |
+
the arguments passed to `register_keras_serializable()`. If `name` is not
|
| 420 |
+
provided, it uses the class name. If `registered_name` successfully
|
| 421 |
+
resolves to a class (that was registered), the `class_name` and `config`
|
| 422 |
+
values in the dict will not be used. `registered_name` is only used for
|
| 423 |
+
non-built-in classes.
|
| 424 |
+
|
| 425 |
+
For example, the following dictionary represents the built-in Adam optimizer
|
| 426 |
+
with the relevant config:
|
| 427 |
+
|
| 428 |
+
```python
|
| 429 |
+
dict_structure = {
|
| 430 |
+
"class_name": "Adam",
|
| 431 |
+
"config": {
|
| 432 |
+
"amsgrad": false,
|
| 433 |
+
"beta_1": 0.8999999761581421,
|
| 434 |
+
"beta_2": 0.9990000128746033,
|
| 435 |
+
"decay": 0.0,
|
| 436 |
+
"epsilon": 1e-07,
|
| 437 |
+
"learning_rate": 0.0010000000474974513,
|
| 438 |
+
"name": "Adam"
|
| 439 |
+
},
|
| 440 |
+
"module": "keras.optimizers",
|
| 441 |
+
"registered_name": None
|
| 442 |
+
}
|
| 443 |
+
# Returns an `Adam` instance identical to the original one.
|
| 444 |
+
deserialize_keras_object(dict_structure)
|
| 445 |
+
```
|
| 446 |
+
|
| 447 |
+
If the class does not have an exported Keras namespace, the library tracks
|
| 448 |
+
it by its `module` and `class_name`. For example:
|
| 449 |
+
|
| 450 |
+
```python
|
| 451 |
+
dict_structure = {
|
| 452 |
+
"class_name": "MetricsList",
|
| 453 |
+
"config": {
|
| 454 |
+
...
|
| 455 |
+
},
|
| 456 |
+
"module": "keras.trainers.compile_utils",
|
| 457 |
+
"registered_name": "MetricsList"
|
| 458 |
+
}
|
| 459 |
+
|
| 460 |
+
# Returns a `MetricsList` instance identical to the original one.
|
| 461 |
+
deserialize_keras_object(dict_structure)
|
| 462 |
+
```
|
| 463 |
+
|
| 464 |
+
And the following dictionary represents a user-customized `MeanSquaredError`
|
| 465 |
+
loss:
|
| 466 |
+
|
| 467 |
+
```python
|
| 468 |
+
@keras.saving.register_keras_serializable(package='my_package')
|
| 469 |
+
class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
|
| 470 |
+
...
|
| 471 |
+
|
| 472 |
+
dict_structure = {
|
| 473 |
+
"class_name": "ModifiedMeanSquaredError",
|
| 474 |
+
"config": {
|
| 475 |
+
"fn": "mean_squared_error",
|
| 476 |
+
"name": "mean_squared_error",
|
| 477 |
+
"reduction": "auto"
|
| 478 |
+
},
|
| 479 |
+
"registered_name": "my_package>ModifiedMeanSquaredError"
|
| 480 |
+
}
|
| 481 |
+
# Returns the `ModifiedMeanSquaredError` object
|
| 482 |
+
deserialize_keras_object(dict_structure)
|
| 483 |
+
```
|
| 484 |
+
|
| 485 |
+
Args:
|
| 486 |
+
config: Python dict describing the object.
|
| 487 |
+
custom_objects: Python dict containing a mapping between custom
|
| 488 |
+
object names the corresponding classes or functions.
|
| 489 |
+
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
|
| 490 |
+
When `safe_mode=False`, loading an object has the potential to
|
| 491 |
+
trigger arbitrary code execution. This argument is only
|
| 492 |
+
applicable to the Keras v3 model format. Defaults to `True`.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
The object described by the `config` dictionary.
|
| 496 |
+
"""
|
| 497 |
+
safe_scope_arg = in_safe_mode() # Enforces SafeModeScope
|
| 498 |
+
safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode
|
| 499 |
+
|
| 500 |
+
module_objects = kwargs.pop("module_objects", None)
|
| 501 |
+
custom_objects = custom_objects or {}
|
| 502 |
+
tlco = global_state.get_global_attribute("custom_objects_scope_dict", {})
|
| 503 |
+
gco = object_registration.GLOBAL_CUSTOM_OBJECTS
|
| 504 |
+
custom_objects = {**custom_objects, **tlco, **gco}
|
| 505 |
+
|
| 506 |
+
if config is None:
|
| 507 |
+
return None
|
| 508 |
+
|
| 509 |
+
if (
|
| 510 |
+
isinstance(config, str)
|
| 511 |
+
and custom_objects
|
| 512 |
+
and custom_objects.get(config) is not None
|
| 513 |
+
):
|
| 514 |
+
# This is to deserialize plain functions which are serialized as
|
| 515 |
+
# string names by legacy saving formats.
|
| 516 |
+
return custom_objects[config]
|
| 517 |
+
|
| 518 |
+
if isinstance(config, (list, tuple)):
|
| 519 |
+
return [
|
| 520 |
+
deserialize_keras_object(
|
| 521 |
+
x, custom_objects=custom_objects, safe_mode=safe_mode
|
| 522 |
+
)
|
| 523 |
+
for x in config
|
| 524 |
+
]
|
| 525 |
+
|
| 526 |
+
if module_objects is not None:
|
| 527 |
+
inner_config, fn_module_name, has_custom_object = None, None, False
|
| 528 |
+
|
| 529 |
+
if isinstance(config, dict):
|
| 530 |
+
if "config" in config:
|
| 531 |
+
inner_config = config["config"]
|
| 532 |
+
if "class_name" not in config:
|
| 533 |
+
raise ValueError(
|
| 534 |
+
f"Unknown `config` as a `dict`, config={config}"
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
# Check case where config is function or class and in custom objects
|
| 538 |
+
if custom_objects and (
|
| 539 |
+
config["class_name"] in custom_objects
|
| 540 |
+
or config.get("registered_name") in custom_objects
|
| 541 |
+
or (
|
| 542 |
+
isinstance(inner_config, str)
|
| 543 |
+
and inner_config in custom_objects
|
| 544 |
+
)
|
| 545 |
+
):
|
| 546 |
+
has_custom_object = True
|
| 547 |
+
|
| 548 |
+
# Case where config is function but not in custom objects
|
| 549 |
+
elif config["class_name"] == "function":
|
| 550 |
+
fn_module_name = config["module"]
|
| 551 |
+
if fn_module_name == "builtins":
|
| 552 |
+
config = config["config"]
|
| 553 |
+
else:
|
| 554 |
+
config = config["registered_name"]
|
| 555 |
+
|
| 556 |
+
# Case where config is class but not in custom objects
|
| 557 |
+
else:
|
| 558 |
+
if config.get("module", "_") is None:
|
| 559 |
+
raise TypeError(
|
| 560 |
+
"Cannot deserialize object of type "
|
| 561 |
+
f"`{config['class_name']}`. If "
|
| 562 |
+
f"`{config['class_name']}` is a custom class, please "
|
| 563 |
+
"register it using the "
|
| 564 |
+
"`@keras.saving.register_keras_serializable()` "
|
| 565 |
+
"decorator."
|
| 566 |
+
)
|
| 567 |
+
config = config["class_name"]
|
| 568 |
+
|
| 569 |
+
if not has_custom_object:
|
| 570 |
+
# Return if not found in either module objects or custom objects
|
| 571 |
+
if config not in module_objects:
|
| 572 |
+
# Object has already been deserialized
|
| 573 |
+
return config
|
| 574 |
+
if isinstance(module_objects[config], types.FunctionType):
|
| 575 |
+
return deserialize_keras_object(
|
| 576 |
+
serialize_with_public_fn(
|
| 577 |
+
module_objects[config], config, fn_module_name
|
| 578 |
+
),
|
| 579 |
+
custom_objects=custom_objects,
|
| 580 |
+
)
|
| 581 |
+
return deserialize_keras_object(
|
| 582 |
+
serialize_with_public_class(
|
| 583 |
+
module_objects[config], inner_config=inner_config
|
| 584 |
+
),
|
| 585 |
+
custom_objects=custom_objects,
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
if isinstance(config, PLAIN_TYPES):
|
| 589 |
+
return config
|
| 590 |
+
if not isinstance(config, dict):
|
| 591 |
+
raise TypeError(f"Could not parse config: {config}")
|
| 592 |
+
|
| 593 |
+
if "class_name" not in config or "config" not in config:
|
| 594 |
+
return {
|
| 595 |
+
key: deserialize_keras_object(
|
| 596 |
+
value, custom_objects=custom_objects, safe_mode=safe_mode
|
| 597 |
+
)
|
| 598 |
+
for key, value in config.items()
|
| 599 |
+
}
|
| 600 |
+
|
| 601 |
+
class_name = config["class_name"]
|
| 602 |
+
inner_config = config["config"] or {}
|
| 603 |
+
custom_objects = custom_objects or {}
|
| 604 |
+
|
| 605 |
+
# Special cases:
|
| 606 |
+
if class_name == "__keras_tensor__":
|
| 607 |
+
obj = backend.KerasTensor(
|
| 608 |
+
inner_config["shape"], dtype=inner_config["dtype"]
|
| 609 |
+
)
|
| 610 |
+
obj._pre_serialization_keras_history = inner_config["keras_history"]
|
| 611 |
+
return obj
|
| 612 |
+
|
| 613 |
+
if class_name == "__tensor__":
|
| 614 |
+
return backend.convert_to_tensor(
|
| 615 |
+
inner_config["value"], dtype=inner_config["dtype"]
|
| 616 |
+
)
|
| 617 |
+
if class_name == "__numpy__":
|
| 618 |
+
return np.array(inner_config["value"], dtype=inner_config["dtype"])
|
| 619 |
+
if config["class_name"] == "__bytes__":
|
| 620 |
+
return inner_config["value"].encode("utf-8")
|
| 621 |
+
if config["class_name"] == "__ellipsis__":
|
| 622 |
+
return Ellipsis
|
| 623 |
+
if config["class_name"] == "__slice__":
|
| 624 |
+
return slice(
|
| 625 |
+
deserialize_keras_object(
|
| 626 |
+
inner_config["start"],
|
| 627 |
+
custom_objects=custom_objects,
|
| 628 |
+
safe_mode=safe_mode,
|
| 629 |
+
),
|
| 630 |
+
deserialize_keras_object(
|
| 631 |
+
inner_config["stop"],
|
| 632 |
+
custom_objects=custom_objects,
|
| 633 |
+
safe_mode=safe_mode,
|
| 634 |
+
),
|
| 635 |
+
deserialize_keras_object(
|
| 636 |
+
inner_config["step"],
|
| 637 |
+
custom_objects=custom_objects,
|
| 638 |
+
safe_mode=safe_mode,
|
| 639 |
+
),
|
| 640 |
+
)
|
| 641 |
+
if config["class_name"] == "__lambda__":
|
| 642 |
+
if safe_mode:
|
| 643 |
+
raise ValueError(
|
| 644 |
+
"Requested the deserialization of a `lambda` object. "
|
| 645 |
+
"This carries a potential risk of arbitrary code execution "
|
| 646 |
+
"and thus it is disallowed by default. If you trust the "
|
| 647 |
+
"source of the saved model, you can pass `safe_mode=False` to "
|
| 648 |
+
"the loading function in order to allow `lambda` loading, "
|
| 649 |
+
"or call `keras.config.enable_unsafe_deserialization()`."
|
| 650 |
+
)
|
| 651 |
+
return python_utils.func_load(inner_config["value"])
|
| 652 |
+
if tf is not None and config["class_name"] == "__typespec__":
|
| 653 |
+
obj = _retrieve_class_or_fn(
|
| 654 |
+
config["spec_name"],
|
| 655 |
+
config["registered_name"],
|
| 656 |
+
config["module"],
|
| 657 |
+
obj_type="class",
|
| 658 |
+
full_config=config,
|
| 659 |
+
custom_objects=custom_objects,
|
| 660 |
+
)
|
| 661 |
+
# Conversion to TensorShape and DType
|
| 662 |
+
inner_config = map(
|
| 663 |
+
lambda x: (
|
| 664 |
+
tf.TensorShape(x)
|
| 665 |
+
if isinstance(x, list)
|
| 666 |
+
else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x)
|
| 667 |
+
),
|
| 668 |
+
inner_config,
|
| 669 |
+
)
|
| 670 |
+
return obj._deserialize(tuple(inner_config))
|
| 671 |
+
|
| 672 |
+
# Below: classes and functions.
|
| 673 |
+
module = config.get("module", None)
|
| 674 |
+
registered_name = config.get("registered_name", class_name)
|
| 675 |
+
|
| 676 |
+
if class_name == "function":
|
| 677 |
+
fn_name = inner_config
|
| 678 |
+
return _retrieve_class_or_fn(
|
| 679 |
+
fn_name,
|
| 680 |
+
registered_name,
|
| 681 |
+
module,
|
| 682 |
+
obj_type="function",
|
| 683 |
+
full_config=config,
|
| 684 |
+
custom_objects=custom_objects,
|
| 685 |
+
)
|
| 686 |
+
|
| 687 |
+
# Below, handling of all classes.
|
| 688 |
+
# First, is it a shared object?
|
| 689 |
+
if "shared_object_id" in config:
|
| 690 |
+
obj = get_shared_object(config["shared_object_id"])
|
| 691 |
+
if obj is not None:
|
| 692 |
+
return obj
|
| 693 |
+
|
| 694 |
+
cls = _retrieve_class_or_fn(
|
| 695 |
+
class_name,
|
| 696 |
+
registered_name,
|
| 697 |
+
module,
|
| 698 |
+
obj_type="class",
|
| 699 |
+
full_config=config,
|
| 700 |
+
custom_objects=custom_objects,
|
| 701 |
+
)
|
| 702 |
+
|
| 703 |
+
if isinstance(cls, types.FunctionType):
|
| 704 |
+
return cls
|
| 705 |
+
if not hasattr(cls, "from_config"):
|
| 706 |
+
raise TypeError(
|
| 707 |
+
f"Unable to reconstruct an instance of '{class_name}' because "
|
| 708 |
+
f"the class is missing a `from_config()` method. "
|
| 709 |
+
f"Full object config: {config}"
|
| 710 |
+
)
|
| 711 |
+
|
| 712 |
+
# Instantiate the class from its config inside a custom object scope
|
| 713 |
+
# so that we can catch any custom objects that the config refers to.
|
| 714 |
+
custom_obj_scope = object_registration.CustomObjectScope(custom_objects)
|
| 715 |
+
safe_mode_scope = SafeModeScope(safe_mode)
|
| 716 |
+
with custom_obj_scope, safe_mode_scope:
|
| 717 |
+
try:
|
| 718 |
+
instance = cls.from_config(inner_config)
|
| 719 |
+
except TypeError as e:
|
| 720 |
+
raise TypeError(
|
| 721 |
+
f"{cls} could not be deserialized properly. Please"
|
| 722 |
+
" ensure that components that are Python object"
|
| 723 |
+
" instances (layers, models, etc.) returned by"
|
| 724 |
+
" `get_config()` are explicitly deserialized in the"
|
| 725 |
+
" model's `from_config()` method."
|
| 726 |
+
f"\n\nconfig={config}.\n\nException encountered: {e}"
|
| 727 |
+
)
|
| 728 |
+
build_config = config.get("build_config", None)
|
| 729 |
+
if build_config and not instance.built:
|
| 730 |
+
instance.build_from_config(build_config)
|
| 731 |
+
instance.built = True
|
| 732 |
+
compile_config = config.get("compile_config", None)
|
| 733 |
+
if compile_config:
|
| 734 |
+
instance.compile_from_config(compile_config)
|
| 735 |
+
instance.compiled = True
|
| 736 |
+
|
| 737 |
+
if "shared_object_id" in config:
|
| 738 |
+
record_object_after_deserialization(
|
| 739 |
+
instance, config["shared_object_id"]
|
| 740 |
+
)
|
| 741 |
+
return instance
|
| 742 |
+
|
| 743 |
+
|
| 744 |
+
def _retrieve_class_or_fn(
|
| 745 |
+
name, registered_name, module, obj_type, full_config, custom_objects=None
|
| 746 |
+
):
|
| 747 |
+
# If there is a custom object registered via
|
| 748 |
+
# `register_keras_serializable()`, that takes precedence.
|
| 749 |
+
if obj_type == "function":
|
| 750 |
+
custom_obj = object_registration.get_registered_object(
|
| 751 |
+
name, custom_objects=custom_objects
|
| 752 |
+
)
|
| 753 |
+
else:
|
| 754 |
+
custom_obj = object_registration.get_registered_object(
|
| 755 |
+
registered_name, custom_objects=custom_objects
|
| 756 |
+
)
|
| 757 |
+
if custom_obj is not None:
|
| 758 |
+
return custom_obj
|
| 759 |
+
|
| 760 |
+
if module:
|
| 761 |
+
# If it's a Keras built-in object,
|
| 762 |
+
# we cannot always use direct import, because the exported
|
| 763 |
+
# module name might not match the package structure
|
| 764 |
+
# (e.g. experimental symbols).
|
| 765 |
+
if module == "keras" or module.startswith("keras."):
|
| 766 |
+
api_name = module + "." + name
|
| 767 |
+
|
| 768 |
+
obj = api_export.get_symbol_from_name(api_name)
|
| 769 |
+
if obj is not None:
|
| 770 |
+
return obj
|
| 771 |
+
|
| 772 |
+
# Configs of Keras built-in functions do not contain identifying
|
| 773 |
+
# information other than their name (e.g. 'acc' or 'tanh'). This special
|
| 774 |
+
# case searches the Keras modules that contain built-ins to retrieve
|
| 775 |
+
# the corresponding function from the identifying string.
|
| 776 |
+
if obj_type == "function" and module == "builtins":
|
| 777 |
+
for mod in BUILTIN_MODULES:
|
| 778 |
+
obj = api_export.get_symbol_from_name(
|
| 779 |
+
"keras." + mod + "." + name
|
| 780 |
+
)
|
| 781 |
+
if obj is not None:
|
| 782 |
+
return obj
|
| 783 |
+
|
| 784 |
+
# Otherwise, attempt to retrieve the class object given the `module`
|
| 785 |
+
# and `class_name`. Import the module, find the class.
|
| 786 |
+
try:
|
| 787 |
+
mod = importlib.import_module(module)
|
| 788 |
+
except ModuleNotFoundError:
|
| 789 |
+
raise TypeError(
|
| 790 |
+
f"Could not deserialize {obj_type} '{name}' because "
|
| 791 |
+
f"its parent module {module} cannot be imported. "
|
| 792 |
+
f"Full object config: {full_config}"
|
| 793 |
+
)
|
| 794 |
+
obj = vars(mod).get(name, None)
|
| 795 |
+
|
| 796 |
+
# Special case for keras.metrics.metrics
|
| 797 |
+
if obj is None and registered_name is not None:
|
| 798 |
+
obj = vars(mod).get(registered_name, None)
|
| 799 |
+
|
| 800 |
+
if obj is not None:
|
| 801 |
+
return obj
|
| 802 |
+
|
| 803 |
+
raise TypeError(
|
| 804 |
+
f"Could not locate {obj_type} '{name}'. "
|
| 805 |
+
"Make sure custom classes are decorated with "
|
| 806 |
+
"`@keras.saving.register_keras_serializable()`. "
|
| 807 |
+
f"Full object config: {full_config}"
|
| 808 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.testing.test_case import TestCase
|
| 2 |
+
from keras.src.testing.test_case import jax_uses_gpu
|
| 3 |
+
from keras.src.testing.test_case import tensorflow_uses_gpu
|
| 4 |
+
from keras.src.testing.test_case import torch_uses_gpu
|
| 5 |
+
from keras.src.testing.test_case import uses_gpu
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (398 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc
ADDED
|
Binary file (22.5 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc
ADDED
|
Binary file (5.43 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import shutil
|
| 3 |
+
import tempfile
|
| 4 |
+
import unittest
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
from absl.testing import parameterized
|
| 9 |
+
|
| 10 |
+
from keras.src import backend
|
| 11 |
+
from keras.src import distribution
|
| 12 |
+
from keras.src import ops
|
| 13 |
+
from keras.src import tree
|
| 14 |
+
from keras.src import utils
|
| 15 |
+
from keras.src.backend.common import is_float_dtype
|
| 16 |
+
from keras.src.backend.common import standardize_dtype
|
| 17 |
+
from keras.src.backend.common.global_state import clear_session
|
| 18 |
+
from keras.src.backend.common.keras_tensor import KerasTensor
|
| 19 |
+
from keras.src.models import Model
|
| 20 |
+
from keras.src.utils import traceback_utils
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TestCase(parameterized.TestCase, unittest.TestCase):
|
| 24 |
+
maxDiff = None
|
| 25 |
+
|
| 26 |
+
def __init__(self, *args, **kwargs):
|
| 27 |
+
super().__init__(*args, **kwargs)
|
| 28 |
+
|
| 29 |
+
def setUp(self):
|
| 30 |
+
# clear global state so that test cases are independent
|
| 31 |
+
# required for the jit enabled torch tests since dynamo has
|
| 32 |
+
# a global cache for guards, compiled fn, etc
|
| 33 |
+
clear_session(free_memory=False)
|
| 34 |
+
if traceback_utils.is_traceback_filtering_enabled():
|
| 35 |
+
traceback_utils.disable_traceback_filtering()
|
| 36 |
+
|
| 37 |
+
def get_temp_dir(self):
|
| 38 |
+
temp_dir = tempfile.mkdtemp()
|
| 39 |
+
self.addCleanup(lambda: shutil.rmtree(temp_dir))
|
| 40 |
+
return temp_dir
|
| 41 |
+
|
| 42 |
+
def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
|
| 43 |
+
if not isinstance(x1, np.ndarray):
|
| 44 |
+
x1 = backend.convert_to_numpy(x1)
|
| 45 |
+
if not isinstance(x2, np.ndarray):
|
| 46 |
+
x2 = backend.convert_to_numpy(x2)
|
| 47 |
+
np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg)
|
| 48 |
+
|
| 49 |
+
def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
|
| 50 |
+
try:
|
| 51 |
+
self.assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg)
|
| 52 |
+
except AssertionError:
|
| 53 |
+
return
|
| 54 |
+
msg = msg or ""
|
| 55 |
+
raise AssertionError(
|
| 56 |
+
f"The two values are close at all elements. \n"
|
| 57 |
+
f"{msg}.\n"
|
| 58 |
+
f"Values: {x1}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
|
| 62 |
+
msg = msg or ""
|
| 63 |
+
if not isinstance(x1, np.ndarray):
|
| 64 |
+
x1 = backend.convert_to_numpy(x1)
|
| 65 |
+
if not isinstance(x2, np.ndarray):
|
| 66 |
+
x2 = backend.convert_to_numpy(x2)
|
| 67 |
+
np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg)
|
| 68 |
+
|
| 69 |
+
def assertAllEqual(self, x1, x2, msg=None):
|
| 70 |
+
self.assertEqual(len(x1), len(x2), msg=msg)
|
| 71 |
+
for e1, e2 in zip(x1, x2):
|
| 72 |
+
if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)):
|
| 73 |
+
self.assertAllEqual(e1, e2, msg=msg)
|
| 74 |
+
else:
|
| 75 |
+
e1 = backend.convert_to_numpy(e1)
|
| 76 |
+
e2 = backend.convert_to_numpy(e2)
|
| 77 |
+
self.assertEqual(e1, e2, msg=msg)
|
| 78 |
+
|
| 79 |
+
def assertLen(self, iterable, expected_len, msg=None):
|
| 80 |
+
self.assertEqual(len(iterable), expected_len, msg=msg)
|
| 81 |
+
|
| 82 |
+
def assertSparse(self, x, sparse=True):
|
| 83 |
+
if isinstance(x, KerasTensor):
|
| 84 |
+
self.assertEqual(x.sparse, sparse)
|
| 85 |
+
elif backend.backend() == "tensorflow":
|
| 86 |
+
import tensorflow as tf
|
| 87 |
+
|
| 88 |
+
if sparse:
|
| 89 |
+
self.assertIsInstance(x, tf.SparseTensor)
|
| 90 |
+
else:
|
| 91 |
+
self.assertNotIsInstance(x, tf.SparseTensor)
|
| 92 |
+
elif backend.backend() == "jax":
|
| 93 |
+
import jax.experimental.sparse as jax_sparse
|
| 94 |
+
|
| 95 |
+
if sparse:
|
| 96 |
+
self.assertIsInstance(x, jax_sparse.JAXSparse)
|
| 97 |
+
else:
|
| 98 |
+
self.assertNotIsInstance(x, jax_sparse.JAXSparse)
|
| 99 |
+
else:
|
| 100 |
+
self.assertFalse(
|
| 101 |
+
sparse,
|
| 102 |
+
f"Backend {backend.backend()} does not support sparse tensors",
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
def assertDType(self, x, dtype, msg=None):
|
| 106 |
+
if hasattr(x, "dtype"):
|
| 107 |
+
x_dtype = backend.standardize_dtype(x.dtype)
|
| 108 |
+
else:
|
| 109 |
+
# If x is a python number
|
| 110 |
+
x_dtype = backend.standardize_dtype(type(x))
|
| 111 |
+
standardized_dtype = backend.standardize_dtype(dtype)
|
| 112 |
+
default_msg = (
|
| 113 |
+
"The dtype of x does not match the expected one. "
|
| 114 |
+
f"Received: x.dtype={x_dtype} and dtype={dtype}"
|
| 115 |
+
)
|
| 116 |
+
msg = msg or default_msg
|
| 117 |
+
self.assertEqual(x_dtype, standardized_dtype, msg=msg)
|
| 118 |
+
|
| 119 |
+
def assertFileExists(self, path):
|
| 120 |
+
if not Path(path).is_file():
|
| 121 |
+
raise AssertionError(f"File {path} does not exist")
|
| 122 |
+
|
| 123 |
+
def run_class_serialization_test(self, instance, custom_objects=None):
|
| 124 |
+
from keras.src.saving import custom_object_scope
|
| 125 |
+
from keras.src.saving import deserialize_keras_object
|
| 126 |
+
from keras.src.saving import serialize_keras_object
|
| 127 |
+
|
| 128 |
+
# get_config roundtrip
|
| 129 |
+
cls = instance.__class__
|
| 130 |
+
config = instance.get_config()
|
| 131 |
+
config_json = to_json_with_tuples(config)
|
| 132 |
+
ref_dir = dir(instance)[:]
|
| 133 |
+
with custom_object_scope(custom_objects):
|
| 134 |
+
revived_instance = cls.from_config(config)
|
| 135 |
+
revived_config = revived_instance.get_config()
|
| 136 |
+
revived_config_json = to_json_with_tuples(revived_config)
|
| 137 |
+
self.assertEqual(config_json, revived_config_json)
|
| 138 |
+
self.assertEqual(set(ref_dir), set(dir(revived_instance)))
|
| 139 |
+
|
| 140 |
+
# serialization roundtrip
|
| 141 |
+
serialized = serialize_keras_object(instance)
|
| 142 |
+
serialized_json = to_json_with_tuples(serialized)
|
| 143 |
+
with custom_object_scope(custom_objects):
|
| 144 |
+
revived_instance = deserialize_keras_object(
|
| 145 |
+
from_json_with_tuples(serialized_json)
|
| 146 |
+
)
|
| 147 |
+
revived_config = revived_instance.get_config()
|
| 148 |
+
revived_config_json = to_json_with_tuples(revived_config)
|
| 149 |
+
self.assertEqual(config_json, revived_config_json)
|
| 150 |
+
new_dir = dir(revived_instance)[:]
|
| 151 |
+
for lst in [ref_dir, new_dir]:
|
| 152 |
+
if "__annotations__" in lst:
|
| 153 |
+
lst.remove("__annotations__")
|
| 154 |
+
self.assertEqual(set(ref_dir), set(new_dir))
|
| 155 |
+
return revived_instance
|
| 156 |
+
|
| 157 |
+
def run_layer_test(
|
| 158 |
+
self,
|
| 159 |
+
layer_cls,
|
| 160 |
+
init_kwargs,
|
| 161 |
+
input_shape=None,
|
| 162 |
+
input_dtype=None,
|
| 163 |
+
input_sparse=False,
|
| 164 |
+
input_data=None,
|
| 165 |
+
call_kwargs=None,
|
| 166 |
+
expected_output_shape=None,
|
| 167 |
+
expected_output_dtype=None,
|
| 168 |
+
expected_output_sparse=False,
|
| 169 |
+
expected_output=None,
|
| 170 |
+
expected_num_trainable_weights=None,
|
| 171 |
+
expected_num_non_trainable_weights=None,
|
| 172 |
+
expected_num_non_trainable_variables=None,
|
| 173 |
+
expected_num_seed_generators=None,
|
| 174 |
+
expected_num_losses=None,
|
| 175 |
+
supports_masking=None,
|
| 176 |
+
expected_mask_shape=None,
|
| 177 |
+
custom_objects=None,
|
| 178 |
+
run_training_check=True,
|
| 179 |
+
run_mixed_precision_check=True,
|
| 180 |
+
assert_built_after_instantiation=False,
|
| 181 |
+
):
|
| 182 |
+
"""Run basic checks on a layer.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
layer_cls: The class of the layer to test.
|
| 186 |
+
init_kwargs: Dict of arguments to be used to
|
| 187 |
+
instantiate the layer.
|
| 188 |
+
input_shape: Shape tuple (or list/dict of shape tuples)
|
| 189 |
+
to call the layer on.
|
| 190 |
+
input_dtype: Corresponding input dtype.
|
| 191 |
+
input_sparse: Whether the input is a sparse tensor (this requires
|
| 192 |
+
the backend to support sparse tensors).
|
| 193 |
+
input_data: Tensor (or list/dict of tensors)
|
| 194 |
+
to call the layer on.
|
| 195 |
+
call_kwargs: Dict of arguments to use when calling the
|
| 196 |
+
layer (does not include the first input tensor argument)
|
| 197 |
+
expected_output_shape: Shape tuple
|
| 198 |
+
(or list/dict of shape tuples)
|
| 199 |
+
expected as output.
|
| 200 |
+
expected_output_dtype: dtype expected as output.
|
| 201 |
+
expected_output_sparse: Whether the output is expected to be sparse
|
| 202 |
+
(this requires the backend to support sparse tensors).
|
| 203 |
+
expected_output: Expected output tensor -- only
|
| 204 |
+
to be specified if input_data is provided.
|
| 205 |
+
expected_num_trainable_weights: Expected number
|
| 206 |
+
of trainable weights of the layer once built.
|
| 207 |
+
expected_num_non_trainable_weights: Expected number
|
| 208 |
+
of non-trainable weights of the layer once built.
|
| 209 |
+
expected_num_seed_generators: Expected number of
|
| 210 |
+
SeedGenerators objects of the layer once built.
|
| 211 |
+
expected_num_losses: Expected number of loss tensors
|
| 212 |
+
produced when calling the layer.
|
| 213 |
+
supports_masking: If True, will check that the layer
|
| 214 |
+
supports masking.
|
| 215 |
+
expected_mask_shape: Expected mask shape tuple
|
| 216 |
+
returned by compute_mask() (only supports 1 shape).
|
| 217 |
+
custom_objects: Dict of any custom objects to be
|
| 218 |
+
considered during deserialization.
|
| 219 |
+
run_training_check: Whether to attempt to train the layer
|
| 220 |
+
(if an input shape or input data was provided).
|
| 221 |
+
run_mixed_precision_check: Whether to test the layer with a mixed
|
| 222 |
+
precision dtype policy.
|
| 223 |
+
assert_built_after_instantiation: Whether to assert `built=True`
|
| 224 |
+
after the layer's instantiation.
|
| 225 |
+
"""
|
| 226 |
+
if input_shape is not None and input_data is not None:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
"input_shape and input_data cannot be passed "
|
| 229 |
+
"at the same time."
|
| 230 |
+
)
|
| 231 |
+
if expected_output_shape is not None and expected_output is not None:
|
| 232 |
+
raise ValueError(
|
| 233 |
+
"expected_output_shape and expected_output cannot be passed "
|
| 234 |
+
"at the same time."
|
| 235 |
+
)
|
| 236 |
+
if expected_output is not None and input_data is None:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
"In order to use expected_output, input_data must be provided."
|
| 239 |
+
)
|
| 240 |
+
if expected_mask_shape is not None and supports_masking is not True:
|
| 241 |
+
raise ValueError(
|
| 242 |
+
"In order to use expected_mask_shape, supports_masking "
|
| 243 |
+
"must be True."
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
init_kwargs = init_kwargs or {}
|
| 247 |
+
call_kwargs = call_kwargs or {}
|
| 248 |
+
|
| 249 |
+
if input_shape is not None and input_dtype is not None:
|
| 250 |
+
if isinstance(input_shape, tuple) and is_shape_tuple(
|
| 251 |
+
input_shape[0]
|
| 252 |
+
):
|
| 253 |
+
self.assertIsInstance(input_dtype, tuple)
|
| 254 |
+
self.assertEqual(
|
| 255 |
+
len(input_shape),
|
| 256 |
+
len(input_dtype),
|
| 257 |
+
msg="The number of input shapes and dtypes does not match",
|
| 258 |
+
)
|
| 259 |
+
elif isinstance(input_shape, dict):
|
| 260 |
+
self.assertIsInstance(input_dtype, dict)
|
| 261 |
+
self.assertEqual(
|
| 262 |
+
set(input_shape.keys()),
|
| 263 |
+
set(input_dtype.keys()),
|
| 264 |
+
msg="The number of input shapes and dtypes does not match",
|
| 265 |
+
)
|
| 266 |
+
elif isinstance(input_shape, list):
|
| 267 |
+
self.assertIsInstance(input_dtype, list)
|
| 268 |
+
self.assertEqual(
|
| 269 |
+
len(input_shape),
|
| 270 |
+
len(input_dtype),
|
| 271 |
+
msg="The number of input shapes and dtypes does not match",
|
| 272 |
+
)
|
| 273 |
+
elif not isinstance(input_shape, tuple):
|
| 274 |
+
raise ValueError("The type of input_shape is not supported")
|
| 275 |
+
if input_shape is not None and input_dtype is None:
|
| 276 |
+
input_dtype = tree.map_shape_structure(
|
| 277 |
+
lambda _: "float32", input_shape
|
| 278 |
+
)
|
| 279 |
+
|
| 280 |
+
# Estimate actual number of weights, variables, seed generators if
|
| 281 |
+
# expected ones not set. When using layers uses composition it should
|
| 282 |
+
# build each sublayer manually.
|
| 283 |
+
if input_data is not None or input_shape is not None:
|
| 284 |
+
if input_data is None:
|
| 285 |
+
input_data = create_eager_tensors(
|
| 286 |
+
input_shape, input_dtype, input_sparse
|
| 287 |
+
)
|
| 288 |
+
layer = layer_cls(**init_kwargs)
|
| 289 |
+
if isinstance(input_data, dict):
|
| 290 |
+
layer(**input_data, **call_kwargs)
|
| 291 |
+
else:
|
| 292 |
+
layer(input_data, **call_kwargs)
|
| 293 |
+
|
| 294 |
+
if expected_num_trainable_weights is None:
|
| 295 |
+
expected_num_trainable_weights = len(layer.trainable_weights)
|
| 296 |
+
if expected_num_non_trainable_weights is None:
|
| 297 |
+
expected_num_non_trainable_weights = len(
|
| 298 |
+
layer.non_trainable_weights
|
| 299 |
+
)
|
| 300 |
+
if expected_num_non_trainable_variables is None:
|
| 301 |
+
expected_num_non_trainable_variables = len(
|
| 302 |
+
layer.non_trainable_variables
|
| 303 |
+
)
|
| 304 |
+
if expected_num_seed_generators is None:
|
| 305 |
+
expected_num_seed_generators = len(get_seed_generators(layer))
|
| 306 |
+
|
| 307 |
+
# Serialization test.
|
| 308 |
+
layer = layer_cls(**init_kwargs)
|
| 309 |
+
self.run_class_serialization_test(layer, custom_objects)
|
| 310 |
+
|
| 311 |
+
# Basic masking test.
|
| 312 |
+
if supports_masking is not None:
|
| 313 |
+
self.assertEqual(
|
| 314 |
+
layer.supports_masking,
|
| 315 |
+
supports_masking,
|
| 316 |
+
msg="Unexpected supports_masking value",
|
| 317 |
+
)
|
| 318 |
+
|
| 319 |
+
def run_build_asserts(layer):
|
| 320 |
+
self.assertTrue(layer.built)
|
| 321 |
+
if expected_num_trainable_weights is not None:
|
| 322 |
+
self.assertLen(
|
| 323 |
+
layer.trainable_weights,
|
| 324 |
+
expected_num_trainable_weights,
|
| 325 |
+
msg="Unexpected number of trainable_weights",
|
| 326 |
+
)
|
| 327 |
+
if expected_num_non_trainable_weights is not None:
|
| 328 |
+
self.assertLen(
|
| 329 |
+
layer.non_trainable_weights,
|
| 330 |
+
expected_num_non_trainable_weights,
|
| 331 |
+
msg="Unexpected number of non_trainable_weights",
|
| 332 |
+
)
|
| 333 |
+
if expected_num_non_trainable_variables is not None:
|
| 334 |
+
self.assertLen(
|
| 335 |
+
layer.non_trainable_variables,
|
| 336 |
+
expected_num_non_trainable_variables,
|
| 337 |
+
msg="Unexpected number of non_trainable_variables",
|
| 338 |
+
)
|
| 339 |
+
if expected_num_seed_generators is not None:
|
| 340 |
+
self.assertLen(
|
| 341 |
+
get_seed_generators(layer),
|
| 342 |
+
expected_num_seed_generators,
|
| 343 |
+
msg="Unexpected number of seed_generators",
|
| 344 |
+
)
|
| 345 |
+
if (
|
| 346 |
+
backend.backend() == "torch"
|
| 347 |
+
and expected_num_trainable_weights is not None
|
| 348 |
+
and expected_num_non_trainable_weights is not None
|
| 349 |
+
and expected_num_seed_generators is not None
|
| 350 |
+
):
|
| 351 |
+
self.assertLen(
|
| 352 |
+
layer.torch_params,
|
| 353 |
+
expected_num_trainable_weights
|
| 354 |
+
+ expected_num_non_trainable_weights
|
| 355 |
+
+ expected_num_seed_generators,
|
| 356 |
+
msg="Unexpected number of torch_params",
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
def run_output_asserts(layer, output, eager=False):
|
| 360 |
+
if expected_output_shape is not None:
|
| 361 |
+
if isinstance(expected_output_shape, tuple) and is_shape_tuple(
|
| 362 |
+
expected_output_shape[0]
|
| 363 |
+
):
|
| 364 |
+
self.assertIsInstance(output, tuple)
|
| 365 |
+
self.assertEqual(
|
| 366 |
+
len(output),
|
| 367 |
+
len(expected_output_shape),
|
| 368 |
+
msg="Unexpected number of outputs",
|
| 369 |
+
)
|
| 370 |
+
output_shape = tuple(v.shape for v in output)
|
| 371 |
+
self.assertEqual(
|
| 372 |
+
expected_output_shape,
|
| 373 |
+
output_shape,
|
| 374 |
+
msg="Unexpected output shape",
|
| 375 |
+
)
|
| 376 |
+
elif isinstance(expected_output_shape, tuple):
|
| 377 |
+
self.assertEqual(
|
| 378 |
+
expected_output_shape,
|
| 379 |
+
output.shape,
|
| 380 |
+
msg="Unexpected output shape",
|
| 381 |
+
)
|
| 382 |
+
elif isinstance(expected_output_shape, dict):
|
| 383 |
+
self.assertIsInstance(output, dict)
|
| 384 |
+
self.assertEqual(
|
| 385 |
+
set(output.keys()),
|
| 386 |
+
set(expected_output_shape.keys()),
|
| 387 |
+
msg="Unexpected output dict keys",
|
| 388 |
+
)
|
| 389 |
+
output_shape = {k: v.shape for k, v in output.items()}
|
| 390 |
+
self.assertEqual(
|
| 391 |
+
expected_output_shape,
|
| 392 |
+
output_shape,
|
| 393 |
+
msg="Unexpected output shape",
|
| 394 |
+
)
|
| 395 |
+
elif isinstance(expected_output_shape, list):
|
| 396 |
+
self.assertIsInstance(output, list)
|
| 397 |
+
self.assertEqual(
|
| 398 |
+
len(output),
|
| 399 |
+
len(expected_output_shape),
|
| 400 |
+
msg="Unexpected number of outputs",
|
| 401 |
+
)
|
| 402 |
+
output_shape = [v.shape for v in output]
|
| 403 |
+
self.assertEqual(
|
| 404 |
+
expected_output_shape,
|
| 405 |
+
output_shape,
|
| 406 |
+
msg="Unexpected output shape",
|
| 407 |
+
)
|
| 408 |
+
else:
|
| 409 |
+
raise ValueError(
|
| 410 |
+
"The type of expected_output_shape is not supported"
|
| 411 |
+
)
|
| 412 |
+
if expected_output_dtype is not None:
|
| 413 |
+
if isinstance(expected_output_dtype, tuple):
|
| 414 |
+
self.assertIsInstance(output, tuple)
|
| 415 |
+
self.assertEqual(
|
| 416 |
+
len(output),
|
| 417 |
+
len(expected_output_dtype),
|
| 418 |
+
msg="Unexpected number of outputs",
|
| 419 |
+
)
|
| 420 |
+
output_dtype = tuple(
|
| 421 |
+
backend.standardize_dtype(v.dtype) for v in output
|
| 422 |
+
)
|
| 423 |
+
self.assertEqual(
|
| 424 |
+
expected_output_dtype,
|
| 425 |
+
output_dtype,
|
| 426 |
+
msg="Unexpected output dtype",
|
| 427 |
+
)
|
| 428 |
+
elif isinstance(expected_output_dtype, dict):
|
| 429 |
+
self.assertIsInstance(output, dict)
|
| 430 |
+
self.assertEqual(
|
| 431 |
+
set(output.keys()),
|
| 432 |
+
set(expected_output_dtype.keys()),
|
| 433 |
+
msg="Unexpected output dict keys",
|
| 434 |
+
)
|
| 435 |
+
output_dtype = {
|
| 436 |
+
k: backend.standardize_dtype(v.dtype)
|
| 437 |
+
for k, v in output.items()
|
| 438 |
+
}
|
| 439 |
+
self.assertEqual(
|
| 440 |
+
expected_output_dtype,
|
| 441 |
+
output_dtype,
|
| 442 |
+
msg="Unexpected output dtype",
|
| 443 |
+
)
|
| 444 |
+
elif isinstance(expected_output_dtype, list):
|
| 445 |
+
self.assertIsInstance(output, list)
|
| 446 |
+
self.assertEqual(
|
| 447 |
+
len(output),
|
| 448 |
+
len(expected_output_dtype),
|
| 449 |
+
msg="Unexpected number of outputs",
|
| 450 |
+
)
|
| 451 |
+
output_dtype = [
|
| 452 |
+
backend.standardize_dtype(v.dtype) for v in output
|
| 453 |
+
]
|
| 454 |
+
self.assertEqual(
|
| 455 |
+
expected_output_dtype,
|
| 456 |
+
output_dtype,
|
| 457 |
+
msg="Unexpected output dtype",
|
| 458 |
+
)
|
| 459 |
+
else:
|
| 460 |
+
output_dtype = tree.flatten(output)[0].dtype
|
| 461 |
+
self.assertEqual(
|
| 462 |
+
expected_output_dtype,
|
| 463 |
+
backend.standardize_dtype(output_dtype),
|
| 464 |
+
msg="Unexpected output dtype",
|
| 465 |
+
)
|
| 466 |
+
if expected_output_sparse:
|
| 467 |
+
for x in tree.flatten(output):
|
| 468 |
+
self.assertSparse(x)
|
| 469 |
+
if eager:
|
| 470 |
+
if expected_output is not None:
|
| 471 |
+
self.assertEqual(type(expected_output), type(output))
|
| 472 |
+
for ref_v, v in zip(
|
| 473 |
+
tree.flatten(expected_output), tree.flatten(output)
|
| 474 |
+
):
|
| 475 |
+
self.assertAllClose(
|
| 476 |
+
ref_v, v, msg="Unexpected output value"
|
| 477 |
+
)
|
| 478 |
+
if expected_num_losses is not None:
|
| 479 |
+
self.assertLen(layer.losses, expected_num_losses)
|
| 480 |
+
|
| 481 |
+
def run_training_step(layer, input_data, output_data):
|
| 482 |
+
class TestModel(Model):
|
| 483 |
+
def __init__(self, layer):
|
| 484 |
+
super().__init__()
|
| 485 |
+
self.layer = layer
|
| 486 |
+
|
| 487 |
+
def call(self, x, training=False):
|
| 488 |
+
return self.layer(x, training=training)
|
| 489 |
+
|
| 490 |
+
model = TestModel(layer)
|
| 491 |
+
|
| 492 |
+
data = (input_data, output_data)
|
| 493 |
+
if backend.backend() == "torch":
|
| 494 |
+
data = tree.map_structure(backend.convert_to_numpy, data)
|
| 495 |
+
|
| 496 |
+
def data_generator():
|
| 497 |
+
while True:
|
| 498 |
+
yield data
|
| 499 |
+
|
| 500 |
+
# test the "default" path for each backend by setting
|
| 501 |
+
# jit_compile="auto".
|
| 502 |
+
# for tensorflow and jax backends auto is jitted
|
| 503 |
+
# Note that tensorflow cannot be jitted with sparse tensors
|
| 504 |
+
# for torch backend auto is eager
|
| 505 |
+
#
|
| 506 |
+
# NB: for torch, jit_compile=True turns on torchdynamo
|
| 507 |
+
# which may not always succeed in tracing depending
|
| 508 |
+
# on the model. Run your program with these env vars
|
| 509 |
+
# to get debug traces of dynamo:
|
| 510 |
+
# TORCH_LOGS="+dynamo"
|
| 511 |
+
# TORCHDYNAMO_VERBOSE=1
|
| 512 |
+
# TORCHDYNAMO_REPORT_GUARD_FAILURES=1
|
| 513 |
+
jit_compile = "auto"
|
| 514 |
+
if backend.backend() == "tensorflow" and input_sparse:
|
| 515 |
+
jit_compile = False
|
| 516 |
+
model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
|
| 517 |
+
model.fit(data_generator(), steps_per_epoch=1, verbose=0)
|
| 518 |
+
|
| 519 |
+
# Build test.
|
| 520 |
+
if input_data is not None or input_shape is not None:
|
| 521 |
+
if input_shape is None:
|
| 522 |
+
build_shape = tree.map_structure(
|
| 523 |
+
lambda x: ops.shape(x), input_data
|
| 524 |
+
)
|
| 525 |
+
else:
|
| 526 |
+
build_shape = input_shape
|
| 527 |
+
layer = layer_cls(**init_kwargs)
|
| 528 |
+
if isinstance(build_shape, dict):
|
| 529 |
+
layer.build(**build_shape)
|
| 530 |
+
else:
|
| 531 |
+
layer.build(build_shape)
|
| 532 |
+
run_build_asserts(layer)
|
| 533 |
+
|
| 534 |
+
# Symbolic call test.
|
| 535 |
+
if input_shape is None:
|
| 536 |
+
keras_tensor_inputs = tree.map_structure(
|
| 537 |
+
lambda x: create_keras_tensors(
|
| 538 |
+
ops.shape(x), x.dtype, input_sparse
|
| 539 |
+
),
|
| 540 |
+
input_data,
|
| 541 |
+
)
|
| 542 |
+
else:
|
| 543 |
+
keras_tensor_inputs = create_keras_tensors(
|
| 544 |
+
input_shape, input_dtype, input_sparse
|
| 545 |
+
)
|
| 546 |
+
layer = layer_cls(**init_kwargs)
|
| 547 |
+
if isinstance(keras_tensor_inputs, dict):
|
| 548 |
+
keras_tensor_outputs = layer(
|
| 549 |
+
**keras_tensor_inputs, **call_kwargs
|
| 550 |
+
)
|
| 551 |
+
else:
|
| 552 |
+
keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
|
| 553 |
+
run_build_asserts(layer)
|
| 554 |
+
run_output_asserts(layer, keras_tensor_outputs, eager=False)
|
| 555 |
+
|
| 556 |
+
if expected_mask_shape is not None:
|
| 557 |
+
output_mask = layer.compute_mask(keras_tensor_inputs)
|
| 558 |
+
self.assertEqual(expected_mask_shape, output_mask.shape)
|
| 559 |
+
|
| 560 |
+
# The stateless layers should be built after instantiation.
|
| 561 |
+
if assert_built_after_instantiation:
|
| 562 |
+
layer = layer_cls(**init_kwargs)
|
| 563 |
+
self.assertTrue(
|
| 564 |
+
layer.built,
|
| 565 |
+
msg=(
|
| 566 |
+
f"{type(layer)} is stateless, so it should be built "
|
| 567 |
+
"after instantiation."
|
| 568 |
+
),
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Eager call test and compiled training test.
|
| 572 |
+
if input_data is not None or input_shape is not None:
|
| 573 |
+
if input_data is None:
|
| 574 |
+
input_data = create_eager_tensors(
|
| 575 |
+
input_shape, input_dtype, input_sparse
|
| 576 |
+
)
|
| 577 |
+
layer = layer_cls(**init_kwargs)
|
| 578 |
+
if isinstance(input_data, dict):
|
| 579 |
+
output_data = layer(**input_data, **call_kwargs)
|
| 580 |
+
else:
|
| 581 |
+
output_data = layer(input_data, **call_kwargs)
|
| 582 |
+
run_output_asserts(layer, output_data, eager=True)
|
| 583 |
+
|
| 584 |
+
if run_training_check:
|
| 585 |
+
run_training_step(layer, input_data, output_data)
|
| 586 |
+
|
| 587 |
+
# Never test mixed precision on torch CPU. Torch lacks support.
|
| 588 |
+
if run_mixed_precision_check and backend.backend() == "torch":
|
| 589 |
+
import torch
|
| 590 |
+
|
| 591 |
+
run_mixed_precision_check = torch.cuda.is_available()
|
| 592 |
+
|
| 593 |
+
if run_mixed_precision_check:
|
| 594 |
+
layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"})
|
| 595 |
+
input_spec = tree.map_structure(
|
| 596 |
+
lambda spec: KerasTensor(
|
| 597 |
+
spec.shape,
|
| 598 |
+
dtype=(
|
| 599 |
+
layer.compute_dtype
|
| 600 |
+
if layer.autocast
|
| 601 |
+
and backend.is_float_dtype(spec.dtype)
|
| 602 |
+
else spec.dtype
|
| 603 |
+
),
|
| 604 |
+
),
|
| 605 |
+
keras_tensor_inputs,
|
| 606 |
+
)
|
| 607 |
+
if isinstance(input_data, dict):
|
| 608 |
+
output_data = layer(**input_data, **call_kwargs)
|
| 609 |
+
output_spec = layer.compute_output_spec(**input_spec)
|
| 610 |
+
else:
|
| 611 |
+
output_data = layer(input_data, **call_kwargs)
|
| 612 |
+
output_spec = layer.compute_output_spec(input_spec)
|
| 613 |
+
for tensor, spec in zip(
|
| 614 |
+
tree.flatten(output_data), tree.flatten(output_spec)
|
| 615 |
+
):
|
| 616 |
+
dtype = standardize_dtype(tensor.dtype)
|
| 617 |
+
self.assertEqual(
|
| 618 |
+
dtype,
|
| 619 |
+
spec.dtype,
|
| 620 |
+
f"expected output dtype {spec.dtype}, got {dtype}",
|
| 621 |
+
)
|
| 622 |
+
for weight in layer.weights:
|
| 623 |
+
dtype = standardize_dtype(weight.dtype)
|
| 624 |
+
if is_float_dtype(dtype):
|
| 625 |
+
self.assertEqual(dtype, "float32")
|
| 626 |
+
|
| 627 |
+
|
| 628 |
+
def tensorflow_uses_gpu():
|
| 629 |
+
return backend.backend() == "tensorflow" and uses_gpu()
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
def jax_uses_gpu():
|
| 633 |
+
return backend.backend() == "jax" and uses_gpu()
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
def torch_uses_gpu():
|
| 637 |
+
if backend.backend() != "torch":
|
| 638 |
+
return False
|
| 639 |
+
from keras.src.backend.torch.core import get_device
|
| 640 |
+
|
| 641 |
+
return get_device() == "cuda"
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def uses_gpu():
|
| 645 |
+
# Condition used to skip tests when using the GPU
|
| 646 |
+
devices = distribution.list_devices()
|
| 647 |
+
if any(d.startswith("gpu") for d in devices):
|
| 648 |
+
return True
|
| 649 |
+
return False
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def create_keras_tensors(input_shape, dtype, sparse):
|
| 653 |
+
if isinstance(input_shape, dict):
|
| 654 |
+
return {
|
| 655 |
+
utils.removesuffix(k, "_shape"): KerasTensor(
|
| 656 |
+
v, dtype=dtype[k], sparse=sparse
|
| 657 |
+
)
|
| 658 |
+
for k, v in input_shape.items()
|
| 659 |
+
}
|
| 660 |
+
return map_shape_dtype_structure(
|
| 661 |
+
lambda shape, dt: KerasTensor(shape, dtype=dt, sparse=sparse),
|
| 662 |
+
input_shape,
|
| 663 |
+
dtype,
|
| 664 |
+
)
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def create_eager_tensors(input_shape, dtype, sparse):
|
| 668 |
+
from keras.src.backend import random
|
| 669 |
+
|
| 670 |
+
if set(tree.flatten(dtype)).difference(
|
| 671 |
+
[
|
| 672 |
+
"float16",
|
| 673 |
+
"float32",
|
| 674 |
+
"float64",
|
| 675 |
+
"int8",
|
| 676 |
+
"uint8",
|
| 677 |
+
"int16",
|
| 678 |
+
"uint16",
|
| 679 |
+
"int32",
|
| 680 |
+
"uint32",
|
| 681 |
+
"int64",
|
| 682 |
+
"uint64",
|
| 683 |
+
]
|
| 684 |
+
):
|
| 685 |
+
raise ValueError(
|
| 686 |
+
"dtype must be a standard float or int dtype. "
|
| 687 |
+
f"Received: dtype={dtype}"
|
| 688 |
+
)
|
| 689 |
+
|
| 690 |
+
if sparse:
|
| 691 |
+
if backend.backend() == "tensorflow":
|
| 692 |
+
import tensorflow as tf
|
| 693 |
+
|
| 694 |
+
def create_fn(shape, dt):
|
| 695 |
+
rng = np.random.default_rng(0)
|
| 696 |
+
x = (4 * rng.standard_normal(shape)).astype(dt)
|
| 697 |
+
x = np.multiply(x, rng.random(shape) < 0.7)
|
| 698 |
+
return tf.sparse.from_dense(x)
|
| 699 |
+
|
| 700 |
+
elif backend.backend() == "jax":
|
| 701 |
+
import jax.experimental.sparse as jax_sparse
|
| 702 |
+
|
| 703 |
+
def create_fn(shape, dt):
|
| 704 |
+
rng = np.random.default_rng(0)
|
| 705 |
+
x = (4 * rng.standard_normal(shape)).astype(dt)
|
| 706 |
+
x = np.multiply(x, rng.random(shape) < 0.7)
|
| 707 |
+
return jax_sparse.BCOO.fromdense(x, n_batch=1)
|
| 708 |
+
|
| 709 |
+
else:
|
| 710 |
+
raise ValueError(
|
| 711 |
+
f"Sparse is unsupported with backend {backend.backend()}"
|
| 712 |
+
)
|
| 713 |
+
|
| 714 |
+
else:
|
| 715 |
+
|
| 716 |
+
def create_fn(shape, dt):
|
| 717 |
+
return ops.cast(
|
| 718 |
+
random.uniform(shape, dtype="float32") * 3, dtype=dt
|
| 719 |
+
)
|
| 720 |
+
|
| 721 |
+
if isinstance(input_shape, dict):
|
| 722 |
+
return {
|
| 723 |
+
utils.removesuffix(k, "_shape"): create_fn(v, dtype[k])
|
| 724 |
+
for k, v in input_shape.items()
|
| 725 |
+
}
|
| 726 |
+
return map_shape_dtype_structure(create_fn, input_shape, dtype)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def is_shape_tuple(x):
|
| 730 |
+
return isinstance(x, (list, tuple)) and all(
|
| 731 |
+
isinstance(e, (int, type(None))) for e in x
|
| 732 |
+
)
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def map_shape_dtype_structure(fn, shape, dtype):
|
| 736 |
+
"""Variant of tree.map_structure that operates on shape tuples."""
|
| 737 |
+
if is_shape_tuple(shape):
|
| 738 |
+
return fn(tuple(shape), dtype)
|
| 739 |
+
if isinstance(shape, list):
|
| 740 |
+
return [
|
| 741 |
+
map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype)
|
| 742 |
+
]
|
| 743 |
+
if isinstance(shape, tuple):
|
| 744 |
+
return tuple(
|
| 745 |
+
map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype)
|
| 746 |
+
)
|
| 747 |
+
if isinstance(shape, dict):
|
| 748 |
+
return {
|
| 749 |
+
k: map_shape_dtype_structure(fn, v, dtype[k])
|
| 750 |
+
for k, v in shape.items()
|
| 751 |
+
}
|
| 752 |
+
else:
|
| 753 |
+
raise ValueError(
|
| 754 |
+
f"Cannot map function to unknown objects {shape} and {dtype}"
|
| 755 |
+
)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
def get_seed_generators(layer):
|
| 759 |
+
"""Get a List of all seed generators in the layer recursively."""
|
| 760 |
+
seed_generators = []
|
| 761 |
+
seen_ids = set()
|
| 762 |
+
for sublayer in layer._flatten_layers(True, True):
|
| 763 |
+
for sg in sublayer._seed_generators:
|
| 764 |
+
if id(sg) not in seen_ids:
|
| 765 |
+
seed_generators.append(sg)
|
| 766 |
+
seen_ids.add(id(sg))
|
| 767 |
+
return seed_generators
|
| 768 |
+
|
| 769 |
+
|
| 770 |
+
def to_json_with_tuples(value):
|
| 771 |
+
def _tuple_encode(obj):
|
| 772 |
+
if isinstance(obj, tuple):
|
| 773 |
+
return {"__class__": "tuple", "__value__": list(obj)}
|
| 774 |
+
if isinstance(obj, list):
|
| 775 |
+
return [_tuple_encode(e) for e in obj]
|
| 776 |
+
if isinstance(obj, dict):
|
| 777 |
+
return {key: _tuple_encode(value) for key, value in obj.items()}
|
| 778 |
+
return obj
|
| 779 |
+
|
| 780 |
+
class _PreserveTupleJsonEncoder(json.JSONEncoder):
|
| 781 |
+
def encode(self, obj):
|
| 782 |
+
obj = _tuple_encode(obj)
|
| 783 |
+
return super().encode(obj)
|
| 784 |
+
|
| 785 |
+
return _PreserveTupleJsonEncoder(sort_keys=True, indent=4).encode(value)
|
| 786 |
+
|
| 787 |
+
|
| 788 |
+
def from_json_with_tuples(value):
|
| 789 |
+
def _tuple_decode(obj):
|
| 790 |
+
if not isinstance(obj, dict):
|
| 791 |
+
return obj
|
| 792 |
+
if "__class__" not in obj or "__value__" not in obj:
|
| 793 |
+
return obj
|
| 794 |
+
return tuple(obj["__value__"])
|
| 795 |
+
|
| 796 |
+
return json.loads(value, object_hook=_tuple_decode)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_test_data(
|
| 5 |
+
train_samples, test_samples, input_shape, num_classes, random_seed=None
|
| 6 |
+
):
|
| 7 |
+
"""Generates balanced, stratified synthetic test data to train a model on.
|
| 8 |
+
|
| 9 |
+
Args:
|
| 10 |
+
train_samples: Integer, how many training samples to generate.
|
| 11 |
+
test_samples: Integer, how many test samples to generate.
|
| 12 |
+
input_shape: Tuple of integers, shape of the inputs.
|
| 13 |
+
num_classes: Integer, number of classes for the data and targets.
|
| 14 |
+
random_seed: Integer, random seed used by Numpy to generate data.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
|
| 18 |
+
"""
|
| 19 |
+
np.random.seed(random_seed)
|
| 20 |
+
|
| 21 |
+
# Total samples
|
| 22 |
+
total_samples = train_samples + test_samples
|
| 23 |
+
|
| 24 |
+
# Ensure that we generate a balanced dataset
|
| 25 |
+
samples_per_class = total_samples // num_classes
|
| 26 |
+
y = np.array(
|
| 27 |
+
[i for i in range(num_classes) for _ in range(samples_per_class)],
|
| 28 |
+
dtype=np.int32,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
# Generate extra samples in a deterministic manner
|
| 32 |
+
extra_samples = total_samples - len(y)
|
| 33 |
+
y_extra = np.array(
|
| 34 |
+
[i % num_classes for i in range(extra_samples)], dtype=np.int64
|
| 35 |
+
)
|
| 36 |
+
y = np.concatenate([y, y_extra])
|
| 37 |
+
|
| 38 |
+
# Generate data
|
| 39 |
+
templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
|
| 40 |
+
x = np.zeros((total_samples,) + input_shape, dtype=np.float32)
|
| 41 |
+
for i in range(total_samples):
|
| 42 |
+
x[i] = templates[y[i]] + np.random.normal(
|
| 43 |
+
loc=0, scale=1.0, size=input_shape
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
# Shuffle the entire dataset to ensure randomness based on seed
|
| 47 |
+
indices = np.arange(total_samples)
|
| 48 |
+
np.random.shuffle(indices)
|
| 49 |
+
x, y = x[indices], y[indices]
|
| 50 |
+
|
| 51 |
+
# Stratified Shuffle Split
|
| 52 |
+
x_train, y_train, x_test, y_test = [], [], [], []
|
| 53 |
+
for cls in range(num_classes):
|
| 54 |
+
cls_indices = np.where(y == cls)[0]
|
| 55 |
+
np.random.shuffle(cls_indices)
|
| 56 |
+
train_count = int(train_samples / num_classes)
|
| 57 |
+
|
| 58 |
+
x_train.extend(x[cls_indices[:train_count]])
|
| 59 |
+
y_train.extend(y[cls_indices[:train_count]])
|
| 60 |
+
|
| 61 |
+
x_test.extend(x[cls_indices[train_count:]])
|
| 62 |
+
y_test.extend(y[cls_indices[train_count:]])
|
| 63 |
+
|
| 64 |
+
# Convert to numpy arrays
|
| 65 |
+
x_train, y_train = np.array(x_train), np.array(y_train)
|
| 66 |
+
x_test, y_test = np.array(x_test), np.array(y_test)
|
| 67 |
+
|
| 68 |
+
# Shuffle training and test sets after stratified split
|
| 69 |
+
train_indices = np.arange(len(x_train))
|
| 70 |
+
test_indices = np.arange(len(x_test))
|
| 71 |
+
np.random.shuffle(train_indices)
|
| 72 |
+
np.random.shuffle(test_indices)
|
| 73 |
+
|
| 74 |
+
x_train, y_train = x_train[train_indices], y_train[train_indices]
|
| 75 |
+
x_test, y_test = x_test[test_indices], y_test[test_indices]
|
| 76 |
+
|
| 77 |
+
return (x_train, y_train), (x_test, y_test)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def named_product(*args, **kwargs):
|
| 81 |
+
"""Utility to generate the cartesian product of parameters values and
|
| 82 |
+
generate a test case names for each combination.
|
| 83 |
+
|
| 84 |
+
The result of this function is to be used with the
|
| 85 |
+
`@parameterized.named_parameters` decorator. It is a replacement for
|
| 86 |
+
`@parameterized.product` which adds explicit test case names.
|
| 87 |
+
|
| 88 |
+
For example, this code:
|
| 89 |
+
```
|
| 90 |
+
class NamedExample(parameterized.TestCase):
|
| 91 |
+
@parameterized.named_parameters(
|
| 92 |
+
named_product(
|
| 93 |
+
[
|
| 94 |
+
{'testcase_name': 'negative', 'x': -1},
|
| 95 |
+
{'testcase_name': 'positive', 'x': 1},
|
| 96 |
+
{'testcase_name': 'zero', 'x': 0},
|
| 97 |
+
],
|
| 98 |
+
numeral_type=[float, int],
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
def test_conversion(self, x, numeral_type):
|
| 102 |
+
self.assertEqual(numeral_type(x), x)
|
| 103 |
+
```
|
| 104 |
+
produces six tests (note that absl will reorder them by name):
|
| 105 |
+
- `NamedExample::test_conversion_negative_float`
|
| 106 |
+
- `NamedExample::test_conversion_positive_float`
|
| 107 |
+
- `NamedExample::test_conversion_zero_float`
|
| 108 |
+
- `NamedExample::test_conversion_negative_int`
|
| 109 |
+
- `NamedExample::test_conversion_positive_int`
|
| 110 |
+
- `NamedExample::test_conversion_zero_int`
|
| 111 |
+
|
| 112 |
+
This function is also useful in the case where there is no product to
|
| 113 |
+
generate test case names for one argument:
|
| 114 |
+
```
|
| 115 |
+
@parameterized.named_parameters(named_product(numeral_type=[float, int]))
|
| 116 |
+
```
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
*args: Each positional parameter is a sequence of keyword arg dicts.
|
| 120 |
+
Every test case generated will include exactly one dict from each
|
| 121 |
+
positional parameter. These will then be merged to form an overall
|
| 122 |
+
list of arguments for the test case. Each dict must contain a
|
| 123 |
+
`"testcase_name"` key whose value is combined with others to
|
| 124 |
+
generate the test case name.
|
| 125 |
+
**kwargs: A mapping of parameter names and their possible values.
|
| 126 |
+
Possible values should given as either a list or a tuple. A string
|
| 127 |
+
representation of each value is used to generate the test case name.
|
| 128 |
+
|
| 129 |
+
Returns:
|
| 130 |
+
A list of maps for the test parameters combinations to pass to
|
| 131 |
+
`@parameterized.named_parameters`.
|
| 132 |
+
"""
|
| 133 |
+
|
| 134 |
+
def value_to_str(value):
|
| 135 |
+
if hasattr(value, "__name__"):
|
| 136 |
+
return value.__name__.lower()
|
| 137 |
+
return str(value).lower()
|
| 138 |
+
|
| 139 |
+
# Convert the keyword arguments in the same dict format as the args
|
| 140 |
+
all_test_dicts = args + tuple(
|
| 141 |
+
tuple({"testcase_name": value_to_str(v), key: v} for v in values)
|
| 142 |
+
for key, values in kwargs.items()
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# The current list of tests, start with one empty test
|
| 146 |
+
tests = [{}]
|
| 147 |
+
for test_dicts in all_test_dicts:
|
| 148 |
+
new_tests = []
|
| 149 |
+
for test_dict in test_dicts:
|
| 150 |
+
for test in tests:
|
| 151 |
+
# Augment the testcase name by appending
|
| 152 |
+
testcase_name = test.get("testcase_name", "")
|
| 153 |
+
testcase_name += "_" if testcase_name else ""
|
| 154 |
+
testcase_name += test_dict["testcase_name"]
|
| 155 |
+
new_test = test.copy()
|
| 156 |
+
# Augment the test by adding all the parameters
|
| 157 |
+
new_test.update(test_dict)
|
| 158 |
+
new_test["testcase_name"] = testcase_name
|
| 159 |
+
new_tests.append(new_test)
|
| 160 |
+
# Overwrite the list of tests with the product obtained so far
|
| 161 |
+
tests = new_tests
|
| 162 |
+
|
| 163 |
+
return tests
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__init__.py
ADDED
|
File without changes
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (194 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc
ADDED
|
Binary file (20.7 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc
ADDED
|
Binary file (4.34 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc
ADDED
|
Binary file (46.1 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py
ADDED
|
@@ -0,0 +1,820 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import namedtuple
|
| 2 |
+
|
| 3 |
+
from keras.src import losses as losses_module
|
| 4 |
+
from keras.src import metrics as metrics_module
|
| 5 |
+
from keras.src import ops
|
| 6 |
+
from keras.src import tree
|
| 7 |
+
from keras.src.backend.common.keras_tensor import KerasTensor
|
| 8 |
+
from keras.src.losses import loss as loss_module
|
| 9 |
+
from keras.src.utils.naming import get_object_name
|
| 10 |
+
from keras.src.utils.tracking import Tracker
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class MetricsList(metrics_module.Metric):
|
| 14 |
+
def __init__(self, metrics, name="metrics_list", output_name=None):
|
| 15 |
+
super().__init__(name=name)
|
| 16 |
+
self.metrics = metrics
|
| 17 |
+
self.output_name = output_name
|
| 18 |
+
|
| 19 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 20 |
+
for m in self.metrics:
|
| 21 |
+
m.update_state(y_true, y_pred, sample_weight=sample_weight)
|
| 22 |
+
|
| 23 |
+
def reset_state(self):
|
| 24 |
+
for m in self.metrics:
|
| 25 |
+
m.reset_state()
|
| 26 |
+
|
| 27 |
+
def get_result(self):
|
| 28 |
+
return {m.name: m.result() for m in self.metrics}
|
| 29 |
+
|
| 30 |
+
def get_config(self):
|
| 31 |
+
raise NotImplementedError
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def from_config(cls, config):
|
| 35 |
+
raise NotImplementedError
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def is_function_like(value):
|
| 39 |
+
if value is None:
|
| 40 |
+
return True
|
| 41 |
+
if isinstance(value, str):
|
| 42 |
+
return True
|
| 43 |
+
if callable(value):
|
| 44 |
+
return True
|
| 45 |
+
return False
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def is_binary_or_sparse_categorical(y_true, y_pred):
|
| 49 |
+
y_t_rank = len(y_true.shape)
|
| 50 |
+
y_p_rank = len(y_pred.shape)
|
| 51 |
+
y_t_last_dim = y_true.shape[-1]
|
| 52 |
+
y_p_last_dim = y_pred.shape[-1]
|
| 53 |
+
|
| 54 |
+
is_binary = y_p_last_dim == 1
|
| 55 |
+
is_sparse_categorical = (
|
| 56 |
+
y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1
|
| 57 |
+
)
|
| 58 |
+
return is_binary, is_sparse_categorical
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_metric(identifier, y_true, y_pred):
|
| 62 |
+
if identifier is None:
|
| 63 |
+
return None # Ok to have no metric for an output.
|
| 64 |
+
|
| 65 |
+
# Convenience feature for selecting b/t binary, categorical,
|
| 66 |
+
# and sparse categorical.
|
| 67 |
+
if str(identifier).lower() not in ["accuracy", "acc"]:
|
| 68 |
+
metric_obj = metrics_module.get(identifier)
|
| 69 |
+
else:
|
| 70 |
+
is_binary, is_sparse_categorical = is_binary_or_sparse_categorical(
|
| 71 |
+
y_true, y_pred
|
| 72 |
+
)
|
| 73 |
+
if is_binary:
|
| 74 |
+
metric_obj = metrics_module.BinaryAccuracy(name=str(identifier))
|
| 75 |
+
elif is_sparse_categorical:
|
| 76 |
+
metric_obj = metrics_module.SparseCategoricalAccuracy(
|
| 77 |
+
name=str(identifier)
|
| 78 |
+
)
|
| 79 |
+
else:
|
| 80 |
+
metric_obj = metrics_module.CategoricalAccuracy(
|
| 81 |
+
name=str(identifier)
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
if isinstance(identifier, str):
|
| 85 |
+
metric_name = identifier
|
| 86 |
+
else:
|
| 87 |
+
metric_name = get_object_name(metric_obj)
|
| 88 |
+
|
| 89 |
+
if not isinstance(metric_obj, metrics_module.Metric):
|
| 90 |
+
metric_obj = metrics_module.MeanMetricWrapper(metric_obj)
|
| 91 |
+
|
| 92 |
+
metric_obj.name = metric_name
|
| 93 |
+
return metric_obj
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_loss(identifier, y_true, y_pred):
|
| 97 |
+
if identifier is None:
|
| 98 |
+
return None # Ok to have no loss for an output.
|
| 99 |
+
|
| 100 |
+
# Convenience feature for selecting b/t binary, categorical,
|
| 101 |
+
# and sparse categorical.
|
| 102 |
+
if str(identifier).lower() not in ["crossentropy", "ce"]:
|
| 103 |
+
loss_obj = losses_module.get(identifier)
|
| 104 |
+
else:
|
| 105 |
+
is_binary, is_sparse_categorical = is_binary_or_sparse_categorical(
|
| 106 |
+
y_true, y_pred
|
| 107 |
+
)
|
| 108 |
+
if is_binary:
|
| 109 |
+
loss_obj = losses_module.binary_crossentropy
|
| 110 |
+
elif is_sparse_categorical:
|
| 111 |
+
loss_obj = losses_module.sparse_categorical_crossentropy
|
| 112 |
+
else:
|
| 113 |
+
loss_obj = losses_module.categorical_crossentropy
|
| 114 |
+
|
| 115 |
+
if not isinstance(loss_obj, losses_module.Loss):
|
| 116 |
+
if isinstance(identifier, str):
|
| 117 |
+
loss_name = identifier
|
| 118 |
+
else:
|
| 119 |
+
loss_name = get_object_name(loss_obj)
|
| 120 |
+
loss_obj = losses_module.LossFunctionWrapper(loss_obj, name=loss_name)
|
| 121 |
+
return loss_obj
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class CompileMetrics(metrics_module.Metric):
|
| 125 |
+
def __init__(
|
| 126 |
+
self,
|
| 127 |
+
metrics,
|
| 128 |
+
weighted_metrics,
|
| 129 |
+
name="compile_metric",
|
| 130 |
+
output_names=None,
|
| 131 |
+
):
|
| 132 |
+
super().__init__(name=name)
|
| 133 |
+
if metrics and not isinstance(metrics, (list, tuple, dict)):
|
| 134 |
+
raise ValueError(
|
| 135 |
+
"Expected `metrics` argument to be a list, tuple, or dict. "
|
| 136 |
+
f"Received instead: metrics={metrics} of type {type(metrics)}"
|
| 137 |
+
)
|
| 138 |
+
if weighted_metrics and not isinstance(
|
| 139 |
+
weighted_metrics, (list, tuple, dict)
|
| 140 |
+
):
|
| 141 |
+
raise ValueError(
|
| 142 |
+
"Expected `weighted_metrics` argument to be a list, tuple, or "
|
| 143 |
+
f"dict. Received instead: weighted_metrics={weighted_metrics} "
|
| 144 |
+
f"of type {type(weighted_metrics)}"
|
| 145 |
+
)
|
| 146 |
+
self._user_metrics = metrics
|
| 147 |
+
self._user_weighted_metrics = weighted_metrics
|
| 148 |
+
self.built = False
|
| 149 |
+
self.name = "compile_metrics"
|
| 150 |
+
self.output_names = output_names
|
| 151 |
+
|
| 152 |
+
@property
|
| 153 |
+
def metrics(self):
|
| 154 |
+
if not self.built:
|
| 155 |
+
return []
|
| 156 |
+
metrics = []
|
| 157 |
+
for m in self._flat_metrics + self._flat_weighted_metrics:
|
| 158 |
+
if isinstance(m, MetricsList):
|
| 159 |
+
metrics.extend(m.metrics)
|
| 160 |
+
elif m is not None:
|
| 161 |
+
metrics.append(m)
|
| 162 |
+
return metrics
|
| 163 |
+
|
| 164 |
+
@property
|
| 165 |
+
def variables(self):
|
| 166 |
+
# Avoiding relying on implicit tracking since
|
| 167 |
+
# CompileMetrics may be instantiated or built in a no tracking scope.
|
| 168 |
+
if not self.built:
|
| 169 |
+
return []
|
| 170 |
+
vars = []
|
| 171 |
+
for m in self.metrics:
|
| 172 |
+
if m is not None:
|
| 173 |
+
vars.extend(m.variables)
|
| 174 |
+
return vars
|
| 175 |
+
|
| 176 |
+
def build(self, y_true, y_pred):
|
| 177 |
+
num_outputs = 1 # default
|
| 178 |
+
if self.output_names:
|
| 179 |
+
output_names = self.output_names
|
| 180 |
+
elif isinstance(y_pred, dict):
|
| 181 |
+
output_names = sorted(list(y_pred.keys()))
|
| 182 |
+
elif isinstance(y_pred, (list, tuple)):
|
| 183 |
+
num_outputs = len(y_pred)
|
| 184 |
+
if all(hasattr(x, "_keras_history") for x in y_pred):
|
| 185 |
+
output_names = [x._keras_history.operation.name for x in y_pred]
|
| 186 |
+
else:
|
| 187 |
+
output_names = None
|
| 188 |
+
else:
|
| 189 |
+
output_names = None
|
| 190 |
+
if output_names:
|
| 191 |
+
num_outputs = len(output_names)
|
| 192 |
+
|
| 193 |
+
y_pred = self._flatten_y(y_pred)
|
| 194 |
+
y_true = self._flatten_y(y_true)
|
| 195 |
+
|
| 196 |
+
metrics = self._user_metrics
|
| 197 |
+
weighted_metrics = self._user_weighted_metrics
|
| 198 |
+
self._flat_metrics = self._build_metrics_set(
|
| 199 |
+
metrics,
|
| 200 |
+
num_outputs,
|
| 201 |
+
output_names,
|
| 202 |
+
y_true,
|
| 203 |
+
y_pred,
|
| 204 |
+
argument_name="metrics",
|
| 205 |
+
)
|
| 206 |
+
self._flat_weighted_metrics = self._build_metrics_set(
|
| 207 |
+
weighted_metrics,
|
| 208 |
+
num_outputs,
|
| 209 |
+
output_names,
|
| 210 |
+
y_true,
|
| 211 |
+
y_pred,
|
| 212 |
+
argument_name="weighted_metrics",
|
| 213 |
+
)
|
| 214 |
+
self.built = True
|
| 215 |
+
|
| 216 |
+
def _build_metrics_set(
|
| 217 |
+
self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
|
| 218 |
+
):
|
| 219 |
+
flat_metrics = []
|
| 220 |
+
if isinstance(metrics, dict):
|
| 221 |
+
for name in metrics.keys():
|
| 222 |
+
if name not in output_names:
|
| 223 |
+
raise ValueError(
|
| 224 |
+
f"In the dict argument `{argument_name}`, key "
|
| 225 |
+
f"'{name}' does not correspond to any model "
|
| 226 |
+
f"output. Received:\n{argument_name}={metrics}"
|
| 227 |
+
)
|
| 228 |
+
if num_outputs == 1:
|
| 229 |
+
if not metrics:
|
| 230 |
+
flat_metrics.append(None)
|
| 231 |
+
else:
|
| 232 |
+
if isinstance(metrics, dict):
|
| 233 |
+
metrics = tree.flatten(metrics)
|
| 234 |
+
if not isinstance(metrics, list):
|
| 235 |
+
metrics = [metrics]
|
| 236 |
+
if not all(is_function_like(m) for m in metrics):
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Expected all entries in the `{argument_name}` list "
|
| 239 |
+
f"to be metric objects. Received instead:\n"
|
| 240 |
+
f"{argument_name}={metrics}"
|
| 241 |
+
)
|
| 242 |
+
flat_metrics.append(
|
| 243 |
+
MetricsList(
|
| 244 |
+
[
|
| 245 |
+
get_metric(m, y_true[0], y_pred[0])
|
| 246 |
+
for m in metrics
|
| 247 |
+
if m is not None
|
| 248 |
+
]
|
| 249 |
+
)
|
| 250 |
+
)
|
| 251 |
+
else:
|
| 252 |
+
if isinstance(metrics, (list, tuple)):
|
| 253 |
+
if len(metrics) != len(y_pred):
|
| 254 |
+
raise ValueError(
|
| 255 |
+
"For a model with multiple outputs, "
|
| 256 |
+
f"when providing the `{argument_name}` argument as a "
|
| 257 |
+
"list, it should have as many entries as the model has "
|
| 258 |
+
f"outputs. Received:\n{argument_name}={metrics}\nof "
|
| 259 |
+
f"length {len(metrics)} whereas the model has "
|
| 260 |
+
f"{len(y_pred)} outputs."
|
| 261 |
+
)
|
| 262 |
+
for idx, (mls, yt, yp) in enumerate(
|
| 263 |
+
zip(metrics, y_true, y_pred)
|
| 264 |
+
):
|
| 265 |
+
if not isinstance(mls, list):
|
| 266 |
+
mls = [mls]
|
| 267 |
+
name = output_names[idx] if output_names else None
|
| 268 |
+
if not all(is_function_like(e) for e in mls):
|
| 269 |
+
raise ValueError(
|
| 270 |
+
f"All entries in the sublists of the "
|
| 271 |
+
f"`{argument_name}` list should be metric objects. "
|
| 272 |
+
f"Found the following sublist with unknown "
|
| 273 |
+
f"types: {mls}"
|
| 274 |
+
)
|
| 275 |
+
flat_metrics.append(
|
| 276 |
+
MetricsList(
|
| 277 |
+
[
|
| 278 |
+
get_metric(m, yt, yp)
|
| 279 |
+
for m in mls
|
| 280 |
+
if m is not None
|
| 281 |
+
],
|
| 282 |
+
output_name=name,
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
elif isinstance(metrics, dict):
|
| 286 |
+
if output_names is None:
|
| 287 |
+
raise ValueError(
|
| 288 |
+
f"Argument `{argument_name}` can only be provided as a "
|
| 289 |
+
"dict when the model also returns a dict of outputs. "
|
| 290 |
+
f"Received {argument_name}={metrics}"
|
| 291 |
+
)
|
| 292 |
+
for name in metrics.keys():
|
| 293 |
+
if not isinstance(metrics[name], list):
|
| 294 |
+
metrics[name] = [metrics[name]]
|
| 295 |
+
if not all(is_function_like(e) for e in metrics[name]):
|
| 296 |
+
raise ValueError(
|
| 297 |
+
f"All entries in the sublists of the "
|
| 298 |
+
f"`{argument_name}` dict should be metric objects. "
|
| 299 |
+
f"At key '{name}', found the following sublist "
|
| 300 |
+
f"with unknown types: {metrics[name]}"
|
| 301 |
+
)
|
| 302 |
+
for name, yt, yp in zip(output_names, y_true, y_pred):
|
| 303 |
+
if name in metrics:
|
| 304 |
+
flat_metrics.append(
|
| 305 |
+
MetricsList(
|
| 306 |
+
[
|
| 307 |
+
get_metric(m, yt, yp)
|
| 308 |
+
for m in metrics[name]
|
| 309 |
+
if m is not None
|
| 310 |
+
],
|
| 311 |
+
output_name=name,
|
| 312 |
+
)
|
| 313 |
+
)
|
| 314 |
+
else:
|
| 315 |
+
flat_metrics.append(None)
|
| 316 |
+
return flat_metrics
|
| 317 |
+
|
| 318 |
+
def _flatten_y(self, y):
|
| 319 |
+
if isinstance(y, dict) and self.output_names:
|
| 320 |
+
result = []
|
| 321 |
+
for name in self.output_names:
|
| 322 |
+
if name in y:
|
| 323 |
+
result.append(y[name])
|
| 324 |
+
return result
|
| 325 |
+
return tree.flatten(y)
|
| 326 |
+
|
| 327 |
+
def update_state(self, y_true, y_pred, sample_weight=None):
|
| 328 |
+
if not self.built:
|
| 329 |
+
self.build(y_true, y_pred)
|
| 330 |
+
y_true = self._flatten_y(y_true)
|
| 331 |
+
y_pred = self._flatten_y(y_pred)
|
| 332 |
+
for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred):
|
| 333 |
+
if m:
|
| 334 |
+
m.update_state(y_t, y_p)
|
| 335 |
+
if sample_weight is not None:
|
| 336 |
+
sample_weight = self._flatten_y(sample_weight)
|
| 337 |
+
# For multi-outputs, repeat sample weights for n outputs.
|
| 338 |
+
if len(sample_weight) < len(y_true):
|
| 339 |
+
sample_weight = [sample_weight[0] for _ in range(len(y_true))]
|
| 340 |
+
else:
|
| 341 |
+
sample_weight = [None for _ in range(len(y_true))]
|
| 342 |
+
for m, y_t, y_p, s_w in zip(
|
| 343 |
+
self._flat_weighted_metrics, y_true, y_pred, sample_weight
|
| 344 |
+
):
|
| 345 |
+
if m:
|
| 346 |
+
m.update_state(y_t, y_p, s_w)
|
| 347 |
+
|
| 348 |
+
def reset_state(self):
|
| 349 |
+
if not self.built:
|
| 350 |
+
return
|
| 351 |
+
for m in self._flat_metrics:
|
| 352 |
+
if m:
|
| 353 |
+
m.reset_state()
|
| 354 |
+
for m in self._flat_weighted_metrics:
|
| 355 |
+
if m:
|
| 356 |
+
m.reset_state()
|
| 357 |
+
|
| 358 |
+
def result(self):
|
| 359 |
+
if not self.built:
|
| 360 |
+
raise ValueError(
|
| 361 |
+
"Cannot get result() since the metric has not yet been built."
|
| 362 |
+
)
|
| 363 |
+
results = {}
|
| 364 |
+
unique_name_counters = {}
|
| 365 |
+
for mls in self._flat_metrics:
|
| 366 |
+
if not mls:
|
| 367 |
+
continue
|
| 368 |
+
for m in mls.metrics:
|
| 369 |
+
name = m.name
|
| 370 |
+
if mls.output_name:
|
| 371 |
+
name = f"{mls.output_name}_{name}"
|
| 372 |
+
if name not in unique_name_counters:
|
| 373 |
+
results[name] = m.result()
|
| 374 |
+
unique_name_counters[name] = 1
|
| 375 |
+
else:
|
| 376 |
+
index = unique_name_counters[name]
|
| 377 |
+
unique_name_counters[name] += 1
|
| 378 |
+
name = f"{name}_{index}"
|
| 379 |
+
results[name] = m.result()
|
| 380 |
+
|
| 381 |
+
for mls in self._flat_weighted_metrics:
|
| 382 |
+
if not mls:
|
| 383 |
+
continue
|
| 384 |
+
for m in mls.metrics:
|
| 385 |
+
name = m.name
|
| 386 |
+
if mls.output_name:
|
| 387 |
+
name = f"{mls.output_name}_{name}"
|
| 388 |
+
if name not in unique_name_counters:
|
| 389 |
+
results[name] = m.result()
|
| 390 |
+
unique_name_counters[name] = 1
|
| 391 |
+
else:
|
| 392 |
+
name = f"weighted_{m.name}"
|
| 393 |
+
if mls.output_name:
|
| 394 |
+
name = f"{mls.output_name}_{name}"
|
| 395 |
+
if name not in unique_name_counters:
|
| 396 |
+
unique_name_counters[name] = 1
|
| 397 |
+
else:
|
| 398 |
+
index = unique_name_counters[name]
|
| 399 |
+
unique_name_counters[name] += 1
|
| 400 |
+
name = f"{name}_{index}"
|
| 401 |
+
results[name] = m.result()
|
| 402 |
+
return results
|
| 403 |
+
|
| 404 |
+
def get_config(self):
|
| 405 |
+
raise NotImplementedError
|
| 406 |
+
|
| 407 |
+
@classmethod
|
| 408 |
+
def from_config(cls, config):
|
| 409 |
+
raise NotImplementedError
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
class CompileLoss(losses_module.Loss):
|
| 413 |
+
Loss = namedtuple("Loss", ["path", "loss", "loss_weights", "name"])
|
| 414 |
+
|
| 415 |
+
def __init__(
|
| 416 |
+
self,
|
| 417 |
+
loss,
|
| 418 |
+
loss_weights=None,
|
| 419 |
+
reduction="sum_over_batch_size",
|
| 420 |
+
output_names=None,
|
| 421 |
+
):
|
| 422 |
+
if loss_weights and not isinstance(
|
| 423 |
+
loss_weights, (list, tuple, dict, float)
|
| 424 |
+
):
|
| 425 |
+
raise ValueError(
|
| 426 |
+
"Expected `loss_weights` argument to be a float "
|
| 427 |
+
"(single output case) or a list, tuple, or "
|
| 428 |
+
"dict (multiple output case). "
|
| 429 |
+
f"Received instead: loss_weights={loss_weights} "
|
| 430 |
+
f"of type {type(loss_weights)}"
|
| 431 |
+
)
|
| 432 |
+
self._user_loss = loss
|
| 433 |
+
self._user_loss_weights = loss_weights
|
| 434 |
+
self.built = False
|
| 435 |
+
self.output_names = output_names
|
| 436 |
+
super().__init__(name="compile_loss", reduction=reduction)
|
| 437 |
+
|
| 438 |
+
# Use `Tracker` to track metrics for individual losses.
|
| 439 |
+
self._metrics = []
|
| 440 |
+
self._tracker = Tracker(
|
| 441 |
+
{
|
| 442 |
+
"metrics": (
|
| 443 |
+
lambda x: isinstance(x, metrics_module.Metric),
|
| 444 |
+
self._metrics,
|
| 445 |
+
)
|
| 446 |
+
}
|
| 447 |
+
)
|
| 448 |
+
self._flat_losses = None
|
| 449 |
+
self._y_pred_build_structure = None
|
| 450 |
+
self._y_true_build_structure = None
|
| 451 |
+
|
| 452 |
+
@property
|
| 453 |
+
def metrics(self):
|
| 454 |
+
return self._metrics
|
| 455 |
+
|
| 456 |
+
@property
|
| 457 |
+
def variables(self):
|
| 458 |
+
vars = []
|
| 459 |
+
for m in self.metrics:
|
| 460 |
+
vars.extend(m.variables)
|
| 461 |
+
return vars
|
| 462 |
+
|
| 463 |
+
def _build_nested(self, y_true, y_pred, loss, output_names, current_path):
|
| 464 |
+
flat_y_pred = tree.flatten(y_pred)
|
| 465 |
+
if not tree.is_nested(loss):
|
| 466 |
+
_loss = loss.loss
|
| 467 |
+
if _loss is None:
|
| 468 |
+
return
|
| 469 |
+
loss_weight = loss.weight
|
| 470 |
+
resolved_loss = get_loss(_loss, y_true, y_pred)
|
| 471 |
+
name_path = current_path
|
| 472 |
+
if not tree.is_nested(output_names):
|
| 473 |
+
if output_names is not None:
|
| 474 |
+
output_name = output_names
|
| 475 |
+
else:
|
| 476 |
+
output_name = resolved_loss.name
|
| 477 |
+
if len(name_path) == 0:
|
| 478 |
+
name_path = (output_name,)
|
| 479 |
+
elif isinstance(name_path[-1], int):
|
| 480 |
+
name_path = name_path[:-1] + (output_name,)
|
| 481 |
+
name = "/".join([str(path) for path in name_path])
|
| 482 |
+
if name == "":
|
| 483 |
+
if isinstance(output_names, dict):
|
| 484 |
+
flat_output_names = list(output_names.keys())
|
| 485 |
+
else:
|
| 486 |
+
flat_output_names = tree.flatten(output_names)
|
| 487 |
+
name = "_".join(flat_output_names)
|
| 488 |
+
self._flat_losses.append(
|
| 489 |
+
CompileLoss.Loss(current_path, resolved_loss, loss_weight, name)
|
| 490 |
+
)
|
| 491 |
+
return
|
| 492 |
+
elif (
|
| 493 |
+
issubclass(type(loss), (list, tuple))
|
| 494 |
+
and all([not tree.is_nested(_loss) for _loss in loss])
|
| 495 |
+
and len(loss) == len(flat_y_pred)
|
| 496 |
+
):
|
| 497 |
+
loss = tree.pack_sequence_as(y_pred, loss)
|
| 498 |
+
elif issubclass(type(loss), (list, tuple)) and not isinstance(
|
| 499 |
+
y_pred, type(loss)
|
| 500 |
+
):
|
| 501 |
+
for _loss in loss:
|
| 502 |
+
self._build_nested(
|
| 503 |
+
y_true,
|
| 504 |
+
y_pred,
|
| 505 |
+
_loss,
|
| 506 |
+
output_names,
|
| 507 |
+
current_path,
|
| 508 |
+
)
|
| 509 |
+
return
|
| 510 |
+
|
| 511 |
+
if not tree.is_nested(loss):
|
| 512 |
+
return self._build_nested(
|
| 513 |
+
y_true, y_pred, loss, output_names, current_path
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
if not isinstance(loss, type(y_pred)):
|
| 517 |
+
raise KeyError(
|
| 518 |
+
f"The path: {current_path} in "
|
| 519 |
+
"the `loss` argument, can't be found in "
|
| 520 |
+
"the model's output (`y_pred`)."
|
| 521 |
+
)
|
| 522 |
+
|
| 523 |
+
# shallow traverse the loss config
|
| 524 |
+
if isinstance(loss, dict):
|
| 525 |
+
iterator = loss.items()
|
| 526 |
+
|
| 527 |
+
def key_check_fn(key, objs):
|
| 528 |
+
return all(
|
| 529 |
+
[isinstance(obj, dict) and key in obj for obj in objs]
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
elif issubclass(type(loss), (list, tuple)):
|
| 533 |
+
iterator = enumerate(loss)
|
| 534 |
+
|
| 535 |
+
def key_check_fn(key, objs):
|
| 536 |
+
return all(
|
| 537 |
+
[
|
| 538 |
+
issubclass(type(obj), (list, tuple)) and key < len(obj)
|
| 539 |
+
for obj in objs
|
| 540 |
+
]
|
| 541 |
+
)
|
| 542 |
+
|
| 543 |
+
else:
|
| 544 |
+
raise TypeError(
|
| 545 |
+
f"Unsupported type {type(loss)} "
|
| 546 |
+
f"in the `loss` configuration."
|
| 547 |
+
)
|
| 548 |
+
|
| 549 |
+
for key, _loss in iterator:
|
| 550 |
+
if _loss is None:
|
| 551 |
+
continue
|
| 552 |
+
if not key_check_fn(key, (y_true, y_pred)):
|
| 553 |
+
raise KeyError(
|
| 554 |
+
f"The path: {current_path + (key,)} in "
|
| 555 |
+
"the `loss` argument, can't be found in "
|
| 556 |
+
"either the model's output (`y_pred`) or in the "
|
| 557 |
+
"labels (`y_true`)."
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
self._build_nested(
|
| 561 |
+
y_true[key],
|
| 562 |
+
y_pred[key],
|
| 563 |
+
_loss,
|
| 564 |
+
output_names[key],
|
| 565 |
+
current_path + (key,),
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
def build(self, y_true, y_pred):
|
| 569 |
+
loss = self._user_loss
|
| 570 |
+
loss_weights = self._user_loss_weights
|
| 571 |
+
flat_output_names = self.output_names
|
| 572 |
+
if (
|
| 573 |
+
self.output_names
|
| 574 |
+
and isinstance(self._user_loss, dict)
|
| 575 |
+
and not isinstance(y_pred, dict)
|
| 576 |
+
):
|
| 577 |
+
if set(self.output_names) == set(self._user_loss.keys()):
|
| 578 |
+
loss = [self._user_loss[name] for name in self.output_names]
|
| 579 |
+
if isinstance(self._user_loss_weights, dict):
|
| 580 |
+
loss_weights = [
|
| 581 |
+
self._user_loss_weights[name]
|
| 582 |
+
for name in self.output_names
|
| 583 |
+
]
|
| 584 |
+
else:
|
| 585 |
+
raise ValueError(
|
| 586 |
+
f"Expected keys {self.output_names} in loss dict, but "
|
| 587 |
+
f"found loss.keys()={list(self._user_loss.keys())}"
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Pytree leaf container
|
| 591 |
+
class WeightedLoss:
|
| 592 |
+
def __new__(cls, loss, weight):
|
| 593 |
+
if loss is None:
|
| 594 |
+
return None
|
| 595 |
+
return object.__new__(cls)
|
| 596 |
+
|
| 597 |
+
def __init__(self, loss, weight):
|
| 598 |
+
self.loss = loss
|
| 599 |
+
self.weight = weight
|
| 600 |
+
|
| 601 |
+
# pack the losses and the weights together
|
| 602 |
+
if loss_weights is not None:
|
| 603 |
+
try:
|
| 604 |
+
tree.assert_same_structure(loss, loss_weights)
|
| 605 |
+
except ValueError:
|
| 606 |
+
flat_loss_weights = tree.flatten(loss_weights)
|
| 607 |
+
if len(tree.flatten(loss)) != len(flat_loss_weights):
|
| 608 |
+
raise ValueError(
|
| 609 |
+
f"`loss_weights` must match the number of losses, "
|
| 610 |
+
f"got {len(tree.flatten(loss))} losses "
|
| 611 |
+
f"and {len(loss_weights)} weights."
|
| 612 |
+
)
|
| 613 |
+
loss_weights = tree.pack_sequence_as(loss, flat_loss_weights)
|
| 614 |
+
loss = tree.map_structure(
|
| 615 |
+
lambda _loss, _weight: WeightedLoss(_loss, _weight),
|
| 616 |
+
loss,
|
| 617 |
+
loss_weights,
|
| 618 |
+
)
|
| 619 |
+
else:
|
| 620 |
+
loss = tree.map_structure(
|
| 621 |
+
lambda _loss: WeightedLoss(_loss, None), loss
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
self._flat_losses = []
|
| 625 |
+
|
| 626 |
+
if (
|
| 627 |
+
isinstance(loss, dict)
|
| 628 |
+
and issubclass(type(y_pred), (list, tuple))
|
| 629 |
+
and set(loss.keys()) == set(flat_output_names)
|
| 630 |
+
and len(y_pred) == len(flat_output_names)
|
| 631 |
+
):
|
| 632 |
+
y_pred = {name: y_p for name, y_p in zip(flat_output_names, y_pred)}
|
| 633 |
+
y_true = {name: y_t for name, y_t in zip(flat_output_names, y_true)}
|
| 634 |
+
elif (
|
| 635 |
+
isinstance(loss, dict)
|
| 636 |
+
and not tree.is_nested(y_pred)
|
| 637 |
+
and set(loss.keys()) == set(flat_output_names)
|
| 638 |
+
and len(flat_output_names) == 1
|
| 639 |
+
):
|
| 640 |
+
y_pred = {
|
| 641 |
+
name: y_p for name, y_p in zip(flat_output_names, [y_pred])
|
| 642 |
+
}
|
| 643 |
+
y_true = {
|
| 644 |
+
name: y_t for name, y_t in zip(flat_output_names, [y_true])
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
try:
|
| 648 |
+
output_names = tree.pack_sequence_as(y_pred, flat_output_names)
|
| 649 |
+
except:
|
| 650 |
+
inferred_flat_output_names = self._get_y_pred_output_names(y_pred)
|
| 651 |
+
output_names = tree.pack_sequence_as(
|
| 652 |
+
y_pred, inferred_flat_output_names
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
if not tree.is_nested(loss):
|
| 656 |
+
loss = tree.map_structure(lambda x: loss, y_pred)
|
| 657 |
+
|
| 658 |
+
self._build_nested(y_true, y_pred, loss, output_names, ())
|
| 659 |
+
|
| 660 |
+
# Add `Mean` metric to the tracker for each loss.
|
| 661 |
+
if len(self._flat_losses) > 1:
|
| 662 |
+
for _loss in self._flat_losses:
|
| 663 |
+
name = _loss.name + "_loss"
|
| 664 |
+
self._tracker.add_to_store(
|
| 665 |
+
"metrics", metrics_module.Mean(name=name)
|
| 666 |
+
)
|
| 667 |
+
|
| 668 |
+
self._y_pred_build_structure = tree.map_structure(
|
| 669 |
+
lambda x: None, y_pred
|
| 670 |
+
)
|
| 671 |
+
self._y_true_build_structure = tree.map_structure(
|
| 672 |
+
lambda x: None, y_true
|
| 673 |
+
)
|
| 674 |
+
self.built = True
|
| 675 |
+
|
| 676 |
+
def _get_y_pred_output_names(self, y_pred):
|
| 677 |
+
flat_y_pred = tree.flatten(y_pred)
|
| 678 |
+
if all((isinstance(x, KerasTensor) for x in flat_y_pred)):
|
| 679 |
+
output_names = []
|
| 680 |
+
for tensor in flat_y_pred:
|
| 681 |
+
if hasattr(tensor, "_keras_history"):
|
| 682 |
+
output_names.append(tensor._keras_history.operation.name)
|
| 683 |
+
else:
|
| 684 |
+
output_names.append(tensor.name)
|
| 685 |
+
else:
|
| 686 |
+
output_names = [None] * len(flat_y_pred)
|
| 687 |
+
return output_names
|
| 688 |
+
|
| 689 |
+
def __call__(self, y_true, y_pred, sample_weight=None):
|
| 690 |
+
with ops.name_scope(self.name):
|
| 691 |
+
return self.call(y_true, y_pred, sample_weight)
|
| 692 |
+
|
| 693 |
+
def call(self, y_true, y_pred, sample_weight=None):
|
| 694 |
+
if not tree.is_nested(y_true) and not tree.is_nested(y_pred):
|
| 695 |
+
# Fast path: single output case / no loss-tracking metric.
|
| 696 |
+
if not self.built:
|
| 697 |
+
self.build(y_true, y_pred)
|
| 698 |
+
_, loss_fn, loss_weight, _ = self._flat_losses[0]
|
| 699 |
+
loss_value = ops.cast(
|
| 700 |
+
loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
|
| 701 |
+
)
|
| 702 |
+
if loss_weight is not None:
|
| 703 |
+
loss_value = ops.multiply(loss_value, loss_weight)
|
| 704 |
+
return loss_value
|
| 705 |
+
|
| 706 |
+
try:
|
| 707 |
+
tree.assert_same_structure(y_pred, y_true)
|
| 708 |
+
except ValueError:
|
| 709 |
+
# Check case where y_true is either flat or leaf
|
| 710 |
+
if (
|
| 711 |
+
not tree.is_nested(y_true)
|
| 712 |
+
and hasattr(y_pred, "__len__")
|
| 713 |
+
and len(y_pred) == 1
|
| 714 |
+
):
|
| 715 |
+
y_true = [y_true]
|
| 716 |
+
|
| 717 |
+
# Check case where y_pred is list/tuple and y_true is dict
|
| 718 |
+
elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict):
|
| 719 |
+
if set(self.output_names) == set(y_true.keys()):
|
| 720 |
+
y_true = [y_true[name] for name in self.output_names]
|
| 721 |
+
|
| 722 |
+
try:
|
| 723 |
+
y_true = tree.pack_sequence_as(y_pred, y_true)
|
| 724 |
+
except:
|
| 725 |
+
# Check case where y_true has the same structure but uses
|
| 726 |
+
# different (but reconcilable) container types,
|
| 727 |
+
# e.g `list` vs `tuple`.
|
| 728 |
+
try:
|
| 729 |
+
tree.assert_same_paths(y_true, y_pred)
|
| 730 |
+
y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true))
|
| 731 |
+
except:
|
| 732 |
+
try:
|
| 733 |
+
# Check case where loss is partially defined over y_pred
|
| 734 |
+
flat_y_true = tree.flatten(y_true)
|
| 735 |
+
flat_loss = tree.flatten(self._user_loss)
|
| 736 |
+
flat_loss_non_nones = [
|
| 737 |
+
(i, loss)
|
| 738 |
+
for i, loss in enumerate(flat_loss)
|
| 739 |
+
if loss is not None
|
| 740 |
+
]
|
| 741 |
+
assert len(flat_y_true) == len(flat_loss_non_nones)
|
| 742 |
+
y_true = [None] * len(flat_loss)
|
| 743 |
+
for y_t, (i, loss) in zip(
|
| 744 |
+
flat_y_true, flat_loss_non_nones
|
| 745 |
+
):
|
| 746 |
+
y_true[i] = y_t
|
| 747 |
+
y_true = tree.pack_sequence_as(self._user_loss, y_true)
|
| 748 |
+
except:
|
| 749 |
+
y_true_struct = tree.map_structure(
|
| 750 |
+
lambda _: "*", y_true
|
| 751 |
+
)
|
| 752 |
+
y_pred_struct = tree.map_structure(
|
| 753 |
+
lambda _: "*", y_pred
|
| 754 |
+
)
|
| 755 |
+
raise ValueError(
|
| 756 |
+
"y_true and y_pred have different structures.\n"
|
| 757 |
+
f"y_true: {y_true_struct}\n"
|
| 758 |
+
f"y_pred: {y_pred_struct}\n"
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
if not self.built:
|
| 762 |
+
self.build(y_true, y_pred)
|
| 763 |
+
|
| 764 |
+
try:
|
| 765 |
+
tree.assert_same_structure(self._y_pred_build_structure, y_pred)
|
| 766 |
+
except ValueError:
|
| 767 |
+
y_pred = tree.pack_sequence_as(
|
| 768 |
+
self._y_pred_build_structure, tree.flatten(y_pred)
|
| 769 |
+
)
|
| 770 |
+
try:
|
| 771 |
+
tree.assert_same_structure(self._y_true_build_structure, y_true)
|
| 772 |
+
except ValueError:
|
| 773 |
+
y_true = tree.pack_sequence_as(
|
| 774 |
+
self._y_true_build_structure, tree.flatten(y_true)
|
| 775 |
+
)
|
| 776 |
+
|
| 777 |
+
# We need to add a dummy `None` if the model has only a single output.
|
| 778 |
+
metrics = [None] if len(self.metrics) == 0 else self.metrics
|
| 779 |
+
|
| 780 |
+
# Iterate all losses in flat form.
|
| 781 |
+
loss_values = []
|
| 782 |
+
|
| 783 |
+
def resolve_path(path, object):
|
| 784 |
+
for _path in path:
|
| 785 |
+
object = object[_path]
|
| 786 |
+
return object
|
| 787 |
+
|
| 788 |
+
for (path, loss_fn, loss_weight, _), metric in zip(
|
| 789 |
+
self._flat_losses, metrics
|
| 790 |
+
):
|
| 791 |
+
y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred)
|
| 792 |
+
if sample_weight is not None and tree.is_nested(sample_weight):
|
| 793 |
+
_sample_weight = resolve_path(path, sample_weight)
|
| 794 |
+
else:
|
| 795 |
+
_sample_weight = sample_weight
|
| 796 |
+
|
| 797 |
+
value = ops.cast(
|
| 798 |
+
loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype
|
| 799 |
+
)
|
| 800 |
+
# Record *unweighted* individual losses.
|
| 801 |
+
if metric:
|
| 802 |
+
metric.update_state(
|
| 803 |
+
loss_module.unscale_loss_for_distribution(value),
|
| 804 |
+
sample_weight=tree.flatten(y_p)[0].shape[0],
|
| 805 |
+
)
|
| 806 |
+
if loss_weight is not None:
|
| 807 |
+
value = ops.multiply(value, loss_weight)
|
| 808 |
+
loss_values.append(value)
|
| 809 |
+
|
| 810 |
+
if loss_values:
|
| 811 |
+
total_loss = sum(loss_values)
|
| 812 |
+
return total_loss
|
| 813 |
+
return None
|
| 814 |
+
|
| 815 |
+
def get_config(self):
|
| 816 |
+
raise NotImplementedError
|
| 817 |
+
|
| 818 |
+
@classmethod
|
| 819 |
+
def from_config(cls, config):
|
| 820 |
+
raise NotImplementedError
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import types
|
| 2 |
+
|
| 3 |
+
from keras.src.distribution import distribution_lib
|
| 4 |
+
from keras.src.trainers.data_adapters import array_data_adapter
|
| 5 |
+
from keras.src.trainers.data_adapters import data_adapter
|
| 6 |
+
from keras.src.trainers.data_adapters import py_dataset_adapter
|
| 7 |
+
from keras.src.trainers.data_adapters.array_data_adapter import ArrayDataAdapter
|
| 8 |
+
from keras.src.trainers.data_adapters.generator_data_adapter import (
|
| 9 |
+
GeneratorDataAdapter,
|
| 10 |
+
)
|
| 11 |
+
from keras.src.trainers.data_adapters.py_dataset_adapter import PyDatasetAdapter
|
| 12 |
+
from keras.src.trainers.data_adapters.tf_dataset_adapter import TFDatasetAdapter
|
| 13 |
+
from keras.src.trainers.data_adapters.torch_data_loader_adapter import (
|
| 14 |
+
TorchDataLoaderAdapter,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_data_adapter(
|
| 19 |
+
x,
|
| 20 |
+
y=None,
|
| 21 |
+
sample_weight=None,
|
| 22 |
+
batch_size=None,
|
| 23 |
+
steps_per_epoch=None,
|
| 24 |
+
shuffle=False,
|
| 25 |
+
class_weight=None,
|
| 26 |
+
):
|
| 27 |
+
# Allow passing a custom data adapter.
|
| 28 |
+
if isinstance(x, data_adapter.DataAdapter):
|
| 29 |
+
return x
|
| 30 |
+
|
| 31 |
+
# Check for multi-process/worker distribution. Since only tf.dataset
|
| 32 |
+
# is supported at the moment, we will raise error if the inputs fail
|
| 33 |
+
# the type check
|
| 34 |
+
distribution = distribution_lib.distribution()
|
| 35 |
+
if getattr(distribution, "_is_multi_process", False) and not is_tf_dataset(
|
| 36 |
+
x
|
| 37 |
+
):
|
| 38 |
+
raise ValueError(
|
| 39 |
+
"When using multi-worker distribution, the data must be provided "
|
| 40 |
+
f"as a `tf.data.Dataset` instance. Received: type(x)={type(x)}."
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if array_data_adapter.can_convert_arrays((x, y, sample_weight)):
|
| 44 |
+
return ArrayDataAdapter(
|
| 45 |
+
x,
|
| 46 |
+
y,
|
| 47 |
+
sample_weight=sample_weight,
|
| 48 |
+
class_weight=class_weight,
|
| 49 |
+
shuffle=shuffle,
|
| 50 |
+
batch_size=batch_size,
|
| 51 |
+
steps=steps_per_epoch,
|
| 52 |
+
)
|
| 53 |
+
elif is_tf_dataset(x):
|
| 54 |
+
# Unsupported args: y, sample_weight, shuffle
|
| 55 |
+
if y is not None:
|
| 56 |
+
raise_unsupported_arg("y", "the targets", "tf.data.Dataset")
|
| 57 |
+
if sample_weight is not None:
|
| 58 |
+
raise_unsupported_arg(
|
| 59 |
+
"sample_weights", "the sample weights", "tf.data.Dataset"
|
| 60 |
+
)
|
| 61 |
+
return TFDatasetAdapter(
|
| 62 |
+
x, class_weight=class_weight, distribution=distribution
|
| 63 |
+
)
|
| 64 |
+
# TODO: should we warn or not?
|
| 65 |
+
# warnings.warn(
|
| 66 |
+
# "`shuffle=True` was passed, but will be ignored since the "
|
| 67 |
+
# "data `x` was provided as a tf.data.Dataset. The Dataset is "
|
| 68 |
+
# "expected to already be shuffled "
|
| 69 |
+
# "(via `.shuffle(tf.data.AUTOTUNE)`)"
|
| 70 |
+
# )
|
| 71 |
+
elif isinstance(x, py_dataset_adapter.PyDataset):
|
| 72 |
+
if y is not None:
|
| 73 |
+
raise_unsupported_arg("y", "the targets", "PyDataset")
|
| 74 |
+
if sample_weight is not None:
|
| 75 |
+
raise_unsupported_arg(
|
| 76 |
+
"sample_weights", "the sample weights", "PyDataset"
|
| 77 |
+
)
|
| 78 |
+
return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle)
|
| 79 |
+
# TODO: should we warn or not?
|
| 80 |
+
# if x.num_batches is None and shuffle:
|
| 81 |
+
# warnings.warn(
|
| 82 |
+
# "`shuffle=True` was passed, but will be ignored since the "
|
| 83 |
+
# "data `x` was provided as a infinite PyDataset. The "
|
| 84 |
+
# "PyDataset is expected to already be shuffled."
|
| 85 |
+
# )
|
| 86 |
+
elif is_torch_dataloader(x):
|
| 87 |
+
if y is not None:
|
| 88 |
+
raise_unsupported_arg("y", "the targets", "torch DataLoader")
|
| 89 |
+
if sample_weight is not None:
|
| 90 |
+
raise_unsupported_arg(
|
| 91 |
+
"sample_weights", "the sample weights", "torch DataLoader"
|
| 92 |
+
)
|
| 93 |
+
if class_weight is not None:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
"Argument `class_weight` is not supported for torch "
|
| 96 |
+
f"DataLoader inputs. Received: class_weight={class_weight}"
|
| 97 |
+
)
|
| 98 |
+
return TorchDataLoaderAdapter(x)
|
| 99 |
+
# TODO: should we warn or not?
|
| 100 |
+
# warnings.warn(
|
| 101 |
+
# "`shuffle=True` was passed, but will be ignored since the "
|
| 102 |
+
# "data `x` was provided as a torch DataLoader. The DataLoader "
|
| 103 |
+
# "is expected to already be shuffled."
|
| 104 |
+
# )
|
| 105 |
+
elif isinstance(x, types.GeneratorType):
|
| 106 |
+
if y is not None:
|
| 107 |
+
raise_unsupported_arg("y", "the targets", "PyDataset")
|
| 108 |
+
if sample_weight is not None:
|
| 109 |
+
raise_unsupported_arg(
|
| 110 |
+
"sample_weights", "the sample weights", "PyDataset"
|
| 111 |
+
)
|
| 112 |
+
if class_weight is not None:
|
| 113 |
+
raise ValueError(
|
| 114 |
+
"Argument `class_weight` is not supported for Python "
|
| 115 |
+
f"generator inputs. Received: class_weight={class_weight}"
|
| 116 |
+
)
|
| 117 |
+
return GeneratorDataAdapter(x)
|
| 118 |
+
# TODO: should we warn or not?
|
| 119 |
+
# warnings.warn(
|
| 120 |
+
# "`shuffle=True` was passed, but will be ignored since the "
|
| 121 |
+
# "data `x` was provided as a generator. The generator "
|
| 122 |
+
# "is expected to yield already-shuffled data."
|
| 123 |
+
# )
|
| 124 |
+
else:
|
| 125 |
+
raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})")
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def raise_unsupported_arg(arg_name, arg_description, input_type):
|
| 129 |
+
raise ValueError(
|
| 130 |
+
f"When providing `x` as a {input_type}, `{arg_name}` "
|
| 131 |
+
f"should not be passed. Instead, {arg_description} should "
|
| 132 |
+
f"be included as part of the {input_type}."
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def is_tf_dataset(x):
|
| 137 |
+
if hasattr(x, "__class__"):
|
| 138 |
+
for parent in x.__class__.__mro__:
|
| 139 |
+
if parent.__name__ in (
|
| 140 |
+
"DatasetV2",
|
| 141 |
+
"DistributedDataset",
|
| 142 |
+
) and "tensorflow.python." in str(parent.__module__):
|
| 143 |
+
return True
|
| 144 |
+
return False
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def is_torch_dataloader(x):
|
| 148 |
+
if hasattr(x, "__class__"):
|
| 149 |
+
for parent in x.__class__.__mro__:
|
| 150 |
+
if parent.__name__ == "DataLoader" and "torch.utils.data" in str(
|
| 151 |
+
parent.__module__
|
| 152 |
+
):
|
| 153 |
+
return True
|
| 154 |
+
return False
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (3.2 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc
ADDED
|
Binary file (11.9 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc
ADDED
|
Binary file (16.9 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc
ADDED
|
Binary file (4.17 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc
ADDED
|
Binary file (3.63 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc
ADDED
|
Binary file (20.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc
ADDED
|
Binary file (3.09 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import functools
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from keras.src import tree
|
| 7 |
+
from keras.src.trainers.data_adapters import array_slicing
|
| 8 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 9 |
+
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ArrayDataAdapter(DataAdapter):
|
| 13 |
+
"""Adapter for array-like objects, e.g. TF/JAX Tensors, NumPy arrays."""
|
| 14 |
+
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
x,
|
| 18 |
+
y=None,
|
| 19 |
+
sample_weight=None,
|
| 20 |
+
batch_size=None,
|
| 21 |
+
steps=None,
|
| 22 |
+
shuffle=False,
|
| 23 |
+
class_weight=None,
|
| 24 |
+
):
|
| 25 |
+
if not can_convert_arrays((x, y, sample_weight)):
|
| 26 |
+
raise ValueError(
|
| 27 |
+
"Expected all elements of `x` to be array-like. "
|
| 28 |
+
f"Received invalid types: x={x}"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
if sample_weight is not None:
|
| 32 |
+
if class_weight is not None:
|
| 33 |
+
raise ValueError(
|
| 34 |
+
"You cannot `class_weight` and `sample_weight` "
|
| 35 |
+
"at the same time."
|
| 36 |
+
)
|
| 37 |
+
if tree.is_nested(y):
|
| 38 |
+
if isinstance(sample_weight, (list, tuple, dict)):
|
| 39 |
+
try:
|
| 40 |
+
tree.assert_same_structure(y, sample_weight)
|
| 41 |
+
except ValueError:
|
| 42 |
+
raise ValueError(
|
| 43 |
+
"You should provide one `sample_weight` array per "
|
| 44 |
+
"output in `y`. The two structures did not match:\n"
|
| 45 |
+
f"- y: {y}\n"
|
| 46 |
+
f"- sample_weight: {sample_weight}\n"
|
| 47 |
+
)
|
| 48 |
+
else:
|
| 49 |
+
is_samplewise = len(sample_weight.shape) == 1 or (
|
| 50 |
+
len(sample_weight.shape) == 2
|
| 51 |
+
and sample_weight.shape[1] == 1
|
| 52 |
+
)
|
| 53 |
+
if not is_samplewise:
|
| 54 |
+
raise ValueError(
|
| 55 |
+
"For a model with multiple outputs, when providing "
|
| 56 |
+
"a single `sample_weight` array, it should only "
|
| 57 |
+
"have one scalar score per sample "
|
| 58 |
+
"(i.e. shape `(num_samples,)`). If you want to use "
|
| 59 |
+
"non-scalar sample weights, pass a `sample_weight` "
|
| 60 |
+
"argument with one array per model output."
|
| 61 |
+
)
|
| 62 |
+
# Replicate the same sample_weight array on all outputs.
|
| 63 |
+
sample_weight = tree.map_structure(
|
| 64 |
+
lambda _: sample_weight, y
|
| 65 |
+
)
|
| 66 |
+
if class_weight is not None:
|
| 67 |
+
if tree.is_nested(y):
|
| 68 |
+
raise ValueError(
|
| 69 |
+
"`class_weight` is only supported for Models with a single "
|
| 70 |
+
"output."
|
| 71 |
+
)
|
| 72 |
+
sample_weight = data_adapter_utils.class_weight_to_sample_weights(
|
| 73 |
+
y, class_weight
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)
|
| 77 |
+
|
| 78 |
+
data_adapter_utils.check_data_cardinality(inputs)
|
| 79 |
+
num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop()
|
| 80 |
+
self._num_samples = num_samples
|
| 81 |
+
self._inputs = inputs
|
| 82 |
+
|
| 83 |
+
# If batch_size is not passed but steps is, calculate from the input
|
| 84 |
+
# data. Defaults to `32` for backwards compatibility.
|
| 85 |
+
if not batch_size:
|
| 86 |
+
batch_size = int(math.ceil(num_samples / steps)) if steps else 32
|
| 87 |
+
|
| 88 |
+
self._size = int(math.ceil(num_samples / batch_size))
|
| 89 |
+
self._batch_size = batch_size
|
| 90 |
+
self._partial_batch_size = num_samples % batch_size
|
| 91 |
+
self._shuffle = shuffle
|
| 92 |
+
|
| 93 |
+
def get_numpy_iterator(self):
|
| 94 |
+
inputs = array_slicing.convert_to_sliceable(
|
| 95 |
+
self._inputs, target_backend="numpy"
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def slice_and_convert_to_numpy(sliceable, indices=None):
|
| 99 |
+
x = sliceable[indices]
|
| 100 |
+
x = sliceable.convert_to_numpy(x)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
return self._get_iterator(slice_and_convert_to_numpy, inputs)
|
| 104 |
+
|
| 105 |
+
def get_tf_dataset(self):
|
| 106 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 107 |
+
|
| 108 |
+
shuffle = self._shuffle
|
| 109 |
+
batch_size = self._batch_size
|
| 110 |
+
num_samples = self._num_samples
|
| 111 |
+
num_full_batches = int(self._num_samples // batch_size)
|
| 112 |
+
|
| 113 |
+
# Vectorized version of shuffle.
|
| 114 |
+
# This is a performance improvement over using `from_tensor_slices`.
|
| 115 |
+
# The indices of the data are shuffled and batched, and these indices
|
| 116 |
+
# are then zipped with the data and used to extract a batch of the data
|
| 117 |
+
# at each step. The performance improvements here come from:
|
| 118 |
+
# 1. vectorized batch using gather
|
| 119 |
+
# 2. parallelized map
|
| 120 |
+
# 3. pipelined permutation generation
|
| 121 |
+
# 4. optimized permutation batching
|
| 122 |
+
# 5. disabled static optimizations
|
| 123 |
+
|
| 124 |
+
indices_dataset = tf.data.Dataset.range(1)
|
| 125 |
+
|
| 126 |
+
def permutation(_):
|
| 127 |
+
# It turns out to be more performant to make a new set of indices
|
| 128 |
+
# rather than reusing the same range Tensor. (presumably because of
|
| 129 |
+
# buffer forwarding.)
|
| 130 |
+
indices = tf.range(num_samples, dtype=tf.int64)
|
| 131 |
+
if shuffle and shuffle != "batch":
|
| 132 |
+
indices = tf.random.shuffle(indices)
|
| 133 |
+
return indices
|
| 134 |
+
|
| 135 |
+
# We prefetch a single element. Computing large permutations can take
|
| 136 |
+
# quite a while so we don't want to wait for prefetching over an epoch
|
| 137 |
+
# boundary to trigger the next permutation. On the other hand, too many
|
| 138 |
+
# simultaneous shuffles can contend on a hardware level and degrade all
|
| 139 |
+
# performance.
|
| 140 |
+
indices_dataset = indices_dataset.map(permutation).prefetch(1)
|
| 141 |
+
|
| 142 |
+
def slice_batch_indices(indices):
|
| 143 |
+
"""Convert a Tensor of indices into a dataset of batched indices.
|
| 144 |
+
|
| 145 |
+
This step can be accomplished in several ways. The most natural is
|
| 146 |
+
to slice the Tensor in a Dataset map. (With a condition on the upper
|
| 147 |
+
index to handle the partial batch.) However it turns out that
|
| 148 |
+
coercing the Tensor into a shape which is divisible by the batch
|
| 149 |
+
size (and handling the last partial batch separately) allows for a
|
| 150 |
+
much more favorable memory access pattern and improved performance.
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
indices: Tensor which determines the data order for an entire
|
| 154 |
+
epoch.
|
| 155 |
+
|
| 156 |
+
Returns:
|
| 157 |
+
A Dataset of batched indices.
|
| 158 |
+
"""
|
| 159 |
+
num_in_full_batch = num_full_batches * batch_size
|
| 160 |
+
first_k_indices = tf.slice(indices, [0], [num_in_full_batch])
|
| 161 |
+
first_k_indices = tf.reshape(
|
| 162 |
+
first_k_indices, [num_full_batches, batch_size]
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices)
|
| 166 |
+
if self._partial_batch_size:
|
| 167 |
+
index_remainder = tf.data.Dataset.from_tensors(
|
| 168 |
+
tf.slice(
|
| 169 |
+
indices, [num_in_full_batch], [self._partial_batch_size]
|
| 170 |
+
)
|
| 171 |
+
)
|
| 172 |
+
flat_dataset = flat_dataset.concatenate(index_remainder)
|
| 173 |
+
|
| 174 |
+
return flat_dataset
|
| 175 |
+
|
| 176 |
+
def slice_inputs(indices_dataset, inputs):
|
| 177 |
+
"""Slice inputs into a Dataset of batches.
|
| 178 |
+
|
| 179 |
+
Given a Dataset of batch indices and the unsliced inputs,
|
| 180 |
+
this step slices the inputs in a parallelized fashion
|
| 181 |
+
and produces a dataset of input batches.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
indices_dataset: A Dataset of batched indices.
|
| 185 |
+
inputs: A python data structure that contains the inputs,
|
| 186 |
+
targets, and possibly sample weights.
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
A Dataset of input batches matching the batch indices.
|
| 190 |
+
"""
|
| 191 |
+
inputs = array_slicing.convert_to_sliceable(
|
| 192 |
+
self._inputs, target_backend="tensorflow"
|
| 193 |
+
)
|
| 194 |
+
inputs = tree.lists_to_tuples(inputs)
|
| 195 |
+
|
| 196 |
+
dataset = tf.data.Dataset.zip(
|
| 197 |
+
(indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat())
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
def grab_batch(i, data):
|
| 201 |
+
def grab_one(x):
|
| 202 |
+
if isinstance(x, array_slicing.TensorflowSparseWrapper):
|
| 203 |
+
return array_slicing.slice_tensorflow_sparse_wrapper(
|
| 204 |
+
x, i
|
| 205 |
+
)
|
| 206 |
+
if isinstance(x, (list, tuple, dict)):
|
| 207 |
+
return None
|
| 208 |
+
if tf.is_tensor(x):
|
| 209 |
+
return tf.gather(x, i, axis=0)
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
return tree.traverse(grab_one, data)
|
| 213 |
+
|
| 214 |
+
dataset = dataset.map(
|
| 215 |
+
grab_batch, num_parallel_calls=tf.data.AUTOTUNE
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Default optimizations are disabled to avoid the overhead of
|
| 219 |
+
# (unnecessary) input pipeline graph serialization & deserialization
|
| 220 |
+
options = tf.data.Options()
|
| 221 |
+
options.experimental_optimization.apply_default_optimizations = (
|
| 222 |
+
False
|
| 223 |
+
)
|
| 224 |
+
if self._shuffle:
|
| 225 |
+
options.experimental_external_state_policy = (
|
| 226 |
+
tf.data.experimental.ExternalStatePolicy.IGNORE
|
| 227 |
+
)
|
| 228 |
+
dataset = dataset.with_options(options)
|
| 229 |
+
return dataset
|
| 230 |
+
|
| 231 |
+
indices_dataset = indices_dataset.flat_map(slice_batch_indices)
|
| 232 |
+
if shuffle == "batch":
|
| 233 |
+
indices_dataset = indices_dataset.map(tf.random.shuffle)
|
| 234 |
+
|
| 235 |
+
dataset = slice_inputs(indices_dataset, self._inputs)
|
| 236 |
+
|
| 237 |
+
options = tf.data.Options()
|
| 238 |
+
options.experimental_distribute.auto_shard_policy = (
|
| 239 |
+
tf.data.experimental.AutoShardPolicy.DATA
|
| 240 |
+
)
|
| 241 |
+
dataset = dataset.with_options(options)
|
| 242 |
+
return dataset.prefetch(tf.data.AUTOTUNE)
|
| 243 |
+
|
| 244 |
+
def get_jax_iterator(self):
|
| 245 |
+
inputs = array_slicing.convert_to_sliceable(
|
| 246 |
+
self._inputs, target_backend="jax"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
def slice_and_convert_to_jax(sliceable, indices=None):
|
| 250 |
+
x = sliceable[indices]
|
| 251 |
+
x = sliceable.convert_to_jax_compatible(x)
|
| 252 |
+
return x
|
| 253 |
+
|
| 254 |
+
return self._get_iterator(slice_and_convert_to_jax, inputs)
|
| 255 |
+
|
| 256 |
+
def get_torch_dataloader(self):
|
| 257 |
+
import torch
|
| 258 |
+
|
| 259 |
+
from keras.src.backend.torch.core import convert_to_tensor
|
| 260 |
+
|
| 261 |
+
class ArrayDataset(torch.utils.data.Dataset):
|
| 262 |
+
def __init__(self, array):
|
| 263 |
+
self.array = array
|
| 264 |
+
|
| 265 |
+
def __getitems__(self, indices):
|
| 266 |
+
def slice_and_convert(sliceable):
|
| 267 |
+
x = sliceable[indices]
|
| 268 |
+
x = sliceable.convert_to_torch_compatible(x)
|
| 269 |
+
x = convert_to_tensor(x)
|
| 270 |
+
return x
|
| 271 |
+
|
| 272 |
+
return tree.map_structure(slice_and_convert, self.array)
|
| 273 |
+
|
| 274 |
+
def __len__(self):
|
| 275 |
+
return len(self.array[0])
|
| 276 |
+
|
| 277 |
+
class RandomBatchSampler(torch.utils.data.Sampler):
|
| 278 |
+
def __init__(self, sampler):
|
| 279 |
+
self.sampler = sampler
|
| 280 |
+
|
| 281 |
+
def __iter__(self):
|
| 282 |
+
for batch in self.sampler:
|
| 283 |
+
yield [batch[i] for i in torch.randperm(len(batch))]
|
| 284 |
+
|
| 285 |
+
def __len__(self):
|
| 286 |
+
return len(self.sampler)
|
| 287 |
+
|
| 288 |
+
if self._shuffle == "batch":
|
| 289 |
+
batch_sampler = RandomBatchSampler(
|
| 290 |
+
torch.utils.data.BatchSampler(
|
| 291 |
+
range(self._num_samples),
|
| 292 |
+
batch_size=self._batch_size,
|
| 293 |
+
drop_last=False,
|
| 294 |
+
)
|
| 295 |
+
)
|
| 296 |
+
elif self._shuffle:
|
| 297 |
+
batch_sampler = torch.utils.data.BatchSampler(
|
| 298 |
+
torch.utils.data.RandomSampler(range(self._num_samples)),
|
| 299 |
+
batch_size=self._batch_size,
|
| 300 |
+
drop_last=False,
|
| 301 |
+
)
|
| 302 |
+
else:
|
| 303 |
+
batch_sampler = torch.utils.data.BatchSampler(
|
| 304 |
+
torch.utils.data.SequentialSampler(range(self._num_samples)),
|
| 305 |
+
batch_size=self._batch_size,
|
| 306 |
+
drop_last=False,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Because ArrayDataset.__getitems__ returns full batches organized in
|
| 310 |
+
# the expected structure, there is nothing to collate.
|
| 311 |
+
def no_op_collate(batch):
|
| 312 |
+
return batch
|
| 313 |
+
|
| 314 |
+
inputs = array_slicing.convert_to_sliceable(
|
| 315 |
+
self._inputs, target_backend="torch"
|
| 316 |
+
)
|
| 317 |
+
dataset = ArrayDataset(inputs)
|
| 318 |
+
return torch.utils.data.DataLoader(
|
| 319 |
+
dataset, batch_sampler=batch_sampler, collate_fn=no_op_collate
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
def _get_iterator(self, slice_and_convert_fn, inputs):
|
| 323 |
+
global_permutation = None
|
| 324 |
+
if self._shuffle and self._shuffle != "batch":
|
| 325 |
+
global_permutation = np.random.permutation(self._num_samples)
|
| 326 |
+
|
| 327 |
+
for i in range(self._size):
|
| 328 |
+
start = i * self._batch_size
|
| 329 |
+
stop = min((i + 1) * self._batch_size, self._num_samples)
|
| 330 |
+
if self._shuffle == "batch":
|
| 331 |
+
indices = np.random.permutation(stop - start) + start
|
| 332 |
+
elif self._shuffle:
|
| 333 |
+
indices = global_permutation[start:stop]
|
| 334 |
+
else:
|
| 335 |
+
indices = slice(start, stop)
|
| 336 |
+
|
| 337 |
+
slice_indices_and_convert_fn = functools.partial(
|
| 338 |
+
slice_and_convert_fn, indices=indices
|
| 339 |
+
)
|
| 340 |
+
yield tree.map_structure(slice_indices_and_convert_fn, inputs)
|
| 341 |
+
|
| 342 |
+
@property
|
| 343 |
+
def num_batches(self):
|
| 344 |
+
return self._size
|
| 345 |
+
|
| 346 |
+
@property
|
| 347 |
+
def batch_size(self):
|
| 348 |
+
return self._batch_size
|
| 349 |
+
|
| 350 |
+
@property
|
| 351 |
+
def has_partial_batch(self):
|
| 352 |
+
return self._partial_batch_size > 0
|
| 353 |
+
|
| 354 |
+
@property
|
| 355 |
+
def partial_batch_size(self):
|
| 356 |
+
return self._partial_batch_size or None
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def can_convert_arrays(arrays):
|
| 360 |
+
"""Check if array like-inputs can be handled by `ArrayDataAdapter`
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
`True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
|
| 367 |
+
otherwise.
|
| 368 |
+
"""
|
| 369 |
+
|
| 370 |
+
return all(
|
| 371 |
+
tree.flatten(tree.map_structure(array_slicing.can_slice_array, arrays))
|
| 372 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py
ADDED
|
@@ -0,0 +1,520 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import math
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from keras.src import backend
|
| 7 |
+
from keras.src import tree
|
| 8 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 9 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
import pandas
|
| 13 |
+
except ImportError:
|
| 14 |
+
pandas = None
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Leave jax, tf, and torch arrays off this list. Instead we will use
|
| 18 |
+
# `__array__` to detect these types. Doing so allows us to avoid importing a
|
| 19 |
+
# backend framework we are not currently using just to do type-checking.
|
| 20 |
+
ARRAY_TYPES = (np.ndarray,)
|
| 21 |
+
if pandas:
|
| 22 |
+
ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Sliceable:
|
| 26 |
+
"""`Sliceable` wrapping a tensor.
|
| 27 |
+
|
| 28 |
+
A `Sliceable` implements the subscript operator to slice or index against
|
| 29 |
+
the first dimension of the array. It also has conversion methods for each
|
| 30 |
+
one of the backends.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
array: the native array or tensor to wrap.
|
| 34 |
+
|
| 35 |
+
Attributes:
|
| 36 |
+
shape: the shape of the full dense native array.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self, array):
|
| 40 |
+
self.array = array
|
| 41 |
+
|
| 42 |
+
def __getitem__(self, indices):
|
| 43 |
+
"""Select elements in the 0th dimension.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
indices: the indices to select. Only needs to support one dimension,
|
| 47 |
+
the 0th dimension. Should support a `slice` or a list, tuple,
|
| 48 |
+
`np.array` or 1D tensor.
|
| 49 |
+
Returns: A slice of `self.array`.
|
| 50 |
+
"""
|
| 51 |
+
return self.array[indices]
|
| 52 |
+
|
| 53 |
+
@classmethod
|
| 54 |
+
def cast(cls, x, dtype):
|
| 55 |
+
"""Cast a tensor to a different dtype.
|
| 56 |
+
|
| 57 |
+
Only called on a full array as provided by the user.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
x: the tensor to cast.
|
| 61 |
+
Returns: the cast tensor.
|
| 62 |
+
"""
|
| 63 |
+
return x.astype(dtype)
|
| 64 |
+
|
| 65 |
+
@classmethod
|
| 66 |
+
def convert_to_numpy(cls, x):
|
| 67 |
+
"""Convert a tensor to a NumPy array.
|
| 68 |
+
|
| 69 |
+
Only called after slicing using `__getitem__`.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
x: the tensor to convert.
|
| 73 |
+
Returns: the converted tensor.
|
| 74 |
+
"""
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
@classmethod
|
| 78 |
+
def convert_to_tf_dataset_compatible(cls, x):
|
| 79 |
+
"""Convert a tensor to something compatible with `tf.data.Dataset`.
|
| 80 |
+
|
| 81 |
+
This can be a NumPy array, `tf.Tensor` or any other type of tensor that
|
| 82 |
+
`tf.data.Dataset.from_tensors` can consume.
|
| 83 |
+
Only called on a full array as provided by the user.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
x: the tensor to convert.
|
| 87 |
+
Returns: converted version tensor.
|
| 88 |
+
"""
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
@classmethod
|
| 92 |
+
def convert_to_jax_compatible(cls, x):
|
| 93 |
+
"""Convert a tensor to something that the JAX backend can consume.
|
| 94 |
+
|
| 95 |
+
This can be a `JAX` array, `JAXSparse` or a NumPy array.
|
| 96 |
+
Only called after slicing using `__getitem__`.
|
| 97 |
+
Used to convert sparse tensors and densify ragged tensors.
|
| 98 |
+
|
| 99 |
+
Args:
|
| 100 |
+
x: the tensor to convert.
|
| 101 |
+
Returns: the converted tensor.
|
| 102 |
+
"""
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
@classmethod
|
| 106 |
+
def convert_to_torch_compatible(cls, x):
|
| 107 |
+
"""Convert a tensor to something that the Torch backend can consume.
|
| 108 |
+
|
| 109 |
+
This can be a Torch tensor, NumPy array or any other type of tensor that
|
| 110 |
+
`keras.backend.torch.core.convert_to_tensor()` can consume.
|
| 111 |
+
Only called after slicing using `__getitem__`.
|
| 112 |
+
Used to densify sparse tensors and ragged tensors.
|
| 113 |
+
|
| 114 |
+
Args:
|
| 115 |
+
x: the tensor to convert.
|
| 116 |
+
Returns: the converted tensor.
|
| 117 |
+
"""
|
| 118 |
+
return x
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class NumpySliceable(Sliceable):
|
| 122 |
+
pass
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class TensorflowSliceable(Sliceable):
|
| 126 |
+
def __getitem__(self, indices):
|
| 127 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 128 |
+
|
| 129 |
+
if isinstance(indices, slice):
|
| 130 |
+
return self.array[indices]
|
| 131 |
+
else:
|
| 132 |
+
return tf.gather(self.array, indices, axis=0)
|
| 133 |
+
|
| 134 |
+
@classmethod
|
| 135 |
+
def cast(cls, x, dtype):
|
| 136 |
+
from keras.src.backend.tensorflow.core import cast
|
| 137 |
+
|
| 138 |
+
return cast(x, dtype)
|
| 139 |
+
|
| 140 |
+
@classmethod
|
| 141 |
+
def convert_to_numpy(cls, x):
|
| 142 |
+
from keras.src.backend.tensorflow.core import convert_to_numpy
|
| 143 |
+
|
| 144 |
+
return convert_to_numpy(x)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class TensorflowRaggedSliceable(TensorflowSliceable):
|
| 148 |
+
@classmethod
|
| 149 |
+
def convert_to_jax_compatible(cls, x):
|
| 150 |
+
return cls.convert_to_numpy(x)
|
| 151 |
+
|
| 152 |
+
@classmethod
|
| 153 |
+
def convert_to_torch_compatible(cls, x):
|
| 154 |
+
return x.to_tensor()
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class TensorflowSparseSliceable(TensorflowSliceable):
|
| 158 |
+
def __init__(self, array):
|
| 159 |
+
super().__init__(to_tensorflow_sparse_wrapper(array))
|
| 160 |
+
|
| 161 |
+
@property
|
| 162 |
+
def shape(self):
|
| 163 |
+
return self.array.sparse.shape
|
| 164 |
+
|
| 165 |
+
def __getitem__(self, indices):
|
| 166 |
+
return slice_tensorflow_sparse_wrapper(self.array, indices)
|
| 167 |
+
|
| 168 |
+
@classmethod
|
| 169 |
+
def convert_to_tf_dataset_compatible(cls, x):
|
| 170 |
+
return to_tensorflow_sparse_wrapper(x)
|
| 171 |
+
|
| 172 |
+
@classmethod
|
| 173 |
+
def convert_to_jax_compatible(cls, x):
|
| 174 |
+
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
|
| 175 |
+
|
| 176 |
+
@classmethod
|
| 177 |
+
def convert_to_torch_compatible(cls, x):
|
| 178 |
+
from keras.src.backend.tensorflow import sparse as tf_sparse
|
| 179 |
+
|
| 180 |
+
return tf_sparse.sparse_to_dense(x)
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
class JaxSparseSliceable(Sliceable):
|
| 184 |
+
def __getitem__(self, indices):
|
| 185 |
+
return self.array[indices, ...]
|
| 186 |
+
|
| 187 |
+
@classmethod
|
| 188 |
+
def convert_to_numpy(cls, x):
|
| 189 |
+
from keras.src.backend.jax.core import convert_to_numpy
|
| 190 |
+
|
| 191 |
+
return convert_to_numpy(x)
|
| 192 |
+
|
| 193 |
+
@classmethod
|
| 194 |
+
def convert_to_tf_dataset_compatible(cls, array):
|
| 195 |
+
return to_tensorflow_sparse_wrapper(
|
| 196 |
+
data_adapter_utils.jax_sparse_to_tf_sparse(array)
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
@classmethod
|
| 200 |
+
def convert_to_torch_compatible(cls, x):
|
| 201 |
+
return x.todense()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class TorchSliceable(Sliceable):
|
| 205 |
+
@classmethod
|
| 206 |
+
def cast(cls, x, dtype):
|
| 207 |
+
from keras.src.backend.torch.core import cast
|
| 208 |
+
|
| 209 |
+
return cast(x, dtype)
|
| 210 |
+
|
| 211 |
+
@classmethod
|
| 212 |
+
def convert_to_numpy(cls, x):
|
| 213 |
+
from keras.src.backend.torch.core import convert_to_numpy
|
| 214 |
+
|
| 215 |
+
return convert_to_numpy(x)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
class PandasSliceable(Sliceable):
|
| 219 |
+
def __getitem__(self, indices):
|
| 220 |
+
return self.array.iloc[indices]
|
| 221 |
+
|
| 222 |
+
@classmethod
|
| 223 |
+
def convert_to_numpy(cls, x):
|
| 224 |
+
return x.to_numpy()
|
| 225 |
+
|
| 226 |
+
@classmethod
|
| 227 |
+
def convert_to_tf_dataset_compatible(cls, x):
|
| 228 |
+
return cls.convert_to_numpy(x)
|
| 229 |
+
|
| 230 |
+
@classmethod
|
| 231 |
+
def convert_to_jax_compatible(cls, x):
|
| 232 |
+
return cls.convert_to_numpy(x)
|
| 233 |
+
|
| 234 |
+
@classmethod
|
| 235 |
+
def convert_to_torch_compatible(cls, x):
|
| 236 |
+
return cls.convert_to_numpy(x)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class PandasDataFrameSliceable(PandasSliceable):
|
| 240 |
+
pass
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
class PandasSeriesSliceable(PandasSliceable):
|
| 244 |
+
@classmethod
|
| 245 |
+
def convert_to_numpy(cls, x):
|
| 246 |
+
return np.expand_dims(x.to_numpy(), axis=-1)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class ScipySparseSliceable(Sliceable):
|
| 250 |
+
def __init__(self, array):
|
| 251 |
+
# The COO representation is not indexable / sliceable and does not lend
|
| 252 |
+
# itself to it. Use the CSR representation instead, which is sliceable.
|
| 253 |
+
super().__init__(array.tocsr())
|
| 254 |
+
|
| 255 |
+
@classmethod
|
| 256 |
+
def convert_to_numpy(cls, x):
|
| 257 |
+
return x.todense()
|
| 258 |
+
|
| 259 |
+
@classmethod
|
| 260 |
+
def convert_to_tf_dataset_compatible(cls, x):
|
| 261 |
+
return to_tensorflow_sparse_wrapper(
|
| 262 |
+
data_adapter_utils.scipy_sparse_to_tf_sparse(x)
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
@classmethod
|
| 266 |
+
def convert_to_jax_compatible(cls, x):
|
| 267 |
+
return data_adapter_utils.scipy_sparse_to_jax_sparse(x)
|
| 268 |
+
|
| 269 |
+
@classmethod
|
| 270 |
+
def convert_to_torch_compatible(cls, x):
|
| 271 |
+
return x.todense()
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
# `tf.SparseTensor` does not support indexing or `tf.gather`. The COO
|
| 275 |
+
# representation it uses does not lend itself to indexing. We add some
|
| 276 |
+
# intermediary tensors to ease the indexing and slicing. We put both indices and
|
| 277 |
+
# values in `RaggedTensor`s where each row corresponds to a row in the sparse
|
| 278 |
+
# tensor. This is because the number of values per row is not fixed.
|
| 279 |
+
# `RaggedTensor`s do support indexing and `tf.gather`, although on CPU only.
|
| 280 |
+
# We then reconstruct a `SparseTensor` from extracted rows. In theory, there is
|
| 281 |
+
# no duplication of data for the indices and values, only the addition of row
|
| 282 |
+
# splits for the ragged representation.
|
| 283 |
+
# `TensorflowSparseWrapper` is a named tuple which combines the original
|
| 284 |
+
# `SparseTensor` (used for the shape) and the ragged representations of indices
|
| 285 |
+
# and values for indexing / slicing. We use a named tuple and not a `Sliceable`
|
| 286 |
+
# to be able to ingest it in `tf.data.Dataset.from_tensors()` and map it.
|
| 287 |
+
|
| 288 |
+
TensorflowSparseWrapper = collections.namedtuple(
|
| 289 |
+
"TensorflowSparseWrapper", ["sparse", "ragged_indices", "ragged_values"]
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
def to_tensorflow_sparse_wrapper(sparse):
|
| 294 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 295 |
+
|
| 296 |
+
row_ids = sparse.indices[:, 0]
|
| 297 |
+
row_splits = tf.experimental.RowPartition.from_value_rowids(
|
| 298 |
+
row_ids
|
| 299 |
+
).row_splits()
|
| 300 |
+
|
| 301 |
+
ragged_indices = tf.cast(
|
| 302 |
+
tf.RaggedTensor.from_row_splits(sparse.indices, row_splits), tf.int64
|
| 303 |
+
)
|
| 304 |
+
ragged_values = tf.RaggedTensor.from_row_splits(sparse.values, row_splits)
|
| 305 |
+
return TensorflowSparseWrapper(sparse, ragged_indices, ragged_values)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def slice_tensorflow_sparse_wrapper(sparse_wrapper, indices):
|
| 309 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 310 |
+
|
| 311 |
+
if isinstance(indices, slice):
|
| 312 |
+
sparse_indices = sparse_wrapper.ragged_indices[indices]
|
| 313 |
+
sparse_values = sparse_wrapper.ragged_values[indices]
|
| 314 |
+
batch_dim = indices.stop - indices.start
|
| 315 |
+
else:
|
| 316 |
+
sparse_indices = tf.gather(sparse_wrapper.ragged_indices, indices)
|
| 317 |
+
sparse_values = tf.gather(sparse_wrapper.ragged_values, indices)
|
| 318 |
+
if isinstance(indices, list):
|
| 319 |
+
batch_dim = len(indices)
|
| 320 |
+
else:
|
| 321 |
+
batch_dim = indices.shape[0]
|
| 322 |
+
if batch_dim is None:
|
| 323 |
+
batch_dim = tf.shape(indices)[0]
|
| 324 |
+
|
| 325 |
+
row_ids = sparse_indices.value_rowids()
|
| 326 |
+
sparse_indices = sparse_indices.flat_values[:, 1:] # remove first value
|
| 327 |
+
sparse_indices = tf.concat(
|
| 328 |
+
[tf.expand_dims(row_ids, -1), sparse_indices], axis=1
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
sparse_values = sparse_values.flat_values
|
| 332 |
+
sparse_shape = (batch_dim,) + tuple(
|
| 333 |
+
sparse_wrapper.sparse.shape.as_list()[1:]
|
| 334 |
+
)
|
| 335 |
+
return tf.SparseTensor(sparse_indices, sparse_values, sparse_shape)
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def can_slice_array(x):
|
| 339 |
+
return (
|
| 340 |
+
x is None
|
| 341 |
+
or isinstance(x, ARRAY_TYPES)
|
| 342 |
+
or data_adapter_utils.is_tensorflow_tensor(x)
|
| 343 |
+
or data_adapter_utils.is_jax_array(x)
|
| 344 |
+
or data_adapter_utils.is_torch_tensor(x)
|
| 345 |
+
or data_adapter_utils.is_scipy_sparse(x)
|
| 346 |
+
or hasattr(x, "__array__")
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def convert_to_sliceable(arrays, target_backend=None):
|
| 351 |
+
"""Convert a structure of arrays into `Sliceable` instances
|
| 352 |
+
|
| 353 |
+
Args:
|
| 354 |
+
arrays: the arrays to convert.
|
| 355 |
+
target_backend: the target backend for the output:
|
| 356 |
+
- `None` indicates that `arrays` will be wrapped into `Sliceable`s
|
| 357 |
+
as-is without using a different representation. This is used by
|
| 358 |
+
`train_validation_split()`.
|
| 359 |
+
- `tensorflow` indicates that
|
| 360 |
+
`Sliceable.convert_to_tf_dataset_compatible` will be called. The
|
| 361 |
+
returned structure therefore contains arrays, not `Sliceable`s.
|
| 362 |
+
- `numpy`, `jax` or `torch` indices that the arrays will eventually
|
| 363 |
+
be converted to this backend type after slicing. In this case,
|
| 364 |
+
the intermediary `Sliceable`s may use a different representation
|
| 365 |
+
from the input `arrays` for better performance.
|
| 366 |
+
Returns: the same structure with `Sliceable` instances or arrays.
|
| 367 |
+
"""
|
| 368 |
+
|
| 369 |
+
def convert_single_array(x):
|
| 370 |
+
if x is None:
|
| 371 |
+
return x
|
| 372 |
+
|
| 373 |
+
# Special case: handle np "object" arrays containing strings
|
| 374 |
+
if (
|
| 375 |
+
isinstance(x, np.ndarray)
|
| 376 |
+
and str(x.dtype) == "object"
|
| 377 |
+
and backend.backend() == "tensorflow"
|
| 378 |
+
and all(isinstance(e, str) for e in x)
|
| 379 |
+
):
|
| 380 |
+
x = tf.convert_to_tensor(x, dtype="string")
|
| 381 |
+
|
| 382 |
+
# Step 1. Determine which Sliceable class to use.
|
| 383 |
+
if isinstance(x, np.ndarray):
|
| 384 |
+
sliceable_class = NumpySliceable
|
| 385 |
+
elif data_adapter_utils.is_tensorflow_tensor(x):
|
| 386 |
+
if data_adapter_utils.is_tensorflow_ragged(x):
|
| 387 |
+
sliceable_class = TensorflowRaggedSliceable
|
| 388 |
+
elif data_adapter_utils.is_tensorflow_sparse(x):
|
| 389 |
+
sliceable_class = TensorflowSparseSliceable
|
| 390 |
+
else:
|
| 391 |
+
sliceable_class = TensorflowSliceable
|
| 392 |
+
elif data_adapter_utils.is_jax_array(x):
|
| 393 |
+
if data_adapter_utils.is_jax_sparse(x):
|
| 394 |
+
sliceable_class = JaxSparseSliceable
|
| 395 |
+
else:
|
| 396 |
+
x = np.asarray(x)
|
| 397 |
+
sliceable_class = NumpySliceable
|
| 398 |
+
elif data_adapter_utils.is_torch_tensor(x):
|
| 399 |
+
sliceable_class = TorchSliceable
|
| 400 |
+
elif pandas is not None and isinstance(x, pandas.DataFrame):
|
| 401 |
+
sliceable_class = PandasDataFrameSliceable
|
| 402 |
+
elif pandas is not None and isinstance(x, pandas.Series):
|
| 403 |
+
sliceable_class = PandasSeriesSliceable
|
| 404 |
+
elif data_adapter_utils.is_scipy_sparse(x):
|
| 405 |
+
sliceable_class = ScipySparseSliceable
|
| 406 |
+
elif hasattr(x, "__array__"):
|
| 407 |
+
x = np.asarray(x)
|
| 408 |
+
sliceable_class = NumpySliceable
|
| 409 |
+
else:
|
| 410 |
+
raise ValueError(
|
| 411 |
+
"Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "
|
| 412 |
+
"tf.SparseTensor, jax.np.ndarray, "
|
| 413 |
+
"jax.experimental.sparse.JAXSparse, torch.Tensor, "
|
| 414 |
+
"Pandas Dataframe, or Pandas Series. Received invalid input: "
|
| 415 |
+
f"{x} (of type {type(x)})"
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
# Step 2. Normalize floats to floatx.
|
| 419 |
+
def is_non_floatx_float(dtype):
|
| 420 |
+
return (
|
| 421 |
+
dtype is not object
|
| 422 |
+
and backend.is_float_dtype(dtype)
|
| 423 |
+
and not backend.standardize_dtype(dtype) == backend.floatx()
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
cast_dtype = None
|
| 427 |
+
if pandas is not None and isinstance(x, pandas.DataFrame):
|
| 428 |
+
if any(is_non_floatx_float(d) for d in x.dtypes.values):
|
| 429 |
+
cast_dtype = backend.floatx()
|
| 430 |
+
else:
|
| 431 |
+
if is_non_floatx_float(x.dtype):
|
| 432 |
+
cast_dtype = backend.floatx()
|
| 433 |
+
|
| 434 |
+
if cast_dtype is not None:
|
| 435 |
+
x = sliceable_class.cast(x, cast_dtype)
|
| 436 |
+
|
| 437 |
+
# Step 3. Apply target backend specific logic and optimizations.
|
| 438 |
+
if target_backend is None:
|
| 439 |
+
return sliceable_class(x)
|
| 440 |
+
|
| 441 |
+
if target_backend == "tensorflow":
|
| 442 |
+
return sliceable_class.convert_to_tf_dataset_compatible(x)
|
| 443 |
+
|
| 444 |
+
# With dense arrays and JAX as output, it is faster to use NumPy as an
|
| 445 |
+
# intermediary representation, so wrap input array in a NumPy array,
|
| 446 |
+
# which should not use extra memory.
|
| 447 |
+
# See https://github.com/google/jax/issues/1276 for an explanation of
|
| 448 |
+
# why slicing a NumPy array is faster than slicing a JAX array.
|
| 449 |
+
if target_backend == "jax" and sliceable_class in (
|
| 450 |
+
TensorflowSliceable,
|
| 451 |
+
TorchSliceable,
|
| 452 |
+
):
|
| 453 |
+
x = np.asarray(x)
|
| 454 |
+
sliceable_class = NumpySliceable
|
| 455 |
+
|
| 456 |
+
return sliceable_class(x)
|
| 457 |
+
|
| 458 |
+
return tree.map_structure(convert_single_array, arrays)
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def train_validation_split(arrays, validation_split):
|
| 462 |
+
"""Split arrays into train and validation subsets in deterministic order.
|
| 463 |
+
|
| 464 |
+
The last part of data will become validation data.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
arrays: Tensors to split. Allowed inputs are arbitrarily nested
|
| 468 |
+
structures of Tensors and NumPy arrays.
|
| 469 |
+
validation_split: Float between 0 and 1. The proportion of the dataset
|
| 470 |
+
to include in the validation split. The rest of the dataset will be
|
| 471 |
+
included in the training split.
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
`(train_arrays, validation_arrays)`
|
| 475 |
+
"""
|
| 476 |
+
|
| 477 |
+
flat_arrays = tree.flatten(arrays)
|
| 478 |
+
unsplitable = [type(t) for t in flat_arrays if not can_slice_array(t)]
|
| 479 |
+
if unsplitable:
|
| 480 |
+
raise ValueError(
|
| 481 |
+
"Argument `validation_split` is only supported "
|
| 482 |
+
"for tensors or NumPy arrays."
|
| 483 |
+
f"Found incompatible type in the input: {unsplitable}"
|
| 484 |
+
)
|
| 485 |
+
|
| 486 |
+
if all(t is None for t in flat_arrays):
|
| 487 |
+
return arrays, arrays
|
| 488 |
+
|
| 489 |
+
first_non_none = None
|
| 490 |
+
for t in flat_arrays:
|
| 491 |
+
if t is not None:
|
| 492 |
+
first_non_none = t
|
| 493 |
+
break
|
| 494 |
+
|
| 495 |
+
# Assumes all arrays have the same batch shape or are `None`.
|
| 496 |
+
batch_dim = int(first_non_none.shape[0])
|
| 497 |
+
split_at = int(math.floor(batch_dim * (1.0 - validation_split)))
|
| 498 |
+
|
| 499 |
+
if split_at == 0 or split_at == batch_dim:
|
| 500 |
+
raise ValueError(
|
| 501 |
+
f"Training data contains {batch_dim} samples, which is not "
|
| 502 |
+
"sufficient to split it into a validation and training set as "
|
| 503 |
+
f"specified by `validation_split={validation_split}`. Either "
|
| 504 |
+
"provide more data, or a different value for the "
|
| 505 |
+
"`validation_split` argument."
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
def _split(t, start, end):
|
| 509 |
+
if t is None:
|
| 510 |
+
return t
|
| 511 |
+
return t[start:end]
|
| 512 |
+
|
| 513 |
+
sliceables = convert_to_sliceable(arrays)
|
| 514 |
+
train_arrays = tree.map_structure(
|
| 515 |
+
lambda x: _split(x, start=0, end=split_at), sliceables
|
| 516 |
+
)
|
| 517 |
+
val_arrays = tree.map_structure(
|
| 518 |
+
lambda x: _split(x, start=split_at, end=batch_dim), sliceables
|
| 519 |
+
)
|
| 520 |
+
return train_arrays, val_arrays
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class DataAdapter:
|
| 2 |
+
"""Base class for input data adapters.
|
| 3 |
+
|
| 4 |
+
The purpose of a DataAdapter is to provide a unified interface to
|
| 5 |
+
iterate over input data provided in a variety of formats -- such as
|
| 6 |
+
NumPy arrays, tf.Tensors, tf.data.Datasets, Keras PyDatasets, etc.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
def get_numpy_iterator(self):
|
| 10 |
+
"""Get a Python iterable for the `DataAdapter`, that yields NumPy
|
| 11 |
+
arrays.
|
| 12 |
+
|
| 13 |
+
Returns:
|
| 14 |
+
A Python iterator.
|
| 15 |
+
"""
|
| 16 |
+
raise NotImplementedError
|
| 17 |
+
|
| 18 |
+
def get_tf_dataset(self):
|
| 19 |
+
"""Get a `tf.data.Dataset` instance for the DataAdapter.
|
| 20 |
+
|
| 21 |
+
Note that the dataset returned does not repeat for epoch, so caller
|
| 22 |
+
might need to create new iterator for the same dataset at the beginning
|
| 23 |
+
of the epoch. This behavior might change in the future.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
A `tf.data.Dataset`. Caller might use the dataset in different
|
| 27 |
+
context, e.g. iter(dataset) in eager to get the value directly, or
|
| 28 |
+
in graph mode, provide the iterator tensor to Keras model function.
|
| 29 |
+
"""
|
| 30 |
+
raise NotImplementedError
|
| 31 |
+
|
| 32 |
+
def get_jax_iterator(self):
|
| 33 |
+
"""Get a Python iterable for the `DataAdapter`, that yields arrays that
|
| 34 |
+
that can be fed to JAX. NumPy arrays are preferred for performance.
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
A Python iterator.
|
| 38 |
+
"""
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
def get_torch_dataloader(self):
|
| 42 |
+
"""Get a Torch `DataLoader` for the `DataAdapter`.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
A Torch `DataLoader`.
|
| 46 |
+
"""
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
@property
|
| 50 |
+
def num_batches(self):
|
| 51 |
+
"""Return the size (number of batches) for the dataset created.
|
| 52 |
+
|
| 53 |
+
For certain type of the data input, the number of batches is known, eg
|
| 54 |
+
for Numpy data, the size is same as (number_of_element / batch_size).
|
| 55 |
+
Whereas for dataset or python generator, the size is unknown since it
|
| 56 |
+
may or may not have an end state.
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
int, the number of batches for the dataset, or None if it is
|
| 60 |
+
unknown. The caller could use this to control the loop of training,
|
| 61 |
+
show progress bar, or handle unexpected StopIteration error.
|
| 62 |
+
"""
|
| 63 |
+
raise NotImplementedError
|
| 64 |
+
|
| 65 |
+
@property
|
| 66 |
+
def batch_size(self):
|
| 67 |
+
"""Return the batch size of the dataset created.
|
| 68 |
+
|
| 69 |
+
For certain type of the data input, the batch size is known, and even
|
| 70 |
+
required, like numpy array. Whereas for dataset, the batch is unknown
|
| 71 |
+
unless we take a peek.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
int, the batch size of the dataset, or None if it is unknown.
|
| 75 |
+
"""
|
| 76 |
+
raise NotImplementedError
|
| 77 |
+
|
| 78 |
+
@property
|
| 79 |
+
def has_partial_batch(self):
|
| 80 |
+
"""Whether the dataset has partial batch at the end."""
|
| 81 |
+
raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
@property
|
| 84 |
+
def partial_batch_size(self):
|
| 85 |
+
"""The size of the final partial batch for dataset.
|
| 86 |
+
|
| 87 |
+
Will return None if has_partial_batch is False or batch_size is None.
|
| 88 |
+
"""
|
| 89 |
+
raise NotImplementedError
|
| 90 |
+
|
| 91 |
+
def on_epoch_begin(self):
|
| 92 |
+
"""A hook called before each epoch."""
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
def on_epoch_end(self):
|
| 96 |
+
"""A hook called after each epoch."""
|
| 97 |
+
pass
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from keras.src import backend
|
| 4 |
+
from keras.src import ops
|
| 5 |
+
from keras.src import tree
|
| 6 |
+
from keras.src.api_export import keras_export
|
| 7 |
+
|
| 8 |
+
NUM_BATCHES_FOR_TENSOR_SPEC = 2
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@keras_export("keras.utils.unpack_x_y_sample_weight")
|
| 12 |
+
def unpack_x_y_sample_weight(data):
|
| 13 |
+
"""Unpacks user-provided data tuple.
|
| 14 |
+
|
| 15 |
+
This is a convenience utility to be used when overriding
|
| 16 |
+
`Model.train_step`, `Model.test_step`, or `Model.predict_step`.
|
| 17 |
+
This utility makes it easy to support data of the form `(x,)`,
|
| 18 |
+
`(x, y)`, or `(x, y, sample_weight)`.
|
| 19 |
+
|
| 20 |
+
Example:
|
| 21 |
+
|
| 22 |
+
>>> features_batch = ops.ones((10, 5))
|
| 23 |
+
>>> labels_batch = ops.zeros((10, 5))
|
| 24 |
+
>>> data = (features_batch, labels_batch)
|
| 25 |
+
>>> # `y` and `sample_weight` will default to `None` if not provided.
|
| 26 |
+
>>> x, y, sample_weight = unpack_x_y_sample_weight(data)
|
| 27 |
+
>>> sample_weight is None
|
| 28 |
+
True
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
The unpacked tuple, with `None`s for `y` and `sample_weight` if they are
|
| 35 |
+
not provided.
|
| 36 |
+
"""
|
| 37 |
+
if isinstance(data, list):
|
| 38 |
+
data = tuple(data)
|
| 39 |
+
if not isinstance(data, tuple):
|
| 40 |
+
return (data, None, None)
|
| 41 |
+
elif len(data) == 1:
|
| 42 |
+
return (data[0], None, None)
|
| 43 |
+
elif len(data) == 2:
|
| 44 |
+
return (data[0], data[1], None)
|
| 45 |
+
elif len(data) == 3:
|
| 46 |
+
return (data[0], data[1], data[2])
|
| 47 |
+
error_msg = (
|
| 48 |
+
"Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
|
| 49 |
+
f"or `(x, y, sample_weight)`, found: {data}"
|
| 50 |
+
)
|
| 51 |
+
raise ValueError(error_msg)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@keras_export("keras.utils.pack_x_y_sample_weight")
|
| 55 |
+
def pack_x_y_sample_weight(x, y=None, sample_weight=None):
|
| 56 |
+
"""Packs user-provided data into a tuple.
|
| 57 |
+
|
| 58 |
+
This is a convenience utility for packing data into the tuple formats
|
| 59 |
+
that `Model.fit()` uses.
|
| 60 |
+
|
| 61 |
+
Example:
|
| 62 |
+
|
| 63 |
+
>>> x = ops.ones((10, 1))
|
| 64 |
+
>>> data = pack_x_y_sample_weight(x)
|
| 65 |
+
>>> isinstance(data, ops.Tensor)
|
| 66 |
+
True
|
| 67 |
+
>>> y = ops.ones((10, 1))
|
| 68 |
+
>>> data = pack_x_y_sample_weight(x, y)
|
| 69 |
+
>>> isinstance(data, tuple)
|
| 70 |
+
True
|
| 71 |
+
>>> x, y = data
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
x: Features to pass to `Model`.
|
| 75 |
+
y: Ground-truth targets to pass to `Model`.
|
| 76 |
+
sample_weight: Sample weight for each element.
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Tuple in the format used in `Model.fit()`.
|
| 80 |
+
"""
|
| 81 |
+
if y is None:
|
| 82 |
+
# For single x-input, we do no tuple wrapping since in this case
|
| 83 |
+
# there is no ambiguity. This also makes NumPy and Dataset
|
| 84 |
+
# consistent in that the user does not have to wrap their Dataset
|
| 85 |
+
# data in an unnecessary tuple.
|
| 86 |
+
if not isinstance(x, (tuple, list)):
|
| 87 |
+
return x
|
| 88 |
+
else:
|
| 89 |
+
return (x,)
|
| 90 |
+
elif sample_weight is None:
|
| 91 |
+
return (x, y)
|
| 92 |
+
else:
|
| 93 |
+
return (x, y, sample_weight)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def list_to_tuple(maybe_list):
|
| 97 |
+
"""Datasets will stack any list of tensors, so we convert them to tuples."""
|
| 98 |
+
if isinstance(maybe_list, list):
|
| 99 |
+
return tuple(maybe_list)
|
| 100 |
+
return maybe_list
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def check_data_cardinality(data):
|
| 104 |
+
num_samples = set(int(i.shape[0]) for i in tree.flatten(data))
|
| 105 |
+
if len(num_samples) > 1:
|
| 106 |
+
msg = (
|
| 107 |
+
"Data cardinality is ambiguous. "
|
| 108 |
+
"Make sure all arrays contain the same number of samples."
|
| 109 |
+
)
|
| 110 |
+
for label, single_data in zip(["x", "y", "sample_weight"], data):
|
| 111 |
+
sizes = ", ".join(
|
| 112 |
+
str(i.shape[0]) for i in tree.flatten(single_data)
|
| 113 |
+
)
|
| 114 |
+
msg += f"'{label}' sizes: {sizes}\n"
|
| 115 |
+
raise ValueError(msg)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def class_weight_to_sample_weights(y, class_weight):
|
| 119 |
+
# Convert to numpy to ensure consistent handling of operations
|
| 120 |
+
# (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch
|
| 121 |
+
|
| 122 |
+
y_numpy = ops.convert_to_numpy(y)
|
| 123 |
+
sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx())
|
| 124 |
+
if len(y_numpy.shape) > 1:
|
| 125 |
+
if y_numpy.shape[-1] != 1:
|
| 126 |
+
y_numpy = np.argmax(y_numpy, axis=-1)
|
| 127 |
+
else:
|
| 128 |
+
y_numpy = np.squeeze(y_numpy, axis=-1)
|
| 129 |
+
y_numpy = np.round(y_numpy).astype("int32")
|
| 130 |
+
|
| 131 |
+
for i in range(y_numpy.shape[0]):
|
| 132 |
+
sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0)
|
| 133 |
+
return sample_weight
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def get_tensor_spec(batches):
|
| 137 |
+
"""Return the common tensor spec for a list of batches.
|
| 138 |
+
|
| 139 |
+
Args:
|
| 140 |
+
batches: list of structures of tensors. The structures must be
|
| 141 |
+
identical, but the shape at each leaf may be different.
|
| 142 |
+
Returns: the common tensor spec for all the batches.
|
| 143 |
+
"""
|
| 144 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 145 |
+
|
| 146 |
+
def get_single_tensor_spec(*tensors):
|
| 147 |
+
x = tensors[0]
|
| 148 |
+
rank = len(x.shape)
|
| 149 |
+
if rank < 1:
|
| 150 |
+
raise ValueError(
|
| 151 |
+
"When passing a dataset to a Keras model, the arrays must "
|
| 152 |
+
f"be at least rank 1. Received: {x} of rank {len(x.shape)}."
|
| 153 |
+
)
|
| 154 |
+
for t in tensors:
|
| 155 |
+
if len(t.shape) != rank:
|
| 156 |
+
raise ValueError(
|
| 157 |
+
"When passing a dataset to a Keras model, the "
|
| 158 |
+
"corresponding arrays in each batch must have the same "
|
| 159 |
+
f"rank. Received: {x} and {t}"
|
| 160 |
+
)
|
| 161 |
+
shape = []
|
| 162 |
+
# Merge shapes: go through each dimension one by one and keep the
|
| 163 |
+
# common values
|
| 164 |
+
for dims in zip(*[list(x.shape) for x in tensors]):
|
| 165 |
+
dims_set = set(dims)
|
| 166 |
+
shape.append(dims_set.pop() if len(dims_set) == 1 else None)
|
| 167 |
+
shape[0] = None # batch size may not be static
|
| 168 |
+
|
| 169 |
+
dtype = backend.standardize_dtype(x.dtype)
|
| 170 |
+
if isinstance(x, tf.RaggedTensor):
|
| 171 |
+
return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
|
| 172 |
+
if (
|
| 173 |
+
isinstance(x, tf.SparseTensor)
|
| 174 |
+
or is_scipy_sparse(x)
|
| 175 |
+
or is_jax_sparse(x)
|
| 176 |
+
):
|
| 177 |
+
return tf.SparseTensorSpec(shape=shape, dtype=dtype)
|
| 178 |
+
else:
|
| 179 |
+
return tf.TensorSpec(shape=shape, dtype=dtype)
|
| 180 |
+
|
| 181 |
+
return tree.map_structure(get_single_tensor_spec, *batches)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def get_jax_iterator(iterable):
|
| 185 |
+
import jax
|
| 186 |
+
import jax.experimental.sparse as jax_sparse
|
| 187 |
+
|
| 188 |
+
def convert_to_jax_compatible(x):
|
| 189 |
+
if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)):
|
| 190 |
+
return x
|
| 191 |
+
elif is_scipy_sparse(x):
|
| 192 |
+
return scipy_sparse_to_jax_sparse(x)
|
| 193 |
+
elif is_tensorflow_sparse(x):
|
| 194 |
+
return tf_sparse_to_jax_sparse(x)
|
| 195 |
+
else:
|
| 196 |
+
return np.asarray(x)
|
| 197 |
+
|
| 198 |
+
for batch in iterable:
|
| 199 |
+
yield tree.map_structure(convert_to_jax_compatible, batch)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def get_numpy_iterator(iterable):
|
| 203 |
+
def convert_to_numpy(x):
|
| 204 |
+
if not isinstance(x, np.ndarray):
|
| 205 |
+
# Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
|
| 206 |
+
# `torch.Tensor`, as well as any other tensor-like object that
|
| 207 |
+
# has added numpy support.
|
| 208 |
+
if hasattr(x, "__array__"):
|
| 209 |
+
if is_torch_tensor(x):
|
| 210 |
+
x = x.cpu()
|
| 211 |
+
x = np.asarray(x)
|
| 212 |
+
return x
|
| 213 |
+
|
| 214 |
+
for batch in iterable:
|
| 215 |
+
yield tree.map_structure(convert_to_numpy, batch)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def get_torch_dataloader(iterable):
|
| 219 |
+
import torch.utils.data as torch_data
|
| 220 |
+
|
| 221 |
+
from keras.src.backend.torch.core import convert_to_tensor
|
| 222 |
+
|
| 223 |
+
class ConverterIterableDataset(torch_data.IterableDataset):
|
| 224 |
+
def __init__(self, iterable):
|
| 225 |
+
self.iterable = iterable
|
| 226 |
+
|
| 227 |
+
def __iter__(self):
|
| 228 |
+
for batch in self.iterable:
|
| 229 |
+
yield tree.map_structure(convert_to_tensor, batch)
|
| 230 |
+
|
| 231 |
+
dataset = ConverterIterableDataset(iterable)
|
| 232 |
+
# `batch_size=None` indicates that we should not re-batch
|
| 233 |
+
return torch_data.DataLoader(dataset, batch_size=None)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def is_tensorflow_tensor(value):
|
| 237 |
+
if hasattr(value, "__class__"):
|
| 238 |
+
if value.__class__.__name__ in ("RaggedTensor", "SparseTensor"):
|
| 239 |
+
return "tensorflow.python." in str(value.__class__.__module__)
|
| 240 |
+
for parent in value.__class__.__mro__:
|
| 241 |
+
if parent.__name__ in ("Tensor") and "tensorflow.python." in str(
|
| 242 |
+
parent.__module__
|
| 243 |
+
):
|
| 244 |
+
return True
|
| 245 |
+
return False
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def is_tensorflow_ragged(value):
|
| 249 |
+
if hasattr(value, "__class__"):
|
| 250 |
+
return (
|
| 251 |
+
value.__class__.__name__ == "RaggedTensor"
|
| 252 |
+
and "tensorflow.python." in str(value.__class__.__module__)
|
| 253 |
+
)
|
| 254 |
+
return False
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def is_tensorflow_sparse(value):
|
| 258 |
+
if hasattr(value, "__class__"):
|
| 259 |
+
return (
|
| 260 |
+
value.__class__.__name__ == "SparseTensor"
|
| 261 |
+
and "tensorflow.python." in str(value.__class__.__module__)
|
| 262 |
+
)
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
def is_jax_array(value):
|
| 267 |
+
if hasattr(value, "__class__"):
|
| 268 |
+
for parent in value.__class__.__mro__:
|
| 269 |
+
if parent.__name__ == "Array" and str(parent.__module__) == "jax":
|
| 270 |
+
return True
|
| 271 |
+
return is_jax_sparse(value) # JAX sparse arrays do not extend jax.Array
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def is_jax_sparse(value):
|
| 275 |
+
if hasattr(value, "__class__"):
|
| 276 |
+
return str(value.__class__.__module__).startswith(
|
| 277 |
+
"jax.experimental.sparse"
|
| 278 |
+
)
|
| 279 |
+
return False
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def is_torch_tensor(value):
|
| 283 |
+
if hasattr(value, "__class__"):
|
| 284 |
+
for parent in value.__class__.__mro__:
|
| 285 |
+
if parent.__name__ == "Tensor" and str(parent.__module__).endswith(
|
| 286 |
+
"torch"
|
| 287 |
+
):
|
| 288 |
+
return True
|
| 289 |
+
return False
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def is_scipy_sparse(x):
|
| 293 |
+
return str(x.__class__.__module__).startswith("scipy.sparse") and hasattr(
|
| 294 |
+
x, "tocoo"
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def scipy_sparse_to_tf_sparse(x):
|
| 299 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 300 |
+
|
| 301 |
+
coo = x.tocoo()
|
| 302 |
+
indices = np.concatenate(
|
| 303 |
+
(np.expand_dims(coo.row, 1), np.expand_dims(coo.col, 1)), axis=1
|
| 304 |
+
)
|
| 305 |
+
return tf.SparseTensor(indices, coo.data, coo.shape)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def scipy_sparse_to_jax_sparse(x):
|
| 309 |
+
import jax
|
| 310 |
+
import jax.experimental.sparse as jax_sparse
|
| 311 |
+
|
| 312 |
+
with jax.default_device(jax.local_devices(backend="cpu")[0]):
|
| 313 |
+
return jax_sparse.BCOO.from_scipy_sparse(x)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def tf_sparse_to_jax_sparse(x):
|
| 317 |
+
import jax
|
| 318 |
+
import jax.experimental.sparse as jax_sparse
|
| 319 |
+
|
| 320 |
+
values = np.asarray(x.values)
|
| 321 |
+
indices = np.asarray(x.indices)
|
| 322 |
+
with jax.default_device(jax.local_devices(backend="cpu")[0]):
|
| 323 |
+
return jax_sparse.BCOO((values, indices), shape=x.shape)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def jax_sparse_to_tf_sparse(x):
|
| 327 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 328 |
+
|
| 329 |
+
return tf.SparseTensor(x.indices, x.data, x.shape)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
|
| 3 |
+
from keras.src import tree
|
| 4 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 5 |
+
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GeneratorDataAdapter(DataAdapter):
|
| 9 |
+
"""Adapter for Python generators."""
|
| 10 |
+
|
| 11 |
+
def __init__(self, generator):
|
| 12 |
+
first_batches, generator = peek_and_restore(generator)
|
| 13 |
+
self.generator = generator
|
| 14 |
+
self._first_batches = first_batches
|
| 15 |
+
self._output_signature = None
|
| 16 |
+
if not isinstance(first_batches[0], tuple):
|
| 17 |
+
raise ValueError(
|
| 18 |
+
"When passing a Python generator to a Keras model, "
|
| 19 |
+
"the generator must return a tuple, either "
|
| 20 |
+
"(input,) or (inputs, targets) or "
|
| 21 |
+
"(inputs, targets, sample_weights). "
|
| 22 |
+
f"Received: {first_batches[0]}"
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
def get_numpy_iterator(self):
|
| 26 |
+
return data_adapter_utils.get_numpy_iterator(self.generator())
|
| 27 |
+
|
| 28 |
+
def get_jax_iterator(self):
|
| 29 |
+
return data_adapter_utils.get_jax_iterator(self.generator())
|
| 30 |
+
|
| 31 |
+
def get_tf_dataset(self):
|
| 32 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 33 |
+
|
| 34 |
+
def convert_to_tf(x, spec):
|
| 35 |
+
if data_adapter_utils.is_scipy_sparse(x):
|
| 36 |
+
x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
|
| 37 |
+
elif data_adapter_utils.is_jax_sparse(x):
|
| 38 |
+
x = data_adapter_utils.jax_sparse_to_tf_sparse(x)
|
| 39 |
+
if not spec.shape.is_compatible_with(x.shape):
|
| 40 |
+
raise TypeError(
|
| 41 |
+
f"Generator yielded an element of shape {x.shape} where "
|
| 42 |
+
f"an element of shape {spec.shape} was expected. Your "
|
| 43 |
+
"generator provides tensors with variable input "
|
| 44 |
+
"dimensions other than the batch size. Make sure that the "
|
| 45 |
+
"generator's first two batches do not have the same "
|
| 46 |
+
"dimension value wherever there is a variable input "
|
| 47 |
+
"dimension."
|
| 48 |
+
)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
def get_tf_iterator():
|
| 52 |
+
for batch in self.generator():
|
| 53 |
+
batch = tree.map_structure(
|
| 54 |
+
convert_to_tf, batch, self._output_signature
|
| 55 |
+
)
|
| 56 |
+
yield batch
|
| 57 |
+
|
| 58 |
+
if self._output_signature is None:
|
| 59 |
+
self._output_signature = data_adapter_utils.get_tensor_spec(
|
| 60 |
+
self._first_batches
|
| 61 |
+
)
|
| 62 |
+
ds = tf.data.Dataset.from_generator(
|
| 63 |
+
get_tf_iterator,
|
| 64 |
+
output_signature=self._output_signature,
|
| 65 |
+
)
|
| 66 |
+
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 67 |
+
return ds
|
| 68 |
+
|
| 69 |
+
def get_torch_dataloader(self):
|
| 70 |
+
return data_adapter_utils.get_torch_dataloader(self.generator())
|
| 71 |
+
|
| 72 |
+
@property
|
| 73 |
+
def num_batches(self):
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
@property
|
| 77 |
+
def batch_size(self):
|
| 78 |
+
return None
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def peek_and_restore(generator):
|
| 82 |
+
batches = list(
|
| 83 |
+
itertools.islice(
|
| 84 |
+
generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
|
| 85 |
+
)
|
| 86 |
+
)
|
| 87 |
+
return batches, lambda: itertools.chain(batches, generator)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py
ADDED
|
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
import multiprocessing.dummy
|
| 3 |
+
import queue
|
| 4 |
+
import random
|
| 5 |
+
import threading
|
| 6 |
+
import warnings
|
| 7 |
+
import weakref
|
| 8 |
+
from contextlib import closing
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
|
| 12 |
+
from keras.src.api_export import keras_export
|
| 13 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 14 |
+
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"])
|
| 18 |
+
class PyDataset:
|
| 19 |
+
"""Base class for defining a parallel dataset using Python code.
|
| 20 |
+
|
| 21 |
+
Every `PyDataset` must implement the `__getitem__()` and the `__len__()`
|
| 22 |
+
methods. If you want to modify your dataset between epochs,
|
| 23 |
+
you may additionally implement `on_epoch_end()`,
|
| 24 |
+
or `on_epoch_begin` to be called at the start of each epoch.
|
| 25 |
+
The `__getitem__()` method should return a complete batch
|
| 26 |
+
(not a single sample), and the `__len__` method should return
|
| 27 |
+
the number of batches in the dataset (rather than the number of samples).
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
workers: Number of workers to use in multithreading or
|
| 31 |
+
multiprocessing.
|
| 32 |
+
use_multiprocessing: Whether to use Python multiprocessing for
|
| 33 |
+
parallelism. Setting this to `True` means that your
|
| 34 |
+
dataset will be replicated in multiple forked processes.
|
| 35 |
+
This is necessary to gain compute-level (rather than I/O level)
|
| 36 |
+
benefits from parallelism. However it can only be set to
|
| 37 |
+
`True` if your dataset can be safely pickled.
|
| 38 |
+
max_queue_size: Maximum number of batches to keep in the queue
|
| 39 |
+
when iterating over the dataset in a multithreaded or
|
| 40 |
+
multiprocessed setting.
|
| 41 |
+
Reduce this value to reduce the CPU memory consumption of
|
| 42 |
+
your dataset. Defaults to 10.
|
| 43 |
+
|
| 44 |
+
Notes:
|
| 45 |
+
|
| 46 |
+
- `PyDataset` is a safer way to do multiprocessing.
|
| 47 |
+
This structure guarantees that the model will only train
|
| 48 |
+
once on each sample per epoch, which is not the case
|
| 49 |
+
with Python generators.
|
| 50 |
+
- The arguments `workers`, `use_multiprocessing`, and `max_queue_size`
|
| 51 |
+
exist to configure how `fit()` uses parallelism to iterate
|
| 52 |
+
over the dataset. They are not being used by the `PyDataset` class
|
| 53 |
+
directly. When you are manually iterating over a `PyDataset`,
|
| 54 |
+
no parallelism is applied.
|
| 55 |
+
|
| 56 |
+
Example:
|
| 57 |
+
|
| 58 |
+
```python
|
| 59 |
+
from skimage.io import imread
|
| 60 |
+
from skimage.transform import resize
|
| 61 |
+
import numpy as np
|
| 62 |
+
import math
|
| 63 |
+
|
| 64 |
+
# Here, `x_set` is list of path to the images
|
| 65 |
+
# and `y_set` are the associated classes.
|
| 66 |
+
|
| 67 |
+
class CIFAR10PyDataset(keras.utils.PyDataset):
|
| 68 |
+
|
| 69 |
+
def __init__(self, x_set, y_set, batch_size, **kwargs):
|
| 70 |
+
super().__init__(**kwargs)
|
| 71 |
+
self.x, self.y = x_set, y_set
|
| 72 |
+
self.batch_size = batch_size
|
| 73 |
+
|
| 74 |
+
def __len__(self):
|
| 75 |
+
# Return number of batches.
|
| 76 |
+
return math.ceil(len(self.x) / self.batch_size)
|
| 77 |
+
|
| 78 |
+
def __getitem__(self, idx):
|
| 79 |
+
# Return x, y for batch idx.
|
| 80 |
+
low = idx * self.batch_size
|
| 81 |
+
# Cap upper bound at array length; the last batch may be smaller
|
| 82 |
+
# if the total number of items is not a multiple of batch size.
|
| 83 |
+
high = min(low + self.batch_size, len(self.x))
|
| 84 |
+
batch_x = self.x[low:high]
|
| 85 |
+
batch_y = self.y[low:high]
|
| 86 |
+
|
| 87 |
+
return np.array([
|
| 88 |
+
resize(imread(file_name), (200, 200))
|
| 89 |
+
for file_name in batch_x]), np.array(batch_y)
|
| 90 |
+
```
|
| 91 |
+
"""
|
| 92 |
+
|
| 93 |
+
def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10):
|
| 94 |
+
self._workers = workers
|
| 95 |
+
self._use_multiprocessing = use_multiprocessing
|
| 96 |
+
self._max_queue_size = max_queue_size
|
| 97 |
+
|
| 98 |
+
def _warn_if_super_not_called(self):
|
| 99 |
+
warn = False
|
| 100 |
+
if not hasattr(self, "_workers"):
|
| 101 |
+
self._workers = 1
|
| 102 |
+
warn = True
|
| 103 |
+
if not hasattr(self, "_use_multiprocessing"):
|
| 104 |
+
self._use_multiprocessing = False
|
| 105 |
+
warn = True
|
| 106 |
+
if not hasattr(self, "_max_queue_size"):
|
| 107 |
+
self._max_queue_size = 10
|
| 108 |
+
warn = True
|
| 109 |
+
if warn:
|
| 110 |
+
warnings.warn(
|
| 111 |
+
"Your `PyDataset` class should call "
|
| 112 |
+
"`super().__init__(**kwargs)` in its constructor. "
|
| 113 |
+
"`**kwargs` can include `workers`, "
|
| 114 |
+
"`use_multiprocessing`, `max_queue_size`. Do not pass "
|
| 115 |
+
"these arguments to `fit()`, as they will be ignored.",
|
| 116 |
+
stacklevel=2,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
@property
|
| 120 |
+
def workers(self):
|
| 121 |
+
self._warn_if_super_not_called()
|
| 122 |
+
return self._workers
|
| 123 |
+
|
| 124 |
+
@workers.setter
|
| 125 |
+
def workers(self, value):
|
| 126 |
+
self._workers = value
|
| 127 |
+
|
| 128 |
+
@property
|
| 129 |
+
def use_multiprocessing(self):
|
| 130 |
+
self._warn_if_super_not_called()
|
| 131 |
+
return self._use_multiprocessing
|
| 132 |
+
|
| 133 |
+
@use_multiprocessing.setter
|
| 134 |
+
def use_multiprocessing(self, value):
|
| 135 |
+
self._use_multiprocessing = value
|
| 136 |
+
|
| 137 |
+
@property
|
| 138 |
+
def max_queue_size(self):
|
| 139 |
+
self._warn_if_super_not_called()
|
| 140 |
+
return self._max_queue_size
|
| 141 |
+
|
| 142 |
+
@max_queue_size.setter
|
| 143 |
+
def max_queue_size(self, value):
|
| 144 |
+
self._max_queue_size = value
|
| 145 |
+
|
| 146 |
+
def __getitem__(self, index):
|
| 147 |
+
"""Gets batch at position `index`.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
index: position of the batch in the PyDataset.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
A batch
|
| 154 |
+
"""
|
| 155 |
+
raise NotImplementedError
|
| 156 |
+
|
| 157 |
+
@property
|
| 158 |
+
def num_batches(self):
|
| 159 |
+
"""Number of batches in the PyDataset.
|
| 160 |
+
|
| 161 |
+
Returns:
|
| 162 |
+
The number of batches in the PyDataset or `None` to indicate that
|
| 163 |
+
the dataset is infinite.
|
| 164 |
+
"""
|
| 165 |
+
# For backwards compatibility, support `__len__`.
|
| 166 |
+
if hasattr(self, "__len__"):
|
| 167 |
+
return len(self)
|
| 168 |
+
raise NotImplementedError(
|
| 169 |
+
"You need to implement the `num_batches` property:\n\n"
|
| 170 |
+
"@property\ndef num_batches(self):\n return ..."
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def on_epoch_begin(self):
|
| 174 |
+
"""Method called at the beginning of every epoch."""
|
| 175 |
+
pass
|
| 176 |
+
|
| 177 |
+
def on_epoch_end(self):
|
| 178 |
+
"""Method called at the end of every epoch."""
|
| 179 |
+
pass
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class PyDatasetAdapter(DataAdapter):
|
| 183 |
+
"""Adapter for `keras.utils.PyDataset` instances."""
|
| 184 |
+
|
| 185 |
+
def __init__(
|
| 186 |
+
self,
|
| 187 |
+
x,
|
| 188 |
+
class_weight=None,
|
| 189 |
+
shuffle=False,
|
| 190 |
+
):
|
| 191 |
+
self.py_dataset = x
|
| 192 |
+
self.class_weight = class_weight
|
| 193 |
+
self.enqueuer = None
|
| 194 |
+
self.shuffle = shuffle
|
| 195 |
+
self._output_signature = None
|
| 196 |
+
self._within_epoch = False
|
| 197 |
+
|
| 198 |
+
workers = self.py_dataset.workers
|
| 199 |
+
use_multiprocessing = self.py_dataset.use_multiprocessing
|
| 200 |
+
if workers > 1 or (workers > 0 and use_multiprocessing):
|
| 201 |
+
self.enqueuer = OrderedEnqueuer(
|
| 202 |
+
self.py_dataset,
|
| 203 |
+
workers=workers,
|
| 204 |
+
use_multiprocessing=use_multiprocessing,
|
| 205 |
+
max_queue_size=self.py_dataset.max_queue_size,
|
| 206 |
+
shuffle=self.shuffle,
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def _standardize_batch(self, batch):
|
| 210 |
+
if isinstance(batch, dict):
|
| 211 |
+
return batch
|
| 212 |
+
if isinstance(batch, np.ndarray):
|
| 213 |
+
batch = (batch,)
|
| 214 |
+
if isinstance(batch, list):
|
| 215 |
+
batch = tuple(batch)
|
| 216 |
+
if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
"PyDataset.__getitem__() must return a tuple or a dict. "
|
| 219 |
+
"If a tuple, it must be ordered either "
|
| 220 |
+
"(input,) or (inputs, targets) or "
|
| 221 |
+
"(inputs, targets, sample_weights). "
|
| 222 |
+
f"Received: {str(batch)[:100]}... of type {type(batch)}"
|
| 223 |
+
)
|
| 224 |
+
if self.class_weight is not None:
|
| 225 |
+
if len(batch) == 3:
|
| 226 |
+
raise ValueError(
|
| 227 |
+
"You cannot specify `class_weight` "
|
| 228 |
+
"and `sample_weight` at the same time."
|
| 229 |
+
)
|
| 230 |
+
if len(batch) == 2:
|
| 231 |
+
sw = data_adapter_utils.class_weight_to_sample_weights(
|
| 232 |
+
batch[1], self.class_weight
|
| 233 |
+
)
|
| 234 |
+
batch = batch + (sw,)
|
| 235 |
+
return batch
|
| 236 |
+
|
| 237 |
+
def _infinite_generator(self):
|
| 238 |
+
for i in itertools.count():
|
| 239 |
+
yield self._standardize_batch(self.py_dataset[i])
|
| 240 |
+
|
| 241 |
+
def _finite_generator(self):
|
| 242 |
+
indices = range(self.py_dataset.num_batches)
|
| 243 |
+
if self.shuffle:
|
| 244 |
+
indices = list(indices)
|
| 245 |
+
random.shuffle(indices)
|
| 246 |
+
|
| 247 |
+
for i in indices:
|
| 248 |
+
yield self._standardize_batch(self.py_dataset[i])
|
| 249 |
+
|
| 250 |
+
def _infinite_enqueuer_generator(self):
|
| 251 |
+
self.enqueuer.start()
|
| 252 |
+
for batch in self.enqueuer.get():
|
| 253 |
+
yield self._standardize_batch(batch)
|
| 254 |
+
|
| 255 |
+
def _finite_enqueuer_generator(self):
|
| 256 |
+
self.enqueuer.start()
|
| 257 |
+
num_batches = self.py_dataset.num_batches
|
| 258 |
+
for i, batch in enumerate(self.enqueuer.get()):
|
| 259 |
+
yield self._standardize_batch(batch)
|
| 260 |
+
if i >= num_batches - 1:
|
| 261 |
+
self.enqueuer.stop()
|
| 262 |
+
return
|
| 263 |
+
|
| 264 |
+
def _get_iterator(self):
|
| 265 |
+
if self.enqueuer is None:
|
| 266 |
+
if self.py_dataset.num_batches is None:
|
| 267 |
+
return self._infinite_generator()
|
| 268 |
+
else:
|
| 269 |
+
return self._finite_generator()
|
| 270 |
+
else:
|
| 271 |
+
if self.py_dataset.num_batches is None:
|
| 272 |
+
return self._infinite_enqueuer_generator()
|
| 273 |
+
else:
|
| 274 |
+
return self._finite_enqueuer_generator()
|
| 275 |
+
|
| 276 |
+
def get_numpy_iterator(self):
|
| 277 |
+
return data_adapter_utils.get_numpy_iterator(self._get_iterator())
|
| 278 |
+
|
| 279 |
+
def get_jax_iterator(self):
|
| 280 |
+
return data_adapter_utils.get_jax_iterator(self._get_iterator())
|
| 281 |
+
|
| 282 |
+
def get_tf_dataset(self):
|
| 283 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 284 |
+
|
| 285 |
+
num_batches = self.py_dataset.num_batches
|
| 286 |
+
if self._output_signature is None:
|
| 287 |
+
num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
|
| 288 |
+
if num_batches is not None:
|
| 289 |
+
num_samples = min(num_samples, num_batches)
|
| 290 |
+
batches = [
|
| 291 |
+
self._standardize_batch(self.py_dataset[i])
|
| 292 |
+
for i in range(num_samples)
|
| 293 |
+
]
|
| 294 |
+
if len(batches) == 0:
|
| 295 |
+
raise ValueError("The PyDataset has length 0")
|
| 296 |
+
self._output_signature = data_adapter_utils.get_tensor_spec(batches)
|
| 297 |
+
|
| 298 |
+
ds = tf.data.Dataset.from_generator(
|
| 299 |
+
self._get_iterator,
|
| 300 |
+
output_signature=self._output_signature,
|
| 301 |
+
)
|
| 302 |
+
if self.enqueuer is not None:
|
| 303 |
+
# The enqueuer does its own multithreading / multiprocesssing to
|
| 304 |
+
# prefetch items. Disable the tf.data.Dataset prefetching and
|
| 305 |
+
# threading as it interferes.
|
| 306 |
+
options = tf.data.Options()
|
| 307 |
+
options.autotune.enabled = False
|
| 308 |
+
options.threading.private_threadpool_size = 1
|
| 309 |
+
ds = ds.with_options(options)
|
| 310 |
+
else:
|
| 311 |
+
ds = ds.prefetch(tf.data.AUTOTUNE)
|
| 312 |
+
return ds
|
| 313 |
+
|
| 314 |
+
def get_torch_dataloader(self):
|
| 315 |
+
return data_adapter_utils.get_torch_dataloader(self._get_iterator())
|
| 316 |
+
|
| 317 |
+
def on_epoch_begin(self):
|
| 318 |
+
if self._within_epoch:
|
| 319 |
+
raise ValueError(
|
| 320 |
+
"`on_epoch_begin` was called twice without `on_epoch_end` "
|
| 321 |
+
"having been called."
|
| 322 |
+
)
|
| 323 |
+
self._within_epoch = True
|
| 324 |
+
if self.enqueuer:
|
| 325 |
+
self.enqueuer.start()
|
| 326 |
+
self.py_dataset.on_epoch_begin()
|
| 327 |
+
|
| 328 |
+
def on_epoch_end(self):
|
| 329 |
+
if self.enqueuer:
|
| 330 |
+
self.enqueuer.stop()
|
| 331 |
+
self.py_dataset.on_epoch_end()
|
| 332 |
+
self._within_epoch = False
|
| 333 |
+
|
| 334 |
+
@property
|
| 335 |
+
def num_batches(self):
|
| 336 |
+
return self.py_dataset.num_batches
|
| 337 |
+
|
| 338 |
+
@property
|
| 339 |
+
def batch_size(self):
|
| 340 |
+
return None
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
# Global variables to be shared across processes
|
| 344 |
+
_SHARED_SEQUENCES = {}
|
| 345 |
+
# We use a Value to provide unique id to different processes.
|
| 346 |
+
_SEQUENCE_COUNTER = None
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
# Because multiprocessing pools are inherently unsafe, starting from a clean
|
| 350 |
+
# state can be essential to avoiding deadlocks. In order to accomplish this, we
|
| 351 |
+
# need to be able to check on the status of Pools that we create.
|
| 352 |
+
_DATA_POOLS = weakref.WeakSet()
|
| 353 |
+
_WORKER_ID_QUEUE = None # Only created if needed.
|
| 354 |
+
_FORCE_THREADPOOL = False
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def get_pool_class(use_multiprocessing):
|
| 358 |
+
global _FORCE_THREADPOOL
|
| 359 |
+
if not use_multiprocessing or _FORCE_THREADPOOL:
|
| 360 |
+
return multiprocessing.dummy.Pool # ThreadPool
|
| 361 |
+
return multiprocessing.Pool
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
def get_worker_id_queue():
|
| 365 |
+
"""Lazily create the queue to track worker ids."""
|
| 366 |
+
global _WORKER_ID_QUEUE
|
| 367 |
+
if _WORKER_ID_QUEUE is None:
|
| 368 |
+
_WORKER_ID_QUEUE = multiprocessing.Queue()
|
| 369 |
+
return _WORKER_ID_QUEUE
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def get_index(uid, i):
|
| 373 |
+
"""Get the value from the PyDataset `uid` at index `i`.
|
| 374 |
+
|
| 375 |
+
To allow multiple PyDatasets to be used at the same time, we use `uid` to
|
| 376 |
+
get a specific one. A single PyDataset would cause the validation to
|
| 377 |
+
overwrite the training PyDataset.
|
| 378 |
+
|
| 379 |
+
This methods is called from worker threads.
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
uid: int, PyDataset identifier
|
| 383 |
+
i: index
|
| 384 |
+
|
| 385 |
+
Returns:
|
| 386 |
+
The value at index `i`.
|
| 387 |
+
"""
|
| 388 |
+
return _SHARED_SEQUENCES[uid][i]
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class PyDatasetEnqueuer:
|
| 392 |
+
"""Base class to enqueue inputs.
|
| 393 |
+
|
| 394 |
+
The task of an Enqueuer is to use parallelism to speed up preprocessing.
|
| 395 |
+
This is done with processes or threads.
|
| 396 |
+
|
| 397 |
+
Example:
|
| 398 |
+
|
| 399 |
+
```python
|
| 400 |
+
enqueuer = PyDatasetEnqueuer(...)
|
| 401 |
+
enqueuer.start()
|
| 402 |
+
datas = enqueuer.get()
|
| 403 |
+
for data in datas:
|
| 404 |
+
# Use the inputs; training, evaluating, predicting.
|
| 405 |
+
# ... stop sometime.
|
| 406 |
+
enqueuer.stop()
|
| 407 |
+
```
|
| 408 |
+
|
| 409 |
+
The `enqueuer.get()` should be an infinite stream of data.
|
| 410 |
+
"""
|
| 411 |
+
|
| 412 |
+
def __init__(
|
| 413 |
+
self,
|
| 414 |
+
py_dataset,
|
| 415 |
+
workers=1,
|
| 416 |
+
use_multiprocessing=False,
|
| 417 |
+
max_queue_size=10,
|
| 418 |
+
):
|
| 419 |
+
self.py_dataset = py_dataset
|
| 420 |
+
|
| 421 |
+
global _SEQUENCE_COUNTER
|
| 422 |
+
if _SEQUENCE_COUNTER is None:
|
| 423 |
+
try:
|
| 424 |
+
_SEQUENCE_COUNTER = multiprocessing.Value("i", 0)
|
| 425 |
+
except OSError:
|
| 426 |
+
# In this case the OS does not allow us to use
|
| 427 |
+
# multiprocessing. We resort to an int
|
| 428 |
+
# for enqueuer indexing.
|
| 429 |
+
_SEQUENCE_COUNTER = 0
|
| 430 |
+
|
| 431 |
+
if isinstance(_SEQUENCE_COUNTER, int):
|
| 432 |
+
self.uid = _SEQUENCE_COUNTER
|
| 433 |
+
_SEQUENCE_COUNTER += 1
|
| 434 |
+
else:
|
| 435 |
+
# Doing Multiprocessing.Value += x is not process-safe.
|
| 436 |
+
with _SEQUENCE_COUNTER.get_lock():
|
| 437 |
+
self.uid = _SEQUENCE_COUNTER.value
|
| 438 |
+
_SEQUENCE_COUNTER.value += 1
|
| 439 |
+
|
| 440 |
+
self.ready_queue = queue.Queue()
|
| 441 |
+
self.future_queue = queue.Queue(max_queue_size)
|
| 442 |
+
self.running = False
|
| 443 |
+
self.start_stop_lock = threading.Lock()
|
| 444 |
+
self.run_thread = None
|
| 445 |
+
if use_multiprocessing:
|
| 446 |
+
self.executor_fn = self._get_executor_init(workers)
|
| 447 |
+
else:
|
| 448 |
+
# We do not need the init since it's threads.
|
| 449 |
+
self.executor_fn = lambda _: get_pool_class(False)(workers)
|
| 450 |
+
|
| 451 |
+
def is_running(self):
|
| 452 |
+
"""Whether the enqueuer is running.
|
| 453 |
+
|
| 454 |
+
This method is thread safe and called from many threads.
|
| 455 |
+
|
| 456 |
+
Returns: boolean indicating whether this enqueuer is running.
|
| 457 |
+
"""
|
| 458 |
+
return self.running
|
| 459 |
+
|
| 460 |
+
def start(self):
|
| 461 |
+
"""Starts the handler's workers.
|
| 462 |
+
|
| 463 |
+
This method is thread safe but is called from the main thread.
|
| 464 |
+
It is safe to call this method multiple times, extra calls are ignored.
|
| 465 |
+
"""
|
| 466 |
+
with self.start_stop_lock:
|
| 467 |
+
if self.running:
|
| 468 |
+
return
|
| 469 |
+
self.running = True
|
| 470 |
+
self.run_thread = threading.Thread(target=self._run)
|
| 471 |
+
self.run_thread.name = f"Worker_{self.uid}"
|
| 472 |
+
self.run_thread.daemon = True
|
| 473 |
+
self.run_thread.start()
|
| 474 |
+
|
| 475 |
+
def stop(self, drain_queue_and_join=True):
|
| 476 |
+
"""Stops running threads and wait for them to exit, if necessary.
|
| 477 |
+
|
| 478 |
+
This method is thread safe and is called from various threads. Note that
|
| 479 |
+
the `drain_queue_and_join` argument must be set correctly.
|
| 480 |
+
It is safe to call this method multiple times, extra calls are ignored.
|
| 481 |
+
|
| 482 |
+
Args:
|
| 483 |
+
drain_queue_and_join: set to True to drain the queue of pending
|
| 484 |
+
items and wait for the worker thread to complete. Set to False
|
| 485 |
+
if invoked from a worker thread to avoid deadlocks. Note that
|
| 486 |
+
setting this to False means this enqueuer won't be reused.
|
| 487 |
+
"""
|
| 488 |
+
with self.start_stop_lock:
|
| 489 |
+
if not self.running:
|
| 490 |
+
return
|
| 491 |
+
self.running = False
|
| 492 |
+
|
| 493 |
+
if drain_queue_and_join:
|
| 494 |
+
# Drain the `future_queue` and put items in `ready_queue` for
|
| 495 |
+
# the next run.
|
| 496 |
+
while True:
|
| 497 |
+
try:
|
| 498 |
+
value = self.future_queue.get(block=True, timeout=0.1)
|
| 499 |
+
if isinstance(value, Exception):
|
| 500 |
+
raise value # Propagate exception from other thread
|
| 501 |
+
inputs = value.get()
|
| 502 |
+
self.future_queue.task_done()
|
| 503 |
+
if inputs is not None:
|
| 504 |
+
self.ready_queue.put(inputs)
|
| 505 |
+
except queue.Empty:
|
| 506 |
+
break
|
| 507 |
+
self.run_thread.join()
|
| 508 |
+
|
| 509 |
+
self.run_thread = None
|
| 510 |
+
_SHARED_SEQUENCES[self.uid] = None
|
| 511 |
+
|
| 512 |
+
def _send_py_dataset(self):
|
| 513 |
+
"""Sends current Iterable to all workers."""
|
| 514 |
+
# For new processes that may spawn
|
| 515 |
+
_SHARED_SEQUENCES[self.uid] = self.py_dataset
|
| 516 |
+
|
| 517 |
+
def __del__(self):
|
| 518 |
+
self.stop(drain_queue_and_join=False)
|
| 519 |
+
|
| 520 |
+
def _run(self):
|
| 521 |
+
"""Submits request to the executor and queue the `Future` objects."""
|
| 522 |
+
raise NotImplementedError
|
| 523 |
+
|
| 524 |
+
def _get_executor_init(self, workers):
|
| 525 |
+
"""Gets the Pool initializer for multiprocessing.
|
| 526 |
+
|
| 527 |
+
Args:
|
| 528 |
+
workers: Number of workers.
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
Function, a Function to initialize the pool
|
| 532 |
+
"""
|
| 533 |
+
raise NotImplementedError
|
| 534 |
+
|
| 535 |
+
def get(self):
|
| 536 |
+
"""Creates a generator to extract data from the queue.
|
| 537 |
+
|
| 538 |
+
Skip the data if it is `None`.
|
| 539 |
+
|
| 540 |
+
This method is called from the main thread.
|
| 541 |
+
|
| 542 |
+
Yields:
|
| 543 |
+
The next element in the queue, i.e. a tuple
|
| 544 |
+
`(inputs, targets)` or
|
| 545 |
+
`(inputs, targets, sample_weights)`.
|
| 546 |
+
"""
|
| 547 |
+
raise NotImplementedError
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
class OrderedEnqueuer(PyDatasetEnqueuer):
|
| 551 |
+
"""Builds a Enqueuer from a PyDataset.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
py_dataset: A `keras.utils.PyDataset` object.
|
| 555 |
+
use_multiprocessing: use multiprocessing if True, otherwise threading
|
| 556 |
+
shuffle: whether to shuffle the data at the beginning of each epoch
|
| 557 |
+
"""
|
| 558 |
+
|
| 559 |
+
def __init__(
|
| 560 |
+
self,
|
| 561 |
+
py_dataset,
|
| 562 |
+
workers=1,
|
| 563 |
+
use_multiprocessing=False,
|
| 564 |
+
max_queue_size=10,
|
| 565 |
+
shuffle=False,
|
| 566 |
+
):
|
| 567 |
+
super().__init__(
|
| 568 |
+
py_dataset, workers, use_multiprocessing, max_queue_size
|
| 569 |
+
)
|
| 570 |
+
self.shuffle = shuffle
|
| 571 |
+
if self.py_dataset.num_batches is None:
|
| 572 |
+
# For infinite datasets, `self.indices` is created here once for all
|
| 573 |
+
# so that subsequent runs resume from where they stopped.
|
| 574 |
+
self.indices = itertools.count()
|
| 575 |
+
|
| 576 |
+
def _get_executor_init(self, workers):
|
| 577 |
+
"""Gets the Pool initializer for multiprocessing.
|
| 578 |
+
|
| 579 |
+
Args:
|
| 580 |
+
workers: Number of workers.
|
| 581 |
+
|
| 582 |
+
Returns:
|
| 583 |
+
Function, a Function to initialize the pool
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
def pool_fn(seqs):
|
| 587 |
+
pool = get_pool_class(True)(
|
| 588 |
+
workers,
|
| 589 |
+
initializer=init_pool_generator,
|
| 590 |
+
initargs=(seqs, None, get_worker_id_queue()),
|
| 591 |
+
)
|
| 592 |
+
_DATA_POOLS.add(pool)
|
| 593 |
+
return pool
|
| 594 |
+
|
| 595 |
+
return pool_fn
|
| 596 |
+
|
| 597 |
+
def _run(self):
|
| 598 |
+
"""Submits request to the executor and queue the `Future` objects.
|
| 599 |
+
|
| 600 |
+
This method is the run method of worker threads.
|
| 601 |
+
"""
|
| 602 |
+
try:
|
| 603 |
+
if self.py_dataset.num_batches is not None:
|
| 604 |
+
# For finite datasets, `self.indices` is created here so that
|
| 605 |
+
# shuffling creates different a order each time.
|
| 606 |
+
indices = range(self.py_dataset.num_batches)
|
| 607 |
+
if self.shuffle:
|
| 608 |
+
indices = list(indices)
|
| 609 |
+
random.shuffle(indices)
|
| 610 |
+
self.indices = iter(indices)
|
| 611 |
+
self._send_py_dataset() # Share the initial py_dataset
|
| 612 |
+
|
| 613 |
+
with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
|
| 614 |
+
while self.is_running():
|
| 615 |
+
try:
|
| 616 |
+
i = next(self.indices)
|
| 617 |
+
self.future_queue.put(
|
| 618 |
+
executor.apply_async(get_index, (self.uid, i)),
|
| 619 |
+
block=True,
|
| 620 |
+
)
|
| 621 |
+
except StopIteration:
|
| 622 |
+
break
|
| 623 |
+
except Exception as e:
|
| 624 |
+
self.future_queue.put(e) # Report exception
|
| 625 |
+
|
| 626 |
+
def get(self):
|
| 627 |
+
"""Creates a generator to extract data from the queue.
|
| 628 |
+
|
| 629 |
+
Skip the data if it is `None`.
|
| 630 |
+
|
| 631 |
+
This method is called from the main thread.
|
| 632 |
+
|
| 633 |
+
Yields:
|
| 634 |
+
The next element in the queue, i.e. a tuple
|
| 635 |
+
`(inputs, targets)` or
|
| 636 |
+
`(inputs, targets, sample_weights)`.
|
| 637 |
+
"""
|
| 638 |
+
while self.is_running():
|
| 639 |
+
try:
|
| 640 |
+
inputs = self.ready_queue.get(block=False)
|
| 641 |
+
yield inputs
|
| 642 |
+
continue # Retry the ready_queue
|
| 643 |
+
except queue.Empty:
|
| 644 |
+
pass
|
| 645 |
+
|
| 646 |
+
try:
|
| 647 |
+
value = self.future_queue.get(block=True, timeout=5)
|
| 648 |
+
self.future_queue.task_done()
|
| 649 |
+
if isinstance(value, Exception):
|
| 650 |
+
raise value # Propagate exception from other thread
|
| 651 |
+
inputs = value.get()
|
| 652 |
+
if inputs is not None:
|
| 653 |
+
yield inputs
|
| 654 |
+
except queue.Empty:
|
| 655 |
+
pass
|
| 656 |
+
except Exception as e:
|
| 657 |
+
self.stop(drain_queue_and_join=True)
|
| 658 |
+
raise e
|
| 659 |
+
|
| 660 |
+
# Note that it is ok to poll the iterator after the initial `start`,
|
| 661 |
+
# which may happen before the first `on_epoch_begin`. But it's not ok to
|
| 662 |
+
# poll after `on_epoch_end`.
|
| 663 |
+
raise ValueError(
|
| 664 |
+
"Iterator called after `on_epoch_end` or before `on_epoch_begin`."
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
|
| 668 |
+
def init_pool_generator(gens, random_seed=None, id_queue=None):
|
| 669 |
+
"""Initializer function for pool workers.
|
| 670 |
+
|
| 671 |
+
Args:
|
| 672 |
+
gens: State which should be made available to worker processes.
|
| 673 |
+
random_seed: An optional value with which to seed child processes.
|
| 674 |
+
id_queue: A multiprocessing Queue of worker ids.
|
| 675 |
+
This is used to indicate that a worker process
|
| 676 |
+
was created by Keras.
|
| 677 |
+
"""
|
| 678 |
+
global _SHARED_SEQUENCES
|
| 679 |
+
_SHARED_SEQUENCES = gens
|
| 680 |
+
|
| 681 |
+
worker_proc = multiprocessing.current_process()
|
| 682 |
+
|
| 683 |
+
# name isn't used for anything, but setting a more descriptive name is
|
| 684 |
+
# helpful when diagnosing orphaned processes.
|
| 685 |
+
worker_proc.name = f"Keras_worker_{worker_proc.name}"
|
| 686 |
+
|
| 687 |
+
if random_seed is not None:
|
| 688 |
+
np.random.seed(random_seed + worker_proc.ident)
|
| 689 |
+
|
| 690 |
+
if id_queue is not None:
|
| 691 |
+
# If a worker dies during init, the pool will just create a replacement.
|
| 692 |
+
id_queue.put(worker_proc.ident, block=True, timeout=0.1)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src import tree
|
| 2 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 3 |
+
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TFDatasetAdapter(DataAdapter):
|
| 7 |
+
"""Adapter that handles `tf.data.Dataset`."""
|
| 8 |
+
|
| 9 |
+
def __init__(self, dataset, class_weight=None, distribution=None):
|
| 10 |
+
"""Initialize the TFDatasetAdapter.
|
| 11 |
+
|
| 12 |
+
Args:
|
| 13 |
+
dataset: The input `tf.data.Dataset` instance.
|
| 14 |
+
class_weight: A map where the keys are integer class ids and values
|
| 15 |
+
are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`.
|
| 16 |
+
distribution: A `keras.distribution.Distribution` instance. Used to
|
| 17 |
+
shard the input dataset into per worker/process dataset
|
| 18 |
+
instance.
|
| 19 |
+
"""
|
| 20 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 21 |
+
|
| 22 |
+
if not isinstance(
|
| 23 |
+
dataset, (tf.data.Dataset, tf.distribute.DistributedDataset)
|
| 24 |
+
):
|
| 25 |
+
raise ValueError(
|
| 26 |
+
"Expected argument `dataset` to be a tf.data.Dataset. "
|
| 27 |
+
f"Received: {dataset}"
|
| 28 |
+
)
|
| 29 |
+
if class_weight is not None:
|
| 30 |
+
dataset = dataset.map(
|
| 31 |
+
make_class_weight_map_fn(class_weight)
|
| 32 |
+
).prefetch(tf.data.AUTOTUNE)
|
| 33 |
+
if distribution is not None:
|
| 34 |
+
dataset = distribution.distribute_dataset(dataset)
|
| 35 |
+
self._dataset = dataset
|
| 36 |
+
|
| 37 |
+
def get_numpy_iterator(self):
|
| 38 |
+
from keras.src.backend.tensorflow.core import convert_to_numpy
|
| 39 |
+
|
| 40 |
+
for batch in self._dataset:
|
| 41 |
+
yield tree.map_structure(convert_to_numpy, batch)
|
| 42 |
+
|
| 43 |
+
def get_jax_iterator(self):
|
| 44 |
+
from keras.src.backend.tensorflow.core import convert_to_numpy
|
| 45 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 46 |
+
|
| 47 |
+
def convert_to_jax(x):
|
| 48 |
+
if isinstance(x, tf.SparseTensor):
|
| 49 |
+
return data_adapter_utils.tf_sparse_to_jax_sparse(x)
|
| 50 |
+
else:
|
| 51 |
+
# We use numpy as an intermediary because it is faster.
|
| 52 |
+
return convert_to_numpy(x)
|
| 53 |
+
|
| 54 |
+
for batch in self._dataset:
|
| 55 |
+
yield tree.map_structure(convert_to_jax, batch)
|
| 56 |
+
|
| 57 |
+
def get_tf_dataset(self):
|
| 58 |
+
return self._dataset
|
| 59 |
+
|
| 60 |
+
def get_torch_dataloader(self):
|
| 61 |
+
return data_adapter_utils.get_torch_dataloader(self._dataset)
|
| 62 |
+
|
| 63 |
+
@property
|
| 64 |
+
def num_batches(self):
|
| 65 |
+
cardinality = self._dataset.cardinality
|
| 66 |
+
if callable(cardinality):
|
| 67 |
+
# `dataset.cardinality` is normally expected to be a callable.
|
| 68 |
+
cardinality = int(self._dataset.cardinality())
|
| 69 |
+
else:
|
| 70 |
+
# However, in the case of `DistributedDataset`, it's a np.int64.
|
| 71 |
+
cardinality = int(cardinality)
|
| 72 |
+
# Return None for Unknown and Infinite cardinality datasets
|
| 73 |
+
if cardinality < 0:
|
| 74 |
+
return None
|
| 75 |
+
return cardinality
|
| 76 |
+
|
| 77 |
+
@property
|
| 78 |
+
def batch_size(self):
|
| 79 |
+
first_element_spec = tree.flatten(self._dataset.element_spec)[0]
|
| 80 |
+
return first_element_spec.shape[0]
|
| 81 |
+
|
| 82 |
+
@property
|
| 83 |
+
def has_partial_batch(self):
|
| 84 |
+
return None
|
| 85 |
+
|
| 86 |
+
@property
|
| 87 |
+
def partial_batch_size(self):
|
| 88 |
+
return None
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def make_class_weight_map_fn(class_weight):
|
| 92 |
+
"""Applies class weighting to a `Dataset`.
|
| 93 |
+
|
| 94 |
+
The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
|
| 95 |
+
`y` must be a single `Tensor`.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
class_weight: A map where the keys are integer class ids and values are
|
| 99 |
+
the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
A function that can be used with `tf.data.Dataset.map` to apply class
|
| 103 |
+
weighting.
|
| 104 |
+
"""
|
| 105 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 106 |
+
|
| 107 |
+
class_weight_tensor = tf.convert_to_tensor(
|
| 108 |
+
[
|
| 109 |
+
class_weight.get(int(c), 1.0)
|
| 110 |
+
for c in range(max(class_weight.keys()) + 1)
|
| 111 |
+
]
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def class_weights_map_fn(*data):
|
| 115 |
+
"""Convert `class_weight` to `sample_weight`."""
|
| 116 |
+
x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data)
|
| 117 |
+
if sw is not None:
|
| 118 |
+
raise ValueError(
|
| 119 |
+
"You cannot `class_weight` and `sample_weight` "
|
| 120 |
+
"at the same time."
|
| 121 |
+
)
|
| 122 |
+
if tree.is_nested(y):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
"`class_weight` is only supported for Models with a single "
|
| 125 |
+
"output."
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
if y.shape.rank >= 2:
|
| 129 |
+
y_classes = tf.__internal__.smart_cond.smart_cond(
|
| 130 |
+
tf.shape(y)[-1] > 1,
|
| 131 |
+
lambda: tf.argmax(y, axis=-1, output_type=tf.int32),
|
| 132 |
+
lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32),
|
| 133 |
+
)
|
| 134 |
+
else:
|
| 135 |
+
# Special casing for rank 1, where we can guarantee sparse encoding.
|
| 136 |
+
y_classes = tf.cast(tf.round(y), tf.int32)
|
| 137 |
+
|
| 138 |
+
cw = tf.gather(class_weight_tensor, y_classes)
|
| 139 |
+
return x, y, cw
|
| 140 |
+
|
| 141 |
+
return class_weights_map_fn
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import itertools
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from keras.src import tree
|
| 6 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 7 |
+
from keras.src.trainers.data_adapters.data_adapter import DataAdapter
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TorchDataLoaderAdapter(DataAdapter):
|
| 11 |
+
"""Adapter that handles `torch.utils.data.DataLoader`."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, dataloader):
|
| 14 |
+
import torch
|
| 15 |
+
|
| 16 |
+
if not isinstance(dataloader, torch.utils.data.DataLoader):
|
| 17 |
+
raise ValueError(
|
| 18 |
+
f"Expected argument `dataloader` to be an instance of"
|
| 19 |
+
f"`torch.utils.data.DataLoader`. Received: {dataloader}"
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
self._dataloader = dataloader
|
| 23 |
+
self._output_signature = None
|
| 24 |
+
self._batch_size = dataloader.batch_size
|
| 25 |
+
self._num_batches = None
|
| 26 |
+
self._partial_batch_size = None
|
| 27 |
+
if hasattr(dataloader.dataset, "__len__"):
|
| 28 |
+
self._num_batches = len(dataloader)
|
| 29 |
+
if self._batch_size is not None:
|
| 30 |
+
self._partial_batch_size = (
|
| 31 |
+
len(dataloader.dataset) % self._batch_size
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def get_numpy_iterator(self):
|
| 35 |
+
for batch in self._dataloader:
|
| 36 |
+
# shared memory using `np.asarray`
|
| 37 |
+
yield tuple(
|
| 38 |
+
tree.map_structure(lambda x: np.asarray(x.cpu()), batch)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def get_jax_iterator(self):
|
| 42 |
+
# We use numpy as an intermediary because it is faster.
|
| 43 |
+
return self.get_numpy_iterator()
|
| 44 |
+
|
| 45 |
+
def get_tf_dataset(self):
|
| 46 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 47 |
+
|
| 48 |
+
if self._output_signature is None:
|
| 49 |
+
batches = list(
|
| 50 |
+
itertools.islice(
|
| 51 |
+
self._dataloader,
|
| 52 |
+
data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
|
| 53 |
+
)
|
| 54 |
+
)
|
| 55 |
+
self._output_signature = tuple(
|
| 56 |
+
data_adapter_utils.get_tensor_spec(batches)
|
| 57 |
+
)
|
| 58 |
+
return tf.data.Dataset.from_generator(
|
| 59 |
+
self.get_numpy_iterator,
|
| 60 |
+
output_signature=self._output_signature,
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
def get_torch_dataloader(self):
|
| 64 |
+
return self._dataloader
|
| 65 |
+
|
| 66 |
+
@property
|
| 67 |
+
def num_batches(self):
|
| 68 |
+
return self._num_batches
|
| 69 |
+
|
| 70 |
+
@property
|
| 71 |
+
def batch_size(self):
|
| 72 |
+
return self._batch_size
|
| 73 |
+
|
| 74 |
+
@property
|
| 75 |
+
def has_partial_batch(self):
|
| 76 |
+
if self._partial_batch_size:
|
| 77 |
+
return self._partial_batch_size > 0
|
| 78 |
+
else:
|
| 79 |
+
return None
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def partial_batch_size(self):
|
| 83 |
+
return self._partial_batch_size
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Separation of concerns:
|
| 3 |
+
|
| 4 |
+
DataAdapter:
|
| 5 |
+
- x, y
|
| 6 |
+
- sample_weight
|
| 7 |
+
- class_weight
|
| 8 |
+
- shuffle
|
| 9 |
+
- batch_size
|
| 10 |
+
- steps, as it relates to batch_size for array data
|
| 11 |
+
|
| 12 |
+
EpochIterator:
|
| 13 |
+
- whether to yield numpy or tf data
|
| 14 |
+
- steps
|
| 15 |
+
- most argument validation
|
| 16 |
+
|
| 17 |
+
Trainer:
|
| 18 |
+
- steps_per_execution
|
| 19 |
+
- validation_split
|
| 20 |
+
- validation_data
|
| 21 |
+
- callbacks
|
| 22 |
+
- validation_freq
|
| 23 |
+
- epochs
|
| 24 |
+
- initial_epoch
|
| 25 |
+
- any backend-specific concern such as distribution
|
| 26 |
+
|
| 27 |
+
PyDataset:
|
| 28 |
+
- num_workers
|
| 29 |
+
- use_multiprocessing
|
| 30 |
+
- max_queue_size
|
| 31 |
+
|
| 32 |
+
EpochIterator steps:
|
| 33 |
+
|
| 34 |
+
1. Look at data type and select correct DataHandler
|
| 35 |
+
2. Instantiate DataHandler with correct arguments
|
| 36 |
+
3. Raise or warn on unused arguments
|
| 37 |
+
4. in __iter__, iterate, either for a fixed number of steps
|
| 38 |
+
or until there is no data
|
| 39 |
+
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
import contextlib
|
| 43 |
+
import warnings
|
| 44 |
+
|
| 45 |
+
from keras.src.trainers import data_adapters
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class EpochIterator:
|
| 49 |
+
def __init__(
|
| 50 |
+
self,
|
| 51 |
+
x,
|
| 52 |
+
y=None,
|
| 53 |
+
sample_weight=None,
|
| 54 |
+
batch_size=None,
|
| 55 |
+
steps_per_epoch=None,
|
| 56 |
+
shuffle=False,
|
| 57 |
+
class_weight=None,
|
| 58 |
+
steps_per_execution=1,
|
| 59 |
+
):
|
| 60 |
+
self.steps_per_epoch = steps_per_epoch
|
| 61 |
+
self.steps_per_execution = steps_per_execution
|
| 62 |
+
self._current_iterator = None
|
| 63 |
+
self._epoch_iterator = None
|
| 64 |
+
self._steps_seen = 0
|
| 65 |
+
self.data_adapter = data_adapters.get_data_adapter(
|
| 66 |
+
x=x,
|
| 67 |
+
y=y,
|
| 68 |
+
sample_weight=sample_weight,
|
| 69 |
+
batch_size=batch_size,
|
| 70 |
+
steps_per_epoch=steps_per_epoch,
|
| 71 |
+
shuffle=shuffle,
|
| 72 |
+
class_weight=class_weight,
|
| 73 |
+
)
|
| 74 |
+
self._num_batches = self.data_adapter.num_batches
|
| 75 |
+
|
| 76 |
+
def _get_iterator(self):
|
| 77 |
+
return self.data_adapter.get_numpy_iterator()
|
| 78 |
+
|
| 79 |
+
def _interrupted_warning(self):
|
| 80 |
+
warnings.warn(
|
| 81 |
+
"Your input ran out of data; interrupting training. "
|
| 82 |
+
"Make sure that your dataset or generator can generate "
|
| 83 |
+
"at least `steps_per_epoch * epochs` batches. "
|
| 84 |
+
"You may need to use the `.repeat()` "
|
| 85 |
+
"function when building your dataset.",
|
| 86 |
+
stacklevel=2,
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def reset(self):
|
| 90 |
+
self._current_iterator = None
|
| 91 |
+
self._num_batches = self.data_adapter.num_batches
|
| 92 |
+
self._steps_seen = 0
|
| 93 |
+
self._epoch_iterator = None
|
| 94 |
+
self.data_adapter.on_epoch_end()
|
| 95 |
+
|
| 96 |
+
def _enumerate_iterator(self):
|
| 97 |
+
self.data_adapter.on_epoch_begin()
|
| 98 |
+
steps_per_epoch = self.steps_per_epoch or self._num_batches or -1
|
| 99 |
+
|
| 100 |
+
if steps_per_epoch > 0:
|
| 101 |
+
if self._current_iterator is None or self.steps_per_epoch is None:
|
| 102 |
+
self._current_iterator = iter(self._get_iterator())
|
| 103 |
+
self._steps_seen = 0
|
| 104 |
+
for step in range(0, steps_per_epoch, self.steps_per_execution):
|
| 105 |
+
if self._num_batches and self._steps_seen >= self._num_batches:
|
| 106 |
+
if self.steps_per_epoch:
|
| 107 |
+
self._interrupted_warning()
|
| 108 |
+
break
|
| 109 |
+
self._steps_seen += self.steps_per_execution
|
| 110 |
+
yield step, self._current_iterator
|
| 111 |
+
if self._num_batches and self._steps_seen >= self._num_batches:
|
| 112 |
+
self._current_iterator = iter(self._get_iterator())
|
| 113 |
+
self._steps_seen = 0
|
| 114 |
+
else:
|
| 115 |
+
iterator = iter(self._get_iterator())
|
| 116 |
+
step = -self.steps_per_execution
|
| 117 |
+
while True:
|
| 118 |
+
step += self.steps_per_execution
|
| 119 |
+
self._steps_seen = step + self.steps_per_execution
|
| 120 |
+
yield step, iterator
|
| 121 |
+
self.data_adapter.on_epoch_end()
|
| 122 |
+
|
| 123 |
+
def __iter__(self):
|
| 124 |
+
self._epoch_iterator = self._enumerate_iterator()
|
| 125 |
+
return self
|
| 126 |
+
|
| 127 |
+
def __next__(self):
|
| 128 |
+
buffer = []
|
| 129 |
+
step, iterator = next(self._epoch_iterator)
|
| 130 |
+
with self.catch_stop_iteration():
|
| 131 |
+
for _ in range(self.steps_per_execution):
|
| 132 |
+
data = next(iterator)
|
| 133 |
+
buffer.append(data)
|
| 134 |
+
return step, buffer
|
| 135 |
+
if buffer:
|
| 136 |
+
return step, buffer
|
| 137 |
+
raise StopIteration
|
| 138 |
+
|
| 139 |
+
def enumerate_epoch(self):
|
| 140 |
+
for step, data in self:
|
| 141 |
+
yield step, data
|
| 142 |
+
|
| 143 |
+
@contextlib.contextmanager
|
| 144 |
+
def catch_stop_iteration(self):
|
| 145 |
+
"""Catches errors when an iterator runs out of data."""
|
| 146 |
+
try:
|
| 147 |
+
yield
|
| 148 |
+
except StopIteration:
|
| 149 |
+
if self._num_batches is None:
|
| 150 |
+
self._num_batches = self._steps_seen
|
| 151 |
+
self._interrupted_warning()
|
| 152 |
+
self._current_iterator = None
|
| 153 |
+
self.data_adapter.on_epoch_end()
|
| 154 |
+
|
| 155 |
+
@property
|
| 156 |
+
def num_batches(self):
|
| 157 |
+
if self.steps_per_epoch:
|
| 158 |
+
return self.steps_per_epoch
|
| 159 |
+
# Either copied from the data_adapter, or
|
| 160 |
+
# inferred at the end of an iteration.
|
| 161 |
+
return self._num_batches
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py
ADDED
|
@@ -0,0 +1,1147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inspect
|
| 2 |
+
import platform
|
| 3 |
+
import warnings
|
| 4 |
+
|
| 5 |
+
from keras.src import backend
|
| 6 |
+
from keras.src import metrics as metrics_module
|
| 7 |
+
from keras.src import ops
|
| 8 |
+
from keras.src import optimizers
|
| 9 |
+
from keras.src import tree
|
| 10 |
+
from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer
|
| 11 |
+
from keras.src.saving import serialization_lib
|
| 12 |
+
from keras.src.trainers.compile_utils import CompileLoss
|
| 13 |
+
from keras.src.trainers.compile_utils import CompileMetrics
|
| 14 |
+
from keras.src.trainers.data_adapters import data_adapter_utils
|
| 15 |
+
from keras.src.utils import python_utils
|
| 16 |
+
from keras.src.utils import traceback_utils
|
| 17 |
+
from keras.src.utils import tracking
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class Trainer:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self._lock = False
|
| 23 |
+
self._run_eagerly = False
|
| 24 |
+
self._jit_compile = None
|
| 25 |
+
self.compiled = False
|
| 26 |
+
self.loss = None
|
| 27 |
+
self.steps_per_execution = 1
|
| 28 |
+
# Can be set by callbacks in on_train_begin
|
| 29 |
+
self._initial_epoch = None
|
| 30 |
+
self._compute_loss_has_training_arg = (
|
| 31 |
+
"training" in inspect.signature(self.compute_loss).parameters
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Placeholders used in `compile`
|
| 35 |
+
self._compile_loss = None
|
| 36 |
+
self._compile_metrics = None
|
| 37 |
+
self._loss_tracker = None
|
| 38 |
+
|
| 39 |
+
@traceback_utils.filter_traceback
|
| 40 |
+
@tracking.no_automatic_dependency_tracking
|
| 41 |
+
def compile(
|
| 42 |
+
self,
|
| 43 |
+
optimizer="rmsprop",
|
| 44 |
+
loss=None,
|
| 45 |
+
loss_weights=None,
|
| 46 |
+
metrics=None,
|
| 47 |
+
weighted_metrics=None,
|
| 48 |
+
run_eagerly=False,
|
| 49 |
+
steps_per_execution=1,
|
| 50 |
+
jit_compile="auto",
|
| 51 |
+
auto_scale_loss=True,
|
| 52 |
+
):
|
| 53 |
+
"""Configures the model for training.
|
| 54 |
+
|
| 55 |
+
Example:
|
| 56 |
+
|
| 57 |
+
```python
|
| 58 |
+
model.compile(
|
| 59 |
+
optimizer=keras.optimizers.Adam(learning_rate=1e-3),
|
| 60 |
+
loss=keras.losses.BinaryCrossentropy(),
|
| 61 |
+
metrics=[
|
| 62 |
+
keras.metrics.BinaryAccuracy(),
|
| 63 |
+
keras.metrics.FalseNegatives(),
|
| 64 |
+
],
|
| 65 |
+
)
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
optimizer: String (name of optimizer) or optimizer instance. See
|
| 70 |
+
`keras.optimizers`.
|
| 71 |
+
loss: Loss function. May be a string (name of loss function), or
|
| 72 |
+
a `keras.losses.Loss` instance. See `keras.losses`. A
|
| 73 |
+
loss function is any callable with the signature
|
| 74 |
+
`loss = fn(y_true, y_pred)`, where `y_true` are the ground truth
|
| 75 |
+
values, and `y_pred` are the model's predictions.
|
| 76 |
+
`y_true` should have shape `(batch_size, d0, .. dN)`
|
| 77 |
+
(except in the case of sparse loss functions such as
|
| 78 |
+
sparse categorical crossentropy which expects integer arrays of
|
| 79 |
+
shape `(batch_size, d0, .. dN-1)`).
|
| 80 |
+
`y_pred` should have shape `(batch_size, d0, .. dN)`.
|
| 81 |
+
The loss function should return a float tensor.
|
| 82 |
+
loss_weights: Optional list or dictionary specifying scalar
|
| 83 |
+
coefficients (Python floats) to weight the loss contributions of
|
| 84 |
+
different model outputs. The loss value that will be minimized
|
| 85 |
+
by the model will then be the *weighted sum* of all individual
|
| 86 |
+
losses, weighted by the `loss_weights` coefficients. If a list,
|
| 87 |
+
it is expected to have a 1:1 mapping to the model's outputs. If
|
| 88 |
+
a dict, it is expected to map output names (strings) to scalar
|
| 89 |
+
coefficients.
|
| 90 |
+
metrics: List of metrics to be evaluated by the model during
|
| 91 |
+
training and testing. Each of this can be a string (name of a
|
| 92 |
+
built-in function), function or a `keras.metrics.Metric`
|
| 93 |
+
instance. See `keras.metrics`. Typically you will use
|
| 94 |
+
`metrics=['accuracy']`. A function is any callable with the
|
| 95 |
+
signature `result = fn(y_true, _pred)`. To specify different
|
| 96 |
+
metrics for different outputs of a multi-output model, you could
|
| 97 |
+
also pass a dictionary, such as
|
| 98 |
+
`metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`.
|
| 99 |
+
You can also pass a list to specify a metric or a list of
|
| 100 |
+
metrics for each output, such as
|
| 101 |
+
`metrics=[['accuracy'], ['accuracy', 'mse']]`
|
| 102 |
+
or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass
|
| 103 |
+
the strings 'accuracy' or 'acc', we convert this to one of
|
| 104 |
+
`keras.metrics.BinaryAccuracy`,
|
| 105 |
+
`keras.metrics.CategoricalAccuracy`,
|
| 106 |
+
`keras.metrics.SparseCategoricalAccuracy` based on the
|
| 107 |
+
shapes of the targets and of the model output. A similar
|
| 108 |
+
conversion is done for the strings `"crossentropy"`
|
| 109 |
+
and `"ce"` as well.
|
| 110 |
+
The metrics passed here are evaluated without sample weighting;
|
| 111 |
+
if you would like sample weighting to apply, you can specify
|
| 112 |
+
your metrics via the `weighted_metrics` argument instead.
|
| 113 |
+
weighted_metrics: List of metrics to be evaluated and weighted by
|
| 114 |
+
`sample_weight` or `class_weight` during training and testing.
|
| 115 |
+
run_eagerly: Bool. If `True`, this model's forward pass
|
| 116 |
+
will never be compiled. It is recommended to leave this
|
| 117 |
+
as `False` when training (for best performance),
|
| 118 |
+
and to set it to `True` when debugging.
|
| 119 |
+
steps_per_execution: Int. The number of batches to run
|
| 120 |
+
during each a single compiled function call. Running multiple
|
| 121 |
+
batches inside a single compiled function call can
|
| 122 |
+
greatly improve performance on TPUs or small models with a large
|
| 123 |
+
Python overhead. At most, one full epoch will be run each
|
| 124 |
+
execution. If a number larger than the size of the epoch is
|
| 125 |
+
passed, the execution will be truncated to the size of the
|
| 126 |
+
epoch. Note that if `steps_per_execution` is set to `N`,
|
| 127 |
+
`Callback.on_batch_begin` and `Callback.on_batch_end` methods
|
| 128 |
+
will only be called every `N` batches (i.e. before/after
|
| 129 |
+
each compiled function execution).
|
| 130 |
+
Not supported with the PyTorch backend.
|
| 131 |
+
jit_compile: Bool or `"auto"`. Whether to use XLA compilation when
|
| 132 |
+
compiling a model. For `jax` and `tensorflow` backends,
|
| 133 |
+
`jit_compile="auto"` enables XLA compilation if the model
|
| 134 |
+
supports it, and disabled otherwise.
|
| 135 |
+
For `torch` backend, `"auto"` will default to eager
|
| 136 |
+
execution and `jit_compile=True` will run with `torch.compile`
|
| 137 |
+
with the `"inductor"` backend.
|
| 138 |
+
auto_scale_loss: Bool. If `True` and the model dtype policy is
|
| 139 |
+
`"mixed_float16"`, the passed optimizer will be automatically
|
| 140 |
+
wrapped in a `LossScaleOptimizer`, which will dynamically
|
| 141 |
+
scale the loss to prevent underflow.
|
| 142 |
+
"""
|
| 143 |
+
optimizer = optimizers.get(optimizer)
|
| 144 |
+
self.optimizer = optimizer
|
| 145 |
+
if (
|
| 146 |
+
auto_scale_loss
|
| 147 |
+
and self.dtype_policy.name == "mixed_float16"
|
| 148 |
+
and self.optimizer
|
| 149 |
+
and not isinstance(self.optimizer, LossScaleOptimizer)
|
| 150 |
+
):
|
| 151 |
+
self.optimizer = LossScaleOptimizer(
|
| 152 |
+
self.optimizer, name="loss_scale_optimizer"
|
| 153 |
+
)
|
| 154 |
+
if hasattr(self, "output_names"):
|
| 155 |
+
output_names = self.output_names
|
| 156 |
+
else:
|
| 157 |
+
output_names = None
|
| 158 |
+
if loss is not None:
|
| 159 |
+
self._compile_loss = CompileLoss(
|
| 160 |
+
loss, loss_weights, output_names=output_names
|
| 161 |
+
)
|
| 162 |
+
self.loss = loss
|
| 163 |
+
if metrics is not None or weighted_metrics is not None:
|
| 164 |
+
self._compile_metrics = CompileMetrics(
|
| 165 |
+
metrics, weighted_metrics, output_names=output_names
|
| 166 |
+
)
|
| 167 |
+
if jit_compile == "auto":
|
| 168 |
+
if run_eagerly:
|
| 169 |
+
jit_compile = False
|
| 170 |
+
else:
|
| 171 |
+
jit_compile = self._resolve_auto_jit_compile()
|
| 172 |
+
if jit_compile and run_eagerly:
|
| 173 |
+
jit_compile = False
|
| 174 |
+
warnings.warn(
|
| 175 |
+
"If `run_eagerly` is True, then `jit_compile` "
|
| 176 |
+
"cannot also be True. Disabling `jit_compile`.",
|
| 177 |
+
stacklevel=2,
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
self.jit_compile = jit_compile
|
| 181 |
+
self.run_eagerly = run_eagerly
|
| 182 |
+
self.stop_training = False
|
| 183 |
+
self.compiled = True
|
| 184 |
+
self._loss_tracker = metrics_module.Mean(name="loss")
|
| 185 |
+
self.steps_per_execution = steps_per_execution
|
| 186 |
+
|
| 187 |
+
self.train_function = None
|
| 188 |
+
self.test_function = None
|
| 189 |
+
self.predict_function = None
|
| 190 |
+
|
| 191 |
+
self._compile_config = serialization_lib.SerializableDict(
|
| 192 |
+
optimizer=optimizer,
|
| 193 |
+
loss=loss,
|
| 194 |
+
loss_weights=loss_weights,
|
| 195 |
+
metrics=metrics,
|
| 196 |
+
weighted_metrics=weighted_metrics,
|
| 197 |
+
run_eagerly=run_eagerly,
|
| 198 |
+
steps_per_execution=steps_per_execution,
|
| 199 |
+
jit_compile=jit_compile,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
@property
|
| 203 |
+
def jit_compile(self):
|
| 204 |
+
if self._jit_compile is None:
|
| 205 |
+
# Value was never set. Resolve it now.
|
| 206 |
+
self._jit_compile = self._resolve_auto_jit_compile()
|
| 207 |
+
return self._jit_compile
|
| 208 |
+
|
| 209 |
+
@jit_compile.setter
|
| 210 |
+
def jit_compile(self, value):
|
| 211 |
+
if value and not model_supports_jit(self):
|
| 212 |
+
warnings.warn(
|
| 213 |
+
"Model doesn't support `jit_compile=True`. "
|
| 214 |
+
"Proceeding with `jit_compile=False`."
|
| 215 |
+
)
|
| 216 |
+
self._jit_compile = False
|
| 217 |
+
else:
|
| 218 |
+
self._jit_compile = value
|
| 219 |
+
|
| 220 |
+
def _resolve_auto_jit_compile(self):
|
| 221 |
+
if backend.backend() == "torch":
|
| 222 |
+
# jit_compile = "auto" with the pytorch backend defaults to eager
|
| 223 |
+
return False
|
| 224 |
+
|
| 225 |
+
if backend.backend() == "tensorflow":
|
| 226 |
+
import tensorflow as tf
|
| 227 |
+
|
| 228 |
+
devices = tf.config.list_physical_devices()
|
| 229 |
+
if not list(filter(lambda x: x.device_type != "CPU", devices)):
|
| 230 |
+
# Disable XLA on CPU-only machines.
|
| 231 |
+
return False
|
| 232 |
+
|
| 233 |
+
if self._distribute_strategy:
|
| 234 |
+
# Disable XLA with tf.distribute
|
| 235 |
+
return False
|
| 236 |
+
|
| 237 |
+
if model_supports_jit(self):
|
| 238 |
+
return True
|
| 239 |
+
return False
|
| 240 |
+
|
| 241 |
+
@property
|
| 242 |
+
def run_eagerly(self):
|
| 243 |
+
return self._run_eagerly
|
| 244 |
+
|
| 245 |
+
@run_eagerly.setter
|
| 246 |
+
def run_eagerly(self, value):
|
| 247 |
+
self._run_eagerly = value
|
| 248 |
+
|
| 249 |
+
@property
|
| 250 |
+
def metrics(self):
|
| 251 |
+
# Order: loss tracker, individual loss trackers, compiled metrics,
|
| 252 |
+
# custom metrcis, sublayer metrics.
|
| 253 |
+
metrics = []
|
| 254 |
+
if self.compiled:
|
| 255 |
+
if self._loss_tracker is not None:
|
| 256 |
+
metrics.append(self._loss_tracker)
|
| 257 |
+
if self._compile_metrics is not None:
|
| 258 |
+
metrics.append(self._compile_metrics)
|
| 259 |
+
if self._compile_loss is not None:
|
| 260 |
+
metrics.extend(self._compile_loss.metrics)
|
| 261 |
+
metrics.extend(self._metrics)
|
| 262 |
+
for layer in self._flatten_layers(include_self=False):
|
| 263 |
+
if isinstance(layer, Trainer):
|
| 264 |
+
# All Trainer-related metrics in sublayers should be ignored
|
| 265 |
+
# because a new Trainer has been instantiated.
|
| 266 |
+
continue
|
| 267 |
+
metrics.extend(layer.metrics)
|
| 268 |
+
return metrics
|
| 269 |
+
|
| 270 |
+
@property
|
| 271 |
+
def metrics_names(self):
|
| 272 |
+
return [m.name for m in self.metrics]
|
| 273 |
+
|
| 274 |
+
def reset_metrics(self):
|
| 275 |
+
for m in self.metrics:
|
| 276 |
+
m.reset_state()
|
| 277 |
+
|
| 278 |
+
def _get_own_metrics(self):
|
| 279 |
+
metrics = []
|
| 280 |
+
if self._loss_tracker is not None:
|
| 281 |
+
metrics.append(self._loss_tracker)
|
| 282 |
+
if self._compile_metrics is not None:
|
| 283 |
+
metrics.append(self._compile_metrics)
|
| 284 |
+
if self._compile_loss is not None:
|
| 285 |
+
metrics.extend(self._compile_loss.metrics)
|
| 286 |
+
metrics.extend(self._metrics)
|
| 287 |
+
return metrics
|
| 288 |
+
|
| 289 |
+
def compute_loss(
|
| 290 |
+
self,
|
| 291 |
+
x=None,
|
| 292 |
+
y=None,
|
| 293 |
+
y_pred=None,
|
| 294 |
+
sample_weight=None,
|
| 295 |
+
training=True,
|
| 296 |
+
):
|
| 297 |
+
"""Compute the total loss, validate it, and return it.
|
| 298 |
+
|
| 299 |
+
Subclasses can optionally override this method to provide custom loss
|
| 300 |
+
computation logic.
|
| 301 |
+
|
| 302 |
+
Example:
|
| 303 |
+
|
| 304 |
+
```python
|
| 305 |
+
class MyModel(Model):
|
| 306 |
+
def __init__(self, *args, **kwargs):
|
| 307 |
+
super().__init__(*args, **kwargs)
|
| 308 |
+
self.loss_tracker = metrics.Mean(name='loss')
|
| 309 |
+
|
| 310 |
+
def compute_loss(self, x, y, y_pred, sample_weight, training=True):
|
| 311 |
+
loss = ops.mean((y_pred - y) ** 2)
|
| 312 |
+
loss += ops.sum(self.losses)
|
| 313 |
+
self.loss_tracker.update_state(loss)
|
| 314 |
+
return loss
|
| 315 |
+
|
| 316 |
+
def reset_metrics(self):
|
| 317 |
+
self.loss_tracker.reset_state()
|
| 318 |
+
|
| 319 |
+
@property
|
| 320 |
+
def metrics(self):
|
| 321 |
+
return [self.loss_tracker]
|
| 322 |
+
|
| 323 |
+
inputs = layers.Input(shape=(10,), name='my_input')
|
| 324 |
+
outputs = layers.Dense(10)(inputs)
|
| 325 |
+
model = MyModel(inputs, outputs)
|
| 326 |
+
model.add_loss(ops.sum(outputs))
|
| 327 |
+
|
| 328 |
+
optimizer = SGD()
|
| 329 |
+
model.compile(optimizer, loss='mse', steps_per_execution=10)
|
| 330 |
+
dataset = ...
|
| 331 |
+
model.fit(dataset, epochs=2, steps_per_epoch=10)
|
| 332 |
+
print(f"Custom loss: {model.loss_tracker.result()}")
|
| 333 |
+
```
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
x: Input data.
|
| 337 |
+
y: Target data.
|
| 338 |
+
y_pred: Predictions returned by the model (output of `model(x)`)
|
| 339 |
+
sample_weight: Sample weights for weighting the loss function.
|
| 340 |
+
training: Whether we are training or evaluating the model.
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
The total loss as a scalar tensor, or `None` if no loss results
|
| 344 |
+
(which is the case when called by `Model.test_step`).
|
| 345 |
+
"""
|
| 346 |
+
# The default implementation does not use `x` or `training`.
|
| 347 |
+
del x
|
| 348 |
+
del training
|
| 349 |
+
losses = []
|
| 350 |
+
if self._compile_loss is not None:
|
| 351 |
+
loss = self._compile_loss(y, y_pred, sample_weight)
|
| 352 |
+
if loss is not None:
|
| 353 |
+
losses.append(loss)
|
| 354 |
+
for loss in self.losses:
|
| 355 |
+
losses.append(self._aggregate_additional_loss(loss))
|
| 356 |
+
if backend.backend() != "jax" and len(losses) == 0:
|
| 357 |
+
raise ValueError(
|
| 358 |
+
"No loss to compute. Provide a `loss` argument in `compile()`."
|
| 359 |
+
)
|
| 360 |
+
if len(losses) == 1:
|
| 361 |
+
total_loss = losses[0]
|
| 362 |
+
elif len(losses) == 0:
|
| 363 |
+
total_loss = ops.zeros(())
|
| 364 |
+
else:
|
| 365 |
+
total_loss = ops.sum(losses)
|
| 366 |
+
return total_loss
|
| 367 |
+
|
| 368 |
+
def _compute_loss(
|
| 369 |
+
self,
|
| 370 |
+
x=None,
|
| 371 |
+
y=None,
|
| 372 |
+
y_pred=None,
|
| 373 |
+
sample_weight=None,
|
| 374 |
+
training=True,
|
| 375 |
+
):
|
| 376 |
+
"""Backwards compatibility wrapper for `compute_loss`.
|
| 377 |
+
|
| 378 |
+
This should be used instead `compute_loss` within `train_step` and
|
| 379 |
+
`test_step` to support overrides of `compute_loss` that may not have
|
| 380 |
+
the `training` argument, as this argument was added in Keras 3.3.
|
| 381 |
+
"""
|
| 382 |
+
if self._compute_loss_has_training_arg:
|
| 383 |
+
return self.compute_loss(
|
| 384 |
+
x, y, y_pred, sample_weight, training=training
|
| 385 |
+
)
|
| 386 |
+
else:
|
| 387 |
+
return self.compute_loss(x, y, y_pred, sample_weight)
|
| 388 |
+
|
| 389 |
+
def _aggregate_additional_loss(self, loss):
|
| 390 |
+
"""Aggregates losses from `add_loss`, regularizers and sublayers.
|
| 391 |
+
|
| 392 |
+
Args:
|
| 393 |
+
loss: A tensor representing the additional loss to aggregate.
|
| 394 |
+
|
| 395 |
+
Returns:
|
| 396 |
+
A tensor representing the summed loss, cast to the `floatx()` if
|
| 397 |
+
necessary.
|
| 398 |
+
"""
|
| 399 |
+
if not backend.is_float_dtype(loss.dtype):
|
| 400 |
+
loss = ops.cast(loss, dtype=backend.floatx())
|
| 401 |
+
return ops.sum(loss)
|
| 402 |
+
|
| 403 |
+
def stateless_compute_loss(
|
| 404 |
+
self,
|
| 405 |
+
trainable_variables,
|
| 406 |
+
non_trainable_variables,
|
| 407 |
+
metrics_variables,
|
| 408 |
+
x=None,
|
| 409 |
+
y=None,
|
| 410 |
+
y_pred=None,
|
| 411 |
+
sample_weight=None,
|
| 412 |
+
training=True,
|
| 413 |
+
):
|
| 414 |
+
var_mapping = list(zip(self.trainable_variables, trainable_variables))
|
| 415 |
+
var_mapping.extend(
|
| 416 |
+
zip(self.non_trainable_variables, non_trainable_variables)
|
| 417 |
+
)
|
| 418 |
+
var_mapping.extend(zip(self.metrics_variables, metrics_variables))
|
| 419 |
+
with backend.StatelessScope(state_mapping=var_mapping) as scope:
|
| 420 |
+
# Note that this is needed for the regularization loss, which need
|
| 421 |
+
# the latest value of train/non-trainable variables.
|
| 422 |
+
loss = self._compute_loss(
|
| 423 |
+
x,
|
| 424 |
+
y,
|
| 425 |
+
y_pred,
|
| 426 |
+
sample_weight=sample_weight,
|
| 427 |
+
training=training,
|
| 428 |
+
)
|
| 429 |
+
|
| 430 |
+
# Update non trainable vars (may have been updated in compute_loss)
|
| 431 |
+
non_trainable_variables = []
|
| 432 |
+
for v in self.non_trainable_variables:
|
| 433 |
+
new_v = scope.get_current_value(v)
|
| 434 |
+
non_trainable_variables.append(new_v)
|
| 435 |
+
|
| 436 |
+
# Update metrics vars (may have been updated in compute_loss)
|
| 437 |
+
metrics_variables = []
|
| 438 |
+
for v in self.metrics_variables:
|
| 439 |
+
new_v = scope.get_current_value(v)
|
| 440 |
+
metrics_variables.append(new_v)
|
| 441 |
+
return loss, (
|
| 442 |
+
trainable_variables,
|
| 443 |
+
non_trainable_variables,
|
| 444 |
+
metrics_variables,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def compute_metrics(self, x, y, y_pred, sample_weight=None):
|
| 448 |
+
"""Update metric states and collect all metrics to be returned.
|
| 449 |
+
|
| 450 |
+
Subclasses can optionally override this method to provide custom metric
|
| 451 |
+
updating and collection logic. Custom metrics are not passed in
|
| 452 |
+
`compile()`, they can be created in `__init__` or `build`. They are
|
| 453 |
+
automatically tracked and returned by `self.metrics`.
|
| 454 |
+
|
| 455 |
+
Example:
|
| 456 |
+
|
| 457 |
+
```python
|
| 458 |
+
class MyModel(Sequential):
|
| 459 |
+
def __init__(self, *args, **kwargs):
|
| 460 |
+
super().__init__(*args, **kwargs)
|
| 461 |
+
self.custom_metric = MyMetric(name="custom_metric")
|
| 462 |
+
|
| 463 |
+
def compute_metrics(self, x, y, y_pred, sample_weight):
|
| 464 |
+
# This super call updates metrics from `compile` and returns
|
| 465 |
+
# results for all metrics listed in `self.metrics`.
|
| 466 |
+
metric_results = super().compute_metrics(
|
| 467 |
+
x, y, y_pred, sample_weight)
|
| 468 |
+
|
| 469 |
+
# `metric_results` contains the previous result for
|
| 470 |
+
# `custom_metric`, this is where we update it.
|
| 471 |
+
self.custom_metric.update_state(x, y, y_pred, sample_weight)
|
| 472 |
+
metric_results['custom_metric'] = self.custom_metric.result()
|
| 473 |
+
return metric_results
|
| 474 |
+
```
|
| 475 |
+
|
| 476 |
+
Args:
|
| 477 |
+
x: Input data.
|
| 478 |
+
y: Target data.
|
| 479 |
+
y_pred: Predictions returned by the model output of `model.call(x)`.
|
| 480 |
+
sample_weight: Sample weights for weighting the loss function.
|
| 481 |
+
|
| 482 |
+
Returns:
|
| 483 |
+
A `dict` containing values that will be passed to
|
| 484 |
+
`keras.callbacks.CallbackList.on_train_batch_end()`. Typically,
|
| 485 |
+
the values of the metrics listed in `self.metrics` are returned.
|
| 486 |
+
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
|
| 487 |
+
"""
|
| 488 |
+
del x # The default implementation does not use `x`.
|
| 489 |
+
if self._compile_metrics is not None:
|
| 490 |
+
self._compile_metrics.update_state(y, y_pred, sample_weight)
|
| 491 |
+
return self.get_metrics_result()
|
| 492 |
+
|
| 493 |
+
def get_metrics_result(self):
|
| 494 |
+
"""Returns the model's metrics values as a dict.
|
| 495 |
+
|
| 496 |
+
If any of the metric result is a dict (containing multiple metrics),
|
| 497 |
+
each of them gets added to the top level returned dict of this method.
|
| 498 |
+
|
| 499 |
+
Returns:
|
| 500 |
+
A `dict` containing values of the metrics listed in `self.metrics`.
|
| 501 |
+
Example: `{'loss': 0.2, 'accuracy': 0.7}`.
|
| 502 |
+
"""
|
| 503 |
+
return_metrics = {}
|
| 504 |
+
for metric in self.metrics:
|
| 505 |
+
result = metric.result()
|
| 506 |
+
if isinstance(result, dict):
|
| 507 |
+
return_metrics.update(result)
|
| 508 |
+
else:
|
| 509 |
+
return_metrics[metric.name] = result
|
| 510 |
+
return python_utils.pythonify_logs(return_metrics)
|
| 511 |
+
|
| 512 |
+
def fit(
|
| 513 |
+
self,
|
| 514 |
+
x=None,
|
| 515 |
+
y=None,
|
| 516 |
+
batch_size=None,
|
| 517 |
+
epochs=1,
|
| 518 |
+
verbose="auto",
|
| 519 |
+
callbacks=None,
|
| 520 |
+
validation_split=0.0,
|
| 521 |
+
validation_data=None,
|
| 522 |
+
shuffle=True,
|
| 523 |
+
class_weight=None,
|
| 524 |
+
sample_weight=None,
|
| 525 |
+
initial_epoch=0,
|
| 526 |
+
steps_per_epoch=None,
|
| 527 |
+
validation_steps=None,
|
| 528 |
+
validation_batch_size=None,
|
| 529 |
+
validation_freq=1,
|
| 530 |
+
):
|
| 531 |
+
"""Trains the model for a fixed number of epochs (dataset iterations).
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
x: Input data. It can be:
|
| 535 |
+
- A NumPy array (or array-like), or a list of arrays
|
| 536 |
+
(in case the model has multiple inputs).
|
| 537 |
+
- A backend-native tensor, or a list of tensors
|
| 538 |
+
(in case the model has multiple inputs).
|
| 539 |
+
- A dict mapping input names to the corresponding array/tensors,
|
| 540 |
+
if the model has named inputs.
|
| 541 |
+
- A `keras.utils.PyDataset` returning `(inputs, targets)` or
|
| 542 |
+
`(inputs, targets, sample_weights)`.
|
| 543 |
+
- A `tf.data.Dataset` yielding `(inputs, targets)` or
|
| 544 |
+
`(inputs, targets, sample_weights)`.
|
| 545 |
+
- A `torch.utils.data.DataLoader` yielding `(inputs, targets)`
|
| 546 |
+
or `(inputs, targets, sample_weights)`.
|
| 547 |
+
- A Python generator function yielding `(inputs, targets)` or
|
| 548 |
+
`(inputs, targets, sample_weights)`.
|
| 549 |
+
y: Target data. Like the input data `x`, it can be either NumPy
|
| 550 |
+
array(s) or backend-native tensor(s). If `x` is a
|
| 551 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 552 |
+
`torch.utils.data.DataLoader` or a Python generator function,
|
| 553 |
+
`y` should not be specified since targets will be obtained from
|
| 554 |
+
`x`.
|
| 555 |
+
batch_size: Integer or `None`.
|
| 556 |
+
Number of samples per gradient update.
|
| 557 |
+
If unspecified, `batch_size` will default to 32.
|
| 558 |
+
Do not specify the `batch_size` if your input data `x` is a
|
| 559 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 560 |
+
`torch.utils.data.DataLoader` or Python generator function
|
| 561 |
+
since they generate batches.
|
| 562 |
+
epochs: Integer. Number of epochs to train the model.
|
| 563 |
+
An epoch is an iteration over the entire `x` and `y`
|
| 564 |
+
data provided
|
| 565 |
+
(unless the `steps_per_epoch` flag is set to
|
| 566 |
+
something other than None).
|
| 567 |
+
Note that in conjunction with `initial_epoch`,
|
| 568 |
+
`epochs` is to be understood as "final epoch".
|
| 569 |
+
The model is not trained for a number of iterations
|
| 570 |
+
given by `epochs`, but merely until the epoch
|
| 571 |
+
of index `epochs` is reached.
|
| 572 |
+
verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
|
| 573 |
+
0 = silent, 1 = progress bar, 2 = one line per epoch.
|
| 574 |
+
"auto" becomes 1 for most cases.
|
| 575 |
+
Note that the progress bar is not
|
| 576 |
+
particularly useful when logged to a file,
|
| 577 |
+
so `verbose=2` is recommended when not running interactively
|
| 578 |
+
(e.g., in a production environment). Defaults to `"auto"`.
|
| 579 |
+
callbacks: List of `keras.callbacks.Callback` instances.
|
| 580 |
+
List of callbacks to apply during training.
|
| 581 |
+
See `keras.callbacks`. Note
|
| 582 |
+
`keras.callbacks.ProgbarLogger` and
|
| 583 |
+
`keras.callbacks.History` callbacks are created
|
| 584 |
+
automatically and need not be passed to `model.fit()`.
|
| 585 |
+
`keras.callbacks.ProgbarLogger` is created
|
| 586 |
+
or not based on the `verbose` argument in `model.fit()`.
|
| 587 |
+
validation_split: Float between 0 and 1.
|
| 588 |
+
Fraction of the training data to be used as validation data.
|
| 589 |
+
The model will set apart this fraction of the training data,
|
| 590 |
+
will not train on it, and will evaluate the loss and any model
|
| 591 |
+
metrics on this data at the end of each epoch. The validation
|
| 592 |
+
data is selected from the last samples in the `x` and `y` data
|
| 593 |
+
provided, before shuffling.
|
| 594 |
+
This argument is only supported when `x` and `y` are made of
|
| 595 |
+
NumPy arrays or tensors.
|
| 596 |
+
If both `validation_data` and `validation_split` are provided,
|
| 597 |
+
`validation_data` will override `validation_split`.
|
| 598 |
+
validation_data: Data on which to evaluate
|
| 599 |
+
the loss and any model metrics at the end of each epoch.
|
| 600 |
+
The model will not be trained on this data. Thus, note the fact
|
| 601 |
+
that the validation loss of data provided using
|
| 602 |
+
`validation_split` or `validation_data` is not affected by
|
| 603 |
+
regularization layers like noise and dropout.
|
| 604 |
+
`validation_data` will override `validation_split`.
|
| 605 |
+
It can be:
|
| 606 |
+
- A tuple `(x_val, y_val)` of NumPy arrays or tensors.
|
| 607 |
+
- A tuple `(x_val, y_val, val_sample_weights)` of NumPy
|
| 608 |
+
arrays.
|
| 609 |
+
- A `keras.utils.PyDataset`, a `tf.data.Dataset`, a
|
| 610 |
+
`torch.utils.data.DataLoader` yielding `(inputs, targets)` or a
|
| 611 |
+
Python generator function yielding `(x_val, y_val)` or
|
| 612 |
+
`(inputs, targets, sample_weights)`.
|
| 613 |
+
shuffle: Boolean, whether to shuffle the training data before each
|
| 614 |
+
epoch. This argument is ignored when `x` is a
|
| 615 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 616 |
+
`torch.utils.data.DataLoader` or Python generator function.
|
| 617 |
+
class_weight: Optional dictionary mapping class indices (integers)
|
| 618 |
+
to a weight (float) value, used for weighting the loss function
|
| 619 |
+
(during training only).
|
| 620 |
+
This can be useful to tell the model to
|
| 621 |
+
"pay more attention" to samples from
|
| 622 |
+
an under-represented class. When `class_weight` is specified
|
| 623 |
+
and targets have a rank of 2 or greater, either `y` must be
|
| 624 |
+
one-hot encoded, or an explicit final dimension of `1` must
|
| 625 |
+
be included for sparse class labels.
|
| 626 |
+
sample_weight: Optional NumPy array or tensor of weights for
|
| 627 |
+
the training samples, used for weighting the loss function
|
| 628 |
+
(during training only). You can either pass a flat (1D)
|
| 629 |
+
NumPy array or tensor with the same length as the input samples
|
| 630 |
+
(1:1 mapping between weights and samples), or in the case of
|
| 631 |
+
temporal data, you can pass a 2D NumPy array or tensor with
|
| 632 |
+
shape `(samples, sequence_length)` to apply a different weight
|
| 633 |
+
to every timestep of every sample.
|
| 634 |
+
This argument is not supported when `x` is a
|
| 635 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 636 |
+
`torch.utils.data.DataLoader` or Python generator function.
|
| 637 |
+
Instead, provide `sample_weights` as the third element of `x`.
|
| 638 |
+
Note that sample weighting does not apply to metrics specified
|
| 639 |
+
via the `metrics` argument in `compile()`. To apply sample
|
| 640 |
+
weighting to your metrics, you can specify them via the
|
| 641 |
+
`weighted_metrics` in `compile()` instead.
|
| 642 |
+
initial_epoch: Integer.
|
| 643 |
+
Epoch at which to start training
|
| 644 |
+
(useful for resuming a previous training run).
|
| 645 |
+
steps_per_epoch: Integer or `None`.
|
| 646 |
+
Total number of steps (batches of samples) before declaring one
|
| 647 |
+
epoch finished and starting the next epoch. When training with
|
| 648 |
+
input tensors or NumPy arrays, the default `None` means that the
|
| 649 |
+
value used is the number of samples in your dataset divided by
|
| 650 |
+
the batch size, or 1 if that cannot be determined.
|
| 651 |
+
If `x` is a `keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 652 |
+
`torch.utils.data.DataLoader` or Python generator function, the
|
| 653 |
+
epoch will run until the input dataset is exhausted. When
|
| 654 |
+
passing an infinitely repeating dataset, you must specify the
|
| 655 |
+
`steps_per_epoch` argument, otherwise the training will run
|
| 656 |
+
indefinitely.
|
| 657 |
+
validation_steps: Integer or `None`.
|
| 658 |
+
Only relevant if `validation_data` is provided.
|
| 659 |
+
Total number of steps (batches of samples) to draw before
|
| 660 |
+
stopping when performing validation at the end of every epoch.
|
| 661 |
+
If `validation_steps` is `None`, validation will run until the
|
| 662 |
+
`validation_data` dataset is exhausted. In the case of an
|
| 663 |
+
infinitely repeating dataset, it will run indefinitely. If
|
| 664 |
+
`validation_steps` is specified and only part of the dataset
|
| 665 |
+
is consumed, the evaluation will start from the beginning of the
|
| 666 |
+
dataset at each epoch. This ensures that the same validation
|
| 667 |
+
samples are used every time.
|
| 668 |
+
validation_batch_size: Integer or `None`.
|
| 669 |
+
Number of samples per validation batch.
|
| 670 |
+
If unspecified, will default to `batch_size`.
|
| 671 |
+
Do not specify the `validation_batch_size` if your data is a
|
| 672 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 673 |
+
`torch.utils.data.DataLoader` or Python generator function
|
| 674 |
+
since they generate batches.
|
| 675 |
+
validation_freq: Only relevant if validation data is provided.
|
| 676 |
+
Specifies how many training epochs to run
|
| 677 |
+
before a new validation run is performed,
|
| 678 |
+
e.g. `validation_freq=2` runs validation every 2 epochs.
|
| 679 |
+
|
| 680 |
+
Unpacking behavior for iterator-like inputs:
|
| 681 |
+
A common pattern is to pass an iterator like object such as a
|
| 682 |
+
`tf.data.Dataset` or a `keras.utils.PyDataset` to `fit()`,
|
| 683 |
+
which will in fact yield not only features (`x`)
|
| 684 |
+
but optionally targets (`y`) and sample weights (`sample_weight`).
|
| 685 |
+
Keras requires that the output of such iterator-likes be
|
| 686 |
+
unambiguous. The iterator should return a tuple
|
| 687 |
+
of length 1, 2, or 3, where the optional second and third elements
|
| 688 |
+
will be used for `y` and `sample_weight` respectively.
|
| 689 |
+
Any other type provided will be wrapped in
|
| 690 |
+
a length-one tuple, effectively treating everything as `x`. When
|
| 691 |
+
yielding dicts, they should still adhere to the top-level tuple
|
| 692 |
+
structure,
|
| 693 |
+
e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
|
| 694 |
+
features, targets, and weights from the keys of a single dict.
|
| 695 |
+
A notable unsupported data type is the `namedtuple`. The reason is
|
| 696 |
+
that it behaves like both an ordered datatype (tuple) and a mapping
|
| 697 |
+
datatype (dict). So given a namedtuple of the form:
|
| 698 |
+
`namedtuple("example_tuple", ["y", "x"])`
|
| 699 |
+
it is ambiguous whether to reverse the order of the elements when
|
| 700 |
+
interpreting the value. Even worse is a tuple of the form:
|
| 701 |
+
`namedtuple("other_tuple", ["x", "y", "z"])`
|
| 702 |
+
where it is unclear if the tuple was intended to be unpacked
|
| 703 |
+
into `x`, `y`, and `sample_weight` or passed through
|
| 704 |
+
as a single element to `x`.
|
| 705 |
+
|
| 706 |
+
Returns:
|
| 707 |
+
A `History` object. Its `History.history` attribute is
|
| 708 |
+
a record of training loss values and metrics values
|
| 709 |
+
at successive epochs, as well as validation loss values
|
| 710 |
+
and validation metrics values (if applicable).
|
| 711 |
+
"""
|
| 712 |
+
raise NotImplementedError
|
| 713 |
+
|
| 714 |
+
def evaluate(
|
| 715 |
+
self,
|
| 716 |
+
x=None,
|
| 717 |
+
y=None,
|
| 718 |
+
batch_size=None,
|
| 719 |
+
verbose="auto",
|
| 720 |
+
sample_weight=None,
|
| 721 |
+
steps=None,
|
| 722 |
+
callbacks=None,
|
| 723 |
+
return_dict=False,
|
| 724 |
+
**kwargs,
|
| 725 |
+
):
|
| 726 |
+
"""Returns the loss value & metrics values for the model in test mode.
|
| 727 |
+
|
| 728 |
+
Computation is done in batches (see the `batch_size` arg.)
|
| 729 |
+
|
| 730 |
+
Args:
|
| 731 |
+
x: Input data. It can be:
|
| 732 |
+
- A NumPy array (or array-like), or a list of arrays
|
| 733 |
+
(in case the model has multiple inputs).
|
| 734 |
+
- A backend-native tensor, or a list of tensors
|
| 735 |
+
(in case the model has multiple inputs).
|
| 736 |
+
- A dict mapping input names to the corresponding array/tensors,
|
| 737 |
+
if the model has named inputs.
|
| 738 |
+
- A `keras.utils.PyDataset` returning `(inputs, targets)` or
|
| 739 |
+
`(inputs, targets, sample_weights)`.
|
| 740 |
+
- A `tf.data.Dataset` yielding `(inputs, targets)` or
|
| 741 |
+
`(inputs, targets, sample_weights)`.
|
| 742 |
+
- A `torch.utils.data.DataLoader` yielding `(inputs, targets)`
|
| 743 |
+
or `(inputs, targets, sample_weights)`.
|
| 744 |
+
- A Python generator function yielding `(inputs, targets)` or
|
| 745 |
+
`(inputs, targets, sample_weights)`.
|
| 746 |
+
y: Target data. Like the input data `x`, it can be either NumPy
|
| 747 |
+
array(s) or backend-native tensor(s). If `x` is a
|
| 748 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 749 |
+
`torch.utils.data.DataLoader` or a Python generator function,
|
| 750 |
+
`y` should not be specified since targets will be obtained from
|
| 751 |
+
`x`.
|
| 752 |
+
batch_size: Integer or `None`.
|
| 753 |
+
Number of samples per batch of computation.
|
| 754 |
+
If unspecified, `batch_size` will default to 32.
|
| 755 |
+
Do not specify the `batch_size` if your input data `x` is a
|
| 756 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 757 |
+
`torch.utils.data.DataLoader` or Python generator function
|
| 758 |
+
since they generate batches.
|
| 759 |
+
verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
|
| 760 |
+
0 = silent, 1 = progress bar, 2 = single line.
|
| 761 |
+
`"auto"` becomes 1 for most cases.
|
| 762 |
+
Note that the progress bar is not
|
| 763 |
+
particularly useful when logged to a file, so `verbose=2` is
|
| 764 |
+
recommended when not running interactively
|
| 765 |
+
(e.g. in a production environment). Defaults to `"auto"`.
|
| 766 |
+
sample_weight: Optional NumPy array or tensor of weights for
|
| 767 |
+
the training samples, used for weighting the loss function
|
| 768 |
+
(during training only). You can either pass a flat (1D)
|
| 769 |
+
NumPy array or tensor with the same length as the input samples
|
| 770 |
+
(1:1 mapping between weights and samples), or in the case of
|
| 771 |
+
temporal data, you can pass a 2D NumPy array or tensor with
|
| 772 |
+
shape `(samples, sequence_length)` to apply a different weight
|
| 773 |
+
to every timestep of every sample.
|
| 774 |
+
This argument is not supported when `x` is a
|
| 775 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 776 |
+
`torch.utils.data.DataLoader` or Python generator function.
|
| 777 |
+
Instead, provide `sample_weights` as the third element of `x`.
|
| 778 |
+
Note that sample weighting does not apply to metrics specified
|
| 779 |
+
via the `metrics` argument in `compile()`. To apply sample
|
| 780 |
+
weighting to your metrics, you can specify them via the
|
| 781 |
+
`weighted_metrics` in `compile()` instead.
|
| 782 |
+
steps: Integer or `None`.
|
| 783 |
+
Total number of steps (batches of samples) to draw before
|
| 784 |
+
declaring the evaluation round finished. If `steps` is `None`,
|
| 785 |
+
it will run until `x` is exhausted. In the case of an infinitely
|
| 786 |
+
repeating dataset, it will run indefinitely.
|
| 787 |
+
callbacks: List of `keras.callbacks.Callback` instances.
|
| 788 |
+
List of callbacks to apply during evaluation.
|
| 789 |
+
return_dict: If `True`, loss and metric results are returned as a
|
| 790 |
+
dict, with each key being the name of the metric.
|
| 791 |
+
If `False`, they are returned as a list.
|
| 792 |
+
|
| 793 |
+
Returns:
|
| 794 |
+
Scalar test loss (if the model has a single output and no metrics)
|
| 795 |
+
or list of scalars (if the model has multiple outputs
|
| 796 |
+
and/or metrics). The attribute `model.metrics_names` will give you
|
| 797 |
+
the display labels for the scalar outputs.
|
| 798 |
+
"""
|
| 799 |
+
raise NotImplementedError
|
| 800 |
+
|
| 801 |
+
def predict(
|
| 802 |
+
self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
|
| 803 |
+
):
|
| 804 |
+
"""Generates output predictions for the input samples.
|
| 805 |
+
|
| 806 |
+
Computation is done in batches. This method is designed for batch
|
| 807 |
+
processing of large numbers of inputs. It is not intended for use inside
|
| 808 |
+
of loops that iterate over your data and process small numbers of inputs
|
| 809 |
+
at a time.
|
| 810 |
+
|
| 811 |
+
For small numbers of inputs that fit in one batch,
|
| 812 |
+
directly use `__call__()` for faster execution, e.g.,
|
| 813 |
+
`model(x)`, or `model(x, training=False)` if you have layers such as
|
| 814 |
+
`BatchNormalization` that behave differently during
|
| 815 |
+
inference.
|
| 816 |
+
|
| 817 |
+
Note: See [this FAQ entry](
|
| 818 |
+
https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)
|
| 819 |
+
for more details about the difference between `Model` methods
|
| 820 |
+
`predict()` and `__call__()`.
|
| 821 |
+
|
| 822 |
+
Args:
|
| 823 |
+
x: Input data. It can be:
|
| 824 |
+
- A NumPy array (or array-like), or a list of arrays
|
| 825 |
+
(in case the model has multiple inputs).
|
| 826 |
+
- A backend-native tensor, or a list of tensors
|
| 827 |
+
(in case the model has multiple inputs).
|
| 828 |
+
- A dict mapping input names to the corresponding array/tensors,
|
| 829 |
+
if the model has named inputs.
|
| 830 |
+
- A `keras.utils.PyDataset`.
|
| 831 |
+
- A `tf.data.Dataset`.
|
| 832 |
+
- A `torch.utils.data.DataLoader`.
|
| 833 |
+
- A Python generator function.
|
| 834 |
+
batch_size: Integer or `None`.
|
| 835 |
+
Number of samples per batch of computation.
|
| 836 |
+
If unspecified, `batch_size` will default to 32.
|
| 837 |
+
Do not specify the `batch_size` if your input data `x` is a
|
| 838 |
+
`keras.utils.PyDataset`, `tf.data.Dataset`,
|
| 839 |
+
`torch.utils.data.DataLoader` or Python generator function
|
| 840 |
+
since they generate batches.
|
| 841 |
+
verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
|
| 842 |
+
0 = silent, 1 = progress bar, 2 = single line.
|
| 843 |
+
`"auto"` becomes 1 for most cases. Note that the progress bar
|
| 844 |
+
is not particularly useful when logged to a file,
|
| 845 |
+
so `verbose=2` is recommended when not running interactively
|
| 846 |
+
(e.g. in a production environment). Defaults to `"auto"`.
|
| 847 |
+
steps: Total number of steps (batches of samples) to draw before
|
| 848 |
+
declaring the prediction round finished. If `steps` is `None`,
|
| 849 |
+
it will run until `x` is exhausted. In the case of an infinitely
|
| 850 |
+
repeating dataset, it will run indefinitely.
|
| 851 |
+
callbacks: List of `keras.callbacks.Callback` instances.
|
| 852 |
+
List of callbacks to apply during prediction.
|
| 853 |
+
|
| 854 |
+
Returns:
|
| 855 |
+
NumPy array(s) of predictions.
|
| 856 |
+
"""
|
| 857 |
+
raise NotImplementedError
|
| 858 |
+
|
| 859 |
+
def train_on_batch(
|
| 860 |
+
self,
|
| 861 |
+
x,
|
| 862 |
+
y=None,
|
| 863 |
+
sample_weight=None,
|
| 864 |
+
class_weight=None,
|
| 865 |
+
return_dict=False,
|
| 866 |
+
):
|
| 867 |
+
"""Runs a single gradient update on a single batch of data.
|
| 868 |
+
|
| 869 |
+
Args:
|
| 870 |
+
x: Input data. Must be array-like.
|
| 871 |
+
y: Target data. Must be array-like.
|
| 872 |
+
sample_weight: Optional array of the same length as x, containing
|
| 873 |
+
weights to apply to the model's loss for each sample.
|
| 874 |
+
In the case of temporal data, you can pass a 2D array
|
| 875 |
+
with shape `(samples, sequence_length)`, to apply a different
|
| 876 |
+
weight to every timestep of every sample.
|
| 877 |
+
class_weight: Optional dictionary mapping class indices (integers)
|
| 878 |
+
to a weight (float) to apply to the model's loss for the samples
|
| 879 |
+
from this class during training. This can be useful to tell the
|
| 880 |
+
model to "pay more attention" to samples from an
|
| 881 |
+
under-represented class. When `class_weight` is specified
|
| 882 |
+
and targets have a rank of 2 or greater, either `y` must
|
| 883 |
+
be one-hot encoded, or an explicit final dimension of 1
|
| 884 |
+
must be included for sparse class labels.
|
| 885 |
+
return_dict: If `True`, loss and metric results are returned as a
|
| 886 |
+
dict, with each key being the name of the metric. If `False`,
|
| 887 |
+
they are returned as a list.
|
| 888 |
+
|
| 889 |
+
Returns:
|
| 890 |
+
A scalar loss value (when no metrics and `return_dict=False`),
|
| 891 |
+
a list of loss and metric values
|
| 892 |
+
(if there are metrics and `return_dict=False`), or a dict of
|
| 893 |
+
metric and loss values (if `return_dict=True`).
|
| 894 |
+
"""
|
| 895 |
+
raise NotImplementedError
|
| 896 |
+
|
| 897 |
+
def test_on_batch(
|
| 898 |
+
self,
|
| 899 |
+
x,
|
| 900 |
+
y=None,
|
| 901 |
+
sample_weight=None,
|
| 902 |
+
return_dict=False,
|
| 903 |
+
):
|
| 904 |
+
"""Test the model on a single batch of samples.
|
| 905 |
+
|
| 906 |
+
Args:
|
| 907 |
+
x: Input data. Must be array-like.
|
| 908 |
+
y: Target data. Must be array-like.
|
| 909 |
+
sample_weight: Optional array of the same length as x, containing
|
| 910 |
+
weights to apply to the model's loss for each sample.
|
| 911 |
+
In the case of temporal data, you can pass a 2D array
|
| 912 |
+
with shape `(samples, sequence_length)`, to apply a different
|
| 913 |
+
weight to every timestep of every sample.
|
| 914 |
+
return_dict: If `True`, loss and metric results are returned as a
|
| 915 |
+
dict, with each key being the name of the metric. If `False`,
|
| 916 |
+
they are returned as a list.
|
| 917 |
+
|
| 918 |
+
Returns:
|
| 919 |
+
A scalar loss value (when no metrics and `return_dict=False`),
|
| 920 |
+
a list of loss and metric values
|
| 921 |
+
(if there are metrics and `return_dict=False`), or a dict of
|
| 922 |
+
metric and loss values (if `return_dict=True`).
|
| 923 |
+
"""
|
| 924 |
+
raise NotImplementedError
|
| 925 |
+
|
| 926 |
+
def predict_on_batch(self, x):
|
| 927 |
+
"""Returns predictions for a single batch of samples.
|
| 928 |
+
|
| 929 |
+
Args:
|
| 930 |
+
x: Input data. It must be array-like.
|
| 931 |
+
|
| 932 |
+
Returns:
|
| 933 |
+
NumPy array(s) of predictions.
|
| 934 |
+
"""
|
| 935 |
+
raise NotImplementedError
|
| 936 |
+
|
| 937 |
+
def get_compile_config(self):
|
| 938 |
+
"""Returns a serialized config with information for compiling the model.
|
| 939 |
+
|
| 940 |
+
This method returns a config dictionary containing all the information
|
| 941 |
+
(optimizer, loss, metrics, etc.) with which the model was compiled.
|
| 942 |
+
|
| 943 |
+
Returns:
|
| 944 |
+
A dict containing information for compiling the model.
|
| 945 |
+
"""
|
| 946 |
+
if self.compiled and hasattr(self, "_compile_config"):
|
| 947 |
+
return self._compile_config.serialize()
|
| 948 |
+
|
| 949 |
+
def compile_from_config(self, config):
|
| 950 |
+
"""Compiles the model with the information given in config.
|
| 951 |
+
|
| 952 |
+
This method uses the information in the config (optimizer, loss,
|
| 953 |
+
metrics, etc.) to compile the model.
|
| 954 |
+
|
| 955 |
+
Args:
|
| 956 |
+
config: Dict containing information for compiling the model.
|
| 957 |
+
"""
|
| 958 |
+
has_overridden_compile = self.__class__.compile != Trainer.compile
|
| 959 |
+
if has_overridden_compile:
|
| 960 |
+
warnings.warn(
|
| 961 |
+
"`compile()` was not called as part of model loading "
|
| 962 |
+
"because the model's `compile()` method is custom. "
|
| 963 |
+
"All subclassed Models that have `compile()` "
|
| 964 |
+
"overridden should also override "
|
| 965 |
+
"`get_compile_config()` and `compile_from_config(config)`. "
|
| 966 |
+
"Alternatively, you can "
|
| 967 |
+
"call `compile()` manually after loading.",
|
| 968 |
+
stacklevel=2,
|
| 969 |
+
)
|
| 970 |
+
return
|
| 971 |
+
config = serialization_lib.deserialize_keras_object(config)
|
| 972 |
+
self.compile(**config)
|
| 973 |
+
if hasattr(self, "optimizer") and self.built:
|
| 974 |
+
# Create optimizer variables.
|
| 975 |
+
self.optimizer.build(self.trainable_variables)
|
| 976 |
+
|
| 977 |
+
def _should_eval(self, epoch, validation_freq):
|
| 978 |
+
epoch = epoch + 1 # one-index the user-facing epoch.
|
| 979 |
+
if isinstance(validation_freq, int):
|
| 980 |
+
return epoch % validation_freq == 0
|
| 981 |
+
elif isinstance(validation_freq, list):
|
| 982 |
+
return epoch in validation_freq
|
| 983 |
+
else:
|
| 984 |
+
raise ValueError(
|
| 985 |
+
"Expected `validation_freq` to be a list or int. "
|
| 986 |
+
f"Received: validation_freq={validation_freq} of the "
|
| 987 |
+
f"type {type(validation_freq)}."
|
| 988 |
+
)
|
| 989 |
+
|
| 990 |
+
def _get_metrics_result_or_logs(self, logs):
|
| 991 |
+
"""Returns model metrics as a dict if the keys match with input logs.
|
| 992 |
+
|
| 993 |
+
When the training / evaluation is performed with an asynchronous steps,
|
| 994 |
+
the last scheduled `train / test_step` may not give the latest metrics
|
| 995 |
+
because it is not guaranteed to be executed the last. This method gets
|
| 996 |
+
metrics from the model directly instead of relying on the return from
|
| 997 |
+
last step function.
|
| 998 |
+
|
| 999 |
+
When the user has custom train / test step functions, the metrics
|
| 1000 |
+
returned may be different from `Model.metrics`. In those instances,
|
| 1001 |
+
this function will be no-op and return the logs passed in.
|
| 1002 |
+
|
| 1003 |
+
Args:
|
| 1004 |
+
logs: A `dict` of metrics returned by train / test step function.
|
| 1005 |
+
|
| 1006 |
+
Returns:
|
| 1007 |
+
A `dict` containing values of the metrics listed in `self.metrics`
|
| 1008 |
+
when logs and model metrics keys match. Otherwise it returns input
|
| 1009 |
+
`logs`.
|
| 1010 |
+
"""
|
| 1011 |
+
metric_logs = self.get_metrics_result()
|
| 1012 |
+
# Verify that train / test step logs passed and metric logs have
|
| 1013 |
+
# matching keys. It could be different when using custom step functions,
|
| 1014 |
+
# in which case we return the logs from the last step.
|
| 1015 |
+
if isinstance(logs, dict) and set(logs.keys()) == set(
|
| 1016 |
+
metric_logs.keys()
|
| 1017 |
+
):
|
| 1018 |
+
return metric_logs
|
| 1019 |
+
return logs
|
| 1020 |
+
|
| 1021 |
+
def _flatten_metrics_in_order(self, logs):
|
| 1022 |
+
"""Turns `logs` dict into a list as per key order of `metrics_names`."""
|
| 1023 |
+
metric_names = []
|
| 1024 |
+
for metric in self.metrics:
|
| 1025 |
+
if isinstance(metric, CompileMetrics):
|
| 1026 |
+
metric_names += [
|
| 1027 |
+
sub_metric.name for sub_metric in metric.metrics
|
| 1028 |
+
]
|
| 1029 |
+
else:
|
| 1030 |
+
metric_names.append(metric.name)
|
| 1031 |
+
results = []
|
| 1032 |
+
for name in metric_names:
|
| 1033 |
+
if name in logs:
|
| 1034 |
+
results.append(logs[name])
|
| 1035 |
+
for key in sorted(logs.keys()):
|
| 1036 |
+
if key not in metric_names:
|
| 1037 |
+
results.append(logs[key])
|
| 1038 |
+
if len(results) == 1:
|
| 1039 |
+
return results[0]
|
| 1040 |
+
return results
|
| 1041 |
+
|
| 1042 |
+
def _assert_compile_called(self, method_name=None):
|
| 1043 |
+
if not self.compiled:
|
| 1044 |
+
msg = "You must call `compile()` before "
|
| 1045 |
+
if metrics_module:
|
| 1046 |
+
msg += "using the model."
|
| 1047 |
+
else:
|
| 1048 |
+
msg += f"calling `{method_name}()`."
|
| 1049 |
+
raise ValueError(msg)
|
| 1050 |
+
|
| 1051 |
+
def _symbolic_build(self, iterator=None, data_batch=None):
|
| 1052 |
+
model_unbuilt = not all(layer.built for layer in self._flatten_layers())
|
| 1053 |
+
compile_metrics_unbuilt = (
|
| 1054 |
+
self._compile_metrics is not None
|
| 1055 |
+
and not self._compile_metrics.built
|
| 1056 |
+
)
|
| 1057 |
+
compile_loss_unbuilt = (
|
| 1058 |
+
self._compile_loss is not None and not self._compile_loss.built
|
| 1059 |
+
)
|
| 1060 |
+
optimizer_unbuilt = (
|
| 1061 |
+
self.optimizer is not None and not self.optimizer.built
|
| 1062 |
+
)
|
| 1063 |
+
if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
|
| 1064 |
+
# Create symbolic tensors matching an input batch.
|
| 1065 |
+
|
| 1066 |
+
def to_symbolic_input(v):
|
| 1067 |
+
if v is None:
|
| 1068 |
+
return None
|
| 1069 |
+
return backend.KerasTensor(
|
| 1070 |
+
v.shape, backend.standardize_dtype(v.dtype)
|
| 1071 |
+
)
|
| 1072 |
+
|
| 1073 |
+
if data_batch is None:
|
| 1074 |
+
for _, data_or_iterator in iterator:
|
| 1075 |
+
if isinstance(data_or_iterator, (list, tuple)):
|
| 1076 |
+
data_batch = data_or_iterator[0]
|
| 1077 |
+
else:
|
| 1078 |
+
data_batch = next(data_or_iterator)
|
| 1079 |
+
break
|
| 1080 |
+
data_batch = tree.map_structure(to_symbolic_input, data_batch)
|
| 1081 |
+
(
|
| 1082 |
+
x,
|
| 1083 |
+
y,
|
| 1084 |
+
sample_weight,
|
| 1085 |
+
) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
|
| 1086 |
+
|
| 1087 |
+
# Build all model state with `backend.compute_output_spec`.
|
| 1088 |
+
try:
|
| 1089 |
+
y_pred = backend.compute_output_spec(self, x, training=False)
|
| 1090 |
+
except Exception as e:
|
| 1091 |
+
raise RuntimeError(
|
| 1092 |
+
"Unable to automatically build the model. "
|
| 1093 |
+
"Please build it yourself before calling "
|
| 1094 |
+
"fit/evaluate/predict. "
|
| 1095 |
+
"A model is 'built' when its variables have "
|
| 1096 |
+
"been created and its `self.built` attribute "
|
| 1097 |
+
"is True. Usually, calling the model on a batch "
|
| 1098 |
+
"of data is the right way to build it.\n"
|
| 1099 |
+
"Exception encountered:\n"
|
| 1100 |
+
f"'{e}'"
|
| 1101 |
+
)
|
| 1102 |
+
if compile_metrics_unbuilt:
|
| 1103 |
+
# Build all metric state with `backend.compute_output_spec`.
|
| 1104 |
+
backend.compute_output_spec(
|
| 1105 |
+
self.compute_metrics,
|
| 1106 |
+
x,
|
| 1107 |
+
y,
|
| 1108 |
+
y_pred,
|
| 1109 |
+
sample_weight=sample_weight,
|
| 1110 |
+
)
|
| 1111 |
+
if compile_loss_unbuilt:
|
| 1112 |
+
# Build `CompileLoss` state with `backend.compute_output_spec`.
|
| 1113 |
+
backend.compute_output_spec(
|
| 1114 |
+
self._compute_loss,
|
| 1115 |
+
x,
|
| 1116 |
+
y,
|
| 1117 |
+
y_pred,
|
| 1118 |
+
sample_weight=sample_weight,
|
| 1119 |
+
training=False,
|
| 1120 |
+
)
|
| 1121 |
+
if optimizer_unbuilt:
|
| 1122 |
+
# Build optimizer
|
| 1123 |
+
self.optimizer.build(self.trainable_variables)
|
| 1124 |
+
self._post_build()
|
| 1125 |
+
|
| 1126 |
+
|
| 1127 |
+
def model_supports_jit(model):
|
| 1128 |
+
# XLA not supported with TF on MacOS GPU
|
| 1129 |
+
if platform.system() == "Darwin" and "arm" in platform.processor().lower():
|
| 1130 |
+
if backend.backend() == "tensorflow":
|
| 1131 |
+
from keras.src.utils.module_utils import tensorflow as tf
|
| 1132 |
+
|
| 1133 |
+
if tf.config.list_physical_devices("GPU"):
|
| 1134 |
+
return False
|
| 1135 |
+
# XLA not supported by some layers
|
| 1136 |
+
if all(x.supports_jit for x in model._flatten_layers()):
|
| 1137 |
+
if backend.backend() == "tensorflow":
|
| 1138 |
+
from tensorflow.python.framework.config import (
|
| 1139 |
+
is_op_determinism_enabled,
|
| 1140 |
+
)
|
| 1141 |
+
|
| 1142 |
+
if is_op_determinism_enabled():
|
| 1143 |
+
# disable XLA with determinism enabled since not all ops are
|
| 1144 |
+
# supported by XLA with determinism enabled.
|
| 1145 |
+
return False
|
| 1146 |
+
return True
|
| 1147 |
+
return False
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.tree.tree_api import assert_same_paths
|
| 2 |
+
from keras.src.tree.tree_api import assert_same_structure
|
| 3 |
+
from keras.src.tree.tree_api import flatten
|
| 4 |
+
from keras.src.tree.tree_api import flatten_with_path
|
| 5 |
+
from keras.src.tree.tree_api import is_nested
|
| 6 |
+
from keras.src.tree.tree_api import lists_to_tuples
|
| 7 |
+
from keras.src.tree.tree_api import map_shape_structure
|
| 8 |
+
from keras.src.tree.tree_api import map_structure
|
| 9 |
+
from keras.src.tree.tree_api import map_structure_up_to
|
| 10 |
+
from keras.src.tree.tree_api import pack_sequence_as
|
| 11 |
+
from keras.src.tree.tree_api import register_tree_node_class
|
| 12 |
+
from keras.src.tree.tree_api import traverse
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (676 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc
ADDED
|
Binary file (6.11 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py
ADDED
|
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import collections
|
| 2 |
+
import collections.abc
|
| 3 |
+
import itertools
|
| 4 |
+
|
| 5 |
+
from keras.src.backend.config import backend
|
| 6 |
+
from keras.src.utils.module_utils import dmtree
|
| 7 |
+
|
| 8 |
+
# NOTE: There are two known discrepancies between this `dmtree` implementation
|
| 9 |
+
# of the tree API and the `optree` implementation:
|
| 10 |
+
#
|
| 11 |
+
# 1. `map_structure` with *multiple* structures and `map_structure_up_to` do not
|
| 12 |
+
# use the object registration (they use the raw `dmtree.map_structure` and
|
| 13 |
+
# `dmtree.map_structure_up_to`). This only has consequences with two types of
|
| 14 |
+
# structures:
|
| 15 |
+
# - `TrackedSet` will not explored (considered as a leaf).
|
| 16 |
+
# - `OrderedDict` will be traversed in the order of sorted keys, not the
|
| 17 |
+
# order of the items. This is typically inconsequential because functions
|
| 18 |
+
# used with `map_structure` and `map_structure_up_to` are typically not
|
| 19 |
+
# order dependent and are, in fact, stateless.
|
| 20 |
+
#
|
| 21 |
+
# 2. The handling of non-sortable keys in dictionaries in inconsistent. `optree`
|
| 22 |
+
# uses the iteration order while `dmtree` raises an error. This is not an
|
| 23 |
+
# issue as keys are always strings. But this is the reason why we document
|
| 24 |
+
# non-sortable keys as unsupported (meaning behavior is undefined).
|
| 25 |
+
|
| 26 |
+
REGISTERED_CLASSES = {}
|
| 27 |
+
|
| 28 |
+
ClassRegistration = collections.namedtuple(
|
| 29 |
+
"ClassRegistration", ["flatten", "unflatten"]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class TypeErrorRemapping:
|
| 34 |
+
def __enter__(self):
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 38 |
+
if exc_type is TypeError:
|
| 39 |
+
raise ValueError(exc_value).with_traceback(traceback)
|
| 40 |
+
return False
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def register_tree_node(
|
| 44 |
+
cls,
|
| 45 |
+
flatten_func=None,
|
| 46 |
+
unflatten_func=None,
|
| 47 |
+
):
|
| 48 |
+
if flatten_func is None:
|
| 49 |
+
flatten_func = lambda x: x.tree_flatten()
|
| 50 |
+
if unflatten_func is None:
|
| 51 |
+
unflatten_func = cls.tree_unflatten
|
| 52 |
+
REGISTERED_CLASSES[cls] = ClassRegistration(flatten_func, unflatten_func)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def register_tree_node_class(cls):
|
| 56 |
+
register_tree_node(cls)
|
| 57 |
+
return cls
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
register_tree_node(
|
| 61 |
+
collections.OrderedDict,
|
| 62 |
+
lambda d: (d.values(), list(d.keys()), d.keys()),
|
| 63 |
+
lambda metadata, children: collections.OrderedDict(zip(metadata, children)),
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
if backend() == "tensorflow":
|
| 67 |
+
from tensorflow.python.trackable.data_structures import ListWrapper
|
| 68 |
+
from tensorflow.python.trackable.data_structures import _DictWrapper
|
| 69 |
+
|
| 70 |
+
register_tree_node(
|
| 71 |
+
ListWrapper,
|
| 72 |
+
lambda x: (x, None),
|
| 73 |
+
lambda metadata, children: ListWrapper(list(children)),
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def sorted_keys_and_values(d):
|
| 77 |
+
keys = sorted(list(d.keys()))
|
| 78 |
+
values = [d[k] for k in keys]
|
| 79 |
+
return values, keys, keys
|
| 80 |
+
|
| 81 |
+
register_tree_node(
|
| 82 |
+
_DictWrapper,
|
| 83 |
+
sorted_keys_and_values,
|
| 84 |
+
lambda metadata, children: _DictWrapper(
|
| 85 |
+
{key: child for key, child in zip(metadata, children)}
|
| 86 |
+
),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def is_nested(structure):
|
| 91 |
+
return type(structure) in REGISTERED_CLASSES or dmtree.is_nested(structure)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def traverse(func, structure, top_down=True):
|
| 95 |
+
if not callable(func):
|
| 96 |
+
raise TypeError(
|
| 97 |
+
f"`func` must be callable, got {func} of type {type(func)}"
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
def remap_map_to_none(value, new_value):
|
| 101 |
+
if isinstance(value, type) and value.__name__ == "MAP_TO_NONE":
|
| 102 |
+
return new_value
|
| 103 |
+
return value
|
| 104 |
+
|
| 105 |
+
def traverse_top_down(s):
|
| 106 |
+
ret = func(s)
|
| 107 |
+
if ret is not None:
|
| 108 |
+
return remap_map_to_none(ret, dmtree.MAP_TO_NONE)
|
| 109 |
+
registration = REGISTERED_CLASSES.get(type(s), None)
|
| 110 |
+
if registration is None:
|
| 111 |
+
return None
|
| 112 |
+
flat_meta_s = registration.flatten(s)
|
| 113 |
+
flat_s = [
|
| 114 |
+
dmtree.traverse(traverse_top_down, x, top_down=True)
|
| 115 |
+
for x in list(flat_meta_s[0])
|
| 116 |
+
]
|
| 117 |
+
return registration.unflatten(flat_meta_s[1], flat_s)
|
| 118 |
+
|
| 119 |
+
def traverse_bottom_up(s):
|
| 120 |
+
registration = REGISTERED_CLASSES.get(type(s), None)
|
| 121 |
+
if registration is not None:
|
| 122 |
+
flat_meta_s = registration.flatten(s)
|
| 123 |
+
ret = [traverse_bottom_up(x) for x in list(flat_meta_s[0])]
|
| 124 |
+
ret = registration.unflatten(flat_meta_s[1], ret)
|
| 125 |
+
elif not dmtree.is_nested(s):
|
| 126 |
+
ret = s
|
| 127 |
+
elif isinstance(s, collections.abc.Mapping):
|
| 128 |
+
ret = [traverse_bottom_up(s[key]) for key in sorted(s)]
|
| 129 |
+
ret = dmtree._sequence_like(s, ret)
|
| 130 |
+
else:
|
| 131 |
+
ret = [traverse_bottom_up(x) for x in s]
|
| 132 |
+
ret = dmtree._sequence_like(s, ret)
|
| 133 |
+
func_ret = func(ret)
|
| 134 |
+
return ret if func_ret is None else remap_map_to_none(func_ret, None)
|
| 135 |
+
|
| 136 |
+
if top_down:
|
| 137 |
+
return dmtree.traverse(traverse_top_down, structure, top_down=True)
|
| 138 |
+
else:
|
| 139 |
+
return traverse_bottom_up(structure)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def flatten(structure):
|
| 143 |
+
if not is_nested(structure):
|
| 144 |
+
return [structure]
|
| 145 |
+
|
| 146 |
+
flattened = []
|
| 147 |
+
|
| 148 |
+
def flatten_func(s):
|
| 149 |
+
registration = REGISTERED_CLASSES.get(type(s), None)
|
| 150 |
+
if registration is not None:
|
| 151 |
+
flat_s = list(registration.flatten(s)[0])
|
| 152 |
+
return dmtree.traverse(flatten_func, flat_s, top_down=True)
|
| 153 |
+
if not is_nested(s):
|
| 154 |
+
flattened.append(s)
|
| 155 |
+
return dmtree.MAP_TO_NONE if s is None else s
|
| 156 |
+
return None
|
| 157 |
+
|
| 158 |
+
dmtree.traverse(flatten_func, structure, top_down=True)
|
| 159 |
+
return flattened
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _recursive_flatten_with_path(path, structure, flattened):
|
| 163 |
+
registration = REGISTERED_CLASSES.get(type(structure), None)
|
| 164 |
+
if registration is not None:
|
| 165 |
+
flat_meta_paths = registration.flatten(structure)
|
| 166 |
+
flat = flat_meta_paths[0]
|
| 167 |
+
paths = (
|
| 168 |
+
flat_meta_paths[2]
|
| 169 |
+
if len(flat_meta_paths) >= 3
|
| 170 |
+
else itertools.count()
|
| 171 |
+
)
|
| 172 |
+
for key, value in zip(paths, flat):
|
| 173 |
+
_recursive_flatten_with_path(path + (key,), value, flattened)
|
| 174 |
+
elif not dmtree.is_nested(structure):
|
| 175 |
+
flattened.append((path, structure))
|
| 176 |
+
elif isinstance(structure, collections.abc.Mapping):
|
| 177 |
+
for key in sorted(structure):
|
| 178 |
+
_recursive_flatten_with_path(
|
| 179 |
+
path + (key,), structure[key], flattened
|
| 180 |
+
)
|
| 181 |
+
else:
|
| 182 |
+
for key, value in enumerate(structure):
|
| 183 |
+
_recursive_flatten_with_path(path + (key,), value, flattened)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def flatten_with_path(structure):
|
| 187 |
+
if not is_nested(structure):
|
| 188 |
+
return [((), structure)]
|
| 189 |
+
|
| 190 |
+
# Fully reimplemented in Python to handle registered classes, OrderedDict
|
| 191 |
+
# and namedtuples the same way as optree.
|
| 192 |
+
flattened = []
|
| 193 |
+
_recursive_flatten_with_path((), structure, flattened)
|
| 194 |
+
return flattened
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def map_structure(func, *structures):
|
| 198 |
+
if not callable(func):
|
| 199 |
+
raise TypeError(
|
| 200 |
+
f"`func` must be callable, got {func} of type {type(func)}"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
def func_traverse_wrapper(s):
|
| 204 |
+
if is_nested(s):
|
| 205 |
+
return None
|
| 206 |
+
ret = func(s)
|
| 207 |
+
if ret is None:
|
| 208 |
+
return dmtree.MAP_TO_NONE
|
| 209 |
+
return ret
|
| 210 |
+
|
| 211 |
+
if len(structures) == 1:
|
| 212 |
+
return traverse(func_traverse_wrapper, structures[0])
|
| 213 |
+
|
| 214 |
+
with TypeErrorRemapping():
|
| 215 |
+
return dmtree.map_structure(func, *structures)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def map_structure_up_to(shallow_structure, func, *structures):
|
| 219 |
+
if not callable(func):
|
| 220 |
+
raise TypeError(
|
| 221 |
+
f"`func` must be callable, got {func} of type {type(func)}"
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
with TypeErrorRemapping():
|
| 225 |
+
return dmtree.map_structure_up_to(shallow_structure, func, *structures)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def assert_same_structure(a, b):
|
| 229 |
+
# Fully reimplemented in Python to handle registered classes.
|
| 230 |
+
|
| 231 |
+
# Don't handle OrderedDict as a registered class, use the normal dict path
|
| 232 |
+
# so that OrderedDict is equivalent to dict per optree behavior.
|
| 233 |
+
a_registration = REGISTERED_CLASSES.get(type(a), None)
|
| 234 |
+
if isinstance(a, collections.OrderedDict):
|
| 235 |
+
a_registration = None
|
| 236 |
+
|
| 237 |
+
b_registration = REGISTERED_CLASSES.get(type(b), None)
|
| 238 |
+
if isinstance(b, collections.OrderedDict):
|
| 239 |
+
b_registration = None
|
| 240 |
+
|
| 241 |
+
if a_registration != b_registration:
|
| 242 |
+
raise ValueError(
|
| 243 |
+
f"Custom node type mismatch; "
|
| 244 |
+
f"expected type: {type(a)}, got type: {type(b)} "
|
| 245 |
+
f"while comparing {a} and {b}."
|
| 246 |
+
)
|
| 247 |
+
if a_registration is not None:
|
| 248 |
+
a_flat_meta = a_registration.flatten(a)
|
| 249 |
+
b_flat_meta = b_registration.flatten(b)
|
| 250 |
+
a_flat = list(a_flat_meta[0])
|
| 251 |
+
b_flat = list(b_flat_meta[0])
|
| 252 |
+
if not a_flat_meta[1] == b_flat_meta[1]:
|
| 253 |
+
raise ValueError(
|
| 254 |
+
f"Mismatch custom node data; "
|
| 255 |
+
f"expected: {a_flat_meta[1]}, got: {b_flat_meta[1]} "
|
| 256 |
+
f"while comparing {a} and {b}."
|
| 257 |
+
)
|
| 258 |
+
if len(a_flat) != len(b_flat):
|
| 259 |
+
raise ValueError(
|
| 260 |
+
f"Arity mismatch; expected: {len(a)}, got: {len(b)} "
|
| 261 |
+
f"while comparing {a} and {b}."
|
| 262 |
+
)
|
| 263 |
+
for sub_a, sub_b in zip(a_flat, b_flat):
|
| 264 |
+
assert_same_structure(sub_a, sub_b)
|
| 265 |
+
elif not dmtree.is_nested(a):
|
| 266 |
+
if dmtree.is_nested(b):
|
| 267 |
+
raise ValueError(
|
| 268 |
+
f"Structures don't have the same nested structure: {a}, {b}."
|
| 269 |
+
)
|
| 270 |
+
elif isinstance(
|
| 271 |
+
a, (dict, collections.OrderedDict, collections.defaultdict)
|
| 272 |
+
):
|
| 273 |
+
if not isinstance(
|
| 274 |
+
b, (dict, collections.OrderedDict, collections.defaultdict)
|
| 275 |
+
):
|
| 276 |
+
raise ValueError(
|
| 277 |
+
f"Expected an instance of dict, collections.OrderedDict, or "
|
| 278 |
+
f"collections.defaultdict, got {type(b)} "
|
| 279 |
+
f"while comparing {a} and {b}."
|
| 280 |
+
)
|
| 281 |
+
a_keys = sorted(a)
|
| 282 |
+
b_keys = sorted(b)
|
| 283 |
+
if not a_keys == b_keys:
|
| 284 |
+
raise ValueError(
|
| 285 |
+
f"Dictionary key mismatch; "
|
| 286 |
+
f"expected key(s): {a_keys}, got key(s): {b_keys} "
|
| 287 |
+
f"while comparing {a} and {b}."
|
| 288 |
+
)
|
| 289 |
+
for key in a_keys:
|
| 290 |
+
assert_same_structure(a[key], b[key])
|
| 291 |
+
elif isinstance(a, collections.abc.Mapping):
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"Encountered unregistered collections.abc.Mapping type: {type(a)} "
|
| 294 |
+
f"while comparing {a} and {b}."
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
if type(a) is not type(b):
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"Expected an instance of {type(a)}, got {type(b)} "
|
| 300 |
+
f"while comparing {a} and {b}."
|
| 301 |
+
)
|
| 302 |
+
if not len(a) == len(b):
|
| 303 |
+
raise ValueError(
|
| 304 |
+
f"Arity mismatch; expected: {len(a)}, got: {len(b)} "
|
| 305 |
+
f"while comparing {a} and {b}."
|
| 306 |
+
)
|
| 307 |
+
for sub_a, sub_b in zip(a, b):
|
| 308 |
+
assert_same_structure(sub_a, sub_b)
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def assert_same_paths(a, b):
|
| 312 |
+
a_paths = set([path for path, _ in flatten_with_path(a)])
|
| 313 |
+
b_paths = set([path for path, _ in flatten_with_path(b)])
|
| 314 |
+
|
| 315 |
+
if a_paths != b_paths:
|
| 316 |
+
msg = "`a` and `b` don't have the same paths."
|
| 317 |
+
a_diff = a_paths.difference(b_paths)
|
| 318 |
+
if a_diff:
|
| 319 |
+
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
|
| 320 |
+
b_diff = b_paths.difference(a_paths)
|
| 321 |
+
if b_diff:
|
| 322 |
+
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
|
| 323 |
+
raise ValueError(msg)
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
def pack_sequence_as(structure, flat_sequence):
|
| 327 |
+
# This is not just an optimization for the case when structure is a leaf.
|
| 328 |
+
# This is required to avoid Torch Dynamo failures.
|
| 329 |
+
if not is_nested(structure):
|
| 330 |
+
if len(flat_sequence) == 1:
|
| 331 |
+
return flat_sequence[0]
|
| 332 |
+
else:
|
| 333 |
+
raise ValueError(
|
| 334 |
+
"Incorrect number of leaves provided by `flat_sequence` for "
|
| 335 |
+
f"`structure`; expected: 1, got {len(flat_sequence)}."
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
flat_sequence_it = enumerate(flat_sequence)
|
| 339 |
+
|
| 340 |
+
def unflatten_func(s):
|
| 341 |
+
registration = REGISTERED_CLASSES.get(type(s), None)
|
| 342 |
+
if registration is not None:
|
| 343 |
+
flat_meta_s = registration.flatten(s)
|
| 344 |
+
flat_s = dmtree.traverse(
|
| 345 |
+
unflatten_func, list(flat_meta_s[0]), top_down=True
|
| 346 |
+
)
|
| 347 |
+
return registration.unflatten(flat_meta_s[1], flat_s)
|
| 348 |
+
elif not dmtree.is_nested(s):
|
| 349 |
+
try:
|
| 350 |
+
_, value = next(flat_sequence_it)
|
| 351 |
+
return dmtree.MAP_TO_NONE if value is None else value
|
| 352 |
+
except StopIteration:
|
| 353 |
+
raise ValueError(
|
| 354 |
+
"Too few leaves provided by `flat_sequence` for "
|
| 355 |
+
f"`structure`. Got {len(flat_sequence)}."
|
| 356 |
+
)
|
| 357 |
+
return None
|
| 358 |
+
|
| 359 |
+
ret = dmtree.traverse(unflatten_func, structure, top_down=True)
|
| 360 |
+
try:
|
| 361 |
+
index, _ = next(flat_sequence_it)
|
| 362 |
+
raise ValueError(
|
| 363 |
+
"Too many leaves provided by `flat_sequence` for `structure`; "
|
| 364 |
+
f"expected: {index}, got {len(flat_sequence)}."
|
| 365 |
+
)
|
| 366 |
+
except StopIteration:
|
| 367 |
+
return ret
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def lists_to_tuples(structure):
|
| 371 |
+
def list_to_tuple(instance):
|
| 372 |
+
return tuple(instance) if isinstance(instance, list) else None
|
| 373 |
+
|
| 374 |
+
return traverse(list_to_tuple, structure, top_down=False)
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def map_shape_structure(func, structure):
|
| 378 |
+
if not callable(func):
|
| 379 |
+
raise TypeError(
|
| 380 |
+
f"`func` must be callable, got {func} of type {type(func)}"
|
| 381 |
+
)
|
| 382 |
+
|
| 383 |
+
def map_shape_func(x):
|
| 384 |
+
if isinstance(x, (list, tuple)) and all(
|
| 385 |
+
isinstance(e, (int, type(None))) for e in x
|
| 386 |
+
):
|
| 387 |
+
ret = func(x)
|
| 388 |
+
elif is_nested(x):
|
| 389 |
+
return None
|
| 390 |
+
else:
|
| 391 |
+
ret = func(x)
|
| 392 |
+
return ret if ret is not None else dmtree.MAP_TO_NONE
|
| 393 |
+
|
| 394 |
+
return traverse(map_shape_func, structure, top_down=True)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optree
|
| 2 |
+
import optree.utils
|
| 3 |
+
|
| 4 |
+
from keras.src.backend.config import backend
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def register_tree_node_class(cls):
|
| 8 |
+
return optree.register_pytree_node_class(cls, namespace="keras")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Register backend-specific node classes
|
| 12 |
+
if backend() == "tensorflow":
|
| 13 |
+
from tensorflow.python.trackable.data_structures import ListWrapper
|
| 14 |
+
from tensorflow.python.trackable.data_structures import _DictWrapper
|
| 15 |
+
|
| 16 |
+
optree.register_pytree_node(
|
| 17 |
+
ListWrapper,
|
| 18 |
+
lambda x: (x, None),
|
| 19 |
+
lambda metadata, children: ListWrapper(list(children)),
|
| 20 |
+
namespace="keras",
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
def sorted_keys_and_values(d):
|
| 24 |
+
keys = sorted(list(d.keys()))
|
| 25 |
+
values = [d[k] for k in keys]
|
| 26 |
+
return values, keys, keys
|
| 27 |
+
|
| 28 |
+
optree.register_pytree_node(
|
| 29 |
+
_DictWrapper,
|
| 30 |
+
sorted_keys_and_values,
|
| 31 |
+
lambda metadata, children: _DictWrapper(
|
| 32 |
+
{key: child for key, child in zip(metadata, children)}
|
| 33 |
+
),
|
| 34 |
+
namespace="keras",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def is_nested(structure):
|
| 39 |
+
return not optree.tree_is_leaf(
|
| 40 |
+
structure, none_is_leaf=True, namespace="keras"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def traverse(func, structure, top_down=True):
|
| 45 |
+
# From https://github.com/google/jax/pull/19695
|
| 46 |
+
def traverse_children():
|
| 47 |
+
children, treedef = optree.tree_flatten(
|
| 48 |
+
structure,
|
| 49 |
+
is_leaf=lambda x: x is not structure,
|
| 50 |
+
none_is_leaf=True,
|
| 51 |
+
namespace="keras",
|
| 52 |
+
)
|
| 53 |
+
if treedef.num_nodes == 1 and treedef.num_leaves == 1:
|
| 54 |
+
return structure
|
| 55 |
+
else:
|
| 56 |
+
return optree.tree_unflatten(
|
| 57 |
+
treedef,
|
| 58 |
+
[traverse(func, c, top_down=top_down) for c in children],
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
if top_down:
|
| 62 |
+
ret = func(structure)
|
| 63 |
+
if ret is None:
|
| 64 |
+
return traverse_children()
|
| 65 |
+
else:
|
| 66 |
+
traversed_structure = traverse_children()
|
| 67 |
+
ret = func(traversed_structure)
|
| 68 |
+
if ret is None:
|
| 69 |
+
return traversed_structure
|
| 70 |
+
# Detect MAP_TO_NONE without tree_api import to avoid circular import.
|
| 71 |
+
if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
|
| 72 |
+
return None
|
| 73 |
+
return ret
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def flatten(structure):
|
| 77 |
+
# optree.tree_flatten returns a pair (leaves, treespec) where the first
|
| 78 |
+
# element is a list of leaf values and the second element is a treespec
|
| 79 |
+
# representing the structure of the pytree.
|
| 80 |
+
leaves, _ = optree.tree_flatten(
|
| 81 |
+
structure, none_is_leaf=True, namespace="keras"
|
| 82 |
+
)
|
| 83 |
+
return leaves
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def flatten_with_path(structure):
|
| 87 |
+
paths, leaves, _ = optree.tree_flatten_with_path(
|
| 88 |
+
structure, none_is_leaf=True, namespace="keras"
|
| 89 |
+
)
|
| 90 |
+
return list(zip(paths, leaves))
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def map_structure(func, *structures):
|
| 94 |
+
if not structures:
|
| 95 |
+
raise ValueError("Must provide at least one structure")
|
| 96 |
+
|
| 97 |
+
# Add check for same structures, otherwise optree just maps to shallowest.
|
| 98 |
+
def func_with_check(*args):
|
| 99 |
+
if not all(
|
| 100 |
+
optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras")
|
| 101 |
+
for s in args
|
| 102 |
+
):
|
| 103 |
+
raise ValueError("Structures don't have the same nested structure.")
|
| 104 |
+
return func(*args)
|
| 105 |
+
|
| 106 |
+
map_func = func_with_check if len(structures) > 1 else func
|
| 107 |
+
|
| 108 |
+
return optree.tree_map(
|
| 109 |
+
map_func, *structures, none_is_leaf=True, namespace="keras"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def map_structure_up_to(shallow_structure, func, *structures):
|
| 114 |
+
if not structures:
|
| 115 |
+
raise ValueError("Must provide at least one structure")
|
| 116 |
+
|
| 117 |
+
# Add check that `shallow_structure` really is the shallowest.
|
| 118 |
+
# Also only call `func` on `structures` and not `shallow_structure`.
|
| 119 |
+
def func_with_check_without_shallow_structure(shallow, *args):
|
| 120 |
+
if not optree.tree_is_leaf(shallow):
|
| 121 |
+
raise ValueError("Structures don't have the same nested structure.")
|
| 122 |
+
return func(*args)
|
| 123 |
+
|
| 124 |
+
return optree.tree_map(
|
| 125 |
+
func_with_check_without_shallow_structure,
|
| 126 |
+
shallow_structure,
|
| 127 |
+
*structures,
|
| 128 |
+
none_is_leaf=True,
|
| 129 |
+
namespace="keras",
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def assert_same_structure(a, b):
|
| 134 |
+
def check(a_leaf, b_leaf):
|
| 135 |
+
if not optree.tree_is_leaf(
|
| 136 |
+
a_leaf, none_is_leaf=True, namespace="keras"
|
| 137 |
+
) or not optree.tree_is_leaf(
|
| 138 |
+
b_leaf, none_is_leaf=True, namespace="keras"
|
| 139 |
+
):
|
| 140 |
+
raise ValueError("Structures don't have the same nested structure.")
|
| 141 |
+
return None
|
| 142 |
+
|
| 143 |
+
optree.tree_map(check, a, b, none_is_leaf=True, namespace="keras")
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def assert_same_paths(a, b):
|
| 147 |
+
a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace="keras"))
|
| 148 |
+
b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace="keras"))
|
| 149 |
+
|
| 150 |
+
if a_paths != b_paths:
|
| 151 |
+
msg = "`a` and `b` don't have the same paths."
|
| 152 |
+
a_diff = a_paths.difference(b_paths)
|
| 153 |
+
if a_diff:
|
| 154 |
+
msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
|
| 155 |
+
b_diff = b_paths.difference(a_paths)
|
| 156 |
+
if b_diff:
|
| 157 |
+
msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
|
| 158 |
+
raise ValueError(msg)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def pack_sequence_as(structure, flat_sequence):
|
| 162 |
+
_, treespec = optree.tree_flatten(
|
| 163 |
+
structure, none_is_leaf=True, namespace="keras"
|
| 164 |
+
)
|
| 165 |
+
return optree.tree_unflatten(treespec, flat_sequence)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def lists_to_tuples(structure):
|
| 169 |
+
def list_to_tuple(instance):
|
| 170 |
+
return tuple(instance) if isinstance(instance, list) else None
|
| 171 |
+
|
| 172 |
+
return traverse(list_to_tuple, structure, top_down=False)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def map_shape_structure(func, structure):
|
| 176 |
+
def is_shape_tuple(x):
|
| 177 |
+
return isinstance(x, (list, tuple)) and all(
|
| 178 |
+
isinstance(e, (int, type(None))) for e in x
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
return optree.tree_map(
|
| 182 |
+
func,
|
| 183 |
+
structure,
|
| 184 |
+
is_leaf=is_shape_tuple,
|
| 185 |
+
none_is_leaf=True,
|
| 186 |
+
namespace="keras",
|
| 187 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
|
| 3 |
+
from keras.src.api_export import keras_export
|
| 4 |
+
from keras.src.utils.module_utils import dmtree
|
| 5 |
+
from keras.src.utils.module_utils import optree
|
| 6 |
+
|
| 7 |
+
if optree.available:
|
| 8 |
+
from keras.src.tree import optree_impl as tree_impl
|
| 9 |
+
elif dmtree.available:
|
| 10 |
+
from keras.src.tree import dmtree_impl as tree_impl
|
| 11 |
+
else:
|
| 12 |
+
raise ImportError(
|
| 13 |
+
"To use Keras, you need to have `optree` installed. "
|
| 14 |
+
"Install it via `pip install optree`"
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def register_tree_node_class(cls):
|
| 19 |
+
return tree_impl.register_tree_node_class(cls)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@keras_export("keras.tree.MAP_TO_NONE")
|
| 23 |
+
class MAP_TO_NONE:
|
| 24 |
+
"""Special value for use with `traverse()`."""
|
| 25 |
+
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@keras_export("keras.tree.is_nested")
|
| 30 |
+
def is_nested(structure):
|
| 31 |
+
"""Checks if a given structure is nested.
|
| 32 |
+
|
| 33 |
+
Examples:
|
| 34 |
+
|
| 35 |
+
>>> keras.tree.is_nested(42)
|
| 36 |
+
False
|
| 37 |
+
>>> keras.tree.is_nested({"foo": 42})
|
| 38 |
+
True
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
structure: A structure to check.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
`True` if a given structure is nested, i.e. is a sequence, a mapping,
|
| 45 |
+
or a namedtuple, and `False` otherwise.
|
| 46 |
+
"""
|
| 47 |
+
return tree_impl.is_nested(structure)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@keras_export("keras.tree.traverse")
|
| 51 |
+
def traverse(func, structure, top_down=True):
|
| 52 |
+
"""Traverses the given nested structure, applying the given function.
|
| 53 |
+
|
| 54 |
+
The traversal is depth-first. If `top_down` is True (default), parents
|
| 55 |
+
are returned before their children (giving the option to avoid traversing
|
| 56 |
+
into a sub-tree).
|
| 57 |
+
|
| 58 |
+
Examples:
|
| 59 |
+
|
| 60 |
+
>>> v = []
|
| 61 |
+
>>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True)
|
| 62 |
+
[(1, 2), [3], {'a': 4}]
|
| 63 |
+
>>> v
|
| 64 |
+
[[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]
|
| 65 |
+
|
| 66 |
+
>>> v = []
|
| 67 |
+
>>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False)
|
| 68 |
+
[(1, 2), [3], {'a': 4}]
|
| 69 |
+
>>> v
|
| 70 |
+
[1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
func: The function to be applied to each sub-nest of the structure.
|
| 74 |
+
|
| 75 |
+
When traversing top-down:
|
| 76 |
+
If `func(subtree) is None` the traversal continues into the
|
| 77 |
+
sub-tree.
|
| 78 |
+
If `func(subtree) is not None` the traversal does not continue
|
| 79 |
+
into the sub-tree. The sub-tree will be replaced by `func(subtree)`
|
| 80 |
+
in the returned structure (to replace the sub-tree with `None`, use
|
| 81 |
+
the special value `MAP_TO_NONE`).
|
| 82 |
+
|
| 83 |
+
When traversing bottom-up:
|
| 84 |
+
If `func(subtree) is None` the traversed sub-tree is returned
|
| 85 |
+
unaltered.
|
| 86 |
+
If `func(subtree) is not None` the sub-tree will be replaced by
|
| 87 |
+
`func(subtree)` in the returned structure (to replace the sub-tree
|
| 88 |
+
with None, use the special value `MAP_TO_NONE`).
|
| 89 |
+
|
| 90 |
+
structure: The structure to traverse.
|
| 91 |
+
top_down: If True, parent structures will be visited before their
|
| 92 |
+
children.
|
| 93 |
+
|
| 94 |
+
Returns:
|
| 95 |
+
The structured output from the traversal.
|
| 96 |
+
|
| 97 |
+
Raises:
|
| 98 |
+
TypeError: If `func` is not callable.
|
| 99 |
+
"""
|
| 100 |
+
return tree_impl.traverse(func, structure, top_down=top_down)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@keras_export("keras.tree.flatten")
|
| 104 |
+
def flatten(structure):
|
| 105 |
+
"""Flattens a possibly nested structure into a list.
|
| 106 |
+
|
| 107 |
+
In the case of dict instances, the sequence consists of the values,
|
| 108 |
+
sorted by key to ensure deterministic behavior. However, instances of
|
| 109 |
+
`collections.OrderedDict` are handled differently: their sequence order is
|
| 110 |
+
used instead of the sorted keys. The same convention is followed in
|
| 111 |
+
`pack_sequence_as`. This correctly unflattens dicts and `OrderedDict` after
|
| 112 |
+
they have been flattened, or vice-versa.
|
| 113 |
+
|
| 114 |
+
Dictionaries with non-sortable keys are not supported.
|
| 115 |
+
|
| 116 |
+
Examples:
|
| 117 |
+
|
| 118 |
+
>>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]])
|
| 119 |
+
[1, 2, 3, 4, 5, 6]
|
| 120 |
+
>>> keras.tree.flatten(None)
|
| 121 |
+
[None]
|
| 122 |
+
>>> keras.tree.flatten(1)
|
| 123 |
+
[1]
|
| 124 |
+
>>> keras.tree.flatten({100: 'world!', 6: 'Hello'})
|
| 125 |
+
['Hello', 'world!']
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
structure: An arbitrarily nested structure.
|
| 129 |
+
|
| 130 |
+
Returns:
|
| 131 |
+
A list, the flattened version of the input `structure`.
|
| 132 |
+
"""
|
| 133 |
+
return tree_impl.flatten(structure)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@keras_export("keras.tree.flatten_with_path")
|
| 137 |
+
def flatten_with_path(structure):
|
| 138 |
+
"""Flattens a possibly nested structure into a list.
|
| 139 |
+
|
| 140 |
+
This is a variant of flattens() which produces a
|
| 141 |
+
list of pairs: `(path, item)`. A path is a tuple of indices and/or keys
|
| 142 |
+
which uniquely identifies the position of the corresponding item.
|
| 143 |
+
|
| 144 |
+
Dictionaries with non-sortable keys are not supported.
|
| 145 |
+
|
| 146 |
+
Examples:
|
| 147 |
+
|
| 148 |
+
>>> keras.flatten_with_path([{"foo": 42}])
|
| 149 |
+
[((0, 'foo'), 42)]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
Args:
|
| 153 |
+
structure: An arbitrarily nested structure.
|
| 154 |
+
|
| 155 |
+
Returns:
|
| 156 |
+
A list of `(path, item)` pairs corresponding to the flattened
|
| 157 |
+
version of the input `structure`.
|
| 158 |
+
"""
|
| 159 |
+
return tree_impl.flatten_with_path(structure)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
@keras_export("keras.tree.map_structure")
|
| 163 |
+
def map_structure(func, *structures):
|
| 164 |
+
"""Maps `func` through given structures.
|
| 165 |
+
|
| 166 |
+
Examples:
|
| 167 |
+
|
| 168 |
+
>>> structure = [[1], [2], [3]]
|
| 169 |
+
>>> keras.tree.map_structure(lambda v: v**2, structure)
|
| 170 |
+
[[1], [4], [9]]
|
| 171 |
+
>>> keras.tree.map_structure(lambda x, y: x * y, structure, structure)
|
| 172 |
+
[[1], [4], [9]]
|
| 173 |
+
|
| 174 |
+
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
|
| 175 |
+
>>> structure = Foo(a=1, b=2)
|
| 176 |
+
>>> keras.tree.map_structure(lambda v: v * 2, structure)
|
| 177 |
+
Foo(a=2, b=4)
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
func: A callable that accepts as many arguments as there are structures.
|
| 181 |
+
*structures: Arbitrarily nested structures of the same layout.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
A new structure with the same layout as the given ones.
|
| 185 |
+
|
| 186 |
+
Raises:
|
| 187 |
+
TypeError: If `structures` is empty or `func` is not callable.
|
| 188 |
+
ValueError: If there is more than one items in `structures` and some of
|
| 189 |
+
the nested structures don't match according to the rules of
|
| 190 |
+
`assert_same_structure`.
|
| 191 |
+
"""
|
| 192 |
+
return tree_impl.map_structure(func, *structures)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@keras_export("keras.tree.map_structure_up_to")
|
| 196 |
+
def map_structure_up_to(shallow_structure, func, *structures):
|
| 197 |
+
"""Maps `func` through given structures up to `shallow_structure`.
|
| 198 |
+
|
| 199 |
+
This is a variant of `map_structure` which only maps the given structures
|
| 200 |
+
up to `shallow_structure`. All further nested components are retained as-is.
|
| 201 |
+
|
| 202 |
+
Examples:
|
| 203 |
+
|
| 204 |
+
>>> shallow_structure = [None, None]
|
| 205 |
+
>>> structure = [[1, 1], [2, 2]]
|
| 206 |
+
>>> keras.tree.map_structure_up_to(shallow_structure, len, structure)
|
| 207 |
+
[2, 2]
|
| 208 |
+
|
| 209 |
+
>>> shallow_structure = [None, [None, None]]
|
| 210 |
+
>>> keras.tree.map_structure_up_to(shallow_structure, str, structure)
|
| 211 |
+
['[1, 1]', ['2', '2']]
|
| 212 |
+
|
| 213 |
+
Args:
|
| 214 |
+
shallow_structure: A structure with layout common to all `structures`.
|
| 215 |
+
func: A callable that accepts as many arguments as there are structures.
|
| 216 |
+
*structures: Arbitrarily nested structures of the same layout.
|
| 217 |
+
|
| 218 |
+
Returns:
|
| 219 |
+
A new structure with the same layout as `shallow_structure`.
|
| 220 |
+
|
| 221 |
+
Raises:
|
| 222 |
+
TypeError: If `structures` is empty or `func` is not callable.
|
| 223 |
+
ValueError: If one of the items in `structures` doesn't match the
|
| 224 |
+
nested structure of `shallow_structure` according to the rules of
|
| 225 |
+
`assert_same_structure`. Items in `structures` are allowed to be
|
| 226 |
+
nested deeper than `shallow_structure`, but they cannot be
|
| 227 |
+
shallower.
|
| 228 |
+
"""
|
| 229 |
+
return tree_impl.map_structure_up_to(shallow_structure, func, *structures)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
@keras_export("keras.tree.assert_same_structure")
|
| 233 |
+
def assert_same_structure(a, b, check_types=None):
|
| 234 |
+
"""Asserts that two structures are nested in the same way.
|
| 235 |
+
|
| 236 |
+
This function verifies that the nested structures match. The leafs can be of
|
| 237 |
+
any type. At each level, the structures must be of the same type and have
|
| 238 |
+
the same number of elements. Instances of `dict`, `OrderedDict` and
|
| 239 |
+
`defaultdict` are all considered the same as long as they have the same set
|
| 240 |
+
of keys. However, `list`, `tuple`, `namedtuple` and `deque` are not the same
|
| 241 |
+
structures. Two namedtuples with identical fields and even identical names
|
| 242 |
+
are not the same structures.
|
| 243 |
+
|
| 244 |
+
Examples:
|
| 245 |
+
|
| 246 |
+
>>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)])
|
| 247 |
+
|
| 248 |
+
>>> Foo = collections.namedtuple('Foo', ['a', 'b'])
|
| 249 |
+
>>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b'])
|
| 250 |
+
>>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3))
|
| 251 |
+
>>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3))
|
| 252 |
+
Traceback (most recent call last):
|
| 253 |
+
...
|
| 254 |
+
ValueError: The two structures don't have the same nested structure.
|
| 255 |
+
...
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
a: an arbitrarily nested structure.
|
| 259 |
+
b: an arbitrarily nested structure.
|
| 260 |
+
check_types: Deprecated. The behavior of this flag was inconsistent, it
|
| 261 |
+
no longer has any effect. For a looser check, use
|
| 262 |
+
`assert_same_paths` instead, which considers `list`, `tuple`,
|
| 263 |
+
`namedtuple` and `deque` as matching structures.
|
| 264 |
+
|
| 265 |
+
Raises:
|
| 266 |
+
ValueError: If the two structures `a` and `b` don't match.
|
| 267 |
+
"""
|
| 268 |
+
if check_types is not None:
|
| 269 |
+
if check_types:
|
| 270 |
+
warnings.warn(
|
| 271 |
+
"The `check_types` argument is deprecated and no longer has "
|
| 272 |
+
"any effect, please remove.",
|
| 273 |
+
DeprecationWarning,
|
| 274 |
+
stacklevel=2,
|
| 275 |
+
)
|
| 276 |
+
else:
|
| 277 |
+
warnings.warn(
|
| 278 |
+
"The `check_types` argument is deprecated and no longer has "
|
| 279 |
+
"any effect. For a looser check, use "
|
| 280 |
+
"`keras.tree.assert_same_paths()`, which considers `list`, "
|
| 281 |
+
"`tuple`, `namedtuple` and `deque` as matching",
|
| 282 |
+
DeprecationWarning,
|
| 283 |
+
stacklevel=2,
|
| 284 |
+
)
|
| 285 |
+
return tree_impl.assert_same_structure(a, b)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@keras_export("keras.tree.assert_same_paths")
|
| 289 |
+
def assert_same_paths(a, b):
|
| 290 |
+
"""Asserts that two structures have identical paths in their tree structure.
|
| 291 |
+
|
| 292 |
+
This function verifies that two nested structures have the same paths.
|
| 293 |
+
Unlike `assert_same_structure`, this function only checks the paths
|
| 294 |
+
and ignores the collection types.
|
| 295 |
+
For Sequences, to path is the index: 0, 1, 2, etc. For Mappings, the path is
|
| 296 |
+
the key, for instance "a", "b", "c". Note that namedtuples also use indices
|
| 297 |
+
and not field names for the path.
|
| 298 |
+
|
| 299 |
+
Examples:
|
| 300 |
+
>>> keras.tree.assert_same_paths([0, 1], (2, 3))
|
| 301 |
+
>>> Point1 = collections.namedtuple('Point1', ['x', 'y'])
|
| 302 |
+
>>> Point2 = collections.namedtuple('Point2', ['x', 'y'])
|
| 303 |
+
>>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3))
|
| 304 |
+
|
| 305 |
+
Args:
|
| 306 |
+
a: an arbitrarily nested structure.
|
| 307 |
+
b: an arbitrarily nested structure.
|
| 308 |
+
|
| 309 |
+
Raises:
|
| 310 |
+
ValueError: If the paths in structure `a` don't match the paths in
|
| 311 |
+
structure `b`. The error message will include the specific paths
|
| 312 |
+
that differ.
|
| 313 |
+
"""
|
| 314 |
+
return tree_impl.assert_same_paths(a, b)
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@keras_export("keras.tree.pack_sequence_as")
|
| 318 |
+
def pack_sequence_as(structure, flat_sequence):
|
| 319 |
+
"""Returns a given flattened sequence packed into a given structure.
|
| 320 |
+
|
| 321 |
+
If `structure` is an atom, `flat_sequence` must be a single-item list; in
|
| 322 |
+
this case the return value is `flat_sequence[0]`.
|
| 323 |
+
|
| 324 |
+
If `structure` is or contains a dict instance, the keys will be sorted to
|
| 325 |
+
pack the flat sequence in deterministic order. However, instances of
|
| 326 |
+
`collections.OrderedDict` are handled differently: their sequence order is
|
| 327 |
+
used instead of the sorted keys. The same convention is followed in
|
| 328 |
+
`flatten`. This correctly repacks dicts and `OrderedDicts` after they have
|
| 329 |
+
been flattened, or vice-versa.
|
| 330 |
+
|
| 331 |
+
Dictionaries with non-sortable keys are not supported.
|
| 332 |
+
|
| 333 |
+
Examples:
|
| 334 |
+
|
| 335 |
+
>>> structure = {"key3": "", "key1": "", "key2": ""}
|
| 336 |
+
>>> flat_sequence = ["value1", "value2", "value3"]
|
| 337 |
+
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
| 338 |
+
{"key3": "value3", "key1": "value1", "key2": "value2"}
|
| 339 |
+
|
| 340 |
+
>>> structure = (("a", "b"), ("c", "d", "e"), "f")
|
| 341 |
+
>>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
| 342 |
+
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
| 343 |
+
((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
|
| 344 |
+
|
| 345 |
+
>>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")},
|
| 346 |
+
... "key1": {"e": "val1", "d": "val2"}}
|
| 347 |
+
>>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
|
| 348 |
+
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
| 349 |
+
{'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
|
| 350 |
+
|
| 351 |
+
>>> structure = ["a"]
|
| 352 |
+
>>> flat_sequence = [np.array([[1, 2], [3, 4]])]
|
| 353 |
+
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
| 354 |
+
[array([[1, 2],
|
| 355 |
+
[3, 4]])]
|
| 356 |
+
|
| 357 |
+
>>> structure = ["a"]
|
| 358 |
+
>>> flat_sequence = [keras.ops.ones([2, 2])]
|
| 359 |
+
>>> keras.tree.pack_sequence_as(structure, flat_sequence)
|
| 360 |
+
[array([[1., 1.],
|
| 361 |
+
[1., 1.]]]
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
structure: Arbitrarily nested structure.
|
| 365 |
+
flat_sequence: Flat sequence to pack.
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
`flat_sequence` converted to have the same recursive structure as
|
| 369 |
+
`structure`.
|
| 370 |
+
|
| 371 |
+
Raises:
|
| 372 |
+
TypeError: If `flat_sequence` is not iterable.
|
| 373 |
+
ValueError: If `flat_sequence` cannot be repacked as `structure`; for
|
| 374 |
+
instance, if `flat_sequence` has too few or too many elements.
|
| 375 |
+
"""
|
| 376 |
+
return tree_impl.pack_sequence_as(structure, flat_sequence)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
@keras_export("keras.tree.lists_to_tuples")
|
| 380 |
+
def lists_to_tuples(structure):
|
| 381 |
+
"""Returns the structure with list instances changed to tuples.
|
| 382 |
+
|
| 383 |
+
Args:
|
| 384 |
+
structure: Arbitrarily nested structure.
|
| 385 |
+
|
| 386 |
+
Returns:
|
| 387 |
+
The same structure but with tuples instead of lists.
|
| 388 |
+
"""
|
| 389 |
+
return tree_impl.lists_to_tuples(structure)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
@keras_export("keras.tree.map_shape_structure")
|
| 393 |
+
def map_shape_structure(func, structure):
|
| 394 |
+
"""Variant of keras.tree.map_structure that operates on shape tuples.
|
| 395 |
+
|
| 396 |
+
Tuples containing ints and Nones are considered shapes and passed to `func`.
|
| 397 |
+
|
| 398 |
+
Args:
|
| 399 |
+
structure: Arbitrarily nested structure.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
The same structure with `func` applied.
|
| 403 |
+
"""
|
| 404 |
+
return tree_impl.map_shape_structure(func, structure)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory
|
| 2 |
+
from keras.src.utils.dataset_utils import split_dataset
|
| 3 |
+
from keras.src.utils.file_utils import get_file
|
| 4 |
+
from keras.src.utils.image_dataset_utils import image_dataset_from_directory
|
| 5 |
+
from keras.src.utils.image_utils import array_to_img
|
| 6 |
+
from keras.src.utils.image_utils import img_to_array
|
| 7 |
+
from keras.src.utils.image_utils import load_img
|
| 8 |
+
from keras.src.utils.image_utils import save_img
|
| 9 |
+
from keras.src.utils.io_utils import disable_interactive_logging
|
| 10 |
+
from keras.src.utils.io_utils import enable_interactive_logging
|
| 11 |
+
from keras.src.utils.io_utils import is_interactive_logging_enabled
|
| 12 |
+
from keras.src.utils.model_visualization import model_to_dot
|
| 13 |
+
from keras.src.utils.model_visualization import plot_model
|
| 14 |
+
from keras.src.utils.numerical_utils import normalize
|
| 15 |
+
from keras.src.utils.numerical_utils import to_categorical
|
| 16 |
+
from keras.src.utils.progbar import Progbar
|
| 17 |
+
from keras.src.utils.python_utils import default
|
| 18 |
+
from keras.src.utils.python_utils import is_default
|
| 19 |
+
from keras.src.utils.python_utils import removeprefix
|
| 20 |
+
from keras.src.utils.python_utils import removesuffix
|
| 21 |
+
from keras.src.utils.rng_utils import set_random_seed
|
| 22 |
+
from keras.src.utils.sequence_utils import pad_sequences
|
| 23 |
+
from keras.src.utils.text_dataset_utils import text_dataset_from_directory
|
| 24 |
+
from keras.src.utils.timeseries_dataset_utils import (
|
| 25 |
+
timeseries_dataset_from_array,
|
| 26 |
+
)
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc
ADDED
|
Binary file (2.61 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc
ADDED
|
Binary file (10.3 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc
ADDED
|
Binary file (4.91 kB). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc
ADDED
|
Binary file (979 Bytes). View file
|
|
|
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (6 kB). View file
|
|
|