| | import json |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from models.dataset import EnzymeDataModule |
| | import lightning as L |
| | from timm.scheduler.cosine_lr import CosineLRScheduler |
| | from argparse import Namespace as Args |
| | import wandb |
| | import time |
| | from models.dataset import polymer2psmiles |
| | from models.plm import EsmModelInfo |
| | from models.polybert import PolyEncoder |
| | from models.utils import is_wandb_running |
| | from pathlib import Path |
| | from einops import rearrange |
| | import numpy as np |
| | from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, precision_score, recall_score |
| |
|
| |
|
| | class CrossAttnLayer(nn.Module): |
| | def __init__(self, protein_dim, smiles_dim, nheads=8): |
| | super().__init__() |
| | self.fc_smiles = nn.Linear(smiles_dim, protein_dim) |
| | self.fc_protein = nn.Linear(protein_dim, smiles_dim) |
| | self.smiles2protein = nn.MultiheadAttention( |
| | smiles_dim, nheads, batch_first=True) |
| | self.protein2smiles = nn.MultiheadAttention( |
| | protein_dim, nheads, batch_first=True) |
| |
|
| | def forward(self, protein, smiles): |
| | down_protein = self.fc_protein(protein) |
| | up_smiles = self.fc_smiles(smiles) |
| | l_attn, l_weights = self.smiles2protein( |
| | smiles, down_protein, down_protein) |
| | p_attn, p_weights = self.protein2smiles(protein, up_smiles, up_smiles) |
| | return l_attn, p_attn, l_weights, p_weights |
| |
|
| |
|
| | class BaseModel(nn.Module): |
| | def __init__(self, in_dim1, in_dim2, n_classes): |
| | super().__init__() |
| | self.attn = CrossAttnLayer(in_dim1, in_dim2) |
| | self.fc = nn.Linear(in_dim1 + in_dim2, n_classes) |
| |
|
| | def forward(self, x): |
| | protein, smiles = x |
| | P, L, P_weights, L_weights = self.attn(protein, smiles) |
| | x = torch.cat((P.mean(dim=1), L.mean(dim=1)), dim=-1) |
| | x = self.fc(x) |
| | return x, P_weights, L_weights |
| |
|
| |
|
| | class PlasticPredictor(L.LightningModule): |
| | def __init__(self, args: L.LightningModule): |
| | super().__init__() |
| | self.args = args |
| | info = EsmModelInfo(args.plm) |
| | plm_dim = info['dim']*2 |
| | pbert_dim = 600 |
| | self.model = BaseModel( |
| | in_dim1=plm_dim, in_dim2=pbert_dim, n_classes=2) |
| |
|
| | self.cached_proteins = {} |
| | self.cached_smiles = {} |
| |
|
| | self.encoder = {} |
| | self.encoder['polybert'] = PolyEncoder() |
| |
|
| | self.automatic_optimization = False |
| |
|
| | def forward(self, x): |
| | pass |
| |
|
| | def get_protein_embedding(self, seq_id): |
| | if seq_id not in self.cached_proteins: |
| | seq_path = f'cache/protein/{self.args.plm}/{seq_id}.pt' |
| | if not Path(seq_path).exists(): |
| | raise FileNotFoundError( |
| | f"Protein embedding for {seq_id} not found.") |
| | emb = torch.load(seq_path) |
| | emb = rearrange(emb, 'b l d -> b (l d)') |
| | self.cached_proteins[seq_id] = emb |
| | return self.cached_proteins[seq_id] |
| |
|
| | def get_smiles_embedding(self, polymer): |
| | smi = polymer2psmiles[polymer] |
| | |
| | |
| | |
| | |
| |
|
| | if smi not in self.cached_smiles: |
| | |
| | with torch.no_grad(), torch.inference_mode(): |
| | emb = self.encoder['polybert']([smi])[0, 2:-1, :] |
| | self.cached_smiles[smi] = emb |
| | return self.cached_smiles[smi] |
| |
|
| | def step(self, batch): |
| | polymer, seq, deg, seq_id, poly_id = zip(batch) |
| | seqs = [self.get_protein_embedding(s.item()) for s in seq_id[0]] |
| | polys = [self.get_smiles_embedding(p) for p in polymer[0]] |
| | protein_lengths = [len(s) for s in seqs] |
| | smiles_lengths = [len(p) for p in polys] |
| |
|
| | seqs = nn.utils.rnn.pad_sequence( |
| | seqs, batch_first=True).to(self.device) |
| | polys = nn.utils.rnn.pad_sequence( |
| | polys, batch_first=True).to(self.device) |
| | protein_lengths = torch.tensor( |
| | protein_lengths, dtype=torch.long).to(self.device) |
| | smiles_lengths = torch.tensor( |
| | smiles_lengths, dtype=torch.long).to(self.device) |
| |
|
| | logits, P_weights, L_weights = self.model((seqs, polys)) |
| | |
| | logits = logits.view(-1, 2) |
| | deg = deg[0].to(self.device) |
| |
|
| | loss = F.cross_entropy(logits, deg, reduction='mean') |
| |
|
| | y_prob = torch.softmax(logits, dim=-1)[:, 1] |
| |
|
| | return deg, y_prob, loss |
| |
|
| | def training_step(self, batch, batch_idx): |
| | y, y_prob, loss = self.step(batch) |
| | self.log_dict({"train/loss": loss, }, prog_bar=True) |
| |
|
| | self.manual_backward(loss) |
| |
|
| | self.optimizers().step() |
| | self.lr_scheduler_step() |
| | self.optimizers().zero_grad() |
| |
|
| | def validation_step(self, batch, batch_idx): |
| | y, y_prob, loss = self.step(batch) |
| | self.y.append(y.detach().cpu().numpy()) |
| | self.y_prob.append(y_prob.detach().cpu().numpy()) |
| |
|
| | def on_validation_epoch_start(self): |
| | self.y, self.y_prob = [], [] |
| |
|
| | def on_validation_epoch_end(self): |
| | y_prob = np.concatenate(self.y_prob, axis=0) |
| | y = np.concatenate(self.y, axis=0) |
| | auc = roc_auc_score(y, y_prob) |
| | f1 = f1_score(y, y_prob > 0.5) |
| | mcc = matthews_corrcoef(y, y_prob > 0.5) |
| | precision = precision_score(y, y_prob > 0.5) |
| | recall = recall_score(y, y_prob > 0.5) |
| | self.log_dict({ |
| | "val_auc": auc, |
| | "val_f1": f1, |
| | "val_mcc": mcc, |
| | "val_pre": precision, |
| | "val_rec": recall, |
| | }, prog_bar=True) |
| |
|
| | def configure_optimizers(self): |
| | optimizer = torch.optim.AdamW( |
| | self.model.parameters(), lr=self.args.lr, weight_decay=self.args.wd |
| | ) |
| | warmup_steps = round(self.args.t_initial * 0.1) |
| | lr_scheduler = CosineLRScheduler( |
| | optimizer, |
| | t_initial=self.args.t_initial, |
| | lr_min=1e-5, |
| | warmup_t=warmup_steps, |
| | warmup_lr_init=1e-5, |
| | warmup_prefix=True, |
| | ) |
| | self.lr_scheduler = lr_scheduler |
| | return [optimizer] |
| |
|
| | def lr_scheduler_step(self, *args, **kwargs): |
| | if self.trainer.global_step < self.trainer.max_steps: |
| | self.lr_scheduler.step_update(self.trainer.global_step) |
| |
|
| |
|
| | def train_plastic(args: Args): |
| | L.seed_everything(args.seed) |
| | if is_wandb_running(): |
| | wandb.init(project="plastic-predictor",) |
| | args.__dict__.update(dict(wandb.config)) |
| |
|
| | dm = EnzymeDataModule(args) |
| |
|
| | for kfold in range(args.nfolds): |
| | print(f"Training fold {kfold + 1}/{args.nfolds}") |
| |
|
| | model = PlasticPredictor(args) |
| | dm.setup_k_fold(kfold) |
| |
|
| | print( |
| | f'Data loaded: {len(dm.train_dataloader())} train, {len(dm.val_dataloader())} val, {len(dm.test_dataloader())} test') |
| |
|
| | devices = 1 |
| | logger = None |
| | |
| | |
| |
|
| | strategy = "ddp" if devices > 1 else "auto" |
| | steps_per_epoch = len(dm.train_dataloader()) |
| | args.__dict__.update( |
| | { |
| | "batch_size": args.batch_size // devices, |
| | "dev_count": devices, |
| | "t_initial": args.epochs * steps_per_epoch, |
| | "steps_per_epoch": steps_per_epoch, |
| | } |
| | ) |
| | print(f"Total steps: {args.t_initial}") |
| | checkpoint = L.pytorch.callbacks.ModelCheckpoint( |
| | dirpath=args.ckpt_dir, |
| | filename=f"plastic-{{epoch:02d}}-{{val_auc:.4f}}", |
| | |
| | |
| | ) |
| |
|
| | early_stopping = L.pytorch.callbacks.EarlyStopping( |
| | monitor="val_auc", |
| | patience=args.patience, |
| | mode="max", |
| | verbose=True, |
| | ) |
| | precision = "16-mixed" if args.amp else "32-true" |
| | trainer = L.Trainer( |
| | max_epochs=args.epochs, |
| | accelerator="gpu", |
| | devices=devices, |
| | strategy=strategy, |
| | precision=precision, |
| | log_every_n_steps=1, |
| | callbacks=[checkpoint, early_stopping], |
| | |
| | |
| | logger=logger, |
| | ) |
| |
|
| | trainer.fit(model, dm) |
| |
|
| | trainer.validate(model, dm.val_dataloader(), |
| | ckpt_path="best", verbose=True) |
| | time.sleep(1) |
| | val_test_metrics = trainer.callback_metrics.copy() |
| | trainer.validate(model, dm.test_dataloader(), |
| | ckpt_path="best", verbose=True) |
| |
|
| | time.sleep(1) |
| | val_test_metrics.update( |
| | [(k.replace("val_", "test_"), v) |
| | for k, v in trainer.callback_metrics.items()] |
| | ) |
| |
|
| | |
| | |
| | val_test_metrics = { |
| | k + f"_fold{kfold + 1}": v for k, v in val_test_metrics.items() |
| | } |
| | if is_wandb_running(): |
| | wandb.log(val_test_metrics) |
| |
|