diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py new file mode 100644 index 0000000000000000000000000000000000000000..91ce5e3a156aef02d17cd486484da30007de1d7e --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_api.py @@ -0,0 +1,279 @@ +import os +import zipfile + +from absl import logging + +from keras.src.api_export import keras_export +from keras.src.legacy.saving import legacy_h5_format +from keras.src.saving import saving_lib +from keras.src.utils import file_utils +from keras.src.utils import io_utils + +try: + import h5py +except ImportError: + h5py = None + + +@keras_export(["keras.saving.save_model", "keras.models.save_model"]) +def save_model(model, filepath, overwrite=True, zipped=None, **kwargs): + """Saves a model as a `.keras` file. + + Args: + model: Keras model instance to be saved. + filepath: `str` or `pathlib.Path` object. Path where to save the model. + overwrite: Whether we should overwrite any existing model at the target + location, or instead ask the user via an interactive prompt. + zipped: Whether to save the model as a zipped `.keras` + archive (default when saving locally), or as an unzipped directory + (default when saving on the Hugging Face Hub). + + Example: + + ```python + model = keras.Sequential( + [ + keras.layers.Dense(5, input_shape=(3,)), + keras.layers.Softmax(), + ], + ) + model.save("model.keras") + loaded_model = keras.saving.load_model("model.keras") + x = keras.random.uniform((10, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + ``` + + Note that `model.save()` is an alias for `keras.saving.save_model()`. + + The saved `.keras` file is a `zip` archive that contains: + + - The model's configuration (architecture) + - The model's weights + - The model's optimizer's state (if any) + + Thus models can be reinstantiated in the exact same state. + """ + include_optimizer = kwargs.pop("include_optimizer", True) + save_format = kwargs.pop("save_format", False) + if save_format: + if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith( + ".keras" + ): + logging.warning( + "The `save_format` argument is deprecated in Keras 3. " + "We recommend removing this argument as it can be inferred " + "from the file path. " + f"Received: save_format={save_format}" + ) + else: + raise ValueError( + "The `save_format` argument is deprecated in Keras 3. " + "Please remove this argument and pass a file path with " + "either `.keras` or `.h5` extension." + f"Received: save_format={save_format}" + ) + if kwargs: + raise ValueError( + "The following argument(s) are not supported: " + f"{list(kwargs.keys())}" + ) + + # Deprecation warnings + if str(filepath).endswith((".h5", ".hdf5")): + logging.warning( + "You are saving your model as an HDF5 file via " + "`model.save()` or `keras.saving.save_model(model)`. " + "This file format is considered legacy. " + "We recommend using instead the native Keras format, " + "e.g. `model.save('my_model.keras')` or " + "`keras.saving.save_model(model, 'my_model.keras')`. " + ) + + is_hf = str(filepath).startswith("hf://") + if zipped is None: + zipped = not is_hf # default behavior depends on destination + + # If file exists and should not be overwritten. + try: + exists = (not is_hf) and os.path.exists(filepath) + except TypeError: + exists = False + if exists and not overwrite: + proceed = io_utils.ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + + if zipped and str(filepath).endswith(".keras"): + return saving_lib.save_model(model, filepath) + if not zipped: + return saving_lib.save_model(model, filepath, zipped=False) + if str(filepath).endswith((".h5", ".hdf5")): + return legacy_h5_format.save_model_to_hdf5( + model, filepath, overwrite, include_optimizer + ) + raise ValueError( + "Invalid filepath extension for saving. " + "Please add either a `.keras` extension for the native Keras " + f"format (recommended) or a `.h5` extension. " + "Use `model.export(filepath)` if you want to export a SavedModel " + "for use with TFLite/TFServing/etc. " + f"Received: filepath={filepath}." + ) + + +@keras_export(["keras.saving.load_model", "keras.models.load_model"]) +def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): + """Loads a model saved via `model.save()`. + + Args: + filepath: `str` or `pathlib.Path` object, path to the saved model file. + custom_objects: Optional dictionary mapping names + (strings) to custom classes or functions to be + considered during deserialization. + compile: Boolean, whether to compile the model after loading. + safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. + When `safe_mode=False`, loading an object has the potential to + trigger arbitrary code execution. This argument is only + applicable to the Keras v3 model format. Defaults to `True`. + + Returns: + A Keras model instance. If the original model was compiled, + and the argument `compile=True` is set, then the returned model + will be compiled. Otherwise, the model will be left uncompiled. + + Example: + + ```python + model = keras.Sequential([ + keras.layers.Dense(5, input_shape=(3,)), + keras.layers.Softmax()]) + model.save("model.keras") + loaded_model = keras.saving.load_model("model.keras") + x = np.random.random((10, 3)) + assert np.allclose(model.predict(x), loaded_model.predict(x)) + ``` + + Note that the model variables may have different name values + (`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded. + It is recommended that you use layer attributes to + access specific variables, e.g. `model.get_layer("dense_1").kernel`. + """ + is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile( + filepath + ) + is_keras_dir = file_utils.isdir(filepath) and file_utils.exists( + file_utils.join(filepath, "config.json") + ) + is_hf = str(filepath).startswith("hf://") + + # Support for remote zip files + if ( + file_utils.is_remote_path(filepath) + and not file_utils.isdir(filepath) + and not is_keras_zip + and not is_hf + ): + local_path = file_utils.join( + saving_lib.get_temp_dir(), os.path.basename(filepath) + ) + + # Copy from remote to temporary local directory + file_utils.copy(filepath, local_path) + + # Switch filepath to local zipfile for loading model + if zipfile.is_zipfile(local_path): + filepath = local_path + is_keras_zip = True + + if is_keras_zip or is_keras_dir or is_hf: + return saving_lib.load_model( + filepath, + custom_objects=custom_objects, + compile=compile, + safe_mode=safe_mode, + ) + if str(filepath).endswith((".h5", ".hdf5")): + return legacy_h5_format.load_model_from_hdf5( + filepath, custom_objects=custom_objects, compile=compile + ) + elif str(filepath).endswith(".keras"): + raise ValueError( + f"File not found: filepath={filepath}. " + "Please ensure the file is an accessible `.keras` " + "zip file." + ) + else: + raise ValueError( + f"File format not supported: filepath={filepath}. " + "Keras 3 only supports V3 `.keras` files and " + "legacy H5 format files (`.h5` extension). " + "Note that the legacy SavedModel format is not " + "supported by `load_model()` in Keras 3. In " + "order to reload a TensorFlow SavedModel as an " + "inference-only layer in Keras 3, use " + "`keras.layers.TFSMLayer(" + f"{filepath}, call_endpoint='serving_default')` " + "(note that your `call_endpoint` " + "might have a different name)." + ) + + +@keras_export("keras.saving.save_weights") +def save_weights(model, filepath, overwrite=True, **kwargs): + if not str(filepath).endswith(".weights.h5"): + raise ValueError( + "The filename must end in `.weights.h5`. " + f"Received: filepath={filepath}" + ) + try: + exists = os.path.exists(filepath) + except TypeError: + exists = False + if exists and not overwrite: + proceed = io_utils.ask_to_proceed_with_overwrite(filepath) + if not proceed: + return + saving_lib.save_weights_only(model, filepath, **kwargs) + + +@keras_export("keras.saving.load_weights") +def load_weights(model, filepath, skip_mismatch=False, **kwargs): + if str(filepath).endswith(".keras"): + if kwargs: + raise ValueError(f"Invalid keyword arguments: {kwargs}") + saving_lib.load_weights_only( + model, filepath, skip_mismatch=skip_mismatch + ) + elif str(filepath).endswith(".weights.h5"): + objects_to_skip = kwargs.pop("objects_to_skip", None) + if kwargs: + raise ValueError(f"Invalid keyword arguments: {kwargs}") + saving_lib.load_weights_only( + model, + filepath, + skip_mismatch=skip_mismatch, + objects_to_skip=objects_to_skip, + ) + elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"): + by_name = kwargs.pop("by_name", False) + if kwargs: + raise ValueError(f"Invalid keyword arguments: {kwargs}") + if not h5py: + raise ImportError( + "Loading a H5 file requires `h5py` to be installed." + ) + with h5py.File(filepath, "r") as f: + if "layer_names" not in f.attrs and "model_weights" in f: + f = f["model_weights"] + if by_name: + legacy_h5_format.load_weights_from_hdf5_group_by_name( + f, model, skip_mismatch + ) + else: + legacy_h5_format.load_weights_from_hdf5_group(f, model) + else: + raise ValueError( + f"File format not supported: filepath={filepath}. " + "Keras 3 only supports V3 `.keras` and `.weights.h5` " + "files, or legacy V1/V2 `.h5` files." + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..13d408f9538f8b3ecc75b59073d4c8598f462aaf --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/saving_lib.py @@ -0,0 +1,1173 @@ +"""Python-based idempotent model-saving functionality.""" + +import datetime +import io +import json +import os +import pathlib +import shutil +import tempfile +import warnings +import zipfile + +import ml_dtypes +import numpy as np + +from keras.src import backend +from keras.src.backend.common import global_state +from keras.src.layers.layer import Layer +from keras.src.losses.loss import Loss +from keras.src.metrics.metric import Metric +from keras.src.optimizers.optimizer import Optimizer +from keras.src.saving.serialization_lib import ObjectSharingScope +from keras.src.saving.serialization_lib import deserialize_keras_object +from keras.src.saving.serialization_lib import serialize_keras_object +from keras.src.trainers.compile_utils import CompileMetrics +from keras.src.utils import file_utils +from keras.src.utils import io_utils +from keras.src.utils import naming +from keras.src.utils import plot_model +from keras.src.utils.model_visualization import check_pydot +from keras.src.utils.summary_utils import weight_memory_size +from keras.src.version import __version__ as keras_version + +try: + import h5py +except ImportError: + h5py = None +try: + import psutil +except ImportError: + psutil = None +try: + import huggingface_hub +except ImportError: + huggingface_hub = None + + +_CONFIG_FILENAME = "config.json" +_METADATA_FILENAME = "metadata.json" +_VARS_FNAME = "model.weights" # Will become e.g. "model.weights.h5" +_VARS_FNAME_H5 = _VARS_FNAME + ".h5" +_VARS_FNAME_NPZ = _VARS_FNAME + ".npz" +_ASSETS_DIRNAME = "assets" +_MEMORY_UPPER_BOUND = 0.5 # 50% + + +_MODEL_CARD_TEMPLATE = """ +--- +library_name: keras +--- + +This model has been uploaded using the Keras library and can be used with JAX, +TensorFlow, and PyTorch backends. + +This model card has been generated automatically and should be completed by the +model author. +See [Model Cards documentation](https://huggingface.co/docs/hub/model-cards) for +more information. + +For more details about the model architecture, check out +[config.json](./config.json).""" + + +def save_model(model, filepath, weights_format="h5", zipped=True): + """Save a zip-archive representing a Keras model to the given file or path. + + The zip-based archive contains the following structure: + + - JSON-based configuration file (config.json): Records of model, layer, and + other saveables' configuration. + - H5-based saveable state files, found in respective directories, such as + model/states.npz, model/dense_layer/states.npz, etc. + - Metadata file. + + The states of Keras saveables (layers, optimizers, loss, and metrics) are + automatically saved as long as they can be discovered through the attributes + returned by `dir(Model)`. Typically, the state includes the variables + associated with the saveable, but some specially purposed layers may + contain more such as the vocabularies stored in the hashmaps. The saveables + define how their states are saved by exposing `save_state()` and + `load_state()` APIs. + + For the case of layer states, the variables will be visited as long as + they are either 1) referenced via layer attributes, or 2) referenced via a + container (list, tuple, or dict), and the container is referenced via a + layer attribute. + """ + if weights_format == "h5" and h5py is None: + raise ImportError("h5py must be installed in order to save a model.") + + if not model.built: + warnings.warn( + "You are saving a model that has not yet been built. " + "It might not contain any weights yet. " + "Consider building the model first by calling it " + "on some data.", + stacklevel=2, + ) + + if isinstance(filepath, io.IOBase): + _save_model_to_fileobj(model, filepath, weights_format) + return + + filepath = str(filepath) + is_hf = filepath.startswith("hf://") + if zipped and not filepath.endswith(".keras"): + raise ValueError( + "Invalid `filepath` argument: expected a `.keras` extension. " + f"Received: filepath={filepath}" + ) + if not zipped and filepath.endswith(".keras"): + raise ValueError( + "When using `zipped=False`, the `filepath` argument should not " + f"end in `.keras`. Received: filepath={filepath}" + ) + if zipped and is_hf: + raise ValueError( + "When saving to the Hugging Face Hub, you should not save the " + f"model as zipped. Received: filepath={filepath}, zipped={zipped}" + ) + if is_hf: + _upload_model_to_hf(model, filepath, weights_format) + elif not zipped: + _save_model_to_dir(model, filepath, weights_format) + else: + if file_utils.is_remote_path(filepath): + # Remote path. Zip to local memory byte io and copy to remote + zip_filepath = io.BytesIO() + _save_model_to_fileobj(model, zip_filepath, weights_format) + with file_utils.File(filepath, "wb") as f: + f.write(zip_filepath.getvalue()) + else: + with open(filepath, "wb") as f: + _save_model_to_fileobj(model, f, weights_format) + + +def _serialize_model_as_json(model): + with ObjectSharingScope(): + serialized_model_dict = serialize_keras_object(model) + config_json = json.dumps(serialized_model_dict) + metadata_json = json.dumps( + { + "keras_version": keras_version, + "date_saved": datetime.datetime.now().strftime("%Y-%m-%d@%H:%M:%S"), + } + ) + return config_json, metadata_json + + +def _save_model_to_dir(model, dirpath, weights_format): + if not file_utils.exists(dirpath): + file_utils.makedirs(dirpath) + config_json, metadata_json = _serialize_model_as_json(model) + with open(file_utils.join(dirpath, _METADATA_FILENAME), "w") as f: + f.write(metadata_json) + with open(file_utils.join(dirpath, _CONFIG_FILENAME), "w") as f: + f.write(config_json) + weights_filepath = file_utils.join(dirpath, _VARS_FNAME_H5) + assert_dirpath = file_utils.join(dirpath, _ASSETS_DIRNAME) + try: + if weights_format == "h5": + weights_store = H5IOStore(weights_filepath, mode="w") + elif weights_format == "npz": + weights_store = NpzIOStore(weights_filepath, mode="w") + else: + raise ValueError( + "Unknown `weights_format` argument. " + "Expected 'h5' or 'npz'. " + f"Received: weights_format={weights_format}" + ) + asset_store = DiskIOStore(assert_dirpath, mode="w") + _save_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + ) + finally: + weights_store.close() + asset_store.close() + + +def _save_model_to_fileobj(model, fileobj, weights_format): + config_json, metadata_json = _serialize_model_as_json(model) + + with zipfile.ZipFile(fileobj, "w") as zf: + with zf.open(_METADATA_FILENAME, "w") as f: + f.write(metadata_json.encode()) + with zf.open(_CONFIG_FILENAME, "w") as f: + f.write(config_json.encode()) + + weights_file_path = None + weights_store = None + asset_store = None + write_zf = False + try: + if weights_format == "h5": + try: + if is_memory_sufficient(model): + # Load the model weights into memory before writing + # .keras if the system memory is sufficient. + weights_store = H5IOStore( + _VARS_FNAME_H5, archive=zf, mode="w" + ) + else: + # Try opening the .h5 file, then writing it to `zf` at + # the end of the function call. This is more memory + # efficient than writing the weights into memory first. + working_dir = pathlib.Path(fileobj.name).parent + weights_file_path = tempfile.NamedTemporaryFile( + dir=working_dir + ) + weights_store = H5IOStore( + weights_file_path.name, mode="w" + ) + write_zf = True + except: + # If we can't use the local disk for any reason, write the + # weights into memory first, which consumes more memory. + weights_store = H5IOStore( + _VARS_FNAME_H5, archive=zf, mode="w" + ) + elif weights_format == "npz": + weights_store = NpzIOStore( + _VARS_FNAME_NPZ, archive=zf, mode="w" + ) + else: + raise ValueError( + "Unknown `weights_format` argument. " + "Expected 'h5' or 'npz'. " + f"Received: weights_format={weights_format}" + ) + + asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="w") + + _save_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + ) + except: + # Skip the final `zf.write` if any exception is raised + write_zf = False + if weights_store: + weights_store.archive = None + raise + finally: + if weights_store: + weights_store.close() + if asset_store: + asset_store.close() + if write_zf and weights_file_path: + zf.write(weights_file_path.name, _VARS_FNAME_H5) + if weights_file_path: + weights_file_path.close() + + +def _upload_model_to_hf(model, hf_path, weights_format): + if huggingface_hub is None: + raise ImportError( + "To save models to the Hugging Face Hub, " + "you must install the `huggingface_hub` package." + ) + + original_hf_path = hf_path + if hf_path.startswith("hf://"): + hf_path = hf_path[5:] + if hf_path.count("/") > 1: + raise ValueError( + "Invalid `hf_path` argument: expected `namespace/model_name`" + f" format. Received: hf_path={original_hf_path}" + ) + + api = huggingface_hub.HfApi( + library_name="keras", library_version=keras_version + ) + repo_url = api.create_repo(hf_path, exist_ok=True) + repo_id = repo_url.repo_id + + with tempfile.TemporaryDirectory() as tmp_dir: + _save_model_to_dir(model, tmp_dir, weights_format) + + model_card = _MODEL_CARD_TEMPLATE + + if check_pydot(): + plot_path = file_utils.join(tmp_dir, "assets", "summary_plot.png") + plot_model( + model, + to_file=plot_path, + show_layer_names=True, + show_shapes=True, + show_dtype=True, + ) + if len(model.layers) <= 10: + model_card += "\n\n![](./assets/summary_plot.png)" + else: + model_card += ( + "A plot of the model can be found " + "[here](./assets/summary_plot.png)." + ) + + with open(file_utils.join(tmp_dir, "README.md"), "w") as f: + f.write(model_card) + + api.upload_folder( + repo_id=repo_id, + folder_path=tmp_dir, + commit_message="Save model using Keras.", + ) + io_utils.print_msg( + f"Model saved to the Hugging Face Hub: {repo_url}\n" + "To load back the model, use " + f"`keras.saving.load_model('hf://{repo_id}')`" + ) + + +def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): + """Load a zip archive representing a Keras model.""" + if isinstance(filepath, io.IOBase): + return _load_model_from_fileobj( + filepath, custom_objects, compile, safe_mode + ) + elif str(filepath).startswith("hf://"): + if huggingface_hub is None: + raise ImportError( + "To load models from the Hugging Face Hub, " + "you must install the `huggingface_hub` package." + ) + + repo_id = filepath[5:] + folder_path = huggingface_hub.snapshot_download( + repo_id=repo_id, + library_name="keras", + library_version=keras_version, + ) + return _load_model_from_dir( + folder_path, custom_objects, compile, safe_mode + ) + else: + filepath = str(filepath) + if not filepath.endswith(".keras"): + is_keras_dir = file_utils.isdir(filepath) and file_utils.exists( + file_utils.join(filepath, "config.json") + ) + if is_keras_dir: + return _load_model_from_dir( + filepath, custom_objects, compile, safe_mode + ) + raise ValueError( + "Invalid filename: expected a `.keras` extension. " + f"Received: filepath={filepath}" + ) + with open(filepath, "rb") as f: + return _load_model_from_fileobj( + f, custom_objects, compile, safe_mode + ) + + +def _load_model_from_dir(dirpath, custom_objects, compile, safe_mode): + if not file_utils.exists(dirpath): + raise ValueError(f"Directory doesn't exist: {dirpath}") + if not file_utils.isdir(dirpath): + raise ValueError(f"Path isn't a directory: {dirpath}") + + with open(file_utils.join(dirpath, _CONFIG_FILENAME), "r") as f: + config_json = f.read() + model = _model_from_config(config_json, custom_objects, compile, safe_mode) + + all_filenames = file_utils.listdir(dirpath) + try: + if _VARS_FNAME_H5 in all_filenames: + weights_file_path = file_utils.join(dirpath, _VARS_FNAME_H5) + weights_store = H5IOStore(weights_file_path, mode="r") + elif _VARS_FNAME_NPZ in all_filenames: + weights_file_path = file_utils.join(dirpath, _VARS_FNAME_NPZ) + weights_store = NpzIOStore(weights_file_path, mode="r") + else: + raise ValueError( + f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file." + ) + if len(all_filenames) > 3: + asset_store = DiskIOStore( + file_utils.join(dirpath, _ASSETS_DIRNAME), mode="r" + ) + + else: + asset_store = None + + failed_saveables = set() + error_msgs = {} + _load_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + + finally: + weights_store.close() + if asset_store: + asset_store.close() + + if failed_saveables: + _raise_loading_failure(error_msgs) + return model + + +def _model_from_config(config_json, custom_objects, compile, safe_mode): + # Note: we should NOT use a custom JSON decoder. Anything that + # needs custom decoding must be handled in deserialize_keras_object. + config_dict = json.loads(config_json) + if not compile: + # Disable compilation + config_dict["compile_config"] = None + # Construct the model from the configuration file in the archive. + with ObjectSharingScope(): + model = deserialize_keras_object( + config_dict, custom_objects, safe_mode=safe_mode + ) + return model + + +def _load_model_from_fileobj(fileobj, custom_objects, compile, safe_mode): + with zipfile.ZipFile(fileobj, "r") as zf: + with zf.open(_CONFIG_FILENAME, "r") as f: + config_json = f.read() + + model = _model_from_config( + config_json, custom_objects, compile, safe_mode + ) + + all_filenames = zf.namelist() + extract_dir = None + weights_store = None + asset_store = None + try: + if _VARS_FNAME_H5 in all_filenames: + try: + if is_memory_sufficient(model): + # Load the entire file into memory if the system memory + # is sufficient. + io_file = io.BytesIO( + zf.open(_VARS_FNAME_H5, "r").read() + ) + weights_store = H5IOStore(io_file, mode="r") + else: + # Try extracting the model.weights.h5 file, and then + # loading it using using h5py. This is significantly + # faster than reading from the zip archive on the fly. + extract_dir = tempfile.TemporaryDirectory( + dir=pathlib.Path(fileobj.name).parent + ) + zf.extract(_VARS_FNAME_H5, extract_dir.name) + weights_store = H5IOStore( + pathlib.Path(extract_dir.name, _VARS_FNAME_H5), + mode="r", + ) + except: + # If we can't use the local disk for any reason, read the + # weights from the zip archive on the fly, which is less + # efficient. + weights_store = H5IOStore(_VARS_FNAME_H5, zf, mode="r") + elif _VARS_FNAME_NPZ in all_filenames: + weights_store = NpzIOStore(_VARS_FNAME_NPZ, zf, mode="r") + else: + raise ValueError( + f"Expected a {_VARS_FNAME_H5} or {_VARS_FNAME_NPZ} file." + ) + + if len(all_filenames) > 3: + asset_store = DiskIOStore(_ASSETS_DIRNAME, archive=zf, mode="r") + + failed_saveables = set() + error_msgs = {} + _load_state( + model, + weights_store=weights_store, + assets_store=asset_store, + inner_path="", + visited_saveables=set(), + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + finally: + if weights_store: + weights_store.close() + if asset_store: + asset_store.close() + if extract_dir: + extract_dir.cleanup() + + if failed_saveables: + _raise_loading_failure(error_msgs) + return model + + +def save_weights_only(model, filepath, objects_to_skip=None): + """Save only the weights of a model to a target filepath. + + Supports both `.weights.h5` and `.keras`. + """ + if not model.built: + raise ValueError( + "You are saving a model that has not yet been built. " + "Try building the model first by calling it on some data or " + "by using `build()`." + ) + + filepath = str(filepath) + tmp_dir = None + remote_filepath = None + if not filepath.endswith(".weights.h5"): + raise ValueError( + "Invalid `filepath` argument: expected a `.weights.h5` extension. " + f"Received: filepath={filepath}" + ) + try: + if file_utils.is_remote_path(filepath): + tmp_dir = get_temp_dir() + local_filepath = os.path.join(tmp_dir, os.path.basename(filepath)) + remote_filepath = filepath + filepath = local_filepath + + weights_store = H5IOStore(filepath, mode="w") + if objects_to_skip is not None: + visited_saveables = set(id(o) for o in objects_to_skip) + else: + visited_saveables = set() + _save_state( + model, + weights_store=weights_store, + assets_store=None, + inner_path="", + visited_saveables=visited_saveables, + ) + weights_store.close() + finally: + if tmp_dir is not None: + file_utils.copy(filepath, remote_filepath) + shutil.rmtree(tmp_dir) + + +def load_weights_only( + model, filepath, skip_mismatch=False, objects_to_skip=None +): + """Load the weights of a model from a filepath (.keras or .weights.h5). + + Note: only supports h5 for now. + """ + if not model.built: + raise ValueError( + "You are loading weights into a model that has not yet been built. " + "Try building the model first by calling it on some data or " + "by using `build()`." + ) + + archive = None + tmp_dir = None + filepath = str(filepath) + + try: + if file_utils.is_remote_path(filepath): + tmp_dir = get_temp_dir() + local_filepath = os.path.join(tmp_dir, os.path.basename(filepath)) + file_utils.copy(filepath, local_filepath) + filepath = local_filepath + + if filepath.endswith(".weights.h5"): + weights_store = H5IOStore(filepath, mode="r") + elif filepath.endswith(".keras"): + archive = zipfile.ZipFile(filepath, "r") + weights_store = H5IOStore(_VARS_FNAME_H5, archive=archive, mode="r") + + failed_saveables = set() + if objects_to_skip is not None: + visited_saveables = set(id(o) for o in objects_to_skip) + else: + visited_saveables = set() + error_msgs = {} + _load_state( + model, + weights_store=weights_store, + assets_store=None, + inner_path="", + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + weights_store.close() + if archive: + archive.close() + + if failed_saveables: + _raise_loading_failure(error_msgs, warn_only=skip_mismatch) + finally: + if tmp_dir is not None: + shutil.rmtree(tmp_dir) + + +def _raise_loading_failure(error_msgs, warn_only=False): + first_key = list(error_msgs.keys())[0] + ex_saveable, ex_error = error_msgs[first_key] + msg = ( + f"A total of {len(error_msgs)} objects could not " + "be loaded. Example error message for " + f"object {ex_saveable}:\n\n" + f"{ex_error}\n\n" + "List of objects that could not be loaded:\n" + f"{[x[0] for x in error_msgs.values()]}" + ) + if warn_only: + warnings.warn(msg) + else: + raise ValueError(msg) + + +def _write_to_zip_recursively(zipfile_to_save, system_path, zip_path): + if not file_utils.isdir(system_path): + zipfile_to_save.write(system_path, zip_path) + else: + for file_name in file_utils.listdir(system_path): + system_file_path = file_utils.join(system_path, file_name).replace( + "\\", "/" + ) + zip_file_path = file_utils.join(zip_path, file_name).replace( + "\\", "/" + ) + _write_to_zip_recursively( + zipfile_to_save, system_file_path, zip_file_path + ) + + +def _name_key(name): + """Make sure that private attributes are visited last.""" + if name.startswith("_"): + return "~" + name + return name + + +def _walk_saveable(saveable): + from keras.src.saving.keras_saveable import KerasSaveable + + if not isinstance(saveable, KerasSaveable): + raise ValueError( + "Expected object to be an " + "instance of `KerasSaveable`, but " + f"got {saveable} of type {type(saveable)}" + ) + + obj_type = saveable._obj_type() + attr_skipset = get_attr_skipset(obj_type) + + # Save all layers directly tracked by Sequential and Functional first. + # This helps avoid ordering concerns for subclassed Sequential or Functional + # models with extra attributes--the internal Keras state take precedence. + if obj_type in ("Sequential", "Functional"): + yield "layers", saveable.layers + + for child_attr in sorted(dir(saveable), key=lambda x: _name_key(x)): + if child_attr.startswith("__") or child_attr in attr_skipset: + continue + try: + child_obj = getattr(saveable, child_attr) + except Exception: + # Avoid raising the exception when visiting the attributes. + continue + yield child_attr, child_obj + + +def _save_state( + saveable, + weights_store, + assets_store, + inner_path, + visited_saveables, +): + from keras.src.saving.keras_saveable import KerasSaveable + + # If the saveable has already been saved, skip it. + if id(saveable) in visited_saveables: + return + + if hasattr(saveable, "save_own_variables") and weights_store: + if hasattr(saveable, "name") and isinstance(saveable.name, str): + metadata = {"name": saveable.name} + else: + metadata = None + saveable.save_own_variables( + weights_store.make(inner_path, metadata=metadata) + ) + if hasattr(saveable, "save_assets") and assets_store: + saveable.save_assets(assets_store.make(inner_path)) + + visited_saveables.add(id(saveable)) + + # Recursively save state of children saveables (layers, optimizers, etc.) + for child_attr, child_obj in _walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + _save_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + visited_saveables=visited_saveables, + ) + elif isinstance(child_obj, (list, dict, tuple, set)): + _save_container_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + visited_saveables=visited_saveables, + ) + + +def _load_state( + saveable, + weights_store, + assets_store, + inner_path, + skip_mismatch=False, + visited_saveables=None, + failed_saveables=None, + error_msgs=None, +): + from keras.src.saving.keras_saveable import KerasSaveable + + if visited_saveables and id(saveable) in visited_saveables: + return + + failure = False + + if hasattr(saveable, "load_own_variables") and weights_store: + if skip_mismatch or failed_saveables is not None: + try: + saveable.load_own_variables(weights_store.get(inner_path)) + except Exception as e: + failed_saveables.add(id(saveable)) + error_msgs[id(saveable)] = saveable, e + failure = True + else: + saveable.load_own_variables(weights_store.get(inner_path)) + + if hasattr(saveable, "load_assets") and assets_store: + if skip_mismatch or failed_saveables is not None: + try: + saveable.load_assets(assets_store.get(inner_path)) + except Exception as e: + failed_saveables.add(id(saveable)) + error_msgs[id(saveable)] = saveable, e + failure = True + else: + saveable.load_assets(assets_store.get(inner_path)) + + if failed_saveables is not None: + currently_failed = len(failed_saveables) + else: + currently_failed = 0 + + # Recursively load states for Keras saveables such as layers/optimizers. + for child_attr, child_obj in _walk_saveable(saveable): + if isinstance(child_obj, KerasSaveable): + _load_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + elif isinstance(child_obj, (list, dict, tuple, set)): + _load_container_state( + child_obj, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, child_attr).replace( + "\\", "/" + ), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + + if failed_saveables is not None: + newly_failed = len(failed_saveables) - currently_failed + else: + newly_failed = 0 + + if not failure: + if visited_saveables is not None and newly_failed <= 0: + visited_saveables.add(id(saveable)) + if id(saveable) in failed_saveables: + failed_saveables.remove(id(saveable)) + error_msgs.pop(id(saveable)) + + +def _save_container_state( + container, weights_store, assets_store, inner_path, visited_saveables +): + from keras.src.saving.keras_saveable import KerasSaveable + + used_names = {} + if isinstance(container, dict): + container = list(container.values()) + + for saveable in container: + if isinstance(saveable, KerasSaveable): + # Do NOT address the saveable via `saveable.name`, since + # names are usually autogenerated and thus not reproducible + # (i.e. they may vary across two instances of the same model). + name = naming.to_snake_case(saveable.__class__.__name__) + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + _save_state( + saveable, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + visited_saveables=visited_saveables, + ) + + +def _load_container_state( + container, + weights_store, + assets_store, + inner_path, + skip_mismatch, + visited_saveables, + failed_saveables, + error_msgs, +): + from keras.src.saving.keras_saveable import KerasSaveable + + used_names = {} + if isinstance(container, dict): + container = list(container.values()) + + for saveable in container: + if isinstance(saveable, KerasSaveable): + name = naming.to_snake_case(saveable.__class__.__name__) + if name in used_names: + used_names[name] += 1 + name = f"{name}_{used_names[name]}" + else: + used_names[name] = 0 + _load_state( + saveable, + weights_store, + assets_store, + inner_path=file_utils.join(inner_path, name).replace("\\", "/"), + skip_mismatch=skip_mismatch, + visited_saveables=visited_saveables, + failed_saveables=failed_saveables, + error_msgs=error_msgs, + ) + + +class DiskIOStore: + """Asset store backed by disk storage. + + If `archive` is specified, then `root_path` refers to the filename + inside the archive. + + If `archive` is not specified, then `root_path` refers to the full path of + the target directory. + """ + + def __init__(self, root_path, archive=None, mode=None): + self.mode = mode + self.root_path = root_path + self.archive = archive + self.tmp_dir = None + if self.archive: + self.tmp_dir = get_temp_dir() + if self.mode == "r": + self.archive.extractall(path=self.tmp_dir) + self.working_dir = file_utils.join( + self.tmp_dir, self.root_path + ).replace("\\", "/") + if self.mode == "w": + file_utils.makedirs(self.working_dir) + else: + if mode == "r": + self.working_dir = root_path + else: + self.tmp_dir = get_temp_dir() + self.working_dir = file_utils.join( + self.tmp_dir, self.root_path + ).replace("\\", "/") + file_utils.makedirs(self.working_dir) + + def make(self, path): + if not path: + return self.working_dir + path = file_utils.join(self.working_dir, path).replace("\\", "/") + if not file_utils.exists(path): + file_utils.makedirs(path) + return path + + def get(self, path): + if not path: + return self.working_dir + path = file_utils.join(self.working_dir, path).replace("\\", "/") + if file_utils.exists(path): + return path + return None + + def close(self): + if self.mode == "w" and self.archive: + _write_to_zip_recursively( + self.archive, self.working_dir, self.root_path + ) + if self.tmp_dir and file_utils.exists(self.tmp_dir): + file_utils.rmtree(self.tmp_dir) + + +class H5IOStore: + def __init__(self, root_path, archive=None, mode="r"): + """Numerical variable store backed by HDF5. + + If `archive` is specified, then `root_path` refers to the filename + inside the archive. + + If `archive` is not specified, then `root_path` refers to the path of + the h5 file on disk. + """ + self.root_path = root_path + self.mode = mode + self.archive = archive + self.io_file = None + + if self.archive: + if self.mode == "w": + self.io_file = io.BytesIO() + else: + self.io_file = self.archive.open(self.root_path, "r") + self.h5_file = h5py.File(self.io_file, mode=self.mode) + else: + self.h5_file = h5py.File(root_path, mode=self.mode) + + def make(self, path, metadata=None): + return H5Entry(self.h5_file, path, mode="w", metadata=metadata) + + def get(self, path): + return H5Entry(self.h5_file, path, mode="r") + + def close(self): + self.h5_file.close() + if self.mode == "w" and self.archive: + self.archive.writestr(self.root_path, self.io_file.getvalue()) + if self.io_file: + self.io_file.close() + + +class H5Entry: + """Leaf entry in a H5IOStore.""" + + def __init__(self, h5_file, path, mode, metadata=None): + self.h5_file = h5_file + self.path = path + self.mode = mode + self.metadata = metadata + + if mode == "w": + if not path: + self.group = self.h5_file.create_group("vars") + else: + self.group = self.h5_file.create_group(self.path).create_group( + "vars" + ) + if self.metadata: + for k, v in self.metadata.items(): + self.group.attrs[k] = v + else: + found = False + if not path: + if "vars" in self.h5_file: + self.group = self.h5_file["vars"] + found = True + elif path in self.h5_file and "vars" in self.h5_file[path]: + self.group = self.h5_file[path]["vars"] + found = True + else: + # No hit. + # Fix for 2.13 compatibility + if "_layer_checkpoint_dependencies" in self.h5_file: + path = path.replace( + "layers", "_layer_checkpoint_dependencies" + ) + self.path = path + if path in self.h5_file and "vars" in self.h5_file[path]: + self.group = self.h5_file[path]["vars"] + found = True + if not found: + self.group = {} + + def __len__(self): + return self.group.__len__() + + def keys(self): + return self.group.keys() + + def items(self): + return self.group.items() + + def values(self): + return self.group.values() + + def __setitem__(self, key, value): + if self.mode != "w": + raise ValueError("Setting a value is only allowed in write mode.") + value = backend.convert_to_numpy(value) + if backend.standardize_dtype(value.dtype) == "bfloat16": + ds = self.group.create_dataset(key, data=value) + ds.attrs["dtype"] = "bfloat16" + else: + self.group[key] = value + + def __getitem__(self, name): + value = self.group[name] + if "dtype" in value.attrs and value.attrs["dtype"] == "bfloat16": + value = np.array(value, dtype=ml_dtypes.bfloat16) + return value + + +class NpzIOStore: + def __init__(self, root_path, archive=None, mode="r"): + """Numerical variable store backed by NumPy.savez/load. + + If `archive` is specified, then `root_path` refers to the filename + inside the archive. + + If `archive` is not specified, then `root_path` refers to the path of + the npz file on disk. + """ + self.root_path = root_path + self.mode = mode + self.archive = archive + if mode == "w": + self.contents = {} + else: + if self.archive: + self.f = archive.open(root_path, mode="r") + else: + self.f = open(root_path, mode="rb") + self.contents = np.load(self.f, allow_pickle=True) + + def make(self, path, metadata=None): + if not path: + self.contents["__root__"] = {} + return self.contents["__root__"] + self.contents[path] = {} + return self.contents[path] + + def get(self, path): + if not path: + if "__root__" in self.contents: + return dict(self.contents["__root__"]) + return {} + if path in self.contents: + return self.contents[path].tolist() + return {} + + def close(self): + if self.mode == "w": + if self.archive: + self.f = self.archive.open( + self.root_path, mode="w", force_zip64=True + ) + else: + self.f = open(self.root_path, mode="wb") + np.savez(self.f, **self.contents) + self.f.close() + + +def get_temp_dir(): + temp_dir = tempfile.mkdtemp() + testfile = tempfile.TemporaryFile(dir=temp_dir) + testfile.close() + return temp_dir + + +def get_attr_skipset(obj_type): + skipset = global_state.get_global_attribute( + f"saving_attr_skiplist_{obj_type}", None + ) + if skipset is not None: + return skipset + + skipset = set( + [ + "_self_unconditional_dependency_names", + ] + ) + if obj_type == "Layer": + ref_obj = Layer() + skipset.update(dir(ref_obj)) + elif obj_type == "Functional": + ref_obj = Layer() + skipset.update(dir(ref_obj) + ["operations", "_operations"]) + elif obj_type == "Sequential": + ref_obj = Layer() + skipset.update(dir(ref_obj) + ["_functional"]) + elif obj_type == "Metric": + ref_obj_a = Metric() + ref_obj_b = CompileMetrics([], []) + skipset.update(dir(ref_obj_a) + dir(ref_obj_b)) + elif obj_type == "Optimizer": + ref_obj = Optimizer(1.0) + skipset.update(dir(ref_obj)) + skipset.remove("variables") + elif obj_type == "Loss": + ref_obj = Loss() + skipset.update(dir(ref_obj)) + else: + raise ValueError( + f"get_attr_skipset got invalid {obj_type=}. " + "Accepted values for `obj_type` are " + "['Layer', 'Functional', 'Sequential', 'Metric', " + "'Optimizer', 'Loss']" + ) + + global_state.set_global_attribute( + f"saving_attr_skipset_{obj_type}", skipset + ) + return skipset + + +def is_memory_sufficient(model): + """Check if there is sufficient memory to load the model into memory. + + If psutil is installed, we can use it to determine whether the memory is + sufficient. Otherwise, we use a predefined value of 1 GB for available + memory. + """ + if psutil is None: + available_memory = 1024 * 1024 * 1024 # 1 GB in bytes + else: + available_memory = psutil.virtual_memory().available # In bytes + return ( + weight_memory_size(model.variables) + < available_memory * _MEMORY_UPPER_BOUND + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py new file mode 100644 index 0000000000000000000000000000000000000000..cf8eb327fb402639e1bc1270d49e4832ccff0312 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/saving/serialization_lib.py @@ -0,0 +1,808 @@ +"""Object config serialization and deserialization logic.""" + +import importlib +import inspect +import types +import warnings + +import numpy as np + +from keras.src import api_export +from keras.src import backend +from keras.src.api_export import keras_export +from keras.src.backend.common import global_state +from keras.src.saving import object_registration +from keras.src.utils import python_utils +from keras.src.utils.module_utils import tensorflow as tf + +PLAIN_TYPES = (str, int, float, bool) + +# List of Keras modules with built-in string representations for Keras defaults +BUILTIN_MODULES = ( + "activations", + "constraints", + "initializers", + "losses", + "metrics", + "optimizers", + "regularizers", +) + + +class SerializableDict: + def __init__(self, **config): + self.config = config + + def serialize(self): + return serialize_keras_object(self.config) + + +class SafeModeScope: + """Scope to propagate safe mode flag to nested deserialization calls.""" + + def __init__(self, safe_mode=True): + self.safe_mode = safe_mode + + def __enter__(self): + self.original_value = in_safe_mode() + global_state.set_global_attribute("safe_mode_saving", self.safe_mode) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute( + "safe_mode_saving", self.original_value + ) + + +@keras_export("keras.config.enable_unsafe_deserialization") +def enable_unsafe_deserialization(): + """Disables safe mode globally, allowing deserialization of lambdas.""" + global_state.set_global_attribute("safe_mode_saving", False) + + +def in_safe_mode(): + return global_state.get_global_attribute("safe_mode_saving") + + +class ObjectSharingScope: + """Scope to enable detection and reuse of previously seen objects.""" + + def __enter__(self): + global_state.set_global_attribute("shared_objects/id_to_obj_map", {}) + global_state.set_global_attribute("shared_objects/id_to_config_map", {}) + + def __exit__(self, *args, **kwargs): + global_state.set_global_attribute("shared_objects/id_to_obj_map", None) + global_state.set_global_attribute( + "shared_objects/id_to_config_map", None + ) + + +def get_shared_object(obj_id): + """Retrieve an object previously seen during deserialization.""" + id_to_obj_map = global_state.get_global_attribute( + "shared_objects/id_to_obj_map" + ) + if id_to_obj_map is not None: + return id_to_obj_map.get(obj_id, None) + + +def record_object_after_serialization(obj, config): + """Call after serializing an object, to keep track of its config.""" + if config["module"] == "__main__": + config["module"] = None # Ensures module is None when no module found + id_to_config_map = global_state.get_global_attribute( + "shared_objects/id_to_config_map" + ) + if id_to_config_map is None: + return # Not in a sharing scope + obj_id = int(id(obj)) + if obj_id not in id_to_config_map: + id_to_config_map[obj_id] = config + else: + config["shared_object_id"] = obj_id + prev_config = id_to_config_map[obj_id] + prev_config["shared_object_id"] = obj_id + + +def record_object_after_deserialization(obj, obj_id): + """Call after deserializing an object, to keep track of it in the future.""" + id_to_obj_map = global_state.get_global_attribute( + "shared_objects/id_to_obj_map" + ) + if id_to_obj_map is None: + return # Not in a sharing scope + id_to_obj_map[obj_id] = obj + + +@keras_export( + [ + "keras.saving.serialize_keras_object", + "keras.utils.serialize_keras_object", + ] +) +def serialize_keras_object(obj): + """Retrieve the config dict by serializing the Keras object. + + `serialize_keras_object()` serializes a Keras object to a python dictionary + that represents the object, and is a reciprocal function of + `deserialize_keras_object()`. See `deserialize_keras_object()` for more + information about the config format. + + Args: + obj: the Keras object to serialize. + + Returns: + A python dict that represents the object. The python dict can be + deserialized via `deserialize_keras_object()`. + """ + if obj is None: + return obj + + if isinstance(obj, PLAIN_TYPES): + return obj + + if isinstance(obj, (list, tuple)): + config_arr = [serialize_keras_object(x) for x in obj] + return tuple(config_arr) if isinstance(obj, tuple) else config_arr + if isinstance(obj, dict): + return serialize_dict(obj) + + # Special cases: + if isinstance(obj, bytes): + return { + "class_name": "__bytes__", + "config": {"value": obj.decode("utf-8")}, + } + if isinstance(obj, slice): + return { + "class_name": "__slice__", + "config": { + "start": serialize_keras_object(obj.start), + "stop": serialize_keras_object(obj.stop), + "step": serialize_keras_object(obj.step), + }, + } + # Ellipsis is an instance, and ellipsis class is not in global scope. + # checking equality also fails elsewhere in the library, so we have + # to dynamically get the type. + if isinstance(obj, type(Ellipsis)): + return {"class_name": "__ellipsis__", "config": {}} + if isinstance(obj, backend.KerasTensor): + history = getattr(obj, "_keras_history", None) + if history: + history = list(history) + history[0] = history[0].name + return { + "class_name": "__keras_tensor__", + "config": { + "shape": obj.shape, + "dtype": obj.dtype, + "keras_history": history, + }, + } + if tf.available and isinstance(obj, tf.TensorShape): + return obj.as_list() if obj._dims is not None else None + if backend.is_tensor(obj): + return { + "class_name": "__tensor__", + "config": { + "value": backend.convert_to_numpy(obj).tolist(), + "dtype": backend.standardize_dtype(obj.dtype), + }, + } + if type(obj).__module__ == np.__name__: + if isinstance(obj, np.ndarray) and obj.ndim > 0: + return { + "class_name": "__numpy__", + "config": { + "value": obj.tolist(), + "dtype": backend.standardize_dtype(obj.dtype), + }, + } + else: + # Treat numpy floats / etc as plain types. + return obj.item() + if tf.available and isinstance(obj, tf.DType): + return obj.name + if isinstance(obj, types.FunctionType) and obj.__name__ == "": + warnings.warn( + "The object being serialized includes a `lambda`. This is unsafe. " + "In order to reload the object, you will have to pass " + "`safe_mode=False` to the loading function. " + "Please avoid using `lambda` in the " + "future, and use named Python functions instead. " + f"This is the `lambda` being serialized: {inspect.getsource(obj)}", + stacklevel=2, + ) + return { + "class_name": "__lambda__", + "config": { + "value": python_utils.func_dump(obj), + }, + } + if tf.available and isinstance(obj, tf.TypeSpec): + ts_config = obj._serialize() + # TensorShape and tf.DType conversion + ts_config = list( + map( + lambda x: ( + x.as_list() + if isinstance(x, tf.TensorShape) + else (x.name if isinstance(x, tf.DType) else x) + ), + ts_config, + ) + ) + return { + "class_name": "__typespec__", + "spec_name": obj.__class__.__name__, + "module": obj.__class__.__module__, + "config": ts_config, + "registered_name": None, + } + + inner_config = _get_class_or_fn_config(obj) + config_with_public_class = serialize_with_public_class( + obj.__class__, inner_config + ) + + if config_with_public_class is not None: + get_build_and_compile_config(obj, config_with_public_class) + record_object_after_serialization(obj, config_with_public_class) + return config_with_public_class + + # Any custom object or otherwise non-exported object + if isinstance(obj, types.FunctionType): + module = obj.__module__ + else: + module = obj.__class__.__module__ + class_name = obj.__class__.__name__ + + if module == "builtins": + registered_name = None + else: + if isinstance(obj, types.FunctionType): + registered_name = object_registration.get_registered_name(obj) + else: + registered_name = object_registration.get_registered_name( + obj.__class__ + ) + + config = { + "module": module, + "class_name": class_name, + "config": inner_config, + "registered_name": registered_name, + } + get_build_and_compile_config(obj, config) + record_object_after_serialization(obj, config) + return config + + +def get_build_and_compile_config(obj, config): + if hasattr(obj, "get_build_config"): + build_config = obj.get_build_config() + if build_config is not None: + config["build_config"] = serialize_dict(build_config) + if hasattr(obj, "get_compile_config"): + compile_config = obj.get_compile_config() + if compile_config is not None: + config["compile_config"] = serialize_dict(compile_config) + return + + +def serialize_with_public_class(cls, inner_config=None): + """Serializes classes from public Keras API or object registration. + + Called to check and retrieve the config of any class that has a public + Keras API or has been registered as serializable via + `keras.saving.register_keras_serializable()`. + """ + # This gets the `keras.*` exported name, such as + # "keras.optimizers.Adam". + keras_api_name = api_export.get_name_from_symbol(cls) + + # Case of custom or unknown class object + if keras_api_name is None: + registered_name = object_registration.get_registered_name(cls) + if registered_name is None: + return None + + # Return custom object config with corresponding registration name + return { + "module": cls.__module__, + "class_name": cls.__name__, + "config": inner_config, + "registered_name": registered_name, + } + + # Split the canonical Keras API name into a Keras module and class name. + parts = keras_api_name.split(".") + return { + "module": ".".join(parts[:-1]), + "class_name": parts[-1], + "config": inner_config, + "registered_name": None, + } + + +def serialize_with_public_fn(fn, config, fn_module_name=None): + """Serializes functions from public Keras API or object registration. + + Called to check and retrieve the config of any function that has a public + Keras API or has been registered as serializable via + `keras.saving.register_keras_serializable()`. If function's module name + is already known, returns corresponding config. + """ + if fn_module_name: + return { + "module": fn_module_name, + "class_name": "function", + "config": config, + "registered_name": config, + } + keras_api_name = api_export.get_name_from_symbol(fn) + if keras_api_name: + parts = keras_api_name.split(".") + return { + "module": ".".join(parts[:-1]), + "class_name": "function", + "config": config, + "registered_name": config, + } + else: + registered_name = object_registration.get_registered_name(fn) + if not registered_name and not fn.__module__ == "builtins": + return None + return { + "module": fn.__module__, + "class_name": "function", + "config": config, + "registered_name": registered_name, + } + + +def _get_class_or_fn_config(obj): + """Return the object's config depending on its type.""" + # Functions / lambdas: + if isinstance(obj, types.FunctionType): + return object_registration.get_registered_name(obj) + # All classes: + if hasattr(obj, "get_config"): + config = obj.get_config() + if not isinstance(config, dict): + raise TypeError( + f"The `get_config()` method of {obj} should return " + f"a dict. It returned: {config}" + ) + return serialize_dict(config) + elif hasattr(obj, "__name__"): + return object_registration.get_registered_name(obj) + else: + raise TypeError( + f"Cannot serialize object {obj} of type {type(obj)}. " + "To be serializable, " + "a class must implement the `get_config()` method." + ) + + +def serialize_dict(obj): + return {key: serialize_keras_object(value) for key, value in obj.items()} + + +@keras_export( + [ + "keras.saving.deserialize_keras_object", + "keras.utils.deserialize_keras_object", + ] +) +def deserialize_keras_object( + config, custom_objects=None, safe_mode=True, **kwargs +): + """Retrieve the object by deserializing the config dict. + + The config dict is a Python dictionary that consists of a set of key-value + pairs, and represents a Keras object, such as an `Optimizer`, `Layer`, + `Metrics`, etc. The saving and loading library uses the following keys to + record information of a Keras object: + + - `class_name`: String. This is the name of the class, + as exactly defined in the source + code, such as "LossesContainer". + - `config`: Dict. Library-defined or user-defined key-value pairs that store + the configuration of the object, as obtained by `object.get_config()`. + - `module`: String. The path of the python module. Built-in Keras classes + expect to have prefix `keras`. + - `registered_name`: String. The key the class is registered under via + `keras.saving.register_keras_serializable(package, name)` API. The + key has the format of '{package}>{name}', where `package` and `name` are + the arguments passed to `register_keras_serializable()`. If `name` is not + provided, it uses the class name. If `registered_name` successfully + resolves to a class (that was registered), the `class_name` and `config` + values in the dict will not be used. `registered_name` is only used for + non-built-in classes. + + For example, the following dictionary represents the built-in Adam optimizer + with the relevant config: + + ```python + dict_structure = { + "class_name": "Adam", + "config": { + "amsgrad": false, + "beta_1": 0.8999999761581421, + "beta_2": 0.9990000128746033, + "decay": 0.0, + "epsilon": 1e-07, + "learning_rate": 0.0010000000474974513, + "name": "Adam" + }, + "module": "keras.optimizers", + "registered_name": None + } + # Returns an `Adam` instance identical to the original one. + deserialize_keras_object(dict_structure) + ``` + + If the class does not have an exported Keras namespace, the library tracks + it by its `module` and `class_name`. For example: + + ```python + dict_structure = { + "class_name": "MetricsList", + "config": { + ... + }, + "module": "keras.trainers.compile_utils", + "registered_name": "MetricsList" + } + + # Returns a `MetricsList` instance identical to the original one. + deserialize_keras_object(dict_structure) + ``` + + And the following dictionary represents a user-customized `MeanSquaredError` + loss: + + ```python + @keras.saving.register_keras_serializable(package='my_package') + class ModifiedMeanSquaredError(keras.losses.MeanSquaredError): + ... + + dict_structure = { + "class_name": "ModifiedMeanSquaredError", + "config": { + "fn": "mean_squared_error", + "name": "mean_squared_error", + "reduction": "auto" + }, + "registered_name": "my_package>ModifiedMeanSquaredError" + } + # Returns the `ModifiedMeanSquaredError` object + deserialize_keras_object(dict_structure) + ``` + + Args: + config: Python dict describing the object. + custom_objects: Python dict containing a mapping between custom + object names the corresponding classes or functions. + safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. + When `safe_mode=False`, loading an object has the potential to + trigger arbitrary code execution. This argument is only + applicable to the Keras v3 model format. Defaults to `True`. + + Returns: + The object described by the `config` dictionary. + """ + safe_scope_arg = in_safe_mode() # Enforces SafeModeScope + safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode + + module_objects = kwargs.pop("module_objects", None) + custom_objects = custom_objects or {} + tlco = global_state.get_global_attribute("custom_objects_scope_dict", {}) + gco = object_registration.GLOBAL_CUSTOM_OBJECTS + custom_objects = {**custom_objects, **tlco, **gco} + + if config is None: + return None + + if ( + isinstance(config, str) + and custom_objects + and custom_objects.get(config) is not None + ): + # This is to deserialize plain functions which are serialized as + # string names by legacy saving formats. + return custom_objects[config] + + if isinstance(config, (list, tuple)): + return [ + deserialize_keras_object( + x, custom_objects=custom_objects, safe_mode=safe_mode + ) + for x in config + ] + + if module_objects is not None: + inner_config, fn_module_name, has_custom_object = None, None, False + + if isinstance(config, dict): + if "config" in config: + inner_config = config["config"] + if "class_name" not in config: + raise ValueError( + f"Unknown `config` as a `dict`, config={config}" + ) + + # Check case where config is function or class and in custom objects + if custom_objects and ( + config["class_name"] in custom_objects + or config.get("registered_name") in custom_objects + or ( + isinstance(inner_config, str) + and inner_config in custom_objects + ) + ): + has_custom_object = True + + # Case where config is function but not in custom objects + elif config["class_name"] == "function": + fn_module_name = config["module"] + if fn_module_name == "builtins": + config = config["config"] + else: + config = config["registered_name"] + + # Case where config is class but not in custom objects + else: + if config.get("module", "_") is None: + raise TypeError( + "Cannot deserialize object of type " + f"`{config['class_name']}`. If " + f"`{config['class_name']}` is a custom class, please " + "register it using the " + "`@keras.saving.register_keras_serializable()` " + "decorator." + ) + config = config["class_name"] + + if not has_custom_object: + # Return if not found in either module objects or custom objects + if config not in module_objects: + # Object has already been deserialized + return config + if isinstance(module_objects[config], types.FunctionType): + return deserialize_keras_object( + serialize_with_public_fn( + module_objects[config], config, fn_module_name + ), + custom_objects=custom_objects, + ) + return deserialize_keras_object( + serialize_with_public_class( + module_objects[config], inner_config=inner_config + ), + custom_objects=custom_objects, + ) + + if isinstance(config, PLAIN_TYPES): + return config + if not isinstance(config, dict): + raise TypeError(f"Could not parse config: {config}") + + if "class_name" not in config or "config" not in config: + return { + key: deserialize_keras_object( + value, custom_objects=custom_objects, safe_mode=safe_mode + ) + for key, value in config.items() + } + + class_name = config["class_name"] + inner_config = config["config"] or {} + custom_objects = custom_objects or {} + + # Special cases: + if class_name == "__keras_tensor__": + obj = backend.KerasTensor( + inner_config["shape"], dtype=inner_config["dtype"] + ) + obj._pre_serialization_keras_history = inner_config["keras_history"] + return obj + + if class_name == "__tensor__": + return backend.convert_to_tensor( + inner_config["value"], dtype=inner_config["dtype"] + ) + if class_name == "__numpy__": + return np.array(inner_config["value"], dtype=inner_config["dtype"]) + if config["class_name"] == "__bytes__": + return inner_config["value"].encode("utf-8") + if config["class_name"] == "__ellipsis__": + return Ellipsis + if config["class_name"] == "__slice__": + return slice( + deserialize_keras_object( + inner_config["start"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + deserialize_keras_object( + inner_config["stop"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + deserialize_keras_object( + inner_config["step"], + custom_objects=custom_objects, + safe_mode=safe_mode, + ), + ) + if config["class_name"] == "__lambda__": + if safe_mode: + raise ValueError( + "Requested the deserialization of a `lambda` object. " + "This carries a potential risk of arbitrary code execution " + "and thus it is disallowed by default. If you trust the " + "source of the saved model, you can pass `safe_mode=False` to " + "the loading function in order to allow `lambda` loading, " + "or call `keras.config.enable_unsafe_deserialization()`." + ) + return python_utils.func_load(inner_config["value"]) + if tf is not None and config["class_name"] == "__typespec__": + obj = _retrieve_class_or_fn( + config["spec_name"], + config["registered_name"], + config["module"], + obj_type="class", + full_config=config, + custom_objects=custom_objects, + ) + # Conversion to TensorShape and DType + inner_config = map( + lambda x: ( + tf.TensorShape(x) + if isinstance(x, list) + else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x) + ), + inner_config, + ) + return obj._deserialize(tuple(inner_config)) + + # Below: classes and functions. + module = config.get("module", None) + registered_name = config.get("registered_name", class_name) + + if class_name == "function": + fn_name = inner_config + return _retrieve_class_or_fn( + fn_name, + registered_name, + module, + obj_type="function", + full_config=config, + custom_objects=custom_objects, + ) + + # Below, handling of all classes. + # First, is it a shared object? + if "shared_object_id" in config: + obj = get_shared_object(config["shared_object_id"]) + if obj is not None: + return obj + + cls = _retrieve_class_or_fn( + class_name, + registered_name, + module, + obj_type="class", + full_config=config, + custom_objects=custom_objects, + ) + + if isinstance(cls, types.FunctionType): + return cls + if not hasattr(cls, "from_config"): + raise TypeError( + f"Unable to reconstruct an instance of '{class_name}' because " + f"the class is missing a `from_config()` method. " + f"Full object config: {config}" + ) + + # Instantiate the class from its config inside a custom object scope + # so that we can catch any custom objects that the config refers to. + custom_obj_scope = object_registration.CustomObjectScope(custom_objects) + safe_mode_scope = SafeModeScope(safe_mode) + with custom_obj_scope, safe_mode_scope: + try: + instance = cls.from_config(inner_config) + except TypeError as e: + raise TypeError( + f"{cls} could not be deserialized properly. Please" + " ensure that components that are Python object" + " instances (layers, models, etc.) returned by" + " `get_config()` are explicitly deserialized in the" + " model's `from_config()` method." + f"\n\nconfig={config}.\n\nException encountered: {e}" + ) + build_config = config.get("build_config", None) + if build_config and not instance.built: + instance.build_from_config(build_config) + instance.built = True + compile_config = config.get("compile_config", None) + if compile_config: + instance.compile_from_config(compile_config) + instance.compiled = True + + if "shared_object_id" in config: + record_object_after_deserialization( + instance, config["shared_object_id"] + ) + return instance + + +def _retrieve_class_or_fn( + name, registered_name, module, obj_type, full_config, custom_objects=None +): + # If there is a custom object registered via + # `register_keras_serializable()`, that takes precedence. + if obj_type == "function": + custom_obj = object_registration.get_registered_object( + name, custom_objects=custom_objects + ) + else: + custom_obj = object_registration.get_registered_object( + registered_name, custom_objects=custom_objects + ) + if custom_obj is not None: + return custom_obj + + if module: + # If it's a Keras built-in object, + # we cannot always use direct import, because the exported + # module name might not match the package structure + # (e.g. experimental symbols). + if module == "keras" or module.startswith("keras."): + api_name = module + "." + name + + obj = api_export.get_symbol_from_name(api_name) + if obj is not None: + return obj + + # Configs of Keras built-in functions do not contain identifying + # information other than their name (e.g. 'acc' or 'tanh'). This special + # case searches the Keras modules that contain built-ins to retrieve + # the corresponding function from the identifying string. + if obj_type == "function" and module == "builtins": + for mod in BUILTIN_MODULES: + obj = api_export.get_symbol_from_name( + "keras." + mod + "." + name + ) + if obj is not None: + return obj + + # Otherwise, attempt to retrieve the class object given the `module` + # and `class_name`. Import the module, find the class. + try: + mod = importlib.import_module(module) + except ModuleNotFoundError: + raise TypeError( + f"Could not deserialize {obj_type} '{name}' because " + f"its parent module {module} cannot be imported. " + f"Full object config: {full_config}" + ) + obj = vars(mod).get(name, None) + + # Special case for keras.metrics.metrics + if obj is None and registered_name is not None: + obj = vars(mod).get(registered_name, None) + + if obj is not None: + return obj + + raise TypeError( + f"Could not locate {obj_type} '{name}'. " + "Make sure custom classes are decorated with " + "`@keras.saving.register_keras_serializable()`. " + f"Full object config: {full_config}" + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae554ff85857b6a5caa33c719ea96375bb515d6d --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__init__.py @@ -0,0 +1,5 @@ +from keras.src.testing.test_case import TestCase +from keras.src.testing.test_case import jax_uses_gpu +from keras.src.testing.test_case import tensorflow_uses_gpu +from keras.src.testing.test_case import torch_uses_gpu +from keras.src.testing.test_case import uses_gpu diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f520e5c6eec5d2955d524107e206f382656fc86f Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96dead547fa78ec699e78ce41e9ec5dfdf21c1f7 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_case.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..790b2a7c703006602e9a404786b6559131ddd548 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/__pycache__/test_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py new file mode 100644 index 0000000000000000000000000000000000000000..a168f734cd87780eb60ba1cc0abe23252e00ba09 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_case.py @@ -0,0 +1,796 @@ +import json +import shutil +import tempfile +import unittest +from pathlib import Path + +import numpy as np +from absl.testing import parameterized + +from keras.src import backend +from keras.src import distribution +from keras.src import ops +from keras.src import tree +from keras.src import utils +from keras.src.backend.common import is_float_dtype +from keras.src.backend.common import standardize_dtype +from keras.src.backend.common.global_state import clear_session +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.models import Model +from keras.src.utils import traceback_utils + + +class TestCase(parameterized.TestCase, unittest.TestCase): + maxDiff = None + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def setUp(self): + # clear global state so that test cases are independent + # required for the jit enabled torch tests since dynamo has + # a global cache for guards, compiled fn, etc + clear_session(free_memory=False) + if traceback_utils.is_traceback_filtering_enabled(): + traceback_utils.disable_traceback_filtering() + + def get_temp_dir(self): + temp_dir = tempfile.mkdtemp() + self.addCleanup(lambda: shutil.rmtree(temp_dir)) + return temp_dir + + def assertAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): + if not isinstance(x1, np.ndarray): + x1 = backend.convert_to_numpy(x1) + if not isinstance(x2, np.ndarray): + x2 = backend.convert_to_numpy(x2) + np.testing.assert_allclose(x1, x2, atol=atol, rtol=rtol, err_msg=msg) + + def assertNotAllClose(self, x1, x2, atol=1e-6, rtol=1e-6, msg=None): + try: + self.assertAllClose(x1, x2, atol=atol, rtol=rtol, msg=msg) + except AssertionError: + return + msg = msg or "" + raise AssertionError( + f"The two values are close at all elements. \n" + f"{msg}.\n" + f"Values: {x1}" + ) + + def assertAlmostEqual(self, x1, x2, decimal=3, msg=None): + msg = msg or "" + if not isinstance(x1, np.ndarray): + x1 = backend.convert_to_numpy(x1) + if not isinstance(x2, np.ndarray): + x2 = backend.convert_to_numpy(x2) + np.testing.assert_almost_equal(x1, x2, decimal=decimal, err_msg=msg) + + def assertAllEqual(self, x1, x2, msg=None): + self.assertEqual(len(x1), len(x2), msg=msg) + for e1, e2 in zip(x1, x2): + if isinstance(e1, (list, tuple)) or isinstance(e2, (list, tuple)): + self.assertAllEqual(e1, e2, msg=msg) + else: + e1 = backend.convert_to_numpy(e1) + e2 = backend.convert_to_numpy(e2) + self.assertEqual(e1, e2, msg=msg) + + def assertLen(self, iterable, expected_len, msg=None): + self.assertEqual(len(iterable), expected_len, msg=msg) + + def assertSparse(self, x, sparse=True): + if isinstance(x, KerasTensor): + self.assertEqual(x.sparse, sparse) + elif backend.backend() == "tensorflow": + import tensorflow as tf + + if sparse: + self.assertIsInstance(x, tf.SparseTensor) + else: + self.assertNotIsInstance(x, tf.SparseTensor) + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + if sparse: + self.assertIsInstance(x, jax_sparse.JAXSparse) + else: + self.assertNotIsInstance(x, jax_sparse.JAXSparse) + else: + self.assertFalse( + sparse, + f"Backend {backend.backend()} does not support sparse tensors", + ) + + def assertDType(self, x, dtype, msg=None): + if hasattr(x, "dtype"): + x_dtype = backend.standardize_dtype(x.dtype) + else: + # If x is a python number + x_dtype = backend.standardize_dtype(type(x)) + standardized_dtype = backend.standardize_dtype(dtype) + default_msg = ( + "The dtype of x does not match the expected one. " + f"Received: x.dtype={x_dtype} and dtype={dtype}" + ) + msg = msg or default_msg + self.assertEqual(x_dtype, standardized_dtype, msg=msg) + + def assertFileExists(self, path): + if not Path(path).is_file(): + raise AssertionError(f"File {path} does not exist") + + def run_class_serialization_test(self, instance, custom_objects=None): + from keras.src.saving import custom_object_scope + from keras.src.saving import deserialize_keras_object + from keras.src.saving import serialize_keras_object + + # get_config roundtrip + cls = instance.__class__ + config = instance.get_config() + config_json = to_json_with_tuples(config) + ref_dir = dir(instance)[:] + with custom_object_scope(custom_objects): + revived_instance = cls.from_config(config) + revived_config = revived_instance.get_config() + revived_config_json = to_json_with_tuples(revived_config) + self.assertEqual(config_json, revived_config_json) + self.assertEqual(set(ref_dir), set(dir(revived_instance))) + + # serialization roundtrip + serialized = serialize_keras_object(instance) + serialized_json = to_json_with_tuples(serialized) + with custom_object_scope(custom_objects): + revived_instance = deserialize_keras_object( + from_json_with_tuples(serialized_json) + ) + revived_config = revived_instance.get_config() + revived_config_json = to_json_with_tuples(revived_config) + self.assertEqual(config_json, revived_config_json) + new_dir = dir(revived_instance)[:] + for lst in [ref_dir, new_dir]: + if "__annotations__" in lst: + lst.remove("__annotations__") + self.assertEqual(set(ref_dir), set(new_dir)) + return revived_instance + + def run_layer_test( + self, + layer_cls, + init_kwargs, + input_shape=None, + input_dtype=None, + input_sparse=False, + input_data=None, + call_kwargs=None, + expected_output_shape=None, + expected_output_dtype=None, + expected_output_sparse=False, + expected_output=None, + expected_num_trainable_weights=None, + expected_num_non_trainable_weights=None, + expected_num_non_trainable_variables=None, + expected_num_seed_generators=None, + expected_num_losses=None, + supports_masking=None, + expected_mask_shape=None, + custom_objects=None, + run_training_check=True, + run_mixed_precision_check=True, + assert_built_after_instantiation=False, + ): + """Run basic checks on a layer. + + Args: + layer_cls: The class of the layer to test. + init_kwargs: Dict of arguments to be used to + instantiate the layer. + input_shape: Shape tuple (or list/dict of shape tuples) + to call the layer on. + input_dtype: Corresponding input dtype. + input_sparse: Whether the input is a sparse tensor (this requires + the backend to support sparse tensors). + input_data: Tensor (or list/dict of tensors) + to call the layer on. + call_kwargs: Dict of arguments to use when calling the + layer (does not include the first input tensor argument) + expected_output_shape: Shape tuple + (or list/dict of shape tuples) + expected as output. + expected_output_dtype: dtype expected as output. + expected_output_sparse: Whether the output is expected to be sparse + (this requires the backend to support sparse tensors). + expected_output: Expected output tensor -- only + to be specified if input_data is provided. + expected_num_trainable_weights: Expected number + of trainable weights of the layer once built. + expected_num_non_trainable_weights: Expected number + of non-trainable weights of the layer once built. + expected_num_seed_generators: Expected number of + SeedGenerators objects of the layer once built. + expected_num_losses: Expected number of loss tensors + produced when calling the layer. + supports_masking: If True, will check that the layer + supports masking. + expected_mask_shape: Expected mask shape tuple + returned by compute_mask() (only supports 1 shape). + custom_objects: Dict of any custom objects to be + considered during deserialization. + run_training_check: Whether to attempt to train the layer + (if an input shape or input data was provided). + run_mixed_precision_check: Whether to test the layer with a mixed + precision dtype policy. + assert_built_after_instantiation: Whether to assert `built=True` + after the layer's instantiation. + """ + if input_shape is not None and input_data is not None: + raise ValueError( + "input_shape and input_data cannot be passed " + "at the same time." + ) + if expected_output_shape is not None and expected_output is not None: + raise ValueError( + "expected_output_shape and expected_output cannot be passed " + "at the same time." + ) + if expected_output is not None and input_data is None: + raise ValueError( + "In order to use expected_output, input_data must be provided." + ) + if expected_mask_shape is not None and supports_masking is not True: + raise ValueError( + "In order to use expected_mask_shape, supports_masking " + "must be True." + ) + + init_kwargs = init_kwargs or {} + call_kwargs = call_kwargs or {} + + if input_shape is not None and input_dtype is not None: + if isinstance(input_shape, tuple) and is_shape_tuple( + input_shape[0] + ): + self.assertIsInstance(input_dtype, tuple) + self.assertEqual( + len(input_shape), + len(input_dtype), + msg="The number of input shapes and dtypes does not match", + ) + elif isinstance(input_shape, dict): + self.assertIsInstance(input_dtype, dict) + self.assertEqual( + set(input_shape.keys()), + set(input_dtype.keys()), + msg="The number of input shapes and dtypes does not match", + ) + elif isinstance(input_shape, list): + self.assertIsInstance(input_dtype, list) + self.assertEqual( + len(input_shape), + len(input_dtype), + msg="The number of input shapes and dtypes does not match", + ) + elif not isinstance(input_shape, tuple): + raise ValueError("The type of input_shape is not supported") + if input_shape is not None and input_dtype is None: + input_dtype = tree.map_shape_structure( + lambda _: "float32", input_shape + ) + + # Estimate actual number of weights, variables, seed generators if + # expected ones not set. When using layers uses composition it should + # build each sublayer manually. + if input_data is not None or input_shape is not None: + if input_data is None: + input_data = create_eager_tensors( + input_shape, input_dtype, input_sparse + ) + layer = layer_cls(**init_kwargs) + if isinstance(input_data, dict): + layer(**input_data, **call_kwargs) + else: + layer(input_data, **call_kwargs) + + if expected_num_trainable_weights is None: + expected_num_trainable_weights = len(layer.trainable_weights) + if expected_num_non_trainable_weights is None: + expected_num_non_trainable_weights = len( + layer.non_trainable_weights + ) + if expected_num_non_trainable_variables is None: + expected_num_non_trainable_variables = len( + layer.non_trainable_variables + ) + if expected_num_seed_generators is None: + expected_num_seed_generators = len(get_seed_generators(layer)) + + # Serialization test. + layer = layer_cls(**init_kwargs) + self.run_class_serialization_test(layer, custom_objects) + + # Basic masking test. + if supports_masking is not None: + self.assertEqual( + layer.supports_masking, + supports_masking, + msg="Unexpected supports_masking value", + ) + + def run_build_asserts(layer): + self.assertTrue(layer.built) + if expected_num_trainable_weights is not None: + self.assertLen( + layer.trainable_weights, + expected_num_trainable_weights, + msg="Unexpected number of trainable_weights", + ) + if expected_num_non_trainable_weights is not None: + self.assertLen( + layer.non_trainable_weights, + expected_num_non_trainable_weights, + msg="Unexpected number of non_trainable_weights", + ) + if expected_num_non_trainable_variables is not None: + self.assertLen( + layer.non_trainable_variables, + expected_num_non_trainable_variables, + msg="Unexpected number of non_trainable_variables", + ) + if expected_num_seed_generators is not None: + self.assertLen( + get_seed_generators(layer), + expected_num_seed_generators, + msg="Unexpected number of seed_generators", + ) + if ( + backend.backend() == "torch" + and expected_num_trainable_weights is not None + and expected_num_non_trainable_weights is not None + and expected_num_seed_generators is not None + ): + self.assertLen( + layer.torch_params, + expected_num_trainable_weights + + expected_num_non_trainable_weights + + expected_num_seed_generators, + msg="Unexpected number of torch_params", + ) + + def run_output_asserts(layer, output, eager=False): + if expected_output_shape is not None: + if isinstance(expected_output_shape, tuple) and is_shape_tuple( + expected_output_shape[0] + ): + self.assertIsInstance(output, tuple) + self.assertEqual( + len(output), + len(expected_output_shape), + msg="Unexpected number of outputs", + ) + output_shape = tuple(v.shape for v in output) + self.assertEqual( + expected_output_shape, + output_shape, + msg="Unexpected output shape", + ) + elif isinstance(expected_output_shape, tuple): + self.assertEqual( + expected_output_shape, + output.shape, + msg="Unexpected output shape", + ) + elif isinstance(expected_output_shape, dict): + self.assertIsInstance(output, dict) + self.assertEqual( + set(output.keys()), + set(expected_output_shape.keys()), + msg="Unexpected output dict keys", + ) + output_shape = {k: v.shape for k, v in output.items()} + self.assertEqual( + expected_output_shape, + output_shape, + msg="Unexpected output shape", + ) + elif isinstance(expected_output_shape, list): + self.assertIsInstance(output, list) + self.assertEqual( + len(output), + len(expected_output_shape), + msg="Unexpected number of outputs", + ) + output_shape = [v.shape for v in output] + self.assertEqual( + expected_output_shape, + output_shape, + msg="Unexpected output shape", + ) + else: + raise ValueError( + "The type of expected_output_shape is not supported" + ) + if expected_output_dtype is not None: + if isinstance(expected_output_dtype, tuple): + self.assertIsInstance(output, tuple) + self.assertEqual( + len(output), + len(expected_output_dtype), + msg="Unexpected number of outputs", + ) + output_dtype = tuple( + backend.standardize_dtype(v.dtype) for v in output + ) + self.assertEqual( + expected_output_dtype, + output_dtype, + msg="Unexpected output dtype", + ) + elif isinstance(expected_output_dtype, dict): + self.assertIsInstance(output, dict) + self.assertEqual( + set(output.keys()), + set(expected_output_dtype.keys()), + msg="Unexpected output dict keys", + ) + output_dtype = { + k: backend.standardize_dtype(v.dtype) + for k, v in output.items() + } + self.assertEqual( + expected_output_dtype, + output_dtype, + msg="Unexpected output dtype", + ) + elif isinstance(expected_output_dtype, list): + self.assertIsInstance(output, list) + self.assertEqual( + len(output), + len(expected_output_dtype), + msg="Unexpected number of outputs", + ) + output_dtype = [ + backend.standardize_dtype(v.dtype) for v in output + ] + self.assertEqual( + expected_output_dtype, + output_dtype, + msg="Unexpected output dtype", + ) + else: + output_dtype = tree.flatten(output)[0].dtype + self.assertEqual( + expected_output_dtype, + backend.standardize_dtype(output_dtype), + msg="Unexpected output dtype", + ) + if expected_output_sparse: + for x in tree.flatten(output): + self.assertSparse(x) + if eager: + if expected_output is not None: + self.assertEqual(type(expected_output), type(output)) + for ref_v, v in zip( + tree.flatten(expected_output), tree.flatten(output) + ): + self.assertAllClose( + ref_v, v, msg="Unexpected output value" + ) + if expected_num_losses is not None: + self.assertLen(layer.losses, expected_num_losses) + + def run_training_step(layer, input_data, output_data): + class TestModel(Model): + def __init__(self, layer): + super().__init__() + self.layer = layer + + def call(self, x, training=False): + return self.layer(x, training=training) + + model = TestModel(layer) + + data = (input_data, output_data) + if backend.backend() == "torch": + data = tree.map_structure(backend.convert_to_numpy, data) + + def data_generator(): + while True: + yield data + + # test the "default" path for each backend by setting + # jit_compile="auto". + # for tensorflow and jax backends auto is jitted + # Note that tensorflow cannot be jitted with sparse tensors + # for torch backend auto is eager + # + # NB: for torch, jit_compile=True turns on torchdynamo + # which may not always succeed in tracing depending + # on the model. Run your program with these env vars + # to get debug traces of dynamo: + # TORCH_LOGS="+dynamo" + # TORCHDYNAMO_VERBOSE=1 + # TORCHDYNAMO_REPORT_GUARD_FAILURES=1 + jit_compile = "auto" + if backend.backend() == "tensorflow" and input_sparse: + jit_compile = False + model.compile(optimizer="sgd", loss="mse", jit_compile=jit_compile) + model.fit(data_generator(), steps_per_epoch=1, verbose=0) + + # Build test. + if input_data is not None or input_shape is not None: + if input_shape is None: + build_shape = tree.map_structure( + lambda x: ops.shape(x), input_data + ) + else: + build_shape = input_shape + layer = layer_cls(**init_kwargs) + if isinstance(build_shape, dict): + layer.build(**build_shape) + else: + layer.build(build_shape) + run_build_asserts(layer) + + # Symbolic call test. + if input_shape is None: + keras_tensor_inputs = tree.map_structure( + lambda x: create_keras_tensors( + ops.shape(x), x.dtype, input_sparse + ), + input_data, + ) + else: + keras_tensor_inputs = create_keras_tensors( + input_shape, input_dtype, input_sparse + ) + layer = layer_cls(**init_kwargs) + if isinstance(keras_tensor_inputs, dict): + keras_tensor_outputs = layer( + **keras_tensor_inputs, **call_kwargs + ) + else: + keras_tensor_outputs = layer(keras_tensor_inputs, **call_kwargs) + run_build_asserts(layer) + run_output_asserts(layer, keras_tensor_outputs, eager=False) + + if expected_mask_shape is not None: + output_mask = layer.compute_mask(keras_tensor_inputs) + self.assertEqual(expected_mask_shape, output_mask.shape) + + # The stateless layers should be built after instantiation. + if assert_built_after_instantiation: + layer = layer_cls(**init_kwargs) + self.assertTrue( + layer.built, + msg=( + f"{type(layer)} is stateless, so it should be built " + "after instantiation." + ), + ) + + # Eager call test and compiled training test. + if input_data is not None or input_shape is not None: + if input_data is None: + input_data = create_eager_tensors( + input_shape, input_dtype, input_sparse + ) + layer = layer_cls(**init_kwargs) + if isinstance(input_data, dict): + output_data = layer(**input_data, **call_kwargs) + else: + output_data = layer(input_data, **call_kwargs) + run_output_asserts(layer, output_data, eager=True) + + if run_training_check: + run_training_step(layer, input_data, output_data) + + # Never test mixed precision on torch CPU. Torch lacks support. + if run_mixed_precision_check and backend.backend() == "torch": + import torch + + run_mixed_precision_check = torch.cuda.is_available() + + if run_mixed_precision_check: + layer = layer_cls(**{**init_kwargs, "dtype": "mixed_float16"}) + input_spec = tree.map_structure( + lambda spec: KerasTensor( + spec.shape, + dtype=( + layer.compute_dtype + if layer.autocast + and backend.is_float_dtype(spec.dtype) + else spec.dtype + ), + ), + keras_tensor_inputs, + ) + if isinstance(input_data, dict): + output_data = layer(**input_data, **call_kwargs) + output_spec = layer.compute_output_spec(**input_spec) + else: + output_data = layer(input_data, **call_kwargs) + output_spec = layer.compute_output_spec(input_spec) + for tensor, spec in zip( + tree.flatten(output_data), tree.flatten(output_spec) + ): + dtype = standardize_dtype(tensor.dtype) + self.assertEqual( + dtype, + spec.dtype, + f"expected output dtype {spec.dtype}, got {dtype}", + ) + for weight in layer.weights: + dtype = standardize_dtype(weight.dtype) + if is_float_dtype(dtype): + self.assertEqual(dtype, "float32") + + +def tensorflow_uses_gpu(): + return backend.backend() == "tensorflow" and uses_gpu() + + +def jax_uses_gpu(): + return backend.backend() == "jax" and uses_gpu() + + +def torch_uses_gpu(): + if backend.backend() != "torch": + return False + from keras.src.backend.torch.core import get_device + + return get_device() == "cuda" + + +def uses_gpu(): + # Condition used to skip tests when using the GPU + devices = distribution.list_devices() + if any(d.startswith("gpu") for d in devices): + return True + return False + + +def create_keras_tensors(input_shape, dtype, sparse): + if isinstance(input_shape, dict): + return { + utils.removesuffix(k, "_shape"): KerasTensor( + v, dtype=dtype[k], sparse=sparse + ) + for k, v in input_shape.items() + } + return map_shape_dtype_structure( + lambda shape, dt: KerasTensor(shape, dtype=dt, sparse=sparse), + input_shape, + dtype, + ) + + +def create_eager_tensors(input_shape, dtype, sparse): + from keras.src.backend import random + + if set(tree.flatten(dtype)).difference( + [ + "float16", + "float32", + "float64", + "int8", + "uint8", + "int16", + "uint16", + "int32", + "uint32", + "int64", + "uint64", + ] + ): + raise ValueError( + "dtype must be a standard float or int dtype. " + f"Received: dtype={dtype}" + ) + + if sparse: + if backend.backend() == "tensorflow": + import tensorflow as tf + + def create_fn(shape, dt): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal(shape)).astype(dt) + x = np.multiply(x, rng.random(shape) < 0.7) + return tf.sparse.from_dense(x) + + elif backend.backend() == "jax": + import jax.experimental.sparse as jax_sparse + + def create_fn(shape, dt): + rng = np.random.default_rng(0) + x = (4 * rng.standard_normal(shape)).astype(dt) + x = np.multiply(x, rng.random(shape) < 0.7) + return jax_sparse.BCOO.fromdense(x, n_batch=1) + + else: + raise ValueError( + f"Sparse is unsupported with backend {backend.backend()}" + ) + + else: + + def create_fn(shape, dt): + return ops.cast( + random.uniform(shape, dtype="float32") * 3, dtype=dt + ) + + if isinstance(input_shape, dict): + return { + utils.removesuffix(k, "_shape"): create_fn(v, dtype[k]) + for k, v in input_shape.items() + } + return map_shape_dtype_structure(create_fn, input_shape, dtype) + + +def is_shape_tuple(x): + return isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ) + + +def map_shape_dtype_structure(fn, shape, dtype): + """Variant of tree.map_structure that operates on shape tuples.""" + if is_shape_tuple(shape): + return fn(tuple(shape), dtype) + if isinstance(shape, list): + return [ + map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype) + ] + if isinstance(shape, tuple): + return tuple( + map_shape_dtype_structure(fn, s, d) for s, d in zip(shape, dtype) + ) + if isinstance(shape, dict): + return { + k: map_shape_dtype_structure(fn, v, dtype[k]) + for k, v in shape.items() + } + else: + raise ValueError( + f"Cannot map function to unknown objects {shape} and {dtype}" + ) + + +def get_seed_generators(layer): + """Get a List of all seed generators in the layer recursively.""" + seed_generators = [] + seen_ids = set() + for sublayer in layer._flatten_layers(True, True): + for sg in sublayer._seed_generators: + if id(sg) not in seen_ids: + seed_generators.append(sg) + seen_ids.add(id(sg)) + return seed_generators + + +def to_json_with_tuples(value): + def _tuple_encode(obj): + if isinstance(obj, tuple): + return {"__class__": "tuple", "__value__": list(obj)} + if isinstance(obj, list): + return [_tuple_encode(e) for e in obj] + if isinstance(obj, dict): + return {key: _tuple_encode(value) for key, value in obj.items()} + return obj + + class _PreserveTupleJsonEncoder(json.JSONEncoder): + def encode(self, obj): + obj = _tuple_encode(obj) + return super().encode(obj) + + return _PreserveTupleJsonEncoder(sort_keys=True, indent=4).encode(value) + + +def from_json_with_tuples(value): + def _tuple_decode(obj): + if not isinstance(obj, dict): + return obj + if "__class__" not in obj or "__value__" not in obj: + return obj + return tuple(obj["__value__"]) + + return json.loads(value, object_hook=_tuple_decode) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0df3645ff6df76c3d99e12d64a110e04e35f68d3 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/testing/test_utils.py @@ -0,0 +1,163 @@ +import numpy as np + + +def get_test_data( + train_samples, test_samples, input_shape, num_classes, random_seed=None +): + """Generates balanced, stratified synthetic test data to train a model on. + + Args: + train_samples: Integer, how many training samples to generate. + test_samples: Integer, how many test samples to generate. + input_shape: Tuple of integers, shape of the inputs. + num_classes: Integer, number of classes for the data and targets. + random_seed: Integer, random seed used by Numpy to generate data. + + Returns: + A tuple of Numpy arrays: `(x_train, y_train), (x_test, y_test)`. + """ + np.random.seed(random_seed) + + # Total samples + total_samples = train_samples + test_samples + + # Ensure that we generate a balanced dataset + samples_per_class = total_samples // num_classes + y = np.array( + [i for i in range(num_classes) for _ in range(samples_per_class)], + dtype=np.int32, + ) + + # Generate extra samples in a deterministic manner + extra_samples = total_samples - len(y) + y_extra = np.array( + [i % num_classes for i in range(extra_samples)], dtype=np.int64 + ) + y = np.concatenate([y, y_extra]) + + # Generate data + templates = 2 * num_classes * np.random.random((num_classes,) + input_shape) + x = np.zeros((total_samples,) + input_shape, dtype=np.float32) + for i in range(total_samples): + x[i] = templates[y[i]] + np.random.normal( + loc=0, scale=1.0, size=input_shape + ) + + # Shuffle the entire dataset to ensure randomness based on seed + indices = np.arange(total_samples) + np.random.shuffle(indices) + x, y = x[indices], y[indices] + + # Stratified Shuffle Split + x_train, y_train, x_test, y_test = [], [], [], [] + for cls in range(num_classes): + cls_indices = np.where(y == cls)[0] + np.random.shuffle(cls_indices) + train_count = int(train_samples / num_classes) + + x_train.extend(x[cls_indices[:train_count]]) + y_train.extend(y[cls_indices[:train_count]]) + + x_test.extend(x[cls_indices[train_count:]]) + y_test.extend(y[cls_indices[train_count:]]) + + # Convert to numpy arrays + x_train, y_train = np.array(x_train), np.array(y_train) + x_test, y_test = np.array(x_test), np.array(y_test) + + # Shuffle training and test sets after stratified split + train_indices = np.arange(len(x_train)) + test_indices = np.arange(len(x_test)) + np.random.shuffle(train_indices) + np.random.shuffle(test_indices) + + x_train, y_train = x_train[train_indices], y_train[train_indices] + x_test, y_test = x_test[test_indices], y_test[test_indices] + + return (x_train, y_train), (x_test, y_test) + + +def named_product(*args, **kwargs): + """Utility to generate the cartesian product of parameters values and + generate a test case names for each combination. + + The result of this function is to be used with the + `@parameterized.named_parameters` decorator. It is a replacement for + `@parameterized.product` which adds explicit test case names. + + For example, this code: + ``` + class NamedExample(parameterized.TestCase): + @parameterized.named_parameters( + named_product( + [ + {'testcase_name': 'negative', 'x': -1}, + {'testcase_name': 'positive', 'x': 1}, + {'testcase_name': 'zero', 'x': 0}, + ], + numeral_type=[float, int], + ) + ) + def test_conversion(self, x, numeral_type): + self.assertEqual(numeral_type(x), x) + ``` + produces six tests (note that absl will reorder them by name): + - `NamedExample::test_conversion_negative_float` + - `NamedExample::test_conversion_positive_float` + - `NamedExample::test_conversion_zero_float` + - `NamedExample::test_conversion_negative_int` + - `NamedExample::test_conversion_positive_int` + - `NamedExample::test_conversion_zero_int` + + This function is also useful in the case where there is no product to + generate test case names for one argument: + ``` + @parameterized.named_parameters(named_product(numeral_type=[float, int])) + ``` + + Args: + *args: Each positional parameter is a sequence of keyword arg dicts. + Every test case generated will include exactly one dict from each + positional parameter. These will then be merged to form an overall + list of arguments for the test case. Each dict must contain a + `"testcase_name"` key whose value is combined with others to + generate the test case name. + **kwargs: A mapping of parameter names and their possible values. + Possible values should given as either a list or a tuple. A string + representation of each value is used to generate the test case name. + + Returns: + A list of maps for the test parameters combinations to pass to + `@parameterized.named_parameters`. + """ + + def value_to_str(value): + if hasattr(value, "__name__"): + return value.__name__.lower() + return str(value).lower() + + # Convert the keyword arguments in the same dict format as the args + all_test_dicts = args + tuple( + tuple({"testcase_name": value_to_str(v), key: v} for v in values) + for key, values in kwargs.items() + ) + + # The current list of tests, start with one empty test + tests = [{}] + for test_dicts in all_test_dicts: + new_tests = [] + for test_dict in test_dicts: + for test in tests: + # Augment the testcase name by appending + testcase_name = test.get("testcase_name", "") + testcase_name += "_" if testcase_name else "" + testcase_name += test_dict["testcase_name"] + new_test = test.copy() + # Augment the test by adding all the parameters + new_test.update(test_dict) + new_test["testcase_name"] = testcase_name + new_tests.append(new_test) + # Overwrite the list of tests with the product obtained so far + tests = new_tests + + return tests diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4042126d4f1408bba877b9a680ddb8837c5fb58a Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..30038801408bde88e37b94b441f8f54a359988d8 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/compile_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93c58d0bf7988c907b48ce0f118b2cf6c7717554 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/epoch_iterator.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97a3600649b741957acaf49c52fa295f93f39632 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/__pycache__/trainer.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1ca6e54f21fee1708b20e7382b283f5b6840dffa --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/compile_utils.py @@ -0,0 +1,820 @@ +from collections import namedtuple + +from keras.src import losses as losses_module +from keras.src import metrics as metrics_module +from keras.src import ops +from keras.src import tree +from keras.src.backend.common.keras_tensor import KerasTensor +from keras.src.losses import loss as loss_module +from keras.src.utils.naming import get_object_name +from keras.src.utils.tracking import Tracker + + +class MetricsList(metrics_module.Metric): + def __init__(self, metrics, name="metrics_list", output_name=None): + super().__init__(name=name) + self.metrics = metrics + self.output_name = output_name + + def update_state(self, y_true, y_pred, sample_weight=None): + for m in self.metrics: + m.update_state(y_true, y_pred, sample_weight=sample_weight) + + def reset_state(self): + for m in self.metrics: + m.reset_state() + + def get_result(self): + return {m.name: m.result() for m in self.metrics} + + def get_config(self): + raise NotImplementedError + + @classmethod + def from_config(cls, config): + raise NotImplementedError + + +def is_function_like(value): + if value is None: + return True + if isinstance(value, str): + return True + if callable(value): + return True + return False + + +def is_binary_or_sparse_categorical(y_true, y_pred): + y_t_rank = len(y_true.shape) + y_p_rank = len(y_pred.shape) + y_t_last_dim = y_true.shape[-1] + y_p_last_dim = y_pred.shape[-1] + + is_binary = y_p_last_dim == 1 + is_sparse_categorical = ( + y_t_rank < y_p_rank or y_t_last_dim == 1 and y_p_last_dim > 1 + ) + return is_binary, is_sparse_categorical + + +def get_metric(identifier, y_true, y_pred): + if identifier is None: + return None # Ok to have no metric for an output. + + # Convenience feature for selecting b/t binary, categorical, + # and sparse categorical. + if str(identifier).lower() not in ["accuracy", "acc"]: + metric_obj = metrics_module.get(identifier) + else: + is_binary, is_sparse_categorical = is_binary_or_sparse_categorical( + y_true, y_pred + ) + if is_binary: + metric_obj = metrics_module.BinaryAccuracy(name=str(identifier)) + elif is_sparse_categorical: + metric_obj = metrics_module.SparseCategoricalAccuracy( + name=str(identifier) + ) + else: + metric_obj = metrics_module.CategoricalAccuracy( + name=str(identifier) + ) + + if isinstance(identifier, str): + metric_name = identifier + else: + metric_name = get_object_name(metric_obj) + + if not isinstance(metric_obj, metrics_module.Metric): + metric_obj = metrics_module.MeanMetricWrapper(metric_obj) + + metric_obj.name = metric_name + return metric_obj + + +def get_loss(identifier, y_true, y_pred): + if identifier is None: + return None # Ok to have no loss for an output. + + # Convenience feature for selecting b/t binary, categorical, + # and sparse categorical. + if str(identifier).lower() not in ["crossentropy", "ce"]: + loss_obj = losses_module.get(identifier) + else: + is_binary, is_sparse_categorical = is_binary_or_sparse_categorical( + y_true, y_pred + ) + if is_binary: + loss_obj = losses_module.binary_crossentropy + elif is_sparse_categorical: + loss_obj = losses_module.sparse_categorical_crossentropy + else: + loss_obj = losses_module.categorical_crossentropy + + if not isinstance(loss_obj, losses_module.Loss): + if isinstance(identifier, str): + loss_name = identifier + else: + loss_name = get_object_name(loss_obj) + loss_obj = losses_module.LossFunctionWrapper(loss_obj, name=loss_name) + return loss_obj + + +class CompileMetrics(metrics_module.Metric): + def __init__( + self, + metrics, + weighted_metrics, + name="compile_metric", + output_names=None, + ): + super().__init__(name=name) + if metrics and not isinstance(metrics, (list, tuple, dict)): + raise ValueError( + "Expected `metrics` argument to be a list, tuple, or dict. " + f"Received instead: metrics={metrics} of type {type(metrics)}" + ) + if weighted_metrics and not isinstance( + weighted_metrics, (list, tuple, dict) + ): + raise ValueError( + "Expected `weighted_metrics` argument to be a list, tuple, or " + f"dict. Received instead: weighted_metrics={weighted_metrics} " + f"of type {type(weighted_metrics)}" + ) + self._user_metrics = metrics + self._user_weighted_metrics = weighted_metrics + self.built = False + self.name = "compile_metrics" + self.output_names = output_names + + @property + def metrics(self): + if not self.built: + return [] + metrics = [] + for m in self._flat_metrics + self._flat_weighted_metrics: + if isinstance(m, MetricsList): + metrics.extend(m.metrics) + elif m is not None: + metrics.append(m) + return metrics + + @property + def variables(self): + # Avoiding relying on implicit tracking since + # CompileMetrics may be instantiated or built in a no tracking scope. + if not self.built: + return [] + vars = [] + for m in self.metrics: + if m is not None: + vars.extend(m.variables) + return vars + + def build(self, y_true, y_pred): + num_outputs = 1 # default + if self.output_names: + output_names = self.output_names + elif isinstance(y_pred, dict): + output_names = sorted(list(y_pred.keys())) + elif isinstance(y_pred, (list, tuple)): + num_outputs = len(y_pred) + if all(hasattr(x, "_keras_history") for x in y_pred): + output_names = [x._keras_history.operation.name for x in y_pred] + else: + output_names = None + else: + output_names = None + if output_names: + num_outputs = len(output_names) + + y_pred = self._flatten_y(y_pred) + y_true = self._flatten_y(y_true) + + metrics = self._user_metrics + weighted_metrics = self._user_weighted_metrics + self._flat_metrics = self._build_metrics_set( + metrics, + num_outputs, + output_names, + y_true, + y_pred, + argument_name="metrics", + ) + self._flat_weighted_metrics = self._build_metrics_set( + weighted_metrics, + num_outputs, + output_names, + y_true, + y_pred, + argument_name="weighted_metrics", + ) + self.built = True + + def _build_metrics_set( + self, metrics, num_outputs, output_names, y_true, y_pred, argument_name + ): + flat_metrics = [] + if isinstance(metrics, dict): + for name in metrics.keys(): + if name not in output_names: + raise ValueError( + f"In the dict argument `{argument_name}`, key " + f"'{name}' does not correspond to any model " + f"output. Received:\n{argument_name}={metrics}" + ) + if num_outputs == 1: + if not metrics: + flat_metrics.append(None) + else: + if isinstance(metrics, dict): + metrics = tree.flatten(metrics) + if not isinstance(metrics, list): + metrics = [metrics] + if not all(is_function_like(m) for m in metrics): + raise ValueError( + f"Expected all entries in the `{argument_name}` list " + f"to be metric objects. Received instead:\n" + f"{argument_name}={metrics}" + ) + flat_metrics.append( + MetricsList( + [ + get_metric(m, y_true[0], y_pred[0]) + for m in metrics + if m is not None + ] + ) + ) + else: + if isinstance(metrics, (list, tuple)): + if len(metrics) != len(y_pred): + raise ValueError( + "For a model with multiple outputs, " + f"when providing the `{argument_name}` argument as a " + "list, it should have as many entries as the model has " + f"outputs. Received:\n{argument_name}={metrics}\nof " + f"length {len(metrics)} whereas the model has " + f"{len(y_pred)} outputs." + ) + for idx, (mls, yt, yp) in enumerate( + zip(metrics, y_true, y_pred) + ): + if not isinstance(mls, list): + mls = [mls] + name = output_names[idx] if output_names else None + if not all(is_function_like(e) for e in mls): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` list should be metric objects. " + f"Found the following sublist with unknown " + f"types: {mls}" + ) + flat_metrics.append( + MetricsList( + [ + get_metric(m, yt, yp) + for m in mls + if m is not None + ], + output_name=name, + ) + ) + elif isinstance(metrics, dict): + if output_names is None: + raise ValueError( + f"Argument `{argument_name}` can only be provided as a " + "dict when the model also returns a dict of outputs. " + f"Received {argument_name}={metrics}" + ) + for name in metrics.keys(): + if not isinstance(metrics[name], list): + metrics[name] = [metrics[name]] + if not all(is_function_like(e) for e in metrics[name]): + raise ValueError( + f"All entries in the sublists of the " + f"`{argument_name}` dict should be metric objects. " + f"At key '{name}', found the following sublist " + f"with unknown types: {metrics[name]}" + ) + for name, yt, yp in zip(output_names, y_true, y_pred): + if name in metrics: + flat_metrics.append( + MetricsList( + [ + get_metric(m, yt, yp) + for m in metrics[name] + if m is not None + ], + output_name=name, + ) + ) + else: + flat_metrics.append(None) + return flat_metrics + + def _flatten_y(self, y): + if isinstance(y, dict) and self.output_names: + result = [] + for name in self.output_names: + if name in y: + result.append(y[name]) + return result + return tree.flatten(y) + + def update_state(self, y_true, y_pred, sample_weight=None): + if not self.built: + self.build(y_true, y_pred) + y_true = self._flatten_y(y_true) + y_pred = self._flatten_y(y_pred) + for m, y_t, y_p in zip(self._flat_metrics, y_true, y_pred): + if m: + m.update_state(y_t, y_p) + if sample_weight is not None: + sample_weight = self._flatten_y(sample_weight) + # For multi-outputs, repeat sample weights for n outputs. + if len(sample_weight) < len(y_true): + sample_weight = [sample_weight[0] for _ in range(len(y_true))] + else: + sample_weight = [None for _ in range(len(y_true))] + for m, y_t, y_p, s_w in zip( + self._flat_weighted_metrics, y_true, y_pred, sample_weight + ): + if m: + m.update_state(y_t, y_p, s_w) + + def reset_state(self): + if not self.built: + return + for m in self._flat_metrics: + if m: + m.reset_state() + for m in self._flat_weighted_metrics: + if m: + m.reset_state() + + def result(self): + if not self.built: + raise ValueError( + "Cannot get result() since the metric has not yet been built." + ) + results = {} + unique_name_counters = {} + for mls in self._flat_metrics: + if not mls: + continue + for m in mls.metrics: + name = m.name + if mls.output_name: + name = f"{mls.output_name}_{name}" + if name not in unique_name_counters: + results[name] = m.result() + unique_name_counters[name] = 1 + else: + index = unique_name_counters[name] + unique_name_counters[name] += 1 + name = f"{name}_{index}" + results[name] = m.result() + + for mls in self._flat_weighted_metrics: + if not mls: + continue + for m in mls.metrics: + name = m.name + if mls.output_name: + name = f"{mls.output_name}_{name}" + if name not in unique_name_counters: + results[name] = m.result() + unique_name_counters[name] = 1 + else: + name = f"weighted_{m.name}" + if mls.output_name: + name = f"{mls.output_name}_{name}" + if name not in unique_name_counters: + unique_name_counters[name] = 1 + else: + index = unique_name_counters[name] + unique_name_counters[name] += 1 + name = f"{name}_{index}" + results[name] = m.result() + return results + + def get_config(self): + raise NotImplementedError + + @classmethod + def from_config(cls, config): + raise NotImplementedError + + +class CompileLoss(losses_module.Loss): + Loss = namedtuple("Loss", ["path", "loss", "loss_weights", "name"]) + + def __init__( + self, + loss, + loss_weights=None, + reduction="sum_over_batch_size", + output_names=None, + ): + if loss_weights and not isinstance( + loss_weights, (list, tuple, dict, float) + ): + raise ValueError( + "Expected `loss_weights` argument to be a float " + "(single output case) or a list, tuple, or " + "dict (multiple output case). " + f"Received instead: loss_weights={loss_weights} " + f"of type {type(loss_weights)}" + ) + self._user_loss = loss + self._user_loss_weights = loss_weights + self.built = False + self.output_names = output_names + super().__init__(name="compile_loss", reduction=reduction) + + # Use `Tracker` to track metrics for individual losses. + self._metrics = [] + self._tracker = Tracker( + { + "metrics": ( + lambda x: isinstance(x, metrics_module.Metric), + self._metrics, + ) + } + ) + self._flat_losses = None + self._y_pred_build_structure = None + self._y_true_build_structure = None + + @property + def metrics(self): + return self._metrics + + @property + def variables(self): + vars = [] + for m in self.metrics: + vars.extend(m.variables) + return vars + + def _build_nested(self, y_true, y_pred, loss, output_names, current_path): + flat_y_pred = tree.flatten(y_pred) + if not tree.is_nested(loss): + _loss = loss.loss + if _loss is None: + return + loss_weight = loss.weight + resolved_loss = get_loss(_loss, y_true, y_pred) + name_path = current_path + if not tree.is_nested(output_names): + if output_names is not None: + output_name = output_names + else: + output_name = resolved_loss.name + if len(name_path) == 0: + name_path = (output_name,) + elif isinstance(name_path[-1], int): + name_path = name_path[:-1] + (output_name,) + name = "/".join([str(path) for path in name_path]) + if name == "": + if isinstance(output_names, dict): + flat_output_names = list(output_names.keys()) + else: + flat_output_names = tree.flatten(output_names) + name = "_".join(flat_output_names) + self._flat_losses.append( + CompileLoss.Loss(current_path, resolved_loss, loss_weight, name) + ) + return + elif ( + issubclass(type(loss), (list, tuple)) + and all([not tree.is_nested(_loss) for _loss in loss]) + and len(loss) == len(flat_y_pred) + ): + loss = tree.pack_sequence_as(y_pred, loss) + elif issubclass(type(loss), (list, tuple)) and not isinstance( + y_pred, type(loss) + ): + for _loss in loss: + self._build_nested( + y_true, + y_pred, + _loss, + output_names, + current_path, + ) + return + + if not tree.is_nested(loss): + return self._build_nested( + y_true, y_pred, loss, output_names, current_path + ) + + if not isinstance(loss, type(y_pred)): + raise KeyError( + f"The path: {current_path} in " + "the `loss` argument, can't be found in " + "the model's output (`y_pred`)." + ) + + # shallow traverse the loss config + if isinstance(loss, dict): + iterator = loss.items() + + def key_check_fn(key, objs): + return all( + [isinstance(obj, dict) and key in obj for obj in objs] + ) + + elif issubclass(type(loss), (list, tuple)): + iterator = enumerate(loss) + + def key_check_fn(key, objs): + return all( + [ + issubclass(type(obj), (list, tuple)) and key < len(obj) + for obj in objs + ] + ) + + else: + raise TypeError( + f"Unsupported type {type(loss)} " + f"in the `loss` configuration." + ) + + for key, _loss in iterator: + if _loss is None: + continue + if not key_check_fn(key, (y_true, y_pred)): + raise KeyError( + f"The path: {current_path + (key,)} in " + "the `loss` argument, can't be found in " + "either the model's output (`y_pred`) or in the " + "labels (`y_true`)." + ) + + self._build_nested( + y_true[key], + y_pred[key], + _loss, + output_names[key], + current_path + (key,), + ) + + def build(self, y_true, y_pred): + loss = self._user_loss + loss_weights = self._user_loss_weights + flat_output_names = self.output_names + if ( + self.output_names + and isinstance(self._user_loss, dict) + and not isinstance(y_pred, dict) + ): + if set(self.output_names) == set(self._user_loss.keys()): + loss = [self._user_loss[name] for name in self.output_names] + if isinstance(self._user_loss_weights, dict): + loss_weights = [ + self._user_loss_weights[name] + for name in self.output_names + ] + else: + raise ValueError( + f"Expected keys {self.output_names} in loss dict, but " + f"found loss.keys()={list(self._user_loss.keys())}" + ) + + # Pytree leaf container + class WeightedLoss: + def __new__(cls, loss, weight): + if loss is None: + return None + return object.__new__(cls) + + def __init__(self, loss, weight): + self.loss = loss + self.weight = weight + + # pack the losses and the weights together + if loss_weights is not None: + try: + tree.assert_same_structure(loss, loss_weights) + except ValueError: + flat_loss_weights = tree.flatten(loss_weights) + if len(tree.flatten(loss)) != len(flat_loss_weights): + raise ValueError( + f"`loss_weights` must match the number of losses, " + f"got {len(tree.flatten(loss))} losses " + f"and {len(loss_weights)} weights." + ) + loss_weights = tree.pack_sequence_as(loss, flat_loss_weights) + loss = tree.map_structure( + lambda _loss, _weight: WeightedLoss(_loss, _weight), + loss, + loss_weights, + ) + else: + loss = tree.map_structure( + lambda _loss: WeightedLoss(_loss, None), loss + ) + + self._flat_losses = [] + + if ( + isinstance(loss, dict) + and issubclass(type(y_pred), (list, tuple)) + and set(loss.keys()) == set(flat_output_names) + and len(y_pred) == len(flat_output_names) + ): + y_pred = {name: y_p for name, y_p in zip(flat_output_names, y_pred)} + y_true = {name: y_t for name, y_t in zip(flat_output_names, y_true)} + elif ( + isinstance(loss, dict) + and not tree.is_nested(y_pred) + and set(loss.keys()) == set(flat_output_names) + and len(flat_output_names) == 1 + ): + y_pred = { + name: y_p for name, y_p in zip(flat_output_names, [y_pred]) + } + y_true = { + name: y_t for name, y_t in zip(flat_output_names, [y_true]) + } + + try: + output_names = tree.pack_sequence_as(y_pred, flat_output_names) + except: + inferred_flat_output_names = self._get_y_pred_output_names(y_pred) + output_names = tree.pack_sequence_as( + y_pred, inferred_flat_output_names + ) + + if not tree.is_nested(loss): + loss = tree.map_structure(lambda x: loss, y_pred) + + self._build_nested(y_true, y_pred, loss, output_names, ()) + + # Add `Mean` metric to the tracker for each loss. + if len(self._flat_losses) > 1: + for _loss in self._flat_losses: + name = _loss.name + "_loss" + self._tracker.add_to_store( + "metrics", metrics_module.Mean(name=name) + ) + + self._y_pred_build_structure = tree.map_structure( + lambda x: None, y_pred + ) + self._y_true_build_structure = tree.map_structure( + lambda x: None, y_true + ) + self.built = True + + def _get_y_pred_output_names(self, y_pred): + flat_y_pred = tree.flatten(y_pred) + if all((isinstance(x, KerasTensor) for x in flat_y_pred)): + output_names = [] + for tensor in flat_y_pred: + if hasattr(tensor, "_keras_history"): + output_names.append(tensor._keras_history.operation.name) + else: + output_names.append(tensor.name) + else: + output_names = [None] * len(flat_y_pred) + return output_names + + def __call__(self, y_true, y_pred, sample_weight=None): + with ops.name_scope(self.name): + return self.call(y_true, y_pred, sample_weight) + + def call(self, y_true, y_pred, sample_weight=None): + if not tree.is_nested(y_true) and not tree.is_nested(y_pred): + # Fast path: single output case / no loss-tracking metric. + if not self.built: + self.build(y_true, y_pred) + _, loss_fn, loss_weight, _ = self._flat_losses[0] + loss_value = ops.cast( + loss_fn(y_true, y_pred, sample_weight), dtype=self.dtype + ) + if loss_weight is not None: + loss_value = ops.multiply(loss_value, loss_weight) + return loss_value + + try: + tree.assert_same_structure(y_pred, y_true) + except ValueError: + # Check case where y_true is either flat or leaf + if ( + not tree.is_nested(y_true) + and hasattr(y_pred, "__len__") + and len(y_pred) == 1 + ): + y_true = [y_true] + + # Check case where y_pred is list/tuple and y_true is dict + elif isinstance(y_pred, (list, tuple)) and isinstance(y_true, dict): + if set(self.output_names) == set(y_true.keys()): + y_true = [y_true[name] for name in self.output_names] + + try: + y_true = tree.pack_sequence_as(y_pred, y_true) + except: + # Check case where y_true has the same structure but uses + # different (but reconcilable) container types, + # e.g `list` vs `tuple`. + try: + tree.assert_same_paths(y_true, y_pred) + y_true = tree.pack_sequence_as(y_pred, tree.flatten(y_true)) + except: + try: + # Check case where loss is partially defined over y_pred + flat_y_true = tree.flatten(y_true) + flat_loss = tree.flatten(self._user_loss) + flat_loss_non_nones = [ + (i, loss) + for i, loss in enumerate(flat_loss) + if loss is not None + ] + assert len(flat_y_true) == len(flat_loss_non_nones) + y_true = [None] * len(flat_loss) + for y_t, (i, loss) in zip( + flat_y_true, flat_loss_non_nones + ): + y_true[i] = y_t + y_true = tree.pack_sequence_as(self._user_loss, y_true) + except: + y_true_struct = tree.map_structure( + lambda _: "*", y_true + ) + y_pred_struct = tree.map_structure( + lambda _: "*", y_pred + ) + raise ValueError( + "y_true and y_pred have different structures.\n" + f"y_true: {y_true_struct}\n" + f"y_pred: {y_pred_struct}\n" + ) + + if not self.built: + self.build(y_true, y_pred) + + try: + tree.assert_same_structure(self._y_pred_build_structure, y_pred) + except ValueError: + y_pred = tree.pack_sequence_as( + self._y_pred_build_structure, tree.flatten(y_pred) + ) + try: + tree.assert_same_structure(self._y_true_build_structure, y_true) + except ValueError: + y_true = tree.pack_sequence_as( + self._y_true_build_structure, tree.flatten(y_true) + ) + + # We need to add a dummy `None` if the model has only a single output. + metrics = [None] if len(self.metrics) == 0 else self.metrics + + # Iterate all losses in flat form. + loss_values = [] + + def resolve_path(path, object): + for _path in path: + object = object[_path] + return object + + for (path, loss_fn, loss_weight, _), metric in zip( + self._flat_losses, metrics + ): + y_t, y_p = resolve_path(path, y_true), resolve_path(path, y_pred) + if sample_weight is not None and tree.is_nested(sample_weight): + _sample_weight = resolve_path(path, sample_weight) + else: + _sample_weight = sample_weight + + value = ops.cast( + loss_fn(y_t, y_p, _sample_weight), dtype=self.dtype + ) + # Record *unweighted* individual losses. + if metric: + metric.update_state( + loss_module.unscale_loss_for_distribution(value), + sample_weight=tree.flatten(y_p)[0].shape[0], + ) + if loss_weight is not None: + value = ops.multiply(value, loss_weight) + loss_values.append(value) + + if loss_values: + total_loss = sum(loss_values) + return total_loss + return None + + def get_config(self): + raise NotImplementedError + + @classmethod + def from_config(cls, config): + raise NotImplementedError diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b10e1d233a3df9dbddb875c1d1235904e7f40abf --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__init__.py @@ -0,0 +1,154 @@ +import types + +from keras.src.distribution import distribution_lib +from keras.src.trainers.data_adapters import array_data_adapter +from keras.src.trainers.data_adapters import data_adapter +from keras.src.trainers.data_adapters import py_dataset_adapter +from keras.src.trainers.data_adapters.array_data_adapter import ArrayDataAdapter +from keras.src.trainers.data_adapters.generator_data_adapter import ( + GeneratorDataAdapter, +) +from keras.src.trainers.data_adapters.py_dataset_adapter import PyDatasetAdapter +from keras.src.trainers.data_adapters.tf_dataset_adapter import TFDatasetAdapter +from keras.src.trainers.data_adapters.torch_data_loader_adapter import ( + TorchDataLoaderAdapter, +) + + +def get_data_adapter( + x, + y=None, + sample_weight=None, + batch_size=None, + steps_per_epoch=None, + shuffle=False, + class_weight=None, +): + # Allow passing a custom data adapter. + if isinstance(x, data_adapter.DataAdapter): + return x + + # Check for multi-process/worker distribution. Since only tf.dataset + # is supported at the moment, we will raise error if the inputs fail + # the type check + distribution = distribution_lib.distribution() + if getattr(distribution, "_is_multi_process", False) and not is_tf_dataset( + x + ): + raise ValueError( + "When using multi-worker distribution, the data must be provided " + f"as a `tf.data.Dataset` instance. Received: type(x)={type(x)}." + ) + + if array_data_adapter.can_convert_arrays((x, y, sample_weight)): + return ArrayDataAdapter( + x, + y, + sample_weight=sample_weight, + class_weight=class_weight, + shuffle=shuffle, + batch_size=batch_size, + steps=steps_per_epoch, + ) + elif is_tf_dataset(x): + # Unsupported args: y, sample_weight, shuffle + if y is not None: + raise_unsupported_arg("y", "the targets", "tf.data.Dataset") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "tf.data.Dataset" + ) + return TFDatasetAdapter( + x, class_weight=class_weight, distribution=distribution + ) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a tf.data.Dataset. The Dataset is " + # "expected to already be shuffled " + # "(via `.shuffle(tf.data.AUTOTUNE)`)" + # ) + elif isinstance(x, py_dataset_adapter.PyDataset): + if y is not None: + raise_unsupported_arg("y", "the targets", "PyDataset") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "PyDataset" + ) + return PyDatasetAdapter(x, class_weight=class_weight, shuffle=shuffle) + # TODO: should we warn or not? + # if x.num_batches is None and shuffle: + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a infinite PyDataset. The " + # "PyDataset is expected to already be shuffled." + # ) + elif is_torch_dataloader(x): + if y is not None: + raise_unsupported_arg("y", "the targets", "torch DataLoader") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "torch DataLoader" + ) + if class_weight is not None: + raise ValueError( + "Argument `class_weight` is not supported for torch " + f"DataLoader inputs. Received: class_weight={class_weight}" + ) + return TorchDataLoaderAdapter(x) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a torch DataLoader. The DataLoader " + # "is expected to already be shuffled." + # ) + elif isinstance(x, types.GeneratorType): + if y is not None: + raise_unsupported_arg("y", "the targets", "PyDataset") + if sample_weight is not None: + raise_unsupported_arg( + "sample_weights", "the sample weights", "PyDataset" + ) + if class_weight is not None: + raise ValueError( + "Argument `class_weight` is not supported for Python " + f"generator inputs. Received: class_weight={class_weight}" + ) + return GeneratorDataAdapter(x) + # TODO: should we warn or not? + # warnings.warn( + # "`shuffle=True` was passed, but will be ignored since the " + # "data `x` was provided as a generator. The generator " + # "is expected to yield already-shuffled data." + # ) + else: + raise ValueError(f"Unrecognized data type: x={x} (of type {type(x)})") + + +def raise_unsupported_arg(arg_name, arg_description, input_type): + raise ValueError( + f"When providing `x` as a {input_type}, `{arg_name}` " + f"should not be passed. Instead, {arg_description} should " + f"be included as part of the {input_type}." + ) + + +def is_tf_dataset(x): + if hasattr(x, "__class__"): + for parent in x.__class__.__mro__: + if parent.__name__ in ( + "DatasetV2", + "DistributedDataset", + ) and "tensorflow.python." in str(parent.__module__): + return True + return False + + +def is_torch_dataloader(x): + if hasattr(x, "__class__"): + for parent in x.__class__.__mro__: + if parent.__name__ == "DataLoader" and "torch.utils.data" in str( + parent.__module__ + ): + return True + return False diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d19a51f3e1286570e9d9d439608d2f756771e449 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9135a054d172a8b3e8e1a570bfafc57ab7a8c6dc Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_data_adapter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a9b16b36f49e8f8e121bd3c62cb3c83fb4dc0b1 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/array_slicing.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69be4ffbe973d1d6fb29f0a2f44496c6f55637b6 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b22ca7503ffebf4224cd3be46e71041b66c5e8c6 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/data_adapter_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b00949069d4ad94245ba04f20f13a134b886d351 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/generator_data_adapter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af8611f8cb335dd3d326669e8acc3f6e3e1fbb26 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/py_dataset_adapter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..49dae3c8c795fc7f3c10a0a4dc31e57b284b8603 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/tf_dataset_adapter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87b817bbb7d1eb524a70f23ce16a857e25cd259d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/__pycache__/torch_data_loader_adapter.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..e732f28688bddb02ba973d3687dc0e56f18d3dc1 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_data_adapter.py @@ -0,0 +1,372 @@ +import functools +import math + +import numpy as np + +from keras.src import tree +from keras.src.trainers.data_adapters import array_slicing +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class ArrayDataAdapter(DataAdapter): + """Adapter for array-like objects, e.g. TF/JAX Tensors, NumPy arrays.""" + + def __init__( + self, + x, + y=None, + sample_weight=None, + batch_size=None, + steps=None, + shuffle=False, + class_weight=None, + ): + if not can_convert_arrays((x, y, sample_weight)): + raise ValueError( + "Expected all elements of `x` to be array-like. " + f"Received invalid types: x={x}" + ) + + if sample_weight is not None: + if class_weight is not None: + raise ValueError( + "You cannot `class_weight` and `sample_weight` " + "at the same time." + ) + if tree.is_nested(y): + if isinstance(sample_weight, (list, tuple, dict)): + try: + tree.assert_same_structure(y, sample_weight) + except ValueError: + raise ValueError( + "You should provide one `sample_weight` array per " + "output in `y`. The two structures did not match:\n" + f"- y: {y}\n" + f"- sample_weight: {sample_weight}\n" + ) + else: + is_samplewise = len(sample_weight.shape) == 1 or ( + len(sample_weight.shape) == 2 + and sample_weight.shape[1] == 1 + ) + if not is_samplewise: + raise ValueError( + "For a model with multiple outputs, when providing " + "a single `sample_weight` array, it should only " + "have one scalar score per sample " + "(i.e. shape `(num_samples,)`). If you want to use " + "non-scalar sample weights, pass a `sample_weight` " + "argument with one array per model output." + ) + # Replicate the same sample_weight array on all outputs. + sample_weight = tree.map_structure( + lambda _: sample_weight, y + ) + if class_weight is not None: + if tree.is_nested(y): + raise ValueError( + "`class_weight` is only supported for Models with a single " + "output." + ) + sample_weight = data_adapter_utils.class_weight_to_sample_weights( + y, class_weight + ) + + inputs = data_adapter_utils.pack_x_y_sample_weight(x, y, sample_weight) + + data_adapter_utils.check_data_cardinality(inputs) + num_samples = set(i.shape[0] for i in tree.flatten(inputs)).pop() + self._num_samples = num_samples + self._inputs = inputs + + # If batch_size is not passed but steps is, calculate from the input + # data. Defaults to `32` for backwards compatibility. + if not batch_size: + batch_size = int(math.ceil(num_samples / steps)) if steps else 32 + + self._size = int(math.ceil(num_samples / batch_size)) + self._batch_size = batch_size + self._partial_batch_size = num_samples % batch_size + self._shuffle = shuffle + + def get_numpy_iterator(self): + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="numpy" + ) + + def slice_and_convert_to_numpy(sliceable, indices=None): + x = sliceable[indices] + x = sliceable.convert_to_numpy(x) + return x + + return self._get_iterator(slice_and_convert_to_numpy, inputs) + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + shuffle = self._shuffle + batch_size = self._batch_size + num_samples = self._num_samples + num_full_batches = int(self._num_samples // batch_size) + + # Vectorized version of shuffle. + # This is a performance improvement over using `from_tensor_slices`. + # The indices of the data are shuffled and batched, and these indices + # are then zipped with the data and used to extract a batch of the data + # at each step. The performance improvements here come from: + # 1. vectorized batch using gather + # 2. parallelized map + # 3. pipelined permutation generation + # 4. optimized permutation batching + # 5. disabled static optimizations + + indices_dataset = tf.data.Dataset.range(1) + + def permutation(_): + # It turns out to be more performant to make a new set of indices + # rather than reusing the same range Tensor. (presumably because of + # buffer forwarding.) + indices = tf.range(num_samples, dtype=tf.int64) + if shuffle and shuffle != "batch": + indices = tf.random.shuffle(indices) + return indices + + # We prefetch a single element. Computing large permutations can take + # quite a while so we don't want to wait for prefetching over an epoch + # boundary to trigger the next permutation. On the other hand, too many + # simultaneous shuffles can contend on a hardware level and degrade all + # performance. + indices_dataset = indices_dataset.map(permutation).prefetch(1) + + def slice_batch_indices(indices): + """Convert a Tensor of indices into a dataset of batched indices. + + This step can be accomplished in several ways. The most natural is + to slice the Tensor in a Dataset map. (With a condition on the upper + index to handle the partial batch.) However it turns out that + coercing the Tensor into a shape which is divisible by the batch + size (and handling the last partial batch separately) allows for a + much more favorable memory access pattern and improved performance. + + Args: + indices: Tensor which determines the data order for an entire + epoch. + + Returns: + A Dataset of batched indices. + """ + num_in_full_batch = num_full_batches * batch_size + first_k_indices = tf.slice(indices, [0], [num_in_full_batch]) + first_k_indices = tf.reshape( + first_k_indices, [num_full_batches, batch_size] + ) + + flat_dataset = tf.data.Dataset.from_tensor_slices(first_k_indices) + if self._partial_batch_size: + index_remainder = tf.data.Dataset.from_tensors( + tf.slice( + indices, [num_in_full_batch], [self._partial_batch_size] + ) + ) + flat_dataset = flat_dataset.concatenate(index_remainder) + + return flat_dataset + + def slice_inputs(indices_dataset, inputs): + """Slice inputs into a Dataset of batches. + + Given a Dataset of batch indices and the unsliced inputs, + this step slices the inputs in a parallelized fashion + and produces a dataset of input batches. + + Args: + indices_dataset: A Dataset of batched indices. + inputs: A python data structure that contains the inputs, + targets, and possibly sample weights. + + Returns: + A Dataset of input batches matching the batch indices. + """ + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="tensorflow" + ) + inputs = tree.lists_to_tuples(inputs) + + dataset = tf.data.Dataset.zip( + (indices_dataset, tf.data.Dataset.from_tensors(inputs).repeat()) + ) + + def grab_batch(i, data): + def grab_one(x): + if isinstance(x, array_slicing.TensorflowSparseWrapper): + return array_slicing.slice_tensorflow_sparse_wrapper( + x, i + ) + if isinstance(x, (list, tuple, dict)): + return None + if tf.is_tensor(x): + return tf.gather(x, i, axis=0) + return x + + return tree.traverse(grab_one, data) + + dataset = dataset.map( + grab_batch, num_parallel_calls=tf.data.AUTOTUNE + ) + + # Default optimizations are disabled to avoid the overhead of + # (unnecessary) input pipeline graph serialization & deserialization + options = tf.data.Options() + options.experimental_optimization.apply_default_optimizations = ( + False + ) + if self._shuffle: + options.experimental_external_state_policy = ( + tf.data.experimental.ExternalStatePolicy.IGNORE + ) + dataset = dataset.with_options(options) + return dataset + + indices_dataset = indices_dataset.flat_map(slice_batch_indices) + if shuffle == "batch": + indices_dataset = indices_dataset.map(tf.random.shuffle) + + dataset = slice_inputs(indices_dataset, self._inputs) + + options = tf.data.Options() + options.experimental_distribute.auto_shard_policy = ( + tf.data.experimental.AutoShardPolicy.DATA + ) + dataset = dataset.with_options(options) + return dataset.prefetch(tf.data.AUTOTUNE) + + def get_jax_iterator(self): + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="jax" + ) + + def slice_and_convert_to_jax(sliceable, indices=None): + x = sliceable[indices] + x = sliceable.convert_to_jax_compatible(x) + return x + + return self._get_iterator(slice_and_convert_to_jax, inputs) + + def get_torch_dataloader(self): + import torch + + from keras.src.backend.torch.core import convert_to_tensor + + class ArrayDataset(torch.utils.data.Dataset): + def __init__(self, array): + self.array = array + + def __getitems__(self, indices): + def slice_and_convert(sliceable): + x = sliceable[indices] + x = sliceable.convert_to_torch_compatible(x) + x = convert_to_tensor(x) + return x + + return tree.map_structure(slice_and_convert, self.array) + + def __len__(self): + return len(self.array[0]) + + class RandomBatchSampler(torch.utils.data.Sampler): + def __init__(self, sampler): + self.sampler = sampler + + def __iter__(self): + for batch in self.sampler: + yield [batch[i] for i in torch.randperm(len(batch))] + + def __len__(self): + return len(self.sampler) + + if self._shuffle == "batch": + batch_sampler = RandomBatchSampler( + torch.utils.data.BatchSampler( + range(self._num_samples), + batch_size=self._batch_size, + drop_last=False, + ) + ) + elif self._shuffle: + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.RandomSampler(range(self._num_samples)), + batch_size=self._batch_size, + drop_last=False, + ) + else: + batch_sampler = torch.utils.data.BatchSampler( + torch.utils.data.SequentialSampler(range(self._num_samples)), + batch_size=self._batch_size, + drop_last=False, + ) + + # Because ArrayDataset.__getitems__ returns full batches organized in + # the expected structure, there is nothing to collate. + def no_op_collate(batch): + return batch + + inputs = array_slicing.convert_to_sliceable( + self._inputs, target_backend="torch" + ) + dataset = ArrayDataset(inputs) + return torch.utils.data.DataLoader( + dataset, batch_sampler=batch_sampler, collate_fn=no_op_collate + ) + + def _get_iterator(self, slice_and_convert_fn, inputs): + global_permutation = None + if self._shuffle and self._shuffle != "batch": + global_permutation = np.random.permutation(self._num_samples) + + for i in range(self._size): + start = i * self._batch_size + stop = min((i + 1) * self._batch_size, self._num_samples) + if self._shuffle == "batch": + indices = np.random.permutation(stop - start) + start + elif self._shuffle: + indices = global_permutation[start:stop] + else: + indices = slice(start, stop) + + slice_indices_and_convert_fn = functools.partial( + slice_and_convert_fn, indices=indices + ) + yield tree.map_structure(slice_indices_and_convert_fn, inputs) + + @property + def num_batches(self): + return self._size + + @property + def batch_size(self): + return self._batch_size + + @property + def has_partial_batch(self): + return self._partial_batch_size > 0 + + @property + def partial_batch_size(self): + return self._partial_batch_size or None + + +def can_convert_arrays(arrays): + """Check if array like-inputs can be handled by `ArrayDataAdapter` + + Args: + inputs: Structure of `Tensor`s, NumPy arrays, or tensor-like. + + Returns: + `True` if `arrays` can be handled by `ArrayDataAdapter`, `False` + otherwise. + """ + + return all( + tree.flatten(tree.map_structure(array_slicing.can_slice_array, arrays)) + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py new file mode 100644 index 0000000000000000000000000000000000000000..74622ebb4aee3ead4179db88b7d29888b5b35ac8 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/array_slicing.py @@ -0,0 +1,520 @@ +import collections +import math + +import numpy as np + +from keras.src import backend +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.utils.module_utils import tensorflow as tf + +try: + import pandas +except ImportError: + pandas = None + + +# Leave jax, tf, and torch arrays off this list. Instead we will use +# `__array__` to detect these types. Doing so allows us to avoid importing a +# backend framework we are not currently using just to do type-checking. +ARRAY_TYPES = (np.ndarray,) +if pandas: + ARRAY_TYPES = ARRAY_TYPES + (pandas.Series, pandas.DataFrame) + + +class Sliceable: + """`Sliceable` wrapping a tensor. + + A `Sliceable` implements the subscript operator to slice or index against + the first dimension of the array. It also has conversion methods for each + one of the backends. + + Args: + array: the native array or tensor to wrap. + + Attributes: + shape: the shape of the full dense native array. + """ + + def __init__(self, array): + self.array = array + + def __getitem__(self, indices): + """Select elements in the 0th dimension. + + Args: + indices: the indices to select. Only needs to support one dimension, + the 0th dimension. Should support a `slice` or a list, tuple, + `np.array` or 1D tensor. + Returns: A slice of `self.array`. + """ + return self.array[indices] + + @classmethod + def cast(cls, x, dtype): + """Cast a tensor to a different dtype. + + Only called on a full array as provided by the user. + + Args: + x: the tensor to cast. + Returns: the cast tensor. + """ + return x.astype(dtype) + + @classmethod + def convert_to_numpy(cls, x): + """Convert a tensor to a NumPy array. + + Only called after slicing using `__getitem__`. + + Args: + x: the tensor to convert. + Returns: the converted tensor. + """ + return x + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + """Convert a tensor to something compatible with `tf.data.Dataset`. + + This can be a NumPy array, `tf.Tensor` or any other type of tensor that + `tf.data.Dataset.from_tensors` can consume. + Only called on a full array as provided by the user. + + Args: + x: the tensor to convert. + Returns: converted version tensor. + """ + return x + + @classmethod + def convert_to_jax_compatible(cls, x): + """Convert a tensor to something that the JAX backend can consume. + + This can be a `JAX` array, `JAXSparse` or a NumPy array. + Only called after slicing using `__getitem__`. + Used to convert sparse tensors and densify ragged tensors. + + Args: + x: the tensor to convert. + Returns: the converted tensor. + """ + return x + + @classmethod + def convert_to_torch_compatible(cls, x): + """Convert a tensor to something that the Torch backend can consume. + + This can be a Torch tensor, NumPy array or any other type of tensor that + `keras.backend.torch.core.convert_to_tensor()` can consume. + Only called after slicing using `__getitem__`. + Used to densify sparse tensors and ragged tensors. + + Args: + x: the tensor to convert. + Returns: the converted tensor. + """ + return x + + +class NumpySliceable(Sliceable): + pass + + +class TensorflowSliceable(Sliceable): + def __getitem__(self, indices): + from keras.src.utils.module_utils import tensorflow as tf + + if isinstance(indices, slice): + return self.array[indices] + else: + return tf.gather(self.array, indices, axis=0) + + @classmethod + def cast(cls, x, dtype): + from keras.src.backend.tensorflow.core import cast + + return cast(x, dtype) + + @classmethod + def convert_to_numpy(cls, x): + from keras.src.backend.tensorflow.core import convert_to_numpy + + return convert_to_numpy(x) + + +class TensorflowRaggedSliceable(TensorflowSliceable): + @classmethod + def convert_to_jax_compatible(cls, x): + return cls.convert_to_numpy(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + return x.to_tensor() + + +class TensorflowSparseSliceable(TensorflowSliceable): + def __init__(self, array): + super().__init__(to_tensorflow_sparse_wrapper(array)) + + @property + def shape(self): + return self.array.sparse.shape + + def __getitem__(self, indices): + return slice_tensorflow_sparse_wrapper(self.array, indices) + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + return to_tensorflow_sparse_wrapper(x) + + @classmethod + def convert_to_jax_compatible(cls, x): + return data_adapter_utils.tf_sparse_to_jax_sparse(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + from keras.src.backend.tensorflow import sparse as tf_sparse + + return tf_sparse.sparse_to_dense(x) + + +class JaxSparseSliceable(Sliceable): + def __getitem__(self, indices): + return self.array[indices, ...] + + @classmethod + def convert_to_numpy(cls, x): + from keras.src.backend.jax.core import convert_to_numpy + + return convert_to_numpy(x) + + @classmethod + def convert_to_tf_dataset_compatible(cls, array): + return to_tensorflow_sparse_wrapper( + data_adapter_utils.jax_sparse_to_tf_sparse(array) + ) + + @classmethod + def convert_to_torch_compatible(cls, x): + return x.todense() + + +class TorchSliceable(Sliceable): + @classmethod + def cast(cls, x, dtype): + from keras.src.backend.torch.core import cast + + return cast(x, dtype) + + @classmethod + def convert_to_numpy(cls, x): + from keras.src.backend.torch.core import convert_to_numpy + + return convert_to_numpy(x) + + +class PandasSliceable(Sliceable): + def __getitem__(self, indices): + return self.array.iloc[indices] + + @classmethod + def convert_to_numpy(cls, x): + return x.to_numpy() + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + return cls.convert_to_numpy(x) + + @classmethod + def convert_to_jax_compatible(cls, x): + return cls.convert_to_numpy(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + return cls.convert_to_numpy(x) + + +class PandasDataFrameSliceable(PandasSliceable): + pass + + +class PandasSeriesSliceable(PandasSliceable): + @classmethod + def convert_to_numpy(cls, x): + return np.expand_dims(x.to_numpy(), axis=-1) + + +class ScipySparseSliceable(Sliceable): + def __init__(self, array): + # The COO representation is not indexable / sliceable and does not lend + # itself to it. Use the CSR representation instead, which is sliceable. + super().__init__(array.tocsr()) + + @classmethod + def convert_to_numpy(cls, x): + return x.todense() + + @classmethod + def convert_to_tf_dataset_compatible(cls, x): + return to_tensorflow_sparse_wrapper( + data_adapter_utils.scipy_sparse_to_tf_sparse(x) + ) + + @classmethod + def convert_to_jax_compatible(cls, x): + return data_adapter_utils.scipy_sparse_to_jax_sparse(x) + + @classmethod + def convert_to_torch_compatible(cls, x): + return x.todense() + + +# `tf.SparseTensor` does not support indexing or `tf.gather`. The COO +# representation it uses does not lend itself to indexing. We add some +# intermediary tensors to ease the indexing and slicing. We put both indices and +# values in `RaggedTensor`s where each row corresponds to a row in the sparse +# tensor. This is because the number of values per row is not fixed. +# `RaggedTensor`s do support indexing and `tf.gather`, although on CPU only. +# We then reconstruct a `SparseTensor` from extracted rows. In theory, there is +# no duplication of data for the indices and values, only the addition of row +# splits for the ragged representation. +# `TensorflowSparseWrapper` is a named tuple which combines the original +# `SparseTensor` (used for the shape) and the ragged representations of indices +# and values for indexing / slicing. We use a named tuple and not a `Sliceable` +# to be able to ingest it in `tf.data.Dataset.from_tensors()` and map it. + +TensorflowSparseWrapper = collections.namedtuple( + "TensorflowSparseWrapper", ["sparse", "ragged_indices", "ragged_values"] +) + + +def to_tensorflow_sparse_wrapper(sparse): + from keras.src.utils.module_utils import tensorflow as tf + + row_ids = sparse.indices[:, 0] + row_splits = tf.experimental.RowPartition.from_value_rowids( + row_ids + ).row_splits() + + ragged_indices = tf.cast( + tf.RaggedTensor.from_row_splits(sparse.indices, row_splits), tf.int64 + ) + ragged_values = tf.RaggedTensor.from_row_splits(sparse.values, row_splits) + return TensorflowSparseWrapper(sparse, ragged_indices, ragged_values) + + +def slice_tensorflow_sparse_wrapper(sparse_wrapper, indices): + from keras.src.utils.module_utils import tensorflow as tf + + if isinstance(indices, slice): + sparse_indices = sparse_wrapper.ragged_indices[indices] + sparse_values = sparse_wrapper.ragged_values[indices] + batch_dim = indices.stop - indices.start + else: + sparse_indices = tf.gather(sparse_wrapper.ragged_indices, indices) + sparse_values = tf.gather(sparse_wrapper.ragged_values, indices) + if isinstance(indices, list): + batch_dim = len(indices) + else: + batch_dim = indices.shape[0] + if batch_dim is None: + batch_dim = tf.shape(indices)[0] + + row_ids = sparse_indices.value_rowids() + sparse_indices = sparse_indices.flat_values[:, 1:] # remove first value + sparse_indices = tf.concat( + [tf.expand_dims(row_ids, -1), sparse_indices], axis=1 + ) + + sparse_values = sparse_values.flat_values + sparse_shape = (batch_dim,) + tuple( + sparse_wrapper.sparse.shape.as_list()[1:] + ) + return tf.SparseTensor(sparse_indices, sparse_values, sparse_shape) + + +def can_slice_array(x): + return ( + x is None + or isinstance(x, ARRAY_TYPES) + or data_adapter_utils.is_tensorflow_tensor(x) + or data_adapter_utils.is_jax_array(x) + or data_adapter_utils.is_torch_tensor(x) + or data_adapter_utils.is_scipy_sparse(x) + or hasattr(x, "__array__") + ) + + +def convert_to_sliceable(arrays, target_backend=None): + """Convert a structure of arrays into `Sliceable` instances + + Args: + arrays: the arrays to convert. + target_backend: the target backend for the output: + - `None` indicates that `arrays` will be wrapped into `Sliceable`s + as-is without using a different representation. This is used by + `train_validation_split()`. + - `tensorflow` indicates that + `Sliceable.convert_to_tf_dataset_compatible` will be called. The + returned structure therefore contains arrays, not `Sliceable`s. + - `numpy`, `jax` or `torch` indices that the arrays will eventually + be converted to this backend type after slicing. In this case, + the intermediary `Sliceable`s may use a different representation + from the input `arrays` for better performance. + Returns: the same structure with `Sliceable` instances or arrays. + """ + + def convert_single_array(x): + if x is None: + return x + + # Special case: handle np "object" arrays containing strings + if ( + isinstance(x, np.ndarray) + and str(x.dtype) == "object" + and backend.backend() == "tensorflow" + and all(isinstance(e, str) for e in x) + ): + x = tf.convert_to_tensor(x, dtype="string") + + # Step 1. Determine which Sliceable class to use. + if isinstance(x, np.ndarray): + sliceable_class = NumpySliceable + elif data_adapter_utils.is_tensorflow_tensor(x): + if data_adapter_utils.is_tensorflow_ragged(x): + sliceable_class = TensorflowRaggedSliceable + elif data_adapter_utils.is_tensorflow_sparse(x): + sliceable_class = TensorflowSparseSliceable + else: + sliceable_class = TensorflowSliceable + elif data_adapter_utils.is_jax_array(x): + if data_adapter_utils.is_jax_sparse(x): + sliceable_class = JaxSparseSliceable + else: + x = np.asarray(x) + sliceable_class = NumpySliceable + elif data_adapter_utils.is_torch_tensor(x): + sliceable_class = TorchSliceable + elif pandas is not None and isinstance(x, pandas.DataFrame): + sliceable_class = PandasDataFrameSliceable + elif pandas is not None and isinstance(x, pandas.Series): + sliceable_class = PandasSeriesSliceable + elif data_adapter_utils.is_scipy_sparse(x): + sliceable_class = ScipySparseSliceable + elif hasattr(x, "__array__"): + x = np.asarray(x) + sliceable_class = NumpySliceable + else: + raise ValueError( + "Expected a NumPy array, tf.Tensor, tf.RaggedTensor, " + "tf.SparseTensor, jax.np.ndarray, " + "jax.experimental.sparse.JAXSparse, torch.Tensor, " + "Pandas Dataframe, or Pandas Series. Received invalid input: " + f"{x} (of type {type(x)})" + ) + + # Step 2. Normalize floats to floatx. + def is_non_floatx_float(dtype): + return ( + dtype is not object + and backend.is_float_dtype(dtype) + and not backend.standardize_dtype(dtype) == backend.floatx() + ) + + cast_dtype = None + if pandas is not None and isinstance(x, pandas.DataFrame): + if any(is_non_floatx_float(d) for d in x.dtypes.values): + cast_dtype = backend.floatx() + else: + if is_non_floatx_float(x.dtype): + cast_dtype = backend.floatx() + + if cast_dtype is not None: + x = sliceable_class.cast(x, cast_dtype) + + # Step 3. Apply target backend specific logic and optimizations. + if target_backend is None: + return sliceable_class(x) + + if target_backend == "tensorflow": + return sliceable_class.convert_to_tf_dataset_compatible(x) + + # With dense arrays and JAX as output, it is faster to use NumPy as an + # intermediary representation, so wrap input array in a NumPy array, + # which should not use extra memory. + # See https://github.com/google/jax/issues/1276 for an explanation of + # why slicing a NumPy array is faster than slicing a JAX array. + if target_backend == "jax" and sliceable_class in ( + TensorflowSliceable, + TorchSliceable, + ): + x = np.asarray(x) + sliceable_class = NumpySliceable + + return sliceable_class(x) + + return tree.map_structure(convert_single_array, arrays) + + +def train_validation_split(arrays, validation_split): + """Split arrays into train and validation subsets in deterministic order. + + The last part of data will become validation data. + + Args: + arrays: Tensors to split. Allowed inputs are arbitrarily nested + structures of Tensors and NumPy arrays. + validation_split: Float between 0 and 1. The proportion of the dataset + to include in the validation split. The rest of the dataset will be + included in the training split. + + Returns: + `(train_arrays, validation_arrays)` + """ + + flat_arrays = tree.flatten(arrays) + unsplitable = [type(t) for t in flat_arrays if not can_slice_array(t)] + if unsplitable: + raise ValueError( + "Argument `validation_split` is only supported " + "for tensors or NumPy arrays." + f"Found incompatible type in the input: {unsplitable}" + ) + + if all(t is None for t in flat_arrays): + return arrays, arrays + + first_non_none = None + for t in flat_arrays: + if t is not None: + first_non_none = t + break + + # Assumes all arrays have the same batch shape or are `None`. + batch_dim = int(first_non_none.shape[0]) + split_at = int(math.floor(batch_dim * (1.0 - validation_split))) + + if split_at == 0 or split_at == batch_dim: + raise ValueError( + f"Training data contains {batch_dim} samples, which is not " + "sufficient to split it into a validation and training set as " + f"specified by `validation_split={validation_split}`. Either " + "provide more data, or a different value for the " + "`validation_split` argument." + ) + + def _split(t, start, end): + if t is None: + return t + return t[start:end] + + sliceables = convert_to_sliceable(arrays) + train_arrays = tree.map_structure( + lambda x: _split(x, start=0, end=split_at), sliceables + ) + val_arrays = tree.map_structure( + lambda x: _split(x, start=split_at, end=batch_dim), sliceables + ) + return train_arrays, val_arrays diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..b12dd06203bb359b67ce3008a0ce67c99c300897 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter.py @@ -0,0 +1,97 @@ +class DataAdapter: + """Base class for input data adapters. + + The purpose of a DataAdapter is to provide a unified interface to + iterate over input data provided in a variety of formats -- such as + NumPy arrays, tf.Tensors, tf.data.Datasets, Keras PyDatasets, etc. + """ + + def get_numpy_iterator(self): + """Get a Python iterable for the `DataAdapter`, that yields NumPy + arrays. + + Returns: + A Python iterator. + """ + raise NotImplementedError + + def get_tf_dataset(self): + """Get a `tf.data.Dataset` instance for the DataAdapter. + + Note that the dataset returned does not repeat for epoch, so caller + might need to create new iterator for the same dataset at the beginning + of the epoch. This behavior might change in the future. + + Returns: + A `tf.data.Dataset`. Caller might use the dataset in different + context, e.g. iter(dataset) in eager to get the value directly, or + in graph mode, provide the iterator tensor to Keras model function. + """ + raise NotImplementedError + + def get_jax_iterator(self): + """Get a Python iterable for the `DataAdapter`, that yields arrays that + that can be fed to JAX. NumPy arrays are preferred for performance. + + Returns: + A Python iterator. + """ + raise NotImplementedError + + def get_torch_dataloader(self): + """Get a Torch `DataLoader` for the `DataAdapter`. + + Returns: + A Torch `DataLoader`. + """ + raise NotImplementedError + + @property + def num_batches(self): + """Return the size (number of batches) for the dataset created. + + For certain type of the data input, the number of batches is known, eg + for Numpy data, the size is same as (number_of_element / batch_size). + Whereas for dataset or python generator, the size is unknown since it + may or may not have an end state. + + Returns: + int, the number of batches for the dataset, or None if it is + unknown. The caller could use this to control the loop of training, + show progress bar, or handle unexpected StopIteration error. + """ + raise NotImplementedError + + @property + def batch_size(self): + """Return the batch size of the dataset created. + + For certain type of the data input, the batch size is known, and even + required, like numpy array. Whereas for dataset, the batch is unknown + unless we take a peek. + + Returns: + int, the batch size of the dataset, or None if it is unknown. + """ + raise NotImplementedError + + @property + def has_partial_batch(self): + """Whether the dataset has partial batch at the end.""" + raise NotImplementedError + + @property + def partial_batch_size(self): + """The size of the final partial batch for dataset. + + Will return None if has_partial_batch is False or batch_size is None. + """ + raise NotImplementedError + + def on_epoch_begin(self): + """A hook called before each epoch.""" + pass + + def on_epoch_end(self): + """A hook called after each epoch.""" + pass diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..35a1327d34531d5be2f7fc7a454773a57eb434d7 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/data_adapter_utils.py @@ -0,0 +1,329 @@ +import numpy as np + +from keras.src import backend +from keras.src import ops +from keras.src import tree +from keras.src.api_export import keras_export + +NUM_BATCHES_FOR_TENSOR_SPEC = 2 + + +@keras_export("keras.utils.unpack_x_y_sample_weight") +def unpack_x_y_sample_weight(data): + """Unpacks user-provided data tuple. + + This is a convenience utility to be used when overriding + `Model.train_step`, `Model.test_step`, or `Model.predict_step`. + This utility makes it easy to support data of the form `(x,)`, + `(x, y)`, or `(x, y, sample_weight)`. + + Example: + + >>> features_batch = ops.ones((10, 5)) + >>> labels_batch = ops.zeros((10, 5)) + >>> data = (features_batch, labels_batch) + >>> # `y` and `sample_weight` will default to `None` if not provided. + >>> x, y, sample_weight = unpack_x_y_sample_weight(data) + >>> sample_weight is None + True + + Args: + data: A tuple of the form `(x,)`, `(x, y)`, or `(x, y, sample_weight)`. + + Returns: + The unpacked tuple, with `None`s for `y` and `sample_weight` if they are + not provided. + """ + if isinstance(data, list): + data = tuple(data) + if not isinstance(data, tuple): + return (data, None, None) + elif len(data) == 1: + return (data[0], None, None) + elif len(data) == 2: + return (data[0], data[1], None) + elif len(data) == 3: + return (data[0], data[1], data[2]) + error_msg = ( + "Data is expected to be in format `x`, `(x,)`, `(x, y)`, " + f"or `(x, y, sample_weight)`, found: {data}" + ) + raise ValueError(error_msg) + + +@keras_export("keras.utils.pack_x_y_sample_weight") +def pack_x_y_sample_weight(x, y=None, sample_weight=None): + """Packs user-provided data into a tuple. + + This is a convenience utility for packing data into the tuple formats + that `Model.fit()` uses. + + Example: + + >>> x = ops.ones((10, 1)) + >>> data = pack_x_y_sample_weight(x) + >>> isinstance(data, ops.Tensor) + True + >>> y = ops.ones((10, 1)) + >>> data = pack_x_y_sample_weight(x, y) + >>> isinstance(data, tuple) + True + >>> x, y = data + + Args: + x: Features to pass to `Model`. + y: Ground-truth targets to pass to `Model`. + sample_weight: Sample weight for each element. + + Returns: + Tuple in the format used in `Model.fit()`. + """ + if y is None: + # For single x-input, we do no tuple wrapping since in this case + # there is no ambiguity. This also makes NumPy and Dataset + # consistent in that the user does not have to wrap their Dataset + # data in an unnecessary tuple. + if not isinstance(x, (tuple, list)): + return x + else: + return (x,) + elif sample_weight is None: + return (x, y) + else: + return (x, y, sample_weight) + + +def list_to_tuple(maybe_list): + """Datasets will stack any list of tensors, so we convert them to tuples.""" + if isinstance(maybe_list, list): + return tuple(maybe_list) + return maybe_list + + +def check_data_cardinality(data): + num_samples = set(int(i.shape[0]) for i in tree.flatten(data)) + if len(num_samples) > 1: + msg = ( + "Data cardinality is ambiguous. " + "Make sure all arrays contain the same number of samples." + ) + for label, single_data in zip(["x", "y", "sample_weight"], data): + sizes = ", ".join( + str(i.shape[0]) for i in tree.flatten(single_data) + ) + msg += f"'{label}' sizes: {sizes}\n" + raise ValueError(msg) + + +def class_weight_to_sample_weights(y, class_weight): + # Convert to numpy to ensure consistent handling of operations + # (e.g., np.round()) across frameworks like TensorFlow, JAX, and PyTorch + + y_numpy = ops.convert_to_numpy(y) + sample_weight = np.ones(shape=(y_numpy.shape[0],), dtype=backend.floatx()) + if len(y_numpy.shape) > 1: + if y_numpy.shape[-1] != 1: + y_numpy = np.argmax(y_numpy, axis=-1) + else: + y_numpy = np.squeeze(y_numpy, axis=-1) + y_numpy = np.round(y_numpy).astype("int32") + + for i in range(y_numpy.shape[0]): + sample_weight[i] = class_weight.get(int(y_numpy[i]), 1.0) + return sample_weight + + +def get_tensor_spec(batches): + """Return the common tensor spec for a list of batches. + + Args: + batches: list of structures of tensors. The structures must be + identical, but the shape at each leaf may be different. + Returns: the common tensor spec for all the batches. + """ + from keras.src.utils.module_utils import tensorflow as tf + + def get_single_tensor_spec(*tensors): + x = tensors[0] + rank = len(x.shape) + if rank < 1: + raise ValueError( + "When passing a dataset to a Keras model, the arrays must " + f"be at least rank 1. Received: {x} of rank {len(x.shape)}." + ) + for t in tensors: + if len(t.shape) != rank: + raise ValueError( + "When passing a dataset to a Keras model, the " + "corresponding arrays in each batch must have the same " + f"rank. Received: {x} and {t}" + ) + shape = [] + # Merge shapes: go through each dimension one by one and keep the + # common values + for dims in zip(*[list(x.shape) for x in tensors]): + dims_set = set(dims) + shape.append(dims_set.pop() if len(dims_set) == 1 else None) + shape[0] = None # batch size may not be static + + dtype = backend.standardize_dtype(x.dtype) + if isinstance(x, tf.RaggedTensor): + return tf.RaggedTensorSpec(shape=shape, dtype=dtype) + if ( + isinstance(x, tf.SparseTensor) + or is_scipy_sparse(x) + or is_jax_sparse(x) + ): + return tf.SparseTensorSpec(shape=shape, dtype=dtype) + else: + return tf.TensorSpec(shape=shape, dtype=dtype) + + return tree.map_structure(get_single_tensor_spec, *batches) + + +def get_jax_iterator(iterable): + import jax + import jax.experimental.sparse as jax_sparse + + def convert_to_jax_compatible(x): + if isinstance(x, (jax.Array, jax_sparse.JAXSparse, np.ndarray)): + return x + elif is_scipy_sparse(x): + return scipy_sparse_to_jax_sparse(x) + elif is_tensorflow_sparse(x): + return tf_sparse_to_jax_sparse(x) + else: + return np.asarray(x) + + for batch in iterable: + yield tree.map_structure(convert_to_jax_compatible, batch) + + +def get_numpy_iterator(iterable): + def convert_to_numpy(x): + if not isinstance(x, np.ndarray): + # Using `__array__` should handle `tf.Tensor`, `jax.np.ndarray`, + # `torch.Tensor`, as well as any other tensor-like object that + # has added numpy support. + if hasattr(x, "__array__"): + if is_torch_tensor(x): + x = x.cpu() + x = np.asarray(x) + return x + + for batch in iterable: + yield tree.map_structure(convert_to_numpy, batch) + + +def get_torch_dataloader(iterable): + import torch.utils.data as torch_data + + from keras.src.backend.torch.core import convert_to_tensor + + class ConverterIterableDataset(torch_data.IterableDataset): + def __init__(self, iterable): + self.iterable = iterable + + def __iter__(self): + for batch in self.iterable: + yield tree.map_structure(convert_to_tensor, batch) + + dataset = ConverterIterableDataset(iterable) + # `batch_size=None` indicates that we should not re-batch + return torch_data.DataLoader(dataset, batch_size=None) + + +def is_tensorflow_tensor(value): + if hasattr(value, "__class__"): + if value.__class__.__name__ in ("RaggedTensor", "SparseTensor"): + return "tensorflow.python." in str(value.__class__.__module__) + for parent in value.__class__.__mro__: + if parent.__name__ in ("Tensor") and "tensorflow.python." in str( + parent.__module__ + ): + return True + return False + + +def is_tensorflow_ragged(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__name__ == "RaggedTensor" + and "tensorflow.python." in str(value.__class__.__module__) + ) + return False + + +def is_tensorflow_sparse(value): + if hasattr(value, "__class__"): + return ( + value.__class__.__name__ == "SparseTensor" + and "tensorflow.python." in str(value.__class__.__module__) + ) + return False + + +def is_jax_array(value): + if hasattr(value, "__class__"): + for parent in value.__class__.__mro__: + if parent.__name__ == "Array" and str(parent.__module__) == "jax": + return True + return is_jax_sparse(value) # JAX sparse arrays do not extend jax.Array + + +def is_jax_sparse(value): + if hasattr(value, "__class__"): + return str(value.__class__.__module__).startswith( + "jax.experimental.sparse" + ) + return False + + +def is_torch_tensor(value): + if hasattr(value, "__class__"): + for parent in value.__class__.__mro__: + if parent.__name__ == "Tensor" and str(parent.__module__).endswith( + "torch" + ): + return True + return False + + +def is_scipy_sparse(x): + return str(x.__class__.__module__).startswith("scipy.sparse") and hasattr( + x, "tocoo" + ) + + +def scipy_sparse_to_tf_sparse(x): + from keras.src.utils.module_utils import tensorflow as tf + + coo = x.tocoo() + indices = np.concatenate( + (np.expand_dims(coo.row, 1), np.expand_dims(coo.col, 1)), axis=1 + ) + return tf.SparseTensor(indices, coo.data, coo.shape) + + +def scipy_sparse_to_jax_sparse(x): + import jax + import jax.experimental.sparse as jax_sparse + + with jax.default_device(jax.local_devices(backend="cpu")[0]): + return jax_sparse.BCOO.from_scipy_sparse(x) + + +def tf_sparse_to_jax_sparse(x): + import jax + import jax.experimental.sparse as jax_sparse + + values = np.asarray(x.values) + indices = np.asarray(x.indices) + with jax.default_device(jax.local_devices(backend="cpu")[0]): + return jax_sparse.BCOO((values, indices), shape=x.shape) + + +def jax_sparse_to_tf_sparse(x): + from keras.src.utils.module_utils import tensorflow as tf + + return tf.SparseTensor(x.indices, x.data, x.shape) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..50603e99c7d6c3ec9dcdf8ceb17988b2493aedbf --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/generator_data_adapter.py @@ -0,0 +1,87 @@ +import itertools + +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class GeneratorDataAdapter(DataAdapter): + """Adapter for Python generators.""" + + def __init__(self, generator): + first_batches, generator = peek_and_restore(generator) + self.generator = generator + self._first_batches = first_batches + self._output_signature = None + if not isinstance(first_batches[0], tuple): + raise ValueError( + "When passing a Python generator to a Keras model, " + "the generator must return a tuple, either " + "(input,) or (inputs, targets) or " + "(inputs, targets, sample_weights). " + f"Received: {first_batches[0]}" + ) + + def get_numpy_iterator(self): + return data_adapter_utils.get_numpy_iterator(self.generator()) + + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self.generator()) + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + def convert_to_tf(x, spec): + if data_adapter_utils.is_scipy_sparse(x): + x = data_adapter_utils.scipy_sparse_to_tf_sparse(x) + elif data_adapter_utils.is_jax_sparse(x): + x = data_adapter_utils.jax_sparse_to_tf_sparse(x) + if not spec.shape.is_compatible_with(x.shape): + raise TypeError( + f"Generator yielded an element of shape {x.shape} where " + f"an element of shape {spec.shape} was expected. Your " + "generator provides tensors with variable input " + "dimensions other than the batch size. Make sure that the " + "generator's first two batches do not have the same " + "dimension value wherever there is a variable input " + "dimension." + ) + return x + + def get_tf_iterator(): + for batch in self.generator(): + batch = tree.map_structure( + convert_to_tf, batch, self._output_signature + ) + yield batch + + if self._output_signature is None: + self._output_signature = data_adapter_utils.get_tensor_spec( + self._first_batches + ) + ds = tf.data.Dataset.from_generator( + get_tf_iterator, + output_signature=self._output_signature, + ) + ds = ds.prefetch(tf.data.AUTOTUNE) + return ds + + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self.generator()) + + @property + def num_batches(self): + return None + + @property + def batch_size(self): + return None + + +def peek_and_restore(generator): + batches = list( + itertools.islice( + generator, data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + ) + ) + return batches, lambda: itertools.chain(batches, generator) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..06e61122493889af51f77f6a3e096b0552b28d5d --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py @@ -0,0 +1,692 @@ +import itertools +import multiprocessing.dummy +import queue +import random +import threading +import warnings +import weakref +from contextlib import closing + +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +@keras_export(["keras.utils.PyDataset", "keras.utils.Sequence"]) +class PyDataset: + """Base class for defining a parallel dataset using Python code. + + Every `PyDataset` must implement the `__getitem__()` and the `__len__()` + methods. If you want to modify your dataset between epochs, + you may additionally implement `on_epoch_end()`, + or `on_epoch_begin` to be called at the start of each epoch. + The `__getitem__()` method should return a complete batch + (not a single sample), and the `__len__` method should return + the number of batches in the dataset (rather than the number of samples). + + Args: + workers: Number of workers to use in multithreading or + multiprocessing. + use_multiprocessing: Whether to use Python multiprocessing for + parallelism. Setting this to `True` means that your + dataset will be replicated in multiple forked processes. + This is necessary to gain compute-level (rather than I/O level) + benefits from parallelism. However it can only be set to + `True` if your dataset can be safely pickled. + max_queue_size: Maximum number of batches to keep in the queue + when iterating over the dataset in a multithreaded or + multiprocessed setting. + Reduce this value to reduce the CPU memory consumption of + your dataset. Defaults to 10. + + Notes: + + - `PyDataset` is a safer way to do multiprocessing. + This structure guarantees that the model will only train + once on each sample per epoch, which is not the case + with Python generators. + - The arguments `workers`, `use_multiprocessing`, and `max_queue_size` + exist to configure how `fit()` uses parallelism to iterate + over the dataset. They are not being used by the `PyDataset` class + directly. When you are manually iterating over a `PyDataset`, + no parallelism is applied. + + Example: + + ```python + from skimage.io import imread + from skimage.transform import resize + import numpy as np + import math + + # Here, `x_set` is list of path to the images + # and `y_set` are the associated classes. + + class CIFAR10PyDataset(keras.utils.PyDataset): + + def __init__(self, x_set, y_set, batch_size, **kwargs): + super().__init__(**kwargs) + self.x, self.y = x_set, y_set + self.batch_size = batch_size + + def __len__(self): + # Return number of batches. + return math.ceil(len(self.x) / self.batch_size) + + def __getitem__(self, idx): + # Return x, y for batch idx. + low = idx * self.batch_size + # Cap upper bound at array length; the last batch may be smaller + # if the total number of items is not a multiple of batch size. + high = min(low + self.batch_size, len(self.x)) + batch_x = self.x[low:high] + batch_y = self.y[low:high] + + return np.array([ + resize(imread(file_name), (200, 200)) + for file_name in batch_x]), np.array(batch_y) + ``` + """ + + def __init__(self, workers=1, use_multiprocessing=False, max_queue_size=10): + self._workers = workers + self._use_multiprocessing = use_multiprocessing + self._max_queue_size = max_queue_size + + def _warn_if_super_not_called(self): + warn = False + if not hasattr(self, "_workers"): + self._workers = 1 + warn = True + if not hasattr(self, "_use_multiprocessing"): + self._use_multiprocessing = False + warn = True + if not hasattr(self, "_max_queue_size"): + self._max_queue_size = 10 + warn = True + if warn: + warnings.warn( + "Your `PyDataset` class should call " + "`super().__init__(**kwargs)` in its constructor. " + "`**kwargs` can include `workers`, " + "`use_multiprocessing`, `max_queue_size`. Do not pass " + "these arguments to `fit()`, as they will be ignored.", + stacklevel=2, + ) + + @property + def workers(self): + self._warn_if_super_not_called() + return self._workers + + @workers.setter + def workers(self, value): + self._workers = value + + @property + def use_multiprocessing(self): + self._warn_if_super_not_called() + return self._use_multiprocessing + + @use_multiprocessing.setter + def use_multiprocessing(self, value): + self._use_multiprocessing = value + + @property + def max_queue_size(self): + self._warn_if_super_not_called() + return self._max_queue_size + + @max_queue_size.setter + def max_queue_size(self, value): + self._max_queue_size = value + + def __getitem__(self, index): + """Gets batch at position `index`. + + Args: + index: position of the batch in the PyDataset. + + Returns: + A batch + """ + raise NotImplementedError + + @property + def num_batches(self): + """Number of batches in the PyDataset. + + Returns: + The number of batches in the PyDataset or `None` to indicate that + the dataset is infinite. + """ + # For backwards compatibility, support `__len__`. + if hasattr(self, "__len__"): + return len(self) + raise NotImplementedError( + "You need to implement the `num_batches` property:\n\n" + "@property\ndef num_batches(self):\n return ..." + ) + + def on_epoch_begin(self): + """Method called at the beginning of every epoch.""" + pass + + def on_epoch_end(self): + """Method called at the end of every epoch.""" + pass + + +class PyDatasetAdapter(DataAdapter): + """Adapter for `keras.utils.PyDataset` instances.""" + + def __init__( + self, + x, + class_weight=None, + shuffle=False, + ): + self.py_dataset = x + self.class_weight = class_weight + self.enqueuer = None + self.shuffle = shuffle + self._output_signature = None + self._within_epoch = False + + workers = self.py_dataset.workers + use_multiprocessing = self.py_dataset.use_multiprocessing + if workers > 1 or (workers > 0 and use_multiprocessing): + self.enqueuer = OrderedEnqueuer( + self.py_dataset, + workers=workers, + use_multiprocessing=use_multiprocessing, + max_queue_size=self.py_dataset.max_queue_size, + shuffle=self.shuffle, + ) + + def _standardize_batch(self, batch): + if isinstance(batch, dict): + return batch + if isinstance(batch, np.ndarray): + batch = (batch,) + if isinstance(batch, list): + batch = tuple(batch) + if not isinstance(batch, tuple) or len(batch) not in {1, 2, 3}: + raise ValueError( + "PyDataset.__getitem__() must return a tuple or a dict. " + "If a tuple, it must be ordered either " + "(input,) or (inputs, targets) or " + "(inputs, targets, sample_weights). " + f"Received: {str(batch)[:100]}... of type {type(batch)}" + ) + if self.class_weight is not None: + if len(batch) == 3: + raise ValueError( + "You cannot specify `class_weight` " + "and `sample_weight` at the same time." + ) + if len(batch) == 2: + sw = data_adapter_utils.class_weight_to_sample_weights( + batch[1], self.class_weight + ) + batch = batch + (sw,) + return batch + + def _infinite_generator(self): + for i in itertools.count(): + yield self._standardize_batch(self.py_dataset[i]) + + def _finite_generator(self): + indices = range(self.py_dataset.num_batches) + if self.shuffle: + indices = list(indices) + random.shuffle(indices) + + for i in indices: + yield self._standardize_batch(self.py_dataset[i]) + + def _infinite_enqueuer_generator(self): + self.enqueuer.start() + for batch in self.enqueuer.get(): + yield self._standardize_batch(batch) + + def _finite_enqueuer_generator(self): + self.enqueuer.start() + num_batches = self.py_dataset.num_batches + for i, batch in enumerate(self.enqueuer.get()): + yield self._standardize_batch(batch) + if i >= num_batches - 1: + self.enqueuer.stop() + return + + def _get_iterator(self): + if self.enqueuer is None: + if self.py_dataset.num_batches is None: + return self._infinite_generator() + else: + return self._finite_generator() + else: + if self.py_dataset.num_batches is None: + return self._infinite_enqueuer_generator() + else: + return self._finite_enqueuer_generator() + + def get_numpy_iterator(self): + return data_adapter_utils.get_numpy_iterator(self._get_iterator()) + + def get_jax_iterator(self): + return data_adapter_utils.get_jax_iterator(self._get_iterator()) + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + num_batches = self.py_dataset.num_batches + if self._output_signature is None: + num_samples = data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC + if num_batches is not None: + num_samples = min(num_samples, num_batches) + batches = [ + self._standardize_batch(self.py_dataset[i]) + for i in range(num_samples) + ] + if len(batches) == 0: + raise ValueError("The PyDataset has length 0") + self._output_signature = data_adapter_utils.get_tensor_spec(batches) + + ds = tf.data.Dataset.from_generator( + self._get_iterator, + output_signature=self._output_signature, + ) + if self.enqueuer is not None: + # The enqueuer does its own multithreading / multiprocesssing to + # prefetch items. Disable the tf.data.Dataset prefetching and + # threading as it interferes. + options = tf.data.Options() + options.autotune.enabled = False + options.threading.private_threadpool_size = 1 + ds = ds.with_options(options) + else: + ds = ds.prefetch(tf.data.AUTOTUNE) + return ds + + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self._get_iterator()) + + def on_epoch_begin(self): + if self._within_epoch: + raise ValueError( + "`on_epoch_begin` was called twice without `on_epoch_end` " + "having been called." + ) + self._within_epoch = True + if self.enqueuer: + self.enqueuer.start() + self.py_dataset.on_epoch_begin() + + def on_epoch_end(self): + if self.enqueuer: + self.enqueuer.stop() + self.py_dataset.on_epoch_end() + self._within_epoch = False + + @property + def num_batches(self): + return self.py_dataset.num_batches + + @property + def batch_size(self): + return None + + +# Global variables to be shared across processes +_SHARED_SEQUENCES = {} +# We use a Value to provide unique id to different processes. +_SEQUENCE_COUNTER = None + + +# Because multiprocessing pools are inherently unsafe, starting from a clean +# state can be essential to avoiding deadlocks. In order to accomplish this, we +# need to be able to check on the status of Pools that we create. +_DATA_POOLS = weakref.WeakSet() +_WORKER_ID_QUEUE = None # Only created if needed. +_FORCE_THREADPOOL = False + + +def get_pool_class(use_multiprocessing): + global _FORCE_THREADPOOL + if not use_multiprocessing or _FORCE_THREADPOOL: + return multiprocessing.dummy.Pool # ThreadPool + return multiprocessing.Pool + + +def get_worker_id_queue(): + """Lazily create the queue to track worker ids.""" + global _WORKER_ID_QUEUE + if _WORKER_ID_QUEUE is None: + _WORKER_ID_QUEUE = multiprocessing.Queue() + return _WORKER_ID_QUEUE + + +def get_index(uid, i): + """Get the value from the PyDataset `uid` at index `i`. + + To allow multiple PyDatasets to be used at the same time, we use `uid` to + get a specific one. A single PyDataset would cause the validation to + overwrite the training PyDataset. + + This methods is called from worker threads. + + Args: + uid: int, PyDataset identifier + i: index + + Returns: + The value at index `i`. + """ + return _SHARED_SEQUENCES[uid][i] + + +class PyDatasetEnqueuer: + """Base class to enqueue inputs. + + The task of an Enqueuer is to use parallelism to speed up preprocessing. + This is done with processes or threads. + + Example: + + ```python + enqueuer = PyDatasetEnqueuer(...) + enqueuer.start() + datas = enqueuer.get() + for data in datas: + # Use the inputs; training, evaluating, predicting. + # ... stop sometime. + enqueuer.stop() + ``` + + The `enqueuer.get()` should be an infinite stream of data. + """ + + def __init__( + self, + py_dataset, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + ): + self.py_dataset = py_dataset + + global _SEQUENCE_COUNTER + if _SEQUENCE_COUNTER is None: + try: + _SEQUENCE_COUNTER = multiprocessing.Value("i", 0) + except OSError: + # In this case the OS does not allow us to use + # multiprocessing. We resort to an int + # for enqueuer indexing. + _SEQUENCE_COUNTER = 0 + + if isinstance(_SEQUENCE_COUNTER, int): + self.uid = _SEQUENCE_COUNTER + _SEQUENCE_COUNTER += 1 + else: + # Doing Multiprocessing.Value += x is not process-safe. + with _SEQUENCE_COUNTER.get_lock(): + self.uid = _SEQUENCE_COUNTER.value + _SEQUENCE_COUNTER.value += 1 + + self.ready_queue = queue.Queue() + self.future_queue = queue.Queue(max_queue_size) + self.running = False + self.start_stop_lock = threading.Lock() + self.run_thread = None + if use_multiprocessing: + self.executor_fn = self._get_executor_init(workers) + else: + # We do not need the init since it's threads. + self.executor_fn = lambda _: get_pool_class(False)(workers) + + def is_running(self): + """Whether the enqueuer is running. + + This method is thread safe and called from many threads. + + Returns: boolean indicating whether this enqueuer is running. + """ + return self.running + + def start(self): + """Starts the handler's workers. + + This method is thread safe but is called from the main thread. + It is safe to call this method multiple times, extra calls are ignored. + """ + with self.start_stop_lock: + if self.running: + return + self.running = True + self.run_thread = threading.Thread(target=self._run) + self.run_thread.name = f"Worker_{self.uid}" + self.run_thread.daemon = True + self.run_thread.start() + + def stop(self, drain_queue_and_join=True): + """Stops running threads and wait for them to exit, if necessary. + + This method is thread safe and is called from various threads. Note that + the `drain_queue_and_join` argument must be set correctly. + It is safe to call this method multiple times, extra calls are ignored. + + Args: + drain_queue_and_join: set to True to drain the queue of pending + items and wait for the worker thread to complete. Set to False + if invoked from a worker thread to avoid deadlocks. Note that + setting this to False means this enqueuer won't be reused. + """ + with self.start_stop_lock: + if not self.running: + return + self.running = False + + if drain_queue_and_join: + # Drain the `future_queue` and put items in `ready_queue` for + # the next run. + while True: + try: + value = self.future_queue.get(block=True, timeout=0.1) + if isinstance(value, Exception): + raise value # Propagate exception from other thread + inputs = value.get() + self.future_queue.task_done() + if inputs is not None: + self.ready_queue.put(inputs) + except queue.Empty: + break + self.run_thread.join() + + self.run_thread = None + _SHARED_SEQUENCES[self.uid] = None + + def _send_py_dataset(self): + """Sends current Iterable to all workers.""" + # For new processes that may spawn + _SHARED_SEQUENCES[self.uid] = self.py_dataset + + def __del__(self): + self.stop(drain_queue_and_join=False) + + def _run(self): + """Submits request to the executor and queue the `Future` objects.""" + raise NotImplementedError + + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. + + Args: + workers: Number of workers. + + Returns: + Function, a Function to initialize the pool + """ + raise NotImplementedError + + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + This method is called from the main thread. + + Yields: + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + """ + raise NotImplementedError + + +class OrderedEnqueuer(PyDatasetEnqueuer): + """Builds a Enqueuer from a PyDataset. + + Args: + py_dataset: A `keras.utils.PyDataset` object. + use_multiprocessing: use multiprocessing if True, otherwise threading + shuffle: whether to shuffle the data at the beginning of each epoch + """ + + def __init__( + self, + py_dataset, + workers=1, + use_multiprocessing=False, + max_queue_size=10, + shuffle=False, + ): + super().__init__( + py_dataset, workers, use_multiprocessing, max_queue_size + ) + self.shuffle = shuffle + if self.py_dataset.num_batches is None: + # For infinite datasets, `self.indices` is created here once for all + # so that subsequent runs resume from where they stopped. + self.indices = itertools.count() + + def _get_executor_init(self, workers): + """Gets the Pool initializer for multiprocessing. + + Args: + workers: Number of workers. + + Returns: + Function, a Function to initialize the pool + """ + + def pool_fn(seqs): + pool = get_pool_class(True)( + workers, + initializer=init_pool_generator, + initargs=(seqs, None, get_worker_id_queue()), + ) + _DATA_POOLS.add(pool) + return pool + + return pool_fn + + def _run(self): + """Submits request to the executor and queue the `Future` objects. + + This method is the run method of worker threads. + """ + try: + if self.py_dataset.num_batches is not None: + # For finite datasets, `self.indices` is created here so that + # shuffling creates different a order each time. + indices = range(self.py_dataset.num_batches) + if self.shuffle: + indices = list(indices) + random.shuffle(indices) + self.indices = iter(indices) + self._send_py_dataset() # Share the initial py_dataset + + with closing(self.executor_fn(_SHARED_SEQUENCES)) as executor: + while self.is_running(): + try: + i = next(self.indices) + self.future_queue.put( + executor.apply_async(get_index, (self.uid, i)), + block=True, + ) + except StopIteration: + break + except Exception as e: + self.future_queue.put(e) # Report exception + + def get(self): + """Creates a generator to extract data from the queue. + + Skip the data if it is `None`. + + This method is called from the main thread. + + Yields: + The next element in the queue, i.e. a tuple + `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + """ + while self.is_running(): + try: + inputs = self.ready_queue.get(block=False) + yield inputs + continue # Retry the ready_queue + except queue.Empty: + pass + + try: + value = self.future_queue.get(block=True, timeout=5) + self.future_queue.task_done() + if isinstance(value, Exception): + raise value # Propagate exception from other thread + inputs = value.get() + if inputs is not None: + yield inputs + except queue.Empty: + pass + except Exception as e: + self.stop(drain_queue_and_join=True) + raise e + + # Note that it is ok to poll the iterator after the initial `start`, + # which may happen before the first `on_epoch_begin`. But it's not ok to + # poll after `on_epoch_end`. + raise ValueError( + "Iterator called after `on_epoch_end` or before `on_epoch_begin`." + ) + + +def init_pool_generator(gens, random_seed=None, id_queue=None): + """Initializer function for pool workers. + + Args: + gens: State which should be made available to worker processes. + random_seed: An optional value with which to seed child processes. + id_queue: A multiprocessing Queue of worker ids. + This is used to indicate that a worker process + was created by Keras. + """ + global _SHARED_SEQUENCES + _SHARED_SEQUENCES = gens + + worker_proc = multiprocessing.current_process() + + # name isn't used for anything, but setting a more descriptive name is + # helpful when diagnosing orphaned processes. + worker_proc.name = f"Keras_worker_{worker_proc.name}" + + if random_seed is not None: + np.random.seed(random_seed + worker_proc.ident) + + if id_queue is not None: + # If a worker dies during init, the pool will just create a replacement. + id_queue.put(worker_proc.ident, block=True, timeout=0.1) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..26b0da3d41ad22d054cca73e133021e512cf5147 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/tf_dataset_adapter.py @@ -0,0 +1,141 @@ +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class TFDatasetAdapter(DataAdapter): + """Adapter that handles `tf.data.Dataset`.""" + + def __init__(self, dataset, class_weight=None, distribution=None): + """Initialize the TFDatasetAdapter. + + Args: + dataset: The input `tf.data.Dataset` instance. + class_weight: A map where the keys are integer class ids and values + are the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}`. + distribution: A `keras.distribution.Distribution` instance. Used to + shard the input dataset into per worker/process dataset + instance. + """ + from keras.src.utils.module_utils import tensorflow as tf + + if not isinstance( + dataset, (tf.data.Dataset, tf.distribute.DistributedDataset) + ): + raise ValueError( + "Expected argument `dataset` to be a tf.data.Dataset. " + f"Received: {dataset}" + ) + if class_weight is not None: + dataset = dataset.map( + make_class_weight_map_fn(class_weight) + ).prefetch(tf.data.AUTOTUNE) + if distribution is not None: + dataset = distribution.distribute_dataset(dataset) + self._dataset = dataset + + def get_numpy_iterator(self): + from keras.src.backend.tensorflow.core import convert_to_numpy + + for batch in self._dataset: + yield tree.map_structure(convert_to_numpy, batch) + + def get_jax_iterator(self): + from keras.src.backend.tensorflow.core import convert_to_numpy + from keras.src.utils.module_utils import tensorflow as tf + + def convert_to_jax(x): + if isinstance(x, tf.SparseTensor): + return data_adapter_utils.tf_sparse_to_jax_sparse(x) + else: + # We use numpy as an intermediary because it is faster. + return convert_to_numpy(x) + + for batch in self._dataset: + yield tree.map_structure(convert_to_jax, batch) + + def get_tf_dataset(self): + return self._dataset + + def get_torch_dataloader(self): + return data_adapter_utils.get_torch_dataloader(self._dataset) + + @property + def num_batches(self): + cardinality = self._dataset.cardinality + if callable(cardinality): + # `dataset.cardinality` is normally expected to be a callable. + cardinality = int(self._dataset.cardinality()) + else: + # However, in the case of `DistributedDataset`, it's a np.int64. + cardinality = int(cardinality) + # Return None for Unknown and Infinite cardinality datasets + if cardinality < 0: + return None + return cardinality + + @property + def batch_size(self): + first_element_spec = tree.flatten(self._dataset.element_spec)[0] + return first_element_spec.shape[0] + + @property + def has_partial_batch(self): + return None + + @property + def partial_batch_size(self): + return None + + +def make_class_weight_map_fn(class_weight): + """Applies class weighting to a `Dataset`. + + The `Dataset` is assumed to be in format `(x, y)` or `(x, y, sw)`, where + `y` must be a single `Tensor`. + + Args: + class_weight: A map where the keys are integer class ids and values are + the class weights, e.g. `{0: 0.2, 1: 0.6, 2: 0.3}` + + Returns: + A function that can be used with `tf.data.Dataset.map` to apply class + weighting. + """ + from keras.src.utils.module_utils import tensorflow as tf + + class_weight_tensor = tf.convert_to_tensor( + [ + class_weight.get(int(c), 1.0) + for c in range(max(class_weight.keys()) + 1) + ] + ) + + def class_weights_map_fn(*data): + """Convert `class_weight` to `sample_weight`.""" + x, y, sw = data_adapter_utils.unpack_x_y_sample_weight(data) + if sw is not None: + raise ValueError( + "You cannot `class_weight` and `sample_weight` " + "at the same time." + ) + if tree.is_nested(y): + raise ValueError( + "`class_weight` is only supported for Models with a single " + "output." + ) + + if y.shape.rank >= 2: + y_classes = tf.__internal__.smart_cond.smart_cond( + tf.shape(y)[-1] > 1, + lambda: tf.argmax(y, axis=-1, output_type=tf.int32), + lambda: tf.cast(tf.round(tf.squeeze(y, axis=-1)), tf.int32), + ) + else: + # Special casing for rank 1, where we can guarantee sparse encoding. + y_classes = tf.cast(tf.round(y), tf.int32) + + cw = tf.gather(class_weight_tensor, y_classes) + return x, y, cw + + return class_weights_map_fn diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..8aeb4511029f5cbc65abade873adc9969f90a27c --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/data_adapters/torch_data_loader_adapter.py @@ -0,0 +1,83 @@ +import itertools + +import numpy as np + +from keras.src import tree +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.trainers.data_adapters.data_adapter import DataAdapter + + +class TorchDataLoaderAdapter(DataAdapter): + """Adapter that handles `torch.utils.data.DataLoader`.""" + + def __init__(self, dataloader): + import torch + + if not isinstance(dataloader, torch.utils.data.DataLoader): + raise ValueError( + f"Expected argument `dataloader` to be an instance of" + f"`torch.utils.data.DataLoader`. Received: {dataloader}" + ) + + self._dataloader = dataloader + self._output_signature = None + self._batch_size = dataloader.batch_size + self._num_batches = None + self._partial_batch_size = None + if hasattr(dataloader.dataset, "__len__"): + self._num_batches = len(dataloader) + if self._batch_size is not None: + self._partial_batch_size = ( + len(dataloader.dataset) % self._batch_size + ) + + def get_numpy_iterator(self): + for batch in self._dataloader: + # shared memory using `np.asarray` + yield tuple( + tree.map_structure(lambda x: np.asarray(x.cpu()), batch) + ) + + def get_jax_iterator(self): + # We use numpy as an intermediary because it is faster. + return self.get_numpy_iterator() + + def get_tf_dataset(self): + from keras.src.utils.module_utils import tensorflow as tf + + if self._output_signature is None: + batches = list( + itertools.islice( + self._dataloader, + data_adapter_utils.NUM_BATCHES_FOR_TENSOR_SPEC, + ) + ) + self._output_signature = tuple( + data_adapter_utils.get_tensor_spec(batches) + ) + return tf.data.Dataset.from_generator( + self.get_numpy_iterator, + output_signature=self._output_signature, + ) + + def get_torch_dataloader(self): + return self._dataloader + + @property + def num_batches(self): + return self._num_batches + + @property + def batch_size(self): + return self._batch_size + + @property + def has_partial_batch(self): + if self._partial_batch_size: + return self._partial_batch_size > 0 + else: + return None + + @property + def partial_batch_size(self): + return self._partial_batch_size diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py new file mode 100644 index 0000000000000000000000000000000000000000..9f2e670be1f937c7da66c757dd483ea5008fcc3f --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/epoch_iterator.py @@ -0,0 +1,161 @@ +""" +Separation of concerns: + +DataAdapter: + - x, y + - sample_weight + - class_weight + - shuffle + - batch_size + - steps, as it relates to batch_size for array data + +EpochIterator: + - whether to yield numpy or tf data + - steps + - most argument validation + +Trainer: + - steps_per_execution + - validation_split + - validation_data + - callbacks + - validation_freq + - epochs + - initial_epoch + - any backend-specific concern such as distribution + +PyDataset: + - num_workers + - use_multiprocessing + - max_queue_size + +EpochIterator steps: + +1. Look at data type and select correct DataHandler +2. Instantiate DataHandler with correct arguments +3. Raise or warn on unused arguments +4. in __iter__, iterate, either for a fixed number of steps +or until there is no data + +""" + +import contextlib +import warnings + +from keras.src.trainers import data_adapters + + +class EpochIterator: + def __init__( + self, + x, + y=None, + sample_weight=None, + batch_size=None, + steps_per_epoch=None, + shuffle=False, + class_weight=None, + steps_per_execution=1, + ): + self.steps_per_epoch = steps_per_epoch + self.steps_per_execution = steps_per_execution + self._current_iterator = None + self._epoch_iterator = None + self._steps_seen = 0 + self.data_adapter = data_adapters.get_data_adapter( + x=x, + y=y, + sample_weight=sample_weight, + batch_size=batch_size, + steps_per_epoch=steps_per_epoch, + shuffle=shuffle, + class_weight=class_weight, + ) + self._num_batches = self.data_adapter.num_batches + + def _get_iterator(self): + return self.data_adapter.get_numpy_iterator() + + def _interrupted_warning(self): + warnings.warn( + "Your input ran out of data; interrupting training. " + "Make sure that your dataset or generator can generate " + "at least `steps_per_epoch * epochs` batches. " + "You may need to use the `.repeat()` " + "function when building your dataset.", + stacklevel=2, + ) + + def reset(self): + self._current_iterator = None + self._num_batches = self.data_adapter.num_batches + self._steps_seen = 0 + self._epoch_iterator = None + self.data_adapter.on_epoch_end() + + def _enumerate_iterator(self): + self.data_adapter.on_epoch_begin() + steps_per_epoch = self.steps_per_epoch or self._num_batches or -1 + + if steps_per_epoch > 0: + if self._current_iterator is None or self.steps_per_epoch is None: + self._current_iterator = iter(self._get_iterator()) + self._steps_seen = 0 + for step in range(0, steps_per_epoch, self.steps_per_execution): + if self._num_batches and self._steps_seen >= self._num_batches: + if self.steps_per_epoch: + self._interrupted_warning() + break + self._steps_seen += self.steps_per_execution + yield step, self._current_iterator + if self._num_batches and self._steps_seen >= self._num_batches: + self._current_iterator = iter(self._get_iterator()) + self._steps_seen = 0 + else: + iterator = iter(self._get_iterator()) + step = -self.steps_per_execution + while True: + step += self.steps_per_execution + self._steps_seen = step + self.steps_per_execution + yield step, iterator + self.data_adapter.on_epoch_end() + + def __iter__(self): + self._epoch_iterator = self._enumerate_iterator() + return self + + def __next__(self): + buffer = [] + step, iterator = next(self._epoch_iterator) + with self.catch_stop_iteration(): + for _ in range(self.steps_per_execution): + data = next(iterator) + buffer.append(data) + return step, buffer + if buffer: + return step, buffer + raise StopIteration + + def enumerate_epoch(self): + for step, data in self: + yield step, data + + @contextlib.contextmanager + def catch_stop_iteration(self): + """Catches errors when an iterator runs out of data.""" + try: + yield + except StopIteration: + if self._num_batches is None: + self._num_batches = self._steps_seen + self._interrupted_warning() + self._current_iterator = None + self.data_adapter.on_epoch_end() + + @property + def num_batches(self): + if self.steps_per_epoch: + return self.steps_per_epoch + # Either copied from the data_adapter, or + # inferred at the end of an iteration. + return self._num_batches diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..a0de303faf862303328275eacf8feb60652d516e --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/trainers/trainer.py @@ -0,0 +1,1147 @@ +import inspect +import platform +import warnings + +from keras.src import backend +from keras.src import metrics as metrics_module +from keras.src import ops +from keras.src import optimizers +from keras.src import tree +from keras.src.optimizers.loss_scale_optimizer import LossScaleOptimizer +from keras.src.saving import serialization_lib +from keras.src.trainers.compile_utils import CompileLoss +from keras.src.trainers.compile_utils import CompileMetrics +from keras.src.trainers.data_adapters import data_adapter_utils +from keras.src.utils import python_utils +from keras.src.utils import traceback_utils +from keras.src.utils import tracking + + +class Trainer: + def __init__(self): + self._lock = False + self._run_eagerly = False + self._jit_compile = None + self.compiled = False + self.loss = None + self.steps_per_execution = 1 + # Can be set by callbacks in on_train_begin + self._initial_epoch = None + self._compute_loss_has_training_arg = ( + "training" in inspect.signature(self.compute_loss).parameters + ) + + # Placeholders used in `compile` + self._compile_loss = None + self._compile_metrics = None + self._loss_tracker = None + + @traceback_utils.filter_traceback + @tracking.no_automatic_dependency_tracking + def compile( + self, + optimizer="rmsprop", + loss=None, + loss_weights=None, + metrics=None, + weighted_metrics=None, + run_eagerly=False, + steps_per_execution=1, + jit_compile="auto", + auto_scale_loss=True, + ): + """Configures the model for training. + + Example: + + ```python + model.compile( + optimizer=keras.optimizers.Adam(learning_rate=1e-3), + loss=keras.losses.BinaryCrossentropy(), + metrics=[ + keras.metrics.BinaryAccuracy(), + keras.metrics.FalseNegatives(), + ], + ) + ``` + + Args: + optimizer: String (name of optimizer) or optimizer instance. See + `keras.optimizers`. + loss: Loss function. May be a string (name of loss function), or + a `keras.losses.Loss` instance. See `keras.losses`. A + loss function is any callable with the signature + `loss = fn(y_true, y_pred)`, where `y_true` are the ground truth + values, and `y_pred` are the model's predictions. + `y_true` should have shape `(batch_size, d0, .. dN)` + (except in the case of sparse loss functions such as + sparse categorical crossentropy which expects integer arrays of + shape `(batch_size, d0, .. dN-1)`). + `y_pred` should have shape `(batch_size, d0, .. dN)`. + The loss function should return a float tensor. + loss_weights: Optional list or dictionary specifying scalar + coefficients (Python floats) to weight the loss contributions of + different model outputs. The loss value that will be minimized + by the model will then be the *weighted sum* of all individual + losses, weighted by the `loss_weights` coefficients. If a list, + it is expected to have a 1:1 mapping to the model's outputs. If + a dict, it is expected to map output names (strings) to scalar + coefficients. + metrics: List of metrics to be evaluated by the model during + training and testing. Each of this can be a string (name of a + built-in function), function or a `keras.metrics.Metric` + instance. See `keras.metrics`. Typically you will use + `metrics=['accuracy']`. A function is any callable with the + signature `result = fn(y_true, _pred)`. To specify different + metrics for different outputs of a multi-output model, you could + also pass a dictionary, such as + `metrics={'a':'accuracy', 'b':['accuracy', 'mse']}`. + You can also pass a list to specify a metric or a list of + metrics for each output, such as + `metrics=[['accuracy'], ['accuracy', 'mse']]` + or `metrics=['accuracy', ['accuracy', 'mse']]`. When you pass + the strings 'accuracy' or 'acc', we convert this to one of + `keras.metrics.BinaryAccuracy`, + `keras.metrics.CategoricalAccuracy`, + `keras.metrics.SparseCategoricalAccuracy` based on the + shapes of the targets and of the model output. A similar + conversion is done for the strings `"crossentropy"` + and `"ce"` as well. + The metrics passed here are evaluated without sample weighting; + if you would like sample weighting to apply, you can specify + your metrics via the `weighted_metrics` argument instead. + weighted_metrics: List of metrics to be evaluated and weighted by + `sample_weight` or `class_weight` during training and testing. + run_eagerly: Bool. If `True`, this model's forward pass + will never be compiled. It is recommended to leave this + as `False` when training (for best performance), + and to set it to `True` when debugging. + steps_per_execution: Int. The number of batches to run + during each a single compiled function call. Running multiple + batches inside a single compiled function call can + greatly improve performance on TPUs or small models with a large + Python overhead. At most, one full epoch will be run each + execution. If a number larger than the size of the epoch is + passed, the execution will be truncated to the size of the + epoch. Note that if `steps_per_execution` is set to `N`, + `Callback.on_batch_begin` and `Callback.on_batch_end` methods + will only be called every `N` batches (i.e. before/after + each compiled function execution). + Not supported with the PyTorch backend. + jit_compile: Bool or `"auto"`. Whether to use XLA compilation when + compiling a model. For `jax` and `tensorflow` backends, + `jit_compile="auto"` enables XLA compilation if the model + supports it, and disabled otherwise. + For `torch` backend, `"auto"` will default to eager + execution and `jit_compile=True` will run with `torch.compile` + with the `"inductor"` backend. + auto_scale_loss: Bool. If `True` and the model dtype policy is + `"mixed_float16"`, the passed optimizer will be automatically + wrapped in a `LossScaleOptimizer`, which will dynamically + scale the loss to prevent underflow. + """ + optimizer = optimizers.get(optimizer) + self.optimizer = optimizer + if ( + auto_scale_loss + and self.dtype_policy.name == "mixed_float16" + and self.optimizer + and not isinstance(self.optimizer, LossScaleOptimizer) + ): + self.optimizer = LossScaleOptimizer( + self.optimizer, name="loss_scale_optimizer" + ) + if hasattr(self, "output_names"): + output_names = self.output_names + else: + output_names = None + if loss is not None: + self._compile_loss = CompileLoss( + loss, loss_weights, output_names=output_names + ) + self.loss = loss + if metrics is not None or weighted_metrics is not None: + self._compile_metrics = CompileMetrics( + metrics, weighted_metrics, output_names=output_names + ) + if jit_compile == "auto": + if run_eagerly: + jit_compile = False + else: + jit_compile = self._resolve_auto_jit_compile() + if jit_compile and run_eagerly: + jit_compile = False + warnings.warn( + "If `run_eagerly` is True, then `jit_compile` " + "cannot also be True. Disabling `jit_compile`.", + stacklevel=2, + ) + + self.jit_compile = jit_compile + self.run_eagerly = run_eagerly + self.stop_training = False + self.compiled = True + self._loss_tracker = metrics_module.Mean(name="loss") + self.steps_per_execution = steps_per_execution + + self.train_function = None + self.test_function = None + self.predict_function = None + + self._compile_config = serialization_lib.SerializableDict( + optimizer=optimizer, + loss=loss, + loss_weights=loss_weights, + metrics=metrics, + weighted_metrics=weighted_metrics, + run_eagerly=run_eagerly, + steps_per_execution=steps_per_execution, + jit_compile=jit_compile, + ) + + @property + def jit_compile(self): + if self._jit_compile is None: + # Value was never set. Resolve it now. + self._jit_compile = self._resolve_auto_jit_compile() + return self._jit_compile + + @jit_compile.setter + def jit_compile(self, value): + if value and not model_supports_jit(self): + warnings.warn( + "Model doesn't support `jit_compile=True`. " + "Proceeding with `jit_compile=False`." + ) + self._jit_compile = False + else: + self._jit_compile = value + + def _resolve_auto_jit_compile(self): + if backend.backend() == "torch": + # jit_compile = "auto" with the pytorch backend defaults to eager + return False + + if backend.backend() == "tensorflow": + import tensorflow as tf + + devices = tf.config.list_physical_devices() + if not list(filter(lambda x: x.device_type != "CPU", devices)): + # Disable XLA on CPU-only machines. + return False + + if self._distribute_strategy: + # Disable XLA with tf.distribute + return False + + if model_supports_jit(self): + return True + return False + + @property + def run_eagerly(self): + return self._run_eagerly + + @run_eagerly.setter + def run_eagerly(self, value): + self._run_eagerly = value + + @property + def metrics(self): + # Order: loss tracker, individual loss trackers, compiled metrics, + # custom metrcis, sublayer metrics. + metrics = [] + if self.compiled: + if self._loss_tracker is not None: + metrics.append(self._loss_tracker) + if self._compile_metrics is not None: + metrics.append(self._compile_metrics) + if self._compile_loss is not None: + metrics.extend(self._compile_loss.metrics) + metrics.extend(self._metrics) + for layer in self._flatten_layers(include_self=False): + if isinstance(layer, Trainer): + # All Trainer-related metrics in sublayers should be ignored + # because a new Trainer has been instantiated. + continue + metrics.extend(layer.metrics) + return metrics + + @property + def metrics_names(self): + return [m.name for m in self.metrics] + + def reset_metrics(self): + for m in self.metrics: + m.reset_state() + + def _get_own_metrics(self): + metrics = [] + if self._loss_tracker is not None: + metrics.append(self._loss_tracker) + if self._compile_metrics is not None: + metrics.append(self._compile_metrics) + if self._compile_loss is not None: + metrics.extend(self._compile_loss.metrics) + metrics.extend(self._metrics) + return metrics + + def compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + """Compute the total loss, validate it, and return it. + + Subclasses can optionally override this method to provide custom loss + computation logic. + + Example: + + ```python + class MyModel(Model): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.loss_tracker = metrics.Mean(name='loss') + + def compute_loss(self, x, y, y_pred, sample_weight, training=True): + loss = ops.mean((y_pred - y) ** 2) + loss += ops.sum(self.losses) + self.loss_tracker.update_state(loss) + return loss + + def reset_metrics(self): + self.loss_tracker.reset_state() + + @property + def metrics(self): + return [self.loss_tracker] + + inputs = layers.Input(shape=(10,), name='my_input') + outputs = layers.Dense(10)(inputs) + model = MyModel(inputs, outputs) + model.add_loss(ops.sum(outputs)) + + optimizer = SGD() + model.compile(optimizer, loss='mse', steps_per_execution=10) + dataset = ... + model.fit(dataset, epochs=2, steps_per_epoch=10) + print(f"Custom loss: {model.loss_tracker.result()}") + ``` + + Args: + x: Input data. + y: Target data. + y_pred: Predictions returned by the model (output of `model(x)`) + sample_weight: Sample weights for weighting the loss function. + training: Whether we are training or evaluating the model. + + Returns: + The total loss as a scalar tensor, or `None` if no loss results + (which is the case when called by `Model.test_step`). + """ + # The default implementation does not use `x` or `training`. + del x + del training + losses = [] + if self._compile_loss is not None: + loss = self._compile_loss(y, y_pred, sample_weight) + if loss is not None: + losses.append(loss) + for loss in self.losses: + losses.append(self._aggregate_additional_loss(loss)) + if backend.backend() != "jax" and len(losses) == 0: + raise ValueError( + "No loss to compute. Provide a `loss` argument in `compile()`." + ) + if len(losses) == 1: + total_loss = losses[0] + elif len(losses) == 0: + total_loss = ops.zeros(()) + else: + total_loss = ops.sum(losses) + return total_loss + + def _compute_loss( + self, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + """Backwards compatibility wrapper for `compute_loss`. + + This should be used instead `compute_loss` within `train_step` and + `test_step` to support overrides of `compute_loss` that may not have + the `training` argument, as this argument was added in Keras 3.3. + """ + if self._compute_loss_has_training_arg: + return self.compute_loss( + x, y, y_pred, sample_weight, training=training + ) + else: + return self.compute_loss(x, y, y_pred, sample_weight) + + def _aggregate_additional_loss(self, loss): + """Aggregates losses from `add_loss`, regularizers and sublayers. + + Args: + loss: A tensor representing the additional loss to aggregate. + + Returns: + A tensor representing the summed loss, cast to the `floatx()` if + necessary. + """ + if not backend.is_float_dtype(loss.dtype): + loss = ops.cast(loss, dtype=backend.floatx()) + return ops.sum(loss) + + def stateless_compute_loss( + self, + trainable_variables, + non_trainable_variables, + metrics_variables, + x=None, + y=None, + y_pred=None, + sample_weight=None, + training=True, + ): + var_mapping = list(zip(self.trainable_variables, trainable_variables)) + var_mapping.extend( + zip(self.non_trainable_variables, non_trainable_variables) + ) + var_mapping.extend(zip(self.metrics_variables, metrics_variables)) + with backend.StatelessScope(state_mapping=var_mapping) as scope: + # Note that this is needed for the regularization loss, which need + # the latest value of train/non-trainable variables. + loss = self._compute_loss( + x, + y, + y_pred, + sample_weight=sample_weight, + training=training, + ) + + # Update non trainable vars (may have been updated in compute_loss) + non_trainable_variables = [] + for v in self.non_trainable_variables: + new_v = scope.get_current_value(v) + non_trainable_variables.append(new_v) + + # Update metrics vars (may have been updated in compute_loss) + metrics_variables = [] + for v in self.metrics_variables: + new_v = scope.get_current_value(v) + metrics_variables.append(new_v) + return loss, ( + trainable_variables, + non_trainable_variables, + metrics_variables, + ) + + def compute_metrics(self, x, y, y_pred, sample_weight=None): + """Update metric states and collect all metrics to be returned. + + Subclasses can optionally override this method to provide custom metric + updating and collection logic. Custom metrics are not passed in + `compile()`, they can be created in `__init__` or `build`. They are + automatically tracked and returned by `self.metrics`. + + Example: + + ```python + class MyModel(Sequential): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.custom_metric = MyMetric(name="custom_metric") + + def compute_metrics(self, x, y, y_pred, sample_weight): + # This super call updates metrics from `compile` and returns + # results for all metrics listed in `self.metrics`. + metric_results = super().compute_metrics( + x, y, y_pred, sample_weight) + + # `metric_results` contains the previous result for + # `custom_metric`, this is where we update it. + self.custom_metric.update_state(x, y, y_pred, sample_weight) + metric_results['custom_metric'] = self.custom_metric.result() + return metric_results + ``` + + Args: + x: Input data. + y: Target data. + y_pred: Predictions returned by the model output of `model.call(x)`. + sample_weight: Sample weights for weighting the loss function. + + Returns: + A `dict` containing values that will be passed to + `keras.callbacks.CallbackList.on_train_batch_end()`. Typically, + the values of the metrics listed in `self.metrics` are returned. + Example: `{'loss': 0.2, 'accuracy': 0.7}`. + """ + del x # The default implementation does not use `x`. + if self._compile_metrics is not None: + self._compile_metrics.update_state(y, y_pred, sample_weight) + return self.get_metrics_result() + + def get_metrics_result(self): + """Returns the model's metrics values as a dict. + + If any of the metric result is a dict (containing multiple metrics), + each of them gets added to the top level returned dict of this method. + + Returns: + A `dict` containing values of the metrics listed in `self.metrics`. + Example: `{'loss': 0.2, 'accuracy': 0.7}`. + """ + return_metrics = {} + for metric in self.metrics: + result = metric.result() + if isinstance(result, dict): + return_metrics.update(result) + else: + return_metrics[metric.name] = result + return python_utils.pythonify_logs(return_metrics) + + def fit( + self, + x=None, + y=None, + batch_size=None, + epochs=1, + verbose="auto", + callbacks=None, + validation_split=0.0, + validation_data=None, + shuffle=True, + class_weight=None, + sample_weight=None, + initial_epoch=0, + steps_per_epoch=None, + validation_steps=None, + validation_batch_size=None, + validation_freq=1, + ): + """Trains the model for a fixed number of epochs (dataset iterations). + + Args: + x: Input data. It can be: + - A NumPy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A backend-native tensor, or a list of tensors + (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, + if the model has named inputs. + - A `keras.utils.PyDataset` returning `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `tf.data.Dataset` yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `torch.utils.data.DataLoader` yielding `(inputs, targets)` + or `(inputs, targets, sample_weights)`. + - A Python generator function yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + y: Target data. Like the input data `x`, it can be either NumPy + array(s) or backend-native tensor(s). If `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or a Python generator function, + `y` should not be specified since targets will be obtained from + `x`. + batch_size: Integer or `None`. + Number of samples per gradient update. + If unspecified, `batch_size` will default to 32. + Do not specify the `batch_size` if your input data `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + epochs: Integer. Number of epochs to train the model. + An epoch is an iteration over the entire `x` and `y` + data provided + (unless the `steps_per_epoch` flag is set to + something other than None). + Note that in conjunction with `initial_epoch`, + `epochs` is to be understood as "final epoch". + The model is not trained for a number of iterations + given by `epochs`, but merely until the epoch + of index `epochs` is reached. + verbose: `"auto"`, 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = one line per epoch. + "auto" becomes 1 for most cases. + Note that the progress bar is not + particularly useful when logged to a file, + so `verbose=2` is recommended when not running interactively + (e.g., in a production environment). Defaults to `"auto"`. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during training. + See `keras.callbacks`. Note + `keras.callbacks.ProgbarLogger` and + `keras.callbacks.History` callbacks are created + automatically and need not be passed to `model.fit()`. + `keras.callbacks.ProgbarLogger` is created + or not based on the `verbose` argument in `model.fit()`. + validation_split: Float between 0 and 1. + Fraction of the training data to be used as validation data. + The model will set apart this fraction of the training data, + will not train on it, and will evaluate the loss and any model + metrics on this data at the end of each epoch. The validation + data is selected from the last samples in the `x` and `y` data + provided, before shuffling. + This argument is only supported when `x` and `y` are made of + NumPy arrays or tensors. + If both `validation_data` and `validation_split` are provided, + `validation_data` will override `validation_split`. + validation_data: Data on which to evaluate + the loss and any model metrics at the end of each epoch. + The model will not be trained on this data. Thus, note the fact + that the validation loss of data provided using + `validation_split` or `validation_data` is not affected by + regularization layers like noise and dropout. + `validation_data` will override `validation_split`. + It can be: + - A tuple `(x_val, y_val)` of NumPy arrays or tensors. + - A tuple `(x_val, y_val, val_sample_weights)` of NumPy + arrays. + - A `keras.utils.PyDataset`, a `tf.data.Dataset`, a + `torch.utils.data.DataLoader` yielding `(inputs, targets)` or a + Python generator function yielding `(x_val, y_val)` or + `(inputs, targets, sample_weights)`. + shuffle: Boolean, whether to shuffle the training data before each + epoch. This argument is ignored when `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) value, used for weighting the loss function + (during training only). + This can be useful to tell the model to + "pay more attention" to samples from + an under-represented class. When `class_weight` is specified + and targets have a rank of 2 or greater, either `y` must be + one-hot encoded, or an explicit final dimension of `1` must + be included for sparse class labels. + sample_weight: Optional NumPy array or tensor of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + NumPy array or tensor with the same length as the input samples + (1:1 mapping between weights and samples), or in the case of + temporal data, you can pass a 2D NumPy array or tensor with + shape `(samples, sequence_length)` to apply a different weight + to every timestep of every sample. + This argument is not supported when `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function. + Instead, provide `sample_weights` as the third element of `x`. + Note that sample weighting does not apply to metrics specified + via the `metrics` argument in `compile()`. To apply sample + weighting to your metrics, you can specify them via the + `weighted_metrics` in `compile()` instead. + initial_epoch: Integer. + Epoch at which to start training + (useful for resuming a previous training run). + steps_per_epoch: Integer or `None`. + Total number of steps (batches of samples) before declaring one + epoch finished and starting the next epoch. When training with + input tensors or NumPy arrays, the default `None` means that the + value used is the number of samples in your dataset divided by + the batch size, or 1 if that cannot be determined. + If `x` is a `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function, the + epoch will run until the input dataset is exhausted. When + passing an infinitely repeating dataset, you must specify the + `steps_per_epoch` argument, otherwise the training will run + indefinitely. + validation_steps: Integer or `None`. + Only relevant if `validation_data` is provided. + Total number of steps (batches of samples) to draw before + stopping when performing validation at the end of every epoch. + If `validation_steps` is `None`, validation will run until the + `validation_data` dataset is exhausted. In the case of an + infinitely repeating dataset, it will run indefinitely. If + `validation_steps` is specified and only part of the dataset + is consumed, the evaluation will start from the beginning of the + dataset at each epoch. This ensures that the same validation + samples are used every time. + validation_batch_size: Integer or `None`. + Number of samples per validation batch. + If unspecified, will default to `batch_size`. + Do not specify the `validation_batch_size` if your data is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + validation_freq: Only relevant if validation data is provided. + Specifies how many training epochs to run + before a new validation run is performed, + e.g. `validation_freq=2` runs validation every 2 epochs. + + Unpacking behavior for iterator-like inputs: + A common pattern is to pass an iterator like object such as a + `tf.data.Dataset` or a `keras.utils.PyDataset` to `fit()`, + which will in fact yield not only features (`x`) + but optionally targets (`y`) and sample weights (`sample_weight`). + Keras requires that the output of such iterator-likes be + unambiguous. The iterator should return a tuple + of length 1, 2, or 3, where the optional second and third elements + will be used for `y` and `sample_weight` respectively. + Any other type provided will be wrapped in + a length-one tuple, effectively treating everything as `x`. When + yielding dicts, they should still adhere to the top-level tuple + structure, + e.g. `({"x0": x0, "x1": x1}, y)`. Keras will not attempt to separate + features, targets, and weights from the keys of a single dict. + A notable unsupported data type is the `namedtuple`. The reason is + that it behaves like both an ordered datatype (tuple) and a mapping + datatype (dict). So given a namedtuple of the form: + `namedtuple("example_tuple", ["y", "x"])` + it is ambiguous whether to reverse the order of the elements when + interpreting the value. Even worse is a tuple of the form: + `namedtuple("other_tuple", ["x", "y", "z"])` + where it is unclear if the tuple was intended to be unpacked + into `x`, `y`, and `sample_weight` or passed through + as a single element to `x`. + + Returns: + A `History` object. Its `History.history` attribute is + a record of training loss values and metrics values + at successive epochs, as well as validation loss values + and validation metrics values (if applicable). + """ + raise NotImplementedError + + def evaluate( + self, + x=None, + y=None, + batch_size=None, + verbose="auto", + sample_weight=None, + steps=None, + callbacks=None, + return_dict=False, + **kwargs, + ): + """Returns the loss value & metrics values for the model in test mode. + + Computation is done in batches (see the `batch_size` arg.) + + Args: + x: Input data. It can be: + - A NumPy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A backend-native tensor, or a list of tensors + (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, + if the model has named inputs. + - A `keras.utils.PyDataset` returning `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `tf.data.Dataset` yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + - A `torch.utils.data.DataLoader` yielding `(inputs, targets)` + or `(inputs, targets, sample_weights)`. + - A Python generator function yielding `(inputs, targets)` or + `(inputs, targets, sample_weights)`. + y: Target data. Like the input data `x`, it can be either NumPy + array(s) or backend-native tensor(s). If `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or a Python generator function, + `y` should not be specified since targets will be obtained from + `x`. + batch_size: Integer or `None`. + Number of samples per batch of computation. + If unspecified, `batch_size` will default to 32. + Do not specify the `batch_size` if your input data `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + verbose: `"auto"`, 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = single line. + `"auto"` becomes 1 for most cases. + Note that the progress bar is not + particularly useful when logged to a file, so `verbose=2` is + recommended when not running interactively + (e.g. in a production environment). Defaults to `"auto"`. + sample_weight: Optional NumPy array or tensor of weights for + the training samples, used for weighting the loss function + (during training only). You can either pass a flat (1D) + NumPy array or tensor with the same length as the input samples + (1:1 mapping between weights and samples), or in the case of + temporal data, you can pass a 2D NumPy array or tensor with + shape `(samples, sequence_length)` to apply a different weight + to every timestep of every sample. + This argument is not supported when `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function. + Instead, provide `sample_weights` as the third element of `x`. + Note that sample weighting does not apply to metrics specified + via the `metrics` argument in `compile()`. To apply sample + weighting to your metrics, you can specify them via the + `weighted_metrics` in `compile()` instead. + steps: Integer or `None`. + Total number of steps (batches of samples) to draw before + declaring the evaluation round finished. If `steps` is `None`, + it will run until `x` is exhausted. In the case of an infinitely + repeating dataset, it will run indefinitely. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during evaluation. + return_dict: If `True`, loss and metric results are returned as a + dict, with each key being the name of the metric. + If `False`, they are returned as a list. + + Returns: + Scalar test loss (if the model has a single output and no metrics) + or list of scalars (if the model has multiple outputs + and/or metrics). The attribute `model.metrics_names` will give you + the display labels for the scalar outputs. + """ + raise NotImplementedError + + def predict( + self, x, batch_size=None, verbose="auto", steps=None, callbacks=None + ): + """Generates output predictions for the input samples. + + Computation is done in batches. This method is designed for batch + processing of large numbers of inputs. It is not intended for use inside + of loops that iterate over your data and process small numbers of inputs + at a time. + + For small numbers of inputs that fit in one batch, + directly use `__call__()` for faster execution, e.g., + `model(x)`, or `model(x, training=False)` if you have layers such as + `BatchNormalization` that behave differently during + inference. + + Note: See [this FAQ entry]( + https://keras.io/getting_started/faq/#whats-the-difference-between-model-methods-predict-and-call) + for more details about the difference between `Model` methods + `predict()` and `__call__()`. + + Args: + x: Input data. It can be: + - A NumPy array (or array-like), or a list of arrays + (in case the model has multiple inputs). + - A backend-native tensor, or a list of tensors + (in case the model has multiple inputs). + - A dict mapping input names to the corresponding array/tensors, + if the model has named inputs. + - A `keras.utils.PyDataset`. + - A `tf.data.Dataset`. + - A `torch.utils.data.DataLoader`. + - A Python generator function. + batch_size: Integer or `None`. + Number of samples per batch of computation. + If unspecified, `batch_size` will default to 32. + Do not specify the `batch_size` if your input data `x` is a + `keras.utils.PyDataset`, `tf.data.Dataset`, + `torch.utils.data.DataLoader` or Python generator function + since they generate batches. + verbose: `"auto"`, 0, 1, or 2. Verbosity mode. + 0 = silent, 1 = progress bar, 2 = single line. + `"auto"` becomes 1 for most cases. Note that the progress bar + is not particularly useful when logged to a file, + so `verbose=2` is recommended when not running interactively + (e.g. in a production environment). Defaults to `"auto"`. + steps: Total number of steps (batches of samples) to draw before + declaring the prediction round finished. If `steps` is `None`, + it will run until `x` is exhausted. In the case of an infinitely + repeating dataset, it will run indefinitely. + callbacks: List of `keras.callbacks.Callback` instances. + List of callbacks to apply during prediction. + + Returns: + NumPy array(s) of predictions. + """ + raise NotImplementedError + + def train_on_batch( + self, + x, + y=None, + sample_weight=None, + class_weight=None, + return_dict=False, + ): + """Runs a single gradient update on a single batch of data. + + Args: + x: Input data. Must be array-like. + y: Target data. Must be array-like. + sample_weight: Optional array of the same length as x, containing + weights to apply to the model's loss for each sample. + In the case of temporal data, you can pass a 2D array + with shape `(samples, sequence_length)`, to apply a different + weight to every timestep of every sample. + class_weight: Optional dictionary mapping class indices (integers) + to a weight (float) to apply to the model's loss for the samples + from this class during training. This can be useful to tell the + model to "pay more attention" to samples from an + under-represented class. When `class_weight` is specified + and targets have a rank of 2 or greater, either `y` must + be one-hot encoded, or an explicit final dimension of 1 + must be included for sparse class labels. + return_dict: If `True`, loss and metric results are returned as a + dict, with each key being the name of the metric. If `False`, + they are returned as a list. + + Returns: + A scalar loss value (when no metrics and `return_dict=False`), + a list of loss and metric values + (if there are metrics and `return_dict=False`), or a dict of + metric and loss values (if `return_dict=True`). + """ + raise NotImplementedError + + def test_on_batch( + self, + x, + y=None, + sample_weight=None, + return_dict=False, + ): + """Test the model on a single batch of samples. + + Args: + x: Input data. Must be array-like. + y: Target data. Must be array-like. + sample_weight: Optional array of the same length as x, containing + weights to apply to the model's loss for each sample. + In the case of temporal data, you can pass a 2D array + with shape `(samples, sequence_length)`, to apply a different + weight to every timestep of every sample. + return_dict: If `True`, loss and metric results are returned as a + dict, with each key being the name of the metric. If `False`, + they are returned as a list. + + Returns: + A scalar loss value (when no metrics and `return_dict=False`), + a list of loss and metric values + (if there are metrics and `return_dict=False`), or a dict of + metric and loss values (if `return_dict=True`). + """ + raise NotImplementedError + + def predict_on_batch(self, x): + """Returns predictions for a single batch of samples. + + Args: + x: Input data. It must be array-like. + + Returns: + NumPy array(s) of predictions. + """ + raise NotImplementedError + + def get_compile_config(self): + """Returns a serialized config with information for compiling the model. + + This method returns a config dictionary containing all the information + (optimizer, loss, metrics, etc.) with which the model was compiled. + + Returns: + A dict containing information for compiling the model. + """ + if self.compiled and hasattr(self, "_compile_config"): + return self._compile_config.serialize() + + def compile_from_config(self, config): + """Compiles the model with the information given in config. + + This method uses the information in the config (optimizer, loss, + metrics, etc.) to compile the model. + + Args: + config: Dict containing information for compiling the model. + """ + has_overridden_compile = self.__class__.compile != Trainer.compile + if has_overridden_compile: + warnings.warn( + "`compile()` was not called as part of model loading " + "because the model's `compile()` method is custom. " + "All subclassed Models that have `compile()` " + "overridden should also override " + "`get_compile_config()` and `compile_from_config(config)`. " + "Alternatively, you can " + "call `compile()` manually after loading.", + stacklevel=2, + ) + return + config = serialization_lib.deserialize_keras_object(config) + self.compile(**config) + if hasattr(self, "optimizer") and self.built: + # Create optimizer variables. + self.optimizer.build(self.trainable_variables) + + def _should_eval(self, epoch, validation_freq): + epoch = epoch + 1 # one-index the user-facing epoch. + if isinstance(validation_freq, int): + return epoch % validation_freq == 0 + elif isinstance(validation_freq, list): + return epoch in validation_freq + else: + raise ValueError( + "Expected `validation_freq` to be a list or int. " + f"Received: validation_freq={validation_freq} of the " + f"type {type(validation_freq)}." + ) + + def _get_metrics_result_or_logs(self, logs): + """Returns model metrics as a dict if the keys match with input logs. + + When the training / evaluation is performed with an asynchronous steps, + the last scheduled `train / test_step` may not give the latest metrics + because it is not guaranteed to be executed the last. This method gets + metrics from the model directly instead of relying on the return from + last step function. + + When the user has custom train / test step functions, the metrics + returned may be different from `Model.metrics`. In those instances, + this function will be no-op and return the logs passed in. + + Args: + logs: A `dict` of metrics returned by train / test step function. + + Returns: + A `dict` containing values of the metrics listed in `self.metrics` + when logs and model metrics keys match. Otherwise it returns input + `logs`. + """ + metric_logs = self.get_metrics_result() + # Verify that train / test step logs passed and metric logs have + # matching keys. It could be different when using custom step functions, + # in which case we return the logs from the last step. + if isinstance(logs, dict) and set(logs.keys()) == set( + metric_logs.keys() + ): + return metric_logs + return logs + + def _flatten_metrics_in_order(self, logs): + """Turns `logs` dict into a list as per key order of `metrics_names`.""" + metric_names = [] + for metric in self.metrics: + if isinstance(metric, CompileMetrics): + metric_names += [ + sub_metric.name for sub_metric in metric.metrics + ] + else: + metric_names.append(metric.name) + results = [] + for name in metric_names: + if name in logs: + results.append(logs[name]) + for key in sorted(logs.keys()): + if key not in metric_names: + results.append(logs[key]) + if len(results) == 1: + return results[0] + return results + + def _assert_compile_called(self, method_name=None): + if not self.compiled: + msg = "You must call `compile()` before " + if metrics_module: + msg += "using the model." + else: + msg += f"calling `{method_name}()`." + raise ValueError(msg) + + def _symbolic_build(self, iterator=None, data_batch=None): + model_unbuilt = not all(layer.built for layer in self._flatten_layers()) + compile_metrics_unbuilt = ( + self._compile_metrics is not None + and not self._compile_metrics.built + ) + compile_loss_unbuilt = ( + self._compile_loss is not None and not self._compile_loss.built + ) + optimizer_unbuilt = ( + self.optimizer is not None and not self.optimizer.built + ) + if model_unbuilt or compile_metrics_unbuilt or compile_loss_unbuilt: + # Create symbolic tensors matching an input batch. + + def to_symbolic_input(v): + if v is None: + return None + return backend.KerasTensor( + v.shape, backend.standardize_dtype(v.dtype) + ) + + if data_batch is None: + for _, data_or_iterator in iterator: + if isinstance(data_or_iterator, (list, tuple)): + data_batch = data_or_iterator[0] + else: + data_batch = next(data_or_iterator) + break + data_batch = tree.map_structure(to_symbolic_input, data_batch) + ( + x, + y, + sample_weight, + ) = data_adapter_utils.unpack_x_y_sample_weight(data_batch) + + # Build all model state with `backend.compute_output_spec`. + try: + y_pred = backend.compute_output_spec(self, x, training=False) + except Exception as e: + raise RuntimeError( + "Unable to automatically build the model. " + "Please build it yourself before calling " + "fit/evaluate/predict. " + "A model is 'built' when its variables have " + "been created and its `self.built` attribute " + "is True. Usually, calling the model on a batch " + "of data is the right way to build it.\n" + "Exception encountered:\n" + f"'{e}'" + ) + if compile_metrics_unbuilt: + # Build all metric state with `backend.compute_output_spec`. + backend.compute_output_spec( + self.compute_metrics, + x, + y, + y_pred, + sample_weight=sample_weight, + ) + if compile_loss_unbuilt: + # Build `CompileLoss` state with `backend.compute_output_spec`. + backend.compute_output_spec( + self._compute_loss, + x, + y, + y_pred, + sample_weight=sample_weight, + training=False, + ) + if optimizer_unbuilt: + # Build optimizer + self.optimizer.build(self.trainable_variables) + self._post_build() + + +def model_supports_jit(model): + # XLA not supported with TF on MacOS GPU + if platform.system() == "Darwin" and "arm" in platform.processor().lower(): + if backend.backend() == "tensorflow": + from keras.src.utils.module_utils import tensorflow as tf + + if tf.config.list_physical_devices("GPU"): + return False + # XLA not supported by some layers + if all(x.supports_jit for x in model._flatten_layers()): + if backend.backend() == "tensorflow": + from tensorflow.python.framework.config import ( + is_op_determinism_enabled, + ) + + if is_op_determinism_enabled(): + # disable XLA with determinism enabled since not all ops are + # supported by XLA with determinism enabled. + return False + return True + return False diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a719378ef3501f640d87ba9d24425b05e8f6a6c0 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__init__.py @@ -0,0 +1,12 @@ +from keras.src.tree.tree_api import assert_same_paths +from keras.src.tree.tree_api import assert_same_structure +from keras.src.tree.tree_api import flatten +from keras.src.tree.tree_api import flatten_with_path +from keras.src.tree.tree_api import is_nested +from keras.src.tree.tree_api import lists_to_tuples +from keras.src.tree.tree_api import map_shape_structure +from keras.src.tree.tree_api import map_structure +from keras.src.tree.tree_api import map_structure_up_to +from keras.src.tree.tree_api import pack_sequence_as +from keras.src.tree.tree_api import register_tree_node_class +from keras.src.tree.tree_api import traverse diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66c794fb30ae4b8504d6503f336ac815857b9bc6 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bdab73dcb386c4808a066c58cf5ac5af3a6a862 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/dmtree_impl.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b3c491a6999f0c3f5553de14c3441c7e3ec2aba Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/optree_impl.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec0e038fe99e0ff3e7bf26fded497837dc91574d Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/__pycache__/tree_api.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..ff9964b43d74be665f2914d1d907321bdf7de089 --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/dmtree_impl.py @@ -0,0 +1,394 @@ +import collections +import collections.abc +import itertools + +from keras.src.backend.config import backend +from keras.src.utils.module_utils import dmtree + +# NOTE: There are two known discrepancies between this `dmtree` implementation +# of the tree API and the `optree` implementation: +# +# 1. `map_structure` with *multiple* structures and `map_structure_up_to` do not +# use the object registration (they use the raw `dmtree.map_structure` and +# `dmtree.map_structure_up_to`). This only has consequences with two types of +# structures: +# - `TrackedSet` will not explored (considered as a leaf). +# - `OrderedDict` will be traversed in the order of sorted keys, not the +# order of the items. This is typically inconsequential because functions +# used with `map_structure` and `map_structure_up_to` are typically not +# order dependent and are, in fact, stateless. +# +# 2. The handling of non-sortable keys in dictionaries in inconsistent. `optree` +# uses the iteration order while `dmtree` raises an error. This is not an +# issue as keys are always strings. But this is the reason why we document +# non-sortable keys as unsupported (meaning behavior is undefined). + +REGISTERED_CLASSES = {} + +ClassRegistration = collections.namedtuple( + "ClassRegistration", ["flatten", "unflatten"] +) + + +class TypeErrorRemapping: + def __enter__(self): + pass + + def __exit__(self, exc_type, exc_value, traceback): + if exc_type is TypeError: + raise ValueError(exc_value).with_traceback(traceback) + return False + + +def register_tree_node( + cls, + flatten_func=None, + unflatten_func=None, +): + if flatten_func is None: + flatten_func = lambda x: x.tree_flatten() + if unflatten_func is None: + unflatten_func = cls.tree_unflatten + REGISTERED_CLASSES[cls] = ClassRegistration(flatten_func, unflatten_func) + + +def register_tree_node_class(cls): + register_tree_node(cls) + return cls + + +register_tree_node( + collections.OrderedDict, + lambda d: (d.values(), list(d.keys()), d.keys()), + lambda metadata, children: collections.OrderedDict(zip(metadata, children)), +) + +if backend() == "tensorflow": + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + register_tree_node( + ListWrapper, + lambda x: (x, None), + lambda metadata, children: ListWrapper(list(children)), + ) + + def sorted_keys_and_values(d): + keys = sorted(list(d.keys())) + values = [d[k] for k in keys] + return values, keys, keys + + register_tree_node( + _DictWrapper, + sorted_keys_and_values, + lambda metadata, children: _DictWrapper( + {key: child for key, child in zip(metadata, children)} + ), + ) + + +def is_nested(structure): + return type(structure) in REGISTERED_CLASSES or dmtree.is_nested(structure) + + +def traverse(func, structure, top_down=True): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + def remap_map_to_none(value, new_value): + if isinstance(value, type) and value.__name__ == "MAP_TO_NONE": + return new_value + return value + + def traverse_top_down(s): + ret = func(s) + if ret is not None: + return remap_map_to_none(ret, dmtree.MAP_TO_NONE) + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is None: + return None + flat_meta_s = registration.flatten(s) + flat_s = [ + dmtree.traverse(traverse_top_down, x, top_down=True) + for x in list(flat_meta_s[0]) + ] + return registration.unflatten(flat_meta_s[1], flat_s) + + def traverse_bottom_up(s): + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is not None: + flat_meta_s = registration.flatten(s) + ret = [traverse_bottom_up(x) for x in list(flat_meta_s[0])] + ret = registration.unflatten(flat_meta_s[1], ret) + elif not dmtree.is_nested(s): + ret = s + elif isinstance(s, collections.abc.Mapping): + ret = [traverse_bottom_up(s[key]) for key in sorted(s)] + ret = dmtree._sequence_like(s, ret) + else: + ret = [traverse_bottom_up(x) for x in s] + ret = dmtree._sequence_like(s, ret) + func_ret = func(ret) + return ret if func_ret is None else remap_map_to_none(func_ret, None) + + if top_down: + return dmtree.traverse(traverse_top_down, structure, top_down=True) + else: + return traverse_bottom_up(structure) + + +def flatten(structure): + if not is_nested(structure): + return [structure] + + flattened = [] + + def flatten_func(s): + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is not None: + flat_s = list(registration.flatten(s)[0]) + return dmtree.traverse(flatten_func, flat_s, top_down=True) + if not is_nested(s): + flattened.append(s) + return dmtree.MAP_TO_NONE if s is None else s + return None + + dmtree.traverse(flatten_func, structure, top_down=True) + return flattened + + +def _recursive_flatten_with_path(path, structure, flattened): + registration = REGISTERED_CLASSES.get(type(structure), None) + if registration is not None: + flat_meta_paths = registration.flatten(structure) + flat = flat_meta_paths[0] + paths = ( + flat_meta_paths[2] + if len(flat_meta_paths) >= 3 + else itertools.count() + ) + for key, value in zip(paths, flat): + _recursive_flatten_with_path(path + (key,), value, flattened) + elif not dmtree.is_nested(structure): + flattened.append((path, structure)) + elif isinstance(structure, collections.abc.Mapping): + for key in sorted(structure): + _recursive_flatten_with_path( + path + (key,), structure[key], flattened + ) + else: + for key, value in enumerate(structure): + _recursive_flatten_with_path(path + (key,), value, flattened) + + +def flatten_with_path(structure): + if not is_nested(structure): + return [((), structure)] + + # Fully reimplemented in Python to handle registered classes, OrderedDict + # and namedtuples the same way as optree. + flattened = [] + _recursive_flatten_with_path((), structure, flattened) + return flattened + + +def map_structure(func, *structures): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + def func_traverse_wrapper(s): + if is_nested(s): + return None + ret = func(s) + if ret is None: + return dmtree.MAP_TO_NONE + return ret + + if len(structures) == 1: + return traverse(func_traverse_wrapper, structures[0]) + + with TypeErrorRemapping(): + return dmtree.map_structure(func, *structures) + + +def map_structure_up_to(shallow_structure, func, *structures): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + with TypeErrorRemapping(): + return dmtree.map_structure_up_to(shallow_structure, func, *structures) + + +def assert_same_structure(a, b): + # Fully reimplemented in Python to handle registered classes. + + # Don't handle OrderedDict as a registered class, use the normal dict path + # so that OrderedDict is equivalent to dict per optree behavior. + a_registration = REGISTERED_CLASSES.get(type(a), None) + if isinstance(a, collections.OrderedDict): + a_registration = None + + b_registration = REGISTERED_CLASSES.get(type(b), None) + if isinstance(b, collections.OrderedDict): + b_registration = None + + if a_registration != b_registration: + raise ValueError( + f"Custom node type mismatch; " + f"expected type: {type(a)}, got type: {type(b)} " + f"while comparing {a} and {b}." + ) + if a_registration is not None: + a_flat_meta = a_registration.flatten(a) + b_flat_meta = b_registration.flatten(b) + a_flat = list(a_flat_meta[0]) + b_flat = list(b_flat_meta[0]) + if not a_flat_meta[1] == b_flat_meta[1]: + raise ValueError( + f"Mismatch custom node data; " + f"expected: {a_flat_meta[1]}, got: {b_flat_meta[1]} " + f"while comparing {a} and {b}." + ) + if len(a_flat) != len(b_flat): + raise ValueError( + f"Arity mismatch; expected: {len(a)}, got: {len(b)} " + f"while comparing {a} and {b}." + ) + for sub_a, sub_b in zip(a_flat, b_flat): + assert_same_structure(sub_a, sub_b) + elif not dmtree.is_nested(a): + if dmtree.is_nested(b): + raise ValueError( + f"Structures don't have the same nested structure: {a}, {b}." + ) + elif isinstance( + a, (dict, collections.OrderedDict, collections.defaultdict) + ): + if not isinstance( + b, (dict, collections.OrderedDict, collections.defaultdict) + ): + raise ValueError( + f"Expected an instance of dict, collections.OrderedDict, or " + f"collections.defaultdict, got {type(b)} " + f"while comparing {a} and {b}." + ) + a_keys = sorted(a) + b_keys = sorted(b) + if not a_keys == b_keys: + raise ValueError( + f"Dictionary key mismatch; " + f"expected key(s): {a_keys}, got key(s): {b_keys} " + f"while comparing {a} and {b}." + ) + for key in a_keys: + assert_same_structure(a[key], b[key]) + elif isinstance(a, collections.abc.Mapping): + raise ValueError( + f"Encountered unregistered collections.abc.Mapping type: {type(a)} " + f"while comparing {a} and {b}." + ) + else: + if type(a) is not type(b): + raise ValueError( + f"Expected an instance of {type(a)}, got {type(b)} " + f"while comparing {a} and {b}." + ) + if not len(a) == len(b): + raise ValueError( + f"Arity mismatch; expected: {len(a)}, got: {len(b)} " + f"while comparing {a} and {b}." + ) + for sub_a, sub_b in zip(a, b): + assert_same_structure(sub_a, sub_b) + + +def assert_same_paths(a, b): + a_paths = set([path for path, _ in flatten_with_path(a)]) + b_paths = set([path for path, _ in flatten_with_path(b)]) + + if a_paths != b_paths: + msg = "`a` and `b` don't have the same paths." + a_diff = a_paths.difference(b_paths) + if a_diff: + msg += f"\nPaths in `a` missing in `b`:\n{a_diff}" + b_diff = b_paths.difference(a_paths) + if b_diff: + msg += f"\nPaths in `b` missing in `a`:\n{b_diff}" + raise ValueError(msg) + + +def pack_sequence_as(structure, flat_sequence): + # This is not just an optimization for the case when structure is a leaf. + # This is required to avoid Torch Dynamo failures. + if not is_nested(structure): + if len(flat_sequence) == 1: + return flat_sequence[0] + else: + raise ValueError( + "Incorrect number of leaves provided by `flat_sequence` for " + f"`structure`; expected: 1, got {len(flat_sequence)}." + ) + + flat_sequence_it = enumerate(flat_sequence) + + def unflatten_func(s): + registration = REGISTERED_CLASSES.get(type(s), None) + if registration is not None: + flat_meta_s = registration.flatten(s) + flat_s = dmtree.traverse( + unflatten_func, list(flat_meta_s[0]), top_down=True + ) + return registration.unflatten(flat_meta_s[1], flat_s) + elif not dmtree.is_nested(s): + try: + _, value = next(flat_sequence_it) + return dmtree.MAP_TO_NONE if value is None else value + except StopIteration: + raise ValueError( + "Too few leaves provided by `flat_sequence` for " + f"`structure`. Got {len(flat_sequence)}." + ) + return None + + ret = dmtree.traverse(unflatten_func, structure, top_down=True) + try: + index, _ = next(flat_sequence_it) + raise ValueError( + "Too many leaves provided by `flat_sequence` for `structure`; " + f"expected: {index}, got {len(flat_sequence)}." + ) + except StopIteration: + return ret + + +def lists_to_tuples(structure): + def list_to_tuple(instance): + return tuple(instance) if isinstance(instance, list) else None + + return traverse(list_to_tuple, structure, top_down=False) + + +def map_shape_structure(func, structure): + if not callable(func): + raise TypeError( + f"`func` must be callable, got {func} of type {type(func)}" + ) + + def map_shape_func(x): + if isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ): + ret = func(x) + elif is_nested(x): + return None + else: + ret = func(x) + return ret if ret is not None else dmtree.MAP_TO_NONE + + return traverse(map_shape_func, structure, top_down=True) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py new file mode 100644 index 0000000000000000000000000000000000000000..1aad9c2b13530f3674a4dca3b53f446980c4e7db --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/optree_impl.py @@ -0,0 +1,187 @@ +import optree +import optree.utils + +from keras.src.backend.config import backend + + +def register_tree_node_class(cls): + return optree.register_pytree_node_class(cls, namespace="keras") + + +# Register backend-specific node classes +if backend() == "tensorflow": + from tensorflow.python.trackable.data_structures import ListWrapper + from tensorflow.python.trackable.data_structures import _DictWrapper + + optree.register_pytree_node( + ListWrapper, + lambda x: (x, None), + lambda metadata, children: ListWrapper(list(children)), + namespace="keras", + ) + + def sorted_keys_and_values(d): + keys = sorted(list(d.keys())) + values = [d[k] for k in keys] + return values, keys, keys + + optree.register_pytree_node( + _DictWrapper, + sorted_keys_and_values, + lambda metadata, children: _DictWrapper( + {key: child for key, child in zip(metadata, children)} + ), + namespace="keras", + ) + + +def is_nested(structure): + return not optree.tree_is_leaf( + structure, none_is_leaf=True, namespace="keras" + ) + + +def traverse(func, structure, top_down=True): + # From https://github.com/google/jax/pull/19695 + def traverse_children(): + children, treedef = optree.tree_flatten( + structure, + is_leaf=lambda x: x is not structure, + none_is_leaf=True, + namespace="keras", + ) + if treedef.num_nodes == 1 and treedef.num_leaves == 1: + return structure + else: + return optree.tree_unflatten( + treedef, + [traverse(func, c, top_down=top_down) for c in children], + ) + + if top_down: + ret = func(structure) + if ret is None: + return traverse_children() + else: + traversed_structure = traverse_children() + ret = func(traversed_structure) + if ret is None: + return traversed_structure + # Detect MAP_TO_NONE without tree_api import to avoid circular import. + if isinstance(ret, type) and ret.__name__ == "MAP_TO_NONE": + return None + return ret + + +def flatten(structure): + # optree.tree_flatten returns a pair (leaves, treespec) where the first + # element is a list of leaf values and the second element is a treespec + # representing the structure of the pytree. + leaves, _ = optree.tree_flatten( + structure, none_is_leaf=True, namespace="keras" + ) + return leaves + + +def flatten_with_path(structure): + paths, leaves, _ = optree.tree_flatten_with_path( + structure, none_is_leaf=True, namespace="keras" + ) + return list(zip(paths, leaves)) + + +def map_structure(func, *structures): + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check for same structures, otherwise optree just maps to shallowest. + def func_with_check(*args): + if not all( + optree.tree_is_leaf(s, none_is_leaf=True, namespace="keras") + for s in args + ): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + map_func = func_with_check if len(structures) > 1 else func + + return optree.tree_map( + map_func, *structures, none_is_leaf=True, namespace="keras" + ) + + +def map_structure_up_to(shallow_structure, func, *structures): + if not structures: + raise ValueError("Must provide at least one structure") + + # Add check that `shallow_structure` really is the shallowest. + # Also only call `func` on `structures` and not `shallow_structure`. + def func_with_check_without_shallow_structure(shallow, *args): + if not optree.tree_is_leaf(shallow): + raise ValueError("Structures don't have the same nested structure.") + return func(*args) + + return optree.tree_map( + func_with_check_without_shallow_structure, + shallow_structure, + *structures, + none_is_leaf=True, + namespace="keras", + ) + + +def assert_same_structure(a, b): + def check(a_leaf, b_leaf): + if not optree.tree_is_leaf( + a_leaf, none_is_leaf=True, namespace="keras" + ) or not optree.tree_is_leaf( + b_leaf, none_is_leaf=True, namespace="keras" + ): + raise ValueError("Structures don't have the same nested structure.") + return None + + optree.tree_map(check, a, b, none_is_leaf=True, namespace="keras") + + +def assert_same_paths(a, b): + a_paths = set(optree.tree_paths(a, none_is_leaf=True, namespace="keras")) + b_paths = set(optree.tree_paths(b, none_is_leaf=True, namespace="keras")) + + if a_paths != b_paths: + msg = "`a` and `b` don't have the same paths." + a_diff = a_paths.difference(b_paths) + if a_diff: + msg += f"\nPaths in `a` missing in `b`:\n{a_diff}" + b_diff = b_paths.difference(a_paths) + if b_diff: + msg += f"\nPaths in `b` missing in `a`:\n{b_diff}" + raise ValueError(msg) + + +def pack_sequence_as(structure, flat_sequence): + _, treespec = optree.tree_flatten( + structure, none_is_leaf=True, namespace="keras" + ) + return optree.tree_unflatten(treespec, flat_sequence) + + +def lists_to_tuples(structure): + def list_to_tuple(instance): + return tuple(instance) if isinstance(instance, list) else None + + return traverse(list_to_tuple, structure, top_down=False) + + +def map_shape_structure(func, structure): + def is_shape_tuple(x): + return isinstance(x, (list, tuple)) and all( + isinstance(e, (int, type(None))) for e in x + ) + + return optree.tree_map( + func, + structure, + is_leaf=is_shape_tuple, + none_is_leaf=True, + namespace="keras", + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py new file mode 100644 index 0000000000000000000000000000000000000000..a4f98f068eec3846bb358569ec0b7a7ae69727bc --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/tree/tree_api.py @@ -0,0 +1,404 @@ +import warnings + +from keras.src.api_export import keras_export +from keras.src.utils.module_utils import dmtree +from keras.src.utils.module_utils import optree + +if optree.available: + from keras.src.tree import optree_impl as tree_impl +elif dmtree.available: + from keras.src.tree import dmtree_impl as tree_impl +else: + raise ImportError( + "To use Keras, you need to have `optree` installed. " + "Install it via `pip install optree`" + ) + + +def register_tree_node_class(cls): + return tree_impl.register_tree_node_class(cls) + + +@keras_export("keras.tree.MAP_TO_NONE") +class MAP_TO_NONE: + """Special value for use with `traverse()`.""" + + pass + + +@keras_export("keras.tree.is_nested") +def is_nested(structure): + """Checks if a given structure is nested. + + Examples: + + >>> keras.tree.is_nested(42) + False + >>> keras.tree.is_nested({"foo": 42}) + True + + Args: + structure: A structure to check. + + Returns: + `True` if a given structure is nested, i.e. is a sequence, a mapping, + or a namedtuple, and `False` otherwise. + """ + return tree_impl.is_nested(structure) + + +@keras_export("keras.tree.traverse") +def traverse(func, structure, top_down=True): + """Traverses the given nested structure, applying the given function. + + The traversal is depth-first. If `top_down` is True (default), parents + are returned before their children (giving the option to avoid traversing + into a sub-tree). + + Examples: + + >>> v = [] + >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=True) + [(1, 2), [3], {'a': 4}] + >>> v + [[(1, 2), [3], {'a': 4}], (1, 2), 1, 2, [3], 3, {'a': 4}, 4] + + >>> v = [] + >>> keras.tree.traverse(v.append, [(1, 2), [3], {"a": 4}], top_down=False) + [(1, 2), [3], {'a': 4}] + >>> v + [1, 2, (1, 2), 3, [3], 4, {'a': 4}, [(1, 2), [3], {'a': 4}]] + + Args: + func: The function to be applied to each sub-nest of the structure. + + When traversing top-down: + If `func(subtree) is None` the traversal continues into the + sub-tree. + If `func(subtree) is not None` the traversal does not continue + into the sub-tree. The sub-tree will be replaced by `func(subtree)` + in the returned structure (to replace the sub-tree with `None`, use + the special value `MAP_TO_NONE`). + + When traversing bottom-up: + If `func(subtree) is None` the traversed sub-tree is returned + unaltered. + If `func(subtree) is not None` the sub-tree will be replaced by + `func(subtree)` in the returned structure (to replace the sub-tree + with None, use the special value `MAP_TO_NONE`). + + structure: The structure to traverse. + top_down: If True, parent structures will be visited before their + children. + + Returns: + The structured output from the traversal. + + Raises: + TypeError: If `func` is not callable. + """ + return tree_impl.traverse(func, structure, top_down=top_down) + + +@keras_export("keras.tree.flatten") +def flatten(structure): + """Flattens a possibly nested structure into a list. + + In the case of dict instances, the sequence consists of the values, + sorted by key to ensure deterministic behavior. However, instances of + `collections.OrderedDict` are handled differently: their sequence order is + used instead of the sorted keys. The same convention is followed in + `pack_sequence_as`. This correctly unflattens dicts and `OrderedDict` after + they have been flattened, or vice-versa. + + Dictionaries with non-sortable keys are not supported. + + Examples: + + >>> keras.tree.flatten([[1, 2, 3], [4, [5], [[6]]]]) + [1, 2, 3, 4, 5, 6] + >>> keras.tree.flatten(None) + [None] + >>> keras.tree.flatten(1) + [1] + >>> keras.tree.flatten({100: 'world!', 6: 'Hello'}) + ['Hello', 'world!'] + + Args: + structure: An arbitrarily nested structure. + + Returns: + A list, the flattened version of the input `structure`. + """ + return tree_impl.flatten(structure) + + +@keras_export("keras.tree.flatten_with_path") +def flatten_with_path(structure): + """Flattens a possibly nested structure into a list. + + This is a variant of flattens() which produces a + list of pairs: `(path, item)`. A path is a tuple of indices and/or keys + which uniquely identifies the position of the corresponding item. + + Dictionaries with non-sortable keys are not supported. + + Examples: + + >>> keras.flatten_with_path([{"foo": 42}]) + [((0, 'foo'), 42)] + + + Args: + structure: An arbitrarily nested structure. + + Returns: + A list of `(path, item)` pairs corresponding to the flattened + version of the input `structure`. + """ + return tree_impl.flatten_with_path(structure) + + +@keras_export("keras.tree.map_structure") +def map_structure(func, *structures): + """Maps `func` through given structures. + + Examples: + + >>> structure = [[1], [2], [3]] + >>> keras.tree.map_structure(lambda v: v**2, structure) + [[1], [4], [9]] + >>> keras.tree.map_structure(lambda x, y: x * y, structure, structure) + [[1], [4], [9]] + + >>> Foo = collections.namedtuple('Foo', ['a', 'b']) + >>> structure = Foo(a=1, b=2) + >>> keras.tree.map_structure(lambda v: v * 2, structure) + Foo(a=2, b=4) + + Args: + func: A callable that accepts as many arguments as there are structures. + *structures: Arbitrarily nested structures of the same layout. + + Returns: + A new structure with the same layout as the given ones. + + Raises: + TypeError: If `structures` is empty or `func` is not callable. + ValueError: If there is more than one items in `structures` and some of + the nested structures don't match according to the rules of + `assert_same_structure`. + """ + return tree_impl.map_structure(func, *structures) + + +@keras_export("keras.tree.map_structure_up_to") +def map_structure_up_to(shallow_structure, func, *structures): + """Maps `func` through given structures up to `shallow_structure`. + + This is a variant of `map_structure` which only maps the given structures + up to `shallow_structure`. All further nested components are retained as-is. + + Examples: + + >>> shallow_structure = [None, None] + >>> structure = [[1, 1], [2, 2]] + >>> keras.tree.map_structure_up_to(shallow_structure, len, structure) + [2, 2] + + >>> shallow_structure = [None, [None, None]] + >>> keras.tree.map_structure_up_to(shallow_structure, str, structure) + ['[1, 1]', ['2', '2']] + + Args: + shallow_structure: A structure with layout common to all `structures`. + func: A callable that accepts as many arguments as there are structures. + *structures: Arbitrarily nested structures of the same layout. + + Returns: + A new structure with the same layout as `shallow_structure`. + + Raises: + TypeError: If `structures` is empty or `func` is not callable. + ValueError: If one of the items in `structures` doesn't match the + nested structure of `shallow_structure` according to the rules of + `assert_same_structure`. Items in `structures` are allowed to be + nested deeper than `shallow_structure`, but they cannot be + shallower. + """ + return tree_impl.map_structure_up_to(shallow_structure, func, *structures) + + +@keras_export("keras.tree.assert_same_structure") +def assert_same_structure(a, b, check_types=None): + """Asserts that two structures are nested in the same way. + + This function verifies that the nested structures match. The leafs can be of + any type. At each level, the structures must be of the same type and have + the same number of elements. Instances of `dict`, `OrderedDict` and + `defaultdict` are all considered the same as long as they have the same set + of keys. However, `list`, `tuple`, `namedtuple` and `deque` are not the same + structures. Two namedtuples with identical fields and even identical names + are not the same structures. + + Examples: + + >>> keras.tree.assert_same_structure([(0, 1)], [(2, 3)]) + + >>> Foo = collections.namedtuple('Foo', ['a', 'b']) + >>> AlsoFoo = collections.namedtuple('Foo', ['a', 'b']) + >>> keras.tree.assert_same_structure(Foo(0, 1), Foo(2, 3)) + >>> keras.tree.assert_same_structure(Foo(0, 1), AlsoFoo(2, 3)) + Traceback (most recent call last): + ... + ValueError: The two structures don't have the same nested structure. + ... + + Args: + a: an arbitrarily nested structure. + b: an arbitrarily nested structure. + check_types: Deprecated. The behavior of this flag was inconsistent, it + no longer has any effect. For a looser check, use + `assert_same_paths` instead, which considers `list`, `tuple`, + `namedtuple` and `deque` as matching structures. + + Raises: + ValueError: If the two structures `a` and `b` don't match. + """ + if check_types is not None: + if check_types: + warnings.warn( + "The `check_types` argument is deprecated and no longer has " + "any effect, please remove.", + DeprecationWarning, + stacklevel=2, + ) + else: + warnings.warn( + "The `check_types` argument is deprecated and no longer has " + "any effect. For a looser check, use " + "`keras.tree.assert_same_paths()`, which considers `list`, " + "`tuple`, `namedtuple` and `deque` as matching", + DeprecationWarning, + stacklevel=2, + ) + return tree_impl.assert_same_structure(a, b) + + +@keras_export("keras.tree.assert_same_paths") +def assert_same_paths(a, b): + """Asserts that two structures have identical paths in their tree structure. + + This function verifies that two nested structures have the same paths. + Unlike `assert_same_structure`, this function only checks the paths + and ignores the collection types. + For Sequences, to path is the index: 0, 1, 2, etc. For Mappings, the path is + the key, for instance "a", "b", "c". Note that namedtuples also use indices + and not field names for the path. + + Examples: + >>> keras.tree.assert_same_paths([0, 1], (2, 3)) + >>> Point1 = collections.namedtuple('Point1', ['x', 'y']) + >>> Point2 = collections.namedtuple('Point2', ['x', 'y']) + >>> keras.tree.assert_same_paths(Point1(0, 1), Point2(2, 3)) + + Args: + a: an arbitrarily nested structure. + b: an arbitrarily nested structure. + + Raises: + ValueError: If the paths in structure `a` don't match the paths in + structure `b`. The error message will include the specific paths + that differ. + """ + return tree_impl.assert_same_paths(a, b) + + +@keras_export("keras.tree.pack_sequence_as") +def pack_sequence_as(structure, flat_sequence): + """Returns a given flattened sequence packed into a given structure. + + If `structure` is an atom, `flat_sequence` must be a single-item list; in + this case the return value is `flat_sequence[0]`. + + If `structure` is or contains a dict instance, the keys will be sorted to + pack the flat sequence in deterministic order. However, instances of + `collections.OrderedDict` are handled differently: their sequence order is + used instead of the sorted keys. The same convention is followed in + `flatten`. This correctly repacks dicts and `OrderedDicts` after they have + been flattened, or vice-versa. + + Dictionaries with non-sortable keys are not supported. + + Examples: + + >>> structure = {"key3": "", "key1": "", "key2": ""} + >>> flat_sequence = ["value1", "value2", "value3"] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + {"key3": "value3", "key1": "value1", "key2": "value2"} + + >>> structure = (("a", "b"), ("c", "d", "e"), "f") + >>> flat_sequence = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + ((1.0, 2.0), (3.0, 4.0, 5.0), 6.0) + + >>> structure = {"key3": {"c": ("alpha", "beta"), "a": ("gamma")}, + ... "key1": {"e": "val1", "d": "val2"}} + >>> flat_sequence = ["val2", "val1", 3.0, 1.0, 2.0] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + {'key3': {'c': (1.0, 2.0), 'a': 3.0}, 'key1': {'e': 'val1', 'd': 'val2'}} + + >>> structure = ["a"] + >>> flat_sequence = [np.array([[1, 2], [3, 4]])] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + [array([[1, 2], + [3, 4]])] + + >>> structure = ["a"] + >>> flat_sequence = [keras.ops.ones([2, 2])] + >>> keras.tree.pack_sequence_as(structure, flat_sequence) + [array([[1., 1.], + [1., 1.]]] + + Args: + structure: Arbitrarily nested structure. + flat_sequence: Flat sequence to pack. + + Returns: + `flat_sequence` converted to have the same recursive structure as + `structure`. + + Raises: + TypeError: If `flat_sequence` is not iterable. + ValueError: If `flat_sequence` cannot be repacked as `structure`; for + instance, if `flat_sequence` has too few or too many elements. + """ + return tree_impl.pack_sequence_as(structure, flat_sequence) + + +@keras_export("keras.tree.lists_to_tuples") +def lists_to_tuples(structure): + """Returns the structure with list instances changed to tuples. + + Args: + structure: Arbitrarily nested structure. + + Returns: + The same structure but with tuples instead of lists. + """ + return tree_impl.lists_to_tuples(structure) + + +@keras_export("keras.tree.map_shape_structure") +def map_shape_structure(func, structure): + """Variant of keras.tree.map_structure that operates on shape tuples. + + Tuples containing ints and Nones are considered shapes and passed to `func`. + + Args: + structure: Arbitrarily nested structure. + + Returns: + The same structure with `func` applied. + """ + return tree_impl.map_shape_structure(func, structure) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c503a2043776abf97036e01bafcbf18d4c2852aa --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__init__.py @@ -0,0 +1,26 @@ +from keras.src.utils.audio_dataset_utils import audio_dataset_from_directory +from keras.src.utils.dataset_utils import split_dataset +from keras.src.utils.file_utils import get_file +from keras.src.utils.image_dataset_utils import image_dataset_from_directory +from keras.src.utils.image_utils import array_to_img +from keras.src.utils.image_utils import img_to_array +from keras.src.utils.image_utils import load_img +from keras.src.utils.image_utils import save_img +from keras.src.utils.io_utils import disable_interactive_logging +from keras.src.utils.io_utils import enable_interactive_logging +from keras.src.utils.io_utils import is_interactive_logging_enabled +from keras.src.utils.model_visualization import model_to_dot +from keras.src.utils.model_visualization import plot_model +from keras.src.utils.numerical_utils import normalize +from keras.src.utils.numerical_utils import to_categorical +from keras.src.utils.progbar import Progbar +from keras.src.utils.python_utils import default +from keras.src.utils.python_utils import is_default +from keras.src.utils.python_utils import removeprefix +from keras.src.utils.python_utils import removesuffix +from keras.src.utils.rng_utils import set_random_seed +from keras.src.utils.sequence_utils import pad_sequences +from keras.src.utils.text_dataset_utils import text_dataset_from_directory +from keras.src.utils.timeseries_dataset_utils import ( + timeseries_dataset_from_array, +) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99b4154f80e2a3bbc7d62efa146ba44bbbd6604b Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ca88c8c5d6b14f418f860fba9f6b29acc278820 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/argument_validation.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75c0ff8f190be2065887c0fdd2c2cb3a063e490f Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/audio_dataset_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..afb06d65f62d5e5c4847ec27b7face87614897d8 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/backend_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2295c0e124061bb85118a410d563c482eb7578dd Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/code_stats.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0b10ecc0480ebbe9bb7fb738861d68cffd0b650 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/config.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/dataset_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..020db5da3249012b96028e469ba0152db9c9163c Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/dataset_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/dtype_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/dtype_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa745aa9cc67f58065a15fc188b0b898b7c8f693 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/dtype_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/file_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/file_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6698426a3b7c9fafb49378501ea812e01c8f5e73 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/file_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/image_dataset_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/image_dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c9cb7bddc9dfab54ff6e1af6dd010b8c61c1f3 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/image_dataset_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/image_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/image_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2269eb87e96afe85c6eb4ca8ab1a6b01f4abd745 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/image_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/io_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/io_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..902f5d1fd8e08870ee0671e01203158fa177f85f Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/io_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/jax_layer.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/jax_layer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c7ea4ee718417d8405c99725653463ef4ca2a16 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/jax_layer.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/jax_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/jax_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e35618c70280109ff68ef6bf483cbbd224a3a188 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/jax_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/model_visualization.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/model_visualization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e19e1a245eec642f21db908b7cb4444426eecf69 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/model_visualization.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/module_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/module_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29d97d026691b70c95fac08c3c2c38f6b6b48709 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/module_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/naming.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/naming.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4556fc035053728d324ea2b13bdbcfc88f2e501 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/naming.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/numerical_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/numerical_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeb0f00f2384282b84fdc50345ffe56865fc4e53 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/numerical_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/progbar.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/progbar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e160b13d1e6b561ec0188c245e319ab8d8a7f082 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/progbar.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/python_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/python_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9167e711ad75af5ac916ecbad7f00c132c03654 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/python_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/rng_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/rng_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c46cd677692a106722812078d91984c427f4befa Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/rng_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/sequence_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/sequence_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6f8f7bc35e0bac49fa60b5f3b0358967f23bd72 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/sequence_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/summary_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/summary_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b63db18e2a06a39e6dbd234e99e4873eb5c68f5 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/summary_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/text_dataset_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/text_dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..faf525fbe5a9af266c4e27a463cb68414dc2279c Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/text_dataset_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/tf_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/tf_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d64c8c9b96e7ba1d4f9000b15ceb6ef45f9e1177 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/tf_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/timeseries_dataset_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/timeseries_dataset_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4ac507cfad95ca21654c5a93194dac89fbdf2db Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/timeseries_dataset_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/torch_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/torch_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b296e13955559aa225a92eb94f9c3d99fe13c4f Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/torch_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/traceback_utils.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/traceback_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74d47bd97718d767b243bbc9dccdf5c127945e46 Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/traceback_utils.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/tracking.cpython-310.pyc b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/tracking.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ff76a1a6b89503bd4d72be8ef431df42f0cba4c Binary files /dev/null and b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/__pycache__/tracking.cpython-310.pyc differ diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/argument_validation.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/argument_validation.py new file mode 100644 index 0000000000000000000000000000000000000000..8f772b11e5c0d26d44100e23bfec6a84ffe02a1c --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/argument_validation.py @@ -0,0 +1,92 @@ +def standardize_tuple(value, n, name, allow_zero=False): + """Transforms non-negative/positive integer/integers into an integer tuple. + + Args: + value: int or iterable of ints. The value to validate and convert. + n: int. The size of the tuple to be returned. + name: string. The name of the argument being validated, e.g. "strides" + or "kernel_size". This is only used to format error messages. + allow_zero: bool, defaults to `False`. A `ValueError` will raised + if zero is received and this argument is `False`. + + Returns: + A tuple of n integers. + """ + error_msg = ( + f"The `{name}` argument must be a tuple of {n} integers. " + f"Received {name}={value}" + ) + + if isinstance(value, int): + value_tuple = (value,) * n + else: + try: + value_tuple = tuple(value) + except TypeError: + raise ValueError(error_msg) + if len(value_tuple) != n: + raise ValueError(error_msg) + for single_value in value_tuple: + try: + int(single_value) + except (ValueError, TypeError): + error_msg += ( + f"including element {single_value} of " + f"type {type(single_value)}" + ) + raise ValueError(error_msg) + + if allow_zero: + unqualified_values = {v for v in value_tuple if v < 0} + req_msg = ">= 0" + else: + unqualified_values = {v for v in value_tuple if v <= 0} + req_msg = "> 0" + + if unqualified_values: + error_msg += ( + f", including values {unqualified_values}" + f" that do not satisfy `value {req_msg}`" + ) + raise ValueError(error_msg) + + return value_tuple + + +def standardize_padding(value, allow_causal=False): + if isinstance(value, (list, tuple)): + return value + padding = value.lower() + if allow_causal: + allowed_values = {"valid", "same", "causal"} + else: + allowed_values = {"valid", "same"} + if padding not in allowed_values: + raise ValueError( + "The `padding` argument must be a list/tuple or one of " + f"{allowed_values}. " + f"Received: {padding}" + ) + return padding + + +def validate_string_arg( + value, + allowable_strings, + caller_name, + arg_name, + allow_none=False, + allow_callables=False, +): + """Validates the correctness of a string-based arg.""" + if allow_none and value is None: + return + elif allow_callables and callable(value): + return + elif isinstance(value, str) and value in allowable_strings: + return + raise ValueError( + f"Unknown value for `{arg_name}` argument of {caller_name}. " + f"Allowed values are: {allowable_strings}. Received: " + f"{arg_name}={value}" + ) diff --git a/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/audio_dataset_utils.py b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/audio_dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6f27d37c85cb6745d5424ee68503bd2f760a1df --- /dev/null +++ b/SwarmUI/dlbackend/ComfyUI/venv/lib/python3.10/site-packages/keras/src/utils/audio_dataset_utils.py @@ -0,0 +1,453 @@ +import numpy as np + +from keras.src.api_export import keras_export +from keras.src.utils import dataset_utils +from keras.src.utils.module_utils import tensorflow as tf +from keras.src.utils.module_utils import tensorflow_io as tfio + +ALLOWED_FORMATS = (".wav",) + + +@keras_export("keras.utils.audio_dataset_from_directory") +def audio_dataset_from_directory( + directory, + labels="inferred", + label_mode="int", + class_names=None, + batch_size=32, + sampling_rate=None, + output_sequence_length=None, + ragged=False, + shuffle=True, + seed=None, + validation_split=None, + subset=None, + follow_links=False, + verbose=True, +): + """Generates a `tf.data.Dataset` from audio files in a directory. + + If your directory structure is: + + ``` + main_directory/ + ...class_a/ + ......a_audio_1.wav + ......a_audio_2.wav + ...class_b/ + ......b_audio_1.wav + ......b_audio_2.wav + ``` + + Then calling `audio_dataset_from_directory(main_directory, + labels='inferred')` + will return a `tf.data.Dataset` that yields batches of audio files from + the subdirectories `class_a` and `class_b`, together with labels + 0 and 1 (0 corresponding to `class_a` and 1 corresponding to `class_b`). + + Only `.wav` files are supported at this time. + + Args: + directory: Directory where the data is located. + If `labels` is `"inferred"`, it should contain subdirectories, + each containing audio files for a class. Otherwise, the directory + structure is ignored. + labels: Either "inferred" (labels are generated from the directory + structure), `None` (no labels), or a list/tuple of integer labels + of the same size as the number of audio files found in + the directory. Labels should be sorted according to the + alphanumeric order of the audio file paths + (obtained via `os.walk(directory)` in Python). + label_mode: String describing the encoding of `labels`. Options are: + - `"int"`: means that the labels are encoded as integers (e.g. for + `sparse_categorical_crossentropy` loss). + - `"categorical"` means that the labels are encoded as a categorical + vector (e.g. for `categorical_crossentropy` loss) + - `"binary"` means that the labels (there can be only 2) + are encoded as `float32` scalars with values 0 + or 1 (e.g. for `binary_crossentropy`). + - `None` (no labels). + class_names: Only valid if "labels" is `"inferred"`. + This is the explicit list of class names + (must match names of subdirectories). Used to control the order + of the classes (otherwise alphanumerical order is used). + batch_size: Size of the batches of data. Default: 32. If `None`, + the data will not be batched + (the dataset will yield individual samples). + sampling_rate: Audio sampling rate (in samples per second). + output_sequence_length: Maximum length of an audio sequence. Audio files + longer than this will be truncated to `output_sequence_length`. + If set to `None`, then all sequences in the same batch will + be padded to the + length of the longest sequence in the batch. + ragged: Whether to return a Ragged dataset (where each sequence has its + own length). Defaults to `False`. + shuffle: Whether to shuffle the data. + If set to `False`, sorts the data in alphanumeric order. + Defaults to `True`. + seed: Optional random seed for shuffling and transformations. + validation_split: Optional float between 0 and 1, fraction of data to + reserve for validation. + subset: Subset of the data to return. One of `"training"`, + `"validation"` or `"both"`. Only used if `validation_split` is set. + follow_links: Whether to visits subdirectories pointed to by symlinks. + Defaults to `False`. + verbose: Whether to display number information on classes and + number of files found. Defaults to `True`. + + Returns: + + A `tf.data.Dataset` object. + + - If `label_mode` is `None`, it yields `string` tensors of shape + `(batch_size,)`, containing the contents of a batch of audio files. + - Otherwise, it yields a tuple `(audio, labels)`, where `audio` + has shape `(batch_size, sequence_length, num_channels)` and `labels` + follows the format described + below. + + Rules regarding labels format: + + - if `label_mode` is `int`, the labels are an `int32` tensor of shape + `(batch_size,)`. + - if `label_mode` is `binary`, the labels are a `float32` tensor of + 1s and 0s of shape `(batch_size, 1)`. + - if `label_mode` is `categorical`, the labels are a `float32` tensor + of shape `(batch_size, num_classes)`, representing a one-hot + encoding of the class index. + """ + if labels not in ("inferred", None): + if not isinstance(labels, (list, tuple)): + raise ValueError( + "The `labels` argument should be a list/tuple of integer " + "labels, of the same size as the number of audio files in " + "the target directory. If you wish to infer the labels from " + "the subdirectory names in the target directory," + ' pass `labels="inferred"`. ' + "If you wish to get a dataset that only contains audio samples " + f"(no labels), pass `labels=None`. Received: labels={labels}" + ) + if class_names: + raise ValueError( + "You can only pass `class_names` if " + f'`labels="inferred"`. Received: labels={labels}, and ' + f"class_names={class_names}" + ) + if label_mode not in {"int", "categorical", "binary", None}: + raise ValueError( + '`label_mode` argument must be one of "int", "categorical", ' + '"binary", ' + f"or None. Received: label_mode={label_mode}" + ) + + if ragged and output_sequence_length is not None: + raise ValueError( + "Cannot set both `ragged` and `output_sequence_length`" + ) + + if sampling_rate is not None: + if not isinstance(sampling_rate, int): + raise ValueError( + "`sampling_rate` should have an integer value. " + f"Received: sampling_rate={sampling_rate}" + ) + + if sampling_rate <= 0: + raise ValueError( + "`sampling_rate` should be higher than 0. " + f"Received: sampling_rate={sampling_rate}" + ) + + if not tfio.available: + raise ImportError( + "To use the argument `sampling_rate`, you should install " + "tensorflow_io. You can install it via `pip install " + "tensorflow-io`." + ) + + if labels is None or label_mode is None: + labels = None + label_mode = None + + dataset_utils.check_validation_split_arg( + validation_split, subset, shuffle, seed + ) + + if seed is None: + seed = np.random.randint(1e6) + if batch_size is not None: + shuffle_buffer_size = batch_size * 8 + else: + shuffle_buffer_size = 1024 + + file_paths, labels, class_names = dataset_utils.index_directory( + directory, + labels, + formats=ALLOWED_FORMATS, + class_names=class_names, + shuffle=shuffle, + seed=seed, + follow_links=follow_links, + verbose=verbose, + ) + + if label_mode == "binary" and len(class_names) != 2: + raise ValueError( + 'When passing `label_mode="binary"`, there must be exactly 2 ' + f"class_names. Received: class_names={class_names}" + ) + + if subset == "both": + train_dataset, val_dataset = get_training_and_validation_dataset( + file_paths=file_paths, + labels=labels, + validation_split=validation_split, + directory=directory, + label_mode=label_mode, + class_names=class_names, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + train_dataset = prepare_dataset( + dataset=train_dataset, + batch_size=batch_size, + class_names=class_names, + output_sequence_length=output_sequence_length, + ragged=ragged, + ) + val_dataset = prepare_dataset( + dataset=val_dataset, + batch_size=batch_size, + class_names=class_names, + output_sequence_length=output_sequence_length, + ragged=ragged, + ) + return train_dataset, val_dataset + + else: + dataset = get_dataset( + file_paths=file_paths, + labels=labels, + directory=directory, + validation_split=validation_split, + subset=subset, + label_mode=label_mode, + class_names=class_names, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + dataset = prepare_dataset( + dataset=dataset, + batch_size=batch_size, + class_names=class_names, + output_sequence_length=output_sequence_length, + ragged=ragged, + ) + return dataset + + +def prepare_dataset( + dataset, + batch_size, + class_names, + output_sequence_length, + ragged, +): + dataset = dataset.prefetch(tf.data.AUTOTUNE) + if batch_size is not None: + if output_sequence_length is None and not ragged: + dataset = dataset.padded_batch( + batch_size, padded_shapes=([None, None], []) + ) + else: + dataset = dataset.batch(batch_size) + + # Users may need to reference `class_names`. + dataset.class_names = class_names + return dataset + + +def get_training_and_validation_dataset( + file_paths, + labels, + validation_split, + directory, + label_mode, + class_names, + sampling_rate, + output_sequence_length, + ragged, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + ( + file_paths_train, + labels_train, + ) = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, "training" + ) + if not file_paths_train: + raise ValueError( + f"No training audio files found in directory {directory}. " + f"Allowed format(s): {ALLOWED_FORMATS}" + ) + + file_paths_val, labels_val = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, "validation" + ) + if not file_paths_val: + raise ValueError( + f"No validation audio files found in directory {directory}. " + f"Allowed format(s): {ALLOWED_FORMATS}" + ) + + train_dataset = paths_and_labels_to_dataset( + file_paths=file_paths_train, + labels=labels_train, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + + val_dataset = paths_and_labels_to_dataset( + file_paths=file_paths_val, + labels=labels_val, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=False, + ) + + return train_dataset, val_dataset + + +def get_dataset( + file_paths, + labels, + directory, + validation_split, + subset, + label_mode, + class_names, + sampling_rate, + output_sequence_length, + ragged, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + file_paths, labels = dataset_utils.get_training_or_validation_split( + file_paths, labels, validation_split, subset + ) + if not file_paths: + raise ValueError( + f"No audio files found in directory {directory}. " + f"Allowed format(s): {ALLOWED_FORMATS}" + ) + + return paths_and_labels_to_dataset( + file_paths=file_paths, + labels=labels, + label_mode=label_mode, + num_classes=len(class_names) if class_names else 0, + sampling_rate=sampling_rate, + output_sequence_length=output_sequence_length, + ragged=ragged, + shuffle=shuffle, + shuffle_buffer_size=shuffle_buffer_size, + seed=seed, + ) + + +def read_and_decode_audio( + path, sampling_rate=None, output_sequence_length=None +): + """Reads and decodes audio file.""" + audio = tf.io.read_file(path) + + if output_sequence_length is None: + output_sequence_length = -1 + + audio, default_audio_rate = tf.audio.decode_wav( + contents=audio, desired_samples=output_sequence_length + ) + if sampling_rate is not None: + # default_audio_rate should have dtype=int64 + default_audio_rate = tf.cast(default_audio_rate, tf.int64) + audio = tfio.audio.resample( + input=audio, rate_in=default_audio_rate, rate_out=sampling_rate + ) + return audio + + +def paths_and_labels_to_dataset( + file_paths, + labels, + label_mode, + num_classes, + sampling_rate, + output_sequence_length, + ragged, + shuffle=False, + shuffle_buffer_size=None, + seed=None, +): + """Constructs a fixed-size dataset of audio and labels.""" + path_ds = tf.data.Dataset.from_tensor_slices(file_paths) + if label_mode: + label_ds = dataset_utils.labels_to_dataset( + labels, label_mode, num_classes + ) + ds = tf.data.Dataset.zip((path_ds, label_ds)) + else: + ds = path_ds + + if shuffle: + ds = ds.shuffle(buffer_size=shuffle_buffer_size or 1024, seed=seed) + + if label_mode: + ds = ds.map( + lambda x, y: ( + read_and_decode_audio(x, sampling_rate, output_sequence_length), + y, + ), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + if ragged: + ds = ds.map( + lambda x, y: (tf.RaggedTensor.from_tensor(x), y), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + else: + ds = ds.map( + lambda x: read_and_decode_audio( + x, sampling_rate, output_sequence_length + ), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + if ragged: + ds = ds.map( + lambda x: tf.RaggedTensor.from_tensor(x), + num_parallel_calls=tf.data.AUTOTUNE, + ) + + return ds