|
|
from pathlib import Path |
|
|
from einops import rearrange |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import time |
|
|
from models.polybert import PolyEncoder |
|
|
from models.training import BaseModel |
|
|
from models.utils import decrypt_checkpoint, load_private_key_from_file |
|
|
import argparse |
|
|
from tqdm import tqdm |
|
|
|
|
|
from models.utils import Config |
|
|
from models.plm import EsmModelInfo, get_model |
|
|
import pandas as pd |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Predict plastic degradation") |
|
|
parser.add_argument("--ckpt", type=str, help="Path to the model checkpoint") |
|
|
parser.add_argument("--plm", type=str, help="Protein language model to use", default='esm2_t33_650M_UR50D') |
|
|
parser.add_argument("--csv", type=str, help="Path to the CSV file with test data", default=None) |
|
|
parser.add_argument("--output",'-o', type=str, help="Path to the output file", default='predictions.csv') |
|
|
parser.add_argument("--attn", action='store_true', help="Save attention weights to files") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
info = EsmModelInfo(args.plm) |
|
|
plm_dim = info['dim']*2 |
|
|
pbert_dim = 600 |
|
|
|
|
|
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
model = BaseModel(plm_dim, pbert_dim, n_classes=2).to(dev) |
|
|
|
|
|
|
|
|
private_key = load_private_key_from_file() |
|
|
state_dict = decrypt_checkpoint(args.ckpt, private_key) |
|
|
state_dict = { |
|
|
k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')} |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
print(f'Load predictor from {args.ckpt}') |
|
|
|
|
|
plm_func = get_model(args.plm, 'cuda') |
|
|
print(f'Loaded PLM model {args.plm}') |
|
|
|
|
|
polybert_func = PolyEncoder() |
|
|
print('Loaded PolyEncoder model') |
|
|
|
|
|
outfile = Path( |
|
|
'predictions.csv' if args.output is None else args.output) |
|
|
|
|
|
with torch.no_grad(), torch.inference_mode(): |
|
|
df = pd.read_csv(args.csv) |
|
|
probs = [] |
|
|
running_time = [] |
|
|
for i, row in tqdm(df.iterrows()): |
|
|
start_time = time.time() |
|
|
|
|
|
seq = row['sequence'].upper() |
|
|
poly = row['polymer'] |
|
|
seq_emb = plm_func([seq]).to(dev) |
|
|
seq_emb = rearrange(seq_emb, 'b l d -> b (l d)').unsqueeze(0) |
|
|
poly_emb = polybert_func([poly]).to(dev) |
|
|
logits, p_weights, l_weights = model((seq_emb, poly_emb)) |
|
|
prob = F.softmax(logits, dim=-1)[:, 1].item() |
|
|
if args.attn: |
|
|
outfile.with_suffix('.attn').mkdir( |
|
|
parents=True, exist_ok=True) |
|
|
torch.save( |
|
|
(p_weights, l_weights), |
|
|
outfile.with_suffix('.attn') / f'{i}.pt') |
|
|
probs.append(prob) |
|
|
running_time.append(time.time() - start_time) |
|
|
|
|
|
df['prob'] = probs |
|
|
df['pred'] = df['prob'].apply(lambda x: 'Yes' if x >= 0.5 else 'No') |
|
|
df['time'] = running_time |
|
|
|
|
|
|
|
|
df = df[['pred', 'prob'] + |
|
|
[col for col in df.columns if col not in ['pred', 'prob']]] |
|
|
|
|
|
df.to_csv(outfile, index=False) |
|
|
print(f'Predictions saved to {outfile}') |
|
|
print(f'Attention weights saved to current directory as <index>.pt') |
|
|
|