File size: 7,358 Bytes
1f5470c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
import inspect

from keras.src.api_export import keras_export
from keras.src.backend.common import global_state

GLOBAL_CUSTOM_OBJECTS = {}
GLOBAL_CUSTOM_NAMES = {}


@keras_export(
    [
        "keras.saving.CustomObjectScope",
        "keras.saving.custom_object_scope",
        "keras.utils.CustomObjectScope",
        "keras.utils.custom_object_scope",
    ]
)
class CustomObjectScope:
    """Exposes custom classes/functions to Keras deserialization internals.

    Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
    as `keras.models.load_model()` or
    `keras.models.model_from_config()` will be able to deserialize any
    custom object referenced by a saved config (e.g. a custom layer or metric).

    Example:

    Consider a custom regularizer `my_regularizer`:

    ```python
    layer = Dense(3, kernel_regularizer=my_regularizer)
    # Config contains a reference to `my_regularizer`
    config = layer.get_config()
    ...
    # Later:
    with custom_object_scope({'my_regularizer': my_regularizer}):
        layer = Dense.from_config(config)
    ```

    Args:
        custom_objects: Dictionary of `{str: object}` pairs,
            where the `str` key is the object name.
    """

    def __init__(self, custom_objects):
        self.custom_objects = custom_objects or {}
        self.backup = None

    def __enter__(self):
        self.backup = global_state.get_global_attribute(
            "custom_objects_scope_dict", {}
        ).copy()
        global_state.set_global_attribute(
            "custom_objects_scope_dict", self.custom_objects.copy()
        )
        return self

    def __exit__(self, *args, **kwargs):
        global_state.set_global_attribute(
            "custom_objects_scope_dict", self.backup.copy()
        )


# Alias.
custom_object_scope = CustomObjectScope


@keras_export(
    [
        "keras.saving.get_custom_objects",
        "keras.utils.get_custom_objects",
    ]
)
def get_custom_objects():
    """Retrieves a live reference to the global dictionary of custom objects.

    Custom objects set using `custom_object_scope()` are not added to the
    global dictionary of custom objects, and will not appear in the returned
    dictionary.

    Example:

    ```python
    get_custom_objects().clear()
    get_custom_objects()['MyObject'] = MyObject
    ```

    Returns:
        Global dictionary mapping registered class names to classes.
    """
    return GLOBAL_CUSTOM_OBJECTS


@keras_export(
    [
        "keras.saving.register_keras_serializable",
        "keras.utils.register_keras_serializable",
    ]
)
def register_keras_serializable(package="Custom", name=None):
    """Registers an object with the Keras serialization framework.

    This decorator injects the decorated class or function into the Keras custom
    object dictionary, so that it can be serialized and deserialized without
    needing an entry in the user-provided custom object dict. It also injects a
    function that Keras will call to get the object's serializable string key.

    Note that to be serialized and deserialized, classes must implement the
    `get_config()` method. Functions do not have this requirement.

    The object will be registered under the key `'package>name'` where `name`,
    defaults to the object name if not passed.

    Example:

    ```python
    # Note that `'my_package'` is used as the `package` argument here, and since
    # the `name` argument is not provided, `'MyDense'` is used as the `name`.
    @register_keras_serializable('my_package')
    class MyDense(keras.layers.Dense):
        pass

    assert get_registered_object('my_package>MyDense') == MyDense
    assert get_registered_name(MyDense) == 'my_package>MyDense'
    ```

    Args:
        package: The package that this class belongs to. This is used for the
            `key` (which is `"package>name"`) to identify the class. Note that
            this is the first argument passed into the decorator.
        name: The name to serialize this class under in this package. If not
            provided or `None`, the class' name will be used (note that this is
            the case when the decorator is used with only one argument, which
            becomes the `package`).

    Returns:
        A decorator that registers the decorated class with the passed names.
    """

    def decorator(arg):
        """Registers a class with the Keras serialization framework."""
        class_name = name if name is not None else arg.__name__
        registered_name = package + ">" + class_name

        if inspect.isclass(arg) and not hasattr(arg, "get_config"):
            raise ValueError(
                "Cannot register a class that does not have a "
                "get_config() method."
            )

        GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
        GLOBAL_CUSTOM_NAMES[arg] = registered_name

        return arg

    return decorator


@keras_export(
    [
        "keras.saving.get_registered_name",
        "keras.utils.get_registered_name",
    ]
)
def get_registered_name(obj):
    """Returns the name registered to an object within the Keras framework.

    This function is part of the Keras serialization and deserialization
    framework. It maps objects to the string names associated with those objects
    for serialization/deserialization.

    Args:
        obj: The object to look up.

    Returns:
        The name associated with the object, or the default Python name if the
            object is not registered.
    """
    if obj in GLOBAL_CUSTOM_NAMES:
        return GLOBAL_CUSTOM_NAMES[obj]
    else:
        return obj.__name__


@keras_export(
    [
        "keras.saving.get_registered_object",
        "keras.utils.get_registered_object",
    ]
)
def get_registered_object(name, custom_objects=None, module_objects=None):
    """Returns the class associated with `name` if it is registered with Keras.

    This function is part of the Keras serialization and deserialization
    framework. It maps strings to the objects associated with them for
    serialization/deserialization.

    Example:

    ```python
    def from_config(cls, config, custom_objects=None):
        if 'my_custom_object_name' in config:
            config['hidden_cls'] = tf.keras.saving.get_registered_object(
                config['my_custom_object_name'], custom_objects=custom_objects)
    ```

    Args:
        name: The name to look up.
        custom_objects: A dictionary of custom objects to look the name up in.
            Generally, custom_objects is provided by the user.
        module_objects: A dictionary of custom objects to look the name up in.
            Generally, module_objects is provided by midlevel library
            implementers.

    Returns:
        An instantiable class associated with `name`, or `None` if no such class
            exists.
    """
    custom_objects_scope_dict = global_state.get_global_attribute(
        "custom_objects_scope_dict", {}
    )
    if name in custom_objects_scope_dict:
        return custom_objects_scope_dict[name]
    elif name in GLOBAL_CUSTOM_OBJECTS:
        return GLOBAL_CUSTOM_OBJECTS[name]
    elif custom_objects and name in custom_objects:
        return custom_objects[name]
    elif module_objects and name in module_objects:
        return module_objects[name]
    return None