File size: 10,904 Bytes
1f5470c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import os
import zipfile

from absl import logging

from keras.src.api_export import keras_export
from keras.src.legacy.saving import legacy_h5_format
from keras.src.saving import saving_lib
from keras.src.utils import file_utils
from keras.src.utils import io_utils

try:
    import h5py
except ImportError:
    h5py = None


@keras_export(["keras.saving.save_model", "keras.models.save_model"])
def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
    """Saves a model as a `.keras` file.

    Args:
        model: Keras model instance to be saved.
        filepath: `str` or `pathlib.Path` object. Path where to save the model.
        overwrite: Whether we should overwrite any existing model at the target
            location, or instead ask the user via an interactive prompt.
        zipped: Whether to save the model as a zipped `.keras`
            archive (default when saving locally), or as an unzipped directory
            (default when saving on the Hugging Face Hub).

    Example:

    ```python
    model = keras.Sequential(
        [
            keras.layers.Dense(5, input_shape=(3,)),
            keras.layers.Softmax(),
        ],
    )
    model.save("model.keras")
    loaded_model = keras.saving.load_model("model.keras")
    x = keras.random.uniform((10, 3))
    assert np.allclose(model.predict(x), loaded_model.predict(x))
    ```

    Note that `model.save()` is an alias for `keras.saving.save_model()`.

    The saved `.keras` file is a `zip` archive that contains:

    - The model's configuration (architecture)
    - The model's weights
    - The model's optimizer's state (if any)

    Thus models can be reinstantiated in the exact same state.
    """
    include_optimizer = kwargs.pop("include_optimizer", True)
    save_format = kwargs.pop("save_format", False)
    if save_format:
        if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith(
            ".keras"
        ):
            logging.warning(
                "The `save_format` argument is deprecated in Keras 3. "
                "We recommend removing this argument as it can be inferred "
                "from the file path. "
                f"Received: save_format={save_format}"
            )
        else:
            raise ValueError(
                "The `save_format` argument is deprecated in Keras 3. "
                "Please remove this argument and pass a file path with "
                "either `.keras` or `.h5` extension."
                f"Received: save_format={save_format}"
            )
    if kwargs:
        raise ValueError(
            "The following argument(s) are not supported: "
            f"{list(kwargs.keys())}"
        )

    # Deprecation warnings
    if str(filepath).endswith((".h5", ".hdf5")):
        logging.warning(
            "You are saving your model as an HDF5 file via "
            "`model.save()` or `keras.saving.save_model(model)`. "
            "This file format is considered legacy. "
            "We recommend using instead the native Keras format, "
            "e.g. `model.save('my_model.keras')` or "
            "`keras.saving.save_model(model, 'my_model.keras')`. "
        )

    is_hf = str(filepath).startswith("hf://")
    if zipped is None:
        zipped = not is_hf  # default behavior depends on destination

    # If file exists and should not be overwritten.
    try:
        exists = (not is_hf) and os.path.exists(filepath)
    except TypeError:
        exists = False
    if exists and not overwrite:
        proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
        if not proceed:
            return

    if zipped and str(filepath).endswith(".keras"):
        return saving_lib.save_model(model, filepath)
    if not zipped:
        return saving_lib.save_model(model, filepath, zipped=False)
    if str(filepath).endswith((".h5", ".hdf5")):
        return legacy_h5_format.save_model_to_hdf5(
            model, filepath, overwrite, include_optimizer
        )
    raise ValueError(
        "Invalid filepath extension for saving. "
        "Please add either a `.keras` extension for the native Keras "
        f"format (recommended) or a `.h5` extension. "
        "Use `model.export(filepath)` if you want to export a SavedModel "
        "for use with TFLite/TFServing/etc. "
        f"Received: filepath={filepath}."
    )


