root
commited on
Commit
·
9a73cb0
1
Parent(s):
4f08905
uploaded training code and model weights
Browse files- fuson_plm/training/README.md +60 -0
- fuson_plm/training/__init__.py +0 -0
- fuson_plm/training/__pycache__/__init__.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/config.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/model.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/plot.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/train.cpython-310.pyc +0 -0
- fuson_plm/training/__pycache__/utils.cpython-310.pyc +0 -0
- fuson_plm/training/config.py +38 -0
- fuson_plm/training/demo.py +46 -0
- fuson_plm/training/model.py +119 -0
- fuson_plm/training/test_esm2.py +122 -0
- fuson_plm/training/train.py +388 -0
- fuson_plm/training/utils.py +312 -0
fuson_plm/training/README.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Training Script
|
| 2 |
+
|
| 3 |
+
This folder holds code for training the model (`train.py`), defining the model architecture (`model.py`), and defining utility functions including masking rate schedulers adn dataloaders (`utils.py`). There is also a script for running ESM-2 on the test data (`test_esm2.py`).
|
| 4 |
+
|
| 5 |
+
The weights and other necessary files for loading FusOn-pLM are stored in `checkpoints/best/ckpt`. Results on the test set are stored in `checkpoints/best/test_results.csv`.
|
| 6 |
+
|
| 7 |
+
### Usage
|
| 8 |
+
#### Configs
|
| 9 |
+
The `config.py` script holds configurations for **training** and **plotting**.
|
| 10 |
+
|
| 11 |
+
```python
|
| 12 |
+
# Model parameters
|
| 13 |
+
EPOCHS = 30
|
| 14 |
+
BATCH_SIZE = 8
|
| 15 |
+
MAX_LENGTH = 2000
|
| 16 |
+
LEARNING_RATE = 3e-4
|
| 17 |
+
N_UNFROZEN_LAYERS = 8
|
| 18 |
+
UNFREEZE_QUERY = True
|
| 19 |
+
UNFREEZE_KEY = True
|
| 20 |
+
UNFREEZE_VALUE = True
|
| 21 |
+
|
| 22 |
+
### Masking parameters - must use either variable or fixed masking rate
|
| 23 |
+
# var masking rate (choice 1)
|
| 24 |
+
VAR_MASK_RATE = True # if this is
|
| 25 |
+
MASK_LOW = 0.15
|
| 26 |
+
MASK_HIGH = 0.40
|
| 27 |
+
MASK_STEPS = 20
|
| 28 |
+
MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
|
| 29 |
+
# fixed masking rate (choice 2)
|
| 30 |
+
MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
|
| 31 |
+
|
| 32 |
+
# To continue training a model you already started, fill in the following parameters
|
| 33 |
+
FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
|
| 34 |
+
PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
|
| 35 |
+
|
| 36 |
+
# File paths - do not change unless you move the training dta
|
| 37 |
+
TRAIN_PATH = '../data/splits/train_df.csv'
|
| 38 |
+
VAL_PATH = '../data/splits/val_df.csv'
|
| 39 |
+
TEST_PATH = '../data/splits/test_df.csv'
|
| 40 |
+
|
| 41 |
+
# WandB parameters
|
| 42 |
+
# Fill these in with your own WandB account info
|
| 43 |
+
WANDB_PROJECT = ''
|
| 44 |
+
WANDB_ENTITY = ''
|
| 45 |
+
WANDB_API_KEY=''
|
| 46 |
+
|
| 47 |
+
# GPU parameters
|
| 48 |
+
CUDA_VISIBLE_DEVICES = "0"
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
#### Training
|
| 52 |
+
The `train.py` script trains a fusion-aware ESM model according to the settings specified in `config.py`.
|
| 53 |
+
|
| 54 |
+
To run, enter in terminal:
|
| 55 |
+
```bash
|
| 56 |
+
python train.py
|
| 57 |
+
```
|
| 58 |
+
or, to run the (long) training process in the background:
|
| 59 |
+
```bash
|
| 60 |
+
nohup python train.py > train.out 2> train.err &
|
fuson_plm/training/__init__.py
ADDED
|
File without changes
|
fuson_plm/training/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (152 Bytes). View file
|
|
|
fuson_plm/training/__pycache__/config.cpython-310.pyc
ADDED
|
Binary file (898 Bytes). View file
|
|
|
fuson_plm/training/__pycache__/model.cpython-310.pyc
ADDED
|
Binary file (4.88 kB). View file
|
|
|
fuson_plm/training/__pycache__/plot.cpython-310.pyc
ADDED
|
Binary file (4.02 kB). View file
|
|
|
fuson_plm/training/__pycache__/train.cpython-310.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
fuson_plm/training/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
fuson_plm/training/config.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
###### TRAINING
|
| 2 |
+
# Model parameters
|
| 3 |
+
EPOCHS = 30
|
| 4 |
+
BATCH_SIZE = 8
|
| 5 |
+
MAX_LENGTH = 2000
|
| 6 |
+
LEARNING_RATE = 3e-4
|
| 7 |
+
N_UNFROZEN_LAYERS = 8
|
| 8 |
+
UNFREEZE_QUERY = True
|
| 9 |
+
UNFREEZE_KEY = True
|
| 10 |
+
UNFREEZE_VALUE = True
|
| 11 |
+
|
| 12 |
+
### Masking parameters - must use either variable or fixed masking rate
|
| 13 |
+
# var masking rate (choice 1)
|
| 14 |
+
VAR_MASK_RATE = True # if this is
|
| 15 |
+
MASK_LOW = 0.15
|
| 16 |
+
MASK_HIGH = 0.40
|
| 17 |
+
MASK_STEPS = 20
|
| 18 |
+
MASK_SCHEDULER = "cosine" # specify the type of scheduler to use. options are: "cosine","loglinear","stepwise"
|
| 19 |
+
# fixed masking rate (choice 2)
|
| 20 |
+
MASK_PERCENTAGE = 0.15 # if VAR_MASK_RATE = False, code will use fixed masking rate
|
| 21 |
+
|
| 22 |
+
# To continue training a model you already started, fill in the following parameters
|
| 23 |
+
FINETUNE_FROM_SCRATCH = True # Set to False if you want to finetune from a checkpoint
|
| 24 |
+
PATH_TO_STARTING_CKPT = '' # only set the path if FINETUNE_FROM_SCRATCH = False
|
| 25 |
+
|
| 26 |
+
# File paths - do not change unless you move the training dta
|
| 27 |
+
TRAIN_PATH = '../data/splits/train_df.csv'
|
| 28 |
+
VAL_PATH = '../data/splits/val_df.csv'
|
| 29 |
+
TEST_PATH = '../data/splits/test_df.csv'
|
| 30 |
+
|
| 31 |
+
# WandB parameters
|
| 32 |
+
# Fill these in with your own WandB account info
|
| 33 |
+
WANDB_PROJECT = ''
|
| 34 |
+
WANDB_ENTITY = ''
|
| 35 |
+
WANDB_API_KEY=''
|
| 36 |
+
|
| 37 |
+
# GPU parameters
|
| 38 |
+
CUDA_VISIBLE_DEVICES = "0"
|
fuson_plm/training/demo.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fuson_plm.training.model import FusOnpLM
|
| 2 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
|
| 3 |
+
import logging
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = "1"
|
| 8 |
+
|
| 9 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
| 10 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 11 |
+
|
| 12 |
+
# Set device
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
print(f"Using device: {device}")
|
| 15 |
+
|
| 16 |
+
# Load the tokenizer and model
|
| 17 |
+
model_name = 'checkpoints/old_splits_snp_2000_ft_11layers_Q_b8_lr5e-05_mask0.15-08-12-2024-12:42:48/checkpoint_epoch_1.pth'
|
| 18 |
+
model = AutoModel.from_pretrained(model_name) # initialize model
|
| 19 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 20 |
+
model.eval()
|
| 21 |
+
model.to(device)
|
| 22 |
+
|
| 23 |
+
# Example fusion oncoprotein sequence: MLLT10:PICALM, associated with Acute Myeloid Leukemia (LAML)
|
| 24 |
+
# Amino acids 1-80 are derived from the head gene, MLLT10
|
| 25 |
+
# Amino acids 81-119 are derived from the tail gene, PICALM
|
| 26 |
+
sequence = "MVSSDRPVSLEDEVSHSMKEMIGGCCVCSDERGWAENPLVYCDGHGCSVAVHQACYGIVQVPTGPWFCRKCESQERAARVPPQMGSVPVMTQPTLIYSQPVMRPPNPFGPVSGAQIQFM"
|
| 27 |
+
|
| 28 |
+
# Tokenize the input sequence
|
| 29 |
+
inputs = tokenizer(sequence, return_tensors="pt", padding=True, truncation=True,max_length=2000)
|
| 30 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 31 |
+
|
| 32 |
+
# Get the embeddings
|
| 33 |
+
with torch.no_grad():
|
| 34 |
+
outputs = model(**inputs)
|
| 35 |
+
# The embeddings are in the last_hidden_state tensor
|
| 36 |
+
embeddings = outputs.last_hidden_state
|
| 37 |
+
# remove extra dimension
|
| 38 |
+
embeddings = embeddings.squeeze(0)
|
| 39 |
+
# remove BOS and EOS tokens
|
| 40 |
+
embeddings = embeddings[1:-1, :]
|
| 41 |
+
|
| 42 |
+
# Convert embeddings to numpy array (if needed)
|
| 43 |
+
embeddings = embeddings.cpu().numpy()
|
| 44 |
+
|
| 45 |
+
print("Sequence length: ", len(sequence))
|
| 46 |
+
print("Per-residue embeddings shape:", embeddings.shape)
|
fuson_plm/training/model.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer, AutoModel
|
| 2 |
+
import torch
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
class FusOnTokenizer:
|
| 6 |
+
"""
|
| 7 |
+
FusOnTokenizer class: a wrapper around AutoTokenizer
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D'):
|
| 10 |
+
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
|
| 11 |
+
|
| 12 |
+
def __getattr__(self, name):
|
| 13 |
+
"""
|
| 14 |
+
Delegate attribute access to the underlying tokenizer.
|
| 15 |
+
This allows calls like .tokenize(), .train(), and .eval() to be forwarded to the tokenizer.
|
| 16 |
+
"""
|
| 17 |
+
return getattr(self.tokenizer, name)
|
| 18 |
+
|
| 19 |
+
def __call__(self, *args, **kwargs):
|
| 20 |
+
"""
|
| 21 |
+
Make the FusOnTokenizer object callable, delegating to the tokenizer's __call__ method.
|
| 22 |
+
"""
|
| 23 |
+
return self.tokenizer(*args, **kwargs)
|
| 24 |
+
|
| 25 |
+
def save_tokenizer(self, save_directory):
|
| 26 |
+
self.tokenizer.save_pretrained(save_directory)
|
| 27 |
+
|
| 28 |
+
def load_tokenizer(self, load_directory):
|
| 29 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
|
| 30 |
+
|
| 31 |
+
class FusOnpLM:
|
| 32 |
+
"""
|
| 33 |
+
FusOn-pLM class: a wrapper around AutoModelForMaskedLM
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, pretrained_path='facebook/esm2_t33_650M_UR50D', ckpt_path = None, mlm_head=False):
|
| 36 |
+
if not(ckpt_path is None):
|
| 37 |
+
self.load_model(ckpt_path, mlm_head)
|
| 38 |
+
else:
|
| 39 |
+
# Load the pre-trained model and tokenizer
|
| 40 |
+
self.model = AutoModelForMaskedLM.from_pretrained(pretrained_path)
|
| 41 |
+
self.tokenizer = FusOnTokenizer(pretrained_path)
|
| 42 |
+
|
| 43 |
+
self.n_layers = self.count_encoder_layers()
|
| 44 |
+
|
| 45 |
+
def __getattr__(self, name):
|
| 46 |
+
"""
|
| 47 |
+
Delegate attribute access to the underlying model.
|
| 48 |
+
This allows calls like .to(), .train(), and .eval() to be forwarded to the model.
|
| 49 |
+
"""
|
| 50 |
+
return getattr(self.model, name)
|
| 51 |
+
|
| 52 |
+
def __call__(self, *args, **kwargs):
|
| 53 |
+
"""
|
| 54 |
+
Make the FusOnpLM object callable, delegating to the model's __call__ method.
|
| 55 |
+
"""
|
| 56 |
+
return self.model(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
def freeze_model(self):
|
| 59 |
+
"""
|
| 60 |
+
Freezes all parameters in the model
|
| 61 |
+
"""
|
| 62 |
+
for param in self.model.parameters():
|
| 63 |
+
param.requires_grad = False
|
| 64 |
+
|
| 65 |
+
def unfreeze_last_n_layers(self, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
|
| 66 |
+
"""
|
| 67 |
+
Unfreezes specific parts of the final n layers in the model's encoder.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
n_unfrozen_layers (int): Number of final layers to unfreeze.
|
| 71 |
+
unfreeze_query (bool): Whether to unfreeze the query projections. Default is True.
|
| 72 |
+
unfreeze_key (bool): Whether to unfreeze the key projections. Default is True.
|
| 73 |
+
unfreeze_value (bool): Whether to unfreeze the value projections. Default is True.
|
| 74 |
+
"""
|
| 75 |
+
for i, layer in enumerate(self.model.esm.encoder.layer):
|
| 76 |
+
if (self.n_layers - i) <= n_unfrozen_layers: # Only the last n layers
|
| 77 |
+
if unfreeze_query:
|
| 78 |
+
self._unfreeze_parameters(layer.attention.self.query)
|
| 79 |
+
if unfreeze_key:
|
| 80 |
+
self._unfreeze_parameters(layer.attention.self.key)
|
| 81 |
+
if unfreeze_value:
|
| 82 |
+
self._unfreeze_parameters(layer.attention.self.value)
|
| 83 |
+
|
| 84 |
+
def _unfreeze_parameters(self, module):
|
| 85 |
+
"""
|
| 86 |
+
Helper method to unfreeze parameters in a given module.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
module (nn.Module): The module whose parameters are to be unfrozen.
|
| 90 |
+
"""
|
| 91 |
+
for param in module.parameters():
|
| 92 |
+
param.requires_grad = True
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def count_encoder_layers(self):
|
| 96 |
+
"""
|
| 97 |
+
Count the number of encoder layers in the model.
|
| 98 |
+
"""
|
| 99 |
+
return len(self.model.esm.encoder.layer)
|
| 100 |
+
|
| 101 |
+
def save_model(self, save_directory, optimizer=None):
|
| 102 |
+
# Save the model and tokenizer
|
| 103 |
+
self.model.save_pretrained(save_directory)
|
| 104 |
+
self.tokenizer.save_pretrained(save_directory)
|
| 105 |
+
|
| 106 |
+
# If an optimizer is provided, save its state dict
|
| 107 |
+
if optimizer is not None:
|
| 108 |
+
optimizer_path = os.path.join(save_directory, "optimizer.pt")
|
| 109 |
+
torch.save(optimizer.state_dict(), optimizer_path)
|
| 110 |
+
|
| 111 |
+
def load_model(self, load_directory, mlm_head):
|
| 112 |
+
# Load a checkpoint of the model either with or without an MLM head
|
| 113 |
+
if mlm_head:
|
| 114 |
+
self.model = AutoModelForMaskedLM.from_pretrained(load_directory)
|
| 115 |
+
else:
|
| 116 |
+
# Load the model and tokenizer from a directory
|
| 117 |
+
self.model = AutoModel.from_pretrained(load_directory)
|
| 118 |
+
self.tokenizer = AutoTokenizer.from_pretrained(load_directory)
|
| 119 |
+
|
fuson_plm/training/test_esm2.py
ADDED
|
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
### Run ESM2 on the validation and test set. Get val and test losses.
|
| 2 |
+
import os
|
| 3 |
+
import fuson_plm.training.config as config
|
| 4 |
+
# Set the WANDB_API_KEY environment variable
|
| 5 |
+
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
|
| 6 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import tqdm
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import logging
|
| 13 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
| 14 |
+
from fuson_plm.utils.logging import log_update, open_logfile, print_configpy
|
| 15 |
+
from fuson_plm.benchmarking.caid.utils import DisorderDataset, get_dataloader, check_dataloaders
|
| 16 |
+
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders
|
| 17 |
+
from fuson_plm.training.train import test
|
| 18 |
+
|
| 19 |
+
def load_esm2_maskedlm(esm_type, device=None):
|
| 20 |
+
"""
|
| 21 |
+
Loads ESM-2 version of a specified version (e.g. esm2_t33_650M_UR50D)
|
| 22 |
+
"""
|
| 23 |
+
# Suppress warnings about newly initialized 'esm.pooler.dense.bias', 'esm.pooler.dense.weight' layers - these are not used to extract embeddings
|
| 24 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 25 |
+
|
| 26 |
+
if device is None:
|
| 27 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 28 |
+
print(f"Using device: {device}")
|
| 29 |
+
|
| 30 |
+
model = AutoModelForMaskedLM.from_pretrained(f"facebook/{esm_type}")
|
| 31 |
+
tokenizer = AutoTokenizer.from_pretrained(f"facebook/{esm_type}")
|
| 32 |
+
|
| 33 |
+
model.to(device)
|
| 34 |
+
model.eval() # disables dropout for deterministic results
|
| 35 |
+
|
| 36 |
+
return model, tokenizer, device
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def val(model, tokenizer, val_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
|
| 40 |
+
"""
|
| 41 |
+
Same method as val, just for running the val set
|
| 42 |
+
"""
|
| 43 |
+
model.to(device)
|
| 44 |
+
model.eval()
|
| 45 |
+
total_val_loss = 0
|
| 46 |
+
total_weighted_val_loss = 0
|
| 47 |
+
total_val_masked_tokens = 0
|
| 48 |
+
|
| 49 |
+
with torch.no_grad(): # No gradients needed
|
| 50 |
+
# Loop over val data (no progress bar)
|
| 51 |
+
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Val Batch', leave=True, position=0) as tbar:
|
| 52 |
+
for batch_idx, (inputs, prob) in tbar:
|
| 53 |
+
# Move tensors
|
| 54 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 55 |
+
prob = prob.to(device)
|
| 56 |
+
|
| 57 |
+
# Mask based on probability vectors
|
| 58 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=mask_percentage)
|
| 59 |
+
|
| 60 |
+
# Forward pass
|
| 61 |
+
outputs = model(**masked_inputs)
|
| 62 |
+
val_loss = outputs.loss
|
| 63 |
+
|
| 64 |
+
# Number of masked tokens
|
| 65 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
| 66 |
+
|
| 67 |
+
# Loss calculations
|
| 68 |
+
total_val_loss += val_loss.item()
|
| 69 |
+
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
| 70 |
+
total_val_masked_tokens += num_masked_tokens
|
| 71 |
+
|
| 72 |
+
# Compute and log avg. loss and perplexity
|
| 73 |
+
n_val_batches = len(val_loader)
|
| 74 |
+
avg_val_loss = total_val_loss / n_val_batches
|
| 75 |
+
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens
|
| 76 |
+
val_perplexity = np.exp(avg_weighted_val_loss)
|
| 77 |
+
|
| 78 |
+
log_update(f"\nval results:\nTotal batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
|
| 79 |
+
|
| 80 |
+
# Save to dataframe for plotting
|
| 81 |
+
val_stats_df = pd.DataFrame(data={
|
| 82 |
+
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
|
| 83 |
+
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
|
| 84 |
+
"val_perplexity": [val_perplexity]
|
| 85 |
+
})
|
| 86 |
+
val_stats_df.to_csv(f"{checkpoint_dir}/val_results.csv",index=False) # overwrite old file no matter what; should only be one val eval
|
| 87 |
+
|
| 88 |
+
def main():
|
| 89 |
+
# Load the ESM-2 model
|
| 90 |
+
model, tokenizer, device = load_esm2_maskedlm("esm2_t33_650M_UR50D")
|
| 91 |
+
|
| 92 |
+
checkpoint_dir = f"checkpoints/esm2_t33_650M_UR50D_{config.PROBABILITY_TYPE}_mask{config.MASK_PERCENTAGE}"
|
| 93 |
+
os.makedirs(checkpoint_dir,exist_ok=True)
|
| 94 |
+
|
| 95 |
+
with open_logfile(f"{checkpoint_dir}/evaluate_val_test_esm.txt"):
|
| 96 |
+
# Print configurations
|
| 97 |
+
print_configpy(config)
|
| 98 |
+
|
| 99 |
+
##### Validation
|
| 100 |
+
val_loader = get_dataloader(config.VAL_PATH, tokenizer,
|
| 101 |
+
probability_type=config.PROBABILITY_TYPE,
|
| 102 |
+
batch_size=config.BATCH_SIZE,
|
| 103 |
+
max_length=config.MAX_LENGTH, shuffle=False)
|
| 104 |
+
|
| 105 |
+
# Validation
|
| 106 |
+
val(model, tokenizer, val_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
##### Test
|
| 110 |
+
# Crete dataloader
|
| 111 |
+
test_loader = get_dataloader(config.TEST_PATH,
|
| 112 |
+
tokenizer,
|
| 113 |
+
probability_type=config.PROBABILITY_TYPE,
|
| 114 |
+
batch_size=config.BATCH_SIZE,
|
| 115 |
+
max_length=config.MAX_LENGTH, shuffle=False)
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
# Test the model
|
| 119 |
+
test(model, tokenizer, test_loader, config.MASK_PERCENTAGE, device=device, checkpoint_dir=checkpoint_dir)
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
main()
|
fuson_plm/training/train.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''
|
| 2 |
+
This is a training script for finetuning ESM.
|
| 3 |
+
I am going to freeze the parameters in the head and unfreeze the last N layers in the model.
|
| 4 |
+
'''
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import fuson_plm.training.config as config
|
| 8 |
+
|
| 9 |
+
# Set the WANDB_API_KEY environment variable
|
| 10 |
+
os.environ['WANDB_API_KEY'] = config.WANDB_API_KEY
|
| 11 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = config.CUDA_VISIBLE_DEVICES
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
import tqdm
|
| 17 |
+
from datetime import datetime
|
| 18 |
+
import wandb
|
| 19 |
+
import pytz
|
| 20 |
+
import sys
|
| 21 |
+
|
| 22 |
+
from transformers import AdamW
|
| 23 |
+
|
| 24 |
+
from fuson_plm.utils.logging import print_configpy, get_local_time, open_logfile, open_errfile, log_update
|
| 25 |
+
from fuson_plm.training.model import FusOnpLM
|
| 26 |
+
from fuson_plm.training.utils import batch_sample_mask_tokens_with_probabilities, get_dataloader, check_dataloaders, get_mask_rate_scheduler
|
| 27 |
+
from fuson_plm.training.plot import make_train_val_test_bd_plot
|
| 28 |
+
|
| 29 |
+
def prepare_model(model, n_unfrozen_layers, unfreeze_query=True, unfreeze_key=True, unfreeze_value=True):
|
| 30 |
+
# Log the model's initial state
|
| 31 |
+
n_layers = model.count_encoder_layers()
|
| 32 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 33 |
+
total_head_params = sum(p.numel() for p in model.lm_head.parameters())
|
| 34 |
+
log_update(f'\nInitial state:\n\tTotal number of layers in the model: {n_layers}')
|
| 35 |
+
log_update(f'\tTotal parameters in the AutoModelforMaskedLM model: {total_params}')
|
| 36 |
+
log_update(f'\tTotal parameters in the MLM Head ONLY: {total_head_params}')
|
| 37 |
+
|
| 38 |
+
# Freeze the model to start
|
| 39 |
+
model.freeze_model()
|
| 40 |
+
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 41 |
+
log_update(f'Froze all {model.n_layers} model layers')
|
| 42 |
+
log_update(f'\tTrainable params: {n_trainable_params}')
|
| 43 |
+
|
| 44 |
+
# Unfreeze the last n layers
|
| 45 |
+
model.unfreeze_last_n_layers(n_unfrozen_layers, unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
|
| 46 |
+
n_trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 47 |
+
trainable_params = '\n\t\t'.join([name for name, param in model.named_parameters() if param.requires_grad])
|
| 48 |
+
num_trainable_params_lm_head = sum(p.numel() for p in model.lm_head.parameters() if p.requires_grad)
|
| 49 |
+
num_trainable_params_esm = sum(p.numel() for p in model.esm.parameters() if p.requires_grad)
|
| 50 |
+
log_update(f'Unfroze final {n_unfrozen_layers} layers')
|
| 51 |
+
log_update(f'\tTrainable params: {n_trainable_params}\n\t\t{trainable_params}')
|
| 52 |
+
log_update(f"\tTrainable parameters in the lm_head: {num_trainable_params_lm_head}")
|
| 53 |
+
log_update(f"\tTrainable params in the ESM part: {num_trainable_params_esm}")
|
| 54 |
+
|
| 55 |
+
def train(model, tokenizer, optimizer, train_loader, val_loader, n_epochs=10, start_epoch=1, mask_percentage=0.15, mask_rate_scheduler=None, device='cuda', checkpoint_dir='./checkpoints'):
|
| 56 |
+
"""
|
| 57 |
+
Train the model
|
| 58 |
+
"""
|
| 59 |
+
# Loop over epochs
|
| 60 |
+
log_update("\n")
|
| 61 |
+
|
| 62 |
+
for epoch in range(start_epoch, start_epoch+n_epochs):
|
| 63 |
+
if mask_rate_scheduler is not None:
|
| 64 |
+
mask_rate_scheduler.reset() # resetting because we rant to ramp it up again every epoch
|
| 65 |
+
|
| 66 |
+
model.train()
|
| 67 |
+
total_train_loss = 0
|
| 68 |
+
total_weighted_train_loss = 0
|
| 69 |
+
total_train_masked_tokens = 0
|
| 70 |
+
|
| 71 |
+
log_update(f"Epoch {epoch}")
|
| 72 |
+
# Loop over train data with progress bar
|
| 73 |
+
with tqdm.tqdm(enumerate(train_loader), total=len(train_loader), desc='Training Batch', leave=True, position=0) as pbar:
|
| 74 |
+
for batch_idx, (inputs, prob) in pbar:
|
| 75 |
+
# Take a step with the mask rate scheduler, if there is one.
|
| 76 |
+
masking_rate = mask_percentage
|
| 77 |
+
if mask_rate_scheduler is not None:
|
| 78 |
+
mask_rate_scheduler.step()
|
| 79 |
+
masking_rate = mask_rate_scheduler.get_masking_rate()
|
| 80 |
+
log_update(f"\tBatch index: {batch_idx}\tMasking rate: {masking_rate:.5f}")
|
| 81 |
+
|
| 82 |
+
# Move tensors
|
| 83 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 84 |
+
prob = prob.to(device)
|
| 85 |
+
|
| 86 |
+
# Mask based on probability vectors
|
| 87 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=masking_rate)
|
| 88 |
+
|
| 89 |
+
# Forward pass and update
|
| 90 |
+
optimizer.zero_grad()
|
| 91 |
+
outputs = model(**masked_inputs)
|
| 92 |
+
loss = outputs.loss
|
| 93 |
+
loss.backward()
|
| 94 |
+
optimizer.step()
|
| 95 |
+
|
| 96 |
+
# Number of masked tokens
|
| 97 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
| 98 |
+
|
| 99 |
+
# Loss calculations and wandb log
|
| 100 |
+
total_train_loss += loss.item()
|
| 101 |
+
total_weighted_train_loss += loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
| 102 |
+
total_train_masked_tokens += num_masked_tokens
|
| 103 |
+
wandb.log({"batch_loss": loss.item()})
|
| 104 |
+
|
| 105 |
+
# Save a checkpoint at the end of each epoch
|
| 106 |
+
checkpoint_path = os.path.join(checkpoint_dir, f'checkpoint_epoch_{epoch}')
|
| 107 |
+
model.save_model(checkpoint_path, optimizer=optimizer)
|
| 108 |
+
log_update(f'\nSaved checkpoint to {checkpoint_path}')
|
| 109 |
+
|
| 110 |
+
# Calculate and log average training loss on wandb
|
| 111 |
+
n_train_batches = len(train_loader)
|
| 112 |
+
avg_train_loss = total_train_loss / n_train_batches
|
| 113 |
+
avg_weighted_train_loss = total_weighted_train_loss / total_train_masked_tokens
|
| 114 |
+
train_perplexity = np.exp(avg_weighted_train_loss)
|
| 115 |
+
wandb.log({"epoch": epoch,
|
| 116 |
+
"total_train_loss": total_train_loss, "weighted_train_loss": total_weighted_train_loss,
|
| 117 |
+
"avg_train_loss": avg_train_loss, "avg_weighted_train_loss": avg_weighted_train_loss,
|
| 118 |
+
"train_perplexity": train_perplexity})
|
| 119 |
+
|
| 120 |
+
# Track curve stats for easy re-plotting of training curves later
|
| 121 |
+
train_stats_df = pd.DataFrame(data={
|
| 122 |
+
"epoch": [epoch],
|
| 123 |
+
"total_train_loss": [total_train_loss], "weighted_train_loss": [total_weighted_train_loss],
|
| 124 |
+
"avg_train_loss": [avg_train_loss], "avg_weighted_train_loss": [avg_weighted_train_loss],
|
| 125 |
+
"train_perplexity": [train_perplexity]
|
| 126 |
+
})
|
| 127 |
+
if os.path.exists(f"{checkpoint_dir}/train_curve.csv"): # add to file if necessary
|
| 128 |
+
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False,header=False,mode='a')
|
| 129 |
+
else: # make new file if necessary
|
| 130 |
+
train_stats_df.to_csv(f"{checkpoint_dir}/train_curve.csv",index=False)
|
| 131 |
+
|
| 132 |
+
# Validation loop
|
| 133 |
+
model.eval()
|
| 134 |
+
total_val_loss = 0
|
| 135 |
+
total_weighted_val_loss = 0
|
| 136 |
+
total_val_masked_tokens = 0
|
| 137 |
+
|
| 138 |
+
with torch.no_grad(): # No gradients needed
|
| 139 |
+
# Loop over val data with progress bar
|
| 140 |
+
with tqdm.tqdm(enumerate(val_loader), total=len(val_loader), desc='Validation Batch', leave=True, position=0) as vbar:
|
| 141 |
+
for batch_idx, (inputs, prob) in vbar:
|
| 142 |
+
# Move tensors
|
| 143 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 144 |
+
prob = prob.to(device)
|
| 145 |
+
|
| 146 |
+
# Mask based on probability vectors
|
| 147 |
+
## FIXED 15% masking for the validation set
|
| 148 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer,mask_percentage=0.15)
|
| 149 |
+
|
| 150 |
+
# Forward pass
|
| 151 |
+
outputs = model(**masked_inputs)
|
| 152 |
+
val_loss = outputs.loss
|
| 153 |
+
|
| 154 |
+
# Number of masked tokens
|
| 155 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
| 156 |
+
|
| 157 |
+
# Loss calculations
|
| 158 |
+
total_val_loss += val_loss.item()
|
| 159 |
+
total_weighted_val_loss += val_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
| 160 |
+
total_val_masked_tokens += num_masked_tokens
|
| 161 |
+
|
| 162 |
+
# Calculate and log avg. loss and perplexity (wandb and locally)
|
| 163 |
+
n_val_batches = len(val_loader)
|
| 164 |
+
avg_val_loss = total_val_loss / n_val_batches # avg per batch
|
| 165 |
+
avg_weighted_val_loss = total_weighted_val_loss / total_val_masked_tokens # avg per masked token
|
| 166 |
+
val_perplexity = np.exp(avg_weighted_val_loss)
|
| 167 |
+
wandb.log({"epoch": epoch,
|
| 168 |
+
"total_val_loss": total_val_loss, "weighted_val_loss": total_weighted_val_loss,
|
| 169 |
+
"avg_val_loss": avg_val_loss, "avg_weighted_val_loss": avg_weighted_val_loss,
|
| 170 |
+
"val_perplexity": val_perplexity})
|
| 171 |
+
|
| 172 |
+
# Track curve stats for easy re-plotting of training curves later
|
| 173 |
+
val_stats_df = pd.DataFrame(data={
|
| 174 |
+
"epoch": [epoch],
|
| 175 |
+
"total_val_loss": [total_val_loss], "weighted_val_loss": [total_weighted_val_loss],
|
| 176 |
+
"avg_val_loss": [avg_val_loss], "avg_weighted_val_loss": [avg_weighted_val_loss],
|
| 177 |
+
"val_perplexity": [val_perplexity]
|
| 178 |
+
})
|
| 179 |
+
if os.path.exists(f"{checkpoint_dir}/val_curve.csv"): # add to file if necessary
|
| 180 |
+
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False,header=False,mode='a')
|
| 181 |
+
else: # make new file if necessary
|
| 182 |
+
val_stats_df.to_csv(f"{checkpoint_dir}/val_curve.csv",index=False)
|
| 183 |
+
|
| 184 |
+
log_update(f"Epoch: {epoch}")
|
| 185 |
+
log_update(f"\tTrain set: Total batches = {n_train_batches}, Total masked tokens = {total_train_masked_tokens}, Total Loss = {total_train_loss:.4f}, Avg Batch Loss = {avg_train_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_train_loss:.4f}, Perplexity = {train_perplexity:.4f}")
|
| 186 |
+
log_update(f"\tValidation set: Total batches = {n_val_batches}, Total masked tokens = {total_val_masked_tokens}, Total Loss = {total_val_loss:.4f}, Avg Batch Loss = {avg_val_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_val_loss:.4f}, Perplexity = {val_perplexity:.4f}")
|
| 187 |
+
|
| 188 |
+
def test(model, tokenizer, test_loader, mask_percentage=0.15, device='cuda', checkpoint_dir='./checkpoints'):
|
| 189 |
+
"""
|
| 190 |
+
"""
|
| 191 |
+
model.to(device)
|
| 192 |
+
model.eval()
|
| 193 |
+
total_test_loss = 0
|
| 194 |
+
total_weighted_test_loss = 0
|
| 195 |
+
total_test_masked_tokens = 0
|
| 196 |
+
|
| 197 |
+
with torch.no_grad(): # No gradients needed
|
| 198 |
+
# Loop over test data (no progress bar)
|
| 199 |
+
with tqdm.tqdm(enumerate(test_loader), total=len(test_loader), desc='Test Batch', leave=True, position=0) as tbar:
|
| 200 |
+
for batch_idx, (inputs, prob) in tbar:
|
| 201 |
+
# Move tensors
|
| 202 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 203 |
+
prob = prob.to(device)
|
| 204 |
+
|
| 205 |
+
# Mask based on probability vectors
|
| 206 |
+
### FIXED 15% masking for the testing set
|
| 207 |
+
masked_inputs = batch_sample_mask_tokens_with_probabilities(inputs, prob, tokenizer, mask_percentage=0.15)
|
| 208 |
+
|
| 209 |
+
# Forward pass
|
| 210 |
+
outputs = model(**masked_inputs)
|
| 211 |
+
test_loss = outputs.loss
|
| 212 |
+
|
| 213 |
+
# Number of masked tokens
|
| 214 |
+
num_masked_tokens = (masked_inputs["input_ids"] == tokenizer.mask_token_id).sum().item()
|
| 215 |
+
|
| 216 |
+
# Loss calculations
|
| 217 |
+
total_test_loss += test_loss.item()
|
| 218 |
+
total_weighted_test_loss += test_loss.item() * num_masked_tokens # Multiply loss by number of masked tokens
|
| 219 |
+
total_test_masked_tokens += num_masked_tokens
|
| 220 |
+
|
| 221 |
+
# Compute and log avg. loss and perplexity
|
| 222 |
+
n_test_batches = len(test_loader)
|
| 223 |
+
avg_test_loss = total_test_loss / n_test_batches
|
| 224 |
+
avg_weighted_test_loss = total_weighted_test_loss / total_test_masked_tokens
|
| 225 |
+
test_perplexity = np.exp(avg_weighted_test_loss)
|
| 226 |
+
|
| 227 |
+
log_update(f"\nTest results:\nTotal batches = {n_test_batches}, Total masked tokens = {total_test_masked_tokens}, Total Loss = {total_test_loss:.4f}, Avg Batch Loss = {avg_test_loss:.4f}, Avg Masked Token-Weighted Loss = {avg_weighted_test_loss:.4f}, Perplexity = {test_perplexity:.4f}")
|
| 228 |
+
|
| 229 |
+
# Save to dataframe for plotting
|
| 230 |
+
test_stats_df = pd.DataFrame(data={
|
| 231 |
+
"total_test_loss": [total_test_loss], "weighted_test_loss": [total_weighted_test_loss],
|
| 232 |
+
"avg_test_loss": [avg_test_loss], "avg_weighted_test_loss": [avg_weighted_test_loss],
|
| 233 |
+
"test_perplexity": [test_perplexity]
|
| 234 |
+
})
|
| 235 |
+
test_stats_df.to_csv(f"{checkpoint_dir}/test_results.csv",index=False) # overwrite old file no matter what; should only be one test eval
|
| 236 |
+
|
| 237 |
+
def check_env_variables():
|
| 238 |
+
log_update("\nChecking on environment variables...")
|
| 239 |
+
log_update(f"\tWANDB_API_KEY: {os.environ.get('WANDB_API_KEY')}")
|
| 240 |
+
log_update(f"\tCUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES')}")
|
| 241 |
+
log_update(f"\ttorch.cuda.device_count(): {torch.cuda.device_count()}")
|
| 242 |
+
for i in range(torch.cuda.device_count()):
|
| 243 |
+
log_update(f"\t\tDevice {i}: {torch.cuda.get_device_name(i)}")
|
| 244 |
+
|
| 245 |
+
def intialize_model_and_optimizer(finetune_from_scratch, device, path_to_starting_ckpt=None, learning_rate=1e-4, n_unfrozen_layers=0, unfreeze_query=False, unfreeze_key=False, unfreeze_value=False):
|
| 246 |
+
"""
|
| 247 |
+
Initializes the model, either from ESM-2-650M if finetuning from scratch, or from a prior checkpoint if not finetuning from scratch.
|
| 248 |
+
Also prepares
|
| 249 |
+
|
| 250 |
+
Args:
|
| 251 |
+
finetune_from_scratch (bool): True if finetuning from scratch. False if finetuning from a previous ckpt
|
| 252 |
+
path_to_starting_ckpt (str): path to starting ckpt for finetuning (optional)
|
| 253 |
+
"""
|
| 254 |
+
if not(finetune_from_scratch) and not(os.path.exists(path_to_starting_ckpt)):
|
| 255 |
+
raise Exception(f"Error: could not find {path_to_starting_ckpt}. When finetuning from a prior checkpoint, you must provide a valid path to that checkpoint.")
|
| 256 |
+
|
| 257 |
+
# if finetuning from scratch, initialize from scratch
|
| 258 |
+
if finetune_from_scratch:
|
| 259 |
+
log_update(f"\nInitializing FusOn-pLM model to be finetuned from scratch")
|
| 260 |
+
model = FusOnpLM() # because of __getattr__, we can use FusOnpLM() to get the model. It also contains the tokenizer.
|
| 261 |
+
model.to(device)
|
| 262 |
+
prepare_model(model, n_unfrozen_layers,
|
| 263 |
+
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
|
| 264 |
+
|
| 265 |
+
# Set the optimizer here, change it if we are finetuning from an old checkpoint
|
| 266 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=learning_rate)
|
| 267 |
+
|
| 268 |
+
return model, optimizer
|
| 269 |
+
|
| 270 |
+
# if not, initialize from starting ckpt
|
| 271 |
+
else:
|
| 272 |
+
log_update(f"\nInitializing FusOn-pLM model to be finetuned from previous checkpoint: {path_to_starting_ckpt}")
|
| 273 |
+
model = FusOnpLM(ckpt_path = path_to_starting_ckpt, mlm_head=True)
|
| 274 |
+
model.to(device)
|
| 275 |
+
prepare_model(model, n_unfrozen_layers,
|
| 276 |
+
unfreeze_query=unfreeze_query, unfreeze_key=unfreeze_key, unfreeze_value=unfreeze_value)
|
| 277 |
+
|
| 278 |
+
log_update(f"Loading optimizer state_dict from previous checkpoint")
|
| 279 |
+
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()))
|
| 280 |
+
optimizer.load_state_dict(torch.load(os.path.join(path_to_starting_ckpt, "optimizer.pt"), map_location=device))
|
| 281 |
+
|
| 282 |
+
return model, optimizer
|
| 283 |
+
|
| 284 |
+
def main():
|
| 285 |
+
# Set probability type to uniform; only option
|
| 286 |
+
config.PROBABILITY_TYPE = "uniform"
|
| 287 |
+
|
| 288 |
+
# Set run name (WANDB_NAME)
|
| 289 |
+
kqv_tag = f"{'Q' if config.UNFREEZE_QUERY else ''}" + f"{'K' if config.UNFREEZE_KEY else ''}" + f"{'V' if config.UNFREEZE_VALUE else ''}"
|
| 290 |
+
timestamp = get_local_time()
|
| 291 |
+
# make a mask tag _mask{config.MASK_PERCENTAGE}
|
| 292 |
+
mask_tag = f"mask{config.MASK_PERCENTAGE}"
|
| 293 |
+
if config.VAR_MASK_RATE: # if variable masking rate, change the tag to relfect this
|
| 294 |
+
mask_tag=f"maskvar_{config.MASK_SCHEDULER}_low{config.MASK_LOW}_high{config.MASK_HIGH}"
|
| 295 |
+
|
| 296 |
+
# Define the train settings string and wandb name from this
|
| 297 |
+
TRAIN_SETTINGS_STRING = f"{config.PROBABILITY_TYPE}_{config.MAX_LENGTH}_ft_{config.N_UNFROZEN_LAYERS}layers_{kqv_tag}_b{config.BATCH_SIZE}_lr{config.LEARNING_RATE}_{mask_tag}"
|
| 298 |
+
WANDB_NAME = f'{TRAIN_SETTINGS_STRING}-{timestamp}'
|
| 299 |
+
|
| 300 |
+
# Create directory for model checkpoints
|
| 301 |
+
checkpoint_dir = f'checkpoints/{WANDB_NAME}'
|
| 302 |
+
start_epoch = 1
|
| 303 |
+
|
| 304 |
+
# Determine if we're adding to an old log file or opening a new one
|
| 305 |
+
logmode='w'
|
| 306 |
+
|
| 307 |
+
# If we're finetuning from a checkpoint, save to the same folder instead, and keep track of which epoch to start on
|
| 308 |
+
# Also, load the optimizer from here
|
| 309 |
+
if not(config.FINETUNE_FROM_SCRATCH):
|
| 310 |
+
logmode='a'
|
| 311 |
+
path_to_starting_ckpt = config.PATH_TO_STARTING_CKPT
|
| 312 |
+
checkpoint_dir = path_to_starting_ckpt[0:path_to_starting_ckpt.rindex('/')]
|
| 313 |
+
START_MODEL_TRAIN_SETTINGS_STRING = checkpoint_dir[checkpoint_dir.index('checkpoints/')+len('checkpoints/'):checkpoint_dir.index('-')]
|
| 314 |
+
start_epoch = int(path_to_starting_ckpt.split('/checkpoint_epoch_')[1])+1
|
| 315 |
+
|
| 316 |
+
os.makedirs(f'checkpoints', exist_ok=True)
|
| 317 |
+
os.makedirs(checkpoint_dir, exist_ok=True)
|
| 318 |
+
|
| 319 |
+
# Open log file
|
| 320 |
+
LOG_PATH = f'{checkpoint_dir}/training_log.txt'
|
| 321 |
+
ERR_PATH = f'{checkpoint_dir}/training_errors.txt'
|
| 322 |
+
with open_logfile(LOG_PATH,mode=logmode), open_errfile(ERR_PATH,mode=logmode):
|
| 323 |
+
if not(config.FINETUNE_FROM_SCRATCH):
|
| 324 |
+
log_update(f"\n{'-'*200}\nResuming finetuning from checkpoint {start_epoch-1} (first new checkpoint: {start_epoch})\n")
|
| 325 |
+
log_update(f"Settings tag for original model (starting point for finetuning) = {START_MODEL_TRAIN_SETTINGS_STRING}\nSettings tag for new model based on configs = {TRAIN_SETTINGS_STRING}\nSame: {START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING}\n")
|
| 326 |
+
# ONLY proceed with training if we're using the same settings, otherwise we are not finetuning the model we think we are!
|
| 327 |
+
assert START_MODEL_TRAIN_SETTINGS_STRING==TRAIN_SETTINGS_STRING
|
| 328 |
+
|
| 329 |
+
# Print configurations
|
| 330 |
+
print_configpy(config)
|
| 331 |
+
|
| 332 |
+
# Verify that the environment variables are set correctly
|
| 333 |
+
check_env_variables()
|
| 334 |
+
|
| 335 |
+
# Check CUDA availability and set device
|
| 336 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 337 |
+
log_update(f"\nUsing device: {device}")
|
| 338 |
+
|
| 339 |
+
# Init wandb
|
| 340 |
+
wandb.init(project=config.WANDB_PROJECT, entity=config.WANDB_ENTITY, name=WANDB_NAME , config={
|
| 341 |
+
"batch_size": config.BATCH_SIZE,
|
| 342 |
+
"epochs": config.EPOCHS,
|
| 343 |
+
"learning_rate": config.LEARNING_RATE,
|
| 344 |
+
})
|
| 345 |
+
|
| 346 |
+
# Initialize model and prepare it (freeze/unfreeze proper layers). Initialize optimizer as well. Details depend on whether we are finetuning from scratch.
|
| 347 |
+
model, optimizer = intialize_model_and_optimizer(config.FINETUNE_FROM_SCRATCH, device,
|
| 348 |
+
path_to_starting_ckpt=config.PATH_TO_STARTING_CKPT,
|
| 349 |
+
learning_rate=config.LEARNING_RATE,
|
| 350 |
+
n_unfrozen_layers=config.N_UNFROZEN_LAYERS,
|
| 351 |
+
unfreeze_query=config.UNFREEZE_QUERY,
|
| 352 |
+
unfreeze_key=config.UNFREEZE_KEY,
|
| 353 |
+
unfreeze_value=config.UNFREEZE_VALUE)
|
| 354 |
+
|
| 355 |
+
# Initialize the tokenizer (independent of starting model for finetuning)
|
| 356 |
+
tokenizer = model.tokenizer
|
| 357 |
+
|
| 358 |
+
# Create DataLoader instances and perform sanity checks on them
|
| 359 |
+
train_loader = get_dataloader(config.TRAIN_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=True) ## FOR DEBUGGING ONLY, change shuffle to False. Otherwise, True!!
|
| 360 |
+
val_loader = get_dataloader(config.VAL_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
|
| 361 |
+
test_loader = get_dataloader(config.TEST_PATH, tokenizer, probability_type=config.PROBABILITY_TYPE, batch_size=config.BATCH_SIZE, max_length=config.MAX_LENGTH, shuffle=False)
|
| 362 |
+
|
| 363 |
+
# If we're continuing to finetune an old ckpt, store the old batch diversity plot before we overwrite it
|
| 364 |
+
check_dataloaders(train_loader, val_loader, test_loader, max_length=config.MAX_LENGTH, checkpoint_dir=checkpoint_dir)
|
| 365 |
+
|
| 366 |
+
# Set up a masking rate scheduler, if one is needed
|
| 367 |
+
mask_rate_scheduler = None
|
| 368 |
+
if config.VAR_MASK_RATE:
|
| 369 |
+
mask_rate_scheduler = get_mask_rate_scheduler(scheduler_type=config.MASK_SCHEDULER,
|
| 370 |
+
min_masking_rate=config.MASK_LOW,
|
| 371 |
+
max_masking_rate=config.MASK_HIGH,
|
| 372 |
+
total_batches=len(train_loader),
|
| 373 |
+
total_steps=config.MASK_STEPS)
|
| 374 |
+
|
| 375 |
+
# Train the model
|
| 376 |
+
train(model, tokenizer, optimizer, train_loader, val_loader,
|
| 377 |
+
n_epochs=config.EPOCHS,
|
| 378 |
+
start_epoch = start_epoch,
|
| 379 |
+
device=device,
|
| 380 |
+
mask_rate_scheduler=mask_rate_scheduler,
|
| 381 |
+
mask_percentage=config.MASK_PERCENTAGE,
|
| 382 |
+
checkpoint_dir=checkpoint_dir)
|
| 383 |
+
|
| 384 |
+
# Test the model
|
| 385 |
+
test(model, tokenizer, test_loader, mask_percentage=0.15, device=device, checkpoint_dir=checkpoint_dir)
|
| 386 |
+
|
| 387 |
+
if __name__ == "__main__":
|
| 388 |
+
main()
|
fuson_plm/training/utils.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import torch
|
| 4 |
+
import os
|
| 5 |
+
from torch.nn.functional import softmax
|
| 6 |
+
from fuson_plm.utils.logging import log_update
|
| 7 |
+
from torch.utils.data import Dataset, DataLoader
|
| 8 |
+
from transformers import AutoTokenizer
|
| 9 |
+
from abc import ABC, abstractmethod
|
| 10 |
+
|
| 11 |
+
#----------------------------------------------------------------------------------------------------------------------------------------------------
|
| 12 |
+
#### Masking Rate Scheduler base class and sub classes
|
| 13 |
+
# abstract base class
|
| 14 |
+
class MaskingRateScheduler(ABC):
|
| 15 |
+
def __init__(self, total_steps, min_masking_rate, max_masking_rate, last_step=-1):
|
| 16 |
+
self.total_steps = total_steps
|
| 17 |
+
self.min_masking_rate = min_masking_rate
|
| 18 |
+
self.max_masking_rate = max_masking_rate
|
| 19 |
+
self.current_step = last_step
|
| 20 |
+
|
| 21 |
+
def step(self):
|
| 22 |
+
self.current_step += 1
|
| 23 |
+
|
| 24 |
+
def reset(self):
|
| 25 |
+
"""Reset the scheduler to its initial state."""
|
| 26 |
+
self.current_step = -1
|
| 27 |
+
|
| 28 |
+
def get_masking_rate(self):
|
| 29 |
+
progress = self.current_step / self.total_steps
|
| 30 |
+
return self.compute_masking_rate(progress)
|
| 31 |
+
|
| 32 |
+
@abstractmethod
|
| 33 |
+
def compute_masking_rate(self, progress):
|
| 34 |
+
"""To be implemented by subclasses for specific increase functions."""
|
| 35 |
+
raise NotImplementedError("Subclasses must implement this method.")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CosineIncreaseMaskingRateScheduler(MaskingRateScheduler):
|
| 39 |
+
def compute_masking_rate(self, progress):
|
| 40 |
+
# Use a cosine increase function
|
| 41 |
+
cosine_increase = 0.5 * (1 - np.cos(np.pi * progress))
|
| 42 |
+
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * cosine_increase
|
| 43 |
+
|
| 44 |
+
class LogLinearIncreaseMaskingRateScheduler(MaskingRateScheduler):
|
| 45 |
+
def compute_masking_rate(self, progress):
|
| 46 |
+
# Avoid log(0) by clamping progress to a minimum of a small positive number
|
| 47 |
+
progress = max(progress, 1e-10)
|
| 48 |
+
log_linear_increase = np.log1p(progress) / np.log1p(1) # Normalizing to keep range in [0, 1]
|
| 49 |
+
return self.min_masking_rate + (self.max_masking_rate - self.min_masking_rate) * log_linear_increase
|
| 50 |
+
|
| 51 |
+
class StepwiseIncreaseMaskingRateScheduler(MaskingRateScheduler):
|
| 52 |
+
def __init__(self, total_batches, min_masking_rate, max_masking_rate, num_steps):
|
| 53 |
+
super().__init__(total_steps=total_batches, min_masking_rate=min_masking_rate, max_masking_rate=max_masking_rate)
|
| 54 |
+
self.num_steps = num_steps
|
| 55 |
+
self.batch_interval = total_batches // (num_steps) # Adjusting to ensure max rate is included
|
| 56 |
+
self.rate_increment = (max_masking_rate - min_masking_rate) / (num_steps - 1) # Include end rate in the steps
|
| 57 |
+
|
| 58 |
+
def compute_masking_rate(self, progress):
|
| 59 |
+
# Determine the current step based on the number of completed batches
|
| 60 |
+
current_step = int(self.current_step / self.batch_interval)
|
| 61 |
+
# Cap the step number to `num_steps - 1` to include the max rate at the final step
|
| 62 |
+
current_step = min(current_step, self.num_steps - 1)
|
| 63 |
+
# Calculate the masking rate for the current step
|
| 64 |
+
masking_rate = self.min_masking_rate + current_step * self.rate_increment
|
| 65 |
+
return masking_rate
|
| 66 |
+
|
| 67 |
+
def get_mask_rate_scheduler(scheduler_type="cosine",min_masking_rate=0.15,max_masking_rate=0.40,total_batches=100,total_steps=20):
|
| 68 |
+
"""
|
| 69 |
+
Initialize the mask rate scheduler and return it
|
| 70 |
+
"""
|
| 71 |
+
if scheduler_type=="cosine":
|
| 72 |
+
return CosineIncreaseMaskingRateScheduler(total_steps=total_batches,
|
| 73 |
+
min_masking_rate=min_masking_rate,
|
| 74 |
+
max_masking_rate=max_masking_rate)
|
| 75 |
+
elif scheduler_type=="loglinear":
|
| 76 |
+
return LogLinearIncreaseMaskingRateScheduler(total_steps=total_batches,
|
| 77 |
+
min_masking_rate=min_masking_rate,
|
| 78 |
+
max_masking_rate=max_masking_rate)
|
| 79 |
+
elif scheduler_type=="stepwise":
|
| 80 |
+
return StepwiseIncreaseMaskingRateScheduler(total_batches=total_batches,
|
| 81 |
+
num_steps=total_steps,
|
| 82 |
+
min_masking_rate=min_masking_rate,
|
| 83 |
+
max_masking_rate=max_masking_rate)
|
| 84 |
+
else:
|
| 85 |
+
raise Exception("Must specify valid scheduler_type: cosine, loglinear, stepwise")
|
| 86 |
+
|
| 87 |
+
# Adjusted Dataloader for the sequences and probability vectors
|
| 88 |
+
class ProteinDataset(Dataset):
|
| 89 |
+
def __init__(self, data_path, tokenizer, probability_type, max_length=512):
|
| 90 |
+
self.dataframe = pd.read_csv(data_path)
|
| 91 |
+
self.tokenizer = tokenizer
|
| 92 |
+
self.probability_type=probability_type
|
| 93 |
+
self.max_length = max_length
|
| 94 |
+
|
| 95 |
+
self.set_probabilities()
|
| 96 |
+
|
| 97 |
+
def __len__(self):
|
| 98 |
+
return len(self.dataframe)
|
| 99 |
+
|
| 100 |
+
def set_probabilities(self):
|
| 101 |
+
if self.probability_type=="snp":
|
| 102 |
+
self.dataframe = self.dataframe.rename(columns={'snp_probabilities':'probabilities'})
|
| 103 |
+
if self.probability_type=="uniform":
|
| 104 |
+
self.dataframe['probabilities'] = self.dataframe['sequence'].apply(len).apply(lambda x: ('1,'*x)[0:-1])
|
| 105 |
+
|
| 106 |
+
# make probabilities into numbers if they aren't already
|
| 107 |
+
if type(self.dataframe['probabilities'][0]) == str:
|
| 108 |
+
self.dataframe['probabilities'] = self.dataframe['probabilities'].apply(
|
| 109 |
+
lambda x: np.array([float(i) for i in x.split(',')])
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
def get_padded_probabilities(self, idx):
|
| 113 |
+
'''
|
| 114 |
+
Pads probabilities to max_length if they're too short; truncate them if they're too long
|
| 115 |
+
'''
|
| 116 |
+
no_mask_value = int(-1e9) # will be used to make sure CLS and PAD aren't masked
|
| 117 |
+
|
| 118 |
+
# add a no-mask slot for <CLS>
|
| 119 |
+
prob = np.concatenate((
|
| 120 |
+
np.array([no_mask_value]),
|
| 121 |
+
self.dataframe.iloc[idx]['probabilities']
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# Pad with no_mask_value for everything after the probability vector ends
|
| 126 |
+
if len(prob) < self.max_length:
|
| 127 |
+
return np.pad(
|
| 128 |
+
prob,
|
| 129 |
+
(0, self.max_length - len(prob)),
|
| 130 |
+
'constant', constant_values=(0,no_mask_value))
|
| 131 |
+
|
| 132 |
+
# If it's too long, we need to truncate, but we also need to change the last token to an <EOS>.
|
| 133 |
+
prob = prob[0:self.max_length-1]
|
| 134 |
+
prob = np.concatenate((
|
| 135 |
+
prob,
|
| 136 |
+
np.array([no_mask_value]),
|
| 137 |
+
)
|
| 138 |
+
)
|
| 139 |
+
return prob
|
| 140 |
+
|
| 141 |
+
def __getitem__(self, idx):
|
| 142 |
+
sequence = self.dataframe.iloc[idx]['sequence']
|
| 143 |
+
probability = self.get_padded_probabilities(idx) # extract them
|
| 144 |
+
inputs = self.tokenizer(sequence, return_tensors="pt", padding='max_length', truncation=True, max_length=self.max_length) # does this have to be 512?
|
| 145 |
+
inputs = {key: tensor.squeeze(0) for key, tensor in inputs.items()} # Remove batch dimension
|
| 146 |
+
return inputs, probability
|
| 147 |
+
|
| 148 |
+
def get_dataloader(data_path, tokenizer, probability_type='snp', max_length=512, batch_size=8, shuffle=True):
|
| 149 |
+
"""
|
| 150 |
+
Creates a DataLoader for the dataset.
|
| 151 |
+
Args:
|
| 152 |
+
data_path (str): Path to the CSV file (train, val, or test).
|
| 153 |
+
batch_size (int): Batch size.
|
| 154 |
+
shuffle (bool): Whether to shuffle the data.
|
| 155 |
+
tokenizer (Tokenizer): tokenizer object for data tokenization
|
| 156 |
+
Returns:
|
| 157 |
+
DataLoader: DataLoader object.
|
| 158 |
+
"""
|
| 159 |
+
dataset = ProteinDataset(data_path, tokenizer, probability_type, max_length=max_length)
|
| 160 |
+
return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)
|
| 161 |
+
|
| 162 |
+
def check_dataloaders(train_loader, val_loader, test_loader, max_length=512, checkpoint_dir=''):
|
| 163 |
+
log_update(f'\nBuilt train, validation, and test dataloders')
|
| 164 |
+
log_update(f"\tNumber of sequences in the Training DataLoader: {len(train_loader.dataset)}")
|
| 165 |
+
log_update(f"\tNumber of sequences in the Validation DataLoader: {len(val_loader.dataset)}")
|
| 166 |
+
log_update(f"\tNumber of sequences in the Training DataLoader: {len(test_loader.dataset)}")
|
| 167 |
+
dataloader_overlaps = check_dataloader_overlap(train_loader, val_loader, test_loader)
|
| 168 |
+
if len(dataloader_overlaps)==0: log_update("\tDataloaders are clean (no overlaps)")
|
| 169 |
+
else: log_update(f"\tWARNING! sequence overlap found: {','.join(dataloader_overlaps)}")
|
| 170 |
+
|
| 171 |
+
# write length ranges to a text file
|
| 172 |
+
if not(os.path.exists(f'{checkpoint_dir}/batch_diversity')):
|
| 173 |
+
os.mkdir(f'{checkpoint_dir}/batch_diversity')
|
| 174 |
+
|
| 175 |
+
max_length_violators = []
|
| 176 |
+
for name, dataloader in {'train':train_loader, 'val':val_loader, 'test':test_loader}.items():
|
| 177 |
+
max_length_followed, length_ranges = check_max_length_and_length_diversity(dataloader, max_length)
|
| 178 |
+
if max_length_followed == False:
|
| 179 |
+
max_length_violators.append(name)
|
| 180 |
+
|
| 181 |
+
with open(f'{checkpoint_dir}/batch_diversity/{name}_batch_length_ranges.txt','w') as f:
|
| 182 |
+
for tup in length_ranges:
|
| 183 |
+
f.write(f'{tup[0]}\t{tup[1]}\n')
|
| 184 |
+
|
| 185 |
+
if len(max_length_violators)==0: log_update(f"\tDataloaders follow the max length limit set by user: {max_length}")
|
| 186 |
+
else: log_update(f"\tWARNING! these loaders have sequences longer than max length={max_length}: {','.join(max_length_violators)}")
|
| 187 |
+
|
| 188 |
+
def check_dataloader_overlap(train_loader, val_loader, test_loader):
|
| 189 |
+
"""
|
| 190 |
+
Check the data that's about to go into the model. Make sure there is no overlap between train, test, and val
|
| 191 |
+
|
| 192 |
+
Returns:
|
| 193 |
+
"""
|
| 194 |
+
train_protein_seqs = set(train_loader.dataset.dataframe['sequence'].unique())
|
| 195 |
+
val_protein_seqs = set(val_loader.dataset.dataframe['sequence'].unique())
|
| 196 |
+
test_protein_seqs = set(test_loader.dataset.dataframe['sequence'].unique())
|
| 197 |
+
|
| 198 |
+
tr_va = len(train_protein_seqs.intersection(val_protein_seqs))
|
| 199 |
+
tr_te = len(train_protein_seqs.intersection(test_protein_seqs))
|
| 200 |
+
va_te = len(val_protein_seqs.intersection(test_protein_seqs))
|
| 201 |
+
|
| 202 |
+
overlaps = []
|
| 203 |
+
if tr_va==tr_te==va_te==0:
|
| 204 |
+
return overlaps # data is clean
|
| 205 |
+
else:
|
| 206 |
+
if tr_va > 0: overlaps.append(f"Train-Val Overlap={tr_va}")
|
| 207 |
+
if tr_te > 0: overlaps.append(f"Train-Test Overlap={tr_te}")
|
| 208 |
+
if va_te > 0: overlaps.append(f"Val-Test Overlap={va_te}")
|
| 209 |
+
return overlaps
|
| 210 |
+
|
| 211 |
+
def check_max_length_and_length_diversity(dataloader, max_length):
|
| 212 |
+
"""
|
| 213 |
+
Check if all sequences in the DataLoader conform to the specified max_length,
|
| 214 |
+
and return the sequence length ranges within each batch.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
dataloader (DataLoader): The DataLoader object to check.
|
| 218 |
+
max_length (int): The maximum allowed sequence length.
|
| 219 |
+
|
| 220 |
+
Returns:
|
| 221 |
+
bool: True if all sequences are within the max_length, False otherwise.
|
| 222 |
+
list: A list of tuples representing the min and max sequence lengths in each batch.
|
| 223 |
+
"""
|
| 224 |
+
length_ranges = []
|
| 225 |
+
all_within_max_length = True
|
| 226 |
+
|
| 227 |
+
for batch_idx, (inputs, _) in enumerate(dataloader):
|
| 228 |
+
input_ids = inputs['input_ids']
|
| 229 |
+
|
| 230 |
+
# Calculate the actual lengths of sequences in this batch
|
| 231 |
+
actual_lengths = (input_ids != dataloader.dataset.tokenizer.pad_token_id).sum(dim=1)
|
| 232 |
+
min_length = actual_lengths.min().item()
|
| 233 |
+
max_length_in_batch = actual_lengths.max().item()
|
| 234 |
+
|
| 235 |
+
# Check for max length violation
|
| 236 |
+
if max_length_in_batch > max_length:
|
| 237 |
+
#print(f"Error: Sequence exceeds max_length of {max_length} at batch {batch_idx + 1}. Max length found: {max_length_in_batch}")
|
| 238 |
+
all_within_max_length = False
|
| 239 |
+
|
| 240 |
+
# Store the length range for this batch
|
| 241 |
+
length_ranges.append((min_length, max_length_in_batch))
|
| 242 |
+
|
| 243 |
+
#print(f"All sequences in the DataLoader conform to the max_length of {max_length}.") if all_within_max_length else None
|
| 244 |
+
#print(f"Sequence length ranges per batch: {length_ranges}")
|
| 245 |
+
|
| 246 |
+
return all_within_max_length, length_ranges
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
def check_max_length_in_dataloader(dataloader, max_length):
|
| 250 |
+
"""
|
| 251 |
+
Check if all sequences in the DataLoader conform to the specified max_length.
|
| 252 |
+
|
| 253 |
+
Args:
|
| 254 |
+
dataloader (DataLoader): The DataLoader object to check.
|
| 255 |
+
max_length (int): The maximum allowed sequence length.
|
| 256 |
+
|
| 257 |
+
Returns:
|
| 258 |
+
bool: True if all sequences are within the max_length, False otherwise.
|
| 259 |
+
"""
|
| 260 |
+
for batch_idx, (inputs, _) in enumerate(dataloader):
|
| 261 |
+
input_ids = inputs['input_ids']
|
| 262 |
+
|
| 263 |
+
# Check if any sequence length exceeds max_length
|
| 264 |
+
if input_ids.size(1) > max_length:
|
| 265 |
+
return False
|
| 266 |
+
|
| 267 |
+
return True
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
def batch_sample_mask_tokens_with_probabilities(inputs, probabilities, tokenizer: AutoTokenizer, mask_percentage=0.15):
|
| 271 |
+
"""
|
| 272 |
+
"""
|
| 273 |
+
#print('the batch sample method was called!')
|
| 274 |
+
labels = inputs["input_ids"].detach().clone()
|
| 275 |
+
labels[labels != tokenizer.mask_token_id] = -100 # Set labels for unmasked tokens to -100
|
| 276 |
+
|
| 277 |
+
# Iterate over each sequence and its corresponding probabilities in the batch
|
| 278 |
+
for idx in range(inputs["input_ids"].size(0)): # Assuming the first dimension is batch size
|
| 279 |
+
input_ids = inputs["input_ids"][idx]
|
| 280 |
+
prob = probabilities[idx]
|
| 281 |
+
|
| 282 |
+
cls_token_index = (input_ids == 0).nonzero(as_tuple=False)[0].item()
|
| 283 |
+
eos_token_index = (input_ids == 2).nonzero(as_tuple=False)[0].item()
|
| 284 |
+
seq_length = eos_token_index - (cls_token_index+1)
|
| 285 |
+
|
| 286 |
+
assert prob.shape[0] == input_ids.shape[0]
|
| 287 |
+
|
| 288 |
+
# Normalize probabilities using softmax
|
| 289 |
+
prob = softmax(prob, dim=0).cpu().numpy() # move to CPU for numpy
|
| 290 |
+
assert 1 - sum(prob) < 1e-6
|
| 291 |
+
|
| 292 |
+
# Calculate the number of tokens to mask
|
| 293 |
+
num_tokens_to_mask = int(mask_percentage * seq_length)
|
| 294 |
+
|
| 295 |
+
# Choose indices to mask based on the probability distribution
|
| 296 |
+
mask_indices = np.random.choice(input_ids.shape[0], size=num_tokens_to_mask, replace=False, p=prob)
|
| 297 |
+
attention_mask_1_indices = np.arange(0, eos_token_index+1, 1)
|
| 298 |
+
|
| 299 |
+
# Mask the selected indices and set the corresponding labels
|
| 300 |
+
labels[idx, mask_indices] = input_ids[mask_indices].detach().clone()
|
| 301 |
+
input_ids[mask_indices] = tokenizer.mask_token_id
|
| 302 |
+
|
| 303 |
+
inputs["attention_mask"][idx] = torch.zeros_like(input_ids)
|
| 304 |
+
inputs["attention_mask"][idx][attention_mask_1_indices] = 1 # just added this to try and update the attention mask....
|
| 305 |
+
|
| 306 |
+
# Update the input_ids in the inputs dictionary
|
| 307 |
+
inputs["input_ids"][idx] = input_ids
|
| 308 |
+
|
| 309 |
+
inputs["labels"] = labels
|
| 310 |
+
return inputs
|
| 311 |
+
|
| 312 |
+
|