Spaces:
Sleeping
Sleeping
Delete export_tfhub.py
Browse files- export_tfhub.py +0 -219
export_tfhub.py
DELETED
|
@@ -1,219 +0,0 @@
|
|
| 1 |
-
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
| 2 |
-
#
|
| 3 |
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
-
# you may not use this file except in compliance with the License.
|
| 5 |
-
# You may obtain a copy of the License at
|
| 6 |
-
#
|
| 7 |
-
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
-
#
|
| 9 |
-
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
-
# See the License for the specific language governing permissions and
|
| 13 |
-
# limitations under the License.
|
| 14 |
-
|
| 15 |
-
r"""Exports a BERT-like encoder and its preprocessing as SavedModels for TF Hub.
|
| 16 |
-
|
| 17 |
-
This tool creates preprocessor and encoder SavedModels suitable for uploading
|
| 18 |
-
to https://tfhub.dev that implement the preprocessor and encoder APIs defined
|
| 19 |
-
at https://www.tensorflow.org/hub/common_saved_model_apis/text.
|
| 20 |
-
|
| 21 |
-
For a full usage guide, see
|
| 22 |
-
https://github.com/tensorflow/models/blob/master/official/nlp/docs/tfhub.md
|
| 23 |
-
|
| 24 |
-
Minimal usage examples:
|
| 25 |
-
|
| 26 |
-
1) Exporting an Encoder from checkpoint and config.
|
| 27 |
-
|
| 28 |
-
```
|
| 29 |
-
export_tfhub \
|
| 30 |
-
--encoder_config_file=${BERT_DIR:?}/bert_encoder.yaml \
|
| 31 |
-
--model_checkpoint_path=${BERT_DIR:?}/bert_model.ckpt \
|
| 32 |
-
--vocab_file=${BERT_DIR:?}/vocab.txt \
|
| 33 |
-
--export_type=model \
|
| 34 |
-
--export_path=/tmp/bert_model
|
| 35 |
-
```
|
| 36 |
-
|
| 37 |
-
An --encoder_config_file can specify encoder types other than BERT.
|
| 38 |
-
For BERT, a --bert_config_file in the legacy JSON format can be passed instead.
|
| 39 |
-
|
| 40 |
-
Flag --vocab_file (and flag --do_lower_case, whose default value is guessed
|
| 41 |
-
from the vocab_file path) capture how BertTokenizer was used in pre-training.
|
| 42 |
-
Use flag --sp_model_file instead if SentencepieceTokenizer was used.
|
| 43 |
-
|
| 44 |
-
Changing --export_type to model_with_mlm additionally creates an `.mlm`
|
| 45 |
-
subobject on the exported SavedModel that can be called to produce
|
| 46 |
-
the logits of the Masked Language Model task from pretraining.
|
| 47 |
-
The help string for flag --model_checkpoint_path explains the checkpoint
|
| 48 |
-
formats required for each --export_type.
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
2) Exporting a preprocessor SavedModel
|
| 52 |
-
|
| 53 |
-
```
|
| 54 |
-
export_tfhub \
|
| 55 |
-
--vocab_file ${BERT_DIR:?}/vocab.txt \
|
| 56 |
-
--export_type preprocessing --export_path /tmp/bert_preprocessing
|
| 57 |
-
```
|
| 58 |
-
|
| 59 |
-
Be sure to use flag values that match the encoder and how it has been
|
| 60 |
-
pre-trained (see above for --vocab_file vs --sp_model_file).
|
| 61 |
-
|
| 62 |
-
If your encoder has been trained with text preprocessing for which tfhub.dev
|
| 63 |
-
already has SavedModel, you could guide your users to reuse that one instead
|
| 64 |
-
of exporting and publishing your own.
|
| 65 |
-
|
| 66 |
-
TODO(b/175369555): When exporting to users of TensorFlow 2.4, add flag
|
| 67 |
-
`--experimental_disable_assert_in_preprocessing`.
|
| 68 |
-
"""
|
| 69 |
-
|
| 70 |
-
from absl import app
|
| 71 |
-
from absl import flags
|
| 72 |
-
import gin
|
| 73 |
-
|
| 74 |
-
from official.legacy.bert import configs
|
| 75 |
-
from official.modeling import hyperparams
|
| 76 |
-
from official.nlp.configs import encoders
|
| 77 |
-
from official.nlp.tools import export_tfhub_lib
|
| 78 |
-
|
| 79 |
-
FLAGS = flags.FLAGS
|
| 80 |
-
|
| 81 |
-
flags.DEFINE_enum(
|
| 82 |
-
"export_type", "model",
|
| 83 |
-
["model", "model_with_mlm", "preprocessing"],
|
| 84 |
-
"The overall type of SavedModel to export. Flags "
|
| 85 |
-
"--bert_config_file/--encoder_config_file and --vocab_file/--sp_model_file "
|
| 86 |
-
"control which particular encoder model and preprocessing are exported.")
|
| 87 |
-
flags.DEFINE_string(
|
| 88 |
-
"export_path", None,
|
| 89 |
-
"Directory to which the SavedModel is written.")
|
| 90 |
-
flags.DEFINE_string(
|
| 91 |
-
"encoder_config_file", None,
|
| 92 |
-
"A yaml file representing `encoders.EncoderConfig` to define the encoder "
|
| 93 |
-
"(BERT or other). "
|
| 94 |
-
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
|
| 95 |
-
"Needed for --export_type model and model_with_mlm.")
|
| 96 |
-
flags.DEFINE_string(
|
| 97 |
-
"bert_config_file", None,
|
| 98 |
-
"A JSON file with a legacy BERT configuration to define the BERT encoder. "
|
| 99 |
-
"Exactly one of --bert_config_file and --encoder_config_file can be set. "
|
| 100 |
-
"Needed for --export_type model and model_with_mlm.")
|
| 101 |
-
flags.DEFINE_bool(
|
| 102 |
-
"copy_pooler_dense_to_encoder", False,
|
| 103 |
-
"When the model is trained using `BertPretrainerV2`, the pool layer "
|
| 104 |
-
"of next sentence prediction task exists in `ClassificationHead` passed "
|
| 105 |
-
"to `BertPretrainerV2`. If True, we will copy this pooler's dense layer "
|
| 106 |
-
"to the encoder that is exported by this tool (as in classic BERT). "
|
| 107 |
-
"Using `BertPretrainerV2` and leaving this False exports an untrained "
|
| 108 |
-
"(randomly initialized) pooling layer, which some authors recommend for "
|
| 109 |
-
"subsequent fine-tuning,")
|
| 110 |
-
flags.DEFINE_string(
|
| 111 |
-
"model_checkpoint_path", None,
|
| 112 |
-
"File path to a pre-trained model checkpoint. "
|
| 113 |
-
"For --export_type model, this has to be an object-based (TF2) checkpoint "
|
| 114 |
-
"that can be restored to `tf.train.Checkpoint(encoder=encoder)` "
|
| 115 |
-
"for the `encoder` defined by the config file."
|
| 116 |
-
"(Legacy checkpoints with `model=` instead of `encoder=` are also "
|
| 117 |
-
"supported for now.) "
|
| 118 |
-
"For --export_type model_with_mlm, it must be restorable to "
|
| 119 |
-
"`tf.train.Checkpoint(**BertPretrainerV2(...).checkpoint_items)`. "
|
| 120 |
-
"(For now, `tf.train.Checkpoint(pretrainer=BertPretrainerV2(...))` is also "
|
| 121 |
-
"accepted.)")
|
| 122 |
-
flags.DEFINE_string(
|
| 123 |
-
"vocab_file", None,
|
| 124 |
-
"For encoders trained on BertTokenzier input: "
|
| 125 |
-
"the vocabulary file that the encoder model was trained with. "
|
| 126 |
-
"Exactly one of --vocab_file and --sp_model_file can be set. "
|
| 127 |
-
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
| 128 |
-
flags.DEFINE_string(
|
| 129 |
-
"sp_model_file", None,
|
| 130 |
-
"For encoders trained on SentencepieceTokenzier input: "
|
| 131 |
-
"the SentencePiece .model file that the encoder model was trained with. "
|
| 132 |
-
"Exactly one of --vocab_file and --sp_model_file can be set. "
|
| 133 |
-
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
| 134 |
-
flags.DEFINE_bool(
|
| 135 |
-
"do_lower_case", None,
|
| 136 |
-
"Whether to lowercase before tokenization. "
|
| 137 |
-
"If left as None, and --vocab_file is set, do_lower_case will be enabled "
|
| 138 |
-
"if 'uncased' appears in the name of --vocab_file. "
|
| 139 |
-
"If left as None, and --sp_model_file set, do_lower_case defaults to true. "
|
| 140 |
-
"Needed for --export_type model, model_with_mlm and preprocessing.")
|
| 141 |
-
flags.DEFINE_integer(
|
| 142 |
-
"default_seq_length", 128,
|
| 143 |
-
"The sequence length of preprocessing results from "
|
| 144 |
-
"top-level preprocess method. This is also the default "
|
| 145 |
-
"sequence length for the bert_pack_inputs subobject."
|
| 146 |
-
"Needed for --export_type preprocessing.")
|
| 147 |
-
flags.DEFINE_bool(
|
| 148 |
-
"tokenize_with_offsets", False, # TODO(b/181866850)
|
| 149 |
-
"Whether to export a .tokenize_with_offsets subobject for "
|
| 150 |
-
"--export_type preprocessing.")
|
| 151 |
-
flags.DEFINE_multi_string(
|
| 152 |
-
"gin_file", default=None,
|
| 153 |
-
help="List of paths to the config files.")
|
| 154 |
-
flags.DEFINE_multi_string(
|
| 155 |
-
"gin_params", default=None,
|
| 156 |
-
help="List of Gin bindings.")
|
| 157 |
-
flags.DEFINE_bool( # TODO(b/175369555): Remove this flag and its use.
|
| 158 |
-
"experimental_disable_assert_in_preprocessing", False,
|
| 159 |
-
"Export a preprocessing model without tf.Assert ops. "
|
| 160 |
-
"Usually, that would be a bad idea, except TF2.4 has an issue with "
|
| 161 |
-
"Assert ops in tf.functions used in Dataset.map() on a TPU worker, "
|
| 162 |
-
"and omitting the Assert ops lets SavedModels avoid the issue.")
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
def main(argv):
|
| 166 |
-
if len(argv) > 1:
|
| 167 |
-
raise app.UsageError("Too many command-line arguments.")
|
| 168 |
-
gin.parse_config_files_and_bindings(FLAGS.gin_file, FLAGS.gin_params)
|
| 169 |
-
|
| 170 |
-
if bool(FLAGS.vocab_file) == bool(FLAGS.sp_model_file):
|
| 171 |
-
raise ValueError("Exactly one of `vocab_file` and `sp_model_file` "
|
| 172 |
-
"can be specified, but got %s and %s." %
|
| 173 |
-
(FLAGS.vocab_file, FLAGS.sp_model_file))
|
| 174 |
-
do_lower_case = export_tfhub_lib.get_do_lower_case(
|
| 175 |
-
FLAGS.do_lower_case, FLAGS.vocab_file, FLAGS.sp_model_file)
|
| 176 |
-
|
| 177 |
-
if FLAGS.export_type in ("model", "model_with_mlm"):
|
| 178 |
-
if bool(FLAGS.bert_config_file) == bool(FLAGS.encoder_config_file):
|
| 179 |
-
raise ValueError("Exactly one of `bert_config_file` and "
|
| 180 |
-
"`encoder_config_file` can be specified, but got "
|
| 181 |
-
"%s and %s." %
|
| 182 |
-
(FLAGS.bert_config_file, FLAGS.encoder_config_file))
|
| 183 |
-
if FLAGS.bert_config_file:
|
| 184 |
-
bert_config = configs.BertConfig.from_json_file(FLAGS.bert_config_file)
|
| 185 |
-
encoder_config = None
|
| 186 |
-
else:
|
| 187 |
-
bert_config = None
|
| 188 |
-
encoder_config = encoders.EncoderConfig()
|
| 189 |
-
encoder_config = hyperparams.override_params_dict(
|
| 190 |
-
encoder_config, FLAGS.encoder_config_file, is_strict=True)
|
| 191 |
-
export_tfhub_lib.export_model(
|
| 192 |
-
FLAGS.export_path,
|
| 193 |
-
bert_config=bert_config,
|
| 194 |
-
encoder_config=encoder_config,
|
| 195 |
-
model_checkpoint_path=FLAGS.model_checkpoint_path,
|
| 196 |
-
vocab_file=FLAGS.vocab_file,
|
| 197 |
-
sp_model_file=FLAGS.sp_model_file,
|
| 198 |
-
do_lower_case=do_lower_case,
|
| 199 |
-
with_mlm=FLAGS.export_type == "model_with_mlm",
|
| 200 |
-
copy_pooler_dense_to_encoder=FLAGS.copy_pooler_dense_to_encoder)
|
| 201 |
-
|
| 202 |
-
elif FLAGS.export_type == "preprocessing":
|
| 203 |
-
export_tfhub_lib.export_preprocessing(
|
| 204 |
-
FLAGS.export_path,
|
| 205 |
-
vocab_file=FLAGS.vocab_file,
|
| 206 |
-
sp_model_file=FLAGS.sp_model_file,
|
| 207 |
-
do_lower_case=do_lower_case,
|
| 208 |
-
default_seq_length=FLAGS.default_seq_length,
|
| 209 |
-
tokenize_with_offsets=FLAGS.tokenize_with_offsets,
|
| 210 |
-
experimental_disable_assert=
|
| 211 |
-
FLAGS.experimental_disable_assert_in_preprocessing)
|
| 212 |
-
|
| 213 |
-
else:
|
| 214 |
-
raise app.UsageError(
|
| 215 |
-
"Unknown value '%s' for flag --export_type" % FLAGS.export_type)
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
if __name__ == "__main__":
|
| 219 |
-
app.run(main)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|