| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import argparse |
| import os |
|
|
| from trainer import experiment |
|
|
|
|
| def get_args(): |
| """Define the task arguments with the default values. |
| |
| Returns: |
| experiment parameters |
| """ |
| args_parser = argparse.ArgumentParser() |
|
|
|
|
| |
| args_parser.add_argument( |
| '--batch-size', |
| help='Batch size for each training and evaluation step.', |
| type=int, |
| default=16) |
| args_parser.add_argument( |
| '--num-epochs', |
| help="""\ |
| Maximum number of training data epochs on which to train. |
| If both --train-size and --num-epochs are specified, |
| --train-steps will be: (train-size/train-batch-size) * num-epochs.\ |
| """, |
| default=1, |
| type=int, |
| ) |
| args_parser.add_argument( |
| '--seed', |
| help='Random seed (default: 42)', |
| type=int, |
| default=42, |
| ) |
|
|
| |
| args_parser.add_argument( |
| '--learning-rate', |
| help='Learning rate value for the optimizers.', |
| default=2e-5, |
| type=float) |
| args_parser.add_argument( |
| '--weight-decay', |
| help=""" |
| The factor by which the learning rate should decay by the end of the |
| training. |
| |
| decayed_learning_rate = |
| learning_rate * decay_rate ^ (global_step / decay_steps) |
| |
| If set to 0 (default), then no decay will occur. |
| If set to 0.5, then the learning rate should reach 0.5 of its original |
| value at the end of the training. |
| Note that decay_steps is set to train_steps. |
| """, |
| default=0.01, |
| type=float) |
|
|
| |
| args_parser.add_argument( |
| '--hp-tune', |
| default="n", |
| help='Enable hyperparameter tuning. Valida values are: "y" - enable, "n" - disable') |
| |
| |
| args_parser.add_argument( |
| '--job-dir', |
| default=os.getenv('AIP_MODEL_DIR'), |
| help='GCS location to export models') |
| args_parser.add_argument( |
| '--model-name', |
| default="finetuned-bert-classifier", |
| help='The name of your saved model') |
|
|
| return args_parser.parse_args() |
|
|
|
|
| def main(): |
| """Setup / Start the experiment |
| """ |
| args = get_args() |
| print(args) |
| experiment.run(args) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|