Training Script
This folder holds code for training the model (train.py), defining the model architecture (model.py), and defining utility functions including masking rate schedulers adn dataloaders (utils.py). There is also a script for running ESM-2 on the test data (test_esm2.py).
The weights and other necessary files for loading FusOn-pLM are stored in checkpoints/best/ckpt. Results on the test set are stored in checkpoints/best/test_results.csv.
Usage
Configs
The config.py script holds configurations for training and plotting.
# Model parameters
EPOCHS = 30
BATCH_SIZE = 8
MAX_LENGTH = 2000
LEARNING_RATE = 3e-4
N_UNFROZEN_LAYERS = 8
UNFREEZE_QUERY = True
UNFREEZE_KEY = True
UNFREEZE_VALUE = True
### Masking parameters - must use either variable or fixed masking rate
# var masking rate (choice 1)
VAR_MASK_RATE = True # if this is
MASK_LOW = 0.15
MASK_HIGH = 0.40
MASK_STEPS = 20
MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
# fixed masking rate (choice 2)
MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
# To continue training a model you already started, fill in the following parameters
FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
# File paths - do not change unless you move the training dta
TRAIN_PATH = '../data/splits/train_df.csv'
VAL_PATH = '../data/splits/val_df.csv'
TEST_PATH = '../data/splits/test_df.csv'
# WandB parameters
# Fill these in with your own WandB account info
WANDB_PROJECT = ''
WANDB_ENTITY = ''
WANDB_API_KEY=''
# GPU parameters
CUDA_VISIBLE_DEVICES = "0"
Training
The train.py script trains a fusion-aware ESM model according to the settings specified in config.py.
To run, enter in terminal:
python train.py
or, to run the (long) training process in the background:
nohup python train.py > train.out 2> train.err &