|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """Checkpoint converter for Mobilebert."""
|
| import os
|
|
|
| from absl import app
|
| from absl import flags
|
| from absl import logging
|
| import numpy as np
|
| import tensorflow.compat.v1 as tf
|
|
|
| from official.projects.mobilebert import model_utils
|
|
|
|
|
| FLAGS = flags.FLAGS
|
|
|
| flags.DEFINE_string(
|
| "bert_config_file", None,
|
| "Bert configuration file to define core mobilebert layers.")
|
| flags.DEFINE_string("tf1_checkpoint_path", None,
|
| "Path to load tf1 checkpoint.")
|
| flags.DEFINE_string("tf2_checkpoint_path", None,
|
| "Path to save tf2 checkpoint.")
|
| flags.DEFINE_boolean("use_model_prefix", False,
|
| ("If use model name as prefix for variables. Turn this"
|
| "flag on when the converted checkpoint is used for model"
|
| "in subclass implementation, which uses the model name as"
|
| "prefix for all variable names."))
|
|
|
|
|
| def _bert_name_replacement(var_name, name_replacements):
|
| """Gets the variable name replacement."""
|
| for src_pattern, tgt_pattern in name_replacements:
|
| if src_pattern in var_name:
|
| old_var_name = var_name
|
| var_name = var_name.replace(src_pattern, tgt_pattern)
|
| logging.info("Converted: %s --> %s", old_var_name, var_name)
|
| return var_name
|
|
|
|
|
| def _has_exclude_patterns(name, exclude_patterns):
|
| """Checks if a string contains substrings that match patterns to exclude."""
|
| for p in exclude_patterns:
|
| if p in name:
|
| return True
|
| return False
|
|
|
|
|
| def _get_permutation(name, permutations):
|
| """Checks whether a variable requires transposition by pattern matching."""
|
| for src_pattern, permutation in permutations:
|
| if src_pattern in name:
|
| logging.info("Permuted: %s --> %s", name, permutation)
|
| return permutation
|
|
|
| return None
|
|
|
|
|
| def _get_new_shape(name, shape, num_heads):
|
| """Checks whether a variable requires reshape by pattern matching."""
|
| if "attention/attention_output/kernel" in name:
|
| return tuple([num_heads, shape[0] // num_heads, shape[1]])
|
| if "attention/attention_output/bias" in name:
|
| return shape
|
|
|
| patterns = [
|
| "attention/query", "attention/value", "attention/key"
|
| ]
|
| for pattern in patterns:
|
| if pattern in name:
|
| if "kernel" in name:
|
| return tuple([shape[0], num_heads, shape[1] // num_heads])
|
| if "bias" in name:
|
| return tuple([num_heads, shape[0] // num_heads])
|
| return None
|
|
|
|
|
| def convert(checkpoint_from_path,
|
| checkpoint_to_path,
|
| name_replacements,
|
| permutations,
|
| bert_config,
|
| exclude_patterns=None):
|
| """Migrates the names of variables within a checkpoint.
|
|
|
| Args:
|
| checkpoint_from_path: Path to source checkpoint to be read in.
|
| checkpoint_to_path: Path to checkpoint to be written out.
|
| name_replacements: A list of tuples of the form (match_str, replace_str)
|
| describing variable names to adjust.
|
| permutations: A list of tuples of the form (match_str, permutation)
|
| describing permutations to apply to given variables. Note that match_str
|
| should match the original variable name, not the replaced one.
|
| bert_config: A `BertConfig` to create the core model.
|
| exclude_patterns: A list of string patterns to exclude variables from
|
| checkpoint conversion.
|
|
|
| Returns:
|
| A dictionary that maps the new variable names to the Variable objects.
|
| A dictionary that maps the old variable names to the new variable names.
|
| """
|
| last_ffn_layer_id = str(bert_config.num_feedforward_networks - 1)
|
| name_replacements = [
|
| (x[0], x[1].replace("LAST_FFN_LAYER_ID", last_ffn_layer_id))
|
| for x in name_replacements
|
| ]
|
|
|
| output_dir, _ = os.path.split(checkpoint_to_path)
|
| tf.io.gfile.makedirs(output_dir)
|
|
|
| temporary_checkpoint_dir = os.path.join(output_dir, "temp_v1")
|
| temporary_checkpoint = os.path.join(temporary_checkpoint_dir, "ckpt")
|
|
|
| with tf.Graph().as_default():
|
| logging.info("Reading checkpoint_from_path %s", checkpoint_from_path)
|
| reader = tf.train.NewCheckpointReader(checkpoint_from_path)
|
| name_shape_map = reader.get_variable_to_shape_map()
|
| new_variable_map = {}
|
| conversion_map = {}
|
| for var_name in name_shape_map:
|
| if exclude_patterns and _has_exclude_patterns(var_name, exclude_patterns):
|
| continue
|
|
|
| tensor = reader.get_tensor(var_name)
|
|
|
|
|
| new_var_name = _bert_name_replacement(var_name, name_replacements)
|
|
|
|
|
| new_shape = None
|
| if bert_config.num_attention_heads > 0:
|
| new_shape = _get_new_shape(new_var_name, tensor.shape,
|
| bert_config.num_attention_heads)
|
| if new_shape:
|
| logging.info("Veriable %s has a shape change from %s to %s",
|
| var_name, tensor.shape, new_shape)
|
| tensor = np.reshape(tensor, new_shape)
|
|
|
|
|
| permutation = _get_permutation(var_name, permutations)
|
| if permutation:
|
| tensor = np.transpose(tensor, permutation)
|
|
|
|
|
| var = tf.Variable(tensor, name=var_name)
|
|
|
|
|
| new_variable_map[new_var_name] = var
|
|
|
|
|
| if new_var_name != var_name:
|
| conversion_map[var_name] = new_var_name
|
|
|
| saver = tf.train.Saver(new_variable_map)
|
|
|
| with tf.Session() as sess:
|
| sess.run(tf.global_variables_initializer())
|
| logging.info("Writing checkpoint_to_path %s", temporary_checkpoint)
|
| saver.save(sess, temporary_checkpoint, write_meta_graph=False)
|
|
|
| logging.info("Summary:")
|
| logging.info("Converted %d variable name(s).", len(new_variable_map))
|
| logging.info("Converted: %s", str(conversion_map))
|
|
|
| mobilebert_model = model_utils.create_mobilebert_pretrainer(bert_config)
|
| create_v2_checkpoint(
|
| mobilebert_model, temporary_checkpoint, checkpoint_to_path)
|
|
|
|
|
| try:
|
| tf.io.gfile.rmtree(temporary_checkpoint_dir)
|
| except tf.errors.OpError:
|
|
|
| pass
|
|
|
|
|
| def create_v2_checkpoint(model, src_checkpoint, output_path):
|
| """Converts a name-based matched TF V1 checkpoint to TF V2 checkpoint."""
|
|
|
| model.load_weights(src_checkpoint).assert_existing_objects_matched()
|
| checkpoint = tf.train.Checkpoint(**model.checkpoint_items)
|
| checkpoint.save(output_path)
|
|
|
|
|
| _NAME_REPLACEMENT = [
|
|
|
| ("bert/", "mobile_bert_encoder/"),
|
| ("encoder/layer_", "transformer_layer_"),
|
|
|
|
|
| ("embeddings/embedding_transformation",
|
| "mobile_bert_embedding/embedding_projection"),
|
| ("embeddings/position_embeddings",
|
| "mobile_bert_embedding/position_embedding/embeddings"),
|
| ("embeddings/token_type_embeddings",
|
| "mobile_bert_embedding/type_embedding/embeddings"),
|
| ("embeddings/word_embeddings",
|
| "mobile_bert_embedding/word_embedding/embeddings"),
|
| ("embeddings/FakeLayerNorm", "mobile_bert_embedding/embedding_norm"),
|
| ("embeddings/LayerNorm", "mobile_bert_embedding/embedding_norm"),
|
|
|
|
|
| ("attention/output/dense", "attention/attention_output"),
|
| ("attention/output/FakeLayerNorm", "attention/norm"),
|
| ("attention/output/LayerNorm", "attention/norm"),
|
| ("attention/self", "attention"),
|
|
|
|
|
| ("bottleneck/input/dense", "bottleneck_input/dense"),
|
| ("bottleneck/input/FakeLayerNorm", "bottleneck_input/norm"),
|
| ("bottleneck/input/LayerNorm", "bottleneck_input/norm"),
|
| ("bottleneck/attention/dense", "kq_shared_bottleneck/dense"),
|
| ("bottleneck/attention/FakeLayerNorm", "kq_shared_bottleneck/norm"),
|
| ("bottleneck/attention/LayerNorm", "kq_shared_bottleneck/norm"),
|
|
|
|
|
| ("ffn_layer_0/output/dense", "ffn_layer_0/output_dense"),
|
| ("ffn_layer_1/output/dense", "ffn_layer_1/output_dense"),
|
| ("ffn_layer_2/output/dense", "ffn_layer_2/output_dense"),
|
| ("output/dense", "ffn_layer_LAST_FFN_LAYER_ID/output_dense"),
|
| ("ffn_layer_0/output/FakeLayerNorm", "ffn_layer_0/norm"),
|
| ("ffn_layer_0/output/LayerNorm", "ffn_layer_0/norm"),
|
| ("ffn_layer_1/output/FakeLayerNorm", "ffn_layer_1/norm"),
|
| ("ffn_layer_1/output/LayerNorm", "ffn_layer_1/norm"),
|
| ("ffn_layer_2/output/FakeLayerNorm", "ffn_layer_2/norm"),
|
| ("ffn_layer_2/output/LayerNorm", "ffn_layer_2/norm"),
|
| ("output/FakeLayerNorm", "ffn_layer_LAST_FFN_LAYER_ID/norm"),
|
| ("output/LayerNorm", "ffn_layer_LAST_FFN_LAYER_ID/norm"),
|
| ("ffn_layer_0/intermediate/dense", "ffn_layer_0/intermediate_dense"),
|
| ("ffn_layer_1/intermediate/dense", "ffn_layer_1/intermediate_dense"),
|
| ("ffn_layer_2/intermediate/dense", "ffn_layer_2/intermediate_dense"),
|
| ("intermediate/dense", "ffn_layer_LAST_FFN_LAYER_ID/intermediate_dense"),
|
|
|
|
|
| ("output/bottleneck/FakeLayerNorm", "bottleneck_output/norm"),
|
| ("output/bottleneck/LayerNorm", "bottleneck_output/norm"),
|
| ("output/bottleneck/dense", "bottleneck_output/dense"),
|
|
|
|
|
| ("pooler/dense", "pooler"),
|
|
|
|
|
| ("cls/predictions", "bert/cls/predictions"),
|
| ("cls/predictions/output_bias", "cls/predictions/output_bias/bias")
|
| ]
|
|
|
| _EXCLUDE_PATTERNS = ["cls/seq_relationship", "global_step"]
|
|
|
|
|
| def main(argv):
|
| if len(argv) > 1:
|
| raise app.UsageError("Too many command-line arguments.")
|
|
|
| if not FLAGS.use_model_prefix:
|
| _NAME_REPLACEMENT[0] = ("bert/", "")
|
|
|
| bert_config = model_utils.BertConfig.from_json_file(FLAGS.bert_config_file)
|
| convert(FLAGS.tf1_checkpoint_path,
|
| FLAGS.tf2_checkpoint_path,
|
| _NAME_REPLACEMENT,
|
| [],
|
| bert_config,
|
| _EXCLUDE_PATTERNS)
|
|
|
| if __name__ == "__main__":
|
| app.run(main)
|
|
|