DELM / src /models /training.py
xushijie
add app
21f308b
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from models.dataset import EnzymeDataModule
import lightning as L
from timm.scheduler.cosine_lr import CosineLRScheduler
from argparse import Namespace as Args
import wandb
import time
from models.dataset import polymer2psmiles
from models.plm import EsmModelInfo
from models.polybert import PolyEncoder
from models.utils import is_wandb_running
from pathlib import Path
from einops import rearrange
import numpy as np
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, precision_score, recall_score
class CrossAttnLayer(nn.Module):
def __init__(self, protein_dim, smiles_dim, nheads=8):
super().__init__()
self.fc_smiles = nn.Linear(smiles_dim, protein_dim)
self.fc_protein = nn.Linear(protein_dim, smiles_dim)
self.smiles2protein = nn.MultiheadAttention(
smiles_dim, nheads, batch_first=True)
self.protein2smiles = nn.MultiheadAttention(
protein_dim, nheads, batch_first=True)
def forward(self, protein, smiles):
down_protein = self.fc_protein(protein)
up_smiles = self.fc_smiles(smiles)
l_attn, l_weights = self.smiles2protein(
smiles, down_protein, down_protein)
p_attn, p_weights = self.protein2smiles(protein, up_smiles, up_smiles)
return l_attn, p_attn, l_weights, p_weights
class BaseModel(nn.Module):
def __init__(self, in_dim1, in_dim2, n_classes):
super().__init__()
self.attn = CrossAttnLayer(in_dim1, in_dim2)
self.fc = nn.Linear(in_dim1 + in_dim2, n_classes)
def forward(self, x):
protein, smiles = x
P, L, P_weights, L_weights = self.attn(protein, smiles)
x = torch.cat((P.mean(dim=1), L.mean(dim=1)), dim=-1)
x = self.fc(x)
return x, P_weights, L_weights
class PlasticPredictor(L.LightningModule):
def __init__(self, args: L.LightningModule):
super().__init__()
self.args = args
info = EsmModelInfo(args.plm)
plm_dim = info['dim']*2 # the first and last layers are concatenated
pbert_dim = 600
self.model = BaseModel(
in_dim1=plm_dim, in_dim2=pbert_dim, n_classes=2)
self.cached_proteins = {}
self.cached_smiles = {}
self.encoder = {} # trick: use dictionary to exclude modules
self.encoder['polybert'] = PolyEncoder()
self.automatic_optimization = False
def forward(self, x):
pass
def get_protein_embedding(self, seq_id):
if seq_id not in self.cached_proteins:
seq_path = f'cache/protein/{self.args.plm}/{seq_id}.pt'
if not Path(seq_path).exists():
raise FileNotFoundError(
f"Protein embedding for {seq_id} not found.")
emb = torch.load(seq_path)
emb = rearrange(emb, 'b l d -> b (l d)')
self.cached_proteins[seq_id] = emb
return self.cached_proteins[seq_id]
def get_smiles_embedding(self, polymer):
smi = polymer2psmiles[polymer]
# mol = Chem.MolFromSmiles(smi)
# smi = Chem.MolToSmiles(mol, doRandom=True)
# # replace * with [*]
# smi = smi.replace('*', '[*]')
if smi not in self.cached_smiles:
# first dimension is 1
with torch.no_grad(), torch.inference_mode():
emb = self.encoder['polybert']([smi])[0, 2:-1, :]
self.cached_smiles[smi] = emb
return self.cached_smiles[smi]
def step(self, batch):
polymer, seq, deg, seq_id, poly_id = zip(batch)
seqs = [self.get_protein_embedding(s.item()) for s in seq_id[0]]
polys = [self.get_smiles_embedding(p) for p in polymer[0]]
protein_lengths = [len(s) for s in seqs]
smiles_lengths = [len(p) for p in polys]
seqs = nn.utils.rnn.pad_sequence(
seqs, batch_first=True).to(self.device)
polys = nn.utils.rnn.pad_sequence(
polys, batch_first=True).to(self.device)
protein_lengths = torch.tensor(
protein_lengths, dtype=torch.long).to(self.device)
smiles_lengths = torch.tensor(
smiles_lengths, dtype=torch.long).to(self.device)
logits, P_weights, L_weights = self.model((seqs, polys))
# Flatten the output for cross-entropy loss
logits = logits.view(-1, 2)
deg = deg[0].to(self.device)
loss = F.cross_entropy(logits, deg, reduction='mean')
y_prob = torch.softmax(logits, dim=-1)[:, 1]
return deg, y_prob, loss
def training_step(self, batch, batch_idx):
y, y_prob, loss = self.step(batch)
self.log_dict({"train/loss": loss, }, prog_bar=True)
self.manual_backward(loss)
self.optimizers().step()
self.lr_scheduler_step()
self.optimizers().zero_grad()
def validation_step(self, batch, batch_idx):
y, y_prob, loss = self.step(batch)
self.y.append(y.detach().cpu().numpy())
self.y_prob.append(y_prob.detach().cpu().numpy())
def on_validation_epoch_start(self):
self.y, self.y_prob = [], []
def on_validation_epoch_end(self):
y_prob = np.concatenate(self.y_prob, axis=0)
y = np.concatenate(self.y, axis=0)
auc = roc_auc_score(y, y_prob)
f1 = f1_score(y, y_prob > 0.5)
mcc = matthews_corrcoef(y, y_prob > 0.5)
precision = precision_score(y, y_prob > 0.5)
recall = recall_score(y, y_prob > 0.5)
self.log_dict({
"val_auc": auc,
"val_f1": f1,
"val_mcc": mcc,
"val_pre": precision,
"val_rec": recall,
}, prog_bar=True)
def configure_optimizers(self):
optimizer = torch.optim.AdamW(
self.model.parameters(), lr=self.args.lr, weight_decay=self.args.wd
)
warmup_steps = round(self.args.t_initial * 0.1)
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=self.args.t_initial,
lr_min=1e-5,
warmup_t=warmup_steps,
warmup_lr_init=1e-5,
warmup_prefix=True,
)
self.lr_scheduler = lr_scheduler
return [optimizer]
def lr_scheduler_step(self, *args, **kwargs):
if self.trainer.global_step < self.trainer.max_steps:
self.lr_scheduler.step_update(self.trainer.global_step)
def train_plastic(args: Args):
L.seed_everything(args.seed)
if is_wandb_running():
wandb.init(project="plastic-predictor",)
args.__dict__.update(dict(wandb.config))
dm = EnzymeDataModule(args)
for kfold in range(args.nfolds):
print(f"Training fold {kfold + 1}/{args.nfolds}")
model = PlasticPredictor(args)
dm.setup_k_fold(kfold)
print(
f'Data loaded: {len(dm.train_dataloader())} train, {len(dm.val_dataloader())} val, {len(dm.test_dataloader())} test')
devices = 1
logger = None
# devices = torch.cuda.device_count()
# logger = L.pytorch.loggers.WandbLogger(project="plastic-predictor",)
strategy = "ddp" if devices > 1 else "auto"
steps_per_epoch = len(dm.train_dataloader())
args.__dict__.update(
{
"batch_size": args.batch_size // devices,
"dev_count": devices,
"t_initial": args.epochs * steps_per_epoch,
"steps_per_epoch": steps_per_epoch,
}
)
print(f"Total steps: {args.t_initial}")
checkpoint = L.pytorch.callbacks.ModelCheckpoint(
dirpath=args.ckpt_dir,
filename=f"plastic-{{epoch:02d}}-{{val_auc:.4f}}",
# monitor="val_auc",
# mode="max",
)
early_stopping = L.pytorch.callbacks.EarlyStopping(
monitor="val_auc",
patience=args.patience,
mode="max",
verbose=True,
)
precision = "16-mixed" if args.amp else "32-true"
trainer = L.Trainer(
max_epochs=args.epochs,
accelerator="gpu",
devices=devices,
strategy=strategy,
precision=precision,
log_every_n_steps=1,
callbacks=[checkpoint, early_stopping],
# callbacks=[checkpoint],
# enable_checkpointing=False,
logger=logger,
)
trainer.fit(model, dm)
trainer.validate(model, dm.val_dataloader(),
ckpt_path="best", verbose=True)
time.sleep(1)
val_test_metrics = trainer.callback_metrics.copy()
trainer.validate(model, dm.test_dataloader(),
ckpt_path="best", verbose=True)
time.sleep(1)
val_test_metrics.update(
[(k.replace("val_", "test_"), v)
for k, v in trainer.callback_metrics.items()]
)
# val_test_metrics['fold'] = kfold + 1
# add _fold suffix to each key
val_test_metrics = {
k + f"_fold{kfold + 1}": v for k, v in val_test_metrics.items()
}
if is_wandb_running():
wandb.log(val_test_metrics)