File size: 3,334 Bytes
21f308b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
from pathlib import Path
from einops import rearrange
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
from models.polybert import PolyEncoder
from models.training import BaseModel
from models.utils import decrypt_checkpoint, load_private_key_from_file
import argparse
from tqdm import tqdm

from models.utils import Config
from models.plm import EsmModelInfo, get_model
import pandas as pd


if __name__ == "__main__":
    # fmt: off
    parser = argparse.ArgumentParser(description="Predict plastic degradation")
    parser.add_argument("--ckpt", type=str, help="Path to the model checkpoint")
    parser.add_argument("--plm", type=str, help="Protein language model to use", default='esm2_t33_650M_UR50D')
    parser.add_argument("--csv", type=str, help="Path to the CSV file with test data", default=None)
    parser.add_argument("--output",'-o', type=str, help="Path to the output file", default='predictions.csv')
    parser.add_argument("--attn", action='store_true', help="Save attention weights to files")
    # fmt: on
    args = parser.parse_args()

    info = EsmModelInfo(args.plm)
    plm_dim = info['dim']*2
    pbert_dim = 600

    dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = BaseModel(plm_dim, pbert_dim, n_classes=2).to(dev)

    # load weights
    private_key = load_private_key_from_file()
    state_dict = decrypt_checkpoint(args.ckpt, private_key)
    state_dict = {
        k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')}
    model.load_state_dict(state_dict)
    model.eval()
    print(f'Load predictor from {args.ckpt}')

    plm_func = get_model(args.plm, 'cuda')
    print(f'Loaded PLM model {args.plm}')

    polybert_func = PolyEncoder()
    print('Loaded PolyEncoder model')

    outfile = Path(
        'predictions.csv' if args.output is None else args.output)
    # get protein embedding
    with torch.no_grad(), torch.inference_mode():
        df = pd.read_csv(args.csv)
        probs = []
        running_time = []
        for i, row in tqdm(df.iterrows()):
            start_time = time.time()

            seq = row['sequence'].upper()
            poly = row['polymer']
            seq_emb = plm_func([seq]).to(dev)
            seq_emb = rearrange(seq_emb, 'b l d -> b (l d)').unsqueeze(0)
            poly_emb = polybert_func([poly]).to(dev)
            logits, p_weights, l_weights = model((seq_emb, poly_emb))
            prob = F.softmax(logits, dim=-1)[:, 1].item()
            if args.attn:
                outfile.with_suffix('.attn').mkdir(
                    parents=True, exist_ok=True)
                torch.save(
                    (p_weights, l_weights),
                    outfile.with_suffix('.attn') / f'{i}.pt')
            probs.append(prob)
            running_time.append(time.time() - start_time)

        df['prob'] = probs
        df['pred'] = df['prob'].apply(lambda x: 'Yes' if x >= 0.5 else 'No')
        df['time'] = running_time

        # move pred and prob to the front
        df = df[['pred', 'prob'] +
                [col for col in df.columns if col not in ['pred', 'prob']]]

        df.to_csv(outfile, index=False)
        print(f'Predictions saved to {outfile}')
        print(f'Attention weights saved to current directory as <index>.pt')