| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| from datasets import load_from_disk | |
| import pytorch_lightning as pl | |
| from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \ | |
| Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler | |
| from pytorch_lightning.loggers import WandbLogger | |
| from torch.optim.lr_scheduler import _LRScheduler | |
| from argparse import ArgumentParser | |
| import os | |
| import math | |
| import numpy as np | |
| import torch.distributed as dist | |
| from transformers import AutoModelWithLMHead, EsmModel | |
| from torch.optim import AdamW | |
| from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef | |
| import yaml | |
| from easydict import EasyDict as edict | |
| from models import * | |
| from transformers.utils import logging | |
| logging.set_verbosity_error() | |
| def best_mcc_threshold(y_true: np.ndarray, y_prob: np.ndarray): | |
| """Find threshold t that maximizes MCC for rule pred = prob >= t. | |
| Uses sorting + cumulative counts over unique probability values, so it is | |
| much faster than scanning a fixed grid. | |
| Returns: | |
| (best_mcc, best_threshold) | |
| Notes: | |
| - If y_true has only one class or is empty, returns (0.0, 0.5). | |
| """ | |
| y_true = np.asarray(y_true).astype(np.int32) | |
| y_prob = np.asarray(y_prob).astype(np.float64) | |
| if y_true.size == 0: | |
| return 0.0, 0.5 | |
| P = int(y_true.sum()) | |
| N = int(y_true.size - P) | |
| if P == 0 or N == 0: | |
| return 0.0, 0.5 | |
| order = np.argsort(-y_prob) # descending | |
| y_sorted = y_true[order] | |
| p_sorted = y_prob[order] | |
| tp_cum = np.cumsum(y_sorted) | |
| fp_cum = np.arange(1, y_sorted.size + 1) - tp_cum | |
| # Evaluate MCC only when probability changes (unique threshold candidates) | |
| change = np.r_[p_sorted[1:] != p_sorted[:-1], True] | |
| idx = np.where(change)[0] | |
| tp = tp_cum[idx].astype(np.float64) | |
| fp = fp_cum[idx].astype(np.float64) | |
| fn = (P - tp).astype(np.float64) | |
| tn = (N - fp).astype(np.float64) | |
| denom = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn)) | |
| mcc = np.zeros_like(tp, dtype=np.float64) | |
| valid = denom > 0 | |
| mcc[valid] = (tp[valid] * tn[valid] - fp[valid] * fn[valid]) / denom[valid] | |
| best_i = int(np.argmax(mcc)) | |
| best_thr = float(p_sorted[idx[best_i]]) | |
| best_mcc = float(mcc[best_i]) | |
| return best_mcc, best_thr | |
| class CosineAnnealingWithWarmup(_LRScheduler): | |
| def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1): | |
| self.warmup_steps = warmup_steps | |
| self.total_steps = total_steps | |
| self.base_lr = base_lr | |
| self.max_lr = max_lr | |
| self.min_lr = min_lr | |
| super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch) | |
| def get_lr(self): | |
| if self.last_epoch < self.warmup_steps: | |
| # Linear warmup phase from base_lr to max_lr | |
| return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs] | |
| # Cosine annealing phase from max_lr to min_lr | |
| progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps) | |
| cosine_decay = 0.5 * (1 + np.cos(np.pi * progress)) | |
| decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay | |
| return [decayed_lr for base_lr in self.base_lrs] | |
| class BindEvaluator(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super(BindEvaluator, self).__init__() | |
| self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval() | |
| for param in self.esm_model.parameters(): | |
| param.requires_grad = False | |
| self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval() | |
| for param in self.chemberta_model.parameters(): | |
| param.requires_grad = False | |
| self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden, | |
| cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout) | |
| self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model, | |
| cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout) | |
| self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout) | |
| self.output_projection_prot = nn.Linear(cfg.model.d_model, 1) | |
| self.class_weights = torch.tensor([7.803861120886587, 0.5342284711965681]) # binding_site weights, non-bidning site weights | |
| self.cfg = cfg | |
| self._total_steps = None | |
| # Threshold selected each validation epoch to maximize MCC | |
| self.val_threshold = 0.5 | |
| self._val_probs = [] | |
| self._val_labels = [] | |
| def forward(self, binder_tokens, target_tokens): | |
| peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state | |
| protein_sequence = self.esm_model(**target_tokens).last_hidden_state | |
| binder_mask = binder_tokens["attention_mask"] # [B, Ls] | |
| target_mask = target_tokens["attention_mask"] # [B, Lp] | |
| prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \ | |
| prot_seq_attention_list, seq_prot_attention_list = self.repeated_module( | |
| peptide_sequence, | |
| protein_sequence, | |
| peptide_mask=binder_mask, | |
| protein_mask=target_mask, | |
| ) | |
| # final cross-attention: protein queries attend to binder keys | |
| prot_enc, final_prot_seq_attention = self.final_attention_layer( | |
| prot_enc, sequence_enc, sequence_enc, | |
| key_padding_mask=binder_mask, | |
| query_padding_mask=target_mask, | |
| ) | |
| prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask) | |
| prot_enc = self.output_projection_prot(prot_enc) | |
| return prot_enc | |
| def training_step(self, batch, batch_idx): | |
| target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)} | |
| binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)} | |
| binding_site = torch.tensor(batch['labels']).float().to(self.device) | |
| mask = torch.tensor(batch['labels_mask']).to(self.device) | |
| outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1) | |
| weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site) | |
| bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none') | |
| masked_bce_loss = bce_loss * mask | |
| mean_bce_loss = masked_bce_loss.sum() / mask.sum() | |
| kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask) | |
| mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss | |
| batch_size = binding_site.shape[0] | |
| self.log('bce_loss', mean_bce_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('kl_loss', kl_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('train_loss', mean_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| return mean_loss | |
| def on_validation_epoch_start(self): | |
| # clear cached validation probabilities/labels for threshold optimization | |
| self._val_probs = [] | |
| self._val_labels = [] | |
| def validation_step(self, batch, batch_idx): | |
| target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)} | |
| binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device), | |
| 'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)} | |
| binding_site = torch.tensor(batch['labels']).float().to(self.device) | |
| mask = torch.tensor(batch['labels_mask']).to(self.device) | |
| outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1) | |
| weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site) | |
| bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none') | |
| masked_bce_loss = bce_loss * mask | |
| mean_bce_loss = masked_bce_loss.sum() / mask.sum() | |
| kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask) | |
| mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss | |
| # Collect masked probabilities + labels for epoch-end threshold optimization | |
| sigmoid_outputs = torch.sigmoid(outputs_nodes) | |
| probs = sigmoid_outputs[mask.bool()].detach() | |
| labels = binding_site[mask.bool()].detach() | |
| self._val_probs.append(probs) | |
| self._val_labels.append(labels) | |
| batch_size = binding_site.shape[0] | |
| self.log('val_loss', mean_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| self.log('val_bce_loss', mean_bce_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size) | |
| def on_validation_epoch_end(self): | |
| # Concatenate cached tensors from all validation batches | |
| if len(self._val_probs) == 0: | |
| return | |
| probs_t = torch.cat(self._val_probs, dim=0).float().detach().cpu() | |
| labels_t = torch.cat(self._val_labels, dim=0).float().detach().cpu() | |
| # DDP-safe: gather variable-length arrays across ranks | |
| if dist.is_available() and dist.is_initialized(): | |
| world_size = dist.get_world_size() | |
| local_obj = {"probs": probs_t.numpy(), "labels": labels_t.numpy()} | |
| gathered = [None for _ in range(world_size)] | |
| dist.all_gather_object(gathered, local_obj) | |
| probs_np = np.concatenate([g["probs"] for g in gathered], axis=0) | |
| labels_np = np.concatenate([g["labels"] for g in gathered], axis=0) | |
| else: | |
| probs_np = probs_t.numpy() | |
| labels_np = labels_t.numpy() | |
| # Threshold-free metric (may be undefined if only one class is present) | |
| try: | |
| auc = roc_auc_score(labels_np, probs_np) | |
| except ValueError: | |
| auc = float("nan") | |
| # Find MCC-optimal threshold on the whole validation set | |
| best_mcc, best_thr = best_mcc_threshold(labels_np, probs_np) | |
| self.val_threshold = best_thr | |
| preds_np = (probs_np >= best_thr).astype(np.int32) | |
| labels_i = labels_np.astype(np.int32) | |
| f1 = f1_score(labels_i, preds_np) | |
| mcc = matthews_corrcoef(labels_i, preds_np) | |
| acc = (preds_np == labels_i).mean() | |
| # Log epoch-level metrics (checkpoint/early-stop monitors val_mcc) | |
| self.log("val_auc", auc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_f1", f1, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_mcc", mcc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) | |
| self.log("val_mcc_threshold", best_thr, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True) | |
| def compute_kl_loss(self, logits, targets, valid_mask, eps=1e-8, label_smoothing=0.0): | |
| """ | |
| KL(q || p) where | |
| p = softmax(logits over valid residues) | |
| q = normalized target mass over positive residues (multi-hot -> distribution) | |
| logits: [B, L] (raw outputs_nodes) | |
| targets: [B, L] (0/1 binding_site) | |
| valid_mask: [B, L] (1 for real residues, 0 for padding) | |
| """ | |
| valid = valid_mask.bool() | |
| # 1) predicted log-prob distribution over residues (pads excluded) | |
| masked_logits = logits.masked_fill(~valid, -1e9) | |
| log_p = F.log_softmax(masked_logits, dim=-1) # [B, L] | |
| # 2) target distribution q over residues (pads excluded, normalize over positives) | |
| q_mass = (targets.float() * valid.float()) # [B, L] | |
| q_sum = q_mass.sum(dim=-1, keepdim=True) # [B, 1] | |
| # samples with no positive residues -> skip KL (or you can define another behavior) | |
| has_pos = (q_sum.squeeze(-1) > 0) | |
| q = q_mass / (q_sum + eps) # [B, L] | |
| # optional: label smoothing toward uniform over valid residues | |
| if label_smoothing > 0.0: | |
| uniform = valid.float() | |
| uniform = uniform / uniform.sum(dim=-1, keepdim=True).clamp_min(1.0) | |
| q = (1.0 - label_smoothing) * q + label_smoothing * uniform | |
| # 3) KL(q || p) per sample; F.kl_div expects input=log-probs, target=probs | |
| kl_per_pos = F.kl_div(log_p, q, reduction="none") # [B, L] | |
| kl_per_sample = kl_per_pos.sum(dim=-1) # [B] | |
| # only average over samples that have at least one positive | |
| if has_pos.any(): | |
| return kl_per_sample[has_pos].mean() | |
| else: | |
| # no positives in batch (rare) -> return 0 so it doesn't blow up | |
| return logits.new_tensor(0.0) | |
| def configure_optimizers(self): | |
| opt = torch.optim.AdamW( | |
| self.parameters(), | |
| lr=float(self.cfg.optim.lr), | |
| betas=(self.cfg.optim.beta1, self.cfg.optim.beta2), | |
| eps=float(self.cfg.optim.eps), | |
| weight_decay=self.cfg.optim.weight_decay, | |
| fused=self.cfg.optim.fused, | |
| ) | |
| warmup_ratio = getattr(self.cfg.optim, "warmup_ratio", 0.1) | |
| min_scale = 0.1 | |
| def lr_lambda(global_step: int): | |
| # until on_train_start runs we just return 1.0 | |
| if self._total_steps is None or self._total_steps == 0: | |
| return 1.0 | |
| total_steps = self._total_steps | |
| warmup_steps = max(1, int(warmup_ratio * total_steps)) | |
| if global_step < warmup_steps: | |
| # linear warmup: 0.1 -> 1.0 | |
| alpha = (global_step + 1) / warmup_steps | |
| return 0.1 + 0.9 * alpha | |
| else: | |
| # cosine from 1.0 down to min_scale | |
| progress = (global_step - warmup_steps) / max(1, total_steps - warmup_steps) | |
| cosine = 0.5 * (1 + math.cos(math.pi * progress)) # 1 -> 0 | |
| return min_scale + (1.0 - min_scale) * cosine | |
| sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda) | |
| return { | |
| "optimizer": opt, | |
| "lr_scheduler": { | |
| "scheduler": sch, | |
| "interval": "step", # <- per-step | |
| "frequency": 1, | |
| }, | |
| } | |
| def on_train_start(self): | |
| # how many optimizer steps we will take in this fit | |
| self._total_steps = self.trainer.estimated_stepping_batches | |
| # def on_training_epoch_end(self, outputs): | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| # super().training_epoch_end(outputs) | |
| # def on_validation_epoch_end(self, outputs): | |
| # gc.collect() | |
| # torch.cuda.empty_cache() | |
| # super().validation_epoch_end(outputs) | |
| def main(): | |
| CONFIG_PATH = './config.yaml' | |
| with open(CONFIG_PATH, 'r') as f: | |
| config_dict = yaml.safe_load(f) | |
| cfg = edict(config_dict) | |
| run_name = f"mcc_lr={cfg.optim.lr}_nlayers={cfg.model.n_layers}_dmodel={cfg.model.d_model}_nhead={cfg.model.n_head}_dinner={cfg.model.d_inner}_kl{cfg.training.kl_weight}" | |
| workdir = os.path.join(cfg.work_dir, run_name) | |
| train_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset_ChemBERTa/train') | |
| valid_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset_ChemBERTa/valid') | |
| train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=True, num_workers=4) | |
| valid_dataloader = DataLoader(valid_dataset, batch_size=None, shuffle=False, num_workers=4) | |
| model = BindEvaluator(cfg) | |
| wandb_logger = WandbLogger( | |
| project='SMILES_BindEvaluator', | |
| name=run_name, | |
| entity='programmablebio', | |
| ) | |
| checkpoint_callback = ModelCheckpoint( | |
| monitor='val_mcc', | |
| dirpath=workdir, | |
| filename='model-{epoch:02d}-{val_mcc:.2f}', | |
| save_top_k=5, | |
| mode='max', | |
| ) | |
| lrmon = LearningRateMonitor(logging_interval="step") | |
| early_stopping_callback = EarlyStopping( | |
| monitor='val_loss', | |
| patience=5, | |
| verbose=True, | |
| mode='min' | |
| ) | |
| # accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2}) | |
| trainer = pl.Trainer( | |
| default_root_dir=workdir, | |
| max_epochs=cfg.optim.max_epochs, | |
| accelerator='gpu', | |
| strategy='ddp_find_unused_parameters_true', | |
| precision='bf16-mixed', | |
| devices=cfg.compute.ngpus, | |
| callbacks=[checkpoint_callback, early_stopping_callback, lrmon], | |
| gradient_clip_val=cfg.optim.grad_clip, | |
| log_every_n_steps=10, | |
| logger=wandb_logger, | |
| ) | |
| trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader) | |
| best_model_path = checkpoint_callback.best_model_path | |
| print(f"Best model saved at: {best_model_path}") | |
| trainer.test(ckpt_path=best_model_path) | |
| if __name__ == "__main__": | |
| main() | |
Xet Storage Details
- Size:
- 18 kB
- Xet hash:
- a6409685d9996a60ac8148c9514821385d48a48bd92423fde9b984da8cea21d2
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.