File size: 19,183 Bytes
1f5470c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 | import os
import re
import warnings
import numpy as np
from keras.src import backend
from keras.src.api_export import keras_export
from keras.src.callbacks.callback import Callback
from keras.src.utils import file_utils
from keras.src.utils import io_utils
@keras_export("keras.callbacks.ModelCheckpoint")
class ModelCheckpoint(Callback):
"""Callback to save the Keras model or model weights at some frequency.
`ModelCheckpoint` callback is used in conjunction with training using
`model.fit()` to save a model or weights (in a checkpoint file) at some
interval, so the model or weights can be loaded later to continue the
training from the state saved.
A few options this callback provides include:
- Whether to only keep the model that has achieved the "best performance" so
far, or whether to save the model at the end of every epoch regardless of
performance.
- Definition of "best"; which quantity to monitor and whether it should be
maximized or minimized.
- The frequency it should save at. Currently, the callback supports saving
at the end of every epoch, or after a fixed number of training batches.
- Whether only weights are saved, or the whole model is saved.
Example:
```python
model.compile(loss=..., optimizer=...,
metrics=['accuracy'])
EPOCHS = 10
checkpoint_filepath = '/tmp/ckpt/checkpoint.model.keras'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model is saved at the end of every epoch, if it's the best seen so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model (that are considered the best) can be loaded as -
keras.models.load_model(checkpoint_filepath)
# Alternatively, one could checkpoint just the model weights as -
checkpoint_filepath = '/tmp/ckpt/checkpoint.weights.h5'
model_checkpoint_callback = keras.callbacks.ModelCheckpoint(
filepath=checkpoint_filepath,
save_weights_only=True,
monitor='val_accuracy',
mode='max',
save_best_only=True)
# Model weights are saved at the end of every epoch, if it's the best seen
# so far.
model.fit(epochs=EPOCHS, callbacks=[model_checkpoint_callback])
# The model weights (that are considered the best) can be loaded as -
model.load_weights(checkpoint_filepath)
```
Args:
filepath: string or `PathLike`, path to save the model file.
`filepath` can contain named formatting options,
which will be filled the value of `epoch` and keys in `logs`
(passed in `on_epoch_end`).
The `filepath` name needs to end with `".weights.h5"` when
`save_weights_only=True` or should end with `".keras"` or `".h5"`
when checkpoint saving the whole model (default).
For example:
if `filepath` is `"{epoch:02d}-{val_loss:.2f}.keras"` or
"{epoch:02d}-{val_loss:.2f}.weights.h5"`, then the model
checkpoints will be saved with the epoch number and the validation
loss in the filename. The directory of the filepath
should not be reused by any other callbacks to avoid conflicts.
monitor: The metric name to monitor. Typically the metrics are set by
the `Model.compile` method. Note:
* Prefix the name with `"val_"` to monitor validation metrics.
* Use `"loss"` or `"val_loss"` to monitor the model's total loss.
* If you specify metrics as strings, like `"accuracy"`, pass the
same string (with or without the `"val_"` prefix).
* If you pass `metrics.Metric` objects, `monitor` should be set to
`metric.name`
* If you're not sure about the metric names you can check the
contents of the `history.history` dictionary returned by
`history = model.fit()`
* Multi-output models set additional prefixes on the metric names.
verbose: Verbosity mode, 0 or 1. Mode 0 is silent, and mode 1
displays messages when the callback takes an action.
save_best_only: if `save_best_only=True`, it only saves when the model
is considered the "best" and the latest best model according to the
quantity monitored will not be overwritten. If `filepath` doesn't
contain formatting options like `{epoch}` then `filepath` will be
overwritten by each new better model.
mode: one of {`"auto"`, `"min"`, `"max"`}. If `save_best_only=True`, the
decision to overwrite the current save file is made based on either
the maximization or the minimization of the monitored quantity.
For `val_acc`, this should be `"max"`, for `val_loss` this should be
`"min"`, etc. In `"auto"` mode, the mode is set to `"max"` if the
quantities monitored are `"acc"` or start with `"fmeasure"` and are
set to `"min"` for the rest of the quantities.
save_weights_only: if `True`, then only the model's weights will be
saved (`model.save_weights(filepath)`), else the full model is
saved (`model.save(filepath)`).
save_freq: `"epoch"` or integer. When using `"epoch"`, the callback
saves the model after each epoch. When using integer, the callback
saves the model at end of this many batches. If the `Model` is
compiled with `steps_per_execution=N`, then the saving criteria will
be checked every Nth batch. Note that if the saving isn't aligned to
epochs, the monitored metric may potentially be less reliable (it
could reflect as little as 1 batch, since the metrics get reset
every epoch). Defaults to `"epoch"`.
initial_value_threshold: Floating point initial "best" value of the
metric to be monitored. Only applies if `save_best_value=True`. Only
overwrites the model weights already saved if the performance of
current model is better than this value.
"""
def __init__(
self,
filepath,
monitor="val_loss",
verbose=0,
save_best_only=False,
save_weights_only=False,
mode="auto",
save_freq="epoch",
initial_value_threshold=None,
):
super().__init__()
self.monitor = monitor
self.verbose = verbose
self.filepath = file_utils.path_to_string(filepath)
self.save_best_only = save_best_only
self.save_weights_only = save_weights_only
self.save_freq = save_freq
self._batches_seen_since_last_saving = 0
self._last_batch_seen = 0
self.best = initial_value_threshold
if mode not in ["auto", "min", "max"]:
warnings.warn(
f"ModelCheckpoint mode '{mode}' is unknown, "
"fallback to auto mode.",
stacklevel=2,
)
mode = "auto"
if mode == "min":
self.monitor_op = np.less
if self.best is None:
self.best = np.inf
elif mode == "max":
self.monitor_op = np.greater
if self.best is None:
self.best = -np.inf
else:
if "acc" in self.monitor or self.monitor.startswith("fmeasure"):
self.monitor_op = np.greater
if self.best is None:
self.best = -np.inf
else:
self.monitor_op = np.less
if self.best is None:
self.best = np.inf
if self.save_freq != "epoch" and not isinstance(self.save_freq, int):
raise ValueError(
f"Unrecognized save_freq: {self.save_freq}. "
"Expected save_freq are 'epoch' or integer values"
)
if save_weights_only:
if not self.filepath.endswith(".weights.h5"):
raise ValueError(
"When using `save_weights_only=True` in `ModelCheckpoint`"
", the filepath provided must end in `.weights.h5` "
"(Keras weights format). Received: "
f"filepath={self.filepath}"
)
else:
if not any(
self.filepath.endswith(ext) for ext in (".keras", ".h5")
):
raise ValueError(
"The filepath provided must end in `.keras` "
"(Keras model format). Received: "
f"filepath={self.filepath}"
)
def on_train_batch_end(self, batch, logs=None):
if self._should_save_on_batch(batch):
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)
def on_epoch_begin(self, epoch, logs=None):
self._current_epoch = epoch
def on_epoch_end(self, epoch, logs=None):
if self.save_freq == "epoch":
self._save_model(epoch=epoch, batch=None, logs=logs)
def _should_save_on_batch(self, batch):
"""Handles batch-level saving logic, supports steps_per_execution."""
if self.save_freq == "epoch":
return False
if batch <= self._last_batch_seen: # New epoch.
add_batches = batch + 1 # batches are zero-indexed.
else:
add_batches = batch - self._last_batch_seen
self._batches_seen_since_last_saving += add_batches
self._last_batch_seen = batch
if self._batches_seen_since_last_saving >= self.save_freq:
self._batches_seen_since_last_saving = 0
return True
return False
def _should_save_model(self, epoch, batch, logs, filepath):
"""Determines whether the model should be saved.
The model should be saved in the following cases:
- self.save_best_only is False
- self.save_best_only is True and `monitor` is a numpy array or
backend tensor (falls back to `save_best_only=False`)
- self.save_best_only is True and `self.monitor_op(current, self.best)`
evaluates to True.
Args:
epoch: the epoch this iteration is in.
batch: the batch this iteration is in. `None` if the `save_freq`
is set to `"epoch"`.
logs: the `logs` dict passed in to `on_batch_end` or
`on_epoch_end`.
filepath: the path where the model would be saved
"""
logs = logs or {}
if self.save_best_only:
current = logs.get(self.monitor)
if current is None:
warnings.warn(
f"Can save best model only with {self.monitor} available.",
stacklevel=2,
)
return True
elif (
isinstance(current, np.ndarray) or backend.is_tensor(current)
) and len(current.shape) > 0:
warnings.warn(
"Can save best model only when `monitor` is "
f"a scalar value. Received: {current}. "
"Falling back to `save_best_only=False`."
)
return True
else:
if self.monitor_op(current, self.best):
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: {self.monitor} "
"improved "
f"from {self.best:.5f} to {current:.5f}, "
f"saving model to {filepath}"
)
self.best = current
return True
else:
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: "
f"{self.monitor} did not improve "
f"from {self.best:.5f}"
)
return False
else:
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: saving model to {filepath}"
)
return True
def _save_model(self, epoch, batch, logs):
"""Saves the model.
Args:
epoch: the epoch this iteration is in.
batch: the batch this iteration is in. `None` if the `save_freq`
is set to `"epoch"`.
logs: the `logs` dict passed in to `on_batch_end` or `on_epoch_end`.
"""
filepath = self._get_file_path(epoch, batch, logs)
try:
if self._should_save_model(epoch, batch, logs, filepath):
# Create host directory if it doesn't exist.
dirname = os.path.dirname(filepath)
if dirname and not file_utils.exists(dirname):
file_utils.makedirs(dirname)
if self.save_weights_only:
self.model.save_weights(filepath, overwrite=True)
else:
self.model.save(filepath, overwrite=True)
except IsADirectoryError: # h5py 3.x
raise IOError(
"Please specify a non-directory filepath for "
"ModelCheckpoint. Filepath used is an existing "
f"directory: {filepath}"
)
except IOError as e: # h5py 2.x
# `e.errno` appears to be `None` so checking the content of
# `e.args[0]`.
if "is a directory" in str(e.args[0]).lower():
raise IOError(
"Please specify a non-directory filepath for "
"ModelCheckpoint. Filepath used is an existing "
f"directory: f{filepath}"
)
# Re-throw the error for any other causes.
raise e
def _get_file_path(self, epoch, batch, logs):
"""Returns the file path for checkpoint."""
try:
# `filepath` may contain placeholders such as
# `{epoch:02d}`,`{batch:02d}` and `{mape:.2f}`. A mismatch between
# logged metrics and the path's placeholders can cause formatting to
# fail.
if batch is None or "batch" in logs:
file_path = self.filepath.format(epoch=epoch + 1, **logs)
else:
file_path = self.filepath.format(
epoch=epoch + 1, batch=batch + 1, **logs
)
except KeyError as e:
raise KeyError(
f'Failed to format this callback filepath: "{self.filepath}". '
f"Reason: {e}"
)
return file_path
def _checkpoint_exists(self, filepath):
"""Returns whether the checkpoint `filepath` refers to exists."""
return file_utils.exists(filepath)
def _get_most_recently_modified_file_matching_pattern(self, pattern):
"""Returns the most recently modified filepath matching pattern.
In the rare case where there are more than one pattern-matching file
having the same modified time that is most recent among all, return the
filepath that is largest (by `>` operator, lexicographically using the
numeric equivalents). This provides a tie-breaker when multiple files
are most recent. Note that a larger `filepath` can sometimes indicate a
later time of modification (for instance, when epoch/batch is used as
formatting option), but not necessarily (when accuracy or loss is used).
The tie-breaker is put in the logic as best effort to return the most
recent, and to avoid nondeterministic result.
Modified time of a file is obtained with `os.path.getmtime()`.
This utility function is best demonstrated via an example:
```python
file_pattern = 'batch{batch:02d}epoch{epoch:02d}.keras'
test_dir = self.get_temp_dir()
path_pattern = os.path.join(test_dir, file_pattern)
file_paths = [
os.path.join(test_dir, file_name) for file_name in
['batch03epoch02.keras',
'batch02epoch02.keras', 'batch01epoch01.keras']
]
for file_path in file_paths:
# Write something to each of the files
...
self.assertEqual(
_get_most_recently_modified_file_matching_pattern(path_pattern),
file_paths[-1])
```
Args:
pattern: The file pattern that may optionally contain python
placeholder such as `{epoch:02d}`.
Returns:
The most recently modified file's full filepath matching `pattern`.
If `pattern` does not contain any placeholder, this returns the
filepath that exactly matches `pattern`. Returns `None` if no match
is found.
"""
dir_name = os.path.dirname(pattern)
base_name = os.path.basename(pattern)
base_name_regex = "^" + re.sub(r"{.*}", r".*", base_name) + "$"
latest_mod_time = 0
file_path_with_latest_mod_time = None
n_file_with_latest_mod_time = 0
file_path_with_largest_file_name = None
if file_utils.exists(dir_name):
for file_name in os.listdir(dir_name):
# Only consider if `file_name` matches the pattern.
if re.match(base_name_regex, file_name):
file_path = os.path.join(dir_name, file_name)
mod_time = os.path.getmtime(file_path)
if (
file_path_with_largest_file_name is None
or file_path > file_path_with_largest_file_name
):
file_path_with_largest_file_name = file_path
if mod_time > latest_mod_time:
latest_mod_time = mod_time
file_path_with_latest_mod_time = file_path
# In the case a file with later modified time is found,
# reset the counter for the number of files with latest
# modified time.
n_file_with_latest_mod_time = 1
elif mod_time == latest_mod_time:
# In the case a file has modified time tied with the
# most recent, increment the counter for the number of
# files with latest modified time by 1.
n_file_with_latest_mod_time += 1
if n_file_with_latest_mod_time == 1:
# Return the sole file that has most recent modified time.
return file_path_with_latest_mod_time
else:
# If there are more than one file having latest modified time,
# return the file path with the largest file name.
return file_path_with_largest_file_name
|