claudeson / claudson /ai /lib /python3.12 /site-packages /keras /src /legacy /saving /serialization.py
| """Legacy serialization logic for Keras models.""" | |
| import contextlib | |
| import inspect | |
| import json | |
| import threading | |
| import weakref | |
| # isort: off | |
| from keras.src.api_export import keras_export | |
| from keras.src.saving import object_registration | |
| # Flag that determines whether to skip the NotImplementedError when calling | |
| # get_config in custom models and layers. This is only enabled when saving to | |
| # SavedModel, when the config isn't required. | |
| _SKIP_FAILED_SERIALIZATION = False | |
| # If a layer does not have a defined config, then the returned config will be a | |
| # dictionary with the below key. | |
| _LAYER_UNDEFINED_CONFIG_KEY = "layer was saved without config" | |
| # Store a unique, per-object ID for shared objects. | |
| # | |
| # We store a unique ID for each object so that we may, at loading time, | |
| # re-create the network properly. Without this ID, we would have no way of | |
| # determining whether a config is a description of a new object that | |
| # should be created or is merely a reference to an already-created object. | |
| SHARED_OBJECT_KEY = "shared_object_id" | |
| SHARED_OBJECT_DISABLED = threading.local() | |
| SHARED_OBJECT_LOADING = threading.local() | |
| SHARED_OBJECT_SAVING = threading.local() | |
| # Attributes on the threadlocal variable must be set per-thread, thus we | |
| # cannot initialize these globally. Instead, we have accessor functions with | |
| # default values. | |
| def _shared_object_disabled(): | |
| """Get whether shared object handling is disabled in a threadsafe manner.""" | |
| return getattr(SHARED_OBJECT_DISABLED, "disabled", False) | |
| def _shared_object_loading_scope(): | |
| """Get the current shared object saving scope in a threadsafe manner.""" | |
| return getattr(SHARED_OBJECT_LOADING, "scope", NoopLoadingScope()) | |
| def _shared_object_saving_scope(): | |
| """Get the current shared object saving scope in a threadsafe manner.""" | |
| return getattr(SHARED_OBJECT_SAVING, "scope", None) | |
| class DisableSharedObjectScope: | |
| """A context manager for disabling handling of shared objects. | |
| Disables shared object handling for both saving and loading. | |
| Created primarily for use with `clone_model`, which does extra surgery that | |
| is incompatible with shared objects. | |
| """ | |
| def __enter__(self): | |
| SHARED_OBJECT_DISABLED.disabled = True | |
| self._orig_loading_scope = _shared_object_loading_scope() | |
| self._orig_saving_scope = _shared_object_saving_scope() | |
| def __exit__(self, *args, **kwargs): | |
| SHARED_OBJECT_DISABLED.disabled = False | |
| SHARED_OBJECT_LOADING.scope = self._orig_loading_scope | |
| SHARED_OBJECT_SAVING.scope = self._orig_saving_scope | |
| class NoopLoadingScope: | |
| """The default shared object loading scope. It does nothing. | |
| Created to simplify serialization code that doesn't care about shared | |
| objects (e.g. when serializing a single object). | |
| """ | |
| def get(self, unused_object_id): | |
| return None | |
| def set(self, object_id, obj): | |
| pass | |
| class SharedObjectLoadingScope: | |
| """A context manager for keeping track of loaded objects. | |
| During the deserialization process, we may come across objects that are | |
| shared across multiple layers. In order to accurately restore the network | |
| structure to its original state, `SharedObjectLoadingScope` allows us to | |
| re-use shared objects rather than cloning them. | |
| """ | |
| def __enter__(self): | |
| if _shared_object_disabled(): | |
| return NoopLoadingScope() | |
| global SHARED_OBJECT_LOADING | |
| SHARED_OBJECT_LOADING.scope = self | |
| self._obj_ids_to_obj = {} | |
| return self | |
| def get(self, object_id): | |
| """Given a shared object ID, returns a previously instantiated object. | |
| Args: | |
| object_id: shared object ID to use when attempting to find | |
| already-loaded object. | |
| Returns: | |
| The object, if we've seen this ID before. Else, `None`. | |
| """ | |
| # Explicitly check for `None` internally to make external calling code a | |
| # bit cleaner. | |
| if object_id is None: | |
| return | |
| return self._obj_ids_to_obj.get(object_id) | |
| def set(self, object_id, obj): | |
| """Stores an instantiated object for future lookup and sharing.""" | |
| if object_id is None: | |
| return | |
| self._obj_ids_to_obj[object_id] = obj | |
| def __exit__(self, *args, **kwargs): | |
| global SHARED_OBJECT_LOADING | |
| SHARED_OBJECT_LOADING.scope = NoopLoadingScope() | |
| class SharedObjectConfig(dict): | |
| """A configuration container that keeps track of references. | |
| `SharedObjectConfig` will automatically attach a shared object ID to any | |
| configs which are referenced more than once, allowing for proper shared | |
| object reconstruction at load time. | |
| In most cases, it would be more proper to subclass something like | |
| `collections.UserDict` or `collections.Mapping` rather than `dict` directly. | |
| Unfortunately, python's json encoder does not support `Mapping`s. This is | |
| important functionality to retain, since we are dealing with serialization. | |
| We should be safe to subclass `dict` here, since we aren't actually | |
| overriding any core methods, only augmenting with a new one for reference | |
| counting. | |
| """ | |
| def __init__(self, base_config, object_id, **kwargs): | |
| self.ref_count = 1 | |
| self.object_id = object_id | |
| super().__init__(base_config, **kwargs) | |
| def increment_ref_count(self): | |
| # As soon as we've seen the object more than once, we want to attach the | |
| # shared object ID. This allows us to only attach the shared object ID | |
| # when it's strictly necessary, making backwards compatibility breakage | |
| # less likely. | |
| if self.ref_count == 1: | |
| self[SHARED_OBJECT_KEY] = self.object_id | |
| self.ref_count += 1 | |
| class SharedObjectSavingScope: | |
| """Keeps track of shared object configs when serializing.""" | |
| def __enter__(self): | |
| if _shared_object_disabled(): | |
| return None | |
| global SHARED_OBJECT_SAVING | |
| # Serialization can happen at a number of layers for a number of | |
| # reasons. We may end up with a case where we're opening a saving scope | |
| # within another saving scope. In that case, we'd like to use the | |
| # outermost scope available and ignore inner scopes, since there is not | |
| # (yet) a reasonable use case for having these nested and distinct. | |
| if _shared_object_saving_scope() is not None: | |
| self._passthrough = True | |
| return _shared_object_saving_scope() | |
| else: | |
| self._passthrough = False | |
| SHARED_OBJECT_SAVING.scope = self | |
| self._shared_objects_config = weakref.WeakKeyDictionary() | |
| self._next_id = 0 | |
| return self | |
| def get_config(self, obj): | |
| """Gets a `SharedObjectConfig` if one has already been seen for `obj`. | |
| Args: | |
| obj: The object for which to retrieve the `SharedObjectConfig`. | |
| Returns: | |
| The SharedObjectConfig for a given object, if already seen. Else, | |
| `None`. | |
| """ | |
| try: | |
| shared_object_config = self._shared_objects_config[obj] | |
| except (TypeError, KeyError): | |
| # If the object is unhashable (e.g. a subclass of | |
| # `AbstractBaseClass` that has not overridden `__hash__`), a | |
| # `TypeError` will be thrown. We'll just continue on without shared | |
| # object support. | |
| return None | |
| shared_object_config.increment_ref_count() | |
| return shared_object_config | |
| def create_config(self, base_config, obj): | |
| """Create a new SharedObjectConfig for a given object.""" | |
| shared_object_config = SharedObjectConfig(base_config, self._next_id) | |
| self._next_id += 1 | |
| try: | |
| self._shared_objects_config[obj] = shared_object_config | |
| except TypeError: | |
| # If the object is unhashable (e.g. a subclass of | |
| # `AbstractBaseClass` that has not overridden `__hash__`), a | |
| # `TypeError` will be thrown. We'll just continue on without shared | |
| # object support. | |
| pass | |
| return shared_object_config | |
| def __exit__(self, *args, **kwargs): | |
| if not getattr(self, "_passthrough", False): | |
| global SHARED_OBJECT_SAVING | |
| SHARED_OBJECT_SAVING.scope = None | |
| def serialize_keras_class_and_config( | |
| cls_name, cls_config, obj=None, shared_object_id=None | |
| ): | |
| """Returns the serialization of the class with the given config.""" | |
| base_config = {"class_name": cls_name, "config": cls_config} | |
| # We call `serialize_keras_class_and_config` for some branches of the load | |
| # path. In that case, we may already have a shared object ID we'd like to | |
| # retain. | |
| if shared_object_id is not None: | |
| base_config[SHARED_OBJECT_KEY] = shared_object_id | |
| # If we have an active `SharedObjectSavingScope`, check whether we've | |
| # already serialized this config. If so, just use that config. This will | |
| # store an extra ID field in the config, allowing us to re-create the shared | |
| # object relationship at load time. | |
| if _shared_object_saving_scope() is not None and obj is not None: | |
| shared_object_config = _shared_object_saving_scope().get_config(obj) | |
| if shared_object_config is None: | |
| return _shared_object_saving_scope().create_config(base_config, obj) | |
| return shared_object_config | |
| return base_config | |
| def skip_failed_serialization(): | |
| global _SKIP_FAILED_SERIALIZATION | |
| prev = _SKIP_FAILED_SERIALIZATION | |
| try: | |
| _SKIP_FAILED_SERIALIZATION = True | |
| yield | |
| finally: | |
| _SKIP_FAILED_SERIALIZATION = prev | |
| def serialize_keras_object(instance): | |
| """Serialize a Keras object into a JSON-compatible representation. | |
| Calls to `serialize_keras_object` while underneath the | |
| `SharedObjectSavingScope` context manager will cause any objects re-used | |
| across multiple layers to be saved with a special shared object ID. This | |
| allows the network to be re-created properly during deserialization. | |
| Args: | |
| instance: The object to serialize. | |
| Returns: | |
| A dict-like, JSON-compatible representation of the object's config. | |
| """ | |
| # _, instance = tf.__internal__.decorator.unwrap(instance) | |
| instance = inspect.unwrap(instance) | |
| if instance is None: | |
| return None | |
| if hasattr(instance, "get_config"): | |
| name = object_registration.get_registered_name(instance.__class__) | |
| try: | |
| config = instance.get_config() | |
| except NotImplementedError as e: | |
| if _SKIP_FAILED_SERIALIZATION: | |
| return serialize_keras_class_and_config( | |
| name, {_LAYER_UNDEFINED_CONFIG_KEY: True} | |
| ) | |
| raise e | |
| serialization_config = {} | |
| for key, item in config.items(): | |
| if isinstance(item, str): | |
| serialization_config[key] = item | |
| continue | |
| # Any object of a different type needs to be converted to string or | |
| # dict for serialization (e.g. custom functions, custom classes) | |
| try: | |
| serialized_item = serialize_keras_object(item) | |
| if isinstance(serialized_item, dict) and not isinstance( | |
| item, dict | |
| ): | |
| serialized_item["__passive_serialization__"] = True | |
| serialization_config[key] = serialized_item | |
| except ValueError: | |
| serialization_config[key] = item | |
| name = object_registration.get_registered_name(instance.__class__) | |
| return serialize_keras_class_and_config( | |
| name, serialization_config, instance | |
| ) | |
| if hasattr(instance, "__name__"): | |
| return object_registration.get_registered_name(instance) | |
| raise ValueError( | |
| f"Cannot serialize {instance} because it doesn't implement " | |
| "`get_config()`." | |
| ) | |
| def class_and_config_for_serialized_keras_object( | |
| config, | |
| module_objects=None, | |
| custom_objects=None, | |
| printable_module_name="object", | |
| ): | |
| """Returns the class name and config for a serialized keras object.""" | |
| if ( | |
| not isinstance(config, dict) | |
| or "class_name" not in config | |
| or "config" not in config | |
| ): | |
| raise ValueError( | |
| f"Improper config format for {config}. " | |
| "Expecting python dict contains `class_name` and `config` as keys" | |
| ) | |
| class_name = config["class_name"] | |
| cls = object_registration.get_registered_object( | |
| class_name, custom_objects, module_objects | |
| ) | |
| if cls is None: | |
| raise ValueError( | |
| f"Unknown {printable_module_name}: '{class_name}'. " | |
| "Please ensure you are using a `keras.utils.custom_object_scope` " | |
| "and that this object is included in the scope. See " | |
| "https://www.tensorflow.org/guide/keras/save_and_serialize" | |
| "#registering_the_custom_object for details." | |
| ) | |
| cls_config = config["config"] | |
| # Check if `cls_config` is a list. If it is a list, return the class and the | |
| # associated class configs for recursively deserialization. This case will | |
| # happen on the old version of sequential model (e.g. `keras_version` == | |
| # "2.0.6"), which is serialized in a different structure, for example | |
| # "{'class_name': 'Sequential', | |
| # 'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}". | |
| if isinstance(cls_config, list): | |
| return (cls, cls_config) | |
| deserialized_objects = {} | |
| for key, item in cls_config.items(): | |
| if key == "name": | |
| # Assume that the value of 'name' is a string that should not be | |
| # deserialized as a function. This avoids the corner case where | |
| # cls_config['name'] has an identical name to a custom function and | |
| # gets converted into that function. | |
| deserialized_objects[key] = item | |
| elif isinstance(item, dict) and "__passive_serialization__" in item: | |
| deserialized_objects[key] = deserialize_keras_object( | |
| item, | |
| module_objects=module_objects, | |
| custom_objects=custom_objects, | |
| printable_module_name="config_item", | |
| ) | |
| # TODO(momernick): Should this also have 'module_objects'? | |
| elif isinstance(item, str) and inspect.isfunction( | |
| object_registration.get_registered_object(item, custom_objects) | |
| ): | |
| # Handle custom functions here. When saving functions, we only save | |
| # the function's name as a string. If we find a matching string in | |
| # the custom objects during deserialization, we convert the string | |
| # back to the original function. | |
| # Note that a potential issue is that a string field could have a | |
| # naming conflict with a custom function name, but this should be a | |
| # rare case. This issue does not occur if a string field has a | |
| # naming conflict with a custom object, since the config of an | |
| # object will always be a dict. | |
| deserialized_objects[key] = ( | |
| object_registration.get_registered_object(item, custom_objects) | |
| ) | |
| for key, item in deserialized_objects.items(): | |
| cls_config[key] = deserialized_objects[key] | |
| return (cls, cls_config) | |
| def deserialize_keras_object( | |
| identifier, | |
| module_objects=None, | |
| custom_objects=None, | |
| printable_module_name="object", | |
| ): | |
| """Turns the serialized form of a Keras object back into an actual object. | |
| This function is for mid-level library implementers rather than end users. | |
| Importantly, this utility requires you to provide the dict of | |
| `module_objects` to use for looking up the object config; this is not | |
| populated by default. If you need a deserialization utility that has | |
| preexisting knowledge of built-in Keras objects, use e.g. | |
| `keras.layers.deserialize(config)`, `keras.metrics.deserialize(config)`, | |
| etc. | |
| Calling `deserialize_keras_object` while underneath the | |
| `SharedObjectLoadingScope` context manager will cause any already-seen | |
| shared objects to be returned as-is rather than creating a new object. | |
| Args: | |
| identifier: the serialized form of the object. | |
| module_objects: A dictionary of built-in objects to look the name up in. | |
| Generally, `module_objects` is provided by midlevel library | |
| implementers. | |
| custom_objects: A dictionary of custom objects to look the name up in. | |
| Generally, `custom_objects` is provided by the end user. | |
| printable_module_name: A human-readable string representing the type of | |
| the object. Printed in case of exception. | |
| Returns: | |
| The deserialized object. | |
| Example: | |
| A mid-level library implementer might want to implement a utility for | |
| retrieving an object from its config, as such: | |
| ```python | |
| def deserialize(config, custom_objects=None): | |
| return deserialize_keras_object( | |
| identifier, | |
| module_objects=globals(), | |
| custom_objects=custom_objects, | |
| name="MyObjectType", | |
| ) | |
| ``` | |
| This is how e.g. `keras.layers.deserialize()` is implemented. | |
| """ | |
| if identifier is None: | |
| return None | |
| if isinstance(identifier, dict): | |
| # In this case we are dealing with a Keras config dictionary. | |
| config = identifier | |
| (cls, cls_config) = class_and_config_for_serialized_keras_object( | |
| config, module_objects, custom_objects, printable_module_name | |
| ) | |
| # If this object has already been loaded (i.e. it's shared between | |
| # multiple objects), return the already-loaded object. | |
| shared_object_id = config.get(SHARED_OBJECT_KEY) | |
| shared_object = _shared_object_loading_scope().get(shared_object_id) | |
| if shared_object is not None: | |
| return shared_object | |
| if hasattr(cls, "from_config"): | |
| arg_spec = inspect.getfullargspec(cls.from_config) | |
| custom_objects = custom_objects or {} | |
| # TODO(nkovela): Swap find and replace args during Keras 3.0 release | |
| # Replace keras refs with keras | |
| cls_config = _find_replace_nested_dict( | |
| cls_config, "keras.", "keras." | |
| ) | |
| if "custom_objects" in arg_spec.args: | |
| deserialized_obj = cls.from_config( | |
| cls_config, | |
| custom_objects={ | |
| **object_registration.GLOBAL_CUSTOM_OBJECTS, | |
| **custom_objects, | |
| }, | |
| ) | |
| else: | |
| with object_registration.CustomObjectScope(custom_objects): | |
| deserialized_obj = cls.from_config(cls_config) | |
| else: | |
| # Then `cls` may be a function returning a class. | |
| # in this case by convention `config` holds | |
| # the kwargs of the function. | |
| custom_objects = custom_objects or {} | |
| with object_registration.CustomObjectScope(custom_objects): | |
| deserialized_obj = cls(**cls_config) | |
| # Add object to shared objects, in case we find it referenced again. | |
| _shared_object_loading_scope().set(shared_object_id, deserialized_obj) | |
| return deserialized_obj | |
| elif isinstance(identifier, str): | |
| object_name = identifier | |
| if custom_objects and object_name in custom_objects: | |
| obj = custom_objects.get(object_name) | |
| elif ( | |
| object_name | |
| in object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__ | |
| ): | |
| obj = object_registration._THREAD_LOCAL_CUSTOM_OBJECTS.__dict__[ | |
| object_name | |
| ] | |
| elif object_name in object_registration._GLOBAL_CUSTOM_OBJECTS: | |
| obj = object_registration._GLOBAL_CUSTOM_OBJECTS[object_name] | |
| else: | |
| obj = module_objects.get(object_name) | |
| if obj is None: | |
| raise ValueError( | |
| f"Unknown {printable_module_name}: '{object_name}'. " | |
| "Please ensure you are using a " | |
| "`keras.utils.custom_object_scope` " | |
| "and that this object is included in the scope. See " | |
| "https://www.tensorflow.org/guide/keras/save_and_serialize" | |
| "#registering_the_custom_object for details." | |
| ) | |
| # Classes passed by name are instantiated with no args, functions are | |
| # returned as-is. | |
| if inspect.isclass(obj): | |
| return obj() | |
| return obj | |
| elif inspect.isfunction(identifier): | |
| # If a function has already been deserialized, return as is. | |
| return identifier | |
| else: | |
| raise ValueError( | |
| "Could not interpret serialized " | |
| f"{printable_module_name}: {identifier}" | |
| ) | |
| def validate_config(config): | |
| """Determines whether config appears to be a valid layer config.""" | |
| return ( | |
| isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config | |
| ) | |
| def is_default(method): | |
| """Check if a method is decorated with the `default` wrapper.""" | |
| return getattr(method, "_is_default", False) | |
| def _find_replace_nested_dict(config, find, replace): | |
| dict_str = json.dumps(config) | |
| dict_str = dict_str.replace(find, replace) | |
| config = json.loads(dict_str) | |
| return config | |