AlienChen's picture
download
raw
13.2 kB
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_from_disk
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import _LRScheduler
from argparse import ArgumentParser
import os
import math
import numpy as np
import torch.distributed as dist
from transformers import AutoTokenizer, AutoModelForMaskedLM
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
import yaml
from easydict import EasyDict as edict
from models import *
from transformers.utils import logging
logging.set_verbosity_error()
def compute_class_weights(targets):
num_binding_sites = targets.sum()
num_non_binding_sites = targets.numel() - num_binding_sites
total = num_binding_sites + num_non_binding_sites
weight_for_binding = total / (2 * num_binding_sites)
weight_for_non_binding = total / (2 * num_non_binding_sites)
return torch.tensor([weight_for_non_binding, weight_for_binding])
class CosineAnnealingWithWarmup(_LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.base_lr = base_lr
self.max_lr = max_lr
self.min_lr = min_lr
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
print(f"SELF BASE LRS = {self.base_lrs}")
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Linear warmup phase from base_lr to max_lr
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
# Cosine annealing phase from max_lr to min_lr
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
return [decayed_lr for base_lr in self.base_lrs]
class BindEvaluator(pl.LightningModule):
def __init__(self, cfg):
super(BindEvaluator, self).__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
for param in self.esm_model.parameters():
param.requires_grad = False
self.peptideclm_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all', trust_remote_code=True).roformer
for param in self.peptideclm_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule3(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden,
cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout)
self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model,
cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout)
self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout)
self.output_projection_prot = nn.Linear(cfg.model.d_model, 1)
self.class_weights = torch.tensor([7.803861120886587, 0.5342284711965681]) # binding_site weights, non-bidning site weights
self.cfg = cfg
self._total_steps = None
def forward(self, binder_tokens, target_tokens):
peptide_sequence = self.peptideclm_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence,
protein_sequence)
prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc)
prot_enc = self.final_ffn(prot_enc)
prot_enc = self.output_projection_prot(prot_enc)
return prot_enc
def training_step(self, batch, batch_idx):
target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)}
binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)}
binding_site = torch.tensor(batch['labels']).float().to(self.device)
mask = torch.tensor(batch['labels_mask']).to(self.device)
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
masked_bce_loss = bce_loss * mask
mean_bce_loss = masked_bce_loss.sum() / mask.sum()
kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask)
mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss
batch_size = binding_site.shape[0]
self.log('bce_loss', mean_bce_loss, on_step=True, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('kl_loss', kl_loss, on_step=True, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('train_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
return mean_loss
def validation_step(self, batch, batch_idx):
target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)}
binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)}
binding_site = torch.tensor(batch['labels']).float().to(self.device)
mask = torch.tensor(batch['labels_mask']).to(self.device)
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
masked_bce_loss = bce_loss * mask
mean_bce_loss = masked_bce_loss.sum() / mask.sum()
kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask)
mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss
# Calculate predictions and apply mask
sigmoid_outputs = torch.sigmoid(outputs_nodes)
total = mask.sum()
predict = (sigmoid_outputs >= 0.5).float()
correct = ((predict == binding_site) * mask).sum()
accuracy = correct / total
# Compute AUC
outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten()
binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten()
predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten()
auc = roc_auc_score(binding_site_flat, outputs_nodes_flat)
f1 = f1_score(binding_site_flat, predictions_flat)
mcc = matthews_corrcoef(binding_site_flat, predictions_flat)
batch_size = binding_site.shape[0]
self.log('val_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_kl_loss', kl_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_bce_loss', mean_bce_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
def compute_kl_loss(self, outputs, targets, mask):
log_probs = F.log_softmax(outputs, dim=-1)
target_probs = targets.float()
kl_loss = F.kl_div(log_probs, target_probs, reduction='none')
masked_kl_loss = kl_loss * mask
mean_kl_loss = masked_kl_loss.sum() / mask.sum()
return mean_kl_loss
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(),
lr=float(self.cfg.optim.lr),
betas=(self.cfg.optim.beta1, self.cfg.optim.beta2),
eps=float(self.cfg.optim.eps),
weight_decay=self.cfg.optim.weight_decay,
fused=self.cfg.optim.fused,
)
warmup_ratio = getattr(self.cfg.optim, "warmup_ratio", 0.1)
min_scale = 0.1
def lr_lambda(global_step: int):
# until on_train_start runs we just return 1.0
if self._total_steps is None or self._total_steps == 0:
return 1.0
total_steps = self._total_steps
warmup_steps = max(1, int(warmup_ratio * total_steps))
if global_step < warmup_steps:
# linear warmup: 0.1 -> 1.0
alpha = (global_step + 1) / warmup_steps
return 0.1 + 0.9 * alpha
else:
# cosine from 1.0 down to min_scale
progress = (global_step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine = 0.5 * (1 + math.cos(math.pi * progress)) # 1 -> 0
return min_scale + (1.0 - min_scale) * cosine
sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
return {
"optimizer": opt,
"lr_scheduler": {
"scheduler": sch,
"interval": "step", # <- per-step
"frequency": 1,
},
}
def on_train_start(self):
# how many optimizer steps we will take in this fit
self._total_steps = self.trainer.estimated_stepping_batches
# def on_training_epoch_end(self, outputs):
# gc.collect()
# torch.cuda.empty_cache()
# super().training_epoch_end(outputs)
# def on_validation_epoch_end(self, outputs):
# gc.collect()
# torch.cuda.empty_cache()
# super().validation_epoch_end(outputs)
def main():
CONFIG_PATH = './config.yaml'
with open(CONFIG_PATH, 'r') as f:
config_dict = yaml.safe_load(f)
cfg = edict(config_dict)
run_name = f"lr={cfg.optim.lr}_nlayers={cfg.model.n_layers}_dmodel={cfg.model.d_model}_nhead={cfg.model.n_head}_dinner={cfg.model.d_inner}"
workdir = os.path.join(cfg.work_dir, run_name)
train_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset/train')
valid_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset/valid')
train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=None, shuffle=False, num_workers=4)
model = BindEvaluator(cfg)
wandb_logger = WandbLogger(
project='SMILES_BindEvaluator',
name=run_name,
entity='programmablebio',
)
checkpoint_callback = ModelCheckpoint(
monitor='val_mcc',
dirpath=workdir,
filename='model-{epoch:02d}-{val_mcc:.2f}',
save_top_k=3,
mode='max',
)
lrmon = LearningRateMonitor(logging_interval="step")
early_stopping_callback = EarlyStopping(
monitor='val_mcc',
patience=5,
verbose=True,
mode='max'
)
# accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2})
trainer = pl.Trainer(
default_root_dir=workdir,
max_epochs=cfg.optim.max_epochs,
accelerator='gpu',
strategy='ddp_find_unused_parameters_true',
precision='bf16-mixed',
devices=cfg.compute.ngpus,
callbacks=[checkpoint_callback, lrmon],
gradient_clip_val=cfg.optim.grad_clip,
log_every_n_steps=10,
logger=wandb_logger,
)
trainer.fit(model, train_dataloader, valid_dataloader)
best_model_path = checkpoint_callback.best_model_path
print(best_model_path)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
13.2 kB
·
Xet hash:
fabdc747882dcb4acd06b048e1002ac2a5b7206b0bed69c0696a6a41bdebe60d

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.