DELM / src /predict.py
xushijie
add app
21f308b
from pathlib import Path
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from models.polybert import PolyEncoder
from models.training import BaseModel
from models.utils import decrypt_checkpoint, load_private_key_from_file
import argparse
from tqdm import tqdm
from models.utils import Config
from models.plm import EsmModelInfo, get_model
import pandas as pd
if __name__ == "__main__":
# fmt: off
parser = argparse.ArgumentParser(description="Predict plastic degradation")
parser.add_argument("--ckpt", type=str, help="Path to the model checkpoint")
parser.add_argument("--plm", type=str, help="Protein language model to use", default='esm2_t33_650M_UR50D')
parser.add_argument("--csv", type=str, help="Path to the CSV file with test data", default=None)
parser.add_argument("--output",'-o', type=str, help="Path to the output file", default='predictions.csv')
parser.add_argument("--attn", action='store_true', help="Save attention weights to files")
# fmt: on
args = parser.parse_args()
info = EsmModelInfo(args.plm)
plm_dim = info['dim']*2
pbert_dim = 600
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BaseModel(plm_dim, pbert_dim, n_classes=2).to(dev)
# load weights
private_key = load_private_key_from_file()
state_dict = decrypt_checkpoint(args.ckpt, private_key)
state_dict = {
k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')}
model.load_state_dict(state_dict)
model.eval()
print(f'Load predictor from {args.ckpt}')
plm_func = get_model(args.plm, 'cuda')
print(f'Loaded PLM model {args.plm}')
polybert_func = PolyEncoder()
print('Loaded PolyEncoder model')
outfile = Path(
'predictions.csv' if args.output is None else args.output)
# get protein embedding
with torch.no_grad(), torch.inference_mode():
df = pd.read_csv(args.csv)
probs = []
running_time = []
for i, row in tqdm(df.iterrows()):
start_time = time.time()
seq = row['sequence'].upper()
poly = row['polymer']
seq_emb = plm_func([seq]).to(dev)
seq_emb = rearrange(seq_emb, 'b l d -> b (l d)').unsqueeze(0)
poly_emb = polybert_func([poly]).to(dev)
logits, p_weights, l_weights = model((seq_emb, poly_emb))
prob = F.softmax(logits, dim=-1)[:, 1].item()
if args.attn:
outfile.with_suffix('.attn').mkdir(
parents=True, exist_ok=True)
torch.save(
(p_weights, l_weights),
outfile.with_suffix('.attn') / f'{i}.pt')
probs.append(prob)
running_time.append(time.time() - start_time)
df['prob'] = probs
df['pred'] = df['prob'].apply(lambda x: 'Yes' if x >= 0.5 else 'No')
df['time'] = running_time
# move pred and prob to the front
df = df[['pred', 'prob'] +
[col for col in df.columns if col not in ['pred', 'prob']]]
df.to_csv(outfile, index=False)
print(f'Predictions saved to {outfile}')
print(f'Attention weights saved to current directory as <index>.pt')