|
|
import os |
|
|
import zipfile |
|
|
|
|
|
from absl import logging |
|
|
|
|
|
from keras.src.api_export import keras_export |
|
|
from keras.src.legacy.saving import legacy_h5_format |
|
|
from keras.src.saving import saving_lib |
|
|
from keras.src.utils import file_utils |
|
|
from keras.src.utils import io_utils |
|
|
|
|
|
try: |
|
|
import h5py |
|
|
except ImportError: |
|
|
h5py = None |
|
|
|
|
|
|
|
|
@keras_export(["keras.saving.save_model", "keras.models.save_model"]) |
|
|
def save_model(model, filepath, overwrite=True, zipped=None, **kwargs): |
|
|
"""Saves a model as a `.keras` file. |
|
|
|
|
|
Args: |
|
|
model: Keras model instance to be saved. |
|
|
filepath: `str` or `pathlib.Path` object. Path where to save the model. |
|
|
overwrite: Whether we should overwrite any existing model at the target |
|
|
location, or instead ask the user via an interactive prompt. |
|
|
zipped: Whether to save the model as a zipped `.keras` |
|
|
archive (default when saving locally), or as an unzipped directory |
|
|
(default when saving on the Hugging Face Hub). |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
model = keras.Sequential( |
|
|
[ |
|
|
keras.layers.Dense(5, input_shape=(3,)), |
|
|
keras.layers.Softmax(), |
|
|
], |
|
|
) |
|
|
model.save("model.keras") |
|
|
loaded_model = keras.saving.load_model("model.keras") |
|
|
x = keras.random.uniform((10, 3)) |
|
|
assert np.allclose(model.predict(x), loaded_model.predict(x)) |
|
|
``` |
|
|
|
|
|
Note that `model.save()` is an alias for `keras.saving.save_model()`. |
|
|
|
|
|
The saved `.keras` file is a `zip` archive that contains: |
|
|
|
|
|
- The model's configuration (architecture) |
|
|
- The model's weights |
|
|
- The model's optimizer's state (if any) |
|
|
|
|
|
Thus models can be reinstantiated in the exact same state. |
|
|
""" |
|
|
include_optimizer = kwargs.pop("include_optimizer", True) |
|
|
save_format = kwargs.pop("save_format", False) |
|
|
if save_format: |
|
|
if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith( |
|
|
".keras" |
|
|
): |
|
|
logging.warning( |
|
|
"The `save_format` argument is deprecated in Keras 3. " |
|
|
"We recommend removing this argument as it can be inferred " |
|
|
"from the file path. " |
|
|
f"Received: save_format={save_format}" |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
"The `save_format` argument is deprecated in Keras 3. " |
|
|
"Please remove this argument and pass a file path with " |
|
|
"either `.keras` or `.h5` extension." |
|
|
f"Received: save_format={save_format}" |
|
|
) |
|
|
if kwargs: |
|
|
raise ValueError( |
|
|
"The following argument(s) are not supported: " |
|
|
f"{list(kwargs.keys())}" |
|
|
) |
|
|
|
|
|
|
|
|
if str(filepath).endswith((".h5", ".hdf5")): |
|
|
logging.warning( |
|
|
"You are saving your model as an HDF5 file via " |
|
|
"`model.save()` or `keras.saving.save_model(model)`. " |
|
|
"This file format is considered legacy. " |
|
|
"We recommend using instead the native Keras format, " |
|
|
"e.g. `model.save('my_model.keras')` or " |
|
|
"`keras.saving.save_model(model, 'my_model.keras')`. " |
|
|
) |
|
|
|
|
|
is_hf = str(filepath).startswith("hf://") |
|
|
if zipped is None: |
|
|
zipped = not is_hf |
|
|
|
|
|
|
|
|
try: |
|
|
exists = (not is_hf) and os.path.exists(filepath) |
|
|
except TypeError: |
|
|
exists = False |
|
|
if exists and not overwrite: |
|
|
proceed = io_utils.ask_to_proceed_with_overwrite(filepath) |
|
|
if not proceed: |
|
|
return |
|
|
|
|
|
if zipped and str(filepath).endswith(".keras"): |
|
|
return saving_lib.save_model(model, filepath) |
|
|
if not zipped: |
|
|
return saving_lib.save_model(model, filepath, zipped=False) |
|
|
if str(filepath).endswith((".h5", ".hdf5")): |
|
|
return legacy_h5_format.save_model_to_hdf5( |
|
|
model, filepath, overwrite, include_optimizer |
|
|
) |
|
|
raise ValueError( |
|
|
"Invalid filepath extension for saving. " |
|
|
"Please add either a `.keras` extension for the native Keras " |
|
|
f"format (recommended) or a `.h5` extension. " |
|
|
"Use `model.export(filepath)` if you want to export a SavedModel " |
|
|
"for use with TFLite/TFServing/etc. " |
|
|
f"Received: filepath={filepath}." |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export(["keras.saving.load_model", "keras.models.load_model"]) |
|
|
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): |
|
|
"""Loads a model saved via `model.save()`. |
|
|
|
|
|
Args: |
|
|
filepath: `str` or `pathlib.Path` object, path to the saved model file. |
|
|
custom_objects: Optional dictionary mapping names |
|
|
(strings) to custom classes or functions to be |
|
|
considered during deserialization. |
|
|
compile: Boolean, whether to compile the model after loading. |
|
|
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization. |
|
|
When `safe_mode=False`, loading an object has the potential to |
|
|
trigger arbitrary code execution. This argument is only |
|
|
applicable to the Keras v3 model format. Defaults to `True`. |
|
|
|
|
|
Returns: |
|
|
A Keras model instance. If the original model was compiled, |
|
|
and the argument `compile=True` is set, then the returned model |
|
|
will be compiled. Otherwise, the model will be left uncompiled. |
|
|
|
|
|
Example: |
|
|
|
|
|
```python |
|
|
model = keras.Sequential([ |
|
|
keras.layers.Dense(5, input_shape=(3,)), |
|
|
keras.layers.Softmax()]) |
|
|
model.save("model.keras") |
|
|
loaded_model = keras.saving.load_model("model.keras") |
|
|
x = np.random.random((10, 3)) |
|
|
assert np.allclose(model.predict(x), loaded_model.predict(x)) |
|
|
``` |
|
|
|
|
|
Note that the model variables may have different name values |
|
|
(`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded. |
|
|
It is recommended that you use layer attributes to |
|
|
access specific variables, e.g. `model.get_layer("dense_1").kernel`. |
|
|
""" |
|
|
is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile( |
|
|
filepath |
|
|
) |
|
|
is_keras_dir = file_utils.isdir(filepath) and file_utils.exists( |
|
|
file_utils.join(filepath, "config.json") |
|
|
) |
|
|
is_hf = str(filepath).startswith("hf://") |
|
|
|
|
|
|
|
|
if ( |
|
|
file_utils.is_remote_path(filepath) |
|
|
and not file_utils.isdir(filepath) |
|
|
and not is_keras_zip |
|
|
and not is_hf |
|
|
): |
|
|
local_path = file_utils.join( |
|
|
saving_lib.get_temp_dir(), os.path.basename(filepath) |
|
|
) |
|
|
|
|
|
|
|
|
file_utils.copy(filepath, local_path) |
|
|
|
|
|
|
|
|
if zipfile.is_zipfile(local_path): |
|
|
filepath = local_path |
|
|
is_keras_zip = True |
|
|
|
|
|
if is_keras_zip or is_keras_dir or is_hf: |
|
|
return saving_lib.load_model( |
|
|
filepath, |
|
|
custom_objects=custom_objects, |
|
|
compile=compile, |
|
|
safe_mode=safe_mode, |
|
|
) |
|
|
if str(filepath).endswith((".h5", ".hdf5")): |
|
|
return legacy_h5_format.load_model_from_hdf5( |
|
|
filepath, custom_objects=custom_objects, compile=compile |
|
|
) |
|
|
elif str(filepath).endswith(".keras"): |
|
|
raise ValueError( |
|
|
f"File not found: filepath={filepath}. " |
|
|
"Please ensure the file is an accessible `.keras` " |
|
|
"zip file." |
|
|
) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"File format not supported: filepath={filepath}. " |
|
|
"Keras 3 only supports V3 `.keras` files and " |
|
|
"legacy H5 format files (`.h5` extension). " |
|
|
"Note that the legacy SavedModel format is not " |
|
|
"supported by `load_model()` in Keras 3. In " |
|
|
"order to reload a TensorFlow SavedModel as an " |
|
|
"inference-only layer in Keras 3, use " |
|
|
"`keras.layers.TFSMLayer(" |
|
|
f"{filepath}, call_endpoint='serving_default')` " |
|
|
"(note that your `call_endpoint` " |
|
|
"might have a different name)." |
|
|
) |
|
|
|
|
|
|
|
|
@keras_export("keras.saving.save_weights") |
|
|
def save_weights( |
|
|
model, filepath, overwrite=True, max_shard_size=None, **kwargs |
|
|
): |
|
|
filepath_str = str(filepath) |
|
|
if max_shard_size is None and not filepath_str.endswith(".weights.h5"): |
|
|
raise ValueError( |
|
|
"The filename must end in `.weights.h5`. " |
|
|
f"Received: filepath={filepath_str}" |
|
|
) |
|
|
elif max_shard_size is not None and not filepath_str.endswith( |
|
|
("weights.h5", "weights.json") |
|
|
): |
|
|
raise ValueError( |
|
|
"The filename must end in `.weights.json` when `max_shard_size` is " |
|
|
f"specified. Received: filepath={filepath_str}" |
|
|
) |
|
|
try: |
|
|
exists = os.path.exists(filepath) |
|
|
except TypeError: |
|
|
exists = False |
|
|
if exists and not overwrite: |
|
|
proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str) |
|
|
if not proceed: |
|
|
return |
|
|
saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs) |
|
|
|
|
|
|
|
|
@keras_export("keras.saving.load_weights") |
|
|
def load_weights(model, filepath, skip_mismatch=False, **kwargs): |
|
|
filepath_str = str(filepath) |
|
|
if filepath_str.endswith(".keras"): |
|
|
if kwargs: |
|
|
raise ValueError(f"Invalid keyword arguments: {kwargs}") |
|
|
saving_lib.load_weights_only( |
|
|
model, filepath, skip_mismatch=skip_mismatch |
|
|
) |
|
|
elif filepath_str.endswith(".weights.h5") or filepath_str.endswith( |
|
|
".weights.json" |
|
|
): |
|
|
objects_to_skip = kwargs.pop("objects_to_skip", None) |
|
|
if kwargs: |
|
|
raise ValueError(f"Invalid keyword arguments: {kwargs}") |
|
|
saving_lib.load_weights_only( |
|
|
model, |
|
|
filepath, |
|
|
skip_mismatch=skip_mismatch, |
|
|
objects_to_skip=objects_to_skip, |
|
|
) |
|
|
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"): |
|
|
by_name = kwargs.pop("by_name", False) |
|
|
if kwargs: |
|
|
raise ValueError(f"Invalid keyword arguments: {kwargs}") |
|
|
if not h5py: |
|
|
raise ImportError( |
|
|
"Loading a H5 file requires `h5py` to be installed." |
|
|
) |
|
|
with h5py.File(filepath, "r") as f: |
|
|
if "layer_names" not in f.attrs and "model_weights" in f: |
|
|
f = f["model_weights"] |
|
|
if by_name: |
|
|
legacy_h5_format.load_weights_from_hdf5_group_by_name( |
|
|
f, model, skip_mismatch |
|
|
) |
|
|
else: |
|
|
legacy_h5_format.load_weights_from_hdf5_group(f, model) |
|
|
else: |
|
|
raise ValueError( |
|
|
f"File format not supported: filepath={filepath}. " |
|
|
"Keras 3 only supports V3 `.keras` and `.weights.h5` " |
|
|
"files, or legacy V1/V2 `.h5` files." |
|
|
) |
|
|
|