"""Object config serialization and deserialization logic.""" import importlib import inspect import types import warnings import numpy as np from keras.src import api_export from keras.src import backend from keras.src.api_export import keras_export from keras.src.backend.common import global_state from keras.src.saving import object_registration from keras.src.utils import python_utils from keras.src.utils.module_utils import tensorflow as tf PLAIN_TYPES = (str, int, float, bool) # List of Keras modules with built-in string representations for Keras defaults BUILTIN_MODULES = ( "activations", "constraints", "initializers", "losses", "metrics", "optimizers", "regularizers", ) class SerializableDict: def __init__(self, **config): self.config = config def serialize(self): return serialize_keras_object(self.config) class SafeModeScope: """Scope to propagate safe mode flag to nested deserialization calls.""" def __init__(self, safe_mode=True): self.safe_mode = safe_mode def __enter__(self): self.original_value = in_safe_mode() global_state.set_global_attribute("safe_mode_saving", self.safe_mode) def __exit__(self, *args, **kwargs): global_state.set_global_attribute( "safe_mode_saving", self.original_value ) @keras_export("keras.config.enable_unsafe_deserialization") def enable_unsafe_deserialization(): """Disables safe mode globally, allowing deserialization of lambdas.""" global_state.set_global_attribute("safe_mode_saving", False) def in_safe_mode(): return global_state.get_global_attribute("safe_mode_saving") class ObjectSharingScope: """Scope to enable detection and reuse of previously seen objects.""" def __enter__(self): global_state.set_global_attribute("shared_objects/id_to_obj_map", {}) global_state.set_global_attribute("shared_objects/id_to_config_map", {}) def __exit__(self, *args, **kwargs): global_state.set_global_attribute("shared_objects/id_to_obj_map", None) global_state.set_global_attribute( "shared_objects/id_to_config_map", None ) def get_shared_object(obj_id): """Retrieve an object previously seen during deserialization.""" id_to_obj_map = global_state.get_global_attribute( "shared_objects/id_to_obj_map" ) if id_to_obj_map is not None: return id_to_obj_map.get(obj_id, None) def record_object_after_serialization(obj, config): """Call after serializing an object, to keep track of its config.""" if config["module"] == "__main__": config["module"] = None # Ensures module is None when no module found id_to_config_map = global_state.get_global_attribute( "shared_objects/id_to_config_map" ) if id_to_config_map is None: return # Not in a sharing scope obj_id = int(id(obj)) if obj_id not in id_to_config_map: id_to_config_map[obj_id] = config else: config["shared_object_id"] = obj_id prev_config = id_to_config_map[obj_id] prev_config["shared_object_id"] = obj_id def record_object_after_deserialization(obj, obj_id): """Call after deserializing an object, to keep track of it in the future.""" id_to_obj_map = global_state.get_global_attribute( "shared_objects/id_to_obj_map" ) if id_to_obj_map is None: return # Not in a sharing scope id_to_obj_map[obj_id] = obj @keras_export( [ "keras.saving.serialize_keras_object", "keras.utils.serialize_keras_object", ] ) def serialize_keras_object(obj): """Retrieve the config dict by serializing the Keras object. `serialize_keras_object()` serializes a Keras object to a python dictionary that represents the object, and is a reciprocal function of `deserialize_keras_object()`. See `deserialize_keras_object()` for more information about the config format. Args: obj: the Keras object to serialize. Returns: A python dict that represents the object. The python dict can be deserialized via `deserialize_keras_object()`. """ if obj is None: return obj if isinstance(obj, PLAIN_TYPES): return obj if isinstance(obj, (list, tuple)): config_arr = [serialize_keras_object(x) for x in obj] return tuple(config_arr) if isinstance(obj, tuple) else config_arr if isinstance(obj, dict): return serialize_dict(obj) # Special cases: if isinstance(obj, bytes): return { "class_name": "__bytes__", "config": {"value": obj.decode("utf-8")}, } if isinstance(obj, slice): return { "class_name": "__slice__", "config": { "start": serialize_keras_object(obj.start), "stop": serialize_keras_object(obj.stop), "step": serialize_keras_object(obj.step), }, } # Ellipsis is an instance, and ellipsis class is not in global scope. # checking equality also fails elsewhere in the library, so we have # to dynamically get the type. if isinstance(obj, type(Ellipsis)): return {"class_name": "__ellipsis__", "config": {}} if isinstance(obj, backend.KerasTensor): history = getattr(obj, "_keras_history", None) if history: history = list(history) history[0] = history[0].name return { "class_name": "__keras_tensor__", "config": { "shape": obj.shape, "dtype": obj.dtype, "keras_history": history, }, } if tf.available and isinstance(obj, tf.TensorShape): return obj.as_list() if obj._dims is not None else None if backend.is_tensor(obj): return { "class_name": "__tensor__", "config": { "value": backend.convert_to_numpy(obj).tolist(), "dtype": backend.standardize_dtype(obj.dtype), }, } if type(obj).__module__ == np.__name__: if isinstance(obj, np.ndarray) and obj.ndim > 0: return { "class_name": "__numpy__", "config": { "value": obj.tolist(), "dtype": backend.standardize_dtype(obj.dtype), }, } else: # Treat numpy floats / etc as plain types. return obj.item() if tf.available and isinstance(obj, tf.DType): return obj.name if isinstance(obj, types.FunctionType) and obj.__name__ == "": warnings.warn( "The object being serialized includes a `lambda`. This is unsafe. " "In order to reload the object, you will have to pass " "`safe_mode=False` to the loading function. " "Please avoid using `lambda` in the " "future, and use named Python functions instead. " f"This is the `lambda` being serialized: {inspect.getsource(obj)}", stacklevel=2, ) return { "class_name": "__lambda__", "config": { "value": python_utils.func_dump(obj), }, } if tf.available and isinstance(obj, tf.TypeSpec): ts_config = obj._serialize() # TensorShape and tf.DType conversion ts_config = list( map( lambda x: ( x.as_list() if isinstance(x, tf.TensorShape) else (x.name if isinstance(x, tf.DType) else x) ), ts_config, ) ) return { "class_name": "__typespec__", "spec_name": obj.__class__.__name__, "module": obj.__class__.__module__, "config": ts_config, "registered_name": None, } inner_config = _get_class_or_fn_config(obj) config_with_public_class = serialize_with_public_class( obj.__class__, inner_config ) if config_with_public_class is not None: get_build_and_compile_config(obj, config_with_public_class) record_object_after_serialization(obj, config_with_public_class) return config_with_public_class # Any custom object or otherwise non-exported object if isinstance(obj, types.FunctionType): module = obj.__module__ else: module = obj.__class__.__module__ class_name = obj.__class__.__name__ if module == "builtins": registered_name = None else: if isinstance(obj, types.FunctionType): registered_name = object_registration.get_registered_name(obj) else: registered_name = object_registration.get_registered_name( obj.__class__ ) config = { "module": module, "class_name": class_name, "config": inner_config, "registered_name": registered_name, } get_build_and_compile_config(obj, config) record_object_after_serialization(obj, config) return config def get_build_and_compile_config(obj, config): if hasattr(obj, "get_build_config"): build_config = obj.get_build_config() if build_config is not None: config["build_config"] = serialize_dict(build_config) if hasattr(obj, "get_compile_config"): compile_config = obj.get_compile_config() if compile_config is not None: config["compile_config"] = serialize_dict(compile_config) return def serialize_with_public_class(cls, inner_config=None): """Serializes classes from public Keras API or object registration. Called to check and retrieve the config of any class that has a public Keras API or has been registered as serializable via `keras.saving.register_keras_serializable()`. """ # This gets the `keras.*` exported name, such as # "keras.optimizers.Adam". keras_api_name = api_export.get_name_from_symbol(cls) # Case of custom or unknown class object if keras_api_name is None: registered_name = object_registration.get_registered_name(cls) if registered_name is None: return None # Return custom object config with corresponding registration name return { "module": cls.__module__, "class_name": cls.__name__, "config": inner_config, "registered_name": registered_name, } # Split the canonical Keras API name into a Keras module and class name. parts = keras_api_name.split(".") return { "module": ".".join(parts[:-1]), "class_name": parts[-1], "config": inner_config, "registered_name": None, } def serialize_with_public_fn(fn, config, fn_module_name=None): """Serializes functions from public Keras API or object registration. Called to check and retrieve the config of any function that has a public Keras API or has been registered as serializable via `keras.saving.register_keras_serializable()`. If function's module name is already known, returns corresponding config. """ if fn_module_name: return { "module": fn_module_name, "class_name": "function", "config": config, "registered_name": config, } keras_api_name = api_export.get_name_from_symbol(fn) if keras_api_name: parts = keras_api_name.split(".") return { "module": ".".join(parts[:-1]), "class_name": "function", "config": config, "registered_name": config, } else: registered_name = object_registration.get_registered_name(fn) if not registered_name and not fn.__module__ == "builtins": return None return { "module": fn.__module__, "class_name": "function", "config": config, "registered_name": registered_name, } def _get_class_or_fn_config(obj): """Return the object's config depending on its type.""" # Functions / lambdas: if isinstance(obj, types.FunctionType): return object_registration.get_registered_name(obj) # All classes: if hasattr(obj, "get_config"): config = obj.get_config() if not isinstance(config, dict): raise TypeError( f"The `get_config()` method of {obj} should return " f"a dict. It returned: {config}" ) return serialize_dict(config) elif hasattr(obj, "__name__"): return object_registration.get_registered_name(obj) else: raise TypeError( f"Cannot serialize object {obj} of type {type(obj)}. " "To be serializable, " "a class must implement the `get_config()` method." ) def serialize_dict(obj): return {key: serialize_keras_object(value) for key, value in obj.items()} @keras_export( [ "keras.saving.deserialize_keras_object", "keras.utils.deserialize_keras_object", ] ) def deserialize_keras_object( config, custom_objects=None, safe_mode=True, **kwargs ): """Retrieve the object by deserializing the config dict. The config dict is a Python dictionary that consists of a set of key-value pairs, and represents a Keras object, such as an `Optimizer`, `Layer`, `Metrics`, etc. The saving and loading library uses the following keys to record information of a Keras object: - `class_name`: String. This is the name of the class, as exactly defined in the source code, such as "LossesContainer". - `config`: Dict. Library-defined or user-defined key-value pairs that store the configuration of the object, as obtained by `object.get_config()`. - `module`: String. The path of the python module. Built-in Keras classes expect to have prefix `keras`. - `registered_name`: String. The key the class is registered under via `keras.saving.register_keras_serializable(package, name)` API. The key has the format of '{package}>{name}', where `package` and `name` are the arguments passed to `register_keras_serializable()`. If `name` is not provided, it uses the class name. If `registered_name` successfully resolves to a class (that was registered), the `class_name` and `config` values in the dict will not be used. `registered_name` is only used for non-built-in classes. For example, the following dictionary represents the built-in Adam optimizer with the relevant config: ```python dict_structure = { "class_name": "Adam", "config": { "amsgrad": false, "beta_1": 0.8999999761581421, "beta_2": 0.9990000128746033, "decay": 0.0, "epsilon": 1e-07, "learning_rate": 0.0010000000474974513, "name": "Adam" }, "module": "keras.optimizers", "registered_name": None } # Returns an `Adam` instance identical to the original one. deserialize_keras_object(dict_structure) ``` If the class does not have an exported Keras namespace, the library tracks it by its `module` and `class_name`. For example: ```python dict_structure = { "class_name": "MetricsList", "config": { ... }, "module": "keras.trainers.compile_utils", "registered_name": "MetricsList" } # Returns a `MetricsList` instance identical to the original one. deserialize_keras_object(dict_structure) ``` And the following dictionary represents a user-customized `MeanSquaredError` loss: ```python @keras.saving.register_keras_serializable(package='my_package') class ModifiedMeanSquaredError(keras.losses.MeanSquaredError): ... dict_structure = { "class_name": "ModifiedMeanSquaredError", "config": { "fn": "mean_squared_error", "name": "mean_squared_error", "reduction": "auto" }, "registered_name": "my_package>ModifiedMeanSquaredError" } # Returns the `ModifiedMeanSquaredError` object deserialize_keras_object(dict_structure) ``` Args: config: Python dict describing the object. custom_objects: Python dict containing a mapping between custom object names the corresponding classes or functions. safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. When `safe_mode=False`, loading an object has the potential to trigger arbitrary code execution. This argument is only applicable to the Keras v3 model format. Defaults to `True`. Returns: The object described by the `config` dictionary. """ safe_scope_arg = in_safe_mode() # Enforces SafeModeScope safe_mode = safe_scope_arg if safe_scope_arg is not None else safe_mode module_objects = kwargs.pop("module_objects", None) custom_objects = custom_objects or {} tlco = global_state.get_global_attribute("custom_objects_scope_dict", {}) gco = object_registration.GLOBAL_CUSTOM_OBJECTS custom_objects = {**custom_objects, **tlco, **gco} if config is None: return None if ( isinstance(config, str) and custom_objects and custom_objects.get(config) is not None ): # This is to deserialize plain functions which are serialized as # string names by legacy saving formats. return custom_objects[config] if isinstance(config, (list, tuple)): return [ deserialize_keras_object( x, custom_objects=custom_objects, safe_mode=safe_mode ) for x in config ] if module_objects is not None: inner_config, fn_module_name, has_custom_object = None, None, False if isinstance(config, dict): if "config" in config: inner_config = config["config"] if "class_name" not in config: raise ValueError( f"Unknown `config` as a `dict`, config={config}" ) # Check case where config is function or class and in custom objects if custom_objects and ( config["class_name"] in custom_objects or config.get("registered_name") in custom_objects or ( isinstance(inner_config, str) and inner_config in custom_objects ) ): has_custom_object = True # Case where config is function but not in custom objects elif config["class_name"] == "function": fn_module_name = config["module"] if fn_module_name == "builtins": config = config["config"] else: config = config["registered_name"] # Case where config is class but not in custom objects else: if config.get("module", "_") is None: raise TypeError( "Cannot deserialize object of type " f"`{config['class_name']}`. If " f"`{config['class_name']}` is a custom class, please " "register it using the " "`@keras.saving.register_keras_serializable()` " "decorator." ) config = config["class_name"] if not has_custom_object: # Return if not found in either module objects or custom objects if config not in module_objects: # Object has already been deserialized return config if isinstance(module_objects[config], types.FunctionType): return deserialize_keras_object( serialize_with_public_fn( module_objects[config], config, fn_module_name ), custom_objects=custom_objects, ) return deserialize_keras_object( serialize_with_public_class( module_objects[config], inner_config=inner_config ), custom_objects=custom_objects, ) if isinstance(config, PLAIN_TYPES): return config if not isinstance(config, dict): raise TypeError(f"Could not parse config: {config}") if "class_name" not in config or "config" not in config: return { key: deserialize_keras_object( value, custom_objects=custom_objects, safe_mode=safe_mode ) for key, value in config.items() } class_name = config["class_name"] inner_config = config["config"] or {} custom_objects = custom_objects or {} # Special cases: if class_name == "__keras_tensor__": obj = backend.KerasTensor( inner_config["shape"], dtype=inner_config["dtype"] ) obj._pre_serialization_keras_history = inner_config["keras_history"] return obj if class_name == "__tensor__": return backend.convert_to_tensor( inner_config["value"], dtype=inner_config["dtype"] ) if class_name == "__numpy__": return np.array(inner_config["value"], dtype=inner_config["dtype"]) if config["class_name"] == "__bytes__": return inner_config["value"].encode("utf-8") if config["class_name"] == "__ellipsis__": return Ellipsis if config["class_name"] == "__slice__": return slice( deserialize_keras_object( inner_config["start"], custom_objects=custom_objects, safe_mode=safe_mode, ), deserialize_keras_object( inner_config["stop"], custom_objects=custom_objects, safe_mode=safe_mode, ), deserialize_keras_object( inner_config["step"], custom_objects=custom_objects, safe_mode=safe_mode, ), ) if config["class_name"] == "__lambda__": if safe_mode: raise ValueError( "Requested the deserialization of a `lambda` object. " "This carries a potential risk of arbitrary code execution " "and thus it is disallowed by default. If you trust the " "source of the saved model, you can pass `safe_mode=False` to " "the loading function in order to allow `lambda` loading, " "or call `keras.config.enable_unsafe_deserialization()`." ) return python_utils.func_load(inner_config["value"]) if tf is not None and config["class_name"] == "__typespec__": obj = _retrieve_class_or_fn( config["spec_name"], config["registered_name"], config["module"], obj_type="class", full_config=config, custom_objects=custom_objects, ) # Conversion to TensorShape and DType inner_config = map( lambda x: ( tf.TensorShape(x) if isinstance(x, list) else (getattr(tf, x) if hasattr(tf.dtypes, str(x)) else x) ), inner_config, ) return obj._deserialize(tuple(inner_config)) # Below: classes and functions. module = config.get("module", None) registered_name = config.get("registered_name", class_name) if class_name == "function": fn_name = inner_config return _retrieve_class_or_fn( fn_name, registered_name, module, obj_type="function", full_config=config, custom_objects=custom_objects, ) # Below, handling of all classes. # First, is it a shared object? if "shared_object_id" in config: obj = get_shared_object(config["shared_object_id"]) if obj is not None: return obj cls = _retrieve_class_or_fn( class_name, registered_name, module, obj_type="class", full_config=config, custom_objects=custom_objects, ) if isinstance(cls, types.FunctionType): return cls if not hasattr(cls, "from_config"): raise TypeError( f"Unable to reconstruct an instance of '{class_name}' because " f"the class is missing a `from_config()` method. " f"Full object config: {config}" ) # Instantiate the class from its config inside a custom object scope # so that we can catch any custom objects that the config refers to. custom_obj_scope = object_registration.CustomObjectScope(custom_objects) safe_mode_scope = SafeModeScope(safe_mode) with custom_obj_scope, safe_mode_scope: try: instance = cls.from_config(inner_config) except TypeError as e: raise TypeError( f"{cls} could not be deserialized properly. Please" " ensure that components that are Python object" " instances (layers, models, etc.) returned by" " `get_config()` are explicitly deserialized in the" " model's `from_config()` method." f"\n\nconfig={config}.\n\nException encountered: {e}" ) build_config = config.get("build_config", None) if build_config and not instance.built: instance.build_from_config(build_config) instance.built = True compile_config = config.get("compile_config", None) if compile_config: instance.compile_from_config(compile_config) instance.compiled = True if "shared_object_id" in config: record_object_after_deserialization( instance, config["shared_object_id"] ) return instance def _retrieve_class_or_fn( name, registered_name, module, obj_type, full_config, custom_objects=None ): # If there is a custom object registered via # `register_keras_serializable()`, that takes precedence. if obj_type == "function": custom_obj = object_registration.get_registered_object( name, custom_objects=custom_objects ) else: custom_obj = object_registration.get_registered_object( registered_name, custom_objects=custom_objects ) if custom_obj is not None: return custom_obj if module: # If it's a Keras built-in object, # we cannot always use direct import, because the exported # module name might not match the package structure # (e.g. experimental symbols). if module == "keras" or module.startswith("keras."): api_name = module + "." + name obj = api_export.get_symbol_from_name(api_name) if obj is not None: return obj # Configs of Keras built-in functions do not contain identifying # information other than their name (e.g. 'acc' or 'tanh'). This special # case searches the Keras modules that contain built-ins to retrieve # the corresponding function from the identifying string. if obj_type == "function" and module == "builtins": for mod in BUILTIN_MODULES: obj = api_export.get_symbol_from_name( "keras." + mod + "." + name ) if obj is not None: return obj # Workaround for serialization bug in Keras <= 3.6 whereby custom # functions would only be saved by name instead of registered name, # i.e. "name" instead of "package>name". This allows recent versions # of Keras to reload models saved with 3.6 and lower. if ">" not in name: separated_name = ">" + name for custom_name, custom_object in custom_objects.items(): if custom_name.endswith(separated_name): return custom_object # Otherwise, attempt to retrieve the class object given the `module` # and `class_name`. Import the module, find the class. package = module.split(".", maxsplit=1)[0] if package in {"keras", "keras_hub", "keras_cv", "keras_nlp"}: try: mod = importlib.import_module(module) obj = vars(mod).get(name, None) if obj is not None: return obj except ModuleNotFoundError: raise TypeError( f"Could not deserialize {obj_type} '{name}' because " f"its parent module {module} cannot be imported. " f"Full object config: {full_config}" ) raise TypeError( f"Could not locate {obj_type} '{name}'. " "Make sure custom classes are decorated with " "`@keras.saving.register_keras_serializable()`. " f"Full object config: {full_config}" )