File size: 2,181 Bytes
1f5470c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from keras.src.layers.layer import Layer
from keras.src.metrics.metric import Metric
from keras.src.optimizers.optimizer import Optimizer
from keras.src.saving import saving_lib
from keras.src.saving.keras_saveable import KerasSaveable
def map_saveable_variables(saveable, store, visited_saveables):
# If the saveable has already been seen, skip it.
if id(saveable) in visited_saveables:
return
visited_saveables.add(id(saveable))
variables = []
if isinstance(saveable, Layer):
variables = (
saveable._trainable_variables + saveable._non_trainable_variables
)
elif isinstance(saveable, Optimizer):
variables = saveable._variables
elif isinstance(saveable, Metric):
variables = saveable._variables
for v in variables:
if v.path in store:
raise ValueError(
"The model contains two variables with a duplicate path: "
f"path='{v.path}' appears at least twice. "
f"This path is used for {v} and for {store[v.path]}. "
"In order to get a variable map, make sure to use "
"unique paths/names for each variable."
)
store[v.path] = v
# Recursively save state of children saveables (layers, optimizers, etc.)
for child_attr, child_obj in saving_lib._walk_saveable(saveable):
if isinstance(child_obj, KerasSaveable):
map_saveable_variables(
child_obj,
store,
visited_saveables=visited_saveables,
)
elif isinstance(child_obj, (list, dict, tuple, set)):
map_container_variables(
child_obj,
store,
visited_saveables=visited_saveables,
)
def map_container_variables(container, store, visited_saveables):
if isinstance(container, dict):
container = list(container.values())
for saveable in container:
if isinstance(saveable, KerasSaveable):
map_saveable_variables(
saveable,
store,
visited_saveables=visited_saveables,
)
|