| ## Training Script | |
| This folder holds code for training the model (`train.py`), defining the model architecture (`model.py`), and defining utility functions including masking rate schedulers adn dataloaders (`utils.py`). There is also a script for running ESM-2 on the test data (`test_esm2.py`). | |
| The weights and other necessary files for loading FusOn-pLM are stored in `checkpoints/best/ckpt`. Results on the test set are stored in `checkpoints/best/test_results.csv`. | |
| ### Usage | |
| #### Configs | |
| The `config.py` script holds configurations for **training** and **plotting**. | |
| ```python | |
| # Model parameters | |
| EPOCHS = 30 | |
| BATCH_SIZE = 8 | |
| MAX_LENGTH = 2000 | |
| LEARNING_RATE = 3e-4 | |
| N_UNFROZEN_LAYERS = 8 | |
| UNFREEZE_QUERY = True | |
| UNFREEZE_KEY = True | |
| UNFREEZE_VALUE = True | |
| ### Masking parameters - must use either variable or fixed masking rate | |
| # var masking rate (choice 1) | |
| VAR_MASK_RATE = True # if this is | |
| MASK_LOW = 0.15 | |
| MASK_HIGH = 0.40 | |
| MASK_STEPS = 20 | |
| MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise" | |
| # fixed masking rate (choice 2) | |
| MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate | |
| # To continue training a model you already started, fill in the following parameters | |
| FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint | |
| PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False | |
| # File paths - do not change unless you move the training dta | |
| TRAIN_PATH = '../data/splits/train_df.csv' | |
| VAL_PATH = '../data/splits/val_df.csv' | |
| TEST_PATH = '../data/splits/test_df.csv' | |
| # WandB parameters | |
| # Fill these in with your own WandB account info | |
| WANDB_PROJECT = '' | |
| WANDB_ENTITY = '' | |
| WANDB_API_KEY='' | |
| # GPU parameters | |
| CUDA_VISIBLE_DEVICES = "0" | |
| ``` | |
| #### Training | |
| The `train.py` script trains a fusion-aware ESM model according to the settings specified in `config.py`. | |
| To run, enter in terminal: | |
| ```bash | |
| python train.py | |
| ``` | |
| or, to run the (long) training process in the background: | |
| ```bash | |
| nohup python train.py > train.out 2> train.err & |