|
|
|
|
|
|
|
|
import numpy as np |
|
|
import tensorflow as tf |
|
|
import torch |
|
|
|
|
|
def load_param(checkpoint_file, conversion_table, model_name): |
|
|
""" |
|
|
Load parameters according to conversion_table. |
|
|
|
|
|
Args: |
|
|
checkpoint_file (string): pretrained checkpoint model file in tensorflow |
|
|
conversion_table (dict): { pytorch tensor in a model : checkpoint variable name } |
|
|
""" |
|
|
for pyt_param, tf_param_name in conversion_table.items(): |
|
|
tf_param_name = str(model_name) + '/' + tf_param_name |
|
|
tf_param = tf.train.load_variable(checkpoint_file, tf_param_name) |
|
|
if 'conv' in tf_param_name and 'kernel' in tf_param_name: |
|
|
tf_param = np.transpose(tf_param, (3, 2, 0, 1)) |
|
|
if 'depthwise' in tf_param_name: |
|
|
tf_param = np.transpose(tf_param, (1, 0, 2, 3)) |
|
|
elif tf_param_name.endswith('kernel'): |
|
|
tf_param = np.transpose(tf_param) |
|
|
assert pyt_param.size() == tf_param.shape, \ |
|
|
'Dim Mismatch: %s vs %s ; %s' % (tuple(pyt_param.size()), tf_param.shape, tf_param_name) |
|
|
pyt_param.data = torch.from_numpy(tf_param) |
|
|
|
|
|
|
|
|
def load_efficientnet(model, checkpoint_file, model_name): |
|
|
""" |
|
|
Load PyTorch EfficientNet from TensorFlow checkpoint file |
|
|
""" |
|
|
|
|
|
|
|
|
conversion_table = {} |
|
|
merge = lambda dict1, dict2: {**dict1, **dict2} |
|
|
|
|
|
|
|
|
conversion_table_for_weights_outside_blocks = { |
|
|
model._conv_stem.weight: 'stem/conv2d/kernel', |
|
|
model._bn0.bias: 'stem/tpu_batch_normalization/beta', |
|
|
model._bn0.weight: 'stem/tpu_batch_normalization/gamma', |
|
|
model._bn0.running_mean: 'stem/tpu_batch_normalization/moving_mean', |
|
|
model._bn0.running_var: 'stem/tpu_batch_normalization/moving_variance', |
|
|
model._conv_head.weight: 'head/conv2d/kernel', |
|
|
model._bn1.bias: 'head/tpu_batch_normalization/beta', |
|
|
model._bn1.weight: 'head/tpu_batch_normalization/gamma', |
|
|
model._bn1.running_mean: 'head/tpu_batch_normalization/moving_mean', |
|
|
model._bn1.running_var: 'head/tpu_batch_normalization/moving_variance', |
|
|
model._fc.bias: 'head/dense/bias', |
|
|
model._fc.weight: 'head/dense/kernel', |
|
|
} |
|
|
conversion_table = merge(conversion_table, conversion_table_for_weights_outside_blocks) |
|
|
|
|
|
|
|
|
conversion_table_for_first_block = { |
|
|
model._blocks[0]._project_conv.weight: 'blocks_0/conv2d/kernel', |
|
|
model._blocks[0]._depthwise_conv.weight: 'blocks_0/depthwise_conv2d/depthwise_kernel', |
|
|
model._blocks[0]._se_reduce.bias: 'blocks_0/se/conv2d/bias', |
|
|
model._blocks[0]._se_reduce.weight: 'blocks_0/se/conv2d/kernel', |
|
|
model._blocks[0]._se_expand.bias: 'blocks_0/se/conv2d_1/bias', |
|
|
model._blocks[0]._se_expand.weight: 'blocks_0/se/conv2d_1/kernel', |
|
|
model._blocks[0]._bn1.bias: 'blocks_0/tpu_batch_normalization/beta', |
|
|
model._blocks[0]._bn1.weight: 'blocks_0/tpu_batch_normalization/gamma', |
|
|
model._blocks[0]._bn1.running_mean: 'blocks_0/tpu_batch_normalization/moving_mean', |
|
|
model._blocks[0]._bn1.running_var: 'blocks_0/tpu_batch_normalization/moving_variance', |
|
|
model._blocks[0]._bn2.bias: 'blocks_0/tpu_batch_normalization_1/beta', |
|
|
model._blocks[0]._bn2.weight: 'blocks_0/tpu_batch_normalization_1/gamma', |
|
|
model._blocks[0]._bn2.running_mean: 'blocks_0/tpu_batch_normalization_1/moving_mean', |
|
|
model._blocks[0]._bn2.running_var: 'blocks_0/tpu_batch_normalization_1/moving_variance', |
|
|
} |
|
|
conversion_table = merge(conversion_table, conversion_table_for_first_block) |
|
|
|
|
|
|
|
|
for i in range(len(model._blocks)): |
|
|
|
|
|
is_first_block = '_expand_conv.weight' not in [n for n, p in model._blocks[i].named_parameters()] |
|
|
|
|
|
if is_first_block: |
|
|
conversion_table_block = { |
|
|
model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel', |
|
|
model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel', |
|
|
|
|
|
model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias', |
|
|
model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel', |
|
|
model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias', |
|
|
model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel', |
|
|
model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta', |
|
|
model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma', |
|
|
model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean', |
|
|
model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance', |
|
|
model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta', |
|
|
model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma', |
|
|
model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean', |
|
|
model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance', |
|
|
} |
|
|
|
|
|
else: |
|
|
conversion_table_block = { |
|
|
model._blocks[i]._expand_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel', |
|
|
model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d_1/kernel', |
|
|
model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel', |
|
|
model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias', |
|
|
model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel', |
|
|
model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias', |
|
|
model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel', |
|
|
model._blocks[i]._bn0.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta', |
|
|
model._blocks[i]._bn0.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma', |
|
|
model._blocks[i]._bn0.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean', |
|
|
model._blocks[i]._bn0.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance', |
|
|
model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta', |
|
|
model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma', |
|
|
model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean', |
|
|
model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance', |
|
|
model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_2/beta', |
|
|
model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_2/gamma', |
|
|
model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_mean', |
|
|
model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_variance', |
|
|
} |
|
|
|
|
|
conversion_table = merge(conversion_table, conversion_table_block) |
|
|
|
|
|
|
|
|
load_param(checkpoint_file, conversion_table, model_name) |
|
|
return conversion_table |
|
|
|
|
|
|
|
|
def load_and_save_temporary_tensorflow_model(model_name, model_ckpt, example_img= '../../example/img.jpg'): |
|
|
""" Loads and saves a TensorFlow model. """ |
|
|
image_files = [example_img] |
|
|
eval_ckpt_driver = eval_ckpt_main.EvalCkptDriver(model_name) |
|
|
with tf.Graph().as_default(), tf.Session() as sess: |
|
|
images, labels = eval_ckpt_driver.build_dataset(image_files, [0] * len(image_files), False) |
|
|
probs = eval_ckpt_driver.build_model(images, is_training=False) |
|
|
sess.run(tf.global_variables_initializer()) |
|
|
print(model_ckpt) |
|
|
eval_ckpt_driver.restore_model(sess, model_ckpt) |
|
|
tf.train.Saver().save(sess, 'tmp/model.ckpt') |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
import sys |
|
|
import argparse |
|
|
|
|
|
sys.path.append('original_tf') |
|
|
import eval_ckpt_main |
|
|
|
|
|
from efficientnet_pytorch import EfficientNet |
|
|
|
|
|
parser = argparse.ArgumentParser( |
|
|
description='Convert TF model to PyTorch model and save for easier future loading') |
|
|
parser.add_argument('--model_name', type=str, default='efficientnet-b0', |
|
|
help='efficientnet-b{N}, where N is an integer 0 <= N <= 8') |
|
|
parser.add_argument('--tf_checkpoint', type=str, default='pretrained_tensorflow/efficientnet-b0/', |
|
|
help='checkpoint file path') |
|
|
parser.add_argument('--output_file', type=str, default='pretrained_pytorch/efficientnet-b0.pth', |
|
|
help='output PyTorch model file name') |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
model = EfficientNet.from_name(args.model_name) |
|
|
|
|
|
|
|
|
print(args.tf_checkpoint) |
|
|
load_and_save_temporary_tensorflow_model(args.model_name, args.tf_checkpoint) |
|
|
|
|
|
|
|
|
load_efficientnet(model, 'tmp/model.ckpt', model_name=args.model_name) |
|
|
print('Loaded TF checkpoint weights') |
|
|
|
|
|
|
|
|
torch.save(model.state_dict(), args.output_file) |
|
|
print('Saved model to', args.output_file) |
|
|
|