@keras_export(["keras.saving.load_model", "keras.models.load_model"])
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
    """Loads a model saved via `model.save()`.

    Args:
        filepath: `str` or `pathlib.Path` object, path to the saved model file.
        custom_objects: Optional dictionary mapping names
            (strings) to custom classes or functions to be
            considered during deserialization.
        compile: Boolean, whether to compile the model after loading.
        safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
            When `safe_mode=False`, loading an object has the potential to
            trigger arbitrary code execution. This argument is only
            applicable to the Keras v3 model format. Defaults to `True`.

    Returns:
        A Keras model instance. If the original model was compiled,
        and the argument `compile=True` is set, then the returned model
        will be compiled. Otherwise, the model will be left uncompiled.

    Example:

    ```python
    model = keras.Sequential([
        keras.layers.Dense(5, input_shape=(3,)),
        keras.layers.Softmax()])
    model.save("model.keras")
    loaded_model = keras.saving.load_model("model.keras")
    x = np.random.random((10, 3))
    assert np.allclose(model.predict(x), loaded_model.predict(x))
    ```

    Note that the model variables may have different name values
    (`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
    It is recommended that you use layer attributes to
    access specific variables, e.g. `model.get_layer("dense_1").kernel`.
    """
    is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
        filepath
    )
    is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
        file_utils.join(filepath, "config.json")
    )
    is_hf = str(filepath).startswith("hf://")

    # Support for remote zip files
    if (
        file_utils.is_remote_path(filepath)
        and not file_utils.isdir(filepath)
        and not is_keras_zip
        and not is_hf
    ):
        local_path = file_utils.join(
            saving_lib.get_temp_dir(), os.path.basename(filepath)
        )

        # Copy from remote to temporary local directory
        file_utils.copy(filepath, local_path)

        # Switch filepath to local zipfile for loading model
        if zipfile.is_zipfile(local_path):
            filepath = local_path
            is_keras_zip = True

    if is_keras_zip or is_keras_dir or is_hf:
        return saving_lib.load_model(
            filepath,
            custom_objects=custom_objects,
            compile=compile,
            safe_mode=safe_mode,
        )
    if str(filepath).endswith((".h5", ".hdf5")):
        return legacy_h5_format.load_model_from_hdf5(
            filepath, custom_objects=custom_objects, compile=compile
        )
    elif str(filepath).endswith(".keras"):
        raise ValueError(
            f"File not found: filepath={filepath}. "
            "Please ensure the file is an accessible `.keras` "
            "zip file."
        )
    else:
        raise ValueError(
            f"File format not supported: filepath={filepath}. "
            "Keras 3 only supports V3 `.keras` files and "
            "legacy H5 format files (`.h5` extension). "
            "Note that the legacy SavedModel format is not "
            "supported by `load_model()` in Keras 3. In "
            "order to reload a TensorFlow SavedModel as an "
            "inference-only layer in Keras 3, use "
            "`keras.layers.TFSMLayer("
            f"{filepath}, call_endpoint='serving_default')` "
            "(note that your `call_endpoint` "
            "might have a different name)."
        )


@keras_export("keras.saving.save_weights")
def save_weights(
    model, filepath, overwrite=True, max_shard_size=None, **kwargs
):
    filepath_str = str(filepath)
    if max_shard_size is None and not filepath_str.endswith(".weights.h5"):
        raise ValueError(
            "The filename must end in `.weights.h5`. "
            f"Received: filepath={filepath_str}"
        )
    elif max_shard_size is not None and not filepath_str.endswith(
        ("weights.h5", "weights.json")
    ):
        raise ValueError(
            "The filename must end in `.weights.json` when `max_shard_size` is "
            f"specified. Received: filepath={filepath_str}"
        )
    try:
        exists = os.path.exists(filepath)
    except TypeError:
        exists = False
    if exists and not overwrite:
        proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str)
        if not proceed:
            return
    saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs)


@keras_export("keras.saving.load_weights")
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
    filepath_str = str(filepath)
    if filepath_str.endswith(".keras"):
        if kwargs:
            raise ValueError(f"Invalid keyword arguments: {kwargs}")
        saving_lib.load_weights_only(
            model, filepath, skip_mismatch=skip_mismatch
        )
    elif filepath_str.endswith(".weights.h5") or filepath_str.endswith(
        ".weights.json"
    ):
        objects_to_skip = kwargs.pop("objects_to_skip", None)
        if kwargs:
            raise ValueError(f"Invalid keyword arguments: {kwargs}")
        saving_lib.load_weights_only(
            model,
            filepath,
            skip_mismatch=skip_mismatch,
            objects_to_skip=objects_to_skip,
        )
    elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
        by_name = kwargs.pop("by_name", False)
        if kwargs:
            raise ValueError(f"Invalid keyword arguments: {kwargs}")
        if not h5py:
            raise ImportError(
                "Loading a H5 file requires `h5py` to be installed."
            )
        with h5py.File(filepath, "r") as f:
            if "layer_names" not in f.attrs and "model_weights" in f:
                f = f["model_weights"]
            if by_name:
                legacy_h5_format.load_weights_from_hdf5_group_by_name(
                    f, model, skip_mismatch
                )
            else:
                legacy_h5_format.load_weights_from_hdf5_group(f, model)
    else:
        raise ValueError(
            f"File format not supported: filepath={filepath}. "
            "Keras 3 only supports V3 `.keras` and `.weights.h5` "
            "files, or legacy V1/V2 `.h5` files."
        )