File size: 6,091 Bytes
d04a061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
import gc
import torch
import torch.nn as nn
import lightning.pytorch as pl
from omegaconf import OmegaConf
from transformers import AutoModel
from torchmetrics.classification import BinaryAUROC, BinaryAccuracy
from src.utils.model_utils import _print
from src.guidance.utils import CosineWarmup
config = OmegaConf.load("/scratch/sgoel/MeMDLM_v2/src/configs/guidance.yaml")
class SolubilityClassifier(pl.LightningModule):
def __init__(self, config):
super().__init__()
self.config = config
self.loss_fn = nn.BCEWithLogitsLoss(reduction='none')
self.auroc = BinaryAUROC()
self.accuracy = BinaryAccuracy()
self.esm_model = AutoModel.from_pretrained(self.config.lm.pretrained_esm)
for p in self.esm_model.parameters():
p.requires_grad = False
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.model.d_model,
nhead=config.model.num_heads,
dropout=config.model.dropout,
batch_first=True
)
self.encoder = nn.TransformerEncoder(encoder_layer, config.model.num_layers)
self.layer_norm = nn.LayerNorm(config.model.d_model)
self.dropout = nn.Dropout(config.model.dropout)
self.mlp = nn.Sequential(
nn.Linear(config.model.d_model, config.model.d_model // 2),
nn.ReLU(),
nn.Dropout(config.model.dropout),
nn.Linear(config.model.d_model // 2, 1),
)
# -------# Classifier step #-------- #
def forward(self, batch):
if 'input_ids' in batch:
esm_embeds = self.get_esm_embeddings(batch['input_ids'], batch['attention_mask'])
elif 'embeds' in batch:
esm_embeds = batch['embeds']
encodings = self.encoder(esm_embeds, src_key_padding_mask=(batch['attention_mask'] == 0))
encodings = self.dropout(self.layer_norm(encodings))
logits = self.mlp(encodings).squeeze(-1)
return logits
# -------# Training / Evaluation #-------- #
def training_step(self, batch, batch_idx):
train_loss, _ = self.compute_loss(batch)
self.log(name="train/loss", value=train_loss.item(), on_step=True, on_epoch=False, logger=True, sync_dist=True)
self.save_ckpt()
return train_loss
def validation_step(self, batch, batch_idx):
val_loss, _ = self.compute_loss(batch)
self.log(name="val/loss", value=val_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
return val_loss
def test_step(self, batch):
test_loss, preds = self.compute_loss(batch)
auroc, accuracy = self.get_metrics(batch, preds)
self.log(name="test/loss", value=test_loss.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
self.log(name="test/AUROC", value=auroc.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
self.log(name="test/accuracy", value=accuracy.item(), on_step=False, on_epoch=True, logger=True, sync_dist=True)
return test_loss
def on_test_epoch_end(self):
self.auroc.reset()
self.accuracy.reset()
def optimizer_step(self, *args, **kwargs):
super().optimizer_step(*args, **kwargs)
gc.collect()
torch.cuda.empty_cache()
def configure_optimizers(self):
path = self.config.training
optimizer = torch.optim.AdamW(self.parameters(), lr=self.config.optim.lr)
lr_scheduler = CosineWarmup(
optimizer,
warmup_steps=path.warmup_steps,
total_steps=path.max_steps,
)
scheduler_dict = {
"scheduler": lr_scheduler,
"interval": 'step',
'frequency': 1,
'monitor': 'val/loss',
'name': 'learning_rate'
}
return [optimizer], [scheduler_dict]
def save_ckpt(self):
curr_step = self.global_step
save_every = self.config.training.val_check_interval
if curr_step % save_every == 0 and curr_step > 0: # Save every 250 steps
ckpt_path = f"{self.config.checkpointing.save_dir}/step={curr_step}.ckpt"
self.trainer.save_checkpoint(ckpt_path)
# -------# Loss and Test Set Metrics #-------- #
@torch.no_grad
def get_esm_embeddings(self, input_ids, attention_mask):
outputs = self.esm_model(input_ids=input_ids, attention_mask=attention_mask)
embeddings = outputs.last_hidden_state
return embeddings
def compute_loss(self, batch):
"""Helper method to handle loss calculation"""
labels = batch['labels']
preds = self.forward(batch)
loss = self.loss_fn(preds, labels)
loss_mask = (labels != self.config.model.label_pad_value) # only calculate loss over non-pad tokens
loss = (loss * loss_mask).sum() / loss_mask.sum()
return loss, preds
def get_metrics(self, batch, preds):
"""Helper method to compute metrics"""
labels = batch['labels']
valid_mask = (labels != self.config.model.label_pad_value)
labels = labels[valid_mask]
preds = preds[valid_mask]
_print(f"labels {labels.shape}")
_print(f"preds {preds.shape}")
auroc = self.auroc.forward(preds, labels)
accuracy = self.accuracy.forward(preds, labels)
return auroc, accuracy
# -------# Helper Functions #-------- #
def get_state_dict(self, ckpt_path):
"""Helper method to load and process a trained model's state dict from saved checkpoint"""
def remove_model_prefix(state_dict):
for k in state_dict.keys():
if "model." in k:
k.replace('model.', '')
return state_dict
checkpoint = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
state_dict = checkpoint.get("state_dict", checkpoint)
if any(k.startswith("model.") for k in state_dict.keys()):
state_dict = remove_model_prefix(state_dict)
return state_dict |