| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Script for training a masked language model on TPU.""" |
|
|
| import argparse |
| import logging |
| import os |
| import re |
|
|
| import tensorflow as tf |
|
|
| from transformers import ( |
| AutoConfig, |
| AutoTokenizer, |
| DataCollatorForLanguageModeling, |
| PushToHubCallback, |
| TFAutoModelForMaskedLM, |
| create_optimizer, |
| ) |
|
|
|
|
| logger = logging.getLogger(__name__) |
|
|
| AUTO = tf.data.AUTOTUNE |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Train a masked language model on TPU.") |
| parser.add_argument( |
| "--pretrained_model_config", |
| type=str, |
| default="roberta-base", |
| help="The model config to use. Note that we don't copy the model's weights, only the config!", |
| ) |
| parser.add_argument( |
| "--tokenizer", |
| type=str, |
| default="unigram-tokenizer-wikitext", |
| help="The name of the tokenizer to load. We use the pretrained tokenizer to initialize the model's vocab size.", |
| ) |
|
|
| parser.add_argument( |
| "--per_replica_batch_size", |
| type=int, |
| default=8, |
| help="Batch size per TPU core.", |
| ) |
|
|
| parser.add_argument( |
| "--no_tpu", |
| action="store_true", |
| help="If set, run on CPU and don't try to initialize a TPU. Useful for debugging on non-TPU instances.", |
| ) |
|
|
| parser.add_argument( |
| "--tpu_name", |
| type=str, |
| help="Name of TPU resource to initialize. Should be blank on Colab, and 'local' on TPU VMs.", |
| default="local", |
| ) |
|
|
| parser.add_argument( |
| "--tpu_zone", |
| type=str, |
| help="Google cloud zone that TPU resource is located in. Only used for non-Colab TPU nodes.", |
| ) |
|
|
| parser.add_argument( |
| "--gcp_project", type=str, help="Google cloud project name. Only used for non-Colab TPU nodes." |
| ) |
|
|
| parser.add_argument( |
| "--bfloat16", |
| action="store_true", |
| help="Use mixed-precision bfloat16 for training. This is the recommended lower-precision format for TPU.", |
| ) |
|
|
| parser.add_argument( |
| "--train_dataset", |
| type=str, |
| help="Path to training dataset to load. If the path begins with `gs://`" |
| " then the dataset will be loaded from a Google Cloud Storage bucket.", |
| ) |
|
|
| parser.add_argument( |
| "--shuffle_buffer_size", |
| type=int, |
| default=2**18, |
| help="Size of the shuffle buffer (in samples)", |
| ) |
|
|
| parser.add_argument( |
| "--eval_dataset", |
| type=str, |
| help="Path to evaluation dataset to load. If the path begins with `gs://`" |
| " then the dataset will be loaded from a Google Cloud Storage bucket.", |
| ) |
|
|
| parser.add_argument( |
| "--num_epochs", |
| type=int, |
| default=1, |
| help="Number of epochs to train for.", |
| ) |
|
|
| parser.add_argument( |
| "--learning_rate", |
| type=float, |
| default=1e-4, |
| help="Learning rate to use for training.", |
| ) |
|
|
| parser.add_argument( |
| "--weight_decay_rate", |
| type=float, |
| default=1e-3, |
| help="Weight decay rate to use for training.", |
| ) |
|
|
| parser.add_argument( |
| "--max_length", |
| type=int, |
| default=512, |
| help="Maximum length of tokenized sequences. Should match the setting used in prepare_tfrecord_shards.py", |
| ) |
|
|
| parser.add_argument( |
| "--mlm_probability", |
| type=float, |
| default=0.15, |
| help="Fraction of tokens to mask during training.", |
| ) |
|
|
| parser.add_argument("--output_dir", type=str, required=True, help="Path to save model checkpoints to.") |
| parser.add_argument("--hub_model_id", type=str, help="Model ID to upload to on the Hugging Face Hub.") |
|
|
| args = parser.parse_args() |
| return args |
|
|
|
|
| def initialize_tpu(args): |
| try: |
| if args.tpu_name: |
| tpu = tf.distribute.cluster_resolver.TPUClusterResolver( |
| args.tpu_name, zone=args.tpu_zone, project=args.gcp_project |
| ) |
| else: |
| tpu = tf.distribute.cluster_resolver.TPUClusterResolver() |
| except ValueError: |
| raise RuntimeError( |
| "Couldn't connect to TPU! Most likely you need to specify --tpu_name, --tpu_zone, or " |
| "--gcp_project. When running on a TPU VM, use --tpu_name local." |
| ) |
|
|
| tf.config.experimental_connect_to_cluster(tpu) |
| tf.tpu.experimental.initialize_tpu_system(tpu) |
|
|
| return tpu |
|
|
|
|
| def count_samples(file_list): |
| num_samples = 0 |
| for file in file_list: |
| filename = file.split("/")[-1] |
| sample_count = re.search(r"-\d+-(\d+)\.tfrecord", filename).group(1) |
| sample_count = int(sample_count) |
| num_samples += sample_count |
|
|
| return num_samples |
|
|
|
|
| def prepare_dataset(records, decode_fn, mask_fn, batch_size, shuffle, shuffle_buffer_size=None): |
| num_samples = count_samples(records) |
| dataset = tf.data.Dataset.from_tensor_slices(records) |
| if shuffle: |
| dataset = dataset.shuffle(len(dataset)) |
| dataset = tf.data.TFRecordDataset(dataset, num_parallel_reads=AUTO) |
| |
| dataset = dataset.apply(tf.data.experimental.assert_cardinality(num_samples)) |
| dataset = dataset.map(decode_fn, num_parallel_calls=AUTO) |
| if shuffle: |
| assert shuffle_buffer_size is not None |
| dataset = dataset.shuffle(args.shuffle_buffer_size) |
| dataset = dataset.batch(batch_size, drop_remainder=True) |
| dataset = dataset.map(mask_fn, num_parallel_calls=AUTO) |
| dataset = dataset.prefetch(AUTO) |
| return dataset |
|
|
|
|
| def main(args): |
| if not args.no_tpu: |
| tpu = initialize_tpu(args) |
| strategy = tf.distribute.TPUStrategy(tpu) |
| else: |
| strategy = tf.distribute.OneDeviceStrategy(device="/gpu:0") |
|
|
| if args.bfloat16: |
| tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") |
|
|
| tokenizer = AutoTokenizer.from_pretrained(args.tokenizer) |
| config = AutoConfig.from_pretrained(args.pretrained_model_config) |
| config.vocab_size = tokenizer.vocab_size |
|
|
| training_records = tf.io.gfile.glob(os.path.join(args.train_dataset, "*.tfrecord")) |
| if not training_records: |
| raise ValueError(f"No .tfrecord files found in {args.train_dataset}.") |
| eval_records = tf.io.gfile.glob(os.path.join(args.eval_dataset, "*.tfrecord")) |
| if not eval_records: |
| raise ValueError(f"No .tfrecord files found in {args.eval_dataset}.") |
|
|
| num_train_samples = count_samples(training_records) |
|
|
| steps_per_epoch = num_train_samples // (args.per_replica_batch_size * strategy.num_replicas_in_sync) |
| total_train_steps = steps_per_epoch * args.num_epochs |
|
|
| with strategy.scope(): |
| model = TFAutoModelForMaskedLM.from_config(config) |
| model(model.dummy_inputs) |
| optimizer, schedule = create_optimizer( |
| num_train_steps=total_train_steps, |
| num_warmup_steps=total_train_steps // 20, |
| init_lr=args.learning_rate, |
| weight_decay_rate=args.weight_decay_rate, |
| ) |
|
|
| |
| |
| model.compile(optimizer=optimizer, metrics=["accuracy"]) |
|
|
| def decode_fn(example): |
| features = { |
| "input_ids": tf.io.FixedLenFeature(dtype=tf.int64, shape=(args.max_length,)), |
| "attention_mask": tf.io.FixedLenFeature(dtype=tf.int64, shape=(args.max_length,)), |
| } |
| return tf.io.parse_single_example(example, features) |
|
|
| |
| |
| data_collator = DataCollatorForLanguageModeling( |
| tokenizer=tokenizer, mlm_probability=args.mlm_probability, mlm=True, return_tensors="tf" |
| ) |
|
|
| def mask_with_collator(batch): |
| |
| special_tokens_mask = ( |
| ~tf.cast(batch["attention_mask"], tf.bool) |
| | (batch["input_ids"] == tokenizer.cls_token_id) |
| | (batch["input_ids"] == tokenizer.sep_token_id) |
| ) |
| batch["input_ids"], batch["labels"] = data_collator.tf_mask_tokens( |
| batch["input_ids"], |
| vocab_size=len(tokenizer), |
| mask_token_id=tokenizer.mask_token_id, |
| special_tokens_mask=special_tokens_mask, |
| ) |
| return batch |
|
|
| batch_size = args.per_replica_batch_size * strategy.num_replicas_in_sync |
|
|
| train_dataset = prepare_dataset( |
| training_records, |
| decode_fn=decode_fn, |
| mask_fn=mask_with_collator, |
| batch_size=batch_size, |
| shuffle=True, |
| shuffle_buffer_size=args.shuffle_buffer_size, |
| ) |
|
|
| eval_dataset = prepare_dataset( |
| eval_records, |
| decode_fn=decode_fn, |
| mask_fn=mask_with_collator, |
| batch_size=batch_size, |
| shuffle=False, |
| ) |
|
|
| callbacks = [] |
| if args.hub_model_id: |
| callbacks.append( |
| PushToHubCallback(output_dir=args.output_dir, hub_model_id=args.hub_model_id, tokenizer=tokenizer) |
| ) |
|
|
| model.fit( |
| train_dataset, |
| validation_data=eval_dataset, |
| epochs=args.num_epochs, |
| callbacks=callbacks, |
| ) |
|
|
| model.save_pretrained(args.output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| args = parse_args() |
| main(args) |
|
|