File size: 1,396 Bytes
55880f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 | import logging
import tensorflow as tf
from ..sfcn import MultiTaskSFCN
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
level=logging.DEBUG
)
logger = logging.getLogger(__name__)
def load_select_pretrained_weights(
model: tf.keras.Model,
weights: str,
target: str = None
) -> tf.keras.Model:
logger.info('Loading pretrained weights from %s', weights)
backbone = MultiTaskSFCN(input_shape=(224, 192, 224), pooling='max')
checkpoint = tf.train.Checkpoint(backbone)
checkpoint.restore(weights).expect_partial()
conv_layers = [2, 6, 10, 14, 18, 22]
norm_layers = [3, 7, 11, 15, 19, 23]
for idx in conv_layers + norm_layers:
model.layers[idx].set_weights(backbone.layers[idx].get_weights())
# Loading weights from the specific dense-layer corresponding to the
# given prediction-task in the multi-task model
if target == 'age':
logger.info('Loaded age weights for the prediction head')
model.layers[27].set_weights(backbone.layers[27].get_weights())
elif target == 'sex':
logger.info('Loaded sex weights for the prediction head')
model.layers[27].set_weights(backbone.layers[28].get_weights())
else:
logger.warning(
'Unknown target %s. Not loading weights for prediction layer',
target
)
return model |