File size: 10,904 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 |
import os
import zipfile
from absl import logging
from keras.src.api_export import keras_export
from keras.src.legacy.saving import legacy_h5_format
from keras.src.saving import saving_lib
from keras.src.utils import file_utils
from keras.src.utils import io_utils
try:
import h5py
except ImportError:
h5py = None
@keras_export(["keras.saving.save_model", "keras.models.save_model"])
def save_model(model, filepath, overwrite=True, zipped=None, **kwargs):
"""Saves a model as a `.keras` file.
Args:
model: Keras model instance to be saved.
filepath: `str` or `pathlib.Path` object. Path where to save the model.
overwrite: Whether we should overwrite any existing model at the target
location, or instead ask the user via an interactive prompt.
zipped: Whether to save the model as a zipped `.keras`
archive (default when saving locally), or as an unzipped directory
(default when saving on the Hugging Face Hub).
Example:
```python
model = keras.Sequential(
[
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax(),
],
)
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = keras.random.uniform((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
```
Note that `model.save()` is an alias for `keras.saving.save_model()`.
The saved `.keras` file is a `zip` archive that contains:
- The model's configuration (architecture)
- The model's weights
- The model's optimizer's state (if any)
Thus models can be reinstantiated in the exact same state.
"""
include_optimizer = kwargs.pop("include_optimizer", True)
save_format = kwargs.pop("save_format", False)
if save_format:
if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith(
".keras"
):
logging.warning(
"The `save_format` argument is deprecated in Keras 3. "
"We recommend removing this argument as it can be inferred "
"from the file path. "
f"Received: save_format={save_format}"
)
else:
raise ValueError(
"The `save_format` argument is deprecated in Keras 3. "
"Please remove this argument and pass a file path with "
"either `.keras` or `.h5` extension."
f"Received: save_format={save_format}"
)
if kwargs:
raise ValueError(
"The following argument(s) are not supported: "
f"{list(kwargs.keys())}"
)
# Deprecation warnings
if str(filepath).endswith((".h5", ".hdf5")):
logging.warning(
"You are saving your model as an HDF5 file via "
"`model.save()` or `keras.saving.save_model(model)`. "
"This file format is considered legacy. "
"We recommend using instead the native Keras format, "
"e.g. `model.save('my_model.keras')` or "
"`keras.saving.save_model(model, 'my_model.keras')`. "
)
is_hf = str(filepath).startswith("hf://")
if zipped is None:
zipped = not is_hf # default behavior depends on destination
# If file exists and should not be overwritten.
try:
exists = (not is_hf) and os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath)
if not proceed:
return
if zipped and str(filepath).endswith(".keras"):
return saving_lib.save_model(model, filepath)
if not zipped:
return saving_lib.save_model(model, filepath, zipped=False)
if str(filepath).endswith((".h5", ".hdf5")):
return legacy_h5_format.save_model_to_hdf5(
model, filepath, overwrite, include_optimizer
)
raise ValueError(
"Invalid filepath extension for saving. "
"Please add either a `.keras` extension for the native Keras "
f"format (recommended) or a `.h5` extension. "
"Use `model.export(filepath)` if you want to export a SavedModel "
"for use with TFLite/TFServing/etc. "
f"Received: filepath={filepath}."
)
@keras_export(["keras.saving.load_model", "keras.models.load_model"])
def load_model(filepath, custom_objects=None, compile=True, safe_mode=True):
"""Loads a model saved via `model.save()`.
Args:
filepath: `str` or `pathlib.Path` object, path to the saved model file.
custom_objects: Optional dictionary mapping names
(strings) to custom classes or functions to be
considered during deserialization.
compile: Boolean, whether to compile the model after loading.
safe_mode: Boolean, whether to disallow unsafe `lambda` deserialization.
When `safe_mode=False`, loading an object has the potential to
trigger arbitrary code execution. This argument is only
applicable to the Keras v3 model format. Defaults to `True`.
Returns:
A Keras model instance. If the original model was compiled,
and the argument `compile=True` is set, then the returned model
will be compiled. Otherwise, the model will be left uncompiled.
Example:
```python
model = keras.Sequential([
keras.layers.Dense(5, input_shape=(3,)),
keras.layers.Softmax()])
model.save("model.keras")
loaded_model = keras.saving.load_model("model.keras")
x = np.random.random((10, 3))
assert np.allclose(model.predict(x), loaded_model.predict(x))
```
Note that the model variables may have different name values
(`var.name` property, e.g. `"dense_1/kernel:0"`) after being reloaded.
It is recommended that you use layer attributes to
access specific variables, e.g. `model.get_layer("dense_1").kernel`.
"""
is_keras_zip = str(filepath).endswith(".keras") and zipfile.is_zipfile(
filepath
)
is_keras_dir = file_utils.isdir(filepath) and file_utils.exists(
file_utils.join(filepath, "config.json")
)
is_hf = str(filepath).startswith("hf://")
# Support for remote zip files
if (
file_utils.is_remote_path(filepath)
and not file_utils.isdir(filepath)
and not is_keras_zip
and not is_hf
):
local_path = file_utils.join(
saving_lib.get_temp_dir(), os.path.basename(filepath)
)
# Copy from remote to temporary local directory
file_utils.copy(filepath, local_path)
# Switch filepath to local zipfile for loading model
if zipfile.is_zipfile(local_path):
filepath = local_path
is_keras_zip = True
if is_keras_zip or is_keras_dir or is_hf:
return saving_lib.load_model(
filepath,
custom_objects=custom_objects,
compile=compile,
safe_mode=safe_mode,
)
if str(filepath).endswith((".h5", ".hdf5")):
return legacy_h5_format.load_model_from_hdf5(
filepath, custom_objects=custom_objects, compile=compile
)
elif str(filepath).endswith(".keras"):
raise ValueError(
f"File not found: filepath={filepath}. "
"Please ensure the file is an accessible `.keras` "
"zip file."
)
else:
raise ValueError(
f"File format not supported: filepath={filepath}. "
"Keras 3 only supports V3 `.keras` files and "
"legacy H5 format files (`.h5` extension). "
"Note that the legacy SavedModel format is not "
"supported by `load_model()` in Keras 3. In "
"order to reload a TensorFlow SavedModel as an "
"inference-only layer in Keras 3, use "
"`keras.layers.TFSMLayer("
f"{filepath}, call_endpoint='serving_default')` "
"(note that your `call_endpoint` "
"might have a different name)."
)
@keras_export("keras.saving.save_weights")
def save_weights(
model, filepath, overwrite=True, max_shard_size=None, **kwargs
):
filepath_str = str(filepath)
if max_shard_size is None and not filepath_str.endswith(".weights.h5"):
raise ValueError(
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath_str}"
)
elif max_shard_size is not None and not filepath_str.endswith(
("weights.h5", "weights.json")
):
raise ValueError(
"The filename must end in `.weights.json` when `max_shard_size` is "
f"specified. Received: filepath={filepath_str}"
)
try:
exists = os.path.exists(filepath)
except TypeError:
exists = False
if exists and not overwrite:
proceed = io_utils.ask_to_proceed_with_overwrite(filepath_str)
if not proceed:
return
saving_lib.save_weights_only(model, filepath, max_shard_size, **kwargs)
@keras_export("keras.saving.load_weights")
def load_weights(model, filepath, skip_mismatch=False, **kwargs):
filepath_str = str(filepath)
if filepath_str.endswith(".keras"):
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model, filepath, skip_mismatch=skip_mismatch
)
elif filepath_str.endswith(".weights.h5") or filepath_str.endswith(
".weights.json"
):
objects_to_skip = kwargs.pop("objects_to_skip", None)
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
saving_lib.load_weights_only(
model,
filepath,
skip_mismatch=skip_mismatch,
objects_to_skip=objects_to_skip,
)
elif filepath_str.endswith(".h5") or filepath_str.endswith(".hdf5"):
by_name = kwargs.pop("by_name", False)
if kwargs:
raise ValueError(f"Invalid keyword arguments: {kwargs}")
if not h5py:
raise ImportError(
"Loading a H5 file requires `h5py` to be installed."
)
with h5py.File(filepath, "r") as f:
if "layer_names" not in f.attrs and "model_weights" in f:
f = f["model_weights"]
if by_name:
legacy_h5_format.load_weights_from_hdf5_group_by_name(
f, model, skip_mismatch
)
else:
legacy_h5_format.load_weights_from_hdf5_group(f, model)
else:
raise ValueError(
f"File format not supported: filepath={filepath}. "
"Keras 3 only supports V3 `.keras` and `.weights.h5` "
"files, or legacy V1/V2 `.h5` files."
)
|