AIDUDE0541 commited on
Commit
26abf77
·
verified ·
1 Parent(s): 5d27ebd

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py +279 -0
  2. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py +1173 -0
  3. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py +808 -0
  4. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py +5 -0
  5. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc +0 -0
  6. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc +0 -0
  7. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc +0 -0
  8. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py +796 -0
  9. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py +163 -0
  10. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__init__.py +0 -0
  11. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc +0 -0
  12. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc +0 -0
  13. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc +0 -0
  14. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc +0 -0
  15. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py +820 -0
  16. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py +154 -0
  17. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc +0 -0
  18. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc +0 -0
  19. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc +0 -0
  20. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc +0 -0
  21. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc +0 -0
  22. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc +0 -0
  23. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc +0 -0
  24. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc +0 -0
  25. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc +0 -0
  26. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py +372 -0
  27. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py +520 -0
  28. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py +97 -0
  29. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py +329 -0
  30. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py +87 -0
  31. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py +692 -0
  32. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py +141 -0
  33. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py +83 -0
  34. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py +161 -0
  35. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py +1147 -0
  36. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py +12 -0
  37. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc +0 -0
  38. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc +0 -0
  39. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc +0 -0
  40. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc +0 -0
  41. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py +394 -0
  42. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py +187 -0
  43. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py +404 -0
  44. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py +26 -0
  45. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  46. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc +0 -0
  47. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc +0 -0
  48. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc +0 -0
  49. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc +0 -0
  50. SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc +0 -0
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+
4
+ from absl import logging
5
+
6
+ from keras.src.api_export import keras_export
7
+ from keras.src.legacy.saving import legacy_h5_format
8
+ from keras.src.saving import saving_lib
9
+ from keras.src.utils import file_utils
10
+ from keras.src.utils import io_utils
11
+
12
+ try:
13
+ import h5py
14
+ except ImportError:
15
+ h5py = None
16
+
17
+
18
+ @keras_export(["keras.saving.save_model", "keras.models.save_model"])
19
+ def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
20
+ """Saves a model as a `.keras` file.
21
+
22
+ Args:
23
+ model: Keras model instance to be saved.
24
+ filepath: `str` or `pathlib.Path` object. Path where to save the model.
25
+ overwrite: Whether we should overwrite any existing model at the target
26
+ location, or instead ask the user via an interactive prompt.
27
+ zipped: Whether to save the model as a zipped `.keras`
28
+ archive (default when saving locally), or as an unzipped directory
29
+ (default when saving on the Hugging Face Hub).
30
+
31
+ Example:
32
+
33
+ ```python
34
+ model = keras.Sequential(
35
+ [
36
+ keras.layers.Dense(5, input_shape=(3,)),
37
+ keras.layers.Softmax(),
38
+ ],
39
+ )
40
+ model.save("model.keras")
41
+ loaded_model = keras.saving.load_model("model.keras")
42
+ x = keras.random.uniform((10, 3))
43
+ assert np.allclose(model.predict(x), loaded_model.predict(x))
44
+ ```
45
+
46
+ Note that `model.save()` is an alias for `keras.saving.save_model()`.
47
+
48
+ The saved `.keras` file is a `zip` archive that contains:
49
+
50
+ - The model's configuration (architecture)
51
+ - The model's weights
52
+ - The model's optimizer's state (if any)
53
+
54
+ Thus models can be reinstantiated in the exact same state.
55
+ """
56
+ include_optimizer = kwargs.pop("include_optimizer", True)
57
+ save_format = kwargs.pop("save_format", False)
58
+ if save_format:
59
+ if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith(
60
+ ".keras"
61
+ ):
62
+ logging.warning(
63
+ "The `save_format` argument is deprecated in Keras 3. "
64
+ "We recommend removing this argument as it can be inferred "
65
+ "from the file path. "
66
+ f"Received: save_format={save_format}"
67
+ )
68
+ else:
69
+ raise ValueError(
70
+ "The `save_format` argument is deprecated in Keras 3. "
71
+ "Please remove this argument and pass a file path with "
72
+ "either `.keras` or `.h5` extension."
73
+ f"Received: save_format={save_format}"
74
+ )
75
+ if kwargs:
76
+ raise ValueError(
77
+ "The following argument(s) are not supported: "
78
+ f"{list(kwargs.keys())}"
79
+ )
80
+
81
+ # Deprecation warnings
82
+ if str(filepath).endswith((".h5", ".hdf5")):
83
+ logging.warning(
84
+ "You are saving your model as an HDF5 file via "
85
+ "`model.save()` or `keras.saving.save_model(model)`. "
86
+ "This file format is considered legacy. "
87
+ "We recommend using instead the native Keras format, "
88
+ "e.g. `model.save('my_model.keras')` or "
89
+ "`keras.saving.save_model(model, 'my_model.keras')`. "
90
+ )
91
+
92
+ is_hf = str(filepath).startswith("hf://")
93
+ if zipped is None:
94
+ zipped = not is_hf # default behavior depends on destination
95
+
96
+ # If file exists and should not be overwritten.
97
+ try:
98
+ exists = (not is_hf) and os.path.exists(filepath)
99
+ except TypeError:
100
+ exists = False
101
+ if exists and not overwrite:
102
+ proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
103
+ if not proceed:
104
+ return
105
+
106
+ if zipped and str(filepath).endswith(".keras"):
107
+ return saving_lib.save_model(model, filepath)
108
+ if not zipped:
109
+ return saving_lib.save_model(model, filepath, zipped=False)
110
+ if str(filepath).endswith((".h5", ".hdf5")):
111
+ return legacy_h5_format.save_model_to_hdf5(
112
+ model, filepath, overwrite, include_optimizer
113
+ )
114
+ raise ValueError(
115
+ "Invalid filepath extension for saving. "
116
+ "Please add either a `.keras` extension for the native Keras "
117
+ f"format (recommended) or a `.h5` extension. "
118
+ "Use `model.export(filepath)` if you want to export a SavedModel "
119
+ "for use with TFLite/TFServing/etc. "
120
+ f"Received: filepath={filepath}."
121
+ )
122
+
123
+
124
+ @keras_export(["keras.saving.load_model", "keras.models.load_model"])
125
+ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
126
+ """Loads a model saved via `model.save()`.
127
+
128
+ Args:
129
+ filepath: `str` or `pathlib.Path` object, path to the saved model file.
130
+ custom_objects: Optional dictionary mapping names
131
+ (strings) to custom classes or functions to be
132
+ considered during deserialization.
133
+ compile: Boolean, whether to compile the model after loading.
134
+ safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
135
+ When `safe_mode=False`, loading an object has the potential to
136
+ trigger arbitrary code execution. This argument is only
137
+ applicable to the Keras v3 model format. Defaults to `True`.
138
+
139
+ Returns:
140
+ A Keras model instance. If the original model was compiled,
141
+ and the argument `compile=True` is set, then the returned model
142
+ will be compiled. Otherwise, the model will be left uncompiled.
143
+
144
+ Example:
145
+
146
+ ```python
147
+ model = keras.Sequential([
148
+ keras.layers.Dense(5, input_shape=(3,)),
149
+ keras.layers.Softmax()])
150
+ model.save("model.keras")
151
+ loaded_model = keras.saving.load_model("model.keras")
152
+ x = np.random.random((10, 3))
153
+ assert np.allclose(model.predict(x), loaded_model.predict(x))
154
+ ```
155
+
156
+ Note that the model variables may have different name values
157
+ (`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
158
+ It is recommended that you use layer attributes to
159
+ access specific variables, e.g. `model.get_layer("dense_1").kernel`.
160
+ """
161
+ is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
162
+ filepath
163
+ )
164
+ is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
165
+ file_utils.join(filepath, "config.json")
166
+ )
167
+ is_hf = str(filepath).startswith("hf://")
168
+
169
+ # Support for remote zip files
170
+ if (
171
+ file_utils.is_remote_path(filepath)
172
+ and not file_utils.isdir(filepath)
173
+ and not is_keras_zip
174
+ and not is_hf
175
+ ):
176
+ local_path = file_utils.join(
177
+ saving_lib.get_temp_dir(), os.path.basename(filepath)
178
+ )
179
+
180
+ # Copy from remote to temporary local directory
181
+ file_utils.copy(filepath, local_path)
182
+
183
+ # Switch filepath to local zipfile for loading model
184
+ if zipfile.is_zipfile(local_path):
185
+ filepath = local_path
186
+ is_keras_zip = True
187
+
188
+ if is_keras_zip or is_keras_dir or is_hf:
189
+ return saving_lib.load_model(
190
+ filepath,
191
+ custom_objects=custom_objects,
192
+ compile=compile,
193
+ safe_mode=safe_mode,
194
+ )
195
+ if str(filepath).endswith((".h5", ".hdf5")):
196
+ return legacy_h5_format.load_model_from_hdf5(
197
+ filepath, custom_objects=custom_objects, compile=compile
198
+ )
199
+ elif str(filepath).endswith(".keras"):
200
+ raise ValueError(
201
+ f"File not found: filepath={filepath}. "
202
+ "Please ensure the file is an accessible `.keras` "
203
+ "zip file."
204
+ )
205
+ else:
206
+ raise ValueError(
207
+ f"File format not supported: filepath={filepath}. "
208
+ "Keras 3 only supports V3 `.keras` files and "
209
+ "legacy H5 format files (`.h5` extension). "
210
+ "Note that the legacy SavedModel format is not "
211
+ "supported by `load_model()` in Keras 3. In "
212
+ "order to reload a TensorFlow SavedModel as an "
213
+ "inference-only layer in Keras 3, use "
214
+ "`keras.layers.TFSMLayer("
215
+ f"{filepath}, call_endpoint='serving_default')` "
216
+ "(note that your `call_endpoint` "
217
+ "might have a different name)."
218
+ )
219
+
220
+
221
+ @keras_export("keras.saving.save_weights")
222
+ def save_weights(model, filepath, overwrite=True, **kwargs):
223
+ if not str(filepath).endswith(".weights.h5"):
224
+ raise ValueError(
225
+ "The filename must end in `.weights.h5`. "
226
+ f"Received: filepath={filepath}"
227
+ )
228
+ try:
229
+ exists = os.path.exists(filepath)
230
+ except TypeError:
231
+ exists = False
232
+ if exists and not overwrite:
233
+ proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
234
+ if not proceed:
235
+ return
236
+ saving_lib.save_weights_only(model, filepath, **kwargs)
237
+
238
+
239
+ @keras_export("keras.saving.load_weights")
240
+ def load_weights(model, filepath, skip_mismatch=False, **kwargs):
241
+ if str(filepath).endswith(".keras"):
242
+ if kwargs:
243
+ raise ValueError(f"Invalid keyword arguments: {kwargs}")
244
+ saving_lib.load_weights_only(
245
+ model, filepath, skip_mismatch=skip_mismatch
246
+ )
247
+ elif str(filepath).endswith(".weights.h5"):
248
+ objects_to_skip = kwargs.pop("objects_to_skip", None)
249
+ if kwargs:
250
+ raise ValueError(f"Invalid keyword arguments: {kwargs}")
251
+ saving_lib.load_weights_only(
252
+ model,
253
+ filepath,
254
+ skip_mismatch=skip_mismatch,
255
+ objects_to_skip=objects_to_skip,
256
+ )
257
+ elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"):
258
+ by_name = kwargs.pop("by_name", False)
259
+ if kwargs:
260
+ raise ValueError(f"Invalid keyword arguments: {kwargs}")
261
+ if not h5py:
262
+ raise ImportError(
263
+ "Loading a H5 file requires `h5py` to be installed."
264
+ )
265
+ with h5py.File(filepath, "r") as f:
266
+ if "layer_names" not in f.attrs and "model_weights" in f:
267
+ f = f["model_weights"]
268
+ if by_name:
269
+ legacy_h5_format.load_weights_from_hdf5_group_by_name(
270
+ f, model, skip_mismatch
271
+ )
272
+ else:
273
+ legacy_h5_format.load_weights_from_hdf5_group(f, model)
274
+ else:
275
+ raise ValueError(
276
+ f"File format not supported: filepath={filepath}. "
277
+ "Keras 3 only supports V3 `.keras` and `.weights.h5` "
278
+ "files, or legacy V1/V2 `.h5` files."
279
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py ADDED
@@ -0,0 +1,1173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Python-based idempotent model-saving functionality."""
2
+
3
+ import datetime
4
+ import io
5
+ import json
6
+ import os
7
+ import pathlib
8
+ import shutil
9
+ import tempfile
10
+ import warnings
11
+ import zipfile
12
+
13
+ import ml_dtypes
14
+ import numpy as np
15
+
16
+ from keras.src import backend
17
+ from keras.src.backend.common import global_state
18
+ from keras.src.layers.layer import Layer
19
+ from keras.src.losses.loss import Loss
20
+ from keras.src.metrics.metric import Metric
21
+ from keras.src.optimizers.optimizer import Optimizer
22
+ from keras.src.saving.serialization_lib import ObjectSharingScope
23
+ from keras.src.saving.serialization_lib import deserialize_keras_object
24
+ from keras.src.saving.serialization_lib import serialize_keras_object
25
+ from keras.src.trainers.compile_utils import CompileMetrics
26
+ from keras.src.utils import file_utils
27
+ from keras.src.utils import io_utils
28
+ from keras.src.utils import naming
29
+ from keras.src.utils import plot_model
30
+ from keras.src.utils.model_visualization import check_pydot
31
+ from keras.src.utils.summary_utils import weight_memory_size
32
+ from keras.src.version import __version__ as keras_version
33
+
34
+ try:
35
+ import h5py
36
+ except ImportError:
37
+ h5py = None
38
+ try:
39
+ import psutil
40
+ except ImportError:
41
+ psutil = None
42
+ try:
43
+ import huggingface_hub
44
+ except ImportError:
45
+ huggingface_hub = None
46
+
47
+
48
+ _CONFIG_FILENAME = "config.json"
49
+ _METADATA_FILENAME = "metadata.json"
50
+ _VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5"
51
+ _VARS_FNAME_H5 = _VARS_FNAME + ".h5"
52
+ _VARS_FNAME_NPZ = _VARS_FNAME + ".npz"
53
+ _ASSETS_DIRNAME = "assets"
54
+ _MEMORY_UPPER_BOUND = 0.5 # 50%
55
+
56
+
57
+ _MODEL_CARD_TEMPLATE = """
58
+ ---
59
+ library_name: keras
60
+ ---
61
+
62
+ This model has been uploaded using the Keras library and can be used with JAX,
63
+ TensorFlow, and PyTorch backends.
64
+
65
+ This model card has been generated automatically and should be completed by the
66
+ model author.
67
+ See [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for
68
+ more information.
69
+
70
+ For more details about the model architecture, check out
71
+ [config.json](./config.json)."""
72
+
73
+
74
+ def save_model(model, filepath, weights_format="h5", zipped=True):
75
+ """Save a zip-archive representing a Keras model to the given file or path.
76
+
77
+ The zip-based archive contains the following structure:
78
+
79
+ - JSON-based configuration file (config.json): Records of model, layer, and
80
+ other saveables' configuration.
81
+ - H5-based saveable state files, found in respective directories, such as
82
+ model/states.npz, model/dense_layer/states.npz, etc.
83
+ - Metadata file.
84
+
85
+ The states of Keras saveables (layers, optimizers, loss, and metrics) are
86
+ automatically saved as long as they can be discovered through the attributes
87
+ returned by `dir(Model)`. Typically, the state includes the variables
88
+ associated with the saveable, but some specially purposed layers may
89
+ contain more such as the vocabularies stored in the hashmaps. The saveables
90
+ define how their states are saved by exposing `save_state()` and
91
+ `load_state()` APIs.
92
+
93
+ For the case of layer states, the variables will be visited as long as
94
+ they are either 1) referenced via layer attributes, or 2) referenced via a
95
+ container (list, tuple, or dict), and the container is referenced via a
96
+ layer attribute.
97
+ """
98
+ if weights_format == "h5" and h5py is None:
99
+ raise ImportError("h5py must be installed in order to save a model.")
100
+
101
+ if not model.built:
102
+ warnings.warn(
103
+ "You are saving a model that has not yet been built. "
104
+ "It might not contain any weights yet. "
105
+ "Consider building the model first by calling it "
106
+ "on some data.",
107
+ stacklevel=2,
108
+ )
109
+
110
+ if isinstance(filepath, io.IOBase):
111
+ _save_model_to_fileobj(model, filepath, weights_format)
112
+ return
113
+
114
+ filepath = str(filepath)
115
+ is_hf = filepath.startswith("hf://")
116
+ if zipped and not filepath.endswith(".keras"):
117
+ raise ValueError(
118
+ "Invalid `filepath` argument: expected a `.keras` extension. "
119
+ f"Received: filepath={filepath}"
120
+ )
121
+ if not zipped and filepath.endswith(".keras"):
122
+ raise ValueError(
123
+ "When using `zipped=False`, the `filepath` argument should not "
124
+ f"end in `.keras`. Received: filepath={filepath}"
125
+ )
126
+ if zipped and is_hf:
127
+ raise ValueError(
128
+ "When saving to the Hugging Face Hub, you should not save the "
129
+ f"model as zipped. Received: filepath={filepath}, zipped={zipped}"
130
+ )
131
+ if is_hf:
132
+ _upload_model_to_hf(model, filepath, weights_format)
133
+ elif not zipped:
134
+ _save_model_to_dir(model, filepath, weights_format)
135
+ else:
136
+ if file_utils.is_remote_path(filepath):
137
+ # Remote path. Zip to local memory byte io and copy to remote
138
+ zip_filepath = io.BytesIO()
139
+ _save_model_to_fileobj(model, zip_filepath, weights_format)
140
+ with file_utils.File(filepath, "wb") as f:
141
+ f.write(zip_filepath.getvalue())
142
+ else:
143
+ with open(filepath, "wb") as f:
144
+ _save_model_to_fileobj(model, f, weights_format)
145
+
146
+
147
+ def _serialize_model_as_json(model):
148
+ with ObjectSharingScope():
149
+ serialized_model_dict = serialize_keras_object(model)
150
+ config_json = json.dumps(serialized_model_dict)
151
+ metadata_json = json.dumps(
152
+ {
153
+ "keras_version": keras_version,
154
+ "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"),
155
+ }
156
+ )
157
+ return config_json, metadata_json
158
+
159
+
160
+ def _save_model_to_dir(model, dirpath, weights_format):
161
+ if not file_utils.exists(dirpath):
162
+ file_utils.makedirs(dirpath)
163
+ config_json, metadata_json = _serialize_model_as_json(model)
164
+ with open(file_utils.join(dirpath, _METADATA_FILENAME), "w") as f:
165
+ f.write(metadata_json)
166
+ with open(file_utils.join(dirpath, _CONFIG_FILENAME), "w") as f:
167
+ f.write(config_json)
168
+ weights_filepath = file_utils.join(dirpath, _VARS_FNAME_H5)
169
+ assert_dirpath = file_utils.join(dirpath, _ASSETS_DIRNAME)
170
+ try:
171
+ if weights_format == "h5":
172
+ weights_store = H5IOStore(weights_filepath, mode="w")
173
+ elif weights_format == "npz":
174
+ weights_store = NpzIOStore(weights_filepath, mode="w")
175
+ else:
176
+ raise ValueError(
177
+ "Unknown `weights_format` argument. "
178
+ "Expected 'h5' or 'npz'. "
179
+ f"Received: weights_format={weights_format}"
180
+ )
181
+ asset_store = DiskIOStore(assert_dirpath, mode="w")
182
+ _save_state(
183
+ model,
184
+ weights_store=weights_store,
185
+ assets_store=asset_store,
186
+ inner_path="",
187
+ visited_saveables=set(),
188
+ )
189
+ finally:
190
+ weights_store.close()
191
+ asset_store.close()
192
+
193
+
194
+ def _save_model_to_fileobj(model, fileobj, weights_format):
195
+ config_json, metadata_json = _serialize_model_as_json(model)
196
+
197
+ with zipfile.ZipFile(fileobj, "w") as zf:
198
+ with zf.open(_METADATA_FILENAME, "w") as f:
199
+ f.write(metadata_json.encode())
200
+ with zf.open(_CONFIG_FILENAME, "w") as f:
201
+ f.write(config_json.encode())
202
+
203
+ weights_file_path = None
204
+ weights_store = None
205
+ asset_store = None
206
+ write_zf = False
207
+ try:
208
+ if weights_format == "h5":
209
+ try:
210
+ if is_memory_sufficient(model):
211
+ # Load the model weights into memory before writing
212
+ # .keras if the system memory is sufficient.
213
+ weights_store = H5IOStore(
214
+ _VARS_FNAME_H5, archive=zf, mode="w"
215
+ )
216
+ else:
217
+ # Try opening the .h5 file, then writing it to `zf` at
218
+ # the end of the function call. This is more memory
219
+ # efficient than writing the weights into memory first.
220
+ working_dir = pathlib.Path(fileobj.name).parent
221
+ weights_file_path = tempfile.NamedTemporaryFile(
222
+ dir=working_dir
223
+ )
224
+ weights_store = H5IOStore(
225
+ weights_file_path.name, mode="w"
226
+ )
227
+ write_zf = True
228
+ except:
229
+ # If we can't use the local disk for any reason, write the
230
+ # weights into memory first, which consumes more memory.
231
+ weights_store = H5IOStore(
232
+ _VARS_FNAME_H5, archive=zf, mode="w"
233
+ )
234
+ elif weights_format == "npz":
235
+ weights_store = NpzIOStore(
236
+ _VARS_FNAME_NPZ, archive=zf, mode="w"
237
+ )
238
+ else:
239
+ raise ValueError(
240
+ "Unknown `weights_format` argument. "
241
+ "Expected 'h5' or 'npz'. "
242
+ f"Received: weights_format={weights_format}"
243
+ )
244
+
245
+ asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w")
246
+
247
+ _save_state(
248
+ model,
249
+ weights_store=weights_store,
250
+ assets_store=asset_store,
251
+ inner_path="",
252
+ visited_saveables=set(),
253
+ )
254
+ except:
255
+ # Skip the final `zf.write` if any exception is raised
256
+ write_zf = False
257
+ if weights_store:
258
+ weights_store.archive = None
259
+ raise
260
+ finally:
261
+ if weights_store:
262
+ weights_store.close()
263
+ if asset_store:
264
+ asset_store.close()
265
+ if write_zf and weights_file_path:
266
+ zf.write(weights_file_path.name, _VARS_FNAME_H5)
267
+ if weights_file_path:
268
+ weights_file_path.close()
269
+
270
+
271
+ def _upload_model_to_hf(model, hf_path, weights_format):
272
+ if huggingface_hub is None:
273
+ raise ImportError(
274
+ "To save models to the Hugging Face Hub, "
275
+ "you must install the `huggingface_hub` package."
276
+ )
277
+
278
+ original_hf_path = hf_path
279
+ if hf_path.startswith("hf://"):
280
+ hf_path = hf_path[5:]
281
+ if hf_path.count("/") > 1:
282
+ raise ValueError(
283
+ "Invalid `hf_path` argument: expected `namespace/model_name`"
284
+ f" format. Received: hf_path={original_hf_path}"
285
+ )
286
+
287
+ api = huggingface_hub.HfApi(
288
+ library_name="keras", library_version=keras_version
289
+ )
290
+ repo_url = api.create_repo(hf_path, exist_ok=True)
291
+ repo_id = repo_url.repo_id
292
+
293
+ with tempfile.TemporaryDirectory() as tmp_dir:
294
+ _save_model_to_dir(model, tmp_dir, weights_format)
295
+
296
+ model_card = _MODEL_CARD_TEMPLATE
297
+
298
+ if check_pydot():
299
+ plot_path = file_utils.join(tmp_dir, "assets", "summary_plot.png")
300
+ plot_model(
301
+ model,
302
+ to_file=plot_path,
303
+ show_layer_names=True,
304
+ show_shapes=True,
305
+ show_dtype=True,
306
+ )
307
+ if len(model.layers) <= 10:
308
+ model_card += "\n\n![](./assets/summary_plot.png)"
309
+ else:
310
+ model_card += (
311
+ "A plot of the model can be found "
312
+ "[here](./assets/summary_plot.png)."
313
+ )
314
+
315
+ with open(file_utils.join(tmp_dir, "README.md"), "w") as f:
316
+ f.write(model_card)
317
+
318
+ api.upload_folder(
319
+ repo_id=repo_id,
320
+ folder_path=tmp_dir,
321
+ commit_message="Save model using Keras.",
322
+ )
323
+ io_utils.print_msg(
324
+ f"Model saved to the Hugging Face Hub: {repo_url}\n"
325
+ "To load back the model, use "
326
+ f"`keras.saving.load_model('hf://{repo_id}')`"
327
+ )
328
+
329
+
330
+ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
331
+ """Load a zip archive representing a Keras model."""
332
+ if isinstance(filepath, io.IOBase):
333
+ return _load_model_from_fileobj(
334
+ filepath, custom_objects, compile, safe_mode
335
+ )
336
+ elif str(filepath).startswith("hf://"):
337
+ if huggingface_hub is None:
338
+ raise ImportError(
339
+ "To load models from the Hugging Face Hub, "
340
+ "you must install the `huggingface_hub` package."
341
+ )
342
+
343
+ repo_id = filepath[5:]
344
+ folder_path = huggingface_hub.snapshot_download(
345
+ repo_id=repo_id,
346
+ library_name="keras",
347
+ library_version=keras_version,
348
+ )
349
+ return _load_model_from_dir(
350
+ folder_path, custom_objects, compile, safe_mode
351
+ )
352
+ else:
353
+ filepath = str(filepath)
354
+ if not filepath.endswith(".keras"):
355
+ is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
356
+ file_utils.join(filepath, "config.json")
357
+ )
358
+ if is_keras_dir:
359
+ return _load_model_from_dir(
360
+ filepath, custom_objects, compile, safe_mode
361
+ )
362
+ raise ValueError(
363
+ "Invalid filename: expected a `.keras` extension. "
364
+ f"Received: filepath={filepath}"
365
+ )
366
+ with open(filepath, "rb") as f:
367
+ return _load_model_from_fileobj(
368
+ f, custom_objects, compile, safe_mode
369
+ )
370
+
371
+
372
+ def _load_model_from_dir(dirpath, custom_objects, compile, safe_mode):
373
+ if not file_utils.exists(dirpath):
374
+ raise ValueError(f"Directory doesn't exist: {dirpath}")
375
+ if not file_utils.isdir(dirpath):
376
+ raise ValueError(f"Path isn't a directory: {dirpath}")
377
+
378
+ with open(file_utils.join(dirpath, _CONFIG_FILENAME), "r") as f:
379
+ config_json = f.read()
380
+ model = _model_from_config(config_json, custom_objects, compile, safe_mode)
381
+
382
+ all_filenames = file_utils.listdir(dirpath)
383
+ try:
384
+ if _VARS_FNAME_H5 in all_filenames:
385
+ weights_file_path = file_utils.join(dirpath, _VARS_FNAME_H5)
386
+ weights_store = H5IOStore(weights_file_path, mode="r")
387
+ elif _VARS_FNAME_NPZ in all_filenames:
388
+ weights_file_path = file_utils.join(dirpath, _VARS_FNAME_NPZ)
389
+ weights_store = NpzIOStore(weights_file_path, mode="r")
390
+ else:
391
+ raise ValueError(
392
+ f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file."
393
+ )
394
+ if len(all_filenames) > 3:
395
+ asset_store = DiskIOStore(
396
+ file_utils.join(dirpath, _ASSETS_DIRNAME), mode="r"
397
+ )
398
+
399
+ else:
400
+ asset_store = None
401
+
402
+ failed_saveables = set()
403
+ error_msgs = {}
404
+ _load_state(
405
+ model,
406
+ weights_store=weights_store,
407
+ assets_store=asset_store,
408
+ inner_path="",
409
+ visited_saveables=set(),
410
+ failed_saveables=failed_saveables,
411
+ error_msgs=error_msgs,
412
+ )
413
+
414
+ finally:
415
+ weights_store.close()
416
+ if asset_store:
417
+ asset_store.close()
418
+
419
+ if failed_saveables:
420
+ _raise_loading_failure(error_msgs)
421
+ return model
422
+
423
+
424
+ def _model_from_config(config_json, custom_objects, compile, safe_mode):
425
+ # Note: we should NOT use a custom JSON decoder. Anything that
426
+ # needs custom decoding must be handled in deserialize_keras_object.
427
+ config_dict = json.loads(config_json)
428
+ if not compile:
429
+ # Disable compilation
430
+ config_dict["compile_config"] = None
431
+ # Construct the model from the configuration file in the archive.
432
+ with ObjectSharingScope():
433
+ model = deserialize_keras_object(
434
+ config_dict, custom_objects, safe_mode=safe_mode
435
+ )
436
+ return model
437
+
438
+
439
+ def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode):
440
+ with zipfile.ZipFile(fileobj, "r") as zf:
441
+ with zf.open(_CONFIG_FILENAME, "r") as f:
442
+ config_json = f.read()
443
+
444
+ model = _model_from_config(
445
+ config_json, custom_objects, compile, safe_mode
446
+ )
447
+
448
+ all_filenames = zf.namelist()
449
+ extract_dir = None
450
+ weights_store = None
451
+ asset_store = None
452
+ try:
453
+ if _VARS_FNAME_H5 in all_filenames:
454
+ try:
455
+ if is_memory_sufficient(model):
456
+ # Load the entire file into memory if the system memory
457
+ # is sufficient.
458
+ io_file = io.BytesIO(
459
+ zf.open(_VARS_FNAME_H5, "r").read()
460
+ )
461
+ weights_store = H5IOStore(io_file, mode="r")
462
+ else:
463
+ # Try extracting the model.weights.h5 file, and then
464
+ # loading it using using h5py. This is significantly
465
+ # faster than reading from the zip archive on the fly.
466
+ extract_dir = tempfile.TemporaryDirectory(
467
+ dir=pathlib.Path(fileobj.name).parent
468
+ )
469
+ zf.extract(_VARS_FNAME_H5, extract_dir.name)
470
+ weights_store = H5IOStore(
471
+ pathlib.Path(extract_dir.name, _VARS_FNAME_H5),
472
+ mode="r",
473
+ )
474
+ except:
475
+ # If we can't use the local disk for any reason, read the
476
+ # weights from the zip archive on the fly, which is less
477
+ # efficient.
478
+ weights_store = H5IOStore(_VARS_FNAME_H5, zf, mode="r")
479
+ elif _VARS_FNAME_NPZ in all_filenames:
480
+ weights_store = NpzIOStore(_VARS_FNAME_NPZ, zf, mode="r")
481
+ else:
482
+ raise ValueError(
483
+ f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file."
484
+ )
485
+
486
+ if len(all_filenames) > 3:
487
+ asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r")
488
+
489
+ failed_saveables = set()
490
+ error_msgs = {}
491
+ _load_state(
492
+ model,
493
+ weights_store=weights_store,
494
+ assets_store=asset_store,
495
+ inner_path="",
496
+ visited_saveables=set(),
497
+ failed_saveables=failed_saveables,
498
+ error_msgs=error_msgs,
499
+ )
500
+ finally:
501
+ if weights_store:
502
+ weights_store.close()
503
+ if asset_store:
504
+ asset_store.close()
505
+ if extract_dir:
506
+ extract_dir.cleanup()
507
+
508
+ if failed_saveables:
509
+ _raise_loading_failure(error_msgs)
510
+ return model
511
+
512
+
513
+ def save_weights_only(model, filepath, objects_to_skip=None):
514
+ """Save only the weights of a model to a target filepath.
515
+
516
+ Supports both `.weights.h5` and `.keras`.
517
+ """
518
+ if not model.built:
519
+ raise ValueError(
520
+ "You are saving a model that has not yet been built. "
521
+ "Try building the model first by calling it on some data or "
522
+ "by using `build()`."
523
+ )
524
+
525
+ filepath = str(filepath)
526
+ tmp_dir = None
527
+ remote_filepath = None
528
+ if not filepath.endswith(".weights.h5"):
529
+ raise ValueError(
530
+ "Invalid `filepath` argument: expected a `.weights.h5` extension. "
531
+ f"Received: filepath={filepath}"
532
+ )
533
+ try:
534
+ if file_utils.is_remote_path(filepath):
535
+ tmp_dir = get_temp_dir()
536
+ local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))
537
+ remote_filepath = filepath
538
+ filepath = local_filepath
539
+
540
+ weights_store = H5IOStore(filepath, mode="w")
541
+ if objects_to_skip is not None:
542
+ visited_saveables = set(id(o) for o in objects_to_skip)
543
+ else:
544
+ visited_saveables = set()
545
+ _save_state(
546
+ model,
547
+ weights_store=weights_store,
548
+ assets_store=None,
549
+ inner_path="",
550
+ visited_saveables=visited_saveables,
551
+ )
552
+ weights_store.close()
553
+ finally:
554
+ if tmp_dir is not None:
555
+ file_utils.copy(filepath, remote_filepath)
556
+ shutil.rmtree(tmp_dir)
557
+
558
+
559
+ def load_weights_only(
560
+ model, filepath, skip_mismatch=False, objects_to_skip=None
561
+ ):
562
+ """Load the weights of a model from a filepath (.keras or .weights.h5).
563
+
564
+ Note: only supports h5 for now.
565
+ """
566
+ if not model.built:
567
+ raise ValueError(
568
+ "You are loading weights into a model that has not yet been built. "
569
+ "Try building the model first by calling it on some data or "
570
+ "by using `build()`."
571
+ )
572
+
573
+ archive = None
574
+ tmp_dir = None
575
+ filepath = str(filepath)
576
+
577
+ try:
578
+ if file_utils.is_remote_path(filepath):
579
+ tmp_dir = get_temp_dir()
580
+ local_filepath = os.path.join(tmp_dir, os.path.basename(filepath))
581
+ file_utils.copy(filepath, local_filepath)
582
+ filepath = local_filepath
583
+
584
+ if filepath.endswith(".weights.h5"):
585
+ weights_store = H5IOStore(filepath, mode="r")
586
+ elif filepath.endswith(".keras"):
587
+ archive = zipfile.ZipFile(filepath, "r")
588
+ weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode="r")
589
+
590
+ failed_saveables = set()
591
+ if objects_to_skip is not None:
592
+ visited_saveables = set(id(o) for o in objects_to_skip)
593
+ else:
594
+ visited_saveables = set()
595
+ error_msgs = {}
596
+ _load_state(
597
+ model,
598
+ weights_store=weights_store,
599
+ assets_store=None,
600
+ inner_path="",
601
+ skip_mismatch=skip_mismatch,
602
+ visited_saveables=visited_saveables,
603
+ failed_saveables=failed_saveables,
604
+ error_msgs=error_msgs,
605
+ )
606
+ weights_store.close()
607
+ if archive:
608
+ archive.close()
609
+
610
+ if failed_saveables:
611
+ _raise_loading_failure(error_msgs, warn_only=skip_mismatch)
612
+ finally:
613
+ if tmp_dir is not None:
614
+ shutil.rmtree(tmp_dir)
615
+
616
+
617
+ def _raise_loading_failure(error_msgs, warn_only=False):
618
+ first_key = list(error_msgs.keys())[0]
619
+ ex_saveable, ex_error = error_msgs[first_key]
620
+ msg = (
621
+ f"A total of {len(error_msgs)} objects could not "
622
+ "be loaded. Example error message for "
623
+ f"object {ex_saveable}:\n\n"
624
+ f"{ex_error}\n\n"
625
+ "List of objects that could not be loaded:\n"
626
+ f"{[x[0] for x in error_msgs.values()]}"
627
+ )
628
+ if warn_only:
629
+ warnings.warn(msg)
630
+ else:
631
+ raise ValueError(msg)
632
+
633
+
634
+ def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path):
635
+ if not file_utils.isdir(system_path):
636
+ zipfile_to_save.write(system_path, zip_path)
637
+ else:
638
+ for file_name in file_utils.listdir(system_path):
639
+ system_file_path = file_utils.join(system_path, file_name).replace(
640
+ "\\", "/"
641
+ )
642
+ zip_file_path = file_utils.join(zip_path, file_name).replace(
643
+ "\\", "/"
644
+ )
645
+ _write_to_zip_recursively(
646
+ zipfile_to_save, system_file_path, zip_file_path
647
+ )
648
+
649
+
650
+ def _name_key(name):
651
+ """Make sure that private attributes are visited last."""
652
+ if name.startswith("_"):
653
+ return "~" + name
654
+ return name
655
+
656
+
657
+ def _walk_saveable(saveable):
658
+ from keras.src.saving.keras_saveable import KerasSaveable
659
+
660
+ if not isinstance(saveable, KerasSaveable):
661
+ raise ValueError(
662
+ "Expected object to be an "
663
+ "instance of `KerasSaveable`, but "
664
+ f"got {saveable} of type {type(saveable)}"
665
+ )
666
+
667
+ obj_type = saveable._obj_type()
668
+ attr_skipset = get_attr_skipset(obj_type)
669
+
670
+ # Save all layers directly tracked by Sequential and Functional first.
671
+ # This helps avoid ordering concerns for subclassed Sequential or Functional
672
+ # models with extra attributes--the internal Keras state take precedence.
673
+ if obj_type in ("Sequential", "Functional"):
674
+ yield "layers", saveable.layers
675
+
676
+ for child_attr in sorted(dir(saveable), key=lambda x: _name_key(x)):
677
+ if child_attr.startswith("__") or child_attr in attr_skipset:
678
+ continue
679
+ try:
680
+ child_obj = getattr(saveable, child_attr)
681
+ except Exception:
682
+ # Avoid raising the exception when visiting the attributes.
683
+ continue
684
+ yield child_attr, child_obj
685
+
686
+
687
+ def _save_state(
688
+ saveable,
689
+ weights_store,
690
+ assets_store,
691
+ inner_path,
692
+ visited_saveables,
693
+ ):
694
+ from keras.src.saving.keras_saveable import KerasSaveable
695
+
696
+ # If the saveable has already been saved, skip it.
697
+ if id(saveable) in visited_saveables:
698
+ return
699
+
700
+ if hasattr(saveable, "save_own_variables") and weights_store:
701
+ if hasattr(saveable, "name") and isinstance(saveable.name, str):
702
+ metadata = {"name": saveable.name}
703
+ else:
704
+ metadata = None
705
+ saveable.save_own_variables(
706
+ weights_store.make(inner_path, metadata=metadata)
707
+ )
708
+ if hasattr(saveable, "save_assets") and assets_store:
709
+ saveable.save_assets(assets_store.make(inner_path))
710
+
711
+ visited_saveables.add(id(saveable))
712
+
713
+ # Recursively save state of children saveables (layers, optimizers, etc.)
714
+ for child_attr, child_obj in _walk_saveable(saveable):
715
+ if isinstance(child_obj, KerasSaveable):
716
+ _save_state(
717
+ child_obj,
718
+ weights_store,
719
+ assets_store,
720
+ inner_path=file_utils.join(inner_path, child_attr).replace(
721
+ "\\", "/"
722
+ ),
723
+ visited_saveables=visited_saveables,
724
+ )
725
+ elif isinstance(child_obj, (list, dict, tuple, set)):
726
+ _save_container_state(
727
+ child_obj,
728
+ weights_store,
729
+ assets_store,
730
+ inner_path=file_utils.join(inner_path, child_attr).replace(
731
+ "\\", "/"
732
+ ),
733
+ visited_saveables=visited_saveables,
734
+ )
735
+
736
+
737
+ def _load_state(
738
+ saveable,
739
+ weights_store,
740
+ assets_store,
741
+ inner_path,
742
+ skip_mismatch=False,
743
+ visited_saveables=None,
744
+ failed_saveables=None,
745
+ error_msgs=None,
746
+ ):
747
+ from keras.src.saving.keras_saveable import KerasSaveable
748
+
749
+ if visited_saveables and id(saveable) in visited_saveables:
750
+ return
751
+
752
+ failure = False
753
+
754
+ if hasattr(saveable, "load_own_variables") and weights_store:
755
+ if skip_mismatch or failed_saveables is not None:
756
+ try:
757
+ saveable.load_own_variables(weights_store.get(inner_path))
758
+ except Exception as e:
759
+ failed_saveables.add(id(saveable))
760
+ error_msgs[id(saveable)] = saveable, e
761
+ failure = True
762
+ else:
763
+ saveable.load_own_variables(weights_store.get(inner_path))
764
+
765
+ if hasattr(saveable, "load_assets") and assets_store:
766
+ if skip_mismatch or failed_saveables is not None:
767
+ try:
768
+ saveable.load_assets(assets_store.get(inner_path))
769
+ except Exception as e:
770
+ failed_saveables.add(id(saveable))
771
+ error_msgs[id(saveable)] = saveable, e
772
+ failure = True
773
+ else:
774
+ saveable.load_assets(assets_store.get(inner_path))
775
+
776
+ if failed_saveables is not None:
777
+ currently_failed = len(failed_saveables)
778
+ else:
779
+ currently_failed = 0
780
+
781
+ # Recursively load states for Keras saveables such as layers/optimizers.
782
+ for child_attr, child_obj in _walk_saveable(saveable):
783
+ if isinstance(child_obj, KerasSaveable):
784
+ _load_state(
785
+ child_obj,
786
+ weights_store,
787
+ assets_store,
788
+ inner_path=file_utils.join(inner_path, child_attr).replace(
789
+ "\\", "/"
790
+ ),
791
+ skip_mismatch=skip_mismatch,
792
+ visited_saveables=visited_saveables,
793
+ failed_saveables=failed_saveables,
794
+ error_msgs=error_msgs,
795
+ )
796
+ elif isinstance(child_obj, (list, dict, tuple, set)):
797
+ _load_container_state(
798
+ child_obj,
799
+ weights_store,
800
+ assets_store,
801
+ inner_path=file_utils.join(inner_path, child_attr).replace(
802
+ "\\", "/"
803
+ ),
804
+ skip_mismatch=skip_mismatch,
805
+ visited_saveables=visited_saveables,
806
+ failed_saveables=failed_saveables,
807
+ error_msgs=error_msgs,
808
+ )
809
+
810
+ if failed_saveables is not None:
811
+ newly_failed = len(failed_saveables) - currently_failed
812
+ else:
813
+ newly_failed = 0
814
+
815
+ if not failure:
816
+ if visited_saveables is not None and newly_failed <= 0:
817
+ visited_saveables.add(id(saveable))
818
+ if id(saveable) in failed_saveables:
819
+ failed_saveables.remove(id(saveable))
820
+ error_msgs.pop(id(saveable))
821
+
822
+
823
+ def _save_container_state(
824
+ container, weights_store, assets_store, inner_path, visited_saveables
825
+ ):
826
+ from keras.src.saving.keras_saveable import KerasSaveable
827
+
828
+ used_names = {}
829
+ if isinstance(container, dict):
830
+ container = list(container.values())
831
+
832
+ for saveable in container:
833
+ if isinstance(saveable, KerasSaveable):
834
+ # Do NOT address the saveable via `saveable.name`, since
835
+ # names are usually autogenerated and thus not reproducible
836
+ # (i.e. they may vary across two instances of the same model).
837
+ name = naming.to_snake_case(saveable.__class__.__name__)
838
+ if name in used_names:
839
+ used_names[name] += 1
840
+ name = f"{name}_{used_names[name]}"
841
+ else:
842
+ used_names[name] = 0
843
+ _save_state(
844
+ saveable,
845
+ weights_store,
846
+ assets_store,
847
+ inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
848
+ visited_saveables=visited_saveables,
849
+ )
850
+
851
+
852
+ def _load_container_state(
853
+ container,
854
+ weights_store,
855
+ assets_store,
856
+ inner_path,
857
+ skip_mismatch,
858
+ visited_saveables,
859
+ failed_saveables,
860
+ error_msgs,
861
+ ):
862
+ from keras.src.saving.keras_saveable import KerasSaveable
863
+
864
+ used_names = {}
865
+ if isinstance(container, dict):
866
+ container = list(container.values())
867
+
868
+ for saveable in container:
869
+ if isinstance(saveable, KerasSaveable):
870
+ name = naming.to_snake_case(saveable.__class__.__name__)
871
+ if name in used_names:
872
+ used_names[name] += 1
873
+ name = f"{name}_{used_names[name]}"
874
+ else:
875
+ used_names[name] = 0
876
+ _load_state(
877
+ saveable,
878
+ weights_store,
879
+ assets_store,
880
+ inner_path=file_utils.join(inner_path, name).replace("\\", "/"),
881
+ skip_mismatch=skip_mismatch,
882
+ visited_saveables=visited_saveables,
883
+ failed_saveables=failed_saveables,
884
+ error_msgs=error_msgs,
885
+ )
886
+
887
+
888
+ class DiskIOStore:
889
+ """Asset store backed by disk storage.
890
+
891
+ If `archive` is specified, then `root_path` refers to the filename
892
+ inside the archive.
893
+
894
+ If `archive` is not specified, then `root_path` refers to the full path of
895
+ the target directory.
896
+ """
897
+
898
+ def __init__(self, root_path, archive=None, mode=None):
899
+ self.mode = mode
900
+ self.root_path = root_path
901
+ self.archive = archive
902
+ self.tmp_dir = None
903
+ if self.archive:
904
+ self.tmp_dir = get_temp_dir()
905
+ if self.mode == "r":
906
+ self.archive.extractall(path=self.tmp_dir)
907
+ self.working_dir = file_utils.join(
908
+ self.tmp_dir, self.root_path
909
+ ).replace("\\", "/")
910
+ if self.mode == "w":
911
+ file_utils.makedirs(self.working_dir)
912
+ else:
913
+ if mode == "r":
914
+ self.working_dir = root_path
915
+ else:
916
+ self.tmp_dir = get_temp_dir()
917
+ self.working_dir = file_utils.join(
918
+ self.tmp_dir, self.root_path
919
+ ).replace("\\", "/")
920
+ file_utils.makedirs(self.working_dir)
921
+
922
+ def make(self, path):
923
+ if not path:
924
+ return self.working_dir
925
+ path = file_utils.join(self.working_dir, path).replace("\\", "/")
926
+ if not file_utils.exists(path):
927
+ file_utils.makedirs(path)
928
+ return path
929
+
930
+ def get(self, path):
931
+ if not path:
932
+ return self.working_dir
933
+ path = file_utils.join(self.working_dir, path).replace("\\", "/")
934
+ if file_utils.exists(path):
935
+ return path
936
+ return None
937
+
938
+ def close(self):
939
+ if self.mode == "w" and self.archive:
940
+ _write_to_zip_recursively(
941
+ self.archive, self.working_dir, self.root_path
942
+ )
943
+ if self.tmp_dir and file_utils.exists(self.tmp_dir):
944
+ file_utils.rmtree(self.tmp_dir)
945
+
946
+
947
+ class H5IOStore:
948
+ def __init__(self, root_path, archive=None, mode="r"):
949
+ """Numerical variable store backed by HDF5.
950
+
951
+ If `archive` is specified, then `root_path` refers to the filename
952
+ inside the archive.
953
+
954
+ If `archive` is not specified, then `root_path` refers to the path of
955
+ the h5 file on disk.
956
+ """
957
+ self.root_path = root_path
958
+ self.mode = mode
959
+ self.archive = archive
960
+ self.io_file = None
961
+
962
+ if self.archive:
963
+ if self.mode == "w":
964
+ self.io_file = io.BytesIO()
965
+ else:
966
+ self.io_file = self.archive.open(self.root_path, "r")
967
+ self.h5_file = h5py.File(self.io_file, mode=self.mode)
968
+ else:
969
+ self.h5_file = h5py.File(root_path, mode=self.mode)
970
+
971
+ def make(self, path, metadata=None):
972
+ return H5Entry(self.h5_file, path, mode="w", metadata=metadata)
973
+
974
+ def get(self, path):
975
+ return H5Entry(self.h5_file, path, mode="r")
976
+
977
+ def close(self):
978
+ self.h5_file.close()
979
+ if self.mode == "w" and self.archive:
980
+ self.archive.writestr(self.root_path, self.io_file.getvalue())
981
+ if self.io_file:
982
+ self.io_file.close()
983
+
984
+
985
+ class H5Entry:
986
+ """Leaf entry in a H5IOStore."""
987
+
988
+ def __init__(self, h5_file, path, mode, metadata=None):
989
+ self.h5_file = h5_file
990
+ self.path = path
991
+ self.mode = mode
992
+ self.metadata = metadata
993
+
994
+ if mode == "w":
995
+ if not path:
996
+ self.group = self.h5_file.create_group("vars")
997
+ else:
998
+ self.group = self.h5_file.create_group(self.path).create_group(
999
+ "vars"
1000
+ )
1001
+ if self.metadata:
1002
+ for k, v in self.metadata.items():
1003
+ self.group.attrs[k] = v
1004
+ else:
1005
+ found = False
1006
+ if not path:
1007
+ if "vars" in self.h5_file:
1008
+ self.group = self.h5_file["vars"]
1009
+ found = True
1010
+ elif path in self.h5_file and "vars" in self.h5_file[path]:
1011
+ self.group = self.h5_file[path]["vars"]
1012
+ found = True
1013
+ else:
1014
+ # No hit.
1015
+ # Fix for 2.13 compatibility
1016
+ if "_layer_checkpoint_dependencies" in self.h5_file:
1017
+ path = path.replace(
1018
+ "layers", "_layer_checkpoint_dependencies"
1019
+ )
1020
+ self.path = path
1021
+ if path in self.h5_file and "vars" in self.h5_file[path]:
1022
+ self.group = self.h5_file[path]["vars"]
1023
+ found = True
1024
+ if not found:
1025
+ self.group = {}
1026
+
1027
+ def __len__(self):
1028
+ return self.group.__len__()
1029
+
1030
+ def keys(self):
1031
+ return self.group.keys()
1032
+
1033
+ def items(self):
1034
+ return self.group.items()
1035
+
1036
+ def values(self):
1037
+ return self.group.values()
1038
+
1039
+ def __setitem__(self, key, value):
1040
+ if self.mode != "w":
1041
+ raise ValueError("Setting a value is only allowed in write mode.")
1042
+ value = backend.convert_to_numpy(value)
1043
+ if backend.standardize_dtype(value.dtype) == "bfloat16":
1044
+ ds = self.group.create_dataset(key, data=value)
1045
+ ds.attrs["dtype"] = "bfloat16"
1046
+ else:
1047
+ self.group[key] = value
1048
+
1049
+ def __getitem__(self, name):
1050
+ value = self.group[name]
1051
+ if "dtype" in value.attrs and value.attrs["dtype"] == "bfloat16":
1052
+ value = np.array(value, dtype=ml_dtypes.bfloat16)
1053
+ return value
1054
+
1055
+
1056
+ class NpzIOStore:
1057
+ def __init__(self, root_path, archive=None, mode="r"):
1058
+ """Numerical variable store backed by NumPy.savez/load.
1059
+
1060
+ If `archive` is specified, then `root_path` refers to the filename
1061
+ inside the archive.
1062
+
1063
+ If `archive` is not specified, then `root_path` refers to the path of
1064
+ the npz file on disk.
1065
+ """
1066
+ self.root_path = root_path
1067
+ self.mode = mode
1068
+ self.archive = archive
1069
+ if mode == "w":
1070
+ self.contents = {}
1071
+ else:
1072
+ if self.archive:
1073
+ self.f = archive.open(root_path, mode="r")
1074
+ else:
1075
+ self.f = open(root_path, mode="rb")
1076
+ self.contents = np.load(self.f, allow_pickle=True)
1077
+
1078
+ def make(self, path, metadata=None):
1079
+ if not path:
1080
+ self.contents["__root__"] = {}
1081
+ return self.contents["__root__"]
1082
+ self.contents[path] = {}
1083
+ return self.contents[path]
1084
+
1085
+ def get(self, path):
1086
+ if not path:
1087
+ if "__root__" in self.contents:
1088
+ return dict(self.contents["__root__"])
1089
+ return {}
1090
+ if path in self.contents:
1091
+ return self.contents[path].tolist()
1092
+ return {}
1093
+
1094
+ def close(self):
1095
+ if self.mode == "w":
1096
+ if self.archive:
1097
+ self.f = self.archive.open(
1098
+ self.root_path, mode="w", force_zip64=True
1099
+ )
1100
+ else:
1101
+ self.f = open(self.root_path, mode="wb")
1102
+ np.savez(self.f, **self.contents)
1103
+ self.f.close()
1104
+
1105
+
1106
+ def get_temp_dir():
1107
+ temp_dir = tempfile.mkdtemp()
1108
+ testfile = tempfile.TemporaryFile(dir=temp_dir)
1109
+ testfile.close()
1110
+ return temp_dir
1111
+
1112
+
1113
+ def get_attr_skipset(obj_type):
1114
+ skipset = global_state.get_global_attribute(
1115
+ f"saving_attr_skiplist_{obj_type}", None
1116
+ )
1117
+ if skipset is not None:
1118
+ return skipset
1119
+
1120
+ skipset = set(
1121
+ [
1122
+ "_self_unconditional_dependency_names",
1123
+ ]
1124
+ )
1125
+ if obj_type == "Layer":
1126
+ ref_obj = Layer()
1127
+ skipset.update(dir(ref_obj))
1128
+ elif obj_type == "Functional":
1129
+ ref_obj = Layer()
1130
+ skipset.update(dir(ref_obj) + ["operations", "_operations"])
1131
+ elif obj_type == "Sequential":
1132
+ ref_obj = Layer()
1133
+ skipset.update(dir(ref_obj) + ["_functional"])
1134
+ elif obj_type == "Metric":
1135
+ ref_obj_a = Metric()
1136
+ ref_obj_b = CompileMetrics([], [])
1137
+ skipset.update(dir(ref_obj_a) + dir(ref_obj_b))
1138
+ elif obj_type == "Optimizer":
1139
+ ref_obj = Optimizer(1.0)
1140
+ skipset.update(dir(ref_obj))
1141
+ skipset.remove("variables")
1142
+ elif obj_type == "Loss":
1143
+ ref_obj = Loss()
1144
+ skipset.update(dir(ref_obj))
1145
+ else:
1146
+ raise ValueError(
1147
+ f"get_attr_skipset got invalid {obj_type=}. "
1148
+ "Accepted values for `obj_type` are "
1149
+ "['Layer', 'Functional', 'Sequential', 'Metric', "
1150
+ "'Optimizer', 'Loss']"
1151
+ )
1152
+
1153
+ global_state.set_global_attribute(
1154
+ f"saving_attr_skipset_{obj_type}", skipset
1155
+ )
1156
+ return skipset
1157
+
1158
+
1159
+ def is_memory_sufficient(model):
1160
+ """Check if there is sufficient memory to load the model into memory.
1161
+
1162
+ If psutil is installed, we can use it to determine whether the memory is
1163
+ sufficient. Otherwise, we use a predefined value of 1 GB for available
1164
+ memory.
1165
+ """
1166
+ if psutil is None:
1167
+ available_memory = 1024 * 1024 * 1024 # 1 GB in bytes
1168
+ else:
1169
+ available_memory = psutil.virtual_memory().available # In bytes
1170
+ return (
1171
+ weight_memory_size(model.variables)
1172
+ < available_memory * _MEMORY_UPPER_BOUND
1173
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Object config serialization and deserialization logic."""
2
+
3
+ import importlib
4
+ import inspect
5
+ import types
6
+ import warnings
7
+
8
+ import numpy as np
9
+
10
+ from keras.src import api_export
11
+ from keras.src import backend
12
+ from keras.src.api_export import keras_export
13
+ from keras.src.backend.common import global_state
14
+ from keras.src.saving import object_registration
15
+ from keras.src.utils import python_utils
16
+ from keras.src.utils.module_utils import tensorflow as tf
17
+
18
+ PLAIN_TYPES = (str, int, float, bool)
19
+
20
+ # List of Keras modules with built-in string representations for Keras defaults
21
+ BUILTIN_MODULES = (
22
+ "activations",
23
+ "constraints",
24
+ "initializers",
25
+ "losses",
26
+ "metrics",
27
+ "optimizers",
28
+ "regularizers",
29
+ )
30
+
31
+
32
+ class SerializableDict:
33
+ def __init__(self, **config):
34
+ self.config = config
35
+
36
+ def serialize(self):
37
+ return serialize_keras_object(self.config)
38
+
39
+
40
+ class SafeModeScope:
41
+ """Scope to propagate safe mode flag to nested deserialization calls."""
42
+
43
+ def __init__(self, safe_mode=True):
44
+ self.safe_mode = safe_mode
45
+
46
+ def __enter__(self):
47
+ self.original_value = in_safe_mode()
48
+ global_state.set_global_attribute("safe_mode_saving", self.safe_mode)
49
+
50
+ def __exit__(self, *args, **kwargs):
51
+ global_state.set_global_attribute(
52
+ "safe_mode_saving", self.original_value
53
+ )
54
+
55
+
56
+ @keras_export("keras.config.enable_unsafe_deserialization")
57
+ def enable_unsafe_deserialization():
58
+ """Disables safe mode globally, allowing deserialization of lambdas."""
59
+ global_state.set_global_attribute("safe_mode_saving", False)
60
+
61
+
62
+ def in_safe_mode():
63
+ return global_state.get_global_attribute("safe_mode_saving")
64
+
65
+
66
+ class ObjectSharingScope:
67
+ """Scope to enable detection and reuse of previously seen objects."""
68
+
69
+ def __enter__(self):
70
+ global_state.set_global_attribute("shared_objects/id_to_obj_map", {})
71
+ global_state.set_global_attribute("shared_objects/id_to_config_map", {})
72
+
73
+ def __exit__(self, *args, **kwargs):
74
+ global_state.set_global_attribute("shared_objects/id_to_obj_map", None)
75
+ global_state.set_global_attribute(
76
+ "shared_objects/id_to_config_map", None
77
+ )
78
+
79
+
80
+ def get_shared_object(obj_id):
81
+ """Retrieve an object previously seen during deserialization."""
82
+ id_to_obj_map = global_state.get_global_attribute(
83
+ "shared_objects/id_to_obj_map"
84
+ )
85
+ if id_to_obj_map is not None:
86
+ return id_to_obj_map.get(obj_id, None)
87
+
88
+
89
+ def record_object_after_serialization(obj, config):
90
+ """Call after serializing an object, to keep track of its config."""
91
+ if config["module"] == "__main__":
92
+ config["module"] = None # Ensures module is None when no module found
93
+ id_to_config_map = global_state.get_global_attribute(
94
+ "shared_objects/id_to_config_map"
95
+ )
96
+ if id_to_config_map is None:
97
+ return # Not in a sharing scope
98
+ obj_id = int(id(obj))
99
+ if obj_id not in id_to_config_map:
100
+ id_to_config_map[obj_id] = config
101
+ else:
102
+ config["shared_object_id"] = obj_id
103
+ prev_config = id_to_config_map[obj_id]
104
+ prev_config["shared_object_id"] = obj_id
105
+
106
+
107
+ def record_object_after_deserialization(obj, obj_id):
108
+ """Call after deserializing an object, to keep track of it in the future."""
109
+ id_to_obj_map = global_state.get_global_attribute(
110
+ "shared_objects/id_to_obj_map"
111
+ )
112
+ if id_to_obj_map is None:
113
+ return # Not in a sharing scope
114
+ id_to_obj_map[obj_id] = obj
115
+
116
+
117
+ @keras_export(
118
+ [
119
+ "keras.saving.serialize_keras_object",
120
+ "keras.utils.serialize_keras_object",
121
+ ]
122
+ )
123
+ def serialize_keras_object(obj):
124
+ """Retrieve the config dict by serializing the Keras object.
125
+
126
+ `serialize_keras_object()` serializes a Keras object to a python dictionary
127
+ that represents the object, and is a reciprocal function of
128
+ `deserialize_keras_object()`. See `deserialize_keras_object()` for more
129
+ information about the config format.
130
+
131
+ Args:
132
+ obj: the Keras object to serialize.
133
+
134
+ Returns:
135
+ A python dict that represents the object. The python dict can be
136
+ deserialized via `deserialize_keras_object()`.
137
+ """
138
+ if obj is None:
139
+ return obj
140
+
141
+ if isinstance(obj, PLAIN_TYPES):
142
+ return obj
143
+
144
+ if isinstance(obj, (list, tuple)):
145
+ config_arr = [serialize_keras_object(x) for x in obj]
146
+ return tuple(config_arr) if isinstance(obj, tuple) else config_arr
147
+ if isinstance(obj, dict):
148
+ return serialize_dict(obj)
149
+
150
+ # Special cases:
151
+ if isinstance(obj, bytes):
152
+ return {
153
+ "class_name": "__bytes__",
154
+ "config": {"value": obj.decode("utf-8")},
155
+ }
156
+ if isinstance(obj, slice):
157
+ return {
158
+ "class_name": "__slice__",
159
+ "config": {
160
+ "start": serialize_keras_object(obj.start),
161
+ "stop": serialize_keras_object(obj.stop),
162
+ "step": serialize_keras_object(obj.step),
163
+ },
164
+ }
165
+ # Ellipsis is an instance, and ellipsis class is not in global scope.
166
+ # checking equality also fails elsewhere in the library, so we have
167
+ # to dynamically get the type.
168
+ if isinstance(obj, type(Ellipsis)):
169
+ return {"class_name": "__ellipsis__", "config": {}}
170
+ if isinstance(obj, backend.KerasTensor):
171
+ history = getattr(obj, "_keras_history", None)
172
+ if history:
173
+ history = list(history)
174
+ history[0] = history[0].name
175
+ return {
176
+ "class_name": "__keras_tensor__",
177
+ "config": {
178
+ "shape": obj.shape,
179
+ "dtype": obj.dtype,
180
+ "keras_history": history,
181
+ },
182
+ }
183
+ if tf.available and isinstance(obj, tf.TensorShape):
184
+ return obj.as_list() if obj._dims is not None else None
185
+ if backend.is_tensor(obj):
186
+ return {
187
+ "class_name": "__tensor__",
188
+ "config": {
189
+ "value": backend.convert_to_numpy(obj).tolist(),
190
+ "dtype": backend.standardize_dtype(obj.dtype),
191
+ },
192
+ }
193
+ if type(obj).__module__ == np.__name__:
194
+ if isinstance(obj, np.ndarray) and obj.ndim > 0:
195
+ return {
196
+ "class_name": "__numpy__",
197
+ "config": {
198
+ "value": obj.tolist(),
199
+ "dtype": backend.standardize_dtype(obj.dtype),
200
+ },
201
+ }
202
+ else:
203
+ # Treat numpy floats / etc as plain types.
204
+ return obj.item()
205
+ if tf.available and isinstance(obj, tf.DType):
206
+ return obj.name
207
+ if isinstance(obj, types.FunctionType) and obj.__name__ == "<lambda>":
208
+ warnings.warn(
209
+ "The object being serialized includes a `lambda`. This is unsafe. "
210
+ "In order to reload the object, you will have to pass "
211
+ "`safe_mode=False` to the loading function. "
212
+ "Please avoid using `lambda` in the "
213
+ "future, and use named Python functions instead. "
214
+ f"This is the `lambda` being serialized: {inspect.getsource(obj)}",
215
+ stacklevel=2,
216
+ )
217
+ return {
218
+ "class_name": "__lambda__",
219
+ "config": {
220
+ "value": python_utils.func_dump(obj),
221
+ },
222
+ }
223
+ if tf.available and isinstance(obj, tf.TypeSpec):
224
+ ts_config = obj._serialize()
225
+ # TensorShape and tf.DType conversion
226
+ ts_config = list(
227
+ map(
228
+ lambda x: (
229
+ x.as_list()
230
+ if isinstance(x, tf.TensorShape)
231
+ else (x.name if isinstance(x, tf.DType) else x)
232
+ ),
233
+ ts_config,
234
+ )
235
+ )
236
+ return {
237
+ "class_name": "__typespec__",
238
+ "spec_name": obj.__class__.__name__,
239
+ "module": obj.__class__.__module__,
240
+ "config": ts_config,
241
+ "registered_name": None,
242
+ }
243
+
244
+ inner_config = _get_class_or_fn_config(obj)
245
+ config_with_public_class = serialize_with_public_class(
246
+ obj.__class__, inner_config
247
+ )
248
+
249
+ if config_with_public_class is not None:
250
+ get_build_and_compile_config(obj, config_with_public_class)
251
+ record_object_after_serialization(obj, config_with_public_class)
252
+ return config_with_public_class
253
+
254
+ # Any custom object or otherwise non-exported object
255
+ if isinstance(obj, types.FunctionType):
256
+ module = obj.__module__
257
+ else:
258
+ module = obj.__class__.__module__
259
+ class_name = obj.__class__.__name__
260
+
261
+ if module == "builtins":
262
+ registered_name = None
263
+ else:
264
+ if isinstance(obj, types.FunctionType):
265
+ registered_name = object_registration.get_registered_name(obj)
266
+ else:
267
+ registered_name = object_registration.get_registered_name(
268
+ obj.__class__
269
+ )
270
+
271
+ config = {
272
+ "module": module,
273
+ "class_name": class_name,
274
+ "config": inner_config,
275
+ "registered_name": registered_name,
276
+ }
277
+ get_build_and_compile_config(obj, config)
278
+ record_object_after_serialization(obj, config)
279
+ return config
280
+
281
+
282
+ def get_build_and_compile_config(obj, config):
283
+ if hasattr(obj, "get_build_config"):
284
+ build_config = obj.get_build_config()
285
+ if build_config is not None:
286
+ config["build_config"] = serialize_dict(build_config)
287
+ if hasattr(obj, "get_compile_config"):
288
+ compile_config = obj.get_compile_config()
289
+ if compile_config is not None:
290
+ config["compile_config"] = serialize_dict(compile_config)
291
+ return
292
+
293
+
294
+ def serialize_with_public_class(cls, inner_config=None):
295
+ """Serializes classes from public Keras API or object registration.
296
+
297
+ Called to check and retrieve the config of any class that has a public
298
+ Keras API or has been registered as serializable via
299
+ `keras.saving.register_keras_serializable()`.
300
+ """
301
+ # This gets the `keras.*` exported name, such as
302
+ # "keras.optimizers.Adam".
303
+ keras_api_name = api_export.get_name_from_symbol(cls)
304
+
305
+ # Case of custom or unknown class object
306
+ if keras_api_name is None:
307
+ registered_name = object_registration.get_registered_name(cls)
308
+ if registered_name is None:
309
+ return None
310
+
311
+ # Return custom object config with corresponding registration name
312
+ return {
313
+ "module": cls.__module__,
314
+ "class_name": cls.__name__,
315
+ "config": inner_config,
316
+ "registered_name": registered_name,
317
+ }
318
+
319
+ # Split the canonical Keras API name into a Keras module and class name.
320
+ parts = keras_api_name.split(".")
321
+ return {
322
+ "module": ".".join(parts[:-1]),
323
+ "class_name": parts[-1],
324
+ "config": inner_config,
325
+ "registered_name": None,
326
+ }
327
+
328
+
329
+ def serialize_with_public_fn(fn, config, fn_module_name=None):
330
+ """Serializes functions from public Keras API or object registration.
331
+
332
+ Called to check and retrieve the config of any function that has a public
333
+ Keras API or has been registered as serializable via
334
+ `keras.saving.register_keras_serializable()`. If function's module name
335
+ is already known, returns corresponding config.
336
+ """
337
+ if fn_module_name:
338
+ return {
339
+ "module": fn_module_name,
340
+ "class_name": "function",
341
+ "config": config,
342
+ "registered_name": config,
343
+ }
344
+ keras_api_name = api_export.get_name_from_symbol(fn)
345
+ if keras_api_name:
346
+ parts = keras_api_name.split(".")
347
+ return {
348
+ "module": ".".join(parts[:-1]),
349
+ "class_name": "function",
350
+ "config": config,
351
+ "registered_name": config,
352
+ }
353
+ else:
354
+ registered_name = object_registration.get_registered_name(fn)
355
+ if not registered_name and not fn.__module__ == "builtins":
356
+ return None
357
+ return {
358
+ "module": fn.__module__,
359
+ "class_name": "function",
360
+ "config": config,
361
+ "registered_name": registered_name,
362
+ }
363
+
364
+
365
+ def _get_class_or_fn_config(obj):
366
+ """Return the object's config depending on its type."""
367
+ # Functions / lambdas:
368
+ if isinstance(obj, types.FunctionType):
369
+ return object_registration.get_registered_name(obj)
370
+ # All classes:
371
+ if hasattr(obj, "get_config"):
372
+ config = obj.get_config()
373
+ if not isinstance(config, dict):
374
+ raise TypeError(
375
+ f"The `get_config()` method of {obj} should return "
376
+ f"a dict. It returned: {config}"
377
+ )
378
+ return serialize_dict(config)
379
+ elif hasattr(obj, "__name__"):
380
+ return object_registration.get_registered_name(obj)
381
+ else:
382
+ raise TypeError(
383
+ f"Cannot serialize object {obj} of type {type(obj)}. "
384
+ "To be serializable, "
385
+ "a class must implement the `get_config()` method."
386
+ )
387
+
388
+
389
+ def serialize_dict(obj):
390
+ return {key: serialize_keras_object(value) for key, value in obj.items()}
391
+
392
+
393
+ @keras_export(
394
+ [
395
+ "keras.saving.deserialize_keras_object",
396
+ "keras.utils.deserialize_keras_object",
397
+ ]
398
+ )
399
+ def deserialize_keras_object(
400
+ config, custom_objects=None, safe_mode=True, **kwargs
401
+ ):
402
+ """Retrieve the object by deserializing the config dict.
403
+
404
+ The config dict is a Python dictionary that consists of a set of key-value
405
+ pairs, and represents a Keras object, such as an `Optimizer`, `Layer`,
406
+ `Metrics`, etc. The saving and loading library uses the following keys to
407
+ record information of a Keras object:
408
+
409
+ - `class_name`: String. This is the name of the class,
410
+ as exactly defined in the source
411
+ code, such as "LossesContainer".
412
+ - `config`: Dict. Library-defined or user-defined key-value pairs that store
413
+ the configuration of the object, as obtained by `object.get_config()`.
414
+ - `module`: String. The path of the python module. Built-in Keras classes
415
+ expect to have prefix `keras`.
416
+ - `registered_name`: String. The key the class is registered under via
417
+ `keras.saving.register_keras_serializable(package, name)` API. The
418
+ key has the format of '{package}>{name}', where `package` and `name` are
419
+ the arguments passed to `register_keras_serializable()`. If `name` is not
420
+ provided, it uses the class name. If `registered_name` successfully
421
+ resolves to a class (that was registered), the `class_name` and `config`
422
+ values in the dict will not be used. `registered_name` is only used for
423
+ non-built-in classes.
424
+
425
+ For example, the following dictionary represents the built-in Adam optimizer
426
+ with the relevant config:
427
+
428
+ ```python
429
+ dict_structure = {
430
+ "class_name": "Adam",
431
+ "config": {
432
+ "amsgrad": false,
433
+ "beta_1": 0.8999999761581421,
434
+ "beta_2": 0.9990000128746033,
435
+ "decay": 0.0,
436
+ "epsilon": 1e-07,
437
+ "learning_rate": 0.0010000000474974513,
438
+ "name": "Adam"
439
+ },
440
+ "module": "keras.optimizers",
441
+ "registered_name": None
442
+ }
443
+ # Returns an `Adam` instance identical to the original one.
444
+ deserialize_keras_object(dict_structure)
445
+ ```
446
+
447
+ If the class does not have an exported Keras namespace, the library tracks
448
+ it by its `module` and `class_name`. For example:
449
+
450
+ ```python
451
+ dict_structure = {
452
+ "class_name": "MetricsList",
453
+ "config": {
454
+ ...
455
+ },
456
+ "module": "keras.trainers.compile_utils",
457
+ "registered_name": "MetricsList"
458
+ }
459
+
460
+ # Returns a `MetricsList` instance identical to the original one.
461
+ deserialize_keras_object(dict_structure)
462
+ ```
463
+
464
+ And the following dictionary represents a user-customized `MeanSquaredError`
465
+ loss:
466
+
467
+ ```python
468
+ @keras.saving.register_keras_serializable(package='my_package')
469
+ class ModifiedMeanSquaredError(keras.losses.MeanSquaredError):
470
+ ...
471
+
472
+ dict_structure = {
473
+ "class_name": "ModifiedMeanSquaredError",
474
+ "config": {
475
+ "fn": "mean_squared_error",
476
+ "name": "mean_squared_error",
477
+ "reduction": "auto"
478
+ },
479
+ "registered_name": "my_package>ModifiedMeanSquaredError"
480
+ }
481
+ # Returns the `ModifiedMeanSquaredError` object
482
+ deserialize_keras_object(dict_structure)
483
+ ```
484
+
485
+ Args:
486
+ config: Python dict describing the object.
487
+ custom_objects: Python dict containing a mapping between custom
488
+ object names the corresponding classes or functions.
489
+ safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
490
+ When `safe_mode=False`, loading an object has the potential to
491
+ trigger arbitrary code execution. This argument is only
492
+ applicable to the Keras v3 model format. Defaults to `True`.
493
+
494
+ Returns:
495
+ The object described by the `config` dictionary.
496
+ """
497
+ safe_scope_arg = in_safe_mode() # Enforces SafeModeScope
498
+ safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode
499
+
500
+ module_objects = kwargs.pop("module_objects", None)
501
+ custom_objects = custom_objects or {}
502
+ tlco = global_state.get_global_attribute("custom_objects_scope_dict", {})
503
+ gco = object_registration.GLOBAL_CUSTOM_OBJECTS
504
+ custom_objects = {**custom_objects, **tlco, **gco}
505
+
506
+ if config is None:
507
+ return None
508
+
509
+ if (
510
+ isinstance(config, str)
511
+ and custom_objects
512
+ and custom_objects.get(config) is not None
513
+ ):
514
+ # This is to deserialize plain functions which are serialized as
515
+ # string names by legacy saving formats.
516
+ return custom_objects[config]
517
+
518
+ if isinstance(config, (list, tuple)):
519
+ return [
520
+ deserialize_keras_object(
521
+ x, custom_objects=custom_objects, safe_mode=safe_mode
522
+ )
523
+ for x in config
524
+ ]
525
+
526
+ if module_objects is not None:
527
+ inner_config, fn_module_name, has_custom_object = None, None, False
528
+
529
+ if isinstance(config, dict):
530
+ if "config" in config:
531
+ inner_config = config["config"]
532
+ if "class_name" not in config:
533
+ raise ValueError(
534
+ f"Unknown `config` as a `dict`, config={config}"
535
+ )
536
+
537
+ # Check case where config is function or class and in custom objects
538
+ if custom_objects and (
539
+ config["class_name"] in custom_objects
540
+ or config.get("registered_name") in custom_objects
541
+ or (
542
+ isinstance(inner_config, str)
543
+ and inner_config in custom_objects
544
+ )
545
+ ):
546
+ has_custom_object = True
547
+
548
+ # Case where config is function but not in custom objects
549
+ elif config["class_name"] == "function":
550
+ fn_module_name = config["module"]
551
+ if fn_module_name == "builtins":
552
+ config = config["config"]
553
+ else:
554
+ config = config["registered_name"]
555
+
556
+ # Case where config is class but not in custom objects
557
+ else:
558
+ if config.get("module", "_") is None:
559
+ raise TypeError(
560
+ "Cannot deserialize object of type "
561
+ f"`{config['class_name']}`. If "
562
+ f"`{config['class_name']}` is a custom class, please "
563
+ "register it using the "
564
+ "`@keras.saving.register_keras_serializable()` "
565
+ "decorator."
566
+ )
567
+ config = config["class_name"]
568
+
569
+ if not has_custom_object:
570
+ # Return if not found in either module objects or custom objects
571
+ if config not in module_objects:
572
+ # Object has already been deserialized
573
+ return config
574
+ if isinstance(module_objects[config], types.FunctionType):
575
+ return deserialize_keras_object(
576
+ serialize_with_public_fn(
577
+ module_objects[config], config, fn_module_name
578
+ ),
579
+ custom_objects=custom_objects,
580
+ )
581
+ return deserialize_keras_object(
582
+ serialize_with_public_class(
583
+ module_objects[config], inner_config=inner_config
584
+ ),
585
+ custom_objects=custom_objects,
586
+ )
587
+
588
+ if isinstance(config, PLAIN_TYPES):
589
+ return config
590
+ if not isinstance(config, dict):
591
+ raise TypeError(f"Could not parse config: {config}")
592
+
593
+ if "class_name" not in config or "config" not in config:
594
+ return {
595
+ key: deserialize_keras_object(
596
+ value, custom_objects=custom_objects, safe_mode=safe_mode
597
+ )
598
+ for key, value in config.items()
599
+ }
600
+
601
+ class_name = config["class_name"]
602
+ inner_config = config["config"] or {}
603
+ custom_objects = custom_objects or {}
604
+
605
+ # Special cases:
606
+ if class_name == "__keras_tensor__":
607
+ obj = backend.KerasTensor(
608
+ inner_config["shape"], dtype=inner_config["dtype"]
609
+ )
610
+ obj._pre_serialization_keras_history = inner_config["keras_history"]
611
+ return obj
612
+
613
+ if class_name == "__tensor__":
614
+ return backend.convert_to_tensor(
615
+ inner_config["value"], dtype=inner_config["dtype"]
616
+ )
617
+ if class_name == "__numpy__":
618
+ return np.array(inner_config["value"], dtype=inner_config["dtype"])
619
+ if config["class_name"] == "__bytes__":
620
+ return inner_config["value"].encode("utf-8")
621
+ if config["class_name"] == "__ellipsis__":
622
+ return Ellipsis
623
+ if config["class_name"] == "__slice__":
624
+ return slice(
625
+ deserialize_keras_object(
626
+ inner_config["start"],
627
+ custom_objects=custom_objects,
628
+ safe_mode=safe_mode,
629
+ ),
630
+ deserialize_keras_object(
631
+ inner_config["stop"],
632
+ custom_objects=custom_objects,
633
+ safe_mode=safe_mode,
634
+ ),
635
+ deserialize_keras_object(
636
+ inner_config["step"],
637
+ custom_objects=custom_objects,
638
+ safe_mode=safe_mode,
639
+ ),
640
+ )
641
+ if config["class_name"] == "__lambda__":
642
+ if safe_mode:
643
+ raise ValueError(
644
+ "Requested the deserialization of a `lambda` object. "
645
+ "This carries a potential risk of arbitrary code execution "
646
+ "and thus it is disallowed by default. If you trust the "
647
+ "source of the saved model, you can pass `safe_mode=False` to "
648
+ "the loading function in order to allow `lambda` loading, "
649
+ "or call `keras.config.enable_unsafe_deserialization()`."
650
+ )
651
+ return python_utils.func_load(inner_config["value"])
652
+ if tf is not None and config["class_name"] == "__typespec__":
653
+ obj = _retrieve_class_or_fn(
654
+ config["spec_name"],
655
+ config["registered_name"],
656
+ config["module"],
657
+ obj_type="class",
658
+ full_config=config,
659
+ custom_objects=custom_objects,
660
+ )
661
+ # Conversion to TensorShape and DType
662
+ inner_config = map(
663
+ lambda x: (
664
+ tf.TensorShape(x)
665
+ if isinstance(x, list)
666
+ else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x)
667
+ ),
668
+ inner_config,
669
+ )
670
+ return obj._deserialize(tuple(inner_config))
671
+
672
+ # Below: classes and functions.
673
+ module = config.get("module", None)
674
+ registered_name = config.get("registered_name", class_name)
675
+
676
+ if class_name == "function":
677
+ fn_name = inner_config
678
+ return _retrieve_class_or_fn(
679
+ fn_name,
680
+ registered_name,
681
+ module,
682
+ obj_type="function",
683
+ full_config=config,
684
+ custom_objects=custom_objects,
685
+ )
686
+
687
+ # Below, handling of all classes.
688
+ # First, is it a shared object?
689
+ if "shared_object_id" in config:
690
+ obj = get_shared_object(config["shared_object_id"])
691
+ if obj is not None:
692
+ return obj
693
+
694
+ cls = _retrieve_class_or_fn(
695
+ class_name,
696
+ registered_name,
697
+ module,
698
+ obj_type="class",
699
+ full_config=config,
700
+ custom_objects=custom_objects,
701
+ )
702
+
703
+ if isinstance(cls, types.FunctionType):
704
+ return cls
705
+ if not hasattr(cls, "from_config"):
706
+ raise TypeError(
707
+ f"Unable to reconstruct an instance of '{class_name}' because "
708
+ f"the class is missing a `from_config()` method. "
709
+ f"Full object config: {config}"
710
+ )
711
+
712
+ # Instantiate the class from its config inside a custom object scope
713
+ # so that we can catch any custom objects that the config refers to.
714
+ custom_obj_scope = object_registration.CustomObjectScope(custom_objects)
715
+ safe_mode_scope = SafeModeScope(safe_mode)
716
+ with custom_obj_scope, safe_mode_scope:
717
+ try:
718
+ instance = cls.from_config(inner_config)
719
+ except TypeError as e:
720
+ raise TypeError(
721
+ f"{cls} could not be deserialized properly. Please"
722
+ " ensure that components that are Python object"
723
+ " instances (layers, models, etc.) returned by"
724
+ " `get_config()` are explicitly deserialized in the"
725
+ " model's `from_config()` method."
726
+ f"\n\nconfig={config}.\n\nException encountered: {e}"
727
+ )
728
+ build_config = config.get("build_config", None)
729
+ if build_config and not instance.built:
730
+ instance.build_from_config(build_config)
731
+ instance.built = True
732
+ compile_config = config.get("compile_config", None)
733
+ if compile_config:
734
+ instance.compile_from_config(compile_config)
735
+ instance.compiled = True
736
+
737
+ if "shared_object_id" in config:
738
+ record_object_after_deserialization(
739
+ instance, config["shared_object_id"]
740
+ )
741
+ return instance
742
+
743
+
744
+ def _retrieve_class_or_fn(
745
+ name, registered_name, module, obj_type, full_config, custom_objects=None
746
+ ):
747
+ # If there is a custom object registered via
748
+ # `register_keras_serializable()`, that takes precedence.
749
+ if obj_type == "function":
750
+ custom_obj = object_registration.get_registered_object(
751
+ name, custom_objects=custom_objects
752
+ )
753
+ else:
754
+ custom_obj = object_registration.get_registered_object(
755
+ registered_name, custom_objects=custom_objects
756
+ )
757
+ if custom_obj is not None:
758
+ return custom_obj
759
+
760
+ if module:
761
+ # If it's a Keras built-in object,
762
+ # we cannot always use direct import, because the exported
763
+ # module name might not match the package structure
764
+ # (e.g. experimental symbols).
765
+ if module == "keras" or module.startswith("keras."):
766
+ api_name = module + "." + name
767
+
768
+ obj = api_export.get_symbol_from_name(api_name)
769
+ if obj is not None:
770
+ return obj
771
+
772
+ # Configs of Keras built-in functions do not contain identifying
773
+ # information other than their name (e.g. 'acc' or 'tanh'). This special
774
+ # case searches the Keras modules that contain built-ins to retrieve
775
+ # the corresponding function from the identifying string.
776
+ if obj_type == "function" and module == "builtins":
777
+ for mod in BUILTIN_MODULES:
778
+ obj = api_export.get_symbol_from_name(
779
+ "keras." + mod + "." + name
780
+ )
781
+ if obj is not None:
782
+ return obj
783
+
784
+ # Otherwise, attempt to retrieve the class object given the `module`
785
+ # and `class_name`. Import the module, find the class.
786
+ try:
787
+ mod = importlib.import_module(module)
788
+ except ModuleNotFoundError:
789
+ raise TypeError(
790
+ f"Could not deserialize {obj_type} '{name}' because "
791
+ f"its parent module {module} cannot be imported. "
792
+ f"Full object config: {full_config}"
793
+ )
794
+ obj = vars(mod).get(name, None)
795
+
796
+ # Special case for keras.metrics.metrics
797
+ if obj is None and registered_name is not None:
798
+ obj = vars(mod).get(registered_name, None)
799
+
800
+ if obj is not None:
801
+ return obj
802
+
803
+ raise TypeError(
804
+ f"Could not locate {obj_type} '{name}'. "
805
+ "Make sure custom classes are decorated with "
806
+ "`@keras.saving.register_keras_serializable()`. "
807
+ f"Full object config: {full_config}"
808
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from keras.src.testing.test_case import TestCase
2
+ from keras.src.testing.test_case import jax_uses_gpu
3
+ from keras.src.testing.test_case import tensorflow_uses_gpu
4
+ from keras.src.testing.test_case import torch_uses_gpu
5
+ from keras.src.testing.test_case import uses_gpu
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (398 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc ADDED
Binary file (22.5 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc ADDED
Binary file (5.43 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shutil
3
+ import tempfile
4
+ import unittest
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ from absl.testing import parameterized
9
+
10
+ from keras.src import backend
11
+ from keras.src import distribution
12
+ from keras.src import ops
13
+ from keras.src import tree
14
+ from keras.src import utils
15
+ from keras.src.backend.common import is_float_dtype
16
+ from keras.src.backend.common import standardize_dtype
17
+ from keras.src.backend.common.global_state import clear_session
18
+ from keras.src.backend.common.keras_tensor import KerasTensor
19
+ from keras.src.models import Model
20
+ from keras.src.utils import traceback_utils
21
+
22
+
23
+ class TestCase(parameterized.TestCase, unittest.TestCase):
24
+ maxDiff = None
25
+
26
+ def __init__(self, *args, **kwargs):
27
+ super().__init__(*args, **kwargs)
28
+
29
+ def setUp(self):
30
+ # clear global state so that test cases are independent
31
+ # required for the jit enabled torch tests since dynamo has
32
+ # a global cache for guards, compiled fn, etc
33
+ clear_session(free_memory=False)
34
+ if traceback_utils.is_traceback_filtering_enabled():
35
+ traceback_utils.disable_traceback_filtering()
36
+
37
+ def get_temp_dir(self):
38
+ temp_dir = tempfile.mkdtemp()
39
+ self.addCleanup(lambda: shutil.rmtree(temp_dir))
40
+ return temp_dir
41
+
42
+ def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
43
+ if not isinstance(x1, np.ndarray):
44
+ x1 = backend.convert_to_numpy(x1)
45
+ if not isinstance(x2, np.ndarray):
46
+ x2 = backend.convert_to_numpy(x2)
47
+ np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg)
48
+
49
+ def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None):
50
+ try:
51
+ self.assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg)
52
+ except AssertionError:
53
+ return
54
+ msg = msg or ""
55
+ raise AssertionError(
56
+ f"The two values are close at all elements. \n"
57
+ f"{msg}.\n"
58
+ f"Values: {x1}"
59
+ )
60
+
61
+ def assertAlmostEqual(self, x1, x2, decimal=3, msg=None):
62
+ msg = msg or ""
63
+ if not isinstance(x1, np.ndarray):
64
+ x1 = backend.convert_to_numpy(x1)
65
+ if not isinstance(x2, np.ndarray):
66
+ x2 = backend.convert_to_numpy(x2)
67
+ np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg)
68
+
69
+ def assertAllEqual(self, x1, x2, msg=None):
70
+ self.assertEqual(len(x1), len(x2), msg=msg)
71
+ for e1, e2 in zip(x1, x2):
72
+ if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)):
73
+ self.assertAllEqual(e1, e2, msg=msg)
74
+ else:
75
+ e1 = backend.convert_to_numpy(e1)
76
+ e2 = backend.convert_to_numpy(e2)
77
+ self.assertEqual(e1, e2, msg=msg)
78
+
79
+ def assertLen(self, iterable, expected_len, msg=None):
80
+ self.assertEqual(len(iterable), expected_len, msg=msg)
81
+
82
+ def assertSparse(self, x, sparse=True):
83
+ if isinstance(x, KerasTensor):
84
+ self.assertEqual(x.sparse, sparse)
85
+ elif backend.backend() == "tensorflow":
86
+ import tensorflow as tf
87
+
88
+ if sparse:
89
+ self.assertIsInstance(x, tf.SparseTensor)
90
+ else:
91
+ self.assertNotIsInstance(x, tf.SparseTensor)
92
+ elif backend.backend() == "jax":
93
+ import jax.experimental.sparse as jax_sparse
94
+
95
+ if sparse:
96
+ self.assertIsInstance(x, jax_sparse.JAXSparse)
97
+ else:
98
+ self.assertNotIsInstance(x, jax_sparse.JAXSparse)
99
+ else:
100
+ self.assertFalse(
101
+ sparse,
102
+ f"Backend {backend.backend()} does not support sparse tensors",
103
+ )
104
+
105
+ def assertDType(self, x, dtype, msg=None):
106
+ if hasattr(x, "dtype"):
107
+ x_dtype = backend.standardize_dtype(x.dtype)
108
+ else:
109
+ # If x is a python number
110
+ x_dtype = backend.standardize_dtype(type(x))
111
+ standardized_dtype = backend.standardize_dtype(dtype)
112
+ default_msg = (
113
+ "The dtype of x does not match the expected one. "
114
+ f"Received: x.dtype={x_dtype} and dtype={dtype}"
115
+ )
116
+ msg = msg or default_msg
117
+ self.assertEqual(x_dtype, standardized_dtype, msg=msg)
118
+
119
+ def assertFileExists(self, path):
120
+ if not Path(path).is_file():
121
+ raise AssertionError(f"File {path} does not exist")
122
+
123
+ def run_class_serialization_test(self, instance, custom_objects=None):
124
+ from keras.src.saving import custom_object_scope
125
+ from keras.src.saving import deserialize_keras_object
126
+ from keras.src.saving import serialize_keras_object
127
+
128
+ # get_config roundtrip
129
+ cls = instance.__class__
130
+ config = instance.get_config()
131
+ config_json = to_json_with_tuples(config)
132
+ ref_dir = dir(instance)[:]
133
+ with custom_object_scope(custom_objects):
134
+ revived_instance = cls.from_config(config)
135
+ revived_config = revived_instance.get_config()
136
+ revived_config_json = to_json_with_tuples(revived_config)
137
+ self.assertEqual(config_json, revived_config_json)
138
+ self.assertEqual(set(ref_dir), set(dir(revived_instance)))
139
+
140
+ # serialization roundtrip
141
+ serialized = serialize_keras_object(instance)
142
+ serialized_json = to_json_with_tuples(serialized)
143
+ with custom_object_scope(custom_objects):
144
+ revived_instance = deserialize_keras_object(
145
+ from_json_with_tuples(serialized_json)
146
+ )
147
+ revived_config = revived_instance.get_config()
148
+ revived_config_json = to_json_with_tuples(revived_config)
149
+ self.assertEqual(config_json, revived_config_json)
150
+ new_dir = dir(revived_instance)[:]
151
+ for lst in [ref_dir, new_dir]:
152
+ if "__annotations__" in lst:
153
+ lst.remove("__annotations__")
154
+ self.assertEqual(set(ref_dir), set(new_dir))
155
+ return revived_instance
156
+
157
+ def run_layer_test(
158
+ self,
159
+ layer_cls,
160
+ init_kwargs,
161
+ input_shape=None,
162
+ input_dtype=None,
163
+ input_sparse=False,
164
+ input_data=None,
165
+ call_kwargs=None,
166
+ expected_output_shape=None,
167
+ expected_output_dtype=None,
168
+ expected_output_sparse=False,
169
+ expected_output=None,
170
+ expected_num_trainable_weights=None,
171
+ expected_num_non_trainable_weights=None,
172
+ expected_num_non_trainable_variables=None,
173
+ expected_num_seed_generators=None,
174
+ expected_num_losses=None,
175
+ supports_masking=None,
176
+ expected_mask_shape=None,
177
+ custom_objects=None,
178
+ run_training_check=True,
179
+ run_mixed_precision_check=True,
180
+ assert_built_after_instantiation=False,
181
+ ):
182
+ """Run basic checks on a layer.
183
+
184
+ Args:
185
+ layer_cls: The class of the layer to test.
186
+ init_kwargs: Dict of arguments to be used to
187
+ instantiate the layer.
188
+ input_shape: Shape tuple (or list/dict of shape tuples)
189
+ to call the layer on.
190
+ input_dtype: Corresponding input dtype.
191
+ input_sparse: Whether the input is a sparse tensor (this requires
192
+ the backend to support sparse tensors).
193
+ input_data: Tensor (or list/dict of tensors)
194
+ to call the layer on.
195
+ call_kwargs: Dict of arguments to use when calling the
196
+ layer (does not include the first input tensor argument)
197
+ expected_output_shape: Shape tuple
198
+ (or list/dict of shape tuples)
199
+ expected as output.
200
+ expected_output_dtype: dtype expected as output.
201
+ expected_output_sparse: Whether the output is expected to be sparse
202
+ (this requires the backend to support sparse tensors).
203
+ expected_output: Expected output tensor -- only
204
+ to be specified if input_data is provided.
205
+ expected_num_trainable_weights: Expected number
206
+ of trainable weights of the layer once built.
207
+ expected_num_non_trainable_weights: Expected number
208
+ of non-trainable weights of the layer once built.
209
+ expected_num_seed_generators: Expected number of
210
+ SeedGenerators objects of the layer once built.
211
+ expected_num_losses: Expected number of loss tensors
212
+ produced when calling the layer.
213
+ supports_masking: If True, will check that the layer
214
+ supports masking.
215
+ expected_mask_shape: Expected mask shape tuple
216
+ returned by compute_mask() (only supports 1 shape).
217
+ custom_objects: Dict of any custom objects to be
218
+ considered during deserialization.
219
+ run_training_check: Whether to attempt to train the layer
220
+ (if an input shape or input data was provided).
221
+ run_mixed_precision_check: Whether to test the layer with a mixed
222
+ precision dtype policy.
223
+ assert_built_after_instantiation: Whether to assert `built=True`
224
+ after the layer's instantiation.
225
+ """
226
+ if input_shape is not None and input_data is not None:
227
+ raise ValueError(
228
+ "input_shape and input_data cannot be passed "
229
+ "at the same time."
230
+ )
231
+ if expected_output_shape is not None and expected_output is not None:
232
+ raise ValueError(
233
+ "expected_output_shape and expected_output cannot be passed "
234
+ "at the same time."
235
+ )
236
+ if expected_output is not None and input_data is None:
237
+ raise ValueError(
238
+ "In order to use expected_output, input_data must be provided."
239
+ )
240
+ if expected_mask_shape is not None and supports_masking is not True:
241
+ raise ValueError(
242
+ "In order to use expected_mask_shape, supports_masking "
243
+ "must be True."
244
+ )
245
+
246
+ init_kwargs = init_kwargs or {}
247
+ call_kwargs = call_kwargs or {}
248
+
249
+ if input_shape is not None and input_dtype is not None:
250
+ if isinstance(input_shape, tuple) and is_shape_tuple(
251
+ input_shape[0]
252
+ ):
253
+ self.assertIsInstance(input_dtype, tuple)
254
+ self.assertEqual(
255
+ len(input_shape),
256
+ len(input_dtype),
257
+ msg="The number of input shapes and dtypes does not match",
258
+ )
259
+ elif isinstance(input_shape, dict):
260
+ self.assertIsInstance(input_dtype, dict)
261
+ self.assertEqual(
262
+ set(input_shape.keys()),
263
+ set(input_dtype.keys()),
264
+ msg="The number of input shapes and dtypes does not match",
265
+ )
266
+ elif isinstance(input_shape, list):
267
+ self.assertIsInstance(input_dtype, list)
268
+ self.assertEqual(
269
+ len(input_shape),
270
+ len(input_dtype),
271
+ msg="The number of input shapes and dtypes does not match",
272
+ )
273
+ elif not isinstance(input_shape, tuple):
274
+ raise ValueError("The type of input_shape is not supported")
275
+ if input_shape is not None and input_dtype is None:
276
+ input_dtype = tree.map_shape_structure(
277
+ lambda _: "float32", input_shape
278
+ )
279
+
280
+ # Estimate actual number of weights, variables, seed generators if
281
+ # expected ones not set. When using layers uses composition it should
282
+ # build each sublayer manually.
283
+ if input_data is not None or input_shape is not None:
284
+ if input_data is None:
285
+ input_data = create_eager_tensors(
286
+ input_shape, input_dtype, input_sparse
287
+ )
288
+ layer = layer_cls(**init_kwargs)
289
+ if isinstance(input_data, dict):
290
+ layer(**input_data, **call_kwargs)
291
+ else:
292
+ layer(input_data, **call_kwargs)
293
+
294
+ if expected_num_trainable_weights is None:
295
+ expected_num_trainable_weights = len(layer.trainable_weights)
296
+ if expected_num_non_trainable_weights is None:
297
+ expected_num_non_trainable_weights = len(
298
+ layer.non_trainable_weights
299
+ )
300
+ if expected_num_non_trainable_variables is None:
301
+ expected_num_non_trainable_variables = len(
302
+ layer.non_trainable_variables
303
+ )
304
+ if expected_num_seed_generators is None:
305
+ expected_num_seed_generators = len(get_seed_generators(layer))
306
+
307
+ # Serialization test.
308
+ layer = layer_cls(**init_kwargs)
309
+ self.run_class_serialization_test(layer, custom_objects)
310
+
311
+ # Basic masking test.
312
+ if supports_masking is not None:
313
+ self.assertEqual(
314
+ layer.supports_masking,
315
+ supports_masking,
316
+ msg="Unexpected supports_masking value",
317
+ )
318
+
319
+ def run_build_asserts(layer):
320
+ self.assertTrue(layer.built)
321
+ if expected_num_trainable_weights is not None:
322
+ self.assertLen(
323
+ layer.trainable_weights,
324
+ expected_num_trainable_weights,
325
+ msg="Unexpected number of trainable_weights",
326
+ )
327
+ if expected_num_non_trainable_weights is not None:
328
+ self.assertLen(
329
+ layer.non_trainable_weights,
330
+ expected_num_non_trainable_weights,
331
+ msg="Unexpected number of non_trainable_weights",
332
+ )
333
+ if expected_num_non_trainable_variables is not None:
334
+ self.assertLen(
335
+ layer.non_trainable_variables,
336
+ expected_num_non_trainable_variables,
337
+ msg="Unexpected number of non_trainable_variables",
338
+ )
339
+ if expected_num_seed_generators is not None:
340
+ self.assertLen(
341
+ get_seed_generators(layer),
342
+ expected_num_seed_generators,
343
+ msg="Unexpected number of seed_generators",
344
+ )
345
+ if (
346
+ backend.backend() == "torch"
347
+ and expected_num_trainable_weights is not None
348
+ and expected_num_non_trainable_weights is not None
349
+ and expected_num_seed_generators is not None
350
+ ):
351
+ self.assertLen(
352
+ layer.torch_params,
353
+ expected_num_trainable_weights
354
+ + expected_num_non_trainable_weights
355
+ + expected_num_seed_generators,
356
+ msg="Unexpected number of torch_params",
357
+ )
358
+
359
+ def run_output_asserts(layer, output, eager=False):
360
+ if expected_output_shape is not None:
361
+ if isinstance(expected_output_shape, tuple) and is_shape_tuple(
362
+ expected_output_shape[0]
363
+ ):
364
+ self.assertIsInstance(output, tuple)
365
+ self.assertEqual(
366
+ len(output),
367
+ len(expected_output_shape),
368
+ msg="Unexpected number of outputs",
369
+ )
370
+ output_shape = tuple(v.shape for v in output)
371
+ self.assertEqual(
372
+ expected_output_shape,
373
+ output_shape,
374
+ msg="Unexpected output shape",
375
+ )
376
+ elif isinstance(expected_output_shape, tuple):
377
+ self.assertEqual(
378
+ expected_output_shape,
379
+ output.shape,
380
+ msg="Unexpected output shape",
381
+ )
382
+ elif isinstance(expected_output_shape, dict):
383
+ self.assertIsInstance(output, dict)
384
+ self.assertEqual(
385
+ set(output.keys()),
386
+ set(expected_output_shape.keys()),
387
+ msg="Unexpected output dict keys",
388
+ )
389
+ output_shape = {k: v.shape for k, v in output.items()}
390
+ self.assertEqual(
391
+ expected_output_shape,
392
+ output_shape,
393
+ msg="Unexpected output shape",
394
+ )
395
+ elif isinstance(expected_output_shape, list):
396
+ self.assertIsInstance(output, list)
397
+ self.assertEqual(
398
+ len(output),
399
+ len(expected_output_shape),
400
+ msg="Unexpected number of outputs",
401
+ )
402
+ output_shape = [v.shape for v in output]
403
+ self.assertEqual(
404
+ expected_output_shape,
405
+ output_shape,
406
+ msg="Unexpected output shape",
407
+ )
408
+ else:
409
+ raise ValueError(
410
+ "The type of expected_output_shape is not supported"
411
+ )
412
+ if expected_output_dtype is not None:
413
+ if isinstance(expected_output_dtype, tuple):
414
+ self.assertIsInstance(output, tuple)
415
+ self.assertEqual(
416
+ len(output),
417
+ len(expected_output_dtype),
418
+ msg="Unexpected number of outputs",
419
+ )
420
+ output_dtype = tuple(
421
+ backend.standardize_dtype(v.dtype) for v in output
422
+ )
423
+ self.assertEqual(
424
+ expected_output_dtype,
425
+ output_dtype,
426
+ msg="Unexpected output dtype",
427
+ )
428
+ elif isinstance(expected_output_dtype, dict):
429
+ self.assertIsInstance(output, dict)
430
+ self.assertEqual(
431
+ set(output.keys()),
432
+ set(expected_output_dtype.keys()),
433
+ msg="Unexpected output dict keys",
434
+ )
435
+ output_dtype = {
436
+ k: backend.standardize_dtype(v.dtype)
437
+ for k, v in output.items()
438
+ }
439
+ self.assertEqual(
440
+ expected_output_dtype,
441
+ output_dtype,
442
+ msg="Unexpected output dtype",
443
+ )
444
+ elif isinstance(expected_output_dtype, list):
445
+ self.assertIsInstance(output, list)
446
+ self.assertEqual(
447
+ len(output),
448
+ len(expected_output_dtype),
449
+ msg="Unexpected number of outputs",
450
+ )
451
+ output_dtype = [
452
+ backend.standardize_dtype(v.dtype) for v in output
453
+ ]
454
+ self.assertEqual(
455
+ expected_output_dtype,
456
+ output_dtype,
457
+ msg="Unexpected output dtype",
458
+ )
459
+ else:
460
+ output_dtype = tree.flatten(output)[0].dtype
461
+ self.assertEqual(
462
+ expected_output_dtype,
463
+ backend.standardize_dtype(output_dtype),
464
+ msg="Unexpected output dtype",
465
+ )
466
+ if expected_output_sparse:
467
+ for x in tree.flatten(output):
468
+ self.assertSparse(x)
469
+ if eager:
470
+ if expected_output is not None:
471
+ self.assertEqual(type(expected_output), type(output))
472
+ for ref_v, v in zip(
473
+ tree.flatten(expected_output), tree.flatten(output)
474
+ ):
475
+ self.assertAllClose(
476
+ ref_v, v, msg="Unexpected output value"
477
+ )
478
+ if expected_num_losses is not None:
479
+ self.assertLen(layer.losses, expected_num_losses)
480
+
481
+ def run_training_step(layer, input_data, output_data):
482
+ class TestModel(Model):
483
+ def __init__(self, layer):
484
+ super().__init__()
485
+ self.layer = layer
486
+
487
+ def call(self, x, training=False):
488
+ return self.layer(x, training=training)
489
+
490
+ model = TestModel(layer)
491
+
492
+ data = (input_data, output_data)
493
+ if backend.backend() == "torch":
494
+ data = tree.map_structure(backend.convert_to_numpy, data)
495
+
496
+ def data_generator():
497
+ while True:
498
+ yield data
499
+
500
+ # test the "default" path for each backend by setting
501
+ # jit_compile="auto".
502
+ # for tensorflow and jax backends auto is jitted
503
+ # Note that tensorflow cannot be jitted with sparse tensors
504
+ # for torch backend auto is eager
505
+ #
506
+ # NB: for torch, jit_compile=True turns on torchdynamo
507
+ # which may not always succeed in tracing depending
508
+ # on the model. Run your program with these env vars
509
+ # to get debug traces of dynamo:
510
+ # TORCH_LOGS="+dynamo"
511
+ # TORCHDYNAMO_VERBOSE=1
512
+ # TORCHDYNAMO_REPORT_GUARD_FAILURES=1
513
+ jit_compile = "auto"
514
+ if backend.backend() == "tensorflow" and input_sparse:
515
+ jit_compile = False
516
+ model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile)
517
+ model.fit(data_generator(), steps_per_epoch=1, verbose=0)
518
+
519
+ # Build test.
520
+ if input_data is not None or input_shape is not None:
521
+ if input_shape is None:
522
+ build_shape = tree.map_structure(
523
+ lambda x: ops.shape(x), input_data
524
+ )
525
+ else:
526
+ build_shape = input_shape
527
+ layer = layer_cls(**init_kwargs)
528
+ if isinstance(build_shape, dict):
529
+ layer.build(**build_shape)
530
+ else:
531
+ layer.build(build_shape)
532
+ run_build_asserts(layer)
533
+
534
+ # Symbolic call test.
535
+ if input_shape is None:
536
+ keras_tensor_inputs = tree.map_structure(
537
+ lambda x: create_keras_tensors(
538
+ ops.shape(x), x.dtype, input_sparse
539
+ ),
540
+ input_data,
541
+ )
542
+ else:
543
+ keras_tensor_inputs = create_keras_tensors(
544
+ input_shape, input_dtype, input_sparse
545
+ )
546
+ layer = layer_cls(**init_kwargs)
547
+ if isinstance(keras_tensor_inputs, dict):
548
+ keras_tensor_outputs = layer(
549
+ **keras_tensor_inputs, **call_kwargs
550
+ )
551
+ else:
552
+ keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs)
553
+ run_build_asserts(layer)
554
+ run_output_asserts(layer, keras_tensor_outputs, eager=False)
555
+
556
+ if expected_mask_shape is not None:
557
+ output_mask = layer.compute_mask(keras_tensor_inputs)
558
+ self.assertEqual(expected_mask_shape, output_mask.shape)
559
+
560
+ # The stateless layers should be built after instantiation.
561
+ if assert_built_after_instantiation:
562
+ layer = layer_cls(**init_kwargs)
563
+ self.assertTrue(
564
+ layer.built,
565
+ msg=(
566
+ f"{type(layer)} is stateless, so it should be built "
567
+ "after instantiation."
568
+ ),
569
+ )
570
+
571
+ # Eager call test and compiled training test.
572
+ if input_data is not None or input_shape is not None:
573
+ if input_data is None:
574
+ input_data = create_eager_tensors(
575
+ input_shape, input_dtype, input_sparse
576
+ )
577
+ layer = layer_cls(**init_kwargs)
578
+ if isinstance(input_data, dict):
579
+ output_data = layer(**input_data, **call_kwargs)
580
+ else:
581
+ output_data = layer(input_data, **call_kwargs)
582
+ run_output_asserts(layer, output_data, eager=True)
583
+
584
+ if run_training_check:
585
+ run_training_step(layer, input_data, output_data)
586
+
587
+ # Never test mixed precision on torch CPU. Torch lacks support.
588
+ if run_mixed_precision_check and backend.backend() == "torch":
589
+ import torch
590
+
591
+ run_mixed_precision_check = torch.cuda.is_available()
592
+
593
+ if run_mixed_precision_check:
594
+ layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"})
595
+ input_spec = tree.map_structure(
596
+ lambda spec: KerasTensor(
597
+ spec.shape,
598
+ dtype=(
599
+ layer.compute_dtype
600
+ if layer.autocast
601
+ and backend.is_float_dtype(spec.dtype)
602
+ else spec.dtype
603
+ ),
604
+ ),
605
+ keras_tensor_inputs,
606
+ )
607
+ if isinstance(input_data, dict):
608
+ output_data = layer(**input_data, **call_kwargs)
609
+ output_spec = layer.compute_output_spec(**input_spec)
610
+ else:
611
+ output_data = layer(input_data, **call_kwargs)
612
+ output_spec = layer.compute_output_spec(input_spec)
613
+ for tensor, spec in zip(
614
+ tree.flatten(output_data), tree.flatten(output_spec)
615
+ ):
616
+ dtype = standardize_dtype(tensor.dtype)
617
+ self.assertEqual(
618
+ dtype,
619
+ spec.dtype,
620
+ f"expected output dtype {spec.dtype}, got {dtype}",
621
+ )
622
+ for weight in layer.weights:
623
+ dtype = standardize_dtype(weight.dtype)
624
+ if is_float_dtype(dtype):
625
+ self.assertEqual(dtype, "float32")
626
+
627
+
628
+ def tensorflow_uses_gpu():
629
+ return backend.backend() == "tensorflow" and uses_gpu()
630
+
631
+
632
+ def jax_uses_gpu():
633
+ return backend.backend() == "jax" and uses_gpu()
634
+
635
+
636
+ def torch_uses_gpu():
637
+ if backend.backend() != "torch":
638
+ return False
639
+ from keras.src.backend.torch.core import get_device
640
+
641
+ return get_device() == "cuda"
642
+
643
+
644
+ def uses_gpu():
645
+ # Condition used to skip tests when using the GPU
646
+ devices = distribution.list_devices()
647
+ if any(d.startswith("gpu") for d in devices):
648
+ return True
649
+ return False
650
+
651
+
652
+ def create_keras_tensors(input_shape, dtype, sparse):
653
+ if isinstance(input_shape, dict):
654
+ return {
655
+ utils.removesuffix(k, "_shape"): KerasTensor(
656
+ v, dtype=dtype[k], sparse=sparse
657
+ )
658
+ for k, v in input_shape.items()
659
+ }
660
+ return map_shape_dtype_structure(
661
+ lambda shape, dt: KerasTensor(shape, dtype=dt, sparse=sparse),
662
+ input_shape,
663
+ dtype,
664
+ )
665
+
666
+
667
+ def create_eager_tensors(input_shape, dtype, sparse):
668
+ from keras.src.backend import random
669
+
670
+ if set(tree.flatten(dtype)).difference(
671
+ [
672
+ "float16",
673
+ "float32",
674
+ "float64",
675
+ "int8",
676
+ "uint8",
677
+ "int16",
678
+ "uint16",
679
+ "int32",
680
+ "uint32",
681
+ "int64",
682
+ "uint64",
683
+ ]
684
+ ):
685
+ raise ValueError(
686
+ "dtype must be a standard float or int dtype. "
687
+ f"Received: dtype={dtype}"
688
+ )
689
+
690
+ if sparse:
691
+ if backend.backend() == "tensorflow":
692
+ import tensorflow as tf
693
+
694
+ def create_fn(shape, dt):
695
+ rng = np.random.default_rng(0)
696
+ x = (4 * rng.standard_normal(shape)).astype(dt)
697
+ x = np.multiply(x, rng.random(shape) < 0.7)
698
+ return tf.sparse.from_dense(x)
699
+
700
+ elif backend.backend() == "jax":
701
+ import jax.experimental.sparse as jax_sparse
702
+
703
+ def create_fn(shape, dt):
704
+ rng = np.random.default_rng(0)
705
+ x = (4 * rng.standard_normal(shape)).astype(dt)
706
+ x = np.multiply(x, rng.random(shape) < 0.7)
707
+ return jax_sparse.BCOO.fromdense(x, n_batch=1)
708
+
709
+ else:
710
+ raise ValueError(
711
+ f"Sparse is unsupported with backend {backend.backend()}"
712
+ )
713
+
714
+ else:
715
+
716
+ def create_fn(shape, dt):
717
+ return ops.cast(
718
+ random.uniform(shape, dtype="float32") * 3, dtype=dt
719
+ )
720
+
721
+ if isinstance(input_shape, dict):
722
+ return {
723
+ utils.removesuffix(k, "_shape"): create_fn(v, dtype[k])
724
+ for k, v in input_shape.items()
725
+ }
726
+ return map_shape_dtype_structure(create_fn, input_shape, dtype)
727
+
728
+
729
+ def is_shape_tuple(x):
730
+ return isinstance(x, (list, tuple)) and all(
731
+ isinstance(e, (int, type(None))) for e in x
732
+ )
733
+
734
+
735
+ def map_shape_dtype_structure(fn, shape, dtype):
736
+ """Variant of tree.map_structure that operates on shape tuples."""
737
+ if is_shape_tuple(shape):
738
+ return fn(tuple(shape), dtype)
739
+ if isinstance(shape, list):
740
+ return [
741
+ map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype)
742
+ ]
743
+ if isinstance(shape, tuple):
744
+ return tuple(
745
+ map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype)
746
+ )
747
+ if isinstance(shape, dict):
748
+ return {
749
+ k: map_shape_dtype_structure(fn, v, dtype[k])
750
+ for k, v in shape.items()
751
+ }
752
+ else:
753
+ raise ValueError(
754
+ f"Cannot map function to unknown objects {shape} and {dtype}"
755
+ )
756
+
757
+
758
+ def get_seed_generators(layer):
759
+ """Get a List of all seed generators in the layer recursively."""
760
+ seed_generators = []
761
+ seen_ids = set()
762
+ for sublayer in layer._flatten_layers(True, True):
763
+ for sg in sublayer._seed_generators:
764
+ if id(sg) not in seen_ids:
765
+ seed_generators.append(sg)
766
+ seen_ids.add(id(sg))
767
+ return seed_generators
768
+
769
+
770
+ def to_json_with_tuples(value):
771
+ def _tuple_encode(obj):
772
+ if isinstance(obj, tuple):
773
+ return {"__class__": "tuple", "__value__": list(obj)}
774
+ if isinstance(obj, list):
775
+ return [_tuple_encode(e) for e in obj]
776
+ if isinstance(obj, dict):
777
+ return {key: _tuple_encode(value) for key, value in obj.items()}
778
+ return obj
779
+
780
+ class _PreserveTupleJsonEncoder(json.JSONEncoder):
781
+ def encode(self, obj):
782
+ obj = _tuple_encode(obj)
783
+ return super().encode(obj)
784
+
785
+ return _PreserveTupleJsonEncoder(sort_keys=True, indent=4).encode(value)
786
+
787
+
788
+ def from_json_with_tuples(value):
789
+ def _tuple_decode(obj):
790
+ if not isinstance(obj, dict):
791
+ return obj
792
+ if "__class__" not in obj or "__value__" not in obj:
793
+ return obj
794
+ return tuple(obj["__value__"])
795
+
796
+ return json.loads(value, object_hook=_tuple_decode)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_test_data(
5
+ train_samples, test_samples, input_shape, num_classes, random_seed=None
6
+ ):
7
+ """Generates balanced, stratified synthetic test data to train a model on.
8
+
9
+ Args:
10
+ train_samples: Integer, how many training samples to generate.
11
+ test_samples: Integer, how many test samples to generate.
12
+ input_shape: Tuple of integers, shape of the inputs.
13
+ num_classes: Integer, number of classes for the data and targets.
14
+ random_seed: Integer, random seed used by Numpy to generate data.
15
+
16
+ Returns:
17
+ A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`.
18
+ """
19
+ np.random.seed(random_seed)
20
+
21
+ # Total samples
22
+ total_samples = train_samples + test_samples
23
+
24
+ # Ensure that we generate a balanced dataset
25
+ samples_per_class = total_samples // num_classes
26
+ y = np.array(
27
+ [i for i in range(num_classes) for _ in range(samples_per_class)],
28
+ dtype=np.int32,
29
+ )
30
+
31
+ # Generate extra samples in a deterministic manner
32
+ extra_samples = total_samples - len(y)
33
+ y_extra = np.array(
34
+ [i % num_classes for i in range(extra_samples)], dtype=np.int64
35
+ )
36
+ y = np.concatenate([y, y_extra])
37
+
38
+ # Generate data
39
+ templates = 2 * num_classes * np.random.random((num_classes,) + input_shape)
40
+ x = np.zeros((total_samples,) + input_shape, dtype=np.float32)
41
+ for i in range(total_samples):
42
+ x[i] = templates[y[i]] + np.random.normal(
43
+ loc=0, scale=1.0, size=input_shape
44
+ )
45
+
46
+ # Shuffle the entire dataset to ensure randomness based on seed
47
+ indices = np.arange(total_samples)
48
+ np.random.shuffle(indices)
49
+ x, y = x[indices], y[indices]
50
+
51
+ # Stratified Shuffle Split
52
+ x_train, y_train, x_test, y_test = [], [], [], []
53
+ for cls in range(num_classes):
54
+ cls_indices = np.where(y == cls)[0]
55
+ np.random.shuffle(cls_indices)
56
+ train_count = int(train_samples / num_classes)
57
+
58
+ x_train.extend(x[cls_indices[:train_count]])
59
+ y_train.extend(y[cls_indices[:train_count]])
60
+
61
+ x_test.extend(x[cls_indices[train_count:]])
62
+ y_test.extend(y[cls_indices[train_count:]])
63
+
64
+ # Convert to numpy arrays
65
+ x_train, y_train = np.array(x_train), np.array(y_train)
66
+ x_test, y_test = np.array(x_test), np.array(y_test)
67
+
68
+ # Shuffle training and test sets after stratified split
69
+ train_indices = np.arange(len(x_train))
70
+ test_indices = np.arange(len(x_test))
71
+ np.random.shuffle(train_indices)
72
+ np.random.shuffle(test_indices)
73
+
74
+ x_train, y_train = x_train[train_indices], y_train[train_indices]
75
+ x_test, y_test = x_test[test_indices], y_test[test_indices]
76
+
77
+ return (x_train, y_train), (x_test, y_test)
78
+
79
+
80
+ def named_product(*args, **kwargs):
81
+ """Utility to generate the cartesian product of parameters values and
82
+ generate a test case names for each combination.
83
+
84
+ The result of this function is to be used with the
85
+ `@parameterized.named_parameters` decorator. It is a replacement for
86
+ `@parameterized.product` which adds explicit test case names.
87
+
88
+ For example, this code:
89
+ ```
90
+ class NamedExample(parameterized.TestCase):
91
+ @parameterized.named_parameters(
92
+ named_product(
93
+ [
94
+ {'testcase_name': 'negative', 'x': -1},
95
+ {'testcase_name': 'positive', 'x': 1},
96
+ {'testcase_name': 'zero', 'x': 0},
97
+ ],
98
+ numeral_type=[float, int],
99
+ )
100
+ )
101
+ def test_conversion(self, x, numeral_type):
102
+ self.assertEqual(numeral_type(x), x)
103
+ ```
104
+ produces six tests (note that absl will reorder them by name):
105
+ - `NamedExample::test_conversion_negative_float`
106
+ - `NamedExample::test_conversion_positive_float`
107
+ - `NamedExample::test_conversion_zero_float`
108
+ - `NamedExample::test_conversion_negative_int`
109
+ - `NamedExample::test_conversion_positive_int`
110
+ - `NamedExample::test_conversion_zero_int`
111
+
112
+ This function is also useful in the case where there is no product to
113
+ generate test case names for one argument:
114
+ ```
115
+ @parameterized.named_parameters(named_product(numeral_type=[float, int]))
116
+ ```
117
+
118
+ Args:
119
+ *args: Each positional parameter is a sequence of keyword arg dicts.
120
+ Every test case generated will include exactly one dict from each
121
+ positional parameter. These will then be merged to form an overall
122
+ list of arguments for the test case. Each dict must contain a
123
+ `"testcase_name"` key whose value is combined with others to
124
+ generate the test case name.
125
+ **kwargs: A mapping of parameter names and their possible values.
126
+ Possible values should given as either a list or a tuple. A string
127
+ representation of each value is used to generate the test case name.
128
+
129
+ Returns:
130
+ A list of maps for the test parameters combinations to pass to
131
+ `@parameterized.named_parameters`.
132
+ """
133
+
134
+ def value_to_str(value):
135
+ if hasattr(value, "__name__"):
136
+ return value.__name__.lower()
137
+ return str(value).lower()
138
+
139
+ # Convert the keyword arguments in the same dict format as the args
140
+ all_test_dicts = args + tuple(
141
+ tuple({"testcase_name": value_to_str(v), key: v} for v in values)
142
+ for key, values in kwargs.items()
143
+ )
144
+
145
+ # The current list of tests, start with one empty test
146
+ tests = [{}]
147
+ for test_dicts in all_test_dicts:
148
+ new_tests = []
149
+ for test_dict in test_dicts:
150
+ for test in tests:
151
+ # Augment the testcase name by appending
152
+ testcase_name = test.get("testcase_name", "")
153
+ testcase_name += "_" if testcase_name else ""
154
+ testcase_name += test_dict["testcase_name"]
155
+ new_test = test.copy()
156
+ # Augment the test by adding all the parameters
157
+ new_test.update(test_dict)
158
+ new_test["testcase_name"] = testcase_name
159
+ new_tests.append(new_test)
160
+ # Overwrite the list of tests with the product obtained so far
161
+ tests = new_tests
162
+
163
+ return tests
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__init__.py ADDED
File without changes
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (194 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc ADDED
Binary file (20.7 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc ADDED
Binary file (4.34 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc ADDED
Binary file (46.1 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import namedtuple
2
+
3
+ from keras.src import losses as losses_module
4
+ from keras.src import metrics as metrics_module
5
+ from keras.src import ops
6
+ from keras.src import tree
7
+ from keras.src.backend.common.keras_tensor import KerasTensor
8
+ from keras.src.losses import loss as loss_module
9
+ from keras.src.utils.naming import get_object_name
10
+ from keras.src.utils.tracking import Tracker
11
+
12
+
13
+ class MetricsList(metrics_module.Metric):
14
+ def __init__(self, metrics, name="metrics_list", output_name=None):
15
+ super().__init__(name=name)
16
+ self.metrics = metrics
17
+ self.output_name = output_name
18
+
19
+ def update_state(self, y_true, y_pred, sample_weight=None):
20
+ for m in self.metrics:
21
+ m.update_state(y_true, y_pred, sample_weight=sample_weight)
22
+
23
+ def reset_state(self):
24
+ for m in self.metrics:
25
+ m.reset_state()
26
+
27
+ def get_result(self):
28
+ return {m.name: m.result() for m in self.metrics}
29
+
30
+ def get_config(self):
31
+ raise NotImplementedError
32
+
33
+ @classmethod
34
+ def from_config(cls, config):
35
+ raise NotImplementedError
36
+
37
+
38
+ def is_function_like(value):
39
+ if value is None:
40
+ return True
41
+ if isinstance(value, str):
42
+ return True
43
+ if callable(value):
44
+ return True
45
+ return False
46
+
47
+
48
+ def is_binary_or_sparse_categorical(y_true, y_pred):
49
+ y_t_rank = len(y_true.shape)
50
+ y_p_rank = len(y_pred.shape)
51
+ y_t_last_dim = y_true.shape[-1]
52
+ y_p_last_dim = y_pred.shape[-1]
53
+
54
+ is_binary = y_p_last_dim == 1
55
+ is_sparse_categorical = (
56
+ y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1
57
+ )
58
+ return is_binary, is_sparse_categorical
59
+
60
+
61
+ def get_metric(identifier, y_true, y_pred):
62
+ if identifier is None:
63
+ return None # Ok to have no metric for an output.
64
+
65
+ # Convenience feature for selecting b/t binary, categorical,
66
+ # and sparse categorical.
67
+ if str(identifier).lower() not in ["accuracy", "acc"]:
68
+ metric_obj = metrics_module.get(identifier)
69
+ else:
70
+ is_binary, is_sparse_categorical = is_binary_or_sparse_categorical(
71
+ y_true, y_pred
72
+ )
73
+ if is_binary:
74
+ metric_obj = metrics_module.BinaryAccuracy(name=str(identifier))
75
+ elif is_sparse_categorical:
76
+ metric_obj = metrics_module.SparseCategoricalAccuracy(
77
+ name=str(identifier)
78
+ )
79
+ else:
80
+ metric_obj = metrics_module.CategoricalAccuracy(
81
+ name=str(identifier)
82
+ )
83
+
84
+ if isinstance(identifier, str):
85
+ metric_name = identifier
86
+ else:
87
+ metric_name = get_object_name(metric_obj)
88
+
89
+ if not isinstance(metric_obj, metrics_module.Metric):
90
+ metric_obj = metrics_module.MeanMetricWrapper(metric_obj)
91
+
92
+ metric_obj.name = metric_name
93
+ return metric_obj
94
+
95
+
96
+ def get_loss(identifier, y_true, y_pred):
97
+ if identifier is None:
98
+ return None # Ok to have no loss for an output.
99
+
100
+ # Convenience feature for selecting b/t binary, categorical,
101
+ # and sparse categorical.
102
+ if str(identifier).lower() not in ["crossentropy", "ce"]:
103
+ loss_obj = losses_module.get(identifier)
104
+ else:
105
+ is_binary, is_sparse_categorical = is_binary_or_sparse_categorical(
106
+ y_true, y_pred
107
+ )
108
+ if is_binary:
109
+ loss_obj = losses_module.binary_crossentropy
110
+ elif is_sparse_categorical:
111
+ loss_obj = losses_module.sparse_categorical_crossentropy
112
+ else:
113
+ loss_obj = losses_module.categorical_crossentropy
114
+
115
+ if not isinstance(loss_obj, losses_module.Loss):
116
+ if isinstance(identifier, str):
117
+ loss_name = identifier
118
+ else:
119
+ loss_name = get_object_name(loss_obj)
120
+ loss_obj = losses_module.LossFunctionWrapper(loss_obj, name=loss_name)
121
+ return loss_obj
122
+
123
+
124
+ class CompileMetrics(metrics_module.Metric):
125
+ def __init__(
126
+ self,
127
+ metrics,
128
+ weighted_metrics,
129
+ name="compile_metric",
130
+ output_names=None,
131
+ ):
132
+ super().__init__(name=name)
133
+ if metrics and not isinstance(metrics, (list, tuple, dict)):
134
+ raise ValueError(
135
+ "Expected `metrics` argument to be a list, tuple, or dict. "
136
+ f"Received instead: metrics={metrics} of type {type(metrics)}"
137
+ )
138
+ if weighted_metrics and not isinstance(
139
+ weighted_metrics, (list, tuple, dict)
140
+ ):
141
+ raise ValueError(
142
+ "Expected `weighted_metrics` argument to be a list, tuple, or "
143
+ f"dict. Received instead: weighted_metrics={weighted_metrics} "
144
+ f"of type {type(weighted_metrics)}"
145
+ )
146
+ self._user_metrics = metrics
147
+ self._user_weighted_metrics = weighted_metrics
148
+ self.built = False
149
+ self.name = "compile_metrics"
150
+ self.output_names = output_names
151
+
152
+ @property
153
+ def metrics(self):
154
+ if not self.built:
155
+ return []
156
+ metrics = []
157
+ for m in self._flat_metrics + self._flat_weighted_metrics:
158
+ if isinstance(m, MetricsList):
159
+ metrics.extend(m.metrics)
160
+ elif m is not None:
161
+ metrics.append(m)
162
+ return metrics
163
+
164
+ @property
165
+ def variables(self):
166
+ # Avoiding relying on implicit tracking since
167
+ # CompileMetrics may be instantiated or built in a no tracking scope.
168
+ if not self.built:
169
+ return []
170
+ vars = []
171
+ for m in self.metrics:
172
+ if m is not None:
173
+ vars.extend(m.variables)
174
+ return vars
175
+
176
+ def build(self, y_true, y_pred):
177
+ num_outputs = 1 # default
178
+ if self.output_names:
179
+ output_names = self.output_names
180
+ elif isinstance(y_pred, dict):
181
+ output_names = sorted(list(y_pred.keys()))
182
+ elif isinstance(y_pred, (list, tuple)):
183
+ num_outputs = len(y_pred)
184
+ if all(hasattr(x, "_keras_history") for x in y_pred):
185
+ output_names = [x._keras_history.operation.name for x in y_pred]
186
+ else:
187
+ output_names = None
188
+ else:
189
+ output_names = None
190
+ if output_names:
191
+ num_outputs = len(output_names)
192
+
193
+ y_pred = self._flatten_y(y_pred)
194
+ y_true = self._flatten_y(y_true)
195
+
196
+ metrics = self._user_metrics
197
+ weighted_metrics = self._user_weighted_metrics
198
+ self._flat_metrics = self._build_metrics_set(
199
+ metrics,
200
+ num_outputs,
201
+ output_names,
202
+ y_true,
203
+ y_pred,
204
+ argument_name="metrics",
205
+ )
206
+ self._flat_weighted_metrics = self._build_metrics_set(
207
+ weighted_metrics,
208
+ num_outputs,
209
+ output_names,
210
+ y_true,
211
+ y_pred,
212
+ argument_name="weighted_metrics",
213
+ )
214
+ self.built = True
215
+
216
+ def _build_metrics_set(
217
+ self, metrics, num_outputs, output_names, y_true, y_pred, argument_name
218
+ ):
219
+ flat_metrics = []
220
+ if isinstance(metrics, dict):
221
+ for name in metrics.keys():
222
+ if name not in output_names:
223
+ raise ValueError(
224
+ f"In the dict argument `{argument_name}`, key "
225
+ f"'{name}' does not correspond to any model "
226
+ f"output. Received:\n{argument_name}={metrics}"
227
+ )
228
+ if num_outputs == 1:
229
+ if not metrics:
230
+ flat_metrics.append(None)
231
+ else:
232
+ if isinstance(metrics, dict):
233
+ metrics = tree.flatten(metrics)
234
+ if not isinstance(metrics, list):
235
+ metrics = [metrics]
236
+ if not all(is_function_like(m) for m in metrics):
237
+ raise ValueError(
238
+ f"Expected all entries in the `{argument_name}` list "
239
+ f"to be metric objects. Received instead:\n"
240
+ f"{argument_name}={metrics}"
241
+ )
242
+ flat_metrics.append(
243
+ MetricsList(
244
+ [
245
+ get_metric(m, y_true[0], y_pred[0])
246
+ for m in metrics
247
+ if m is not None
248
+ ]
249
+ )
250
+ )
251
+ else:
252
+ if isinstance(metrics, (list, tuple)):
253
+ if len(metrics) != len(y_pred):
254
+ raise ValueError(
255
+ "For a model with multiple outputs, "
256
+ f"when providing the `{argument_name}` argument as a "
257
+ "list, it should have as many entries as the model has "
258
+ f"outputs. Received:\n{argument_name}={metrics}\nof "
259
+ f"length {len(metrics)} whereas the model has "
260
+ f"{len(y_pred)} outputs."
261
+ )
262
+ for idx, (mls, yt, yp) in enumerate(
263
+ zip(metrics, y_true, y_pred)
264
+ ):
265
+ if not isinstance(mls, list):
266
+ mls = [mls]
267
+ name = output_names[idx] if output_names else None
268
+ if not all(is_function_like(e) for e in mls):
269
+ raise ValueError(
270
+ f"All entries in the sublists of the "
271
+ f"`{argument_name}` list should be metric objects. "
272
+ f"Found the following sublist with unknown "
273
+ f"types: {mls}"
274
+ )
275
+ flat_metrics.append(
276
+ MetricsList(
277
+ [
278
+ get_metric(m, yt, yp)
279
+ for m in mls
280
+ if m is not None
281
+ ],
282
+ output_name=name,
283
+ )
284
+ )
285
+ elif isinstance(metrics, dict):
286
+ if output_names is None:
287
+ raise ValueError(
288
+ f"Argument `{argument_name}` can only be provided as a "
289
+ "dict when the model also returns a dict of outputs. "
290
+ f"Received {argument_name}={metrics}"
291
+ )
292
+ for name in metrics.keys():
293
+ if not isinstance(metrics[name], list):
294
+ metrics[name] = [metrics[name]]
295
+ if not all(is_function_like(e) for e in metrics[name]):
296
+ raise ValueError(
297
+ f"All entries in the sublists of the "
298
+ f"`{argument_name}` dict should be metric objects. "
299
+ f"At key '{name}', found the following sublist "
300
+ f"with unknown types: {metrics[name]}"
301
+ )
302
+ for name, yt, yp in zip(output_names, y_true, y_pred):
303
+ if name in metrics:
304
+ flat_metrics.append(
305
+ MetricsList(
306
+ [
307
+ get_metric(m, yt, yp)
308
+ for m in metrics[name]
309
+ if m is not None
310
+ ],
311
+ output_name=name,
312
+ )
313
+ )
314
+ else:
315
+ flat_metrics.append(None)
316
+ return flat_metrics
317
+
318
+ def _flatten_y(self, y):
319
+ if isinstance(y, dict) and self.output_names:
320
+ result = []
321
+ for name in self.output_names:
322
+ if name in y:
323
+ result.append(y[name])
324
+ return result
325
+ return tree.flatten(y)
326
+
327
+ def update_state(self, y_true, y_pred, sample_weight=None):
328
+ if not self.built:
329
+ self.build(y_true, y_pred)
330
+ y_true = self._flatten_y(y_true)
331
+ y_pred = self._flatten_y(y_pred)
332
+ for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred):
333
+ if m:
334
+ m.update_state(y_t, y_p)
335
+ if sample_weight is not None:
336
+ sample_weight = self._flatten_y(sample_weight)
337
+ # For multi-outputs, repeat sample weights for n outputs.
338
+ if len(sample_weight) < len(y_true):
339
+ sample_weight = [sample_weight[0] for _ in range(len(y_true))]
340
+ else:
341
+ sample_weight = [None for _ in range(len(y_true))]
342
+ for m, y_t, y_p, s_w in zip(
343
+ self._flat_weighted_metrics, y_true, y_pred, sample_weight
344
+ ):
345
+ if m:
346
+ m.update_state(y_t, y_p, s_w)
347
+
348
+ def reset_state(self):
349
+ if not self.built:
350
+ return
351
+ for m in self._flat_metrics:
352
+ if m:
353
+ m.reset_state()
354
+ for m in self._flat_weighted_metrics:
355
+ if m:
356
+ m.reset_state()
357
+
358
+ def result(self):
359
+ if not self.built:
360
+ raise ValueError(
361
+ "Cannot get result() since the metric has not yet been built."
362
+ )
363
+ results = {}
364
+ unique_name_counters = {}
365
+ for mls in self._flat_metrics:
366
+ if not mls:
367
+ continue
368
+ for m in mls.metrics:
369
+ name = m.name
370
+ if mls.output_name:
371
+ name = f"{mls.output_name}_{name}"
372
+ if name not in unique_name_counters:
373
+ results[name] = m.result()
374
+ unique_name_counters[name] = 1
375
+ else:
376
+ index = unique_name_counters[name]
377
+ unique_name_counters[name] += 1
378
+ name = f"{name}_{index}"
379
+ results[name] = m.result()
380
+
381
+ for mls in self._flat_weighted_metrics:
382
+ if not mls:
383
+ continue
384
+ for m in mls.metrics:
385
+ name = m.name
386
+ if mls.output_name:
387
+ name = f"{mls.output_name}_{name}"
388
+ if name not in unique_name_counters:
389
+ results[name] = m.result()
390
+ unique_name_counters[name] = 1
391
+ else:
392
+ name = f"weighted_{m.name}"
393
+ if mls.output_name:
394
+ name = f"{mls.output_name}_{name}"
395
+ if name not in unique_name_counters:
396
+ unique_name_counters[name] = 1
397
+ else:
398
+ index = unique_name_counters[name]
399
+ unique_name_counters[name] += 1
400
+ name = f"{name}_{index}"
401
+ results[name] = m.result()
402
+ return results
403
+
404
+ def get_config(self):
405
+ raise NotImplementedError
406
+
407
+ @classmethod
408
+ def from_config(cls, config):
409
+ raise NotImplementedError
410
+
411
+
412
+ class CompileLoss(losses_module.Loss):
413
+ Loss = namedtuple("Loss", ["path", "loss", "loss_weights", "name"])
414
+
415
+ def __init__(
416
+ self,
417
+ loss,
418
+ loss_weights=None,
419
+ reduction="sum_over_batch_size",
420
+ output_names=None,
421
+ ):
422
+ if loss_weights and not isinstance(
423
+ loss_weights, (list, tuple, dict, float)
424
+ ):
425
+ raise ValueError(
426
+ "Expected `loss_weights` argument to be a float "
427
+ "(single output case) or a list, tuple, or "
428
+ "dict (multiple output case). "
429
+ f"Received instead: loss_weights={loss_weights} "
430
+ f"of type {type(loss_weights)}"
431
+ )
432
+ self._user_loss = loss
433
+ self._user_loss_weights = loss_weights
434
+ self.built = False
435
+ self.output_names = output_names
436
+ super().__init__(name="compile_loss", reduction=reduction)
437
+
438
+ # Use `Tracker` to track metrics for individual losses.
439
+ self._metrics = []
440
+ self._tracker = Tracker(
441
+ {
442
+ "metrics": (
443
+ lambda x: isinstance(x, metrics_module.Metric),
444
+ self._metrics,
445
+ )
446
+ }
447
+ )
448
+ self._flat_losses = None
449
+ self._y_pred_build_structure = None
450
+ self._y_true_build_structure = None
451
+
452
+ @property
453
+ def metrics(self):
454
+ return self._metrics
455
+
456
+ @property
457
+ def variables(self):
458
+ vars = []
459
+ for m in self.metrics:
460
+ vars.extend(m.variables)
461
+ return vars
462
+
463
+ def _build_nested(self, y_true, y_pred, loss, output_names, current_path):
464
+ flat_y_pred = tree.flatten(y_pred)
465
+ if not tree.is_nested(loss):
466
+ _loss = loss.loss
467
+ if _loss is None:
468
+ return
469
+ loss_weight = loss.weight
470
+ resolved_loss = get_loss(_loss, y_true, y_pred)
471
+ name_path = current_path
472
+ if not tree.is_nested(output_names):
473
+ if output_names is not None:
474
+ output_name = output_names
475
+ else:
476
+ output_name = resolved_loss.name
477
+ if len(name_path) == 0:
478
+ name_path = (output_name,)
479
+ elif isinstance(name_path[-1], int):
480
+ name_path = name_path[:-1] + (output_name,)
481
+ name = "/".join([str(path) for path in name_path])
482
+ if name == "":
483
+ if isinstance(output_names, dict):
484
+ flat_output_names = list(output_names.keys())
485
+ else:
486
+ flat_output_names = tree.flatten(output_names)
487
+ name = "_".join(flat_output_names)
488
+ self._flat_losses.append(
489
+ CompileLoss.Loss(current_path, resolved_loss, loss_weight, name)
490
+ )
491
+ return
492
+ elif (
493
+ issubclass(type(loss), (list, tuple))
494
+ and all([not tree.is_nested(_loss) for _loss in loss])
495
+ and len(loss) == len(flat_y_pred)
496
+ ):
497
+ loss = tree.pack_sequence_as(y_pred, loss)
498
+ elif issubclass(type(loss), (list, tuple)) and not isinstance(
499
+ y_pred, type(loss)
500
+ ):
501
+ for _loss in loss:
502
+ self._build_nested(
503
+ y_true,
504
+ y_pred,
505
+ _loss,
506
+ output_names,
507
+ current_path,
508
+ )
509
+ return
510
+
511
+ if not tree.is_nested(loss):
512
+ return self._build_nested(
513
+ y_true, y_pred, loss, output_names, current_path
514
+ )
515
+
516
+ if not isinstance(loss, type(y_pred)):
517
+ raise KeyError(
518
+ f"The path: {current_path} in "
519
+ "the `loss` argument, can't be found in "
520
+ "the model's output (`y_pred`)."
521
+ )
522
+
523
+ # shallow traverse the loss config
524
+ if isinstance(loss, dict):
525
+ iterator = loss.items()
526
+
527
+ def key_check_fn(key, objs):
528
+ return all(
529
+ [isinstance(obj, dict) and key in obj for obj in objs]
530
+ )
531
+
532
+ elif issubclass(type(loss), (list, tuple)):
533
+ iterator = enumerate(loss)
534
+
535
+ def key_check_fn(key, objs):
536
+ return all(
537
+ [
538
+ issubclass(type(obj), (list, tuple)) and key < len(obj)
539
+ for obj in objs
540
+ ]
541
+ )
542
+
543
+ else:
544
+ raise TypeError(
545
+ f"Unsupported type {type(loss)} "
546
+ f"in the `loss` configuration."
547
+ )
548
+
549
+ for key, _loss in iterator:
550
+ if _loss is None:
551
+ continue
552
+ if not key_check_fn(key, (y_true, y_pred)):
553
+ raise KeyError(
554
+ f"The path: {current_path + (key,)} in "
555
+ "the `loss` argument, can't be found in "
556
+ "either the model's output (`y_pred`) or in the "
557
+ "labels (`y_true`)."
558
+ )
559
+
560
+ self._build_nested(
561
+ y_true[key],
562
+ y_pred[key],
563
+ _loss,
564
+ output_names[key],
565
+ current_path + (key,),
566
+ )
567
+
568
+ def build(self, y_true, y_pred):
569
+ loss = self._user_loss
570
+ loss_weights = self._user_loss_weights
571
+ flat_output_names = self.output_names
572
+ if (
573
+ self.output_names
574
+ and isinstance(self._user_loss, dict)
575
+ and not isinstance(y_pred, dict)
576
+ ):
577
+ if set(self.output_names) == set(self._user_loss.keys()):
578
+ loss = [self._user_loss[name] for name in self.output_names]
579
+ if isinstance(self._user_loss_weights, dict):
580
+ loss_weights = [
581
+ self._user_loss_weights[name]
582
+ for name in self.output_names
583
+ ]
584
+ else:
585
+ raise ValueError(
586
+ f"Expected keys {self.output_names} in loss dict, but "
587
+ f"found loss.keys()={list(self._user_loss.keys())}"
588
+ )
589
+
590
+ # Pytree leaf container
591
+ class WeightedLoss:
592
+ def __new__(cls, loss, weight):
593
+ if loss is None:
594
+ return None
595
+ return object.__new__(cls)
596
+
597
+ def __init__(self, loss, weight):
598
+ self.loss = loss
599
+ self.weight = weight
600
+
601
+ # pack the losses and the weights together
602
+ if loss_weights is not None:
603
+ try:
604
+ tree.assert_same_structure(loss, loss_weights)
605
+ except ValueError:
606
+ flat_loss_weights = tree.flatten(loss_weights)
607
+ if len(tree.flatten(loss)) != len(flat_loss_weights):
608
+ raise ValueError(
609
+ f"`loss_weights` must match the number of losses, "
610
+ f"got {len(tree.flatten(loss))} losses "
611
+ f"and {len(loss_weights)} weights."
612
+ )
613
+ loss_weights = tree.pack_sequence_as(loss, flat_loss_weights)
614
+ loss = tree.map_structure(
615
+ lambda _loss, _weight: WeightedLoss(_loss, _weight),
616
+ loss,
617
+ loss_weights,
618
+ )
619
+ else:
620
+ loss = tree.map_structure(
621
+ lambda _loss: WeightedLoss(_loss, None), loss
622
+ )
623
+
624
+ self._flat_losses = []
625
+
626
+ if (
627
+ isinstance(loss, dict)
628
+ and issubclass(type(y_pred), (list, tuple))
629
+ and set(loss.keys()) == set(flat_output_names)
630
+ and len(y_pred) == len(flat_output_names)
631
+ ):
632
+ y_pred = {name: y_p for name, y_p in zip(flat_output_names, y_pred)}
633
+ y_true = {name: y_t for name, y_t in zip(flat_output_names, y_true)}
634
+ elif (
635
+ isinstance(loss, dict)
636
+ and not tree.is_nested(y_pred)
637
+ and set(loss.keys()) == set(flat_output_names)
638
+ and len(flat_output_names) == 1
639
+ ):
640
+ y_pred = {
641
+ name: y_p for name, y_p in zip(flat_output_names, [y_pred])
642
+ }
643
+ y_true = {
644
+ name: y_t for name, y_t in zip(flat_output_names, [y_true])
645
+ }
646
+
647
+ try:
648
+ output_names = tree.pack_sequence_as(y_pred, flat_output_names)
649
+ except:
650
+ inferred_flat_output_names = self._get_y_pred_output_names(y_pred)
651
+ output_names = tree.pack_sequence_as(
652
+ y_pred, inferred_flat_output_names
653
+ )
654
+
655
+ if not tree.is_nested(loss):
656
+ loss = tree.map_structure(lambda x: loss, y_pred)
657
+
658
+ self._build_nested(y_true, y_pred, loss, output_names, ())
659
+
660
+ # Add `Mean` metric to the tracker for each loss.
661
+ if len(self._flat_losses) > 1:
662
+ for _loss in self._flat_losses:
663
+ name = _loss.name + "_loss"
664
+ self._tracker.add_to_store(
665
+ "metrics", metrics_module.Mean(name=name)
666
+ )
667
+
668
+ self._y_pred_build_structure = tree.map_structure(
669
+ lambda x: None, y_pred
670
+ )
671
+ self._y_true_build_structure = tree.map_structure(
672
+ lambda x: None, y_true
673
+ )
674
+ self.built = True
675
+
676
+ def _get_y_pred_output_names(self, y_pred):
677
+ flat_y_pred = tree.flatten(y_pred)
678
+ if all((isinstance(x, KerasTensor) for x in flat_y_pred)):
679
+ output_names = []
680
+ for tensor in flat_y_pred:
681
+ if hasattr(tensor, "_keras_history"):
682
+ output_names.append(tensor._keras_history.operation.name)
683
+ else:
684
+ output_names.append(tensor.name)
685
+ else:
686
+ output_names = [None] * len(flat_y_pred)
687
+ return output_names
688
+
689
+ def __call__(self, y_true, y_pred, sample_weight=None):
690
+ with ops.name_scope(self.name):
691
+ return self.call(y_true, y_pred, sample_weight)
692
+
693
+ def call(self, y_true, y_pred, sample_weight=None):
694
+ if not tree.is_nested(y_true) and not tree.is_nested(y_pred):
695
+ # Fast path: single output case / no loss-tracking metric.
696
+ if not self.built:
697
+ self.build(y_true, y_pred)
698
+ _, loss_fn, loss_weight, _ = self._flat_losses[0]
699
+ loss_value = ops.cast(
700
+ loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype
701
+ )
702
+ if loss_weight is not None:
703
+ loss_value = ops.multiply(loss_value, loss_weight)
704
+ return loss_value
705
+
706
+ try:
707
+ tree.assert_same_structure(y_pred, y_true)
708
+ except ValueError:
709
+ # Check case where y_true is either flat or leaf
710
+ if (
711
+ not tree.is_nested(y_true)
712
+ and hasattr(y_pred, "__len__")
713
+ and len(y_pred) == 1
714
+ ):
715
+ y_true = [y_true]
716
+
717
+ # Check case where y_pred is list/tuple and y_true is dict
718
+ elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict):
719
+ if set(self.output_names) == set(y_true.keys()):
720
+ y_true = [y_true[name] for name in self.output_names]
721
+
722
+ try:
723
+ y_true = tree.pack_sequence_as(y_pred, y_true)
724
+ except:
725
+ # Check case where y_true has the same structure but uses
726
+ # different (but reconcilable) container types,
727
+ # e.g `list` vs `tuple`.
728
+ try:
729
+ tree.assert_same_paths(y_true, y_pred)
730
+ y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true))
731
+ except:
732
+ try:
733
+ # Check case where loss is partially defined over y_pred
734
+ flat_y_true = tree.flatten(y_true)
735
+ flat_loss = tree.flatten(self._user_loss)
736
+ flat_loss_non_nones = [
737
+ (i, loss)
738
+ for i, loss in enumerate(flat_loss)
739
+ if loss is not None
740
+ ]
741
+ assert len(flat_y_true) == len(flat_loss_non_nones)
742
+ y_true = [None] * len(flat_loss)
743
+ for y_t, (i, loss) in zip(
744
+ flat_y_true, flat_loss_non_nones
745
+ ):
746
+ y_true[i] = y_t
747
+ y_true = tree.pack_sequence_as(self._user_loss, y_true)
748
+ except:
749
+ y_true_struct = tree.map_structure(
750
+ lambda _: "*", y_true
751
+ )
752
+ y_pred_struct = tree.map_structure(
753
+ lambda _: "*", y_pred
754
+ )
755
+ raise ValueError(
756
+ "y_true and y_pred have different structures.\n"
757
+ f"y_true: {y_true_struct}\n"
758
+ f"y_pred: {y_pred_struct}\n"
759
+ )
760
+
761
+ if not self.built:
762
+ self.build(y_true, y_pred)
763
+
764
+ try:
765
+ tree.assert_same_structure(self._y_pred_build_structure, y_pred)
766
+ except ValueError:
767
+ y_pred = tree.pack_sequence_as(
768
+ self._y_pred_build_structure, tree.flatten(y_pred)
769
+ )
770
+ try:
771
+ tree.assert_same_structure(self._y_true_build_structure, y_true)
772
+ except ValueError:
773
+ y_true = tree.pack_sequence_as(
774
+ self._y_true_build_structure, tree.flatten(y_true)
775
+ )
776
+
777
+ # We need to add a dummy `None` if the model has only a single output.
778
+ metrics = [None] if len(self.metrics) == 0 else self.metrics
779
+
780
+ # Iterate all losses in flat form.
781
+ loss_values = []
782
+
783
+ def resolve_path(path, object):
784
+ for _path in path:
785
+ object = object[_path]
786
+ return object
787
+
788
+ for (path, loss_fn, loss_weight, _), metric in zip(
789
+ self._flat_losses, metrics
790
+ ):
791
+ y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred)
792
+ if sample_weight is not None and tree.is_nested(sample_weight):
793
+ _sample_weight = resolve_path(path, sample_weight)
794
+ else:
795
+ _sample_weight = sample_weight
796
+
797
+ value = ops.cast(
798
+ loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype
799
+ )
800
+ # Record *unweighted* individual losses.
801
+ if metric:
802
+ metric.update_state(
803
+ loss_module.unscale_loss_for_distribution(value),
804
+ sample_weight=tree.flatten(y_p)[0].shape[0],
805
+ )
806
+ if loss_weight is not None:
807
+ value = ops.multiply(value, loss_weight)
808
+ loss_values.append(value)
809
+
810
+ if loss_values:
811
+ total_loss = sum(loss_values)
812
+ return total_loss
813
+ return None
814
+
815
+ def get_config(self):
816
+ raise NotImplementedError
817
+
818
+ @classmethod
819
+ def from_config(cls, config):
820
+ raise NotImplementedError
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import types
2
+
3
+ from keras.src.distribution import distribution_lib
4
+ from keras.src.trainers.data_adapters import array_data_adapter
5
+ from keras.src.trainers.data_adapters import data_adapter
6
+ from keras.src.trainers.data_adapters import py_dataset_adapter
7
+ from keras.src.trainers.data_adapters.array_data_adapter import ArrayDataAdapter
8
+ from keras.src.trainers.data_adapters.generator_data_adapter import (
9
+ GeneratorDataAdapter,
10
+ )
11
+ from keras.src.trainers.data_adapters.py_dataset_adapter import PyDatasetAdapter
12
+ from keras.src.trainers.data_adapters.tf_dataset_adapter import TFDatasetAdapter
13
+ from keras.src.trainers.data_adapters.torch_data_loader_adapter import (
14
+ TorchDataLoaderAdapter,
15
+ )
16
+
17
+
18
+ def get_data_adapter(
19
+ x,
20
+ y=None,
21
+ sample_weight=None,
22
+ batch_size=None,
23
+ steps_per_epoch=None,
24
+ shuffle=False,
25
+ class_weight=None,
26
+ ):
27
+ # Allow passing a custom data adapter.
28
+ if isinstance(x, data_adapter.DataAdapter):
29
+ return x
30
+
31
+ # Check for multi-process/worker distribution. Since only tf.dataset
32
+ # is supported at the moment, we will raise error if the inputs fail
33
+ # the type check
34
+ distribution = distribution_lib.distribution()
35
+ if getattr(distribution, "_is_multi_process", False) and not is_tf_dataset(
36
+ x
37
+ ):
38
+ raise ValueError(
39
+ "When using multi-worker distribution, the data must be provided "
40
+ f"as a `tf.data.Dataset` instance. Received: type(x)={type(x)}."
41
+ )
42
+
43
+ if array_data_adapter.can_convert_arrays((x, y, sample_weight)):
44
+ return ArrayDataAdapter(
45
+ x,
46
+ y,
47
+ sample_weight=sample_weight,
48
+ class_weight=class_weight,
49
+ shuffle=shuffle,
50
+ batch_size=batch_size,
51
+ steps=steps_per_epoch,
52
+ )
53
+ elif is_tf_dataset(x):
54
+ # Unsupported args: y, sample_weight, shuffle
55
+ if y is not None:
56
+ raise_unsupported_arg("y", "the targets", "tf.data.Dataset")
57
+ if sample_weight is not None:
58
+ raise_unsupported_arg(
59
+ "sample_weights", "the sample weights", "tf.data.Dataset"
60
+ )
61
+ return TFDatasetAdapter(
62
+ x, class_weight=class_weight, distribution=distribution
63
+ )
64
+ # TODO: should we warn or not?
65
+ # warnings.warn(
66
+ # "`shuffle=True` was passed, but will be ignored since the "
67
+ # "data `x` was provided as a tf.data.Dataset. The Dataset is "
68
+ # "expected to already be shuffled "
69
+ # "(via `.shuffle(tf.data.AUTOTUNE)`)"
70
+ # )
71
+ elif isinstance(x, py_dataset_adapter.PyDataset):
72
+ if y is not None:
73
+ raise_unsupported_arg("y", "the targets", "PyDataset")
74
+ if sample_weight is not None:
75
+ raise_unsupported_arg(
76
+ "sample_weights", "the sample weights", "PyDataset"
77
+ )
78
+ return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle)
79
+ # TODO: should we warn or not?
80
+ # if x.num_batches is None and shuffle:
81
+ # warnings.warn(
82
+ # "`shuffle=True` was passed, but will be ignored since the "
83
+ # "data `x` was provided as a infinite PyDataset. The "
84
+ # "PyDataset is expected to already be shuffled."
85
+ # )
86
+ elif is_torch_dataloader(x):
87
+ if y is not None:
88
+ raise_unsupported_arg("y", "the targets", "torch DataLoader")
89
+ if sample_weight is not None:
90
+ raise_unsupported_arg(
91
+ "sample_weights", "the sample weights", "torch DataLoader"
92
+ )
93
+ if class_weight is not None:
94
+ raise ValueError(
95
+ "Argument `class_weight` is not supported for torch "
96
+ f"DataLoader inputs. Received: class_weight={class_weight}"
97
+ )
98
+ return TorchDataLoaderAdapter(x)
99
+ # TODO: should we warn or not?
100
+ # warnings.warn(
101
+ # "`shuffle=True` was passed, but will be ignored since the "
102
+ # "data `x` was provided as a torch DataLoader. The DataLoader "
103
+ # "is expected to already be shuffled."
104
+ # )
105
+ elif isinstance(x, types.GeneratorType):
106
+ if y is not None:
107
+ raise_unsupported_arg("y", "the targets", "PyDataset")
108
+ if sample_weight is not None:
109
+ raise_unsupported_arg(
110
+ "sample_weights", "the sample weights", "PyDataset"
111
+ )
112
+ if class_weight is not None:
113
+ raise ValueError(
114
+ "Argument `class_weight` is not supported for Python "
115
+ f"generator inputs. Received: class_weight={class_weight}"
116
+ )
117
+ return GeneratorDataAdapter(x)
118
+ # TODO: should we warn or not?
119
+ # warnings.warn(
120
+ # "`shuffle=True` was passed, but will be ignored since the "
121
+ # "data `x` was provided as a generator. The generator "
122
+ # "is expected to yield already-shuffled data."
123
+ # )
124
+ else:
125
+ raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})")
126
+
127
+
128
+ def raise_unsupported_arg(arg_name, arg_description, input_type):
129
+ raise ValueError(
130
+ f"When providing `x` as a {input_type}, `{arg_name}` "
131
+ f"should not be passed. Instead, {arg_description} should "
132
+ f"be included as part of the {input_type}."
133
+ )
134
+
135
+
136
+ def is_tf_dataset(x):
137
+ if hasattr(x, "__class__"):
138
+ for parent in x.__class__.__mro__:
139
+ if parent.__name__ in (
140
+ "DatasetV2",
141
+ "DistributedDataset",
142
+ ) and "tensorflow.python." in str(parent.__module__):
143
+ return True
144
+ return False
145
+
146
+
147
+ def is_torch_dataloader(x):
148
+ if hasattr(x, "__class__"):
149
+ for parent in x.__class__.__mro__:
150
+ if parent.__name__ == "DataLoader" and "torch.utils.data" in str(
151
+ parent.__module__
152
+ ):
153
+ return True
154
+ return False
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (3.2 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc ADDED
Binary file (11.9 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc ADDED
Binary file (16.9 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc ADDED
Binary file (4.17 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc ADDED
Binary file (3.63 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc ADDED
Binary file (20.3 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc ADDED
Binary file (5.46 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc ADDED
Binary file (3.09 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import math
3
+
4
+ import numpy as np
5
+
6
+ from keras.src import tree
7
+ from keras.src.trainers.data_adapters import array_slicing
8
+ from keras.src.trainers.data_adapters import data_adapter_utils
9
+ from keras.src.trainers.data_adapters.data_adapter import DataAdapter
10
+
11
+
12
+ class ArrayDataAdapter(DataAdapter):
13
+ """Adapter for array-like objects, e.g. TF/JAX Tensors, NumPy arrays."""
14
+
15
+ def __init__(
16
+ self,
17
+ x,
18
+ y=None,
19
+ sample_weight=None,
20
+ batch_size=None,
21
+ steps=None,
22
+ shuffle=False,
23
+ class_weight=None,
24
+ ):
25
+ if not can_convert_arrays((x, y, sample_weight)):
26
+ raise ValueError(
27
+ "Expected all elements of `x` to be array-like. "
28
+ f"Received invalid types: x={x}"
29
+ )
30
+
31
+ if sample_weight is not None:
32
+ if class_weight is not None:
33
+ raise ValueError(
34
+ "You cannot `class_weight` and `sample_weight` "
35
+ "at the same time."
36
+ )
37
+ if tree.is_nested(y):
38
+ if isinstance(sample_weight, (list, tuple, dict)):
39
+ try:
40
+ tree.assert_same_structure(y, sample_weight)
41
+ except ValueError:
42
+ raise ValueError(
43
+ "You should provide one `sample_weight` array per "
44
+ "output in `y`. The two structures did not match:\n"
45
+ f"- y: {y}\n"
46
+ f"- sample_weight: {sample_weight}\n"
47
+ )
48
+ else:
49
+ is_samplewise = len(sample_weight.shape) == 1 or (
50
+ len(sample_weight.shape) == 2
51
+ and sample_weight.shape[1] == 1
52
+ )
53
+ if not is_samplewise:
54
+ raise ValueError(
55
+ "For a model with multiple outputs, when providing "
56
+ "a single `sample_weight` array, it should only "
57
+ "have one scalar score per sample "
58
+ "(i.e. shape `(num_samples,)`). If you want to use "
59
+ "non-scalar sample weights, pass a `sample_weight` "
60
+ "argument with one array per model output."
61
+ )
62
+ # Replicate the same sample_weight array on all outputs.
63
+ sample_weight = tree.map_structure(
64
+ lambda _: sample_weight, y
65
+ )
66
+ if class_weight is not None:
67
+ if tree.is_nested(y):
68
+ raise ValueError(
69
+ "`class_weight` is only supported for Models with a single "
70
+ "output."
71
+ )
72
+ sample_weight = data_adapter_utils.class_weight_to_sample_weights(
73
+ y, class_weight
74
+ )
75
+
76
+ inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight)
77
+
78
+ data_adapter_utils.check_data_cardinality(inputs)
79
+ num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop()
80
+ self._num_samples = num_samples
81
+ self._inputs = inputs
82
+
83
+ # If batch_size is not passed but steps is, calculate from the input
84
+ # data. Defaults to `32` for backwards compatibility.
85
+ if not batch_size:
86
+ batch_size = int(math.ceil(num_samples / steps)) if steps else 32
87
+
88
+ self._size = int(math.ceil(num_samples / batch_size))
89
+ self._batch_size = batch_size
90
+ self._partial_batch_size = num_samples % batch_size
91
+ self._shuffle = shuffle
92
+
93
+ def get_numpy_iterator(self):
94
+ inputs = array_slicing.convert_to_sliceable(
95
+ self._inputs, target_backend="numpy"
96
+ )
97
+
98
+ def slice_and_convert_to_numpy(sliceable, indices=None):
99
+ x = sliceable[indices]
100
+ x = sliceable.convert_to_numpy(x)
101
+ return x
102
+
103
+ return self._get_iterator(slice_and_convert_to_numpy, inputs)
104
+
105
+ def get_tf_dataset(self):
106
+ from keras.src.utils.module_utils import tensorflow as tf
107
+
108
+ shuffle = self._shuffle
109
+ batch_size = self._batch_size
110
+ num_samples = self._num_samples
111
+ num_full_batches = int(self._num_samples // batch_size)
112
+
113
+ # Vectorized version of shuffle.
114
+ # This is a performance improvement over using `from_tensor_slices`.
115
+ # The indices of the data are shuffled and batched, and these indices
116
+ # are then zipped with the data and used to extract a batch of the data
117
+ # at each step. The performance improvements here come from:
118
+ # 1. vectorized batch using gather
119
+ # 2. parallelized map
120
+ # 3. pipelined permutation generation
121
+ # 4. optimized permutation batching
122
+ # 5. disabled static optimizations
123
+
124
+ indices_dataset = tf.data.Dataset.range(1)
125
+
126
+ def permutation(_):
127
+ # It turns out to be more performant to make a new set of indices
128
+ # rather than reusing the same range Tensor. (presumably because of
129
+ # buffer forwarding.)
130
+ indices = tf.range(num_samples, dtype=tf.int64)
131
+ if shuffle and shuffle != "batch":
132
+ indices = tf.random.shuffle(indices)
133
+ return indices
134
+
135
+ # We prefetch a single element. Computing large permutations can take
136
+ # quite a while so we don't want to wait for prefetching over an epoch
137
+ # boundary to trigger the next permutation. On the other hand, too many
138
+ # simultaneous shuffles can contend on a hardware level and degrade all
139
+ # performance.
140
+ indices_dataset = indices_dataset.map(permutation).prefetch(1)
141
+
142
+ def slice_batch_indices(indices):
143
+ """Convert a Tensor of indices into a dataset of batched indices.
144
+
145
+ This step can be accomplished in several ways. The most natural is
146
+ to slice the Tensor in a Dataset map. (With a condition on the upper
147
+ index to handle the partial batch.) However it turns out that
148
+ coercing the Tensor into a shape which is divisible by the batch
149
+ size (and handling the last partial batch separately) allows for a
150
+ much more favorable memory access pattern and improved performance.
151
+
152
+ Args:
153
+ indices: Tensor which determines the data order for an entire
154
+ epoch.
155
+
156
+ Returns:
157
+ A Dataset of batched indices.
158
+ """
159
+ num_in_full_batch = num_full_batches * batch_size
160
+ first_k_indices = tf.slice(indices, [0], [num_in_full_batch])
161
+ first_k_indices = tf.reshape(
162
+ first_k_indices, [num_full_batches, batch_size]
163
+ )
164
+
165
+ flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices)
166
+ if self._partial_batch_size:
167
+ index_remainder = tf.data.Dataset.from_tensors(
168
+ tf.slice(
169
+ indices, [num_in_full_batch], [self._partial_batch_size]
170
+ )
171
+ )
172
+ flat_dataset = flat_dataset.concatenate(index_remainder)
173
+
174
+ return flat_dataset
175
+
176
+ def slice_inputs(indices_dataset, inputs):
177
+ """Slice inputs into a Dataset of batches.
178
+
179
+ Given a Dataset of batch indices and the unsliced inputs,
180
+ this step slices the inputs in a parallelized fashion
181
+ and produces a dataset of input batches.
182
+
183
+ Args:
184
+ indices_dataset: A Dataset of batched indices.
185
+ inputs: A python data structure that contains the inputs,
186
+ targets, and possibly sample weights.
187
+
188
+ Returns:
189
+ A Dataset of input batches matching the batch indices.
190
+ """
191
+ inputs = array_slicing.convert_to_sliceable(
192
+ self._inputs, target_backend="tensorflow"
193
+ )
194
+ inputs = tree.lists_to_tuples(inputs)
195
+
196
+ dataset = tf.data.Dataset.zip(
197
+ (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat())
198
+ )
199
+
200
+ def grab_batch(i, data):
201
+ def grab_one(x):
202
+ if isinstance(x, array_slicing.TensorflowSparseWrapper):
203
+ return array_slicing.slice_tensorflow_sparse_wrapper(
204
+ x, i
205
+ )
206
+ if isinstance(x, (list, tuple, dict)):
207
+ return None
208
+ if tf.is_tensor(x):
209
+ return tf.gather(x, i, axis=0)
210
+ return x
211
+
212
+ return tree.traverse(grab_one, data)
213
+
214
+ dataset = dataset.map(
215
+ grab_batch, num_parallel_calls=tf.data.AUTOTUNE
216
+ )
217
+
218
+ # Default optimizations are disabled to avoid the overhead of
219
+ # (unnecessary) input pipeline graph serialization & deserialization
220
+ options = tf.data.Options()
221
+ options.experimental_optimization.apply_default_optimizations = (
222
+ False
223
+ )
224
+ if self._shuffle:
225
+ options.experimental_external_state_policy = (
226
+ tf.data.experimental.ExternalStatePolicy.IGNORE
227
+ )
228
+ dataset = dataset.with_options(options)
229
+ return dataset
230
+
231
+ indices_dataset = indices_dataset.flat_map(slice_batch_indices)
232
+ if shuffle == "batch":
233
+ indices_dataset = indices_dataset.map(tf.random.shuffle)
234
+
235
+ dataset = slice_inputs(indices_dataset, self._inputs)
236
+
237
+ options = tf.data.Options()
238
+ options.experimental_distribute.auto_shard_policy = (
239
+ tf.data.experimental.AutoShardPolicy.DATA
240
+ )
241
+ dataset = dataset.with_options(options)
242
+ return dataset.prefetch(tf.data.AUTOTUNE)
243
+
244
+ def get_jax_iterator(self):
245
+ inputs = array_slicing.convert_to_sliceable(
246
+ self._inputs, target_backend="jax"
247
+ )
248
+
249
+ def slice_and_convert_to_jax(sliceable, indices=None):
250
+ x = sliceable[indices]
251
+ x = sliceable.convert_to_jax_compatible(x)
252
+ return x
253
+
254
+ return self._get_iterator(slice_and_convert_to_jax, inputs)
255
+
256
+ def get_torch_dataloader(self):
257
+ import torch
258
+
259
+ from keras.src.backend.torch.core import convert_to_tensor
260
+
261
+ class ArrayDataset(torch.utils.data.Dataset):
262
+ def __init__(self, array):
263
+ self.array = array
264
+
265
+ def __getitems__(self, indices):
266
+ def slice_and_convert(sliceable):
267
+ x = sliceable[indices]
268
+ x = sliceable.convert_to_torch_compatible(x)
269
+ x = convert_to_tensor(x)
270
+ return x
271
+
272
+ return tree.map_structure(slice_and_convert, self.array)
273
+
274
+ def __len__(self):
275
+ return len(self.array[0])
276
+
277
+ class RandomBatchSampler(torch.utils.data.Sampler):
278
+ def __init__(self, sampler):
279
+ self.sampler = sampler
280
+
281
+ def __iter__(self):
282
+ for batch in self.sampler:
283
+ yield [batch[i] for i in torch.randperm(len(batch))]
284
+
285
+ def __len__(self):
286
+ return len(self.sampler)
287
+
288
+ if self._shuffle == "batch":
289
+ batch_sampler = RandomBatchSampler(
290
+ torch.utils.data.BatchSampler(
291
+ range(self._num_samples),
292
+ batch_size=self._batch_size,
293
+ drop_last=False,
294
+ )
295
+ )
296
+ elif self._shuffle:
297
+ batch_sampler = torch.utils.data.BatchSampler(
298
+ torch.utils.data.RandomSampler(range(self._num_samples)),
299
+ batch_size=self._batch_size,
300
+ drop_last=False,
301
+ )
302
+ else:
303
+ batch_sampler = torch.utils.data.BatchSampler(
304
+ torch.utils.data.SequentialSampler(range(self._num_samples)),
305
+ batch_size=self._batch_size,
306
+ drop_last=False,
307
+ )
308
+
309
+ # Because ArrayDataset.__getitems__ returns full batches organized in
310
+ # the expected structure, there is nothing to collate.
311
+ def no_op_collate(batch):
312
+ return batch
313
+
314
+ inputs = array_slicing.convert_to_sliceable(
315
+ self._inputs, target_backend="torch"
316
+ )
317
+ dataset = ArrayDataset(inputs)
318
+ return torch.utils.data.DataLoader(
319
+ dataset, batch_sampler=batch_sampler, collate_fn=no_op_collate
320
+ )
321
+
322
+ def _get_iterator(self, slice_and_convert_fn, inputs):
323
+ global_permutation = None
324
+ if self._shuffle and self._shuffle != "batch":
325
+ global_permutation = np.random.permutation(self._num_samples)
326
+
327
+ for i in range(self._size):
328
+ start = i * self._batch_size
329
+ stop = min((i + 1) * self._batch_size, self._num_samples)
330
+ if self._shuffle == "batch":
331
+ indices = np.random.permutation(stop - start) + start
332
+ elif self._shuffle:
333
+ indices = global_permutation[start:stop]
334
+ else:
335
+ indices = slice(start, stop)
336
+
337
+ slice_indices_and_convert_fn = functools.partial(
338
+ slice_and_convert_fn, indices=indices
339
+ )
340
+ yield tree.map_structure(slice_indices_and_convert_fn, inputs)
341
+
342
+ @property
343
+ def num_batches(self):
344
+ return self._size
345
+
346
+ @property
347
+ def batch_size(self):
348
+ return self._batch_size
349
+
350
+ @property
351
+ def has_partial_batch(self):
352
+ return self._partial_batch_size > 0
353
+
354
+ @property
355
+ def partial_batch_size(self):
356
+ return self._partial_batch_size or None
357
+
358
+
359
+ def can_convert_arrays(arrays):
360
+ """Check if array like-inputs can be handled by `ArrayDataAdapter`
361
+
362
+ Args:
363
+ inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like.
364
+
365
+ Returns:
366
+ `True` if `arrays` can be handled by `ArrayDataAdapter`, `False`
367
+ otherwise.
368
+ """
369
+
370
+ return all(
371
+ tree.flatten(tree.map_structure(array_slicing.can_slice_array, arrays))
372
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py ADDED
@@ -0,0 +1,520 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import math
3
+
4
+ import numpy as np
5
+
6
+ from keras.src import backend
7
+ from keras.src import tree
8
+ from keras.src.trainers.data_adapters import data_adapter_utils
9
+ from keras.src.utils.module_utils import tensorflow as tf
10
+
11
+ try:
12
+ import pandas
13
+ except ImportError:
14
+ pandas = None
15
+
16
+
17
+ # Leave jax, tf, and torch arrays off this list. Instead we will use
18
+ # `__array__` to detect these types. Doing so allows us to avoid importing a
19
+ # backend framework we are not currently using just to do type-checking.
20
+ ARRAY_TYPES = (np.ndarray,)
21
+ if pandas:
22
+ ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame)
23
+
24
+
25
+ class Sliceable:
26
+ """`Sliceable` wrapping a tensor.
27
+
28
+ A `Sliceable` implements the subscript operator to slice or index against
29
+ the first dimension of the array. It also has conversion methods for each
30
+ one of the backends.
31
+
32
+ Args:
33
+ array: the native array or tensor to wrap.
34
+
35
+ Attributes:
36
+ shape: the shape of the full dense native array.
37
+ """
38
+
39
+ def __init__(self, array):
40
+ self.array = array
41
+
42
+ def __getitem__(self, indices):
43
+ """Select elements in the 0th dimension.
44
+
45
+ Args:
46
+ indices: the indices to select. Only needs to support one dimension,
47
+ the 0th dimension. Should support a `slice` or a list, tuple,
48
+ `np.array` or 1D tensor.
49
+ Returns: A slice of `self.array`.
50
+ """
51
+ return self.array[indices]
52
+
53
+ @classmethod
54
+ def cast(cls, x, dtype):
55
+ """Cast a tensor to a different dtype.
56
+
57
+ Only called on a full array as provided by the user.
58
+
59
+ Args:
60
+ x: the tensor to cast.
61
+ Returns: the cast tensor.
62
+ """
63
+ return x.astype(dtype)
64
+
65
+ @classmethod
66
+ def convert_to_numpy(cls, x):
67
+ """Convert a tensor to a NumPy array.
68
+
69
+ Only called after slicing using `__getitem__`.
70
+
71
+ Args:
72
+ x: the tensor to convert.
73
+ Returns: the converted tensor.
74
+ """
75
+ return x
76
+
77
+ @classmethod
78
+ def convert_to_tf_dataset_compatible(cls, x):
79
+ """Convert a tensor to something compatible with `tf.data.Dataset`.
80
+
81
+ This can be a NumPy array, `tf.Tensor` or any other type of tensor that
82
+ `tf.data.Dataset.from_tensors` can consume.
83
+ Only called on a full array as provided by the user.
84
+
85
+ Args:
86
+ x: the tensor to convert.
87
+ Returns: converted version tensor.
88
+ """
89
+ return x
90
+
91
+ @classmethod
92
+ def convert_to_jax_compatible(cls, x):
93
+ """Convert a tensor to something that the JAX backend can consume.
94
+
95
+ This can be a `JAX` array, `JAXSparse` or a NumPy array.
96
+ Only called after slicing using `__getitem__`.
97
+ Used to convert sparse tensors and densify ragged tensors.
98
+
99
+ Args:
100
+ x: the tensor to convert.
101
+ Returns: the converted tensor.
102
+ """
103
+ return x
104
+
105
+ @classmethod
106
+ def convert_to_torch_compatible(cls, x):
107
+ """Convert a tensor to something that the Torch backend can consume.
108
+
109
+ This can be a Torch tensor, NumPy array or any other type of tensor that
110
+ `keras.backend.torch.core.convert_to_tensor()` can consume.
111
+ Only called after slicing using `__getitem__`.
112
+ Used to densify sparse tensors and ragged tensors.
113
+
114
+ Args:
115
+ x: the tensor to convert.
116
+ Returns: the converted tensor.
117
+ """
118
+ return x
119
+
120
+
121
+ class NumpySliceable(Sliceable):
122
+ pass
123
+
124
+
125
+ class TensorflowSliceable(Sliceable):
126
+ def __getitem__(self, indices):
127
+ from keras.src.utils.module_utils import tensorflow as tf
128
+
129
+ if isinstance(indices, slice):
130
+ return self.array[indices]
131
+ else:
132
+ return tf.gather(self.array, indices, axis=0)
133
+
134
+ @classmethod
135
+ def cast(cls, x, dtype):
136
+ from keras.src.backend.tensorflow.core import cast
137
+
138
+ return cast(x, dtype)
139
+
140
+ @classmethod
141
+ def convert_to_numpy(cls, x):
142
+ from keras.src.backend.tensorflow.core import convert_to_numpy
143
+
144
+ return convert_to_numpy(x)
145
+
146
+
147
+ class TensorflowRaggedSliceable(TensorflowSliceable):
148
+ @classmethod
149
+ def convert_to_jax_compatible(cls, x):
150
+ return cls.convert_to_numpy(x)
151
+
152
+ @classmethod
153
+ def convert_to_torch_compatible(cls, x):
154
+ return x.to_tensor()
155
+
156
+
157
+ class TensorflowSparseSliceable(TensorflowSliceable):
158
+ def __init__(self, array):
159
+ super().__init__(to_tensorflow_sparse_wrapper(array))
160
+
161
+ @property
162
+ def shape(self):
163
+ return self.array.sparse.shape
164
+
165
+ def __getitem__(self, indices):
166
+ return slice_tensorflow_sparse_wrapper(self.array, indices)
167
+
168
+ @classmethod
169
+ def convert_to_tf_dataset_compatible(cls, x):
170
+ return to_tensorflow_sparse_wrapper(x)
171
+
172
+ @classmethod
173
+ def convert_to_jax_compatible(cls, x):
174
+ return data_adapter_utils.tf_sparse_to_jax_sparse(x)
175
+
176
+ @classmethod
177
+ def convert_to_torch_compatible(cls, x):
178
+ from keras.src.backend.tensorflow import sparse as tf_sparse
179
+
180
+ return tf_sparse.sparse_to_dense(x)
181
+
182
+
183
+ class JaxSparseSliceable(Sliceable):
184
+ def __getitem__(self, indices):
185
+ return self.array[indices, ...]
186
+
187
+ @classmethod
188
+ def convert_to_numpy(cls, x):
189
+ from keras.src.backend.jax.core import convert_to_numpy
190
+
191
+ return convert_to_numpy(x)
192
+
193
+ @classmethod
194
+ def convert_to_tf_dataset_compatible(cls, array):
195
+ return to_tensorflow_sparse_wrapper(
196
+ data_adapter_utils.jax_sparse_to_tf_sparse(array)
197
+ )
198
+
199
+ @classmethod
200
+ def convert_to_torch_compatible(cls, x):
201
+ return x.todense()
202
+
203
+
204
+ class TorchSliceable(Sliceable):
205
+ @classmethod
206
+ def cast(cls, x, dtype):
207
+ from keras.src.backend.torch.core import cast
208
+
209
+ return cast(x, dtype)
210
+
211
+ @classmethod
212
+ def convert_to_numpy(cls, x):
213
+ from keras.src.backend.torch.core import convert_to_numpy
214
+
215
+ return convert_to_numpy(x)
216
+
217
+
218
+ class PandasSliceable(Sliceable):
219
+ def __getitem__(self, indices):
220
+ return self.array.iloc[indices]
221
+
222
+ @classmethod
223
+ def convert_to_numpy(cls, x):
224
+ return x.to_numpy()
225
+
226
+ @classmethod
227
+ def convert_to_tf_dataset_compatible(cls, x):
228
+ return cls.convert_to_numpy(x)
229
+
230
+ @classmethod
231
+ def convert_to_jax_compatible(cls, x):
232
+ return cls.convert_to_numpy(x)
233
+
234
+ @classmethod
235
+ def convert_to_torch_compatible(cls, x):
236
+ return cls.convert_to_numpy(x)
237
+
238
+
239
+ class PandasDataFrameSliceable(PandasSliceable):
240
+ pass
241
+
242
+
243
+ class PandasSeriesSliceable(PandasSliceable):
244
+ @classmethod
245
+ def convert_to_numpy(cls, x):
246
+ return np.expand_dims(x.to_numpy(), axis=-1)
247
+
248
+
249
+ class ScipySparseSliceable(Sliceable):
250
+ def __init__(self, array):
251
+ # The COO representation is not indexable / sliceable and does not lend
252
+ # itself to it. Use the CSR representation instead, which is sliceable.
253
+ super().__init__(array.tocsr())
254
+
255
+ @classmethod
256
+ def convert_to_numpy(cls, x):
257
+ return x.todense()
258
+
259
+ @classmethod
260
+ def convert_to_tf_dataset_compatible(cls, x):
261
+ return to_tensorflow_sparse_wrapper(
262
+ data_adapter_utils.scipy_sparse_to_tf_sparse(x)
263
+ )
264
+
265
+ @classmethod
266
+ def convert_to_jax_compatible(cls, x):
267
+ return data_adapter_utils.scipy_sparse_to_jax_sparse(x)
268
+
269
+ @classmethod
270
+ def convert_to_torch_compatible(cls, x):
271
+ return x.todense()
272
+
273
+
274
+ # `tf.SparseTensor` does not support indexing or `tf.gather`. The COO
275
+ # representation it uses does not lend itself to indexing. We add some
276
+ # intermediary tensors to ease the indexing and slicing. We put both indices and
277
+ # values in `RaggedTensor`s where each row corresponds to a row in the sparse
278
+ # tensor. This is because the number of values per row is not fixed.
279
+ # `RaggedTensor`s do support indexing and `tf.gather`, although on CPU only.
280
+ # We then reconstruct a `SparseTensor` from extracted rows. In theory, there is
281
+ # no duplication of data for the indices and values, only the addition of row
282
+ # splits for the ragged representation.
283
+ # `TensorflowSparseWrapper` is a named tuple which combines the original
284
+ # `SparseTensor` (used for the shape) and the ragged representations of indices
285
+ # and values for indexing / slicing. We use a named tuple and not a `Sliceable`
286
+ # to be able to ingest it in `tf.data.Dataset.from_tensors()` and map it.
287
+
288
+ TensorflowSparseWrapper = collections.namedtuple(
289
+ "TensorflowSparseWrapper", ["sparse", "ragged_indices", "ragged_values"]
290
+ )
291
+
292
+
293
+ def to_tensorflow_sparse_wrapper(sparse):
294
+ from keras.src.utils.module_utils import tensorflow as tf
295
+
296
+ row_ids = sparse.indices[:, 0]
297
+ row_splits = tf.experimental.RowPartition.from_value_rowids(
298
+ row_ids
299
+ ).row_splits()
300
+
301
+ ragged_indices = tf.cast(
302
+ tf.RaggedTensor.from_row_splits(sparse.indices, row_splits), tf.int64
303
+ )
304
+ ragged_values = tf.RaggedTensor.from_row_splits(sparse.values, row_splits)
305
+ return TensorflowSparseWrapper(sparse, ragged_indices, ragged_values)
306
+
307
+
308
+ def slice_tensorflow_sparse_wrapper(sparse_wrapper, indices):
309
+ from keras.src.utils.module_utils import tensorflow as tf
310
+
311
+ if isinstance(indices, slice):
312
+ sparse_indices = sparse_wrapper.ragged_indices[indices]
313
+ sparse_values = sparse_wrapper.ragged_values[indices]
314
+ batch_dim = indices.stop - indices.start
315
+ else:
316
+ sparse_indices = tf.gather(sparse_wrapper.ragged_indices, indices)
317
+ sparse_values = tf.gather(sparse_wrapper.ragged_values, indices)
318
+ if isinstance(indices, list):
319
+ batch_dim = len(indices)
320
+ else:
321
+ batch_dim = indices.shape[0]
322
+ if batch_dim is None:
323
+ batch_dim = tf.shape(indices)[0]
324
+
325
+ row_ids = sparse_indices.value_rowids()
326
+ sparse_indices = sparse_indices.flat_values[:, 1:] # remove first value
327
+ sparse_indices = tf.concat(
328
+ [tf.expand_dims(row_ids, -1), sparse_indices], axis=1
329
+ )
330
+
331
+ sparse_values = sparse_values.flat_values
332
+ sparse_shape = (batch_dim,) + tuple(
333
+ sparse_wrapper.sparse.shape.as_list()[1:]
334
+ )
335
+ return tf.SparseTensor(sparse_indices, sparse_values, sparse_shape)
336
+
337
+
338
+ def can_slice_array(x):
339
+ return (
340
+ x is None
341
+ or isinstance(x, ARRAY_TYPES)
342
+ or data_adapter_utils.is_tensorflow_tensor(x)
343
+ or data_adapter_utils.is_jax_array(x)
344
+ or data_adapter_utils.is_torch_tensor(x)
345
+ or data_adapter_utils.is_scipy_sparse(x)
346
+ or hasattr(x, "__array__")
347
+ )
348
+
349
+
350
+ def convert_to_sliceable(arrays, target_backend=None):
351
+ """Convert a structure of arrays into `Sliceable` instances
352
+
353
+ Args:
354
+ arrays: the arrays to convert.
355
+ target_backend: the target backend for the output:
356
+ - `None` indicates that `arrays` will be wrapped into `Sliceable`s
357
+ as-is without using a different representation. This is used by
358
+ `train_validation_split()`.
359
+ - `tensorflow` indicates that
360
+ `Sliceable.convert_to_tf_dataset_compatible` will be called. The
361
+ returned structure therefore contains arrays, not `Sliceable`s.
362
+ - `numpy`, `jax` or `torch` indices that the arrays will eventually
363
+ be converted to this backend type after slicing. In this case,
364
+ the intermediary `Sliceable`s may use a different representation
365
+ from the input `arrays` for better performance.
366
+ Returns: the same structure with `Sliceable` instances or arrays.
367
+ """
368
+
369
+ def convert_single_array(x):
370
+ if x is None:
371
+ return x
372
+
373
+ # Special case: handle np "object" arrays containing strings
374
+ if (
375
+ isinstance(x, np.ndarray)
376
+ and str(x.dtype) == "object"
377
+ and backend.backend() == "tensorflow"
378
+ and all(isinstance(e, str) for e in x)
379
+ ):
380
+ x = tf.convert_to_tensor(x, dtype="string")
381
+
382
+ # Step 1. Determine which Sliceable class to use.
383
+ if isinstance(x, np.ndarray):
384
+ sliceable_class = NumpySliceable
385
+ elif data_adapter_utils.is_tensorflow_tensor(x):
386
+ if data_adapter_utils.is_tensorflow_ragged(x):
387
+ sliceable_class = TensorflowRaggedSliceable
388
+ elif data_adapter_utils.is_tensorflow_sparse(x):
389
+ sliceable_class = TensorflowSparseSliceable
390
+ else:
391
+ sliceable_class = TensorflowSliceable
392
+ elif data_adapter_utils.is_jax_array(x):
393
+ if data_adapter_utils.is_jax_sparse(x):
394
+ sliceable_class = JaxSparseSliceable
395
+ else:
396
+ x = np.asarray(x)
397
+ sliceable_class = NumpySliceable
398
+ elif data_adapter_utils.is_torch_tensor(x):
399
+ sliceable_class = TorchSliceable
400
+ elif pandas is not None and isinstance(x, pandas.DataFrame):
401
+ sliceable_class = PandasDataFrameSliceable
402
+ elif pandas is not None and isinstance(x, pandas.Series):
403
+ sliceable_class = PandasSeriesSliceable
404
+ elif data_adapter_utils.is_scipy_sparse(x):
405
+ sliceable_class = ScipySparseSliceable
406
+ elif hasattr(x, "__array__"):
407
+ x = np.asarray(x)
408
+ sliceable_class = NumpySliceable
409
+ else:
410
+ raise ValueError(
411
+ "Expected a NumPy array, tf.Tensor, tf.RaggedTensor, "
412
+ "tf.SparseTensor, jax.np.ndarray, "
413
+ "jax.experimental.sparse.JAXSparse, torch.Tensor, "
414
+ "Pandas Dataframe, or Pandas Series. Received invalid input: "
415
+ f"{x} (of type {type(x)})"
416
+ )
417
+
418
+ # Step 2. Normalize floats to floatx.
419
+ def is_non_floatx_float(dtype):
420
+ return (
421
+ dtype is not object
422
+ and backend.is_float_dtype(dtype)
423
+ and not backend.standardize_dtype(dtype) == backend.floatx()
424
+ )
425
+
426
+ cast_dtype = None
427
+ if pandas is not None and isinstance(x, pandas.DataFrame):
428
+ if any(is_non_floatx_float(d) for d in x.dtypes.values):
429
+ cast_dtype = backend.floatx()
430
+ else:
431
+ if is_non_floatx_float(x.dtype):
432
+ cast_dtype = backend.floatx()
433
+
434
+ if cast_dtype is not None:
435
+ x = sliceable_class.cast(x, cast_dtype)
436
+
437
+ # Step 3. Apply target backend specific logic and optimizations.
438
+ if target_backend is None:
439
+ return sliceable_class(x)
440
+
441
+ if target_backend == "tensorflow":
442
+ return sliceable_class.convert_to_tf_dataset_compatible(x)
443
+
444
+ # With dense arrays and JAX as output, it is faster to use NumPy as an
445
+ # intermediary representation, so wrap input array in a NumPy array,
446
+ # which should not use extra memory.
447
+ # See https://github.com/google/jax/issues/1276 for an explanation of
448
+ # why slicing a NumPy array is faster than slicing a JAX array.
449
+ if target_backend == "jax" and sliceable_class in (
450
+ TensorflowSliceable,
451
+ TorchSliceable,
452
+ ):
453
+ x = np.asarray(x)
454
+ sliceable_class = NumpySliceable
455
+
456
+ return sliceable_class(x)
457
+
458
+ return tree.map_structure(convert_single_array, arrays)
459
+
460
+
461
+ def train_validation_split(arrays, validation_split):
462
+ """Split arrays into train and validation subsets in deterministic order.
463
+
464
+ The last part of data will become validation data.
465
+
466
+ Args:
467
+ arrays: Tensors to split. Allowed inputs are arbitrarily nested
468
+ structures of Tensors and NumPy arrays.
469
+ validation_split: Float between 0 and 1. The proportion of the dataset
470
+ to include in the validation split. The rest of the dataset will be
471
+ included in the training split.
472
+
473
+ Returns:
474
+ `(train_arrays, validation_arrays)`
475
+ """
476
+
477
+ flat_arrays = tree.flatten(arrays)
478
+ unsplitable = [type(t) for t in flat_arrays if not can_slice_array(t)]
479
+ if unsplitable:
480
+ raise ValueError(
481
+ "Argument `validation_split` is only supported "
482
+ "for tensors or NumPy arrays."
483
+ f"Found incompatible type in the input: {unsplitable}"
484
+ )
485
+
486
+ if all(t is None for t in flat_arrays):
487
+ return arrays, arrays
488
+
489
+ first_non_none = None
490
+ for t in flat_arrays:
491
+ if t is not None:
492
+ first_non_none = t
493
+ break
494
+
495
+ # Assumes all arrays have the same batch shape or are `None`.
496
+ batch_dim = int(first_non_none.shape[0])
497
+ split_at = int(math.floor(batch_dim * (1.0 - validation_split)))
498
+
499
+ if split_at == 0 or split_at == batch_dim:
500
+ raise ValueError(
501
+ f"Training data contains {batch_dim} samples, which is not "
502
+ "sufficient to split it into a validation and training set as "
503
+ f"specified by `validation_split={validation_split}`. Either "
504
+ "provide more data, or a different value for the "
505
+ "`validation_split` argument."
506
+ )
507
+
508
+ def _split(t, start, end):
509
+ if t is None:
510
+ return t
511
+ return t[start:end]
512
+
513
+ sliceables = convert_to_sliceable(arrays)
514
+ train_arrays = tree.map_structure(
515
+ lambda x: _split(x, start=0, end=split_at), sliceables
516
+ )
517
+ val_arrays = tree.map_structure(
518
+ lambda x: _split(x, start=split_at, end=batch_dim), sliceables
519
+ )
520
+ return train_arrays, val_arrays
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class DataAdapter:
2
+ """Base class for input data adapters.
3
+
4
+ The purpose of a DataAdapter is to provide a unified interface to
5
+ iterate over input data provided in a variety of formats -- such as
6
+ NumPy arrays, tf.Tensors, tf.data.Datasets, Keras PyDatasets, etc.
7
+ """
8
+
9
+ def get_numpy_iterator(self):
10
+ """Get a Python iterable for the `DataAdapter`, that yields NumPy
11
+ arrays.
12
+
13
+ Returns:
14
+ A Python iterator.
15
+ """
16
+ raise NotImplementedError
17
+
18
+ def get_tf_dataset(self):
19
+ """Get a `tf.data.Dataset` instance for the DataAdapter.
20
+
21
+ Note that the dataset returned does not repeat for epoch, so caller
22
+ might need to create new iterator for the same dataset at the beginning
23
+ of the epoch. This behavior might change in the future.
24
+
25
+ Returns:
26
+ A `tf.data.Dataset`. Caller might use the dataset in different
27
+ context, e.g. iter(dataset) in eager to get the value directly, or
28
+ in graph mode, provide the iterator tensor to Keras model function.
29
+ """
30
+ raise NotImplementedError
31
+
32
+ def get_jax_iterator(self):
33
+ """Get a Python iterable for the `DataAdapter`, that yields arrays that
34
+ that can be fed to JAX. NumPy arrays are preferred for performance.
35
+
36
+ Returns:
37
+ A Python iterator.
38
+ """
39
+ raise NotImplementedError
40
+
41
+ def get_torch_dataloader(self):
42
+ """Get a Torch `DataLoader` for the `DataAdapter`.
43
+
44
+ Returns:
45
+ A Torch `DataLoader`.
46
+ """
47
+ raise NotImplementedError
48
+
49
+ @property
50
+ def num_batches(self):
51
+ """Return the size (number of batches) for the dataset created.
52
+
53
+ For certain type of the data input, the number of batches is known, eg
54
+ for Numpy data, the size is same as (number_of_element / batch_size).
55
+ Whereas for dataset or python generator, the size is unknown since it
56
+ may or may not have an end state.
57
+
58
+ Returns:
59
+ int, the number of batches for the dataset, or None if it is
60
+ unknown. The caller could use this to control the loop of training,
61
+ show progress bar, or handle unexpected StopIteration error.
62
+ """
63
+ raise NotImplementedError
64
+
65
+ @property
66
+ def batch_size(self):
67
+ """Return the batch size of the dataset created.
68
+
69
+ For certain type of the data input, the batch size is known, and even
70
+ required, like numpy array. Whereas for dataset, the batch is unknown
71
+ unless we take a peek.
72
+
73
+ Returns:
74
+ int, the batch size of the dataset, or None if it is unknown.
75
+ """
76
+ raise NotImplementedError
77
+
78
+ @property
79
+ def has_partial_batch(self):
80
+ """Whether the dataset has partial batch at the end."""
81
+ raise NotImplementedError
82
+
83
+ @property
84
+ def partial_batch_size(self):
85
+ """The size of the final partial batch for dataset.
86
+
87
+ Will return None if has_partial_batch is False or batch_size is None.
88
+ """
89
+ raise NotImplementedError
90
+
91
+ def on_epoch_begin(self):
92
+ """A hook called before each epoch."""
93
+ pass
94
+
95
+ def on_epoch_end(self):
96
+ """A hook called after each epoch."""
97
+ pass
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from keras.src import backend
4
+ from keras.src import ops
5
+ from keras.src import tree
6
+ from keras.src.api_export import keras_export
7
+
8
+ NUM_BATCHES_FOR_TENSOR_SPEC = 2
9
+
10
+
11
+ @keras_export("keras.utils.unpack_x_y_sample_weight")
12
+ def unpack_x_y_sample_weight(data):
13
+ """Unpacks user-provided data tuple.
14
+
15
+ This is a convenience utility to be used when overriding
16
+ `Model.train_step`, `Model.test_step`, or `Model.predict_step`.
17
+ This utility makes it easy to support data of the form `(x,)`,
18
+ `(x, y)`, or `(x, y, sample_weight)`.
19
+
20
+ Example:
21
+
22
+ >>> features_batch = ops.ones((10, 5))
23
+ >>> labels_batch = ops.zeros((10, 5))
24
+ >>> data = (features_batch, labels_batch)
25
+ >>> # `y` and `sample_weight` will default to `None` if not provided.
26
+ >>> x, y, sample_weight = unpack_x_y_sample_weight(data)
27
+ >>> sample_weight is None
28
+ True
29
+
30
+ Args:
31
+ data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`.
32
+
33
+ Returns:
34
+ The unpacked tuple, with `None`s for `y` and `sample_weight` if they are
35
+ not provided.
36
+ """
37
+ if isinstance(data, list):
38
+ data = tuple(data)
39
+ if not isinstance(data, tuple):
40
+ return (data, None, None)
41
+ elif len(data) == 1:
42
+ return (data[0], None, None)
43
+ elif len(data) == 2:
44
+ return (data[0], data[1], None)
45
+ elif len(data) == 3:
46
+ return (data[0], data[1], data[2])
47
+ error_msg = (
48
+ "Data is expected to be in format `x`, `(x,)`, `(x, y)`, "
49
+ f"or `(x, y, sample_weight)`, found: {data}"
50
+ )
51
+ raise ValueError(error_msg)
52
+
53
+
54
+ @keras_export("keras.utils.pack_x_y_sample_weight")
55
+ def pack_x_y_sample_weight(x, y=None, sample_weight=None):
56
+ """Packs user-provided data into a tuple.
57
+
58
+ This is a convenience utility for packing data into the tuple formats
59
+ that `Model.fit()` uses.
60
+
61
+ Example:
62
+
63
+ >>> x = ops.ones((10, 1))
64
+ >>> data = pack_x_y_sample_weight(x)
65
+ >>> isinstance(data, ops.Tensor)
66
+ True
67
+ >>> y = ops.ones((10, 1))
68
+ >>> data = pack_x_y_sample_weight(x, y)
69
+ >>> isinstance(data, tuple)
70
+ True
71
+ >>> x, y = data
72
+
73
+ Args:
74
+ x: Features to pass to `Model`.
75
+ y: Ground-truth targets to pass to `Model`.
76
+ sample_weight: Sample weight for each element.
77
+
78
+ Returns:
79
+ Tuple in the format used in `Model.fit()`.
80
+ """
81
+ if y is None:
82
+ # For single x-input, we do no tuple wrapping since in this case
83
+ # there is no ambiguity. This also makes NumPy and Dataset
84
+ # consistent in that the user does not have to wrap their Dataset
85
+ # data in an unnecessary tuple.
86
+ if not isinstance(x, (tuple, list)):
87
+ return x
88
+ else:
89
+ return (x,)
90
+ elif sample_weight is None:
91
+ return (x, y)
92
+ else:
93
+ return (x, y, sample_weight)
94
+
95
+
96
+ def list_to_tuple(maybe_list):
97
+ """Datasets will stack any list of tensors, so we convert them to tuples."""
98
+ if isinstance(maybe_list, list):
99
+ return tuple(maybe_list)
100
+ return maybe_list
101
+
102
+
103
+ def check_data_cardinality(data):
104
+ num_samples = set(int(i.shape[0]) for i in tree.flatten(data))
105
+ if len(num_samples) > 1:
106
+ msg = (
107
+ "Data cardinality is ambiguous. "
108
+ "Make sure all arrays contain the same number of samples."
109
+ )
110
+ for label, single_data in zip(["x", "y", "sample_weight"], data):
111
+ sizes = ", ".join(
112
+ str(i.shape[0]) for i in tree.flatten(single_data)
113
+ )
114
+ msg += f"'{label}' sizes: {sizes}\n"
115
+ raise ValueError(msg)
116
+
117
+
118
+ def class_weight_to_sample_weights(y, class_weight):
119
+ # Convert to numpy to ensure consistent handling of operations
120
+ # (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch
121
+
122
+ y_numpy = ops.convert_to_numpy(y)
123
+ sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx())
124
+ if len(y_numpy.shape) > 1:
125
+ if y_numpy.shape[-1] != 1:
126
+ y_numpy = np.argmax(y_numpy, axis=-1)
127
+ else:
128
+ y_numpy = np.squeeze(y_numpy, axis=-1)
129
+ y_numpy = np.round(y_numpy).astype("int32")
130
+
131
+ for i in range(y_numpy.shape[0]):
132
+ sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0)
133
+ return sample_weight
134
+
135
+
136
+ def get_tensor_spec(batches):
137
+ """Return the common tensor spec for a list of batches.
138
+
139
+ Args:
140
+ batches: list of structures of tensors. The structures must be
141
+ identical, but the shape at each leaf may be different.
142
+ Returns: the common tensor spec for all the batches.
143
+ """
144
+ from keras.src.utils.module_utils import tensorflow as tf
145
+
146
+ def get_single_tensor_spec(*tensors):
147
+ x = tensors[0]
148
+ rank = len(x.shape)
149
+ if rank < 1:
150
+ raise ValueError(
151
+ "When passing a dataset to a Keras model, the arrays must "
152
+ f"be at least rank 1. Received: {x} of rank {len(x.shape)}."
153
+ )
154
+ for t in tensors:
155
+ if len(t.shape) != rank:
156
+ raise ValueError(
157
+ "When passing a dataset to a Keras model, the "
158
+ "corresponding arrays in each batch must have the same "
159
+ f"rank. Received: {x} and {t}"
160
+ )
161
+ shape = []
162
+ # Merge shapes: go through each dimension one by one and keep the
163
+ # common values
164
+ for dims in zip(*[list(x.shape) for x in tensors]):
165
+ dims_set = set(dims)
166
+ shape.append(dims_set.pop() if len(dims_set) == 1 else None)
167
+ shape[0] = None # batch size may not be static
168
+
169
+ dtype = backend.standardize_dtype(x.dtype)
170
+ if isinstance(x, tf.RaggedTensor):
171
+ return tf.RaggedTensorSpec(shape=shape, dtype=dtype)
172
+ if (
173
+ isinstance(x, tf.SparseTensor)
174
+ or is_scipy_sparse(x)
175
+ or is_jax_sparse(x)
176
+ ):
177
+ return tf.SparseTensorSpec(shape=shape, dtype=dtype)
178
+ else:
179
+ return tf.TensorSpec(shape=shape, dtype=dtype)
180
+
181
+ return tree.map_structure(get_single_tensor_spec, *batches)
182
+
183
+
184
+ def get_jax_iterator(iterable):
185
+ import jax
186
+ import jax.experimental.sparse as jax_sparse
187
+
188
+ def convert_to_jax_compatible(x):
189
+ if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)):
190
+ return x
191
+ elif is_scipy_sparse(x):
192
+ return scipy_sparse_to_jax_sparse(x)
193
+ elif is_tensorflow_sparse(x):
194
+ return tf_sparse_to_jax_sparse(x)
195
+ else:
196
+ return np.asarray(x)
197
+
198
+ for batch in iterable:
199
+ yield tree.map_structure(convert_to_jax_compatible, batch)
200
+
201
+
202
+ def get_numpy_iterator(iterable):
203
+ def convert_to_numpy(x):
204
+ if not isinstance(x, np.ndarray):
205
+ # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`,
206
+ # `torch.Tensor`, as well as any other tensor-like object that
207
+ # has added numpy support.
208
+ if hasattr(x, "__array__"):
209
+ if is_torch_tensor(x):
210
+ x = x.cpu()
211
+ x = np.asarray(x)
212
+ return x
213
+
214
+ for batch in iterable:
215
+ yield tree.map_structure(convert_to_numpy, batch)
216
+
217
+
218
+ def get_torch_dataloader(iterable):
219
+ import torch.utils.data as torch_data
220
+
221
+ from keras.src.backend.torch.core import convert_to_tensor
222
+
223
+ class ConverterIterableDataset(torch_data.IterableDataset):
224
+ def __init__(self, iterable):
225
+ self.iterable = iterable
226
+
227
+ def __iter__(self):
228
+ for batch in self.iterable:
229
+ yield tree.map_structure(convert_to_tensor, batch)
230
+
231
+ dataset = ConverterIterableDataset(iterable)
232
+ # `batch_size=None` indicates that we should not re-batch
233
+ return torch_data.DataLoader(dataset, batch_size=None)
234
+
235
+
236
+ def is_tensorflow_tensor(value):
237
+ if hasattr(value, "__class__"):
238
+ if value.__class__.__name__ in ("RaggedTensor", "SparseTensor"):
239
+ return "tensorflow.python." in str(value.__class__.__module__)
240
+ for parent in value.__class__.__mro__:
241
+ if parent.__name__ in ("Tensor") and "tensorflow.python." in str(
242
+ parent.__module__
243
+ ):
244
+ return True
245
+ return False
246
+
247
+
248
+ def is_tensorflow_ragged(value):
249
+ if hasattr(value, "__class__"):
250
+ return (
251
+ value.__class__.__name__ == "RaggedTensor"
252
+ and "tensorflow.python." in str(value.__class__.__module__)
253
+ )
254
+ return False
255
+
256
+
257
+ def is_tensorflow_sparse(value):
258
+ if hasattr(value, "__class__"):
259
+ return (
260
+ value.__class__.__name__ == "SparseTensor"
261
+ and "tensorflow.python." in str(value.__class__.__module__)
262
+ )
263
+ return False
264
+
265
+
266
+ def is_jax_array(value):
267
+ if hasattr(value, "__class__"):
268
+ for parent in value.__class__.__mro__:
269
+ if parent.__name__ == "Array" and str(parent.__module__) == "jax":
270
+ return True
271
+ return is_jax_sparse(value) # JAX sparse arrays do not extend jax.Array
272
+
273
+
274
+ def is_jax_sparse(value):
275
+ if hasattr(value, "__class__"):
276
+ return str(value.__class__.__module__).startswith(
277
+ "jax.experimental.sparse"
278
+ )
279
+ return False
280
+
281
+
282
+ def is_torch_tensor(value):
283
+ if hasattr(value, "__class__"):
284
+ for parent in value.__class__.__mro__:
285
+ if parent.__name__ == "Tensor" and str(parent.__module__).endswith(
286
+ "torch"
287
+ ):
288
+ return True
289
+ return False
290
+
291
+
292
+ def is_scipy_sparse(x):
293
+ return str(x.__class__.__module__).startswith("scipy.sparse") and hasattr(
294
+ x, "tocoo"
295
+ )
296
+
297
+
298
+ def scipy_sparse_to_tf_sparse(x):
299
+ from keras.src.utils.module_utils import tensorflow as tf
300
+
301
+ coo = x.tocoo()
302
+ indices = np.concatenate(
303
+ (np.expand_dims(coo.row, 1), np.expand_dims(coo.col, 1)), axis=1
304
+ )
305
+ return tf.SparseTensor(indices, coo.data, coo.shape)
306
+
307
+
308
+ def scipy_sparse_to_jax_sparse(x):
309
+ import jax
310
+ import jax.experimental.sparse as jax_sparse
311
+
312
+ with jax.default_device(jax.local_devices(backend="cpu")[0]):
313
+ return jax_sparse.BCOO.from_scipy_sparse(x)
314
+
315
+
316
+ def tf_sparse_to_jax_sparse(x):
317
+ import jax
318
+ import jax.experimental.sparse as jax_sparse
319
+
320
+ values = np.asarray(x.values)
321
+ indices = np.asarray(x.indices)
322
+ with jax.default_device(jax.local_devices(backend="cpu")[0]):
323
+ return jax_sparse.BCOO((values, indices), shape=x.shape)
324
+
325
+
326
+ def jax_sparse_to_tf_sparse(x):
327
+ from keras.src.utils.module_utils import tensorflow as tf
328
+
329
+ return tf.SparseTensor(x.indices, x.data, x.shape)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ from keras.src import tree
4
+ from keras.src.trainers.data_adapters import data_adapter_utils
5
+ from keras.src.trainers.data_adapters.data_adapter import DataAdapter
6
+
7
+
8
+ class GeneratorDataAdapter(DataAdapter):
9
+ """Adapter for Python generators."""
10
+
11
+ def __init__(self, generator):
12
+ first_batches, generator = peek_and_restore(generator)
13
+ self.generator = generator
14
+ self._first_batches = first_batches
15
+ self._output_signature = None
16
+ if not isinstance(first_batches[0], tuple):
17
+ raise ValueError(
18
+ "When passing a Python generator to a Keras model, "
19
+ "the generator must return a tuple, either "
20
+ "(input,) or (inputs, targets) or "
21
+ "(inputs, targets, sample_weights). "
22
+ f"Received: {first_batches[0]}"
23
+ )
24
+
25
+ def get_numpy_iterator(self):
26
+ return data_adapter_utils.get_numpy_iterator(self.generator())
27
+
28
+ def get_jax_iterator(self):
29
+ return data_adapter_utils.get_jax_iterator(self.generator())
30
+
31
+ def get_tf_dataset(self):
32
+ from keras.src.utils.module_utils import tensorflow as tf
33
+
34
+ def convert_to_tf(x, spec):
35
+ if data_adapter_utils.is_scipy_sparse(x):
36
+ x = data_adapter_utils.scipy_sparse_to_tf_sparse(x)
37
+ elif data_adapter_utils.is_jax_sparse(x):
38
+ x = data_adapter_utils.jax_sparse_to_tf_sparse(x)
39
+ if not spec.shape.is_compatible_with(x.shape):
40
+ raise TypeError(
41
+ f"Generator yielded an element of shape {x.shape} where "
42
+ f"an element of shape {spec.shape} was expected. Your "
43
+ "generator provides tensors with variable input "
44
+ "dimensions other than the batch size. Make sure that the "
45
+ "generator's first two batches do not have the same "
46
+ "dimension value wherever there is a variable input "
47
+ "dimension."
48
+ )
49
+ return x
50
+
51
+ def get_tf_iterator():
52
+ for batch in self.generator():
53
+ batch = tree.map_structure(
54
+ convert_to_tf, batch, self._output_signature
55
+ )
56
+ yield batch
57
+
58
+ if self._output_signature is None:
59
+ self._output_signature = data_adapter_utils.get_tensor_spec(
60
+ self._first_batches
61
+ )
62
+ ds = tf.data.Dataset.from_generator(
63
+ get_tf_iterator,
64
+ output_signature=self._output_signature,
65
+ )
66
+ ds = ds.prefetch(tf.data.AUTOTUNE)
67
+ return ds
68
+
69
+ def get_torch_dataloader(self):
70
+ return data_adapter_utils.get_torch_dataloader(self.generator())
71
+
72
+ @property
73
+ def num_batches(self):
74
+ return None
75
+
76
+ @property
77
+ def batch_size(self):
78
+ return None
79
+
80
+
81
+ def peek_and_restore(generator):
82
+ batches = list(
83
+ itertools.islice(
84
+ generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
85
+ )
86
+ )
87
+ return batches, lambda: itertools.chain(batches, generator)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import multiprocessing.dummy
3
+ import queue
4
+ import random
5
+ import threading
6
+ import warnings
7
+ import weakref
8
+ from contextlib import closing
9
+
10
+ import numpy as np
11
+
12
+ from keras.src.api_export import keras_export
13
+ from keras.src.trainers.data_adapters import data_adapter_utils
14
+ from keras.src.trainers.data_adapters.data_adapter import DataAdapter
15
+
16
+
17
+ @keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"])
18
+ class PyDataset:
19
+ """Base class for defining a parallel dataset using Python code.
20
+
21
+ Every `PyDataset` must implement the `__getitem__()` and the `__len__()`
22
+ methods. If you want to modify your dataset between epochs,
23
+ you may additionally implement `on_epoch_end()`,
24
+ or `on_epoch_begin` to be called at the start of each epoch.
25
+ The `__getitem__()` method should return a complete batch
26
+ (not a single sample), and the `__len__` method should return
27
+ the number of batches in the dataset (rather than the number of samples).
28
+
29
+ Args:
30
+ workers: Number of workers to use in multithreading or
31
+ multiprocessing.
32
+ use_multiprocessing: Whether to use Python multiprocessing for
33
+ parallelism. Setting this to `True` means that your
34
+ dataset will be replicated in multiple forked processes.
35
+ This is necessary to gain compute-level (rather than I/O level)
36
+ benefits from parallelism. However it can only be set to
37
+ `True` if your dataset can be safely pickled.
38
+ max_queue_size: Maximum number of batches to keep in the queue
39
+ when iterating over the dataset in a multithreaded or
40
+ multiprocessed setting.
41
+ Reduce this value to reduce the CPU memory consumption of
42
+ your dataset. Defaults to 10.
43
+
44
+ Notes:
45
+
46
+ - `PyDataset` is a safer way to do multiprocessing.
47
+ This structure guarantees that the model will only train
48
+ once on each sample per epoch, which is not the case
49
+ with Python generators.
50
+ - The arguments `workers`, `use_multiprocessing`, and `max_queue_size`
51
+ exist to configure how `fit()` uses parallelism to iterate
52
+ over the dataset. They are not being used by the `PyDataset` class
53
+ directly. When you are manually iterating over a `PyDataset`,
54
+ no parallelism is applied.
55
+
56
+ Example:
57
+
58
+ ```python
59
+ from skimage.io import imread
60
+ from skimage.transform import resize
61
+ import numpy as np
62
+ import math
63
+
64
+ # Here, `x_set` is list of path to the images
65
+ # and `y_set` are the associated classes.
66
+
67
+ class CIFAR10PyDataset(keras.utils.PyDataset):
68
+
69
+ def __init__(self, x_set, y_set, batch_size, **kwargs):
70
+ super().__init__(**kwargs)
71
+ self.x, self.y = x_set, y_set
72
+ self.batch_size = batch_size
73
+
74
+ def __len__(self):
75
+ # Return number of batches.
76
+ return math.ceil(len(self.x) / self.batch_size)
77
+
78
+ def __getitem__(self, idx):
79
+ # Return x, y for batch idx.
80
+ low = idx * self.batch_size
81
+ # Cap upper bound at array length; the last batch may be smaller
82
+ # if the total number of items is not a multiple of batch size.
83
+ high = min(low + self.batch_size, len(self.x))
84
+ batch_x = self.x[low:high]
85
+ batch_y = self.y[low:high]
86
+
87
+ return np.array([
88
+ resize(imread(file_name), (200, 200))
89
+ for file_name in batch_x]), np.array(batch_y)
90
+ ```
91
+ """
92
+
93
+ def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10):
94
+ self._workers = workers
95
+ self._use_multiprocessing = use_multiprocessing
96
+ self._max_queue_size = max_queue_size
97
+
98
+ def _warn_if_super_not_called(self):
99
+ warn = False
100
+ if not hasattr(self, "_workers"):
101
+ self._workers = 1
102
+ warn = True
103
+ if not hasattr(self, "_use_multiprocessing"):
104
+ self._use_multiprocessing = False
105
+ warn = True
106
+ if not hasattr(self, "_max_queue_size"):
107
+ self._max_queue_size = 10
108
+ warn = True
109
+ if warn:
110
+ warnings.warn(
111
+ "Your `PyDataset` class should call "
112
+ "`super().__init__(**kwargs)` in its constructor. "
113
+ "`**kwargs` can include `workers`, "
114
+ "`use_multiprocessing`, `max_queue_size`. Do not pass "
115
+ "these arguments to `fit()`, as they will be ignored.",
116
+ stacklevel=2,
117
+ )
118
+
119
+ @property
120
+ def workers(self):
121
+ self._warn_if_super_not_called()
122
+ return self._workers
123
+
124
+ @workers.setter
125
+ def workers(self, value):
126
+ self._workers = value
127
+
128
+ @property
129
+ def use_multiprocessing(self):
130
+ self._warn_if_super_not_called()
131
+ return self._use_multiprocessing
132
+
133
+ @use_multiprocessing.setter
134
+ def use_multiprocessing(self, value):
135
+ self._use_multiprocessing = value
136
+
137
+ @property
138
+ def max_queue_size(self):
139
+ self._warn_if_super_not_called()
140
+ return self._max_queue_size
141
+
142
+ @max_queue_size.setter
143
+ def max_queue_size(self, value):
144
+ self._max_queue_size = value
145
+
146
+ def __getitem__(self, index):
147
+ """Gets batch at position `index`.
148
+
149
+ Args:
150
+ index: position of the batch in the PyDataset.
151
+
152
+ Returns:
153
+ A batch
154
+ """
155
+ raise NotImplementedError
156
+
157
+ @property
158
+ def num_batches(self):
159
+ """Number of batches in the PyDataset.
160
+
161
+ Returns:
162
+ The number of batches in the PyDataset or `None` to indicate that
163
+ the dataset is infinite.
164
+ """
165
+ # For backwards compatibility, support `__len__`.
166
+ if hasattr(self, "__len__"):
167
+ return len(self)
168
+ raise NotImplementedError(
169
+ "You need to implement the `num_batches` property:\n\n"
170
+ "@property\ndef num_batches(self):\n return ..."
171
+ )
172
+
173
+ def on_epoch_begin(self):
174
+ """Method called at the beginning of every epoch."""
175
+ pass
176
+
177
+ def on_epoch_end(self):
178
+ """Method called at the end of every epoch."""
179
+ pass
180
+
181
+
182
+ class PyDatasetAdapter(DataAdapter):
183
+ """Adapter for `keras.utils.PyDataset` instances."""
184
+
185
+ def __init__(
186
+ self,
187
+ x,
188
+ class_weight=None,
189
+ shuffle=False,
190
+ ):
191
+ self.py_dataset = x
192
+ self.class_weight = class_weight
193
+ self.enqueuer = None
194
+ self.shuffle = shuffle
195
+ self._output_signature = None
196
+ self._within_epoch = False
197
+
198
+ workers = self.py_dataset.workers
199
+ use_multiprocessing = self.py_dataset.use_multiprocessing
200
+ if workers > 1 or (workers > 0 and use_multiprocessing):
201
+ self.enqueuer = OrderedEnqueuer(
202
+ self.py_dataset,
203
+ workers=workers,
204
+ use_multiprocessing=use_multiprocessing,
205
+ max_queue_size=self.py_dataset.max_queue_size,
206
+ shuffle=self.shuffle,
207
+ )
208
+
209
+ def _standardize_batch(self, batch):
210
+ if isinstance(batch, dict):
211
+ return batch
212
+ if isinstance(batch, np.ndarray):
213
+ batch = (batch,)
214
+ if isinstance(batch, list):
215
+ batch = tuple(batch)
216
+ if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}:
217
+ raise ValueError(
218
+ "PyDataset.__getitem__() must return a tuple or a dict. "
219
+ "If a tuple, it must be ordered either "
220
+ "(input,) or (inputs, targets) or "
221
+ "(inputs, targets, sample_weights). "
222
+ f"Received: {str(batch)[:100]}... of type {type(batch)}"
223
+ )
224
+ if self.class_weight is not None:
225
+ if len(batch) == 3:
226
+ raise ValueError(
227
+ "You cannot specify `class_weight` "
228
+ "and `sample_weight` at the same time."
229
+ )
230
+ if len(batch) == 2:
231
+ sw = data_adapter_utils.class_weight_to_sample_weights(
232
+ batch[1], self.class_weight
233
+ )
234
+ batch = batch + (sw,)
235
+ return batch
236
+
237
+ def _infinite_generator(self):
238
+ for i in itertools.count():
239
+ yield self._standardize_batch(self.py_dataset[i])
240
+
241
+ def _finite_generator(self):
242
+ indices = range(self.py_dataset.num_batches)
243
+ if self.shuffle:
244
+ indices = list(indices)
245
+ random.shuffle(indices)
246
+
247
+ for i in indices:
248
+ yield self._standardize_batch(self.py_dataset[i])
249
+
250
+ def _infinite_enqueuer_generator(self):
251
+ self.enqueuer.start()
252
+ for batch in self.enqueuer.get():
253
+ yield self._standardize_batch(batch)
254
+
255
+ def _finite_enqueuer_generator(self):
256
+ self.enqueuer.start()
257
+ num_batches = self.py_dataset.num_batches
258
+ for i, batch in enumerate(self.enqueuer.get()):
259
+ yield self._standardize_batch(batch)
260
+ if i >= num_batches - 1:
261
+ self.enqueuer.stop()
262
+ return
263
+
264
+ def _get_iterator(self):
265
+ if self.enqueuer is None:
266
+ if self.py_dataset.num_batches is None:
267
+ return self._infinite_generator()
268
+ else:
269
+ return self._finite_generator()
270
+ else:
271
+ if self.py_dataset.num_batches is None:
272
+ return self._infinite_enqueuer_generator()
273
+ else:
274
+ return self._finite_enqueuer_generator()
275
+
276
+ def get_numpy_iterator(self):
277
+ return data_adapter_utils.get_numpy_iterator(self._get_iterator())
278
+
279
+ def get_jax_iterator(self):
280
+ return data_adapter_utils.get_jax_iterator(self._get_iterator())
281
+
282
+ def get_tf_dataset(self):
283
+ from keras.src.utils.module_utils import tensorflow as tf
284
+
285
+ num_batches = self.py_dataset.num_batches
286
+ if self._output_signature is None:
287
+ num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC
288
+ if num_batches is not None:
289
+ num_samples = min(num_samples, num_batches)
290
+ batches = [
291
+ self._standardize_batch(self.py_dataset[i])
292
+ for i in range(num_samples)
293
+ ]
294
+ if len(batches) == 0:
295
+ raise ValueError("The PyDataset has length 0")
296
+ self._output_signature = data_adapter_utils.get_tensor_spec(batches)
297
+
298
+ ds = tf.data.Dataset.from_generator(
299
+ self._get_iterator,
300
+ output_signature=self._output_signature,
301
+ )
302
+ if self.enqueuer is not None:
303
+ # The enqueuer does its own multithreading / multiprocesssing to
304
+ # prefetch items. Disable the tf.data.Dataset prefetching and
305
+ # threading as it interferes.
306
+ options = tf.data.Options()
307
+ options.autotune.enabled = False
308
+ options.threading.private_threadpool_size = 1
309
+ ds = ds.with_options(options)
310
+ else:
311
+ ds = ds.prefetch(tf.data.AUTOTUNE)
312
+ return ds
313
+
314
+ def get_torch_dataloader(self):
315
+ return data_adapter_utils.get_torch_dataloader(self._get_iterator())
316
+
317
+ def on_epoch_begin(self):
318
+ if self._within_epoch:
319
+ raise ValueError(
320
+ "`on_epoch_begin` was called twice without `on_epoch_end` "
321
+ "having been called."
322
+ )
323
+ self._within_epoch = True
324
+ if self.enqueuer:
325
+ self.enqueuer.start()
326
+ self.py_dataset.on_epoch_begin()
327
+
328
+ def on_epoch_end(self):
329
+ if self.enqueuer:
330
+ self.enqueuer.stop()
331
+ self.py_dataset.on_epoch_end()
332
+ self._within_epoch = False
333
+
334
+ @property
335
+ def num_batches(self):
336
+ return self.py_dataset.num_batches
337
+
338
+ @property
339
+ def batch_size(self):
340
+ return None
341
+
342
+
343
+ # Global variables to be shared across processes
344
+ _SHARED_SEQUENCES = {}
345
+ # We use a Value to provide unique id to different processes.
346
+ _SEQUENCE_COUNTER = None
347
+
348
+
349
+ # Because multiprocessing pools are inherently unsafe, starting from a clean
350
+ # state can be essential to avoiding deadlocks. In order to accomplish this, we
351
+ # need to be able to check on the status of Pools that we create.
352
+ _DATA_POOLS = weakref.WeakSet()
353
+ _WORKER_ID_QUEUE = None # Only created if needed.
354
+ _FORCE_THREADPOOL = False
355
+
356
+
357
+ def get_pool_class(use_multiprocessing):
358
+ global _FORCE_THREADPOOL
359
+ if not use_multiprocessing or _FORCE_THREADPOOL:
360
+ return multiprocessing.dummy.Pool # ThreadPool
361
+ return multiprocessing.Pool
362
+
363
+
364
+ def get_worker_id_queue():
365
+ """Lazily create the queue to track worker ids."""
366
+ global _WORKER_ID_QUEUE
367
+ if _WORKER_ID_QUEUE is None:
368
+ _WORKER_ID_QUEUE = multiprocessing.Queue()
369
+ return _WORKER_ID_QUEUE
370
+
371
+
372
+ def get_index(uid, i):
373
+ """Get the value from the PyDataset `uid` at index `i`.
374
+
375
+ To allow multiple PyDatasets to be used at the same time, we use `uid` to
376
+ get a specific one. A single PyDataset would cause the validation to
377
+ overwrite the training PyDataset.
378
+
379
+ This methods is called from worker threads.
380
+
381
+ Args:
382
+ uid: int, PyDataset identifier
383
+ i: index
384
+
385
+ Returns:
386
+ The value at index `i`.
387
+ """
388
+ return _SHARED_SEQUENCES[uid][i]
389
+
390
+
391
+ class PyDatasetEnqueuer:
392
+ """Base class to enqueue inputs.
393
+
394
+ The task of an Enqueuer is to use parallelism to speed up preprocessing.
395
+ This is done with processes or threads.
396
+
397
+ Example:
398
+
399
+ ```python
400
+ enqueuer = PyDatasetEnqueuer(...)
401
+ enqueuer.start()
402
+ datas = enqueuer.get()
403
+ for data in datas:
404
+ # Use the inputs; training, evaluating, predicting.
405
+ # ... stop sometime.
406
+ enqueuer.stop()
407
+ ```
408
+
409
+ The `enqueuer.get()` should be an infinite stream of data.
410
+ """
411
+
412
+ def __init__(
413
+ self,
414
+ py_dataset,
415
+ workers=1,
416
+ use_multiprocessing=False,
417
+ max_queue_size=10,
418
+ ):
419
+ self.py_dataset = py_dataset
420
+
421
+ global _SEQUENCE_COUNTER
422
+ if _SEQUENCE_COUNTER is None:
423
+ try:
424
+ _SEQUENCE_COUNTER = multiprocessing.Value("i", 0)
425
+ except OSError:
426
+ # In this case the OS does not allow us to use
427
+ # multiprocessing. We resort to an int
428
+ # for enqueuer indexing.
429
+ _SEQUENCE_COUNTER = 0
430
+
431
+ if isinstance(_SEQUENCE_COUNTER, int):
432
+ self.uid = _SEQUENCE_COUNTER
433
+ _SEQUENCE_COUNTER += 1
434
+ else:
435
+ # Doing Multiprocessing.Value += x is not process-safe.
436
+ with _SEQUENCE_COUNTER.get_lock():
437
+ self.uid = _SEQUENCE_COUNTER.value
438
+ _SEQUENCE_COUNTER.value += 1
439
+
440
+ self.ready_queue = queue.Queue()
441
+ self.future_queue = queue.Queue(max_queue_size)
442
+ self.running = False
443
+ self.start_stop_lock = threading.Lock()
444
+ self.run_thread = None
445
+ if use_multiprocessing:
446
+ self.executor_fn = self._get_executor_init(workers)
447
+ else:
448
+ # We do not need the init since it's threads.
449
+ self.executor_fn = lambda _: get_pool_class(False)(workers)
450
+
451
+ def is_running(self):
452
+ """Whether the enqueuer is running.
453
+
454
+ This method is thread safe and called from many threads.
455
+
456
+ Returns: boolean indicating whether this enqueuer is running.
457
+ """
458
+ return self.running
459
+
460
+ def start(self):
461
+ """Starts the handler's workers.
462
+
463
+ This method is thread safe but is called from the main thread.
464
+ It is safe to call this method multiple times, extra calls are ignored.
465
+ """
466
+ with self.start_stop_lock:
467
+ if self.running:
468
+ return
469
+ self.running = True
470
+ self.run_thread = threading.Thread(target=self._run)
471
+ self.run_thread.name = f"Worker_{self.uid}"
472
+ self.run_thread.daemon = True
473
+ self.run_thread.start()
474
+
475
+ def stop(self, drain_queue_and_join=True):
476
+ """Stops running threads and wait for them to exit, if necessary.
477
+
478
+ This method is thread safe and is called from various threads. Note that
479
+ the `drain_queue_and_join` argument must be set correctly.
480
+ It is safe to call this method multiple times, extra calls are ignored.
481
+
482
+ Args:
483
+ drain_queue_and_join: set to True to drain the queue of pending
484
+ items and wait for the worker thread to complete. Set to False
485
+ if invoked from a worker thread to avoid deadlocks. Note that
486
+ setting this to False means this enqueuer won't be reused.
487
+ """
488
+ with self.start_stop_lock:
489
+ if not self.running:
490
+ return
491
+ self.running = False
492
+
493
+ if drain_queue_and_join:
494
+ # Drain the `future_queue` and put items in `ready_queue` for
495
+ # the next run.
496
+ while True:
497
+ try:
498
+ value = self.future_queue.get(block=True, timeout=0.1)
499
+ if isinstance(value, Exception):
500
+ raise value # Propagate exception from other thread
501
+ inputs = value.get()
502
+ self.future_queue.task_done()
503
+ if inputs is not None:
504
+ self.ready_queue.put(inputs)
505
+ except queue.Empty:
506
+ break
507
+ self.run_thread.join()
508
+
509
+ self.run_thread = None
510
+ _SHARED_SEQUENCES[self.uid] = None
511
+
512
+ def _send_py_dataset(self):
513
+ """Sends current Iterable to all workers."""
514
+ # For new processes that may spawn
515
+ _SHARED_SEQUENCES[self.uid] = self.py_dataset
516
+
517
+ def __del__(self):
518
+ self.stop(drain_queue_and_join=False)
519
+
520
+ def _run(self):
521
+ """Submits request to the executor and queue the `Future` objects."""
522
+ raise NotImplementedError
523
+
524
+ def _get_executor_init(self, workers):
525
+ """Gets the Pool initializer for multiprocessing.
526
+
527
+ Args:
528
+ workers: Number of workers.
529
+
530
+ Returns:
531
+ Function, a Function to initialize the pool
532
+ """
533
+ raise NotImplementedError
534
+
535
+ def get(self):
536
+ """Creates a generator to extract data from the queue.
537
+
538
+ Skip the data if it is `None`.
539
+
540
+ This method is called from the main thread.
541
+
542
+ Yields:
543
+ The next element in the queue, i.e. a tuple
544
+ `(inputs, targets)` or
545
+ `(inputs, targets, sample_weights)`.
546
+ """
547
+ raise NotImplementedError
548
+
549
+
550
+ class OrderedEnqueuer(PyDatasetEnqueuer):
551
+ """Builds a Enqueuer from a PyDataset.
552
+
553
+ Args:
554
+ py_dataset: A `keras.utils.PyDataset` object.
555
+ use_multiprocessing: use multiprocessing if True, otherwise threading
556
+ shuffle: whether to shuffle the data at the beginning of each epoch
557
+ """
558
+
559
+ def __init__(
560
+ self,
561
+ py_dataset,
562
+ workers=1,
563
+ use_multiprocessing=False,
564
+ max_queue_size=10,
565
+ shuffle=False,
566
+ ):
567
+ super().__init__(
568
+ py_dataset, workers, use_multiprocessing, max_queue_size
569
+ )
570
+ self.shuffle = shuffle
571
+ if self.py_dataset.num_batches is None:
572
+ # For infinite datasets, `self.indices` is created here once for all
573
+ # so that subsequent runs resume from where they stopped.
574
+ self.indices = itertools.count()
575
+
576
+ def _get_executor_init(self, workers):
577
+ """Gets the Pool initializer for multiprocessing.
578
+
579
+ Args:
580
+ workers: Number of workers.
581
+
582
+ Returns:
583
+ Function, a Function to initialize the pool
584
+ """
585
+
586
+ def pool_fn(seqs):
587
+ pool = get_pool_class(True)(
588
+ workers,
589
+ initializer=init_pool_generator,
590
+ initargs=(seqs, None, get_worker_id_queue()),
591
+ )
592
+ _DATA_POOLS.add(pool)
593
+ return pool
594
+
595
+ return pool_fn
596
+
597
+ def _run(self):
598
+ """Submits request to the executor and queue the `Future` objects.
599
+
600
+ This method is the run method of worker threads.
601
+ """
602
+ try:
603
+ if self.py_dataset.num_batches is not None:
604
+ # For finite datasets, `self.indices` is created here so that
605
+ # shuffling creates different a order each time.
606
+ indices = range(self.py_dataset.num_batches)
607
+ if self.shuffle:
608
+ indices = list(indices)
609
+ random.shuffle(indices)
610
+ self.indices = iter(indices)
611
+ self._send_py_dataset() # Share the initial py_dataset
612
+
613
+ with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor:
614
+ while self.is_running():
615
+ try:
616
+ i = next(self.indices)
617
+ self.future_queue.put(
618
+ executor.apply_async(get_index, (self.uid, i)),
619
+ block=True,
620
+ )
621
+ except StopIteration:
622
+ break
623
+ except Exception as e:
624
+ self.future_queue.put(e) # Report exception
625
+
626
+ def get(self):
627
+ """Creates a generator to extract data from the queue.
628
+
629
+ Skip the data if it is `None`.
630
+
631
+ This method is called from the main thread.
632
+
633
+ Yields:
634
+ The next element in the queue, i.e. a tuple
635
+ `(inputs, targets)` or
636
+ `(inputs, targets, sample_weights)`.
637
+ """
638
+ while self.is_running():
639
+ try:
640
+ inputs = self.ready_queue.get(block=False)
641
+ yield inputs
642
+ continue # Retry the ready_queue
643
+ except queue.Empty:
644
+ pass
645
+
646
+ try:
647
+ value = self.future_queue.get(block=True, timeout=5)
648
+ self.future_queue.task_done()
649
+ if isinstance(value, Exception):
650
+ raise value # Propagate exception from other thread
651
+ inputs = value.get()
652
+ if inputs is not None:
653
+ yield inputs
654
+ except queue.Empty:
655
+ pass
656
+ except Exception as e:
657
+ self.stop(drain_queue_and_join=True)
658
+ raise e
659
+
660
+ # Note that it is ok to poll the iterator after the initial `start`,
661
+ # which may happen before the first `on_epoch_begin`. But it's not ok to
662
+ # poll after `on_epoch_end`.
663
+ raise ValueError(
664
+ "Iterator called after `on_epoch_end` or before `on_epoch_begin`."
665
+ )
666
+
667
+
668
+ def init_pool_generator(gens, random_seed=None, id_queue=None):
669
+ """Initializer function for pool workers.
670
+
671
+ Args:
672
+ gens: State which should be made available to worker processes.
673
+ random_seed: An optional value with which to seed child processes.
674
+ id_queue: A multiprocessing Queue of worker ids.
675
+ This is used to indicate that a worker process
676
+ was created by Keras.
677
+ """
678
+ global _SHARED_SEQUENCES
679
+ _SHARED_SEQUENCES = gens
680
+
681
+ worker_proc = multiprocessing.current_process()
682
+
683
+ # name isn't used for anything, but setting a more descriptive name is
684
+ # helpful when diagnosing orphaned processes.
685
+ worker_proc.name = f"Keras_worker_{worker_proc.name}"
686
+
687
+ if random_seed is not None:
688
+ np.random.seed(random_seed + worker_proc.ident)
689
+
690
+ if id_queue is not None:
691
+ # If a worker dies during init, the pool will just create a replacement.
692
+ id_queue.put(worker_proc.ident, block=True, timeout=0.1)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src import tree
2
+ from keras.src.trainers.data_adapters import data_adapter_utils
3
+ from keras.src.trainers.data_adapters.data_adapter import DataAdapter
4
+
5
+
6
+ class TFDatasetAdapter(DataAdapter):
7
+ """Adapter that handles `tf.data.Dataset`."""
8
+
9
+ def __init__(self, dataset, class_weight=None, distribution=None):
10
+ """Initialize the TFDatasetAdapter.
11
+
12
+ Args:
13
+ dataset: The input `tf.data.Dataset` instance.
14
+ class_weight: A map where the keys are integer class ids and values
15
+ are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`.
16
+ distribution: A `keras.distribution.Distribution` instance. Used to
17
+ shard the input dataset into per worker/process dataset
18
+ instance.
19
+ """
20
+ from keras.src.utils.module_utils import tensorflow as tf
21
+
22
+ if not isinstance(
23
+ dataset, (tf.data.Dataset, tf.distribute.DistributedDataset)
24
+ ):
25
+ raise ValueError(
26
+ "Expected argument `dataset` to be a tf.data.Dataset. "
27
+ f"Received: {dataset}"
28
+ )
29
+ if class_weight is not None:
30
+ dataset = dataset.map(
31
+ make_class_weight_map_fn(class_weight)
32
+ ).prefetch(tf.data.AUTOTUNE)
33
+ if distribution is not None:
34
+ dataset = distribution.distribute_dataset(dataset)
35
+ self._dataset = dataset
36
+
37
+ def get_numpy_iterator(self):
38
+ from keras.src.backend.tensorflow.core import convert_to_numpy
39
+
40
+ for batch in self._dataset:
41
+ yield tree.map_structure(convert_to_numpy, batch)
42
+
43
+ def get_jax_iterator(self):
44
+ from keras.src.backend.tensorflow.core import convert_to_numpy
45
+ from keras.src.utils.module_utils import tensorflow as tf
46
+
47
+ def convert_to_jax(x):
48
+ if isinstance(x, tf.SparseTensor):
49
+ return data_adapter_utils.tf_sparse_to_jax_sparse(x)
50
+ else:
51
+ # We use numpy as an intermediary because it is faster.
52
+ return convert_to_numpy(x)
53
+
54
+ for batch in self._dataset:
55
+ yield tree.map_structure(convert_to_jax, batch)
56
+
57
+ def get_tf_dataset(self):
58
+ return self._dataset
59
+
60
+ def get_torch_dataloader(self):
61
+ return data_adapter_utils.get_torch_dataloader(self._dataset)
62
+
63
+ @property
64
+ def num_batches(self):
65
+ cardinality = self._dataset.cardinality
66
+ if callable(cardinality):
67
+ # `dataset.cardinality` is normally expected to be a callable.
68
+ cardinality = int(self._dataset.cardinality())
69
+ else:
70
+ # However, in the case of `DistributedDataset`, it's a np.int64.
71
+ cardinality = int(cardinality)
72
+ # Return None for Unknown and Infinite cardinality datasets
73
+ if cardinality < 0:
74
+ return None
75
+ return cardinality
76
+
77
+ @property
78
+ def batch_size(self):
79
+ first_element_spec = tree.flatten(self._dataset.element_spec)[0]
80
+ return first_element_spec.shape[0]
81
+
82
+ @property
83
+ def has_partial_batch(self):
84
+ return None
85
+
86
+ @property
87
+ def partial_batch_size(self):
88
+ return None
89
+
90
+
91
+ def make_class_weight_map_fn(class_weight):
92
+ """Applies class weighting to a `Dataset`.
93
+
94
+ The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where
95
+ `y` must be a single `Tensor`.
96
+
97
+ Args:
98
+ class_weight: A map where the keys are integer class ids and values are
99
+ the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`
100
+
101
+ Returns:
102
+ A function that can be used with `tf.data.Dataset.map` to apply class
103
+ weighting.
104
+ """
105
+ from keras.src.utils.module_utils import tensorflow as tf
106
+
107
+ class_weight_tensor = tf.convert_to_tensor(
108
+ [
109
+ class_weight.get(int(c), 1.0)
110
+ for c in range(max(class_weight.keys()) + 1)
111
+ ]
112
+ )
113
+
114
+ def class_weights_map_fn(*data):
115
+ """Convert `class_weight` to `sample_weight`."""
116
+ x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data)
117
+ if sw is not None:
118
+ raise ValueError(
119
+ "You cannot `class_weight` and `sample_weight` "
120
+ "at the same time."
121
+ )
122
+ if tree.is_nested(y):
123
+ raise ValueError(
124
+ "`class_weight` is only supported for Models with a single "
125
+ "output."
126
+ )
127
+
128
+ if y.shape.rank >= 2:
129
+ y_classes = tf.__internal__.smart_cond.smart_cond(
130
+ tf.shape(y)[-1] > 1,
131
+ lambda: tf.argmax(y, axis=-1, output_type=tf.int32),
132
+ lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32),
133
+ )
134
+ else:
135
+ # Special casing for rank 1, where we can guarantee sparse encoding.
136
+ y_classes = tf.cast(tf.round(y), tf.int32)
137
+
138
+ cw = tf.gather(class_weight_tensor, y_classes)
139
+ return x, y, cw
140
+
141
+ return class_weights_map_fn
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+
3
+ import numpy as np
4
+
5
+ from keras.src import tree
6
+ from keras.src.trainers.data_adapters import data_adapter_utils
7
+ from keras.src.trainers.data_adapters.data_adapter import DataAdapter
8
+
9
+
10
+ class TorchDataLoaderAdapter(DataAdapter):
11
+ """Adapter that handles `torch.utils.data.DataLoader`."""
12
+
13
+ def __init__(self, dataloader):
14
+ import torch
15
+
16
+ if not isinstance(dataloader, torch.utils.data.DataLoader):
17
+ raise ValueError(
18
+ f"Expected argument `dataloader` to be an instance of"
19
+ f"`torch.utils.data.DataLoader`. Received: {dataloader}"
20
+ )
21
+
22
+ self._dataloader = dataloader
23
+ self._output_signature = None
24
+ self._batch_size = dataloader.batch_size
25
+ self._num_batches = None
26
+ self._partial_batch_size = None
27
+ if hasattr(dataloader.dataset, "__len__"):
28
+ self._num_batches = len(dataloader)
29
+ if self._batch_size is not None:
30
+ self._partial_batch_size = (
31
+ len(dataloader.dataset) % self._batch_size
32
+ )
33
+
34
+ def get_numpy_iterator(self):
35
+ for batch in self._dataloader:
36
+ # shared memory using `np.asarray`
37
+ yield tuple(
38
+ tree.map_structure(lambda x: np.asarray(x.cpu()), batch)
39
+ )
40
+
41
+ def get_jax_iterator(self):
42
+ # We use numpy as an intermediary because it is faster.
43
+ return self.get_numpy_iterator()
44
+
45
+ def get_tf_dataset(self):
46
+ from keras.src.utils.module_utils import tensorflow as tf
47
+
48
+ if self._output_signature is None:
49
+ batches = list(
50
+ itertools.islice(
51
+ self._dataloader,
52
+ data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC,
53
+ )
54
+ )
55
+ self._output_signature = tuple(
56
+ data_adapter_utils.get_tensor_spec(batches)
57
+ )
58
+ return tf.data.Dataset.from_generator(
59
+ self.get_numpy_iterator,
60
+ output_signature=self._output_signature,
61
+ )
62
+
63
+ def get_torch_dataloader(self):
64
+ return self._dataloader
65
+
66
+ @property
67
+ def num_batches(self):
68
+ return self._num_batches
69
+
70
+ @property
71
+ def batch_size(self):
72
+ return self._batch_size
73
+
74
+ @property
75
+ def has_partial_batch(self):
76
+ if self._partial_batch_size:
77
+ return self._partial_batch_size > 0
78
+ else:
79
+ return None
80
+
81
+ @property
82
+ def partial_batch_size(self):
83
+ return self._partial_batch_size
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Separation of concerns:
3
+
4
+ DataAdapter:
5
+ - x, y
6
+ - sample_weight
7
+ - class_weight
8
+ - shuffle
9
+ - batch_size
10
+ - steps, as it relates to batch_size for array data
11
+
12
+ EpochIterator:
13
+ - whether to yield numpy or tf data
14
+ - steps
15
+ - most argument validation
16
+
17
+ Trainer:
18
+ - steps_per_execution
19
+ - validation_split
20
+ - validation_data
21
+ - callbacks
22
+ - validation_freq
23
+ - epochs
24
+ - initial_epoch
25
+ - any backend-specific concern such as distribution
26
+
27
+ PyDataset:
28
+ - num_workers
29
+ - use_multiprocessing
30
+ - max_queue_size
31
+
32
+ EpochIterator steps:
33
+
34
+ 1. Look at data type and select correct DataHandler
35
+ 2. Instantiate DataHandler with correct arguments
36
+ 3. Raise or warn on unused arguments
37
+ 4. in __iter__, iterate, either for a fixed number of steps
38
+ or until there is no data
39
+
40
+ """
41
+
42
+ import contextlib
43
+ import warnings
44
+
45
+ from keras.src.trainers import data_adapters
46
+
47
+
48
+ class EpochIterator:
49
+ def __init__(
50
+ self,
51
+ x,
52
+ y=None,
53
+ sample_weight=None,
54
+ batch_size=None,
55
+ steps_per_epoch=None,
56
+ shuffle=False,
57
+ class_weight=None,
58
+ steps_per_execution=1,
59
+ ):
60
+ self.steps_per_epoch = steps_per_epoch
61
+ self.steps_per_execution = steps_per_execution
62
+ self._current_iterator = None
63
+ self._epoch_iterator = None
64
+ self._steps_seen = 0
65
+ self.data_adapter = data_adapters.get_data_adapter(
66
+ x=x,
67
+ y=y,
68
+ sample_weight=sample_weight,
69
+ batch_size=batch_size,
70
+ steps_per_epoch=steps_per_epoch,
71
+ shuffle=shuffle,
72
+ class_weight=class_weight,
73
+ )
74
+ self._num_batches = self.data_adapter.num_batches
75
+
76
+ def _get_iterator(self):
77
+ return self.data_adapter.get_numpy_iterator()
78
+
79
+ def _interrupted_warning(self):
80
+ warnings.warn(
81
+ "Your input ran out of data; interrupting training. "
82
+ "Make sure that your dataset or generator can generate "
83
+ "at least `steps_per_epoch * epochs` batches. "
84
+ "You may need to use the `.repeat()` "
85
+ "function when building your dataset.",
86
+ stacklevel=2,
87
+ )
88
+
89
+ def reset(self):
90
+ self._current_iterator = None
91
+ self._num_batches = self.data_adapter.num_batches
92
+ self._steps_seen = 0
93
+ self._epoch_iterator = None
94
+ self.data_adapter.on_epoch_end()
95
+
96
+ def _enumerate_iterator(self):
97
+ self.data_adapter.on_epoch_begin()
98
+ steps_per_epoch = self.steps_per_epoch or self._num_batches or -1
99
+
100
+ if steps_per_epoch > 0:
101
+ if self._current_iterator is None or self.steps_per_epoch is None:
102
+ self._current_iterator = iter(self._get_iterator())
103
+ self._steps_seen = 0
104
+ for step in range(0, steps_per_epoch, self.steps_per_execution):
105
+ if self._num_batches and self._steps_seen >= self._num_batches:
106
+ if self.steps_per_epoch:
107
+ self._interrupted_warning()
108
+ break
109
+ self._steps_seen += self.steps_per_execution
110
+ yield step, self._current_iterator
111
+ if self._num_batches and self._steps_seen >= self._num_batches:
112
+ self._current_iterator = iter(self._get_iterator())
113
+ self._steps_seen = 0
114
+ else:
115
+ iterator = iter(self._get_iterator())
116
+ step = -self.steps_per_execution
117
+ while True:
118
+ step += self.steps_per_execution
119
+ self._steps_seen = step + self.steps_per_execution
120
+ yield step, iterator
121
+ self.data_adapter.on_epoch_end()
122
+
123
+ def __iter__(self):
124
+ self._epoch_iterator = self._enumerate_iterator()
125
+ return self
126
+
127
+ def __next__(self):
128
+ buffer = []
129
+ step, iterator = next(self._epoch_iterator)
130
+ with self.catch_stop_iteration():
131
+ for _ in range(self.steps_per_execution):
132
+ data = next(iterator)
133
+ buffer.append(data)
134
+ return step, buffer
135
+ if buffer:
136
+ return step, buffer
137
+ raise StopIteration
138
+
139
+ def enumerate_epoch(self):
140
+ for step, data in self:
141
+ yield step, data
142
+
143
+ @contextlib.contextmanager
144
+ def catch_stop_iteration(self):
145
+ """Catches errors when an iterator runs out of data."""
146
+ try:
147
+ yield
148
+ except StopIteration:
149
+ if self._num_batches is None:
150
+ self._num_batches = self._steps_seen
151
+ self._interrupted_warning()
152
+ self._current_iterator = None
153
+ self.data_adapter.on_epoch_end()
154
+
155
+ @property
156
+ def num_batches(self):
157
+ if self.steps_per_epoch:
158
+ return self.steps_per_epoch
159
+ # Either copied from the data_adapter, or
160
+ # inferred at the end of an iteration.
161
+ return self._num_batches
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import platform
3
+ import warnings
4
+
5
+ from keras.src import backend
6
+ from keras.src import metrics as metrics_module
7
+ from keras.src import ops
8
+ from keras.src import optimizers
9
+ from keras.src import tree
10
+ from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer
11
+ from keras.src.saving import serialization_lib
12
+ from keras.src.trainers.compile_utils import CompileLoss
13
+ from keras.src.trainers.compile_utils import CompileMetrics
14
+ from keras.src.trainers.data_adapters import data_adapter_utils
15
+ from keras.src.utils import python_utils
16
+ from keras.src.utils import traceback_utils
17
+ from keras.src.utils import tracking
18
+
19
+
20
+ class Trainer:
21
+ def __init__(self):
22
+ self._lock = False
23
+ self._run_eagerly = False
24
+ self._jit_compile = None
25
+ self.compiled = False
26
+ self.loss = None
27
+ self.steps_per_execution = 1
28
+ # Can be set by callbacks in on_train_begin
29
+ self._initial_epoch = None
30
+ self._compute_loss_has_training_arg = (
31
+ "training" in inspect.signature(self.compute_loss).parameters
32
+ )
33
+
34
+ # Placeholders used in `compile`
35
+ self._compile_loss = None
36
+ self._compile_metrics = None
37
+ self._loss_tracker = None
38
+
39
+ @traceback_utils.filter_traceback
40
+ @tracking.no_automatic_dependency_tracking
41
+ def compile(
42
+ self,
43
+ optimizer="rmsprop",
44
+ loss=None,
45
+ loss_weights=None,
46
+ metrics=None,
47
+ weighted_metrics=None,
48
+ run_eagerly=False,
49
+ steps_per_execution=1,
50
+ jit_compile="auto",
51
+ auto_scale_loss=True,
52
+ ):
53
+ """Configures the model for training.
54
+
55
+ Example:
56
+
57
+ ```python
58
+ model.compile(
59
+ optimizer=keras.optimizers.Adam(learning_rate=1e-3),
60
+ loss=keras.losses.BinaryCrossentropy(),
61
+ metrics=[
62
+ keras.metrics.BinaryAccuracy(),
63
+ keras.metrics.FalseNegatives(),
64
+ ],
65
+ )
66
+ ```
67
+
68
+ Args:
69
+ optimizer: String (name of optimizer) or optimizer instance. See
70
+ `keras.optimizers`.
71
+ loss: Loss function. May be a string (name of loss function), or
72
+ a `keras.losses.Loss` instance. See `keras.losses`. A
73
+ loss function is any callable with the signature
74
+ `loss = fn(y_true, y_pred)`, where `y_true` are the ground truth
75
+ values, and `y_pred` are the model's predictions.
76
+ `y_true` should have shape `(batch_size, d0, .. dN)`
77
+ (except in the case of sparse loss functions such as
78
+ sparse categorical crossentropy which expects integer arrays of
79
+ shape `(batch_size, d0, .. dN-1)`).
80
+ `y_pred` should have shape `(batch_size, d0, .. dN)`.
81
+ The loss function should return a float tensor.
82
+ loss_weights: Optional list or dictionary specifying scalar
83
+ coefficients (Python floats) to weight the loss contributions of
84
+ different model outputs. The loss value that will be minimized
85
+ by the model will then be the *weighted sum* of all individual
86
+ losses, weighted by the `loss_weights` coefficients. If a list,
87
+ it is expected to have a 1:1 mapping to the model's outputs. If
88
+ a dict, it is expected to map output names (strings) to scalar
89
+ coefficients.
90
+ metrics: List of metrics to be evaluated by the model during
91
+ training and testing. Each of this can be a string (name of a
92
+ built-in function), function or a `keras.metrics.Metric`
93
+ instance. See `keras.metrics`. Typically you will use
94
+ `metrics=['accuracy']`. A function is any callable with the
95
+ signature `result = fn(y_true, _pred)`. To specify different
96
+ metrics for different outputs of a multi-output model, you could
97
+ also pass a dictionary, such as
98
+ `metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`.
99
+ You can also pass a list to specify a metric or a list of
100
+ metrics for each output, such as
101
+ `metrics=[['accuracy'], ['accuracy', 'mse']]`
102
+ or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass
103
+ the strings 'accuracy' or 'acc', we convert this to one of
104
+ `keras.metrics.BinaryAccuracy`,
105
+ `keras.metrics.CategoricalAccuracy`,
106
+ `keras.metrics.SparseCategoricalAccuracy` based on the
107
+ shapes of the targets and of the model output. A similar
108
+ conversion is done for the strings `"crossentropy"`
109
+ and `"ce"` as well.
110
+ The metrics passed here are evaluated without sample weighting;
111
+ if you would like sample weighting to apply, you can specify
112
+ your metrics via the `weighted_metrics` argument instead.
113
+ weighted_metrics: List of metrics to be evaluated and weighted by
114
+ `sample_weight` or `class_weight` during training and testing.
115
+ run_eagerly: Bool. If `True`, this model's forward pass
116
+ will never be compiled. It is recommended to leave this
117
+ as `False` when training (for best performance),
118
+ and to set it to `True` when debugging.
119
+ steps_per_execution: Int. The number of batches to run
120
+ during each a single compiled function call. Running multiple
121
+ batches inside a single compiled function call can
122
+ greatly improve performance on TPUs or small models with a large
123
+ Python overhead. At most, one full epoch will be run each
124
+ execution. If a number larger than the size of the epoch is
125
+ passed, the execution will be truncated to the size of the
126
+ epoch. Note that if `steps_per_execution` is set to `N`,
127
+ `Callback.on_batch_begin` and `Callback.on_batch_end` methods
128
+ will only be called every `N` batches (i.e. before/after
129
+ each compiled function execution).
130
+ Not supported with the PyTorch backend.
131
+ jit_compile: Bool or `"auto"`. Whether to use XLA compilation when
132
+ compiling a model. For `jax` and `tensorflow` backends,
133
+ `jit_compile="auto"` enables XLA compilation if the model
134
+ supports it, and disabled otherwise.
135
+ For `torch` backend, `"auto"` will default to eager
136
+ execution and `jit_compile=True` will run with `torch.compile`
137
+ with the `"inductor"` backend.
138
+ auto_scale_loss: Bool. If `True` and the model dtype policy is
139
+ `"mixed_float16"`, the passed optimizer will be automatically
140
+ wrapped in a `LossScaleOptimizer`, which will dynamically
141
+ scale the loss to prevent underflow.
142
+ """
143
+ optimizer = optimizers.get(optimizer)
144
+ self.optimizer = optimizer
145
+ if (
146
+ auto_scale_loss
147
+ and self.dtype_policy.name == "mixed_float16"
148
+ and self.optimizer
149
+ and not isinstance(self.optimizer, LossScaleOptimizer)
150
+ ):
151
+ self.optimizer = LossScaleOptimizer(
152
+ self.optimizer, name="loss_scale_optimizer"
153
+ )
154
+ if hasattr(self, "output_names"):
155
+ output_names = self.output_names
156
+ else:
157
+ output_names = None
158
+ if loss is not None:
159
+ self._compile_loss = CompileLoss(
160
+ loss, loss_weights, output_names=output_names
161
+ )
162
+ self.loss = loss
163
+ if metrics is not None or weighted_metrics is not None:
164
+ self._compile_metrics = CompileMetrics(
165
+ metrics, weighted_metrics, output_names=output_names
166
+ )
167
+ if jit_compile == "auto":
168
+ if run_eagerly:
169
+ jit_compile = False
170
+ else:
171
+ jit_compile = self._resolve_auto_jit_compile()
172
+ if jit_compile and run_eagerly:
173
+ jit_compile = False
174
+ warnings.warn(
175
+ "If `run_eagerly` is True, then `jit_compile` "
176
+ "cannot also be True. Disabling `jit_compile`.",
177
+ stacklevel=2,
178
+ )
179
+
180
+ self.jit_compile = jit_compile
181
+ self.run_eagerly = run_eagerly
182
+ self.stop_training = False
183
+ self.compiled = True
184
+ self._loss_tracker = metrics_module.Mean(name="loss")
185
+ self.steps_per_execution = steps_per_execution
186
+
187
+ self.train_function = None
188
+ self.test_function = None
189
+ self.predict_function = None
190
+
191
+ self._compile_config = serialization_lib.SerializableDict(
192
+ optimizer=optimizer,
193
+ loss=loss,
194
+ loss_weights=loss_weights,
195
+ metrics=metrics,
196
+ weighted_metrics=weighted_metrics,
197
+ run_eagerly=run_eagerly,
198
+ steps_per_execution=steps_per_execution,
199
+ jit_compile=jit_compile,
200
+ )
201
+
202
+ @property
203
+ def jit_compile(self):
204
+ if self._jit_compile is None:
205
+ # Value was never set. Resolve it now.
206
+ self._jit_compile = self._resolve_auto_jit_compile()
207
+ return self._jit_compile
208
+
209
+ @jit_compile.setter
210
+ def jit_compile(self, value):
211
+ if value and not model_supports_jit(self):
212
+ warnings.warn(
213
+ "Model doesn't support `jit_compile=True`. "
214
+ "Proceeding with `jit_compile=False`."
215
+ )
216
+ self._jit_compile = False
217
+ else:
218
+ self._jit_compile = value
219
+
220
+ def _resolve_auto_jit_compile(self):
221
+ if backend.backend() == "torch":
222
+ # jit_compile = "auto" with the pytorch backend defaults to eager
223
+ return False
224
+
225
+ if backend.backend() == "tensorflow":
226
+ import tensorflow as tf
227
+
228
+ devices = tf.config.list_physical_devices()
229
+ if not list(filter(lambda x: x.device_type != "CPU", devices)):
230
+ # Disable XLA on CPU-only machines.
231
+ return False
232
+
233
+ if self._distribute_strategy:
234
+ # Disable XLA with tf.distribute
235
+ return False
236
+
237
+ if model_supports_jit(self):
238
+ return True
239
+ return False
240
+
241
+ @property
242
+ def run_eagerly(self):
243
+ return self._run_eagerly
244
+
245
+ @run_eagerly.setter
246
+ def run_eagerly(self, value):
247
+ self._run_eagerly = value
248
+
249
+ @property
250
+ def metrics(self):
251
+ # Order: loss tracker, individual loss trackers, compiled metrics,
252
+ # custom metrcis, sublayer metrics.
253
+ metrics = []
254
+ if self.compiled:
255
+ if self._loss_tracker is not None:
256
+ metrics.append(self._loss_tracker)
257
+ if self._compile_metrics is not None:
258
+ metrics.append(self._compile_metrics)
259
+ if self._compile_loss is not None:
260
+ metrics.extend(self._compile_loss.metrics)
261
+ metrics.extend(self._metrics)
262
+ for layer in self._flatten_layers(include_self=False):
263
+ if isinstance(layer, Trainer):
264
+ # All Trainer-related metrics in sublayers should be ignored
265
+ # because a new Trainer has been instantiated.
266
+ continue
267
+ metrics.extend(layer.metrics)
268
+ return metrics
269
+
270
+ @property
271
+ def metrics_names(self):
272
+ return [m.name for m in self.metrics]
273
+
274
+ def reset_metrics(self):
275
+ for m in self.metrics:
276
+ m.reset_state()
277
+
278
+ def _get_own_metrics(self):
279
+ metrics = []
280
+ if self._loss_tracker is not None:
281
+ metrics.append(self._loss_tracker)
282
+ if self._compile_metrics is not None:
283
+ metrics.append(self._compile_metrics)
284
+ if self._compile_loss is not None:
285
+ metrics.extend(self._compile_loss.metrics)
286
+ metrics.extend(self._metrics)
287
+ return metrics
288
+
289
+ def compute_loss(
290
+ self,
291
+ x=None,
292
+ y=None,
293
+ y_pred=None,
294
+ sample_weight=None,
295
+ training=True,
296
+ ):
297
+ """Compute the total loss, validate it, and return it.
298
+
299
+ Subclasses can optionally override this method to provide custom loss
300
+ computation logic.
301
+
302
+ Example:
303
+
304
+ ```python
305
+ class MyModel(Model):
306
+ def __init__(self, *args, **kwargs):
307
+ super().__init__(*args, **kwargs)
308
+ self.loss_tracker = metrics.Mean(name='loss')
309
+
310
+ def compute_loss(self, x, y, y_pred, sample_weight, training=True):
311
+ loss = ops.mean((y_pred - y) ** 2)
312
+ loss += ops.sum(self.losses)
313
+ self.loss_tracker.update_state(loss)
314
+ return loss
315
+
316
+ def reset_metrics(self):
317
+ self.loss_tracker.reset_state()
318
+
319
+ @property
320
+ def metrics(self):
321
+ return [self.loss_tracker]
322
+
323
+ inputs = layers.Input(shape=(10,), name='my_input')
324
+ outputs = layers.Dense(10)(inputs)
325
+ model = MyModel(inputs, outputs)
326
+ model.add_loss(ops.sum(outputs))
327
+
328
+ optimizer = SGD()
329
+ model.compile(optimizer, loss='mse', steps_per_execution=10)
330
+ dataset = ...
331
+ model.fit(dataset, epochs=2, steps_per_epoch=10)
332
+ print(f"Custom loss: {model.loss_tracker.result()}")
333
+ ```
334
+
335
+ Args:
336
+ x: Input data.
337
+ y: Target data.
338
+ y_pred: Predictions returned by the model (output of `model(x)`)
339
+ sample_weight: Sample weights for weighting the loss function.
340
+ training: Whether we are training or evaluating the model.
341
+
342
+ Returns:
343
+ The total loss as a scalar tensor, or `None` if no loss results
344
+ (which is the case when called by `Model.test_step`).
345
+ """
346
+ # The default implementation does not use `x` or `training`.
347
+ del x
348
+ del training
349
+ losses = []
350
+ if self._compile_loss is not None:
351
+ loss = self._compile_loss(y, y_pred, sample_weight)
352
+ if loss is not None:
353
+ losses.append(loss)
354
+ for loss in self.losses:
355
+ losses.append(self._aggregate_additional_loss(loss))
356
+ if backend.backend() != "jax" and len(losses) == 0:
357
+ raise ValueError(
358
+ "No loss to compute. Provide a `loss` argument in `compile()`."
359
+ )
360
+ if len(losses) == 1:
361
+ total_loss = losses[0]
362
+ elif len(losses) == 0:
363
+ total_loss = ops.zeros(())
364
+ else:
365
+ total_loss = ops.sum(losses)
366
+ return total_loss
367
+
368
+ def _compute_loss(
369
+ self,
370
+ x=None,
371
+ y=None,
372
+ y_pred=None,
373
+ sample_weight=None,
374
+ training=True,
375
+ ):
376
+ """Backwards compatibility wrapper for `compute_loss`.
377
+
378
+ This should be used instead `compute_loss` within `train_step` and
379
+ `test_step` to support overrides of `compute_loss` that may not have
380
+ the `training` argument, as this argument was added in Keras 3.3.
381
+ """
382
+ if self._compute_loss_has_training_arg:
383
+ return self.compute_loss(
384
+ x, y, y_pred, sample_weight, training=training
385
+ )
386
+ else:
387
+ return self.compute_loss(x, y, y_pred, sample_weight)
388
+
389
+ def _aggregate_additional_loss(self, loss):
390
+ """Aggregates losses from `add_loss`, regularizers and sublayers.
391
+
392
+ Args:
393
+ loss: A tensor representing the additional loss to aggregate.
394
+
395
+ Returns:
396
+ A tensor representing the summed loss, cast to the `floatx()` if
397
+ necessary.
398
+ """
399
+ if not backend.is_float_dtype(loss.dtype):
400
+ loss = ops.cast(loss, dtype=backend.floatx())
401
+ return ops.sum(loss)
402
+
403
+ def stateless_compute_loss(
404
+ self,
405
+ trainable_variables,
406
+ non_trainable_variables,
407
+ metrics_variables,
408
+ x=None,
409
+ y=None,
410
+ y_pred=None,
411
+ sample_weight=None,
412
+ training=True,
413
+ ):
414
+ var_mapping = list(zip(self.trainable_variables, trainable_variables))
415
+ var_mapping.extend(
416
+ zip(self.non_trainable_variables, non_trainable_variables)
417
+ )
418
+ var_mapping.extend(zip(self.metrics_variables, metrics_variables))
419
+ with backend.StatelessScope(state_mapping=var_mapping) as scope:
420
+ # Note that this is needed for the regularization loss, which need
421
+ # the latest value of train/non-trainable variables.
422
+ loss = self._compute_loss(
423
+ x,
424
+ y,
425
+ y_pred,
426
+ sample_weight=sample_weight,
427
+ training=training,
428
+ )
429
+
430
+ # Update non trainable vars (may have been updated in compute_loss)
431
+ non_trainable_variables = []
432
+ for v in self.non_trainable_variables:
433
+ new_v = scope.get_current_value(v)
434
+ non_trainable_variables.append(new_v)
435
+
436
+ # Update metrics vars (may have been updated in compute_loss)
437
+ metrics_variables = []
438
+ for v in self.metrics_variables:
439
+ new_v = scope.get_current_value(v)
440
+ metrics_variables.append(new_v)
441
+ return loss, (
442
+ trainable_variables,
443
+ non_trainable_variables,
444
+ metrics_variables,
445
+ )
446
+
447
+ def compute_metrics(self, x, y, y_pred, sample_weight=None):
448
+ """Update metric states and collect all metrics to be returned.
449
+
450
+ Subclasses can optionally override this method to provide custom metric
451
+ updating and collection logic. Custom metrics are not passed in
452
+ `compile()`, they can be created in `__init__` or `build`. They are
453
+ automatically tracked and returned by `self.metrics`.
454
+
455
+ Example:
456
+
457
+ ```python
458
+ class MyModel(Sequential):
459
+ def __init__(self, *args, **kwargs):
460
+ super().__init__(*args, **kwargs)
461
+ self.custom_metric = MyMetric(name="custom_metric")
462
+
463
+ def compute_metrics(self, x, y, y_pred, sample_weight):
464
+ # This super call updates metrics from `compile` and returns
465
+ # results for all metrics listed in `self.metrics`.
466
+ metric_results = super().compute_metrics(
467
+ x, y, y_pred, sample_weight)
468
+
469
+ # `metric_results` contains the previous result for
470
+ # `custom_metric`, this is where we update it.
471
+ self.custom_metric.update_state(x, y, y_pred, sample_weight)
472
+ metric_results['custom_metric'] = self.custom_metric.result()
473
+ return metric_results
474
+ ```
475
+
476
+ Args:
477
+ x: Input data.
478
+ y: Target data.
479
+ y_pred: Predictions returned by the model output of `model.call(x)`.
480
+ sample_weight: Sample weights for weighting the loss function.
481
+
482
+ Returns:
483
+ A `dict` containing values that will be passed to
484
+ `keras.callbacks.CallbackList.on_train_batch_end()`. Typically,
485
+ the values of the metrics listed in `self.metrics` are returned.
486
+ Example: `{'loss': 0.2, 'accuracy': 0.7}`.
487
+ """
488
+ del x # The default implementation does not use `x`.
489
+ if self._compile_metrics is not None:
490
+ self._compile_metrics.update_state(y, y_pred, sample_weight)
491
+ return self.get_metrics_result()
492
+
493
+ def get_metrics_result(self):
494
+ """Returns the model's metrics values as a dict.
495
+
496
+ If any of the metric result is a dict (containing multiple metrics),
497
+ each of them gets added to the top level returned dict of this method.
498
+
499
+ Returns:
500
+ A `dict` containing values of the metrics listed in `self.metrics`.
501
+ Example: `{'loss': 0.2, 'accuracy': 0.7}`.
502
+ """
503
+ return_metrics = {}
504
+ for metric in self.metrics:
505
+ result = metric.result()
506
+ if isinstance(result, dict):
507
+ return_metrics.update(result)
508
+ else:
509
+ return_metrics[metric.name] = result
510
+ return python_utils.pythonify_logs(return_metrics)
511
+
512
+ def fit(
513
+ self,
514
+ x=None,
515
+ y=None,
516
+ batch_size=None,
517
+ epochs=1,
518
+ verbose="auto",
519
+ callbacks=None,
520
+ validation_split=0.0,
521
+ validation_data=None,
522
+ shuffle=True,
523
+ class_weight=None,
524
+ sample_weight=None,
525
+ initial_epoch=0,
526
+ steps_per_epoch=None,
527
+ validation_steps=None,
528
+ validation_batch_size=None,
529
+ validation_freq=1,
530
+ ):
531
+ """Trains the model for a fixed number of epochs (dataset iterations).
532
+
533
+ Args:
534
+ x: Input data. It can be:
535
+ - A NumPy array (or array-like), or a list of arrays
536
+ (in case the model has multiple inputs).
537
+ - A backend-native tensor, or a list of tensors
538
+ (in case the model has multiple inputs).
539
+ - A dict mapping input names to the corresponding array/tensors,
540
+ if the model has named inputs.
541
+ - A `keras.utils.PyDataset` returning `(inputs, targets)` or
542
+ `(inputs, targets, sample_weights)`.
543
+ - A `tf.data.Dataset` yielding `(inputs, targets)` or
544
+ `(inputs, targets, sample_weights)`.
545
+ - A `torch.utils.data.DataLoader` yielding `(inputs, targets)`
546
+ or `(inputs, targets, sample_weights)`.
547
+ - A Python generator function yielding `(inputs, targets)` or
548
+ `(inputs, targets, sample_weights)`.
549
+ y: Target data. Like the input data `x`, it can be either NumPy
550
+ array(s) or backend-native tensor(s). If `x` is a
551
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
552
+ `torch.utils.data.DataLoader` or a Python generator function,
553
+ `y` should not be specified since targets will be obtained from
554
+ `x`.
555
+ batch_size: Integer or `None`.
556
+ Number of samples per gradient update.
557
+ If unspecified, `batch_size` will default to 32.
558
+ Do not specify the `batch_size` if your input data `x` is a
559
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
560
+ `torch.utils.data.DataLoader` or Python generator function
561
+ since they generate batches.
562
+ epochs: Integer. Number of epochs to train the model.
563
+ An epoch is an iteration over the entire `x` and `y`
564
+ data provided
565
+ (unless the `steps_per_epoch` flag is set to
566
+ something other than None).
567
+ Note that in conjunction with `initial_epoch`,
568
+ `epochs` is to be understood as "final epoch".
569
+ The model is not trained for a number of iterations
570
+ given by `epochs`, but merely until the epoch
571
+ of index `epochs` is reached.
572
+ verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
573
+ 0 = silent, 1 = progress bar, 2 = one line per epoch.
574
+ "auto" becomes 1 for most cases.
575
+ Note that the progress bar is not
576
+ particularly useful when logged to a file,
577
+ so `verbose=2` is recommended when not running interactively
578
+ (e.g., in a production environment). Defaults to `"auto"`.
579
+ callbacks: List of `keras.callbacks.Callback` instances.
580
+ List of callbacks to apply during training.
581
+ See `keras.callbacks`. Note
582
+ `keras.callbacks.ProgbarLogger` and
583
+ `keras.callbacks.History` callbacks are created
584
+ automatically and need not be passed to `model.fit()`.
585
+ `keras.callbacks.ProgbarLogger` is created
586
+ or not based on the `verbose` argument in `model.fit()`.
587
+ validation_split: Float between 0 and 1.
588
+ Fraction of the training data to be used as validation data.
589
+ The model will set apart this fraction of the training data,
590
+ will not train on it, and will evaluate the loss and any model
591
+ metrics on this data at the end of each epoch. The validation
592
+ data is selected from the last samples in the `x` and `y` data
593
+ provided, before shuffling.
594
+ This argument is only supported when `x` and `y` are made of
595
+ NumPy arrays or tensors.
596
+ If both `validation_data` and `validation_split` are provided,
597
+ `validation_data` will override `validation_split`.
598
+ validation_data: Data on which to evaluate
599
+ the loss and any model metrics at the end of each epoch.
600
+ The model will not be trained on this data. Thus, note the fact
601
+ that the validation loss of data provided using
602
+ `validation_split` or `validation_data` is not affected by
603
+ regularization layers like noise and dropout.
604
+ `validation_data` will override `validation_split`.
605
+ It can be:
606
+ - A tuple `(x_val, y_val)` of NumPy arrays or tensors.
607
+ - A tuple `(x_val, y_val, val_sample_weights)` of NumPy
608
+ arrays.
609
+ - A `keras.utils.PyDataset`, a `tf.data.Dataset`, a
610
+ `torch.utils.data.DataLoader` yielding `(inputs, targets)` or a
611
+ Python generator function yielding `(x_val, y_val)` or
612
+ `(inputs, targets, sample_weights)`.
613
+ shuffle: Boolean, whether to shuffle the training data before each
614
+ epoch. This argument is ignored when `x` is a
615
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
616
+ `torch.utils.data.DataLoader` or Python generator function.
617
+ class_weight: Optional dictionary mapping class indices (integers)
618
+ to a weight (float) value, used for weighting the loss function
619
+ (during training only).
620
+ This can be useful to tell the model to
621
+ "pay more attention" to samples from
622
+ an under-represented class. When `class_weight` is specified
623
+ and targets have a rank of 2 or greater, either `y` must be
624
+ one-hot encoded, or an explicit final dimension of `1` must
625
+ be included for sparse class labels.
626
+ sample_weight: Optional NumPy array or tensor of weights for
627
+ the training samples, used for weighting the loss function
628
+ (during training only). You can either pass a flat (1D)
629
+ NumPy array or tensor with the same length as the input samples
630
+ (1:1 mapping between weights and samples), or in the case of
631
+ temporal data, you can pass a 2D NumPy array or tensor with
632
+ shape `(samples, sequence_length)` to apply a different weight
633
+ to every timestep of every sample.
634
+ This argument is not supported when `x` is a
635
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
636
+ `torch.utils.data.DataLoader` or Python generator function.
637
+ Instead, provide `sample_weights` as the third element of `x`.
638
+ Note that sample weighting does not apply to metrics specified
639
+ via the `metrics` argument in `compile()`. To apply sample
640
+ weighting to your metrics, you can specify them via the
641
+ `weighted_metrics` in `compile()` instead.
642
+ initial_epoch: Integer.
643
+ Epoch at which to start training
644
+ (useful for resuming a previous training run).
645
+ steps_per_epoch: Integer or `None`.
646
+ Total number of steps (batches of samples) before declaring one
647
+ epoch finished and starting the next epoch. When training with
648
+ input tensors or NumPy arrays, the default `None` means that the
649
+ value used is the number of samples in your dataset divided by
650
+ the batch size, or 1 if that cannot be determined.
651
+ If `x` is a `keras.utils.PyDataset`, `tf.data.Dataset`,
652
+ `torch.utils.data.DataLoader` or Python generator function, the
653
+ epoch will run until the input dataset is exhausted. When
654
+ passing an infinitely repeating dataset, you must specify the
655
+ `steps_per_epoch` argument, otherwise the training will run
656
+ indefinitely.
657
+ validation_steps: Integer or `None`.
658
+ Only relevant if `validation_data` is provided.
659
+ Total number of steps (batches of samples) to draw before
660
+ stopping when performing validation at the end of every epoch.
661
+ If `validation_steps` is `None`, validation will run until the
662
+ `validation_data` dataset is exhausted. In the case of an
663
+ infinitely repeating dataset, it will run indefinitely. If
664
+ `validation_steps` is specified and only part of the dataset
665
+ is consumed, the evaluation will start from the beginning of the
666
+ dataset at each epoch. This ensures that the same validation
667
+ samples are used every time.
668
+ validation_batch_size: Integer or `None`.
669
+ Number of samples per validation batch.
670
+ If unspecified, will default to `batch_size`.
671
+ Do not specify the `validation_batch_size` if your data is a
672
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
673
+ `torch.utils.data.DataLoader` or Python generator function
674
+ since they generate batches.
675
+ validation_freq: Only relevant if validation data is provided.
676
+ Specifies how many training epochs to run
677
+ before a new validation run is performed,
678
+ e.g. `validation_freq=2` runs validation every 2 epochs.
679
+
680
+ Unpacking behavior for iterator-like inputs:
681
+ A common pattern is to pass an iterator like object such as a
682
+ `tf.data.Dataset` or a `keras.utils.PyDataset` to `fit()`,
683
+ which will in fact yield not only features (`x`)
684
+ but optionally targets (`y`) and sample weights (`sample_weight`).
685
+ Keras requires that the output of such iterator-likes be
686
+ unambiguous. The iterator should return a tuple
687
+ of length 1, 2, or 3, where the optional second and third elements
688
+ will be used for `y` and `sample_weight` respectively.
689
+ Any other type provided will be wrapped in
690
+ a length-one tuple, effectively treating everything as `x`. When
691
+ yielding dicts, they should still adhere to the top-level tuple
692
+ structure,
693
+ e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate
694
+ features, targets, and weights from the keys of a single dict.
695
+ A notable unsupported data type is the `namedtuple`. The reason is
696
+ that it behaves like both an ordered datatype (tuple) and a mapping
697
+ datatype (dict). So given a namedtuple of the form:
698
+ `namedtuple("example_tuple", ["y", "x"])`
699
+ it is ambiguous whether to reverse the order of the elements when
700
+ interpreting the value. Even worse is a tuple of the form:
701
+ `namedtuple("other_tuple", ["x", "y", "z"])`
702
+ where it is unclear if the tuple was intended to be unpacked
703
+ into `x`, `y`, and `sample_weight` or passed through
704
+ as a single element to `x`.
705
+
706
+ Returns:
707
+ A `History` object. Its `History.history` attribute is
708
+ a record of training loss values and metrics values
709
+ at successive epochs, as well as validation loss values
710
+ and validation metrics values (if applicable).
711
+ """
712
+ raise NotImplementedError
713
+
714
+ def evaluate(
715
+ self,
716
+ x=None,
717
+ y=None,
718
+ batch_size=None,
719
+ verbose="auto",
720
+ sample_weight=None,
721
+ steps=None,
722
+ callbacks=None,
723
+ return_dict=False,
724
+ **kwargs,
725
+ ):
726
+ """Returns the loss value & metrics values for the model in test mode.
727
+
728
+ Computation is done in batches (see the `batch_size` arg.)
729
+
730
+ Args:
731
+ x: Input data. It can be:
732
+ - A NumPy array (or array-like), or a list of arrays
733
+ (in case the model has multiple inputs).
734
+ - A backend-native tensor, or a list of tensors
735
+ (in case the model has multiple inputs).
736
+ - A dict mapping input names to the corresponding array/tensors,
737
+ if the model has named inputs.
738
+ - A `keras.utils.PyDataset` returning `(inputs, targets)` or
739
+ `(inputs, targets, sample_weights)`.
740
+ - A `tf.data.Dataset` yielding `(inputs, targets)` or
741
+ `(inputs, targets, sample_weights)`.
742
+ - A `torch.utils.data.DataLoader` yielding `(inputs, targets)`
743
+ or `(inputs, targets, sample_weights)`.
744
+ - A Python generator function yielding `(inputs, targets)` or
745
+ `(inputs, targets, sample_weights)`.
746
+ y: Target data. Like the input data `x`, it can be either NumPy
747
+ array(s) or backend-native tensor(s). If `x` is a
748
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
749
+ `torch.utils.data.DataLoader` or a Python generator function,
750
+ `y` should not be specified since targets will be obtained from
751
+ `x`.
752
+ batch_size: Integer or `None`.
753
+ Number of samples per batch of computation.
754
+ If unspecified, `batch_size` will default to 32.
755
+ Do not specify the `batch_size` if your input data `x` is a
756
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
757
+ `torch.utils.data.DataLoader` or Python generator function
758
+ since they generate batches.
759
+ verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
760
+ 0 = silent, 1 = progress bar, 2 = single line.
761
+ `"auto"` becomes 1 for most cases.
762
+ Note that the progress bar is not
763
+ particularly useful when logged to a file, so `verbose=2` is
764
+ recommended when not running interactively
765
+ (e.g. in a production environment). Defaults to `"auto"`.
766
+ sample_weight: Optional NumPy array or tensor of weights for
767
+ the training samples, used for weighting the loss function
768
+ (during training only). You can either pass a flat (1D)
769
+ NumPy array or tensor with the same length as the input samples
770
+ (1:1 mapping between weights and samples), or in the case of
771
+ temporal data, you can pass a 2D NumPy array or tensor with
772
+ shape `(samples, sequence_length)` to apply a different weight
773
+ to every timestep of every sample.
774
+ This argument is not supported when `x` is a
775
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
776
+ `torch.utils.data.DataLoader` or Python generator function.
777
+ Instead, provide `sample_weights` as the third element of `x`.
778
+ Note that sample weighting does not apply to metrics specified
779
+ via the `metrics` argument in `compile()`. To apply sample
780
+ weighting to your metrics, you can specify them via the
781
+ `weighted_metrics` in `compile()` instead.
782
+ steps: Integer or `None`.
783
+ Total number of steps (batches of samples) to draw before
784
+ declaring the evaluation round finished. If `steps` is `None`,
785
+ it will run until `x` is exhausted. In the case of an infinitely
786
+ repeating dataset, it will run indefinitely.
787
+ callbacks: List of `keras.callbacks.Callback` instances.
788
+ List of callbacks to apply during evaluation.
789
+ return_dict: If `True`, loss and metric results are returned as a
790
+ dict, with each key being the name of the metric.
791
+ If `False`, they are returned as a list.
792
+
793
+ Returns:
794
+ Scalar test loss (if the model has a single output and no metrics)
795
+ or list of scalars (if the model has multiple outputs
796
+ and/or metrics). The attribute `model.metrics_names` will give you
797
+ the display labels for the scalar outputs.
798
+ """
799
+ raise NotImplementedError
800
+
801
+ def predict(
802
+ self, x, batch_size=None, verbose="auto", steps=None, callbacks=None
803
+ ):
804
+ """Generates output predictions for the input samples.
805
+
806
+ Computation is done in batches. This method is designed for batch
807
+ processing of large numbers of inputs. It is not intended for use inside
808
+ of loops that iterate over your data and process small numbers of inputs
809
+ at a time.
810
+
811
+ For small numbers of inputs that fit in one batch,
812
+ directly use `__call__()` for faster execution, e.g.,
813
+ `model(x)`, or `model(x, training=False)` if you have layers such as
814
+ `BatchNormalization` that behave differently during
815
+ inference.
816
+
817
+ Note: See [this FAQ entry](
818
+ https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call)
819
+ for more details about the difference between `Model` methods
820
+ `predict()` and `__call__()`.
821
+
822
+ Args:
823
+ x: Input data. It can be:
824
+ - A NumPy array (or array-like), or a list of arrays
825
+ (in case the model has multiple inputs).
826
+ - A backend-native tensor, or a list of tensors
827
+ (in case the model has multiple inputs).
828
+ - A dict mapping input names to the corresponding array/tensors,
829
+ if the model has named inputs.
830
+ - A `keras.utils.PyDataset`.
831
+ - A `tf.data.Dataset`.
832
+ - A `torch.utils.data.DataLoader`.
833
+ - A Python generator function.
834
+ batch_size: Integer or `None`.
835
+ Number of samples per batch of computation.
836
+ If unspecified, `batch_size` will default to 32.
837
+ Do not specify the `batch_size` if your input data `x` is a
838
+ `keras.utils.PyDataset`, `tf.data.Dataset`,
839
+ `torch.utils.data.DataLoader` or Python generator function
840
+ since they generate batches.
841
+ verbose: `"auto"`, 0, 1, or 2. Verbosity mode.
842
+ 0 = silent, 1 = progress bar, 2 = single line.
843
+ `"auto"` becomes 1 for most cases. Note that the progress bar
844
+ is not particularly useful when logged to a file,
845
+ so `verbose=2` is recommended when not running interactively
846
+ (e.g. in a production environment). Defaults to `"auto"`.
847
+ steps: Total number of steps (batches of samples) to draw before
848
+ declaring the prediction round finished. If `steps` is `None`,
849
+ it will run until `x` is exhausted. In the case of an infinitely
850
+ repeating dataset, it will run indefinitely.
851
+ callbacks: List of `keras.callbacks.Callback` instances.
852
+ List of callbacks to apply during prediction.
853
+
854
+ Returns:
855
+ NumPy array(s) of predictions.
856
+ """
857
+ raise NotImplementedError
858
+
859
+ def train_on_batch(
860
+ self,
861
+ x,
862
+ y=None,
863
+ sample_weight=None,
864
+ class_weight=None,
865
+ return_dict=False,
866
+ ):
867
+ """Runs a single gradient update on a single batch of data.
868
+
869
+ Args:
870
+ x: Input data. Must be array-like.
871
+ y: Target data. Must be array-like.
872
+ sample_weight: Optional array of the same length as x, containing
873
+ weights to apply to the model's loss for each sample.
874
+ In the case of temporal data, you can pass a 2D array
875
+ with shape `(samples, sequence_length)`, to apply a different
876
+ weight to every timestep of every sample.
877
+ class_weight: Optional dictionary mapping class indices (integers)
878
+ to a weight (float) to apply to the model's loss for the samples
879
+ from this class during training. This can be useful to tell the
880
+ model to "pay more attention" to samples from an
881
+ under-represented class. When `class_weight` is specified
882
+ and targets have a rank of 2 or greater, either `y` must
883
+ be one-hot encoded, or an explicit final dimension of 1
884
+ must be included for sparse class labels.
885
+ return_dict: If `True`, loss and metric results are returned as a
886
+ dict, with each key being the name of the metric. If `False`,
887
+ they are returned as a list.
888
+
889
+ Returns:
890
+ A scalar loss value (when no metrics and `return_dict=False`),
891
+ a list of loss and metric values
892
+ (if there are metrics and `return_dict=False`), or a dict of
893
+ metric and loss values (if `return_dict=True`).
894
+ """
895
+ raise NotImplementedError
896
+
897
+ def test_on_batch(
898
+ self,
899
+ x,
900
+ y=None,
901
+ sample_weight=None,
902
+ return_dict=False,
903
+ ):
904
+ """Test the model on a single batch of samples.
905
+
906
+ Args:
907
+ x: Input data. Must be array-like.
908
+ y: Target data. Must be array-like.
909
+ sample_weight: Optional array of the same length as x, containing
910
+ weights to apply to the model's loss for each sample.
911
+ In the case of temporal data, you can pass a 2D array
912
+ with shape `(samples, sequence_length)`, to apply a different
913
+ weight to every timestep of every sample.
914
+ return_dict: If `True`, loss and metric results are returned as a
915
+ dict, with each key being the name of the metric. If `False`,
916
+ they are returned as a list.
917
+
918
+ Returns:
919
+ A scalar loss value (when no metrics and `return_dict=False`),
920
+ a list of loss and metric values
921
+ (if there are metrics and `return_dict=False`), or a dict of
922
+ metric and loss values (if `return_dict=True`).
923
+ """
924
+ raise NotImplementedError
925
+
926
+ def predict_on_batch(self, x):
927
+ """Returns predictions for a single batch of samples.
928
+
929
+ Args:
930
+ x: Input data. It must be array-like.
931
+
932
+ Returns:
933
+ NumPy array(s) of predictions.
934
+ """
935
+ raise NotImplementedError
936
+
937
+ def get_compile_config(self):
938
+ """Returns a serialized config with information for compiling the model.
939
+
940
+ This method returns a config dictionary containing all the information
941
+ (optimizer, loss, metrics, etc.) with which the model was compiled.
942
+
943
+ Returns:
944
+ A dict containing information for compiling the model.
945
+ """
946
+ if self.compiled and hasattr(self, "_compile_config"):
947
+ return self._compile_config.serialize()
948
+
949
+ def compile_from_config(self, config):
950
+ """Compiles the model with the information given in config.
951
+
952
+ This method uses the information in the config (optimizer, loss,
953
+ metrics, etc.) to compile the model.
954
+
955
+ Args:
956
+ config: Dict containing information for compiling the model.
957
+ """
958
+ has_overridden_compile = self.__class__.compile != Trainer.compile
959
+ if has_overridden_compile:
960
+ warnings.warn(
961
+ "`compile()` was not called as part of model loading "
962
+ "because the model's `compile()` method is custom. "
963
+ "All subclassed Models that have `compile()` "
964
+ "overridden should also override "
965
+ "`get_compile_config()` and `compile_from_config(config)`. "
966
+ "Alternatively, you can "
967
+ "call `compile()` manually after loading.",
968
+ stacklevel=2,
969
+ )
970
+ return
971
+ config = serialization_lib.deserialize_keras_object(config)
972
+ self.compile(**config)
973
+ if hasattr(self, "optimizer") and self.built:
974
+ # Create optimizer variables.
975
+ self.optimizer.build(self.trainable_variables)
976
+
977
+ def _should_eval(self, epoch, validation_freq):
978
+ epoch = epoch + 1 # one-index the user-facing epoch.
979
+ if isinstance(validation_freq, int):
980
+ return epoch % validation_freq == 0
981
+ elif isinstance(validation_freq, list):
982
+ return epoch in validation_freq
983
+ else:
984
+ raise ValueError(
985
+ "Expected `validation_freq` to be a list or int. "
986
+ f"Received: validation_freq={validation_freq} of the "
987
+ f"type {type(validation_freq)}."
988
+ )
989
+
990
+ def _get_metrics_result_or_logs(self, logs):
991
+ """Returns model metrics as a dict if the keys match with input logs.
992
+
993
+ When the training / evaluation is performed with an asynchronous steps,
994
+ the last scheduled `train / test_step` may not give the latest metrics
995
+ because it is not guaranteed to be executed the last. This method gets
996
+ metrics from the model directly instead of relying on the return from
997
+ last step function.
998
+
999
+ When the user has custom train / test step functions, the metrics
1000
+ returned may be different from `Model.metrics`. In those instances,
1001
+ this function will be no-op and return the logs passed in.
1002
+
1003
+ Args:
1004
+ logs: A `dict` of metrics returned by train / test step function.
1005
+
1006
+ Returns:
1007
+ A `dict` containing values of the metrics listed in `self.metrics`
1008
+ when logs and model metrics keys match. Otherwise it returns input
1009
+ `logs`.
1010
+ """
1011
+ metric_logs = self.get_metrics_result()
1012
+ # Verify that train / test step logs passed and metric logs have
1013
+ # matching keys. It could be different when using custom step functions,
1014
+ # in which case we return the logs from the last step.
1015
+ if isinstance(logs, dict) and set(logs.keys()) == set(
1016
+ metric_logs.keys()
1017
+ ):
1018
+ return metric_logs
1019
+ return logs
1020
+
1021
+ def _flatten_metrics_in_order(self, logs):
1022
+ """Turns `logs` dict into a list as per key order of `metrics_names`."""
1023
+ metric_names = []
1024
+ for metric in self.metrics:
1025
+ if isinstance(metric, CompileMetrics):
1026
+ metric_names += [
1027
+ sub_metric.name for sub_metric in metric.metrics
1028
+ ]
1029
+ else:
1030
+ metric_names.append(metric.name)
1031
+ results = []
1032
+ for name in metric_names:
1033
+ if name in logs:
1034
+ results.append(logs[name])
1035
+ for key in sorted(logs.keys()):
1036
+ if key not in metric_names:
1037
+ results.append(logs[key])
1038
+ if len(results) == 1:
1039
+ return results[0]
1040
+ return results
1041
+
1042
+ def _assert_compile_called(self, method_name=None):
1043
+ if not self.compiled:
1044
+ msg = "You must call `compile()` before "
1045
+ if metrics_module:
1046
+ msg += "using the model."
1047
+ else:
1048
+ msg += f"calling `{method_name}()`."
1049
+ raise ValueError(msg)
1050
+
1051
+ def _symbolic_build(self, iterator=None, data_batch=None):
1052
+ model_unbuilt = not all(layer.built for layer in self._flatten_layers())
1053
+ compile_metrics_unbuilt = (
1054
+ self._compile_metrics is not None
1055
+ and not self._compile_metrics.built
1056
+ )
1057
+ compile_loss_unbuilt = (
1058
+ self._compile_loss is not None and not self._compile_loss.built
1059
+ )
1060
+ optimizer_unbuilt = (
1061
+ self.optimizer is not None and not self.optimizer.built
1062
+ )
1063
+ if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt:
1064
+ # Create symbolic tensors matching an input batch.
1065
+
1066
+ def to_symbolic_input(v):
1067
+ if v is None:
1068
+ return None
1069
+ return backend.KerasTensor(
1070
+ v.shape, backend.standardize_dtype(v.dtype)
1071
+ )
1072
+
1073
+ if data_batch is None:
1074
+ for _, data_or_iterator in iterator:
1075
+ if isinstance(data_or_iterator, (list, tuple)):
1076
+ data_batch = data_or_iterator[0]
1077
+ else:
1078
+ data_batch = next(data_or_iterator)
1079
+ break
1080
+ data_batch = tree.map_structure(to_symbolic_input, data_batch)
1081
+ (
1082
+ x,
1083
+ y,
1084
+ sample_weight,
1085
+ ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch)
1086
+
1087
+ # Build all model state with `backend.compute_output_spec`.
1088
+ try:
1089
+ y_pred = backend.compute_output_spec(self, x, training=False)
1090
+ except Exception as e:
1091
+ raise RuntimeError(
1092
+ "Unable to automatically build the model. "
1093
+ "Please build it yourself before calling "
1094
+ "fit/evaluate/predict. "
1095
+ "A model is 'built' when its variables have "
1096
+ "been created and its `self.built` attribute "
1097
+ "is True. Usually, calling the model on a batch "
1098
+ "of data is the right way to build it.\n"
1099
+ "Exception encountered:\n"
1100
+ f"'{e}'"
1101
+ )
1102
+ if compile_metrics_unbuilt:
1103
+ # Build all metric state with `backend.compute_output_spec`.
1104
+ backend.compute_output_spec(
1105
+ self.compute_metrics,
1106
+ x,
1107
+ y,
1108
+ y_pred,
1109
+ sample_weight=sample_weight,
1110
+ )
1111
+ if compile_loss_unbuilt:
1112
+ # Build `CompileLoss` state with `backend.compute_output_spec`.
1113
+ backend.compute_output_spec(
1114
+ self._compute_loss,
1115
+ x,
1116
+ y,
1117
+ y_pred,
1118
+ sample_weight=sample_weight,
1119
+ training=False,
1120
+ )
1121
+ if optimizer_unbuilt:
1122
+ # Build optimizer
1123
+ self.optimizer.build(self.trainable_variables)
1124
+ self._post_build()
1125
+
1126
+
1127
+ def model_supports_jit(model):
1128
+ # XLA not supported with TF on MacOS GPU
1129
+ if platform.system() == "Darwin" and "arm" in platform.processor().lower():
1130
+ if backend.backend() == "tensorflow":
1131
+ from keras.src.utils.module_utils import tensorflow as tf
1132
+
1133
+ if tf.config.list_physical_devices("GPU"):
1134
+ return False
1135
+ # XLA not supported by some layers
1136
+ if all(x.supports_jit for x in model._flatten_layers()):
1137
+ if backend.backend() == "tensorflow":
1138
+ from tensorflow.python.framework.config import (
1139
+ is_op_determinism_enabled,
1140
+ )
1141
+
1142
+ if is_op_determinism_enabled():
1143
+ # disable XLA with determinism enabled since not all ops are
1144
+ # supported by XLA with determinism enabled.
1145
+ return False
1146
+ return True
1147
+ return False
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.tree.tree_api import assert_same_paths
2
+ from keras.src.tree.tree_api import assert_same_structure
3
+ from keras.src.tree.tree_api import flatten
4
+ from keras.src.tree.tree_api import flatten_with_path
5
+ from keras.src.tree.tree_api import is_nested
6
+ from keras.src.tree.tree_api import lists_to_tuples
7
+ from keras.src.tree.tree_api import map_shape_structure
8
+ from keras.src.tree.tree_api import map_structure
9
+ from keras.src.tree.tree_api import map_structure_up_to
10
+ from keras.src.tree.tree_api import pack_sequence_as
11
+ from keras.src.tree.tree_api import register_tree_node_class
12
+ from keras.src.tree.tree_api import traverse
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (676 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc ADDED
Binary file (11.3 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc ADDED
Binary file (6.11 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc ADDED
Binary file (14.4 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import collections.abc
3
+ import itertools
4
+
5
+ from keras.src.backend.config import backend
6
+ from keras.src.utils.module_utils import dmtree
7
+
8
+ # NOTE: There are two known discrepancies between this `dmtree` implementation
9
+ # of the tree API and the `optree` implementation:
10
+ #
11
+ # 1. `map_structure` with *multiple* structures and `map_structure_up_to` do not
12
+ # use the object registration (they use the raw `dmtree.map_structure` and
13
+ # `dmtree.map_structure_up_to`). This only has consequences with two types of
14
+ # structures:
15
+ # - `TrackedSet` will not explored (considered as a leaf).
16
+ # - `OrderedDict` will be traversed in the order of sorted keys, not the
17
+ # order of the items. This is typically inconsequential because functions
18
+ # used with `map_structure` and `map_structure_up_to` are typically not
19
+ # order dependent and are, in fact, stateless.
20
+ #
21
+ # 2. The handling of non-sortable keys in dictionaries in inconsistent. `optree`
22
+ # uses the iteration order while `dmtree` raises an error. This is not an
23
+ # issue as keys are always strings. But this is the reason why we document
24
+ # non-sortable keys as unsupported (meaning behavior is undefined).
25
+
26
+ REGISTERED_CLASSES = {}
27
+
28
+ ClassRegistration = collections.namedtuple(
29
+ "ClassRegistration", ["flatten", "unflatten"]
30
+ )
31
+
32
+
33
+ class TypeErrorRemapping:
34
+ def __enter__(self):
35
+ pass
36
+
37
+ def __exit__(self, exc_type, exc_value, traceback):
38
+ if exc_type is TypeError:
39
+ raise ValueError(exc_value).with_traceback(traceback)
40
+ return False
41
+
42
+
43
+ def register_tree_node(
44
+ cls,
45
+ flatten_func=None,
46
+ unflatten_func=None,
47
+ ):
48
+ if flatten_func is None:
49
+ flatten_func = lambda x: x.tree_flatten()
50
+ if unflatten_func is None:
51
+ unflatten_func = cls.tree_unflatten
52
+ REGISTERED_CLASSES[cls] = ClassRegistration(flatten_func, unflatten_func)
53
+
54
+
55
+ def register_tree_node_class(cls):
56
+ register_tree_node(cls)
57
+ return cls
58
+
59
+
60
+ register_tree_node(
61
+ collections.OrderedDict,
62
+ lambda d: (d.values(), list(d.keys()), d.keys()),
63
+ lambda metadata, children: collections.OrderedDict(zip(metadata, children)),
64
+ )
65
+
66
+ if backend() == "tensorflow":
67
+ from tensorflow.python.trackable.data_structures import ListWrapper
68
+ from tensorflow.python.trackable.data_structures import _DictWrapper
69
+
70
+ register_tree_node(
71
+ ListWrapper,
72
+ lambda x: (x, None),
73
+ lambda metadata, children: ListWrapper(list(children)),
74
+ )
75
+
76
+ def sorted_keys_and_values(d):
77
+ keys = sorted(list(d.keys()))
78
+ values = [d[k] for k in keys]
79
+ return values, keys, keys
80
+
81
+ register_tree_node(
82
+ _DictWrapper,
83
+ sorted_keys_and_values,
84
+ lambda metadata, children: _DictWrapper(
85
+ {key: child for key, child in zip(metadata, children)}
86
+ ),
87
+ )
88
+
89
+
90
+ def is_nested(structure):
91
+ return type(structure) in REGISTERED_CLASSES or dmtree.is_nested(structure)
92
+
93
+
94
+ def traverse(func, structure, top_down=True):
95
+ if not callable(func):
96
+ raise TypeError(
97
+ f"`func` must be callable, got {func} of type {type(func)}"
98
+ )
99
+
100
+ def remap_map_to_none(value, new_value):
101
+ if isinstance(value, type) and value.__name__ == "MAP_TO_NONE":
102
+ return new_value
103
+ return value
104
+
105
+ def traverse_top_down(s):
106
+ ret = func(s)
107
+ if ret is not None:
108
+ return remap_map_to_none(ret, dmtree.MAP_TO_NONE)
109
+ registration = REGISTERED_CLASSES.get(type(s), None)
110
+ if registration is None:
111
+ return None
112
+ flat_meta_s = registration.flatten(s)
113
+ flat_s = [
114
+ dmtree.traverse(traverse_top_down, x, top_down=True)
115
+ for x in list(flat_meta_s[0])
116
+ ]
117
+ return registration.unflatten(flat_meta_s[1], flat_s)
118
+
119
+ def traverse_bottom_up(s):
120
+ registration = REGISTERED_CLASSES.get(type(s), None)
121
+ if registration is not None:
122
+ flat_meta_s = registration.flatten(s)
123
+ ret = [traverse_bottom_up(x) for x in list(flat_meta_s[0])]
124
+ ret = registration.unflatten(flat_meta_s[1], ret)
125
+ elif not dmtree.is_nested(s):
126
+ ret = s
127
+ elif isinstance(s, collections.abc.Mapping):
128
+ ret = [traverse_bottom_up(s[key]) for key in sorted(s)]
129
+ ret = dmtree._sequence_like(s, ret)
130
+ else:
131
+ ret = [traverse_bottom_up(x) for x in s]
132
+ ret = dmtree._sequence_like(s, ret)
133
+ func_ret = func(ret)
134
+ return ret if func_ret is None else remap_map_to_none(func_ret, None)
135
+
136
+ if top_down:
137
+ return dmtree.traverse(traverse_top_down, structure, top_down=True)
138
+ else:
139
+ return traverse_bottom_up(structure)
140
+
141
+
142
+ def flatten(structure):
143
+ if not is_nested(structure):
144
+ return [structure]
145
+
146
+ flattened = []
147
+
148
+ def flatten_func(s):
149
+ registration = REGISTERED_CLASSES.get(type(s), None)
150
+ if registration is not None:
151
+ flat_s = list(registration.flatten(s)[0])
152
+ return dmtree.traverse(flatten_func, flat_s, top_down=True)
153
+ if not is_nested(s):
154
+ flattened.append(s)
155
+ return dmtree.MAP_TO_NONE if s is None else s
156
+ return None
157
+
158
+ dmtree.traverse(flatten_func, structure, top_down=True)
159
+ return flattened
160
+
161
+
162
+ def _recursive_flatten_with_path(path, structure, flattened):
163
+ registration = REGISTERED_CLASSES.get(type(structure), None)
164
+ if registration is not None:
165
+ flat_meta_paths = registration.flatten(structure)
166
+ flat = flat_meta_paths[0]
167
+ paths = (
168
+ flat_meta_paths[2]
169
+ if len(flat_meta_paths) >= 3
170
+ else itertools.count()
171
+ )
172
+ for key, value in zip(paths, flat):
173
+ _recursive_flatten_with_path(path + (key,), value, flattened)
174
+ elif not dmtree.is_nested(structure):
175
+ flattened.append((path, structure))
176
+ elif isinstance(structure, collections.abc.Mapping):
177
+ for key in sorted(structure):
178
+ _recursive_flatten_with_path(
179
+ path + (key,), structure[key], flattened
180
+ )
181
+ else:
182
+ for key, value in enumerate(structure):
183
+ _recursive_flatten_with_path(path + (key,), value, flattened)
184
+
185
+
186
+ def flatten_with_path(structure):
187
+ if not is_nested(structure):
188
+ return [((), structure)]
189
+
190
+ # Fully reimplemented in Python to handle registered classes, OrderedDict
191
+ # and namedtuples the same way as optree.
192
+ flattened = []
193
+ _recursive_flatten_with_path((), structure, flattened)
194
+ return flattened
195
+
196
+
197
+ def map_structure(func, *structures):
198
+ if not callable(func):
199
+ raise TypeError(
200
+ f"`func` must be callable, got {func} of type {type(func)}"
201
+ )
202
+
203
+ def func_traverse_wrapper(s):
204
+ if is_nested(s):
205
+ return None
206
+ ret = func(s)
207
+ if ret is None:
208
+ return dmtree.MAP_TO_NONE
209
+ return ret
210
+
211
+ if len(structures) == 1:
212
+ return traverse(func_traverse_wrapper, structures[0])
213
+
214
+ with TypeErrorRemapping():
215
+ return dmtree.map_structure(func, *structures)
216
+
217
+
218
+ def map_structure_up_to(shallow_structure, func, *structures):
219
+ if not callable(func):
220
+ raise TypeError(
221
+ f"`func` must be callable, got {func} of type {type(func)}"
222
+ )
223
+
224
+ with TypeErrorRemapping():
225
+ return dmtree.map_structure_up_to(shallow_structure, func, *structures)
226
+
227
+
228
+ def assert_same_structure(a, b):
229
+ # Fully reimplemented in Python to handle registered classes.
230
+
231
+ # Don't handle OrderedDict as a registered class, use the normal dict path
232
+ # so that OrderedDict is equivalent to dict per optree behavior.
233
+ a_registration = REGISTERED_CLASSES.get(type(a), None)
234
+ if isinstance(a, collections.OrderedDict):
235
+ a_registration = None
236
+
237
+ b_registration = REGISTERED_CLASSES.get(type(b), None)
238
+ if isinstance(b, collections.OrderedDict):
239
+ b_registration = None
240
+
241
+ if a_registration != b_registration:
242
+ raise ValueError(
243
+ f"Custom node type mismatch; "
244
+ f"expected type: {type(a)}, got type: {type(b)} "
245
+ f"while comparing {a} and {b}."
246
+ )
247
+ if a_registration is not None:
248
+ a_flat_meta = a_registration.flatten(a)
249
+ b_flat_meta = b_registration.flatten(b)
250
+ a_flat = list(a_flat_meta[0])
251
+ b_flat = list(b_flat_meta[0])
252
+ if not a_flat_meta[1] == b_flat_meta[1]:
253
+ raise ValueError(
254
+ f"Mismatch custom node data; "
255
+ f"expected: {a_flat_meta[1]}, got: {b_flat_meta[1]} "
256
+ f"while comparing {a} and {b}."
257
+ )
258
+ if len(a_flat) != len(b_flat):
259
+ raise ValueError(
260
+ f"Arity mismatch; expected: {len(a)}, got: {len(b)} "
261
+ f"while comparing {a} and {b}."
262
+ )
263
+ for sub_a, sub_b in zip(a_flat, b_flat):
264
+ assert_same_structure(sub_a, sub_b)
265
+ elif not dmtree.is_nested(a):
266
+ if dmtree.is_nested(b):
267
+ raise ValueError(
268
+ f"Structures don't have the same nested structure: {a}, {b}."
269
+ )
270
+ elif isinstance(
271
+ a, (dict, collections.OrderedDict, collections.defaultdict)
272
+ ):
273
+ if not isinstance(
274
+ b, (dict, collections.OrderedDict, collections.defaultdict)
275
+ ):
276
+ raise ValueError(
277
+ f"Expected an instance of dict, collections.OrderedDict, or "
278
+ f"collections.defaultdict, got {type(b)} "
279
+ f"while comparing {a} and {b}."
280
+ )
281
+ a_keys = sorted(a)
282
+ b_keys = sorted(b)
283
+ if not a_keys == b_keys:
284
+ raise ValueError(
285
+ f"Dictionary key mismatch; "
286
+ f"expected key(s): {a_keys}, got key(s): {b_keys} "
287
+ f"while comparing {a} and {b}."
288
+ )
289
+ for key in a_keys:
290
+ assert_same_structure(a[key], b[key])
291
+ elif isinstance(a, collections.abc.Mapping):
292
+ raise ValueError(
293
+ f"Encountered unregistered collections.abc.Mapping type: {type(a)} "
294
+ f"while comparing {a} and {b}."
295
+ )
296
+ else:
297
+ if type(a) is not type(b):
298
+ raise ValueError(
299
+ f"Expected an instance of {type(a)}, got {type(b)} "
300
+ f"while comparing {a} and {b}."
301
+ )
302
+ if not len(a) == len(b):
303
+ raise ValueError(
304
+ f"Arity mismatch; expected: {len(a)}, got: {len(b)} "
305
+ f"while comparing {a} and {b}."
306
+ )
307
+ for sub_a, sub_b in zip(a, b):
308
+ assert_same_structure(sub_a, sub_b)
309
+
310
+
311
+ def assert_same_paths(a, b):
312
+ a_paths = set([path for path, _ in flatten_with_path(a)])
313
+ b_paths = set([path for path, _ in flatten_with_path(b)])
314
+
315
+ if a_paths != b_paths:
316
+ msg = "`a` and `b` don't have the same paths."
317
+ a_diff = a_paths.difference(b_paths)
318
+ if a_diff:
319
+ msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
320
+ b_diff = b_paths.difference(a_paths)
321
+ if b_diff:
322
+ msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
323
+ raise ValueError(msg)
324
+
325
+
326
+ def pack_sequence_as(structure, flat_sequence):
327
+ # This is not just an optimization for the case when structure is a leaf.
328
+ # This is required to avoid Torch Dynamo failures.
329
+ if not is_nested(structure):
330
+ if len(flat_sequence) == 1:
331
+ return flat_sequence[0]
332
+ else:
333
+ raise ValueError(
334
+ "Incorrect number of leaves provided by `flat_sequence` for "
335
+ f"`structure`; expected: 1, got {len(flat_sequence)}."
336
+ )
337
+
338
+ flat_sequence_it = enumerate(flat_sequence)
339
+
340
+ def unflatten_func(s):
341
+ registration = REGISTERED_CLASSES.get(type(s), None)
342
+ if registration is not None:
343
+ flat_meta_s = registration.flatten(s)
344
+ flat_s = dmtree.traverse(
345
+ unflatten_func, list(flat_meta_s[0]), top_down=True
346
+ )
347
+ return registration.unflatten(flat_meta_s[1], flat_s)
348
+ elif not dmtree.is_nested(s):
349
+ try:
350
+ _, value = next(flat_sequence_it)
351
+ return dmtree.MAP_TO_NONE if value is None else value
352
+ except StopIteration:
353
+ raise ValueError(
354
+ "Too few leaves provided by `flat_sequence` for "
355
+ f"`structure`. Got {len(flat_sequence)}."
356
+ )
357
+ return None
358
+
359
+ ret = dmtree.traverse(unflatten_func, structure, top_down=True)
360
+ try:
361
+ index, _ = next(flat_sequence_it)
362
+ raise ValueError(
363
+ "Too many leaves provided by `flat_sequence` for `structure`; "
364
+ f"expected: {index}, got {len(flat_sequence)}."
365
+ )
366
+ except StopIteration:
367
+ return ret
368
+
369
+
370
+ def lists_to_tuples(structure):
371
+ def list_to_tuple(instance):
372
+ return tuple(instance) if isinstance(instance, list) else None
373
+
374
+ return traverse(list_to_tuple, structure, top_down=False)
375
+
376
+
377
+ def map_shape_structure(func, structure):
378
+ if not callable(func):
379
+ raise TypeError(
380
+ f"`func` must be callable, got {func} of type {type(func)}"
381
+ )
382
+
383
+ def map_shape_func(x):
384
+ if isinstance(x, (list, tuple)) and all(
385
+ isinstance(e, (int, type(None))) for e in x
386
+ ):
387
+ ret = func(x)
388
+ elif is_nested(x):
389
+ return None
390
+ else:
391
+ ret = func(x)
392
+ return ret if ret is not None else dmtree.MAP_TO_NONE
393
+
394
+ return traverse(map_shape_func, structure, top_down=True)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optree
2
+ import optree.utils
3
+
4
+ from keras.src.backend.config import backend
5
+
6
+
7
+ def register_tree_node_class(cls):
8
+ return optree.register_pytree_node_class(cls, namespace="keras")
9
+
10
+
11
+ # Register backend-specific node classes
12
+ if backend() == "tensorflow":
13
+ from tensorflow.python.trackable.data_structures import ListWrapper
14
+ from tensorflow.python.trackable.data_structures import _DictWrapper
15
+
16
+ optree.register_pytree_node(
17
+ ListWrapper,
18
+ lambda x: (x, None),
19
+ lambda metadata, children: ListWrapper(list(children)),
20
+ namespace="keras",
21
+ )
22
+
23
+ def sorted_keys_and_values(d):
24
+ keys = sorted(list(d.keys()))
25
+ values = [d[k] for k in keys]
26
+ return values, keys, keys
27
+
28
+ optree.register_pytree_node(
29
+ _DictWrapper,
30
+ sorted_keys_and_values,
31
+ lambda metadata, children: _DictWrapper(
32
+ {key: child for key, child in zip(metadata, children)}
33
+ ),
34
+ namespace="keras",
35
+ )
36
+
37
+
38
+ def is_nested(structure):
39
+ return not optree.tree_is_leaf(
40
+ structure, none_is_leaf=True, namespace="keras"
41
+ )
42
+
43
+
44
+ def traverse(func, structure, top_down=True):
45
+ # From https://github.com/google/jax/pull/19695
46
+ def traverse_children():
47
+ children, treedef = optree.tree_flatten(
48
+ structure,
49
+ is_leaf=lambda x: x is not structure,
50
+ none_is_leaf=True,
51
+ namespace="keras",
52
+ )
53
+ if treedef.num_nodes == 1 and treedef.num_leaves == 1:
54
+ return structure
55
+ else:
56
+ return optree.tree_unflatten(
57
+ treedef,
58
+ [traverse(func, c, top_down=top_down) for c in children],
59
+ )
60
+
61
+ if top_down:
62
+ ret = func(structure)
63
+ if ret is None:
64
+ return traverse_children()
65
+ else:
66
+ traversed_structure = traverse_children()
67
+ ret = func(traversed_structure)
68
+ if ret is None:
69
+ return traversed_structure
70
+ # Detect MAP_TO_NONE without tree_api import to avoid circular import.
71
+ if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE":
72
+ return None
73
+ return ret
74
+
75
+
76
+ def flatten(structure):
77
+ # optree.tree_flatten returns a pair (leaves, treespec) where the first
78
+ # element is a list of leaf values and the second element is a treespec
79
+ # representing the structure of the pytree.
80
+ leaves, _ = optree.tree_flatten(
81
+ structure, none_is_leaf=True, namespace="keras"
82
+ )
83
+ return leaves
84
+
85
+
86
+ def flatten_with_path(structure):
87
+ paths, leaves, _ = optree.tree_flatten_with_path(
88
+ structure, none_is_leaf=True, namespace="keras"
89
+ )
90
+ return list(zip(paths, leaves))
91
+
92
+
93
+ def map_structure(func, *structures):
94
+ if not structures:
95
+ raise ValueError("Must provide at least one structure")
96
+
97
+ # Add check for same structures, otherwise optree just maps to shallowest.
98
+ def func_with_check(*args):
99
+ if not all(
100
+ optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras")
101
+ for s in args
102
+ ):
103
+ raise ValueError("Structures don't have the same nested structure.")
104
+ return func(*args)
105
+
106
+ map_func = func_with_check if len(structures) > 1 else func
107
+
108
+ return optree.tree_map(
109
+ map_func, *structures, none_is_leaf=True, namespace="keras"
110
+ )
111
+
112
+
113
+ def map_structure_up_to(shallow_structure, func, *structures):
114
+ if not structures:
115
+ raise ValueError("Must provide at least one structure")
116
+
117
+ # Add check that `shallow_structure` really is the shallowest.
118
+ # Also only call `func` on `structures` and not `shallow_structure`.
119
+ def func_with_check_without_shallow_structure(shallow, *args):
120
+ if not optree.tree_is_leaf(shallow):
121
+ raise ValueError("Structures don't have the same nested structure.")
122
+ return func(*args)
123
+
124
+ return optree.tree_map(
125
+ func_with_check_without_shallow_structure,
126
+ shallow_structure,
127
+ *structures,
128
+ none_is_leaf=True,
129
+ namespace="keras",
130
+ )
131
+
132
+
133
+ def assert_same_structure(a, b):
134
+ def check(a_leaf, b_leaf):
135
+ if not optree.tree_is_leaf(
136
+ a_leaf, none_is_leaf=True, namespace="keras"
137
+ ) or not optree.tree_is_leaf(
138
+ b_leaf, none_is_leaf=True, namespace="keras"
139
+ ):
140
+ raise ValueError("Structures don't have the same nested structure.")
141
+ return None
142
+
143
+ optree.tree_map(check, a, b, none_is_leaf=True, namespace="keras")
144
+
145
+
146
+ def assert_same_paths(a, b):
147
+ a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace="keras"))
148
+ b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace="keras"))
149
+
150
+ if a_paths != b_paths:
151
+ msg = "`a` and `b` don't have the same paths."
152
+ a_diff = a_paths.difference(b_paths)
153
+ if a_diff:
154
+ msg += f"\nPaths in `a` missing in `b`:\n{a_diff}"
155
+ b_diff = b_paths.difference(a_paths)
156
+ if b_diff:
157
+ msg += f"\nPaths in `b` missing in `a`:\n{b_diff}"
158
+ raise ValueError(msg)
159
+
160
+
161
+ def pack_sequence_as(structure, flat_sequence):
162
+ _, treespec = optree.tree_flatten(
163
+ structure, none_is_leaf=True, namespace="keras"
164
+ )
165
+ return optree.tree_unflatten(treespec, flat_sequence)
166
+
167
+
168
+ def lists_to_tuples(structure):
169
+ def list_to_tuple(instance):
170
+ return tuple(instance) if isinstance(instance, list) else None
171
+
172
+ return traverse(list_to_tuple, structure, top_down=False)
173
+
174
+
175
+ def map_shape_structure(func, structure):
176
+ def is_shape_tuple(x):
177
+ return isinstance(x, (list, tuple)) and all(
178
+ isinstance(e, (int, type(None))) for e in x
179
+ )
180
+
181
+ return optree.tree_map(
182
+ func,
183
+ structure,
184
+ is_leaf=is_shape_tuple,
185
+ none_is_leaf=True,
186
+ namespace="keras",
187
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ from keras.src.api_export import keras_export
4
+ from keras.src.utils.module_utils import dmtree
5
+ from keras.src.utils.module_utils import optree
6
+
7
+ if optree.available:
8
+ from keras.src.tree import optree_impl as tree_impl
9
+ elif dmtree.available:
10
+ from keras.src.tree import dmtree_impl as tree_impl
11
+ else:
12
+ raise ImportError(
13
+ "To use Keras, you need to have `optree` installed. "
14
+ "Install it via `pip install optree`"
15
+ )
16
+
17
+
18
+ def register_tree_node_class(cls):
19
+ return tree_impl.register_tree_node_class(cls)
20
+
21
+
22
+ @keras_export("keras.tree.MAP_TO_NONE")
23
+ class MAP_TO_NONE:
24
+ """Special value for use with `traverse()`."""
25
+
26
+ pass
27
+
28
+
29
+ @keras_export("keras.tree.is_nested")
30
+ def is_nested(structure):
31
+ """Checks if a given structure is nested.
32
+
33
+ Examples:
34
+
35
+ >>> keras.tree.is_nested(42)
36
+ False
37
+ >>> keras.tree.is_nested({"foo": 42})
38
+ True
39
+
40
+ Args:
41
+ structure: A structure to check.
42
+
43
+ Returns:
44
+ `True` if a given structure is nested, i.e. is a sequence, a mapping,
45
+ or a namedtuple, and `False` otherwise.
46
+ """
47
+ return tree_impl.is_nested(structure)
48
+
49
+
50
+ @keras_export("keras.tree.traverse")
51
+ def traverse(func, structure, top_down=True):
52
+ """Traverses the given nested structure, applying the given function.
53
+
54
+ The traversal is depth-first. If `top_down` is True (default), parents
55
+ are returned before their children (giving the option to avoid traversing
56
+ into a sub-tree).
57
+
58
+ Examples:
59
+
60
+ >>> v = []
61
+ >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True)
62
+ [(1, 2), [3], {'a': 4}]
63
+ >>> v
64
+ [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4]
65
+
66
+ >>> v = []
67
+ >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False)
68
+ [(1, 2), [3], {'a': 4}]
69
+ >>> v
70
+ [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]]
71
+
72
+ Args:
73
+ func: The function to be applied to each sub-nest of the structure.
74
+
75
+ When traversing top-down:
76
+ If `func(subtree) is None` the traversal continues into the
77
+ sub-tree.
78
+ If `func(subtree) is not None` the traversal does not continue
79
+ into the sub-tree. The sub-tree will be replaced by `func(subtree)`
80
+ in the returned structure (to replace the sub-tree with `None`, use
81
+ the special value `MAP_TO_NONE`).
82
+
83
+ When traversing bottom-up:
84
+ If `func(subtree) is None` the traversed sub-tree is returned
85
+ unaltered.
86
+ If `func(subtree) is not None` the sub-tree will be replaced by
87
+ `func(subtree)` in the returned structure (to replace the sub-tree
88
+ with None, use the special value `MAP_TO_NONE`).
89
+
90
+ structure: The structure to traverse.
91
+ top_down: If True, parent structures will be visited before their
92
+ children.
93
+
94
+ Returns:
95
+ The structured output from the traversal.
96
+
97
+ Raises:
98
+ TypeError: If `func` is not callable.
99
+ """
100
+ return tree_impl.traverse(func, structure, top_down=top_down)
101
+
102
+
103
+ @keras_export("keras.tree.flatten")
104
+ def flatten(structure):
105
+ """Flattens a possibly nested structure into a list.
106
+
107
+ In the case of dict instances, the sequence consists of the values,
108
+ sorted by key to ensure deterministic behavior. However, instances of
109
+ `collections.OrderedDict` are handled differently: their sequence order is
110
+ used instead of the sorted keys. The same convention is followed in
111
+ `pack_sequence_as`. This correctly unflattens dicts and `OrderedDict` after
112
+ they have been flattened, or vice-versa.
113
+
114
+ Dictionaries with non-sortable keys are not supported.
115
+
116
+ Examples:
117
+
118
+ >>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]])
119
+ [1, 2, 3, 4, 5, 6]
120
+ >>> keras.tree.flatten(None)
121
+ [None]
122
+ >>> keras.tree.flatten(1)
123
+ [1]
124
+ >>> keras.tree.flatten({100: 'world!', 6: 'Hello'})
125
+ ['Hello', 'world!']
126
+
127
+ Args:
128
+ structure: An arbitrarily nested structure.
129
+
130
+ Returns:
131
+ A list, the flattened version of the input `structure`.
132
+ """
133
+ return tree_impl.flatten(structure)
134
+
135
+
136
+ @keras_export("keras.tree.flatten_with_path")
137
+ def flatten_with_path(structure):
138
+ """Flattens a possibly nested structure into a list.
139
+
140
+ This is a variant of flattens() which produces a
141
+ list of pairs: `(path, item)`. A path is a tuple of indices and/or keys
142
+ which uniquely identifies the position of the corresponding item.
143
+
144
+ Dictionaries with non-sortable keys are not supported.
145
+
146
+ Examples:
147
+
148
+ >>> keras.flatten_with_path([{"foo": 42}])
149
+ [((0, 'foo'), 42)]
150
+
151
+
152
+ Args:
153
+ structure: An arbitrarily nested structure.
154
+
155
+ Returns:
156
+ A list of `(path, item)` pairs corresponding to the flattened
157
+ version of the input `structure`.
158
+ """
159
+ return tree_impl.flatten_with_path(structure)
160
+
161
+
162
+ @keras_export("keras.tree.map_structure")
163
+ def map_structure(func, *structures):
164
+ """Maps `func` through given structures.
165
+
166
+ Examples:
167
+
168
+ >>> structure = [[1], [2], [3]]
169
+ >>> keras.tree.map_structure(lambda v: v**2, structure)
170
+ [[1], [4], [9]]
171
+ >>> keras.tree.map_structure(lambda x, y: x * y, structure, structure)
172
+ [[1], [4], [9]]
173
+
174
+ >>> Foo = collections.namedtuple('Foo', ['a', 'b'])
175
+ >>> structure = Foo(a=1, b=2)
176
+ >>> keras.tree.map_structure(lambda v: v * 2, structure)
177
+ Foo(a=2, b=4)
178
+
179
+ Args:
180
+ func: A callable that accepts as many arguments as there are structures.
181
+ *structures: Arbitrarily nested structures of the same layout.
182
+
183
+ Returns:
184
+ A new structure with the same layout as the given ones.
185
+
186
+ Raises:
187
+ TypeError: If `structures` is empty or `func` is not callable.
188
+ ValueError: If there is more than one items in `structures` and some of
189
+ the nested structures don't match according to the rules of
190
+ `assert_same_structure`.
191
+ """
192
+ return tree_impl.map_structure(func, *structures)
193
+
194
+
195
+ @keras_export("keras.tree.map_structure_up_to")
196
+ def map_structure_up_to(shallow_structure, func, *structures):
197
+ """Maps `func` through given structures up to `shallow_structure`.
198
+
199
+ This is a variant of `map_structure` which only maps the given structures
200
+ up to `shallow_structure`. All further nested components are retained as-is.
201
+
202
+ Examples:
203
+
204
+ >>> shallow_structure = [None, None]
205
+ >>> structure = [[1, 1], [2, 2]]
206
+ >>> keras.tree.map_structure_up_to(shallow_structure, len, structure)
207
+ [2, 2]
208
+
209
+ >>> shallow_structure = [None, [None, None]]
210
+ >>> keras.tree.map_structure_up_to(shallow_structure, str, structure)
211
+ ['[1, 1]', ['2', '2']]
212
+
213
+ Args:
214
+ shallow_structure: A structure with layout common to all `structures`.
215
+ func: A callable that accepts as many arguments as there are structures.
216
+ *structures: Arbitrarily nested structures of the same layout.
217
+
218
+ Returns:
219
+ A new structure with the same layout as `shallow_structure`.
220
+
221
+ Raises:
222
+ TypeError: If `structures` is empty or `func` is not callable.
223
+ ValueError: If one of the items in `structures` doesn't match the
224
+ nested structure of `shallow_structure` according to the rules of
225
+ `assert_same_structure`. Items in `structures` are allowed to be
226
+ nested deeper than `shallow_structure`, but they cannot be
227
+ shallower.
228
+ """
229
+ return tree_impl.map_structure_up_to(shallow_structure, func, *structures)
230
+
231
+
232
+ @keras_export("keras.tree.assert_same_structure")
233
+ def assert_same_structure(a, b, check_types=None):
234
+ """Asserts that two structures are nested in the same way.
235
+
236
+ This function verifies that the nested structures match. The leafs can be of
237
+ any type. At each level, the structures must be of the same type and have
238
+ the same number of elements. Instances of `dict`, `OrderedDict` and
239
+ `defaultdict` are all considered the same as long as they have the same set
240
+ of keys. However, `list`, `tuple`, `namedtuple` and `deque` are not the same
241
+ structures. Two namedtuples with identical fields and even identical names
242
+ are not the same structures.
243
+
244
+ Examples:
245
+
246
+ >>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)])
247
+
248
+ >>> Foo = collections.namedtuple('Foo', ['a', 'b'])
249
+ >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b'])
250
+ >>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3))
251
+ >>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3))
252
+ Traceback (most recent call last):
253
+ ...
254
+ ValueError: The two structures don't have the same nested structure.
255
+ ...
256
+
257
+ Args:
258
+ a: an arbitrarily nested structure.
259
+ b: an arbitrarily nested structure.
260
+ check_types: Deprecated. The behavior of this flag was inconsistent, it
261
+ no longer has any effect. For a looser check, use
262
+ `assert_same_paths` instead, which considers `list`, `tuple`,
263
+ `namedtuple` and `deque` as matching structures.
264
+
265
+ Raises:
266
+ ValueError: If the two structures `a` and `b` don't match.
267
+ """
268
+ if check_types is not None:
269
+ if check_types:
270
+ warnings.warn(
271
+ "The `check_types` argument is deprecated and no longer has "
272
+ "any effect, please remove.",
273
+ DeprecationWarning,
274
+ stacklevel=2,
275
+ )
276
+ else:
277
+ warnings.warn(
278
+ "The `check_types` argument is deprecated and no longer has "
279
+ "any effect. For a looser check, use "
280
+ "`keras.tree.assert_same_paths()`, which considers `list`, "
281
+ "`tuple`, `namedtuple` and `deque` as matching",
282
+ DeprecationWarning,
283
+ stacklevel=2,
284
+ )
285
+ return tree_impl.assert_same_structure(a, b)
286
+
287
+
288
+ @keras_export("keras.tree.assert_same_paths")
289
+ def assert_same_paths(a, b):
290
+ """Asserts that two structures have identical paths in their tree structure.
291
+
292
+ This function verifies that two nested structures have the same paths.
293
+ Unlike `assert_same_structure`, this function only checks the paths
294
+ and ignores the collection types.
295
+ For Sequences, to path is the index: 0, 1, 2, etc. For Mappings, the path is
296
+ the key, for instance "a", "b", "c". Note that namedtuples also use indices
297
+ and not field names for the path.
298
+
299
+ Examples:
300
+ >>> keras.tree.assert_same_paths([0, 1], (2, 3))
301
+ >>> Point1 = collections.namedtuple('Point1', ['x', 'y'])
302
+ >>> Point2 = collections.namedtuple('Point2', ['x', 'y'])
303
+ >>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3))
304
+
305
+ Args:
306
+ a: an arbitrarily nested structure.
307
+ b: an arbitrarily nested structure.
308
+
309
+ Raises:
310
+ ValueError: If the paths in structure `a` don't match the paths in
311
+ structure `b`. The error message will include the specific paths
312
+ that differ.
313
+ """
314
+ return tree_impl.assert_same_paths(a, b)
315
+
316
+
317
+ @keras_export("keras.tree.pack_sequence_as")
318
+ def pack_sequence_as(structure, flat_sequence):
319
+ """Returns a given flattened sequence packed into a given structure.
320
+
321
+ If `structure` is an atom, `flat_sequence` must be a single-item list; in
322
+ this case the return value is `flat_sequence[0]`.
323
+
324
+ If `structure` is or contains a dict instance, the keys will be sorted to
325
+ pack the flat sequence in deterministic order. However, instances of
326
+ `collections.OrderedDict` are handled differently: their sequence order is
327
+ used instead of the sorted keys. The same convention is followed in
328
+ `flatten`. This correctly repacks dicts and `OrderedDicts` after they have
329
+ been flattened, or vice-versa.
330
+
331
+ Dictionaries with non-sortable keys are not supported.
332
+
333
+ Examples:
334
+
335
+ >>> structure = {"key3": "", "key1": "", "key2": ""}
336
+ >>> flat_sequence = ["value1", "value2", "value3"]
337
+ >>> keras.tree.pack_sequence_as(structure, flat_sequence)
338
+ {"key3": "value3", "key1": "value1", "key2": "value2"}
339
+
340
+ >>> structure = (("a", "b"), ("c", "d", "e"), "f")
341
+ >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
342
+ >>> keras.tree.pack_sequence_as(structure, flat_sequence)
343
+ ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0)
344
+
345
+ >>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")},
346
+ ... "key1": {"e": "val1", "d": "val2"}}
347
+ >>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0]
348
+ >>> keras.tree.pack_sequence_as(structure, flat_sequence)
349
+ {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}}
350
+
351
+ >>> structure = ["a"]
352
+ >>> flat_sequence = [np.array([[1, 2], [3, 4]])]
353
+ >>> keras.tree.pack_sequence_as(structure, flat_sequence)
354
+ [array([[1, 2],
355
+ [3, 4]])]
356
+
357
+ >>> structure = ["a"]
358
+ >>> flat_sequence = [keras.ops.ones([2, 2])]
359
+ >>> keras.tree.pack_sequence_as(structure, flat_sequence)
360
+ [array([[1., 1.],
361
+ [1., 1.]]]
362
+
363
+ Args:
364
+ structure: Arbitrarily nested structure.
365
+ flat_sequence: Flat sequence to pack.
366
+
367
+ Returns:
368
+ `flat_sequence` converted to have the same recursive structure as
369
+ `structure`.
370
+
371
+ Raises:
372
+ TypeError: If `flat_sequence` is not iterable.
373
+ ValueError: If `flat_sequence` cannot be repacked as `structure`; for
374
+ instance, if `flat_sequence` has too few or too many elements.
375
+ """
376
+ return tree_impl.pack_sequence_as(structure, flat_sequence)
377
+
378
+
379
+ @keras_export("keras.tree.lists_to_tuples")
380
+ def lists_to_tuples(structure):
381
+ """Returns the structure with list instances changed to tuples.
382
+
383
+ Args:
384
+ structure: Arbitrarily nested structure.
385
+
386
+ Returns:
387
+ The same structure but with tuples instead of lists.
388
+ """
389
+ return tree_impl.lists_to_tuples(structure)
390
+
391
+
392
+ @keras_export("keras.tree.map_shape_structure")
393
+ def map_shape_structure(func, structure):
394
+ """Variant of keras.tree.map_structure that operates on shape tuples.
395
+
396
+ Tuples containing ints and Nones are considered shapes and passed to `func`.
397
+
398
+ Args:
399
+ structure: Arbitrarily nested structure.
400
+
401
+ Returns:
402
+ The same structure with `func` applied.
403
+ """
404
+ return tree_impl.map_shape_structure(func, structure)
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory
2
+ from keras.src.utils.dataset_utils import split_dataset
3
+ from keras.src.utils.file_utils import get_file
4
+ from keras.src.utils.image_dataset_utils import image_dataset_from_directory
5
+ from keras.src.utils.image_utils import array_to_img
6
+ from keras.src.utils.image_utils import img_to_array
7
+ from keras.src.utils.image_utils import load_img
8
+ from keras.src.utils.image_utils import save_img
9
+ from keras.src.utils.io_utils import disable_interactive_logging
10
+ from keras.src.utils.io_utils import enable_interactive_logging
11
+ from keras.src.utils.io_utils import is_interactive_logging_enabled
12
+ from keras.src.utils.model_visualization import model_to_dot
13
+ from keras.src.utils.model_visualization import plot_model
14
+ from keras.src.utils.numerical_utils import normalize
15
+ from keras.src.utils.numerical_utils import to_categorical
16
+ from keras.src.utils.progbar import Progbar
17
+ from keras.src.utils.python_utils import default
18
+ from keras.src.utils.python_utils import is_default
19
+ from keras.src.utils.python_utils import removeprefix
20
+ from keras.src.utils.python_utils import removesuffix
21
+ from keras.src.utils.rng_utils import set_random_seed
22
+ from keras.src.utils.sequence_utils import pad_sequences
23
+ from keras.src.utils.text_dataset_utils import text_dataset_from_directory
24
+ from keras.src.utils.timeseries_dataset_utils import (
25
+ timeseries_dataset_from_array,
26
+ )
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.57 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc ADDED
Binary file (2.61 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc ADDED
Binary file (10.3 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc ADDED
Binary file (4.91 kB). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc ADDED
Binary file (979 Bytes). View file
 
SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (6 kB). View file