| # Defaults for pretraining with train.py. |
| # |
| # |
| # You must also include a binding for MODEL. |
| # |
| # Required to be set |
| # |
| # - MIXTURE_OR_TASK_NAME |
| # - TASK_FEATURE_LENGTHS |
| # - TRAIN_STEPS - include pretrain steps |
| # - MODEL_DIR: # automatically set when using xm_launch |
| # |
| # Commonly overridden options: |
| # |
| # - train/DatasetConfig.batch_size |
| # - train_eval/DatasetConfig.batch_size |
| # - PjitPartitioner.num_partitions |
| # - Trainer.num_microbatches |
| # - DROPOUT_RATE |
| from __gin__ import dynamic_registration |
|
|
| import __main__ as train_script |
| from t5x import gin_utils |
| from t5x import partitioning |
| from t5x import utils |
| from t5x import trainer |
|
|
| MIXTURE_OR_TASK_NAME = %gin.REQUIRED |
| TASK_FEATURE_LENGTHS = %gin.REQUIRED |
| TRAIN_STEPS = %gin.REQUIRED |
| MODEL_DIR = %gin.REQUIRED |
| BATCH_SIZE = 128 |
| USE_CACHED_TASKS = True |
| INITIAL_CHECKPOINT_PATH = %gin.REQUIRED |
|
|
| # DEPRECATED: Import the this module in your gin file. |
| MIXTURE_OR_TASK_MODULE = None |
| SHUFFLE_TRAIN_EXAMPLES = True |
|
|
| # HW RNG is faster than SW, but has limited determinism. |
| # Most notably it is not deterministic across different |
| # submeshes. |
| USE_HARDWARE_RNG = False |
| # None always uses faster, hardware RNG |
| RANDOM_SEED = None |
|
|
| # Can be overridden with `train.*`.` |
| train_script.train: |
| model = %MODEL # imported from separate gin file |
| model_dir = %MODEL_DIR |
| train_dataset_cfg = @train/utils.DatasetConfig() |
| train_eval_dataset_cfg = @train_eval/utils.DatasetConfig() |
| infer_eval_dataset_cfg = None |
| checkpoint_cfg = @utils.CheckpointConfig() |
| partitioner = @partitioning.PjitPartitioner() |
| trainer_cls = @trainer.Trainer |
| total_steps = %TRAIN_STEPS |
| eval_steps = 20 |
| eval_period = 1000 |
| random_seed = %RANDOM_SEED |
| use_hardware_rng = %USE_HARDWARE_RNG |
| summarize_config_fn = @gin_utils.summarize_gin_config |
|
|
| partitioning.PjitPartitioner: |
| num_partitions = 1 |
| model_parallel_submesh = None |
| logical_axis_rules = @partitioning.standard_logical_axis_rules() |
|
|
| train/utils.DatasetConfig: |
| mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
| task_feature_lengths = %TASK_FEATURE_LENGTHS |
| split = 'train' |
| batch_size = %BATCH_SIZE |
| shuffle = %SHUFFLE_TRAIN_EXAMPLES |
| seed = None # use a new seed each run/restart |
| use_cached = %USE_CACHED_TASKS |
| pack = True |
| module = %MIXTURE_OR_TASK_MODULE |
|
|
| train_eval/utils.DatasetConfig: |
| mixture_or_task_name = %MIXTURE_OR_TASK_NAME |
| task_feature_lengths = %TASK_FEATURE_LENGTHS |
| split = 'validation' |
| batch_size = %BATCH_SIZE |
| shuffle = False |
| seed = 42 |
| use_cached = %USE_CACHED_TASKS |
| pack = True |
| module = %MIXTURE_OR_TASK_MODULE |
|
|
| utils.CheckpointConfig: |
| restore = @utils.RestoreCheckpointConfig() |
| save = @utils.SaveCheckpointConfig() |
| utils.RestoreCheckpointConfig: |
| path = %INITIAL_CHECKPOINT_PATH |
| mode = 'specific' |
| dtype = 'float32' |
| utils.SaveCheckpointConfig: |
| period = 50000 |
| dtype = 'float32' |
| keep = None # keep all checkpoints |
| save_dataset = False # don't checkpoint dataset state |
|
|
| trainer.Trainer: |
| num_microbatches = None |
| learning_rate_fn = @utils.create_learning_rate_scheduler() |
|
|
| utils.create_learning_rate_scheduler: |
| factors = 'constant * rsqrt_decay' |
| base_learning_rate = 0.5 #This is set to half of the original since it is continued training |
| warmup_steps = 10000 # 10k to keep consistent with T5/MTF defaults. |
|
|