import json import torch import torch.nn as nn import torch.nn.functional as F from models.dataset import EnzymeDataModule import lightning as L from timm.scheduler.cosine_lr import CosineLRScheduler from argparse import Namespace as Args import wandb import time from models.dataset import polymer2psmiles from models.plm import EsmModelInfo from models.polybert import PolyEncoder from models.utils import is_wandb_running from pathlib import Path from einops import rearrange import numpy as np from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, precision_score, recall_score class CrossAttnLayer(nn.Module): def __init__(self, protein_dim, smiles_dim, nheads=8): super().__init__() self.fc_smiles = nn.Linear(smiles_dim, protein_dim) self.fc_protein = nn.Linear(protein_dim, smiles_dim) self.smiles2protein = nn.MultiheadAttention( smiles_dim, nheads, batch_first=True) self.protein2smiles = nn.MultiheadAttention( protein_dim, nheads, batch_first=True) def forward(self, protein, smiles): down_protein = self.fc_protein(protein) up_smiles = self.fc_smiles(smiles) l_attn, l_weights = self.smiles2protein( smiles, down_protein, down_protein) p_attn, p_weights = self.protein2smiles(protein, up_smiles, up_smiles) return l_attn, p_attn, l_weights, p_weights class BaseModel(nn.Module): def __init__(self, in_dim1, in_dim2, n_classes): super().__init__() self.attn = CrossAttnLayer(in_dim1, in_dim2) self.fc = nn.Linear(in_dim1 + in_dim2, n_classes) def forward(self, x): protein, smiles = x P, L, P_weights, L_weights = self.attn(protein, smiles) x = torch.cat((P.mean(dim=1), L.mean(dim=1)), dim=-1) x = self.fc(x) return x, P_weights, L_weights class PlasticPredictor(L.LightningModule): def __init__(self, args: L.LightningModule): super().__init__() self.args = args info = EsmModelInfo(args.plm) plm_dim = info['dim']*2 # the first and last layers are concatenated pbert_dim = 600 self.model = BaseModel( in_dim1=plm_dim, in_dim2=pbert_dim, n_classes=2) self.cached_proteins = {} self.cached_smiles = {} self.encoder = {} # trick: use dictionary to exclude modules self.encoder['polybert'] = PolyEncoder() self.automatic_optimization = False def forward(self, x): pass def get_protein_embedding(self, seq_id): if seq_id not in self.cached_proteins: seq_path = f'cache/protein/{self.args.plm}/{seq_id}.pt' if not Path(seq_path).exists(): raise FileNotFoundError( f"Protein embedding for {seq_id} not found.") emb = torch.load(seq_path) emb = rearrange(emb, 'b l d -> b (l d)') self.cached_proteins[seq_id] = emb return self.cached_proteins[seq_id] def get_smiles_embedding(self, polymer): smi = polymer2psmiles[polymer] # mol = Chem.MolFromSmiles(smi) # smi = Chem.MolToSmiles(mol, doRandom=True) # # replace * with [*] # smi = smi.replace('*', '[*]') if smi not in self.cached_smiles: # first dimension is 1 with torch.no_grad(), torch.inference_mode(): emb = self.encoder['polybert']([smi])[0, 2:-1, :] self.cached_smiles[smi] = emb return self.cached_smiles[smi] def step(self, batch): polymer, seq, deg, seq_id, poly_id = zip(batch) seqs = [self.get_protein_embedding(s.item()) for s in seq_id[0]] polys = [self.get_smiles_embedding(p) for p in polymer[0]] protein_lengths = [len(s) for s in seqs] smiles_lengths = [len(p) for p in polys] seqs = nn.utils.rnn.pad_sequence( seqs, batch_first=True).to(self.device) polys = nn.utils.rnn.pad_sequence( polys, batch_first=True).to(self.device) protein_lengths = torch.tensor( protein_lengths, dtype=torch.long).to(self.device) smiles_lengths = torch.tensor( smiles_lengths, dtype=torch.long).to(self.device) logits, P_weights, L_weights = self.model((seqs, polys)) # Flatten the output for cross-entropy loss logits = logits.view(-1, 2) deg = deg[0].to(self.device) loss = F.cross_entropy(logits, deg, reduction='mean') y_prob = torch.softmax(logits, dim=-1)[:, 1] return deg, y_prob, loss def training_step(self, batch, batch_idx): y, y_prob, loss = self.step(batch) self.log_dict({"train/loss": loss, }, prog_bar=True) self.manual_backward(loss) self.optimizers().step() self.lr_scheduler_step() self.optimizers().zero_grad() def validation_step(self, batch, batch_idx): y, y_prob, loss = self.step(batch) self.y.append(y.detach().cpu().numpy()) self.y_prob.append(y_prob.detach().cpu().numpy()) def on_validation_epoch_start(self): self.y, self.y_prob = [], [] def on_validation_epoch_end(self): y_prob = np.concatenate(self.y_prob, axis=0) y = np.concatenate(self.y, axis=0) auc = roc_auc_score(y, y_prob) f1 = f1_score(y, y_prob > 0.5) mcc = matthews_corrcoef(y, y_prob > 0.5) precision = precision_score(y, y_prob > 0.5) recall = recall_score(y, y_prob > 0.5) self.log_dict({ "val_auc": auc, "val_f1": f1, "val_mcc": mcc, "val_pre": precision, "val_rec": recall, }, prog_bar=True) def configure_optimizers(self): optimizer = torch.optim.AdamW( self.model.parameters(), lr=self.args.lr, weight_decay=self.args.wd ) warmup_steps = round(self.args.t_initial * 0.1) lr_scheduler = CosineLRScheduler( optimizer, t_initial=self.args.t_initial, lr_min=1e-5, warmup_t=warmup_steps, warmup_lr_init=1e-5, warmup_prefix=True, ) self.lr_scheduler = lr_scheduler return [optimizer] def lr_scheduler_step(self, *args, **kwargs): if self.trainer.global_step < self.trainer.max_steps: self.lr_scheduler.step_update(self.trainer.global_step) def train_plastic(args: Args): L.seed_everything(args.seed) if is_wandb_running(): wandb.init(project="plastic-predictor",) args.__dict__.update(dict(wandb.config)) dm = EnzymeDataModule(args) for kfold in range(args.nfolds): print(f"Training fold {kfold + 1}/{args.nfolds}") model = PlasticPredictor(args) dm.setup_k_fold(kfold) print( f'Data loaded: {len(dm.train_dataloader())} train, {len(dm.val_dataloader())} val, {len(dm.test_dataloader())} test') devices = 1 logger = None # devices = torch.cuda.device_count() # logger = L.pytorch.loggers.WandbLogger(project="plastic-predictor",) strategy = "ddp" if devices > 1 else "auto" steps_per_epoch = len(dm.train_dataloader()) args.__dict__.update( { "batch_size": args.batch_size // devices, "dev_count": devices, "t_initial": args.epochs * steps_per_epoch, "steps_per_epoch": steps_per_epoch, } ) print(f"Total steps: {args.t_initial}") checkpoint = L.pytorch.callbacks.ModelCheckpoint( dirpath=args.ckpt_dir, filename=f"plastic-{{epoch:02d}}-{{val_auc:.4f}}", # monitor="val_auc", # mode="max", ) early_stopping = L.pytorch.callbacks.EarlyStopping( monitor="val_auc", patience=args.patience, mode="max", verbose=True, ) precision = "16-mixed" if args.amp else "32-true" trainer = L.Trainer( max_epochs=args.epochs, accelerator="gpu", devices=devices, strategy=strategy, precision=precision, log_every_n_steps=1, callbacks=[checkpoint, early_stopping], # callbacks=[checkpoint], # enable_checkpointing=False, logger=logger, ) trainer.fit(model, dm) trainer.validate(model, dm.val_dataloader(), ckpt_path="best", verbose=True) time.sleep(1) val_test_metrics = trainer.callback_metrics.copy() trainer.validate(model, dm.test_dataloader(), ckpt_path="best", verbose=True) time.sleep(1) val_test_metrics.update( [(k.replace("val_", "test_"), v) for k, v in trainer.callback_metrics.items()] ) # val_test_metrics['fold'] = kfold + 1 # add _fold suffix to each key val_test_metrics = { k + f"_fold{kfold + 1}": v for k, v in val_test_metrics.items() } if is_wandb_running(): wandb.log(val_test_metrics)