joebruce1313's picture
Upload 38004 files
1f5470c verified
import os
import zipfile
from absl import logging
from keras.src.api_export import keras_export
from keras.src.legacy.saving import legacy_h5_format
from keras.src.saving import saving_lib
from keras.src.utils import file_utils
from keras.src.utils import io_utils
try:
import h5py
except ImportError:
h5py = None
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
"""Saves a model as a `.keras` file.
Args:
model: Keras model instance to be saved.
filepath: `str` or `pathlib.Path` object. Path where to save the model.
overwrite: Whether we should overwrite any existing model at the target
location, or instead ask the user via an interactive prompt.
zipped: Whether to save the model as a zipped `.keras`
archive (default when saving locally), or as an unzipped directory
(default when saving on the Hugging Face Hub).
Example:
```python
model = keras.Sequential(
[
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax(),
],
)
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = keras.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
```
Note that `model.save()` is an alias for `keras.saving.save_model()`.
The saved `.keras` file is a `zip` archive that contains:
- The model's configuration (architecture)
- The model's weights
- The model's optimizer's state (if any)
Thus models can be reinstantiated in the exact same state.
"""
include_optimizer = kwargs.pop("include_optimizer", True)
save_format = kwargs.pop("save_format", False)
if save_format:
if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith(
".keras"
):
logging.warning(
"The `save_format` argument is deprecated in Keras 3. "
"We recommend removing this argument as it can be inferred "
"from the file path. "
f"Received: save_format={save_format}"
)
else:
raise ValueError(
"The `save_format` argument is deprecated in Keras 3. "
"Please remove this argument and pass a file path with "
"either `.keras` or `.h5` extension."
f"Received: save_format={save_format}"
)
if kwargs:
raise ValueError(
"The following argument(s) are not supported: "
f"{list(kwargs.keys())}"
)
# Deprecation warnings
if str(filepath).endswith((".h5", ".hdf5")):
logging.warning(
"You are saving your model as an HDF5 file via "
"`model.save()` or `keras.saving.save_model(model)`. "
"This file format is considered legacy. "
"We recommend using instead the native Keras format, "
"e.g. `model.save('my_model.keras')` or "
"`keras.saving.save_model(model, 'my_model.keras')`. "
)
is_hf = str(filepath).startswith("hf://")
if zipped is None:
zipped = not is_hf # default behavior depends on destination
# If file exists and should not be overwritten.
try:
exists = (not is_hf) and os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
if zipped and str(filepath).endswith(".keras"):
return saving_lib.save_model(model, filepath)
if not zipped:
return saving_lib.save_model(model, filepath, zipped=False)
if str(filepath).endswith((".h5", ".hdf5")):
return legacy_h5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer
)
raise ValueError(
"Invalid filepath extension for saving. "
"Please add either a `.keras` extension for the native Keras "
f"format (recommended) or a `.h5` extension. "
"Use `model.export(filepath)` if you want to export a SavedModel "
"for use with TFLite/TFServing/etc. "
f"Received: filepath={filepath}."
)
@keras_export(["keras.saving.load_model", "keras.models.load_model"])
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
"""Loads a model saved via `model.save()`.
Args:
filepath: `str` or `pathlib.Path` object, path to the saved model file.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
compile: Boolean, whether to compile the model after loading.
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
When `safe_mode=False`, loading an object has the potential to
trigger arbitrary code execution. This argument is only
applicable to the Keras v3 model format. Defaults to `True`.
Returns:
A Keras model instance. If the original model was compiled,
and the argument `compile=True` is set, then the returned model
will be compiled. Otherwise, the model will be left uncompiled.
Example:
```python
model = keras.Sequential([
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax()])
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = np.random.random((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
```
Note that the model variables may have different name values
(`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
It is recommended that you use layer attributes to
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
"""
is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
filepath
)
is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
file_utils.join(filepath, "config.json")
)
is_hf = str(filepath).startswith("hf://")
# Support for remote zip files
if (
file_utils.is_remote_path(filepath)
and not file_utils.isdir(filepath)
and not is_keras_zip
and not is_hf
):
local_path = file_utils.join(
saving_lib.get_temp_dir(), os.path.basename(filepath)
)
# Copy from remote to temporary local directory
file_utils.copy(filepath, local_path)
# Switch filepath to local zipfile for loading model
if zipfile.is_zipfile(local_path):
filepath = local_path
is_keras_zip = True
if is_keras_zip or is_keras_dir or is_hf:
return saving_lib.load_model(
filepath,
custom_objects=custom_objects,
compile=compile,
safe_mode=safe_mode,
)
if str(filepath).endswith((".h5", ".hdf5")):
return legacy_h5_format.load_model_from_hdf5(
filepath, custom_objects=custom_objects, compile=compile
)
elif str(filepath).endswith(".keras"):
raise ValueError(
f"File not found: filepath={filepath}. "
"Please ensure the file is an accessible `.keras` "
"zip file."
)
else:
raise ValueError(
f"File format not supported: filepath={filepath}. "
"Keras 3 only supports V3 `.keras` files and "
"legacy H5 format files (`.h5` extension). "
"Note that the legacy SavedModel format is not "
"supported by `load_model()` in Keras 3. In "
"order to reload a TensorFlow SavedModel as an "
"inference-only layer in Keras 3, use "
"`keras.layers.TFSMLayer("
f"{filepath}, call_endpoint='serving_default')` "
"(note that your `call_endpoint` "
"might have a different name)."
)
@keras_export("keras.saving.save_weights")
def save_weights(
model, filepath, overwrite=True, max_shard_size=None, **kwargs
):
filepath_str = str(filepath)
if max_shard_size is None and not filepath_str.endswith(".weights.h5"):
raise ValueError(
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath_str}"
)
elif max_shard_size is not None and not filepath_str.endswith(
("weights.h5", "weights.json")
):
raise ValueError(
"The filename must end in `.weights.json` when `max_shard_size` is "
f"specified. Received: filepath={filepath_str}"
)
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str)
if not proceed:
return
saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs)
@keras_export("keras.saving.load_weights")
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
filepath_str = str(filepath)
if filepath_str.endswith(".keras"):
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
)
elif filepath_str.endswith(".weights.h5") or filepath_str.endswith(
".weights.json"
):
objects_to_skip = kwargs.pop("objects_to_skip", None)
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model,
filepath,
skip_mismatch=skip_mismatch,
objects_to_skip=objects_to_skip,
)
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
by_name = kwargs.pop("by_name", False)
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
if not h5py:
raise ImportError(
"Loading a H5 file requires `h5py` to be installed."
)
with h5py.File(filepath, "r") as f:
if "layer_names" not in f.attrs and "model_weights" in f:
f = f["model_weights"]
if by_name:
legacy_h5_format.load_weights_from_hdf5_group_by_name(
f, model, skip_mismatch
)
else:
legacy_h5_format.load_weights_from_hdf5_group(f, model)
else:
raise ValueError(
f"File format not supported: filepath={filepath}. "
"Keras 3 only supports V3 `.keras` and `.weights.h5` "
"files, or legacy V1/V2 `.h5` files."
)