deep / external_data /convert_tf_to_pt.py
Aryan6192's picture
deep
79cf6ef verified
# Source: https://github.com/lukemelas/EfficientNet-PyTorch
import numpy as np
import tensorflow as tf
import torch
def load_param(checkpoint_file, conversion_table, model_name):
"""
Load parameters according to conversion_table.
Args:
checkpoint_file (string): pretrained checkpoint model file in tensorflow
conversion_table (dict): { pytorch tensor in a model : checkpoint variable name }
"""
for pyt_param, tf_param_name in conversion_table.items():
tf_param_name = str(model_name) + '/' + tf_param_name
tf_param = tf.train.load_variable(checkpoint_file, tf_param_name)
if 'conv' in tf_param_name and 'kernel' in tf_param_name:
tf_param = np.transpose(tf_param, (3, 2, 0, 1))
if 'depthwise' in tf_param_name:
tf_param = np.transpose(tf_param, (1, 0, 2, 3))
elif tf_param_name.endswith('kernel'): # for weight(kernel), we should do transpose
tf_param = np.transpose(tf_param)
assert pyt_param.size() == tf_param.shape, \
'Dim Mismatch: %s vs %s ; %s' % (tuple(pyt_param.size()), tf_param.shape, tf_param_name)
pyt_param.data = torch.from_numpy(tf_param)
def load_efficientnet(model, checkpoint_file, model_name):
"""
Load PyTorch EfficientNet from TensorFlow checkpoint file
"""
# This will store the enire conversion table
conversion_table = {}
merge = lambda dict1, dict2: {**dict1, **dict2}
# All the weights not in the conv blocks
conversion_table_for_weights_outside_blocks = {
model._conv_stem.weight: 'stem/conv2d/kernel', # [3, 3, 3, 32]),
model._bn0.bias: 'stem/tpu_batch_normalization/beta', # [32]),
model._bn0.weight: 'stem/tpu_batch_normalization/gamma', # [32]),
model._bn0.running_mean: 'stem/tpu_batch_normalization/moving_mean', # [32]),
model._bn0.running_var: 'stem/tpu_batch_normalization/moving_variance', # [32]),
model._conv_head.weight: 'head/conv2d/kernel', # [1, 1, 320, 1280]),
model._bn1.bias: 'head/tpu_batch_normalization/beta', # [1280]),
model._bn1.weight: 'head/tpu_batch_normalization/gamma', # [1280]),
model._bn1.running_mean: 'head/tpu_batch_normalization/moving_mean', # [32]),
model._bn1.running_var: 'head/tpu_batch_normalization/moving_variance', # [32]),
model._fc.bias: 'head/dense/bias', # [1000]),
model._fc.weight: 'head/dense/kernel', # [1280, 1000]),
}
conversion_table = merge(conversion_table, conversion_table_for_weights_outside_blocks)
# The first conv block is special because it does not have _expand_conv
conversion_table_for_first_block = {
model._blocks[0]._project_conv.weight: 'blocks_0/conv2d/kernel', # 1, 1, 32, 16]),
model._blocks[0]._depthwise_conv.weight: 'blocks_0/depthwise_conv2d/depthwise_kernel', # [3, 3, 32, 1]),
model._blocks[0]._se_reduce.bias: 'blocks_0/se/conv2d/bias', # , [8]),
model._blocks[0]._se_reduce.weight: 'blocks_0/se/conv2d/kernel', # , [1, 1, 32, 8]),
model._blocks[0]._se_expand.bias: 'blocks_0/se/conv2d_1/bias', # , [32]),
model._blocks[0]._se_expand.weight: 'blocks_0/se/conv2d_1/kernel', # , [1, 1, 8, 32]),
model._blocks[0]._bn1.bias: 'blocks_0/tpu_batch_normalization/beta', # [32]),
model._blocks[0]._bn1.weight: 'blocks_0/tpu_batch_normalization/gamma', # [32]),
model._blocks[0]._bn1.running_mean: 'blocks_0/tpu_batch_normalization/moving_mean',
model._blocks[0]._bn1.running_var: 'blocks_0/tpu_batch_normalization/moving_variance',
model._blocks[0]._bn2.bias: 'blocks_0/tpu_batch_normalization_1/beta', # [16]),
model._blocks[0]._bn2.weight: 'blocks_0/tpu_batch_normalization_1/gamma', # [16]),
model._blocks[0]._bn2.running_mean: 'blocks_0/tpu_batch_normalization_1/moving_mean',
model._blocks[0]._bn2.running_var: 'blocks_0/tpu_batch_normalization_1/moving_variance',
}
conversion_table = merge(conversion_table, conversion_table_for_first_block)
# Conv blocks
for i in range(len(model._blocks)):
is_first_block = '_expand_conv.weight' not in [n for n, p in model._blocks[i].named_parameters()]
if is_first_block:
conversion_table_block = {
model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel', # 1, 1, 32, 16]),
model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel',
# [3, 3, 32, 1]),
model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias', # , [8]),
model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel', # , [1, 1, 32, 8]),
model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias', # , [32]),
model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel', # , [1, 1, 8, 32]),
model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta', # [32]),
model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma', # [32]),
model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean',
model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance',
model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta', # [16]),
model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma', # [16]),
model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean',
model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance',
}
else:
conversion_table_block = {
model._blocks[i]._expand_conv.weight: 'blocks_' + str(i) + '/conv2d/kernel',
model._blocks[i]._project_conv.weight: 'blocks_' + str(i) + '/conv2d_1/kernel',
model._blocks[i]._depthwise_conv.weight: 'blocks_' + str(i) + '/depthwise_conv2d/depthwise_kernel',
model._blocks[i]._se_reduce.bias: 'blocks_' + str(i) + '/se/conv2d/bias',
model._blocks[i]._se_reduce.weight: 'blocks_' + str(i) + '/se/conv2d/kernel',
model._blocks[i]._se_expand.bias: 'blocks_' + str(i) + '/se/conv2d_1/bias',
model._blocks[i]._se_expand.weight: 'blocks_' + str(i) + '/se/conv2d_1/kernel',
model._blocks[i]._bn0.bias: 'blocks_' + str(i) + '/tpu_batch_normalization/beta',
model._blocks[i]._bn0.weight: 'blocks_' + str(i) + '/tpu_batch_normalization/gamma',
model._blocks[i]._bn0.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_mean',
model._blocks[i]._bn0.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization/moving_variance',
model._blocks[i]._bn1.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_1/beta',
model._blocks[i]._bn1.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_1/gamma',
model._blocks[i]._bn1.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_mean',
model._blocks[i]._bn1.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_1/moving_variance',
model._blocks[i]._bn2.bias: 'blocks_' + str(i) + '/tpu_batch_normalization_2/beta',
model._blocks[i]._bn2.weight: 'blocks_' + str(i) + '/tpu_batch_normalization_2/gamma',
model._blocks[i]._bn2.running_mean: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_mean',
model._blocks[i]._bn2.running_var: 'blocks_' + str(i) + '/tpu_batch_normalization_2/moving_variance',
}
conversion_table = merge(conversion_table, conversion_table_block)
# Load TensorFlow parameters into PyTorch model
load_param(checkpoint_file, conversion_table, model_name)
return conversion_table
def load_and_save_temporary_tensorflow_model(model_name, model_ckpt, example_img= '../../example/img.jpg'):
""" Loads and saves a TensorFlow model. """
image_files = [example_img]
eval_ckpt_driver = eval_ckpt_main.EvalCkptDriver(model_name)
with tf.Graph().as_default(), tf.Session() as sess:
images, labels = eval_ckpt_driver.build_dataset(image_files, [0] * len(image_files), False)
probs = eval_ckpt_driver.build_model(images, is_training=False)
sess.run(tf.global_variables_initializer())
print(model_ckpt)
eval_ckpt_driver.restore_model(sess, model_ckpt)
tf.train.Saver().save(sess, 'tmp/model.ckpt')
if __name__ == '__main__':
import sys
import argparse
sys.path.append('original_tf')
import eval_ckpt_main
from efficientnet_pytorch import EfficientNet
parser = argparse.ArgumentParser(
description='Convert TF model to PyTorch model and save for easier future loading')
parser.add_argument('--model_name', type=str, default='efficientnet-b0',
help='efficientnet-b{N}, where N is an integer 0 <= N <= 8')
parser.add_argument('--tf_checkpoint', type=str, default='pretrained_tensorflow/efficientnet-b0/',
help='checkpoint file path')
parser.add_argument('--output_file', type=str, default='pretrained_pytorch/efficientnet-b0.pth',
help='output PyTorch model file name')
args = parser.parse_args()
# Build model
model = EfficientNet.from_name(args.model_name)
# Load and save temporary TensorFlow file due to TF nuances
print(args.tf_checkpoint)
load_and_save_temporary_tensorflow_model(args.model_name, args.tf_checkpoint)
# Load weights
load_efficientnet(model, 'tmp/model.ckpt', model_name=args.model_name)
print('Loaded TF checkpoint weights')
# Save PyTorch file
torch.save(model.state_dict(), args.output_file)
print('Saved model to', args.output_file)