| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from datasets import load_from_disk | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \ | |
| Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler | |
| from pytorch_lightning.loggers import WandbLogger | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from argparse import ArgumentParser | |
| import os | |
| import math | |
| import numpy as np | |
| import torch.distributed as dist | |
| from transformers import AutoTokenizer, AutoModelForMaskedLM | |
| from torch.optim import AdamW | |
| from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef | |
| import yaml | |
| from easydict import EasyDict as edict | |
| from models import * | |
| from transformers.utils import logging | |
| logging.set_verbosity_error() | |
| def compute_class_weights(targets): | |
| num_binding_sites = targets.sum() | |
| num_non_binding_sites = targets.numel() - num_binding_sites | |
| total = num_binding_sites + num_non_binding_sites | |
| weight_for_binding = total / (2 * num_binding_sites) | |
| weight_for_non_binding = total / (2 * num_non_binding_sites) | |
| return torch.tensor([weight_for_non_binding, weight_for_binding]) | |
| class CosineAnnealingWithWarmup(_LRScheduler): | |
| def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.base_lr = base_lr | |
| self.max_lr = max_lr | |
| self.min_lr = min_lr | |
| super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) | |
| print(f"SELF BASE LRS = {self.base_lrs}") | |
| def get_lr(self): | |
| if self.last_epoch < self.warmup_steps: | |
| # Linear warmup phase from base_lr to max_lr | |
| return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs] | |
| # Cosine annealing phase from max_lr to min_lr | |
| progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
| cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) | |
| decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay | |
| return [decayed_lr for base_lr in self.base_lrs] | |
| class BindEvaluator(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super(BindEvaluator, self).__init__() | |
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D") | |
| for param in self.esm_model.parameters(): | |
| param.requires_grad = False | |
| self.peptideclm_model = AutoModelForMaskedLM.from_pretrained('aaronfeller/PeptideCLM-23M-all', trust_remote_code=True).roformer | |
| for param in self.peptideclm_model.parameters(): | |
| param.requires_grad = False | |
| self.repeated_module = RepeatedModule3(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden, | |
| cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout) | |
| self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model, | |
| cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout) | |
| self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout) | |
| self.output_projection_prot = nn.Linear(cfg.model.d_model, 1) | |
| self.class_weights = torch.tensor([7.803861120886587, 0.5342284711965681]) # binding_site weights, non-bidning site weights | |
| self.cfg = cfg | |
| self._total_steps = None | |
| def forward(self, binder_tokens, target_tokens): | |
| peptide_sequence = self.peptideclm_model(**binder_tokens).last_hidden_state | |
| protein_sequence = self.esm_model(**target_tokens).last_hidden_state | |
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ | |
| seq_prot_attention_list, seq_prot_attention_list = self.repeated_module(peptide_sequence, | |
| protein_sequence) | |
| prot_enc, final_prot_seq_attention = self.final_attention_layer(prot_enc, sequence_enc, sequence_enc) | |
| prot_enc = self.final_ffn(prot_enc) | |
| prot_enc = self.output_projection_prot(prot_enc) | |
| return prot_enc | |
| def training_step(self, batch, batch_idx): | |
| target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)} | |
| binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)} | |
| binding_site = torch.tensor(batch['labels']).float().to(self.device) | |
| mask = torch.tensor(batch['labels_mask']).to(self.device) | |
| outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1) | |
| weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site) | |
| bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none') | |
| masked_bce_loss = bce_loss * mask | |
| mean_bce_loss = masked_bce_loss.sum() / mask.sum() | |
| kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask) | |
| mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss | |
| batch_size = binding_site.shape[0] | |
| self.log('bce_loss', mean_bce_loss, on_step=True, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('kl_loss', kl_loss, on_step=True, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('train_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| return mean_loss | |
| def validation_step(self, batch, batch_idx): | |
| target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)} | |
| binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)} | |
| binding_site = torch.tensor(batch['labels']).float().to(self.device) | |
| mask = torch.tensor(batch['labels_mask']).to(self.device) | |
| outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1) | |
| weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site) | |
| bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none') | |
| masked_bce_loss = bce_loss * mask | |
| mean_bce_loss = masked_bce_loss.sum() / mask.sum() | |
| kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask) | |
| mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss | |
| # Calculate predictions and apply mask | |
| sigmoid_outputs = torch.sigmoid(outputs_nodes) | |
| total = mask.sum() | |
| predict = (sigmoid_outputs >= 0.5).float() | |
| correct = ((predict == binding_site) * mask).sum() | |
| accuracy = correct / total | |
| # Compute AUC | |
| outputs_nodes_flat = sigmoid_outputs[mask.bool()].float().cpu().detach().numpy().flatten() | |
| binding_site_flat = binding_site[mask.bool()].float().cpu().detach().numpy().flatten() | |
| predictions_flat = predict[mask.bool()].float().cpu().detach().numpy().flatten() | |
| auc = roc_auc_score(binding_site_flat, outputs_nodes_flat) | |
| f1 = f1_score(binding_site_flat, predictions_flat) | |
| mcc = matthews_corrcoef(binding_site_flat, predictions_flat) | |
| batch_size = binding_site.shape[0] | |
| self.log('val_loss', mean_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_kl_loss', kl_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_bce_loss', mean_bce_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_accuracy', accuracy, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_auc', auc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_f1', f1, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_mcc', mcc, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| def compute_kl_loss(self, outputs, targets, mask): | |
| log_probs = F.log_softmax(outputs, dim=-1) | |
| target_probs = targets.float() | |
| kl_loss = F.kl_div(log_probs, target_probs, reduction='none') | |
| masked_kl_loss = kl_loss * mask | |
| mean_kl_loss = masked_kl_loss.sum() / mask.sum() | |
| return mean_kl_loss | |
| def configure_optimizers(self): | |
| opt = torch.optim.AdamW( | |
| self.parameters(), | |
| lr=float(self.cfg.optim.lr), | |
| betas=(self.cfg.optim.beta1, self.cfg.optim.beta2), | |
| eps=float(self.cfg.optim.eps), | |
| weight_decay=self.cfg.optim.weight_decay, | |
| fused=self.cfg.optim.fused, | |
| ) | |
| warmup_ratio = getattr(self.cfg.optim, "warmup_ratio", 0.1) | |
| min_scale = 0.1 | |
| def lr_lambda(global_step: int): | |
| # until on_train_start runs we just return 1.0 | |
| if self._total_steps is None or self._total_steps == 0: | |
| return 1.0 | |
| total_steps = self._total_steps | |
| warmup_steps = max(1, int(warmup_ratio * total_steps)) | |
| if global_step < warmup_steps: | |
| # linear warmup: 0.1 -> 1.0 | |
| alpha = (global_step + 1) / warmup_steps | |
| return 0.1 + 0.9 * alpha | |
| else: | |
| # cosine from 1.0 down to min_scale | |
| progress = (global_step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| cosine = 0.5 * (1 + math.cos(math.pi * progress)) # 1 -> 0 | |
| return min_scale + (1.0 - min_scale) * cosine | |
| sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda) | |
| return { | |
| "optimizer": opt, | |
| "lr_scheduler": { | |
| "scheduler": sch, | |
| "interval": "step", # <- per-step | |
| "frequency": 1, | |
| }, | |
| } | |
| def on_train_start(self): | |
| # how many optimizer steps we will take in this fit | |
| self._total_steps = self.trainer.estimated_stepping_batches | |
| # def on_training_epoch_end(self, outputs): | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| # super().training_epoch_end(outputs) | |
| # def on_validation_epoch_end(self, outputs): | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| # super().validation_epoch_end(outputs) | |
| def main(): | |
| CONFIG_PATH = './config.yaml' | |
| with open(CONFIG_PATH, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| cfg = edict(config_dict) | |
| run_name = f"lr={cfg.optim.lr}_nlayers={cfg.model.n_layers}_dmodel={cfg.model.d_model}_nhead={cfg.model.n_head}_dinner={cfg.model.d_inner}" | |
| workdir = os.path.join(cfg.work_dir, run_name) | |
| train_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset/train') | |
| valid_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset/valid') | |
| train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=True, num_workers=4) | |
| valid_dataloader = DataLoader(valid_dataset, batch_size=None, shuffle=False, num_workers=4) | |
| model = BindEvaluator(cfg) | |
| wandb_logger = WandbLogger( | |
| project='SMILES_BindEvaluator', | |
| name=run_name, | |
| entity='programmablebio', | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor='val_mcc', | |
| dirpath=workdir, | |
| filename='model-{epoch:02d}-{val_mcc:.2f}', | |
| save_top_k=3, | |
| mode='max', | |
| ) | |
| lrmon = LearningRateMonitor(logging_interval="step") | |
| early_stopping_callback = EarlyStopping( | |
| monitor='val_mcc', | |
| patience=5, | |
| verbose=True, | |
| mode='max' | |
| ) | |
| # accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2}) | |
| trainer = pl.Trainer( | |
| default_root_dir=workdir, | |
| max_epochs=cfg.optim.max_epochs, | |
| accelerator='gpu', | |
| strategy='ddp_find_unused_parameters_true', | |
| precision='bf16-mixed', | |
| devices=cfg.compute.ngpus, | |
| callbacks=[checkpoint_callback, lrmon], | |
| gradient_clip_val=cfg.optim.grad_clip, | |
| log_every_n_steps=10, | |
| logger=wandb_logger, | |
| ) | |
| trainer.fit(model, train_dataloader, valid_dataloader) | |
| best_model_path = checkpoint_callback.best_model_path | |
| print(best_model_path) | |
| if __name__ == "__main__": | |
| main() |
Xet Storage Details
- Size:
- 13.2 kB
- Xet hash:
- fabdc747882dcb4acd06b048e1002ac2a5b7206b0bed69c0696a6a41bdebe60d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.