AlienChen's picture
download
raw
18 kB
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from datasets import load_from_disk
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, \
Timer, TQDMProgressBar, LearningRateMonitor, StochasticWeightAveraging, GradientAccumulationScheduler
from pytorch_lightning.loggers import WandbLogger
from torch.optim.lr_scheduler import _LRScheduler
from argparse import ArgumentParser
import os
import math
import numpy as np
import torch.distributed as dist
from transformers import AutoModelWithLMHead, EsmModel
from torch.optim import AdamW
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef
import yaml
from easydict import EasyDict as edict
from models import *
from transformers.utils import logging
logging.set_verbosity_error()
def best_mcc_threshold(y_true: np.ndarray, y_prob: np.ndarray):
"""Find threshold t that maximizes MCC for rule pred = prob >= t.
Uses sorting + cumulative counts over unique probability values, so it is
much faster than scanning a fixed grid.
Returns:
(best_mcc, best_threshold)
Notes:
- If y_true has only one class or is empty, returns (0.0, 0.5).
"""
y_true = np.asarray(y_true).astype(np.int32)
y_prob = np.asarray(y_prob).astype(np.float64)
if y_true.size == 0:
return 0.0, 0.5
P = int(y_true.sum())
N = int(y_true.size - P)
if P == 0 or N == 0:
return 0.0, 0.5
order = np.argsort(-y_prob) # descending
y_sorted = y_true[order]
p_sorted = y_prob[order]
tp_cum = np.cumsum(y_sorted)
fp_cum = np.arange(1, y_sorted.size + 1) - tp_cum
# Evaluate MCC only when probability changes (unique threshold candidates)
change = np.r_[p_sorted[1:] != p_sorted[:-1], True]
idx = np.where(change)[0]
tp = tp_cum[idx].astype(np.float64)
fp = fp_cum[idx].astype(np.float64)
fn = (P - tp).astype(np.float64)
tn = (N - fp).astype(np.float64)
denom = np.sqrt((tp + fp) * (tp + fn) * (tn + fp) * (tn + fn))
mcc = np.zeros_like(tp, dtype=np.float64)
valid = denom > 0
mcc[valid] = (tp[valid] * tn[valid] - fp[valid] * fn[valid]) / denom[valid]
best_i = int(np.argmax(mcc))
best_thr = float(p_sorted[idx[best_i]])
best_mcc = float(mcc[best_i])
return best_mcc, best_thr
class CosineAnnealingWithWarmup(_LRScheduler):
def __init__(self, optimizer, warmup_steps, total_steps, base_lr, max_lr, min_lr, last_epoch=-1):
self.warmup_steps = warmup_steps
self.total_steps = total_steps
self.base_lr = base_lr
self.max_lr = max_lr
self.min_lr = min_lr
super(CosineAnnealingWithWarmup, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_steps:
# Linear warmup phase from base_lr to max_lr
return [self.base_lr + (self.max_lr - self.base_lr) * (self.last_epoch / self.warmup_steps) for base_lr in self.base_lrs]
# Cosine annealing phase from max_lr to min_lr
progress = (self.last_epoch - self.warmup_steps) / (self.total_steps - self.warmup_steps)
cosine_decay = 0.5 * (1 + np.cos(np.pi * progress))
decayed_lr = self.min_lr + (self.max_lr - self.min_lr) * cosine_decay
return [decayed_lr for base_lr in self.base_lrs]
class BindEvaluator(pl.LightningModule):
def __init__(self, cfg):
super(BindEvaluator, self).__init__()
self.esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D").eval()
for param in self.esm_model.parameters():
param.requires_grad = False
self.chemberta_model = AutoModelWithLMHead.from_pretrained("seyonec/ChemBERTa_zinc250k_v2_40k").roberta.eval()
for param in self.chemberta_model.parameters():
param.requires_grad = False
self.repeated_module = RepeatedModule(cfg.model.n_layers, cfg.model.d_model, cfg.model.d_hidden,
cfg.model.n_head, cfg.model.d_k, cfg.model.d_v, cfg.model.d_inner, dropout=cfg.model.dropout)
self.final_attention_layer = MultiHeadAttentionSequence(cfg.model.n_head, cfg.model.d_model,
cfg.model.d_k, cfg.model.d_v, dropout=cfg.model.dropout)
self.final_ffn = FFN(cfg.model.d_model, cfg.model.d_inner, dropout=cfg.model.dropout)
self.output_projection_prot = nn.Linear(cfg.model.d_model, 1)
self.class_weights = torch.tensor([7.803861120886587, 0.5342284711965681]) # binding_site weights, non-bidning site weights
self.cfg = cfg
self._total_steps = None
# Threshold selected each validation epoch to maximize MCC
self.val_threshold = 0.5
self._val_probs = []
self._val_labels = []
def forward(self, binder_tokens, target_tokens):
peptide_sequence = self.chemberta_model(**binder_tokens).last_hidden_state
protein_sequence = self.esm_model(**target_tokens).last_hidden_state
binder_mask = binder_tokens["attention_mask"] # [B, Ls]
target_mask = target_tokens["attention_mask"] # [B, Lp]
prot_enc, sequence_enc, sequence_attention_list, prot_attention_list, \
prot_seq_attention_list, seq_prot_attention_list = self.repeated_module(
peptide_sequence,
protein_sequence,
peptide_mask=binder_mask,
protein_mask=target_mask,
)
# final cross-attention: protein queries attend to binder keys
prot_enc, final_prot_seq_attention = self.final_attention_layer(
prot_enc, sequence_enc, sequence_enc,
key_padding_mask=binder_mask,
query_padding_mask=target_mask,
)
prot_enc = self.final_ffn(prot_enc, padding_mask=target_mask)
prot_enc = self.output_projection_prot(prot_enc)
return prot_enc
def training_step(self, batch, batch_idx):
target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)}
binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)}
binding_site = torch.tensor(batch['labels']).float().to(self.device)
mask = torch.tensor(batch['labels_mask']).to(self.device)
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
masked_bce_loss = bce_loss * mask
mean_bce_loss = masked_bce_loss.sum() / mask.sum()
kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask)
mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss
batch_size = binding_site.shape[0]
self.log('bce_loss', mean_bce_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('kl_loss', kl_loss, on_step=False, on_epoch=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('train_loss', mean_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
return mean_loss
def on_validation_epoch_start(self):
# clear cached validation probabilities/labels for threshold optimization
self._val_probs = []
self._val_labels = []
def validation_step(self, batch, batch_idx):
target_tokens = {'input_ids': torch.tensor(batch['target_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['target_attention_mask']).to(self.device)}
binder_tokens = {'input_ids': torch.tensor(batch['binder_input_ids']).to(self.device),
'attention_mask': torch.tensor(batch['binder_attention_mask']).to(self.device)}
binding_site = torch.tensor(batch['labels']).float().to(self.device)
mask = torch.tensor(batch['labels_mask']).to(self.device)
outputs_nodes = self.forward(binder_tokens, target_tokens).squeeze(-1)
weight = self.class_weights[0] * binding_site + self.class_weights[1] * (1 - binding_site)
bce_loss = F.binary_cross_entropy_with_logits(outputs_nodes, binding_site, weight=weight, reduction='none')
masked_bce_loss = bce_loss * mask
mean_bce_loss = masked_bce_loss.sum() / mask.sum()
kl_loss = self.compute_kl_loss(outputs_nodes, binding_site, mask)
mean_loss = mean_bce_loss + self.cfg.training.kl_weight * kl_loss
# Collect masked probabilities + labels for epoch-end threshold optimization
sigmoid_outputs = torch.sigmoid(outputs_nodes)
probs = sigmoid_outputs[mask.bool()].detach()
labels = binding_site[mask.bool()].detach()
self._val_probs.append(probs)
self._val_labels.append(labels)
batch_size = binding_site.shape[0]
self.log('val_loss', mean_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_kl_loss', kl_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
self.log('val_bce_loss', mean_bce_loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True, batch_size=batch_size)
def on_validation_epoch_end(self):
# Concatenate cached tensors from all validation batches
if len(self._val_probs) == 0:
return
probs_t = torch.cat(self._val_probs, dim=0).float().detach().cpu()
labels_t = torch.cat(self._val_labels, dim=0).float().detach().cpu()
# DDP-safe: gather variable-length arrays across ranks
if dist.is_available() and dist.is_initialized():
world_size = dist.get_world_size()
local_obj = {"probs": probs_t.numpy(), "labels": labels_t.numpy()}
gathered = [None for _ in range(world_size)]
dist.all_gather_object(gathered, local_obj)
probs_np = np.concatenate([g["probs"] for g in gathered], axis=0)
labels_np = np.concatenate([g["labels"] for g in gathered], axis=0)
else:
probs_np = probs_t.numpy()
labels_np = labels_t.numpy()
# Threshold-free metric (may be undefined if only one class is present)
try:
auc = roc_auc_score(labels_np, probs_np)
except ValueError:
auc = float("nan")
# Find MCC-optimal threshold on the whole validation set
best_mcc, best_thr = best_mcc_threshold(labels_np, probs_np)
self.val_threshold = best_thr
preds_np = (probs_np >= best_thr).astype(np.int32)
labels_i = labels_np.astype(np.int32)
f1 = f1_score(labels_i, preds_np)
mcc = matthews_corrcoef(labels_i, preds_np)
acc = (preds_np == labels_i).mean()
# Log epoch-level metrics (checkpoint/early-stop monitors val_mcc)
self.log("val_auc", auc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log("val_f1", f1, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log("val_mcc", mcc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log("val_accuracy", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
self.log("val_mcc_threshold", best_thr, on_step=False, on_epoch=True, prog_bar=False, logger=True, sync_dist=True)
def compute_kl_loss(self, logits, targets, valid_mask, eps=1e-8, label_smoothing=0.0):
"""
KL(q || p) where
p = softmax(logits over valid residues)
q = normalized target mass over positive residues (multi-hot -> distribution)
logits: [B, L] (raw outputs_nodes)
targets: [B, L] (0/1 binding_site)
valid_mask: [B, L] (1 for real residues, 0 for padding)
"""
valid = valid_mask.bool()
# 1) predicted log-prob distribution over residues (pads excluded)
masked_logits = logits.masked_fill(~valid, -1e9)
log_p = F.log_softmax(masked_logits, dim=-1) # [B, L]
# 2) target distribution q over residues (pads excluded, normalize over positives)
q_mass = (targets.float() * valid.float()) # [B, L]
q_sum = q_mass.sum(dim=-1, keepdim=True) # [B, 1]
# samples with no positive residues -> skip KL (or you can define another behavior)
has_pos = (q_sum.squeeze(-1) > 0)
q = q_mass / (q_sum + eps) # [B, L]
# optional: label smoothing toward uniform over valid residues
if label_smoothing > 0.0:
uniform = valid.float()
uniform = uniform / uniform.sum(dim=-1, keepdim=True).clamp_min(1.0)
q = (1.0 - label_smoothing) * q + label_smoothing * uniform
# 3) KL(q || p) per sample; F.kl_div expects input=log-probs, target=probs
kl_per_pos = F.kl_div(log_p, q, reduction="none") # [B, L]
kl_per_sample = kl_per_pos.sum(dim=-1) # [B]
# only average over samples that have at least one positive
if has_pos.any():
return kl_per_sample[has_pos].mean()
else:
# no positives in batch (rare) -> return 0 so it doesn't blow up
return logits.new_tensor(0.0)
def configure_optimizers(self):
opt = torch.optim.AdamW(
self.parameters(),
lr=float(self.cfg.optim.lr),
betas=(self.cfg.optim.beta1, self.cfg.optim.beta2),
eps=float(self.cfg.optim.eps),
weight_decay=self.cfg.optim.weight_decay,
fused=self.cfg.optim.fused,
)
warmup_ratio = getattr(self.cfg.optim, "warmup_ratio", 0.1)
min_scale = 0.1
def lr_lambda(global_step: int):
# until on_train_start runs we just return 1.0
if self._total_steps is None or self._total_steps == 0:
return 1.0
total_steps = self._total_steps
warmup_steps = max(1, int(warmup_ratio * total_steps))
if global_step < warmup_steps:
# linear warmup: 0.1 -> 1.0
alpha = (global_step + 1) / warmup_steps
return 0.1 + 0.9 * alpha
else:
# cosine from 1.0 down to min_scale
progress = (global_step - warmup_steps) / max(1, total_steps - warmup_steps)
cosine = 0.5 * (1 + math.cos(math.pi * progress)) # 1 -> 0
return min_scale + (1.0 - min_scale) * cosine
sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda=lr_lambda)
return {
"optimizer": opt,
"lr_scheduler": {
"scheduler": sch,
"interval": "step", # <- per-step
"frequency": 1,
},
}
def on_train_start(self):
# how many optimizer steps we will take in this fit
self._total_steps = self.trainer.estimated_stepping_batches
# def on_training_epoch_end(self, outputs):
# gc.collect()
# torch.cuda.empty_cache()
# super().training_epoch_end(outputs)
# def on_validation_epoch_end(self, outputs):
# gc.collect()
# torch.cuda.empty_cache()
# super().validation_epoch_end(outputs)
def main():
CONFIG_PATH = './config.yaml'
with open(CONFIG_PATH, 'r') as f:
config_dict = yaml.safe_load(f)
cfg = edict(config_dict)
run_name = f"mcc_lr={cfg.optim.lr}_nlayers={cfg.model.n_layers}_dmodel={cfg.model.d_model}_nhead={cfg.model.n_head}_dinner={cfg.model.d_inner}_kl{cfg.training.kl_weight}"
workdir = os.path.join(cfg.work_dir, run_name)
train_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset_ChemBERTa/train')
valid_dataset = load_from_disk('/scratch/pranamlab/tong/SMILES_BindEvaluator/datasets/tokenized_dataset_ChemBERTa/valid')
train_dataloader = DataLoader(train_dataset, batch_size=None, shuffle=True, num_workers=4)
valid_dataloader = DataLoader(valid_dataset, batch_size=None, shuffle=False, num_workers=4)
model = BindEvaluator(cfg)
wandb_logger = WandbLogger(
project='SMILES_BindEvaluator',
name=run_name,
entity='programmablebio',
)
checkpoint_callback = ModelCheckpoint(
monitor='val_mcc',
dirpath=workdir,
filename='model-{epoch:02d}-{val_mcc:.2f}',
save_top_k=5,
mode='max',
)
lrmon = LearningRateMonitor(logging_interval="step")
early_stopping_callback = EarlyStopping(
monitor='val_loss',
patience=5,
verbose=True,
mode='min'
)
# accumulator = GradientAccumulationScheduler(scheduling={0: 8, 3: 4, 20: 2})
trainer = pl.Trainer(
default_root_dir=workdir,
max_epochs=cfg.optim.max_epochs,
accelerator='gpu',
strategy='ddp_find_unused_parameters_true',
precision='bf16-mixed',
devices=cfg.compute.ngpus,
callbacks=[checkpoint_callback, early_stopping_callback, lrmon],
gradient_clip_val=cfg.optim.grad_clip,
log_every_n_steps=10,
logger=wandb_logger,
)
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=valid_dataloader)
best_model_path = checkpoint_callback.best_model_path
print(f"Best model saved at: {best_model_path}")
trainer.test(ckpt_path=best_model_path)
if __name__ == "__main__":
main()

Xet Storage Details

Size:
18 kB
·
Xet hash:
a6409685d9996a60ac8148c9514821385d48a48bd92423fde9b984da8cea21d2

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.