xushijie commited on
Commit
21f308b
·
0 Parent(s):
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ trash/
2
+ __pycache__/
3
+ scripts/
.streamlit/config.toml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [browser]
2
+ gatherUsageStats = false
3
+
4
+ [server]
5
+ headless = true
6
+ enableCORS = false
7
+ enableXsrfProtection = false
Dockerfile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11
2
+
3
+ ENV DEBIAN_FRONTEND=noninteractive
4
+ SHELL ["/bin/bash", "-c"]
5
+ RUN apt-get update -y \
6
+ && apt-get install -y build-essential python3-dev r-base make apt-utils unzip gpg doxygen git curl aria2 vim screen rsync wget locales gfortran mafft libglew-dev libeigen3-dev \
7
+ libpng-dev libfreetype6-dev libxml2-dev \
8
+ libmsgpack-dev python3-pyqt5.qtopengl libglm-dev libnetcdf-dev \
9
+ && locale-gen en_US.UTF-8 \
10
+ && rm -rf /var/lib/apt/lists/*
11
+
12
+ WORKDIR /app
13
+
14
+ COPY . .
15
+
16
+ RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
17
+ RUN pip3 install -r requirements.txt
18
+
19
+ # Ensure Hugging Face cache directory exists and is writable
20
+ RUN mkdir -p /app/.cache && chmod 777 /app/.cache
21
+ RUN mkdir -p /app/.cache/offload && chmod 777 /app/.cache/offload
22
+
23
+ EXPOSE 8501
24
+
25
+ HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
26
+
27
+ ENV HF_HOME=/app/.cache
28
+
29
+ ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: DELM
3
+ emoji: 🧬
4
+ colorFrom: red
5
+ colorTo: red
6
+ sdk: docker
7
+ app_port: 8501
8
+ tags:
9
+ - streamlit
10
+ pinned: false
11
+ short_description: Prediction of enzymatic degradation
12
+ license: cc-by-nc-sa-4.0
13
+ ---
14
+
15
+
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ transformers
3
+ fair-esm
4
+ lightning
5
+ cupy-cuda11x
6
+ scikit-learn
7
+ line_profiler
8
+ sentence-transformers
9
+ pandas
10
+ openpyxl
11
+ timm
12
+ wandb
13
+ accelerate
14
+ ipykernel
15
+ einops
16
+ SentencePiece
17
+ seaborn
18
+ streamlit
19
+ biotite
20
+ matplotlib
21
+ git+https://github.com/Ramprasad-Group/psmiles.git
22
+ py3Dmol
23
+ stmol
24
+ ipython_genutils
25
+ cryptography
src/checkpoints/weights.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:42319d679b0fac05d5e21979771295e9c97b333bc3e13a1ef0f4a8189e47dbc7
3
+ size 122999708
src/configs/train.yml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ plm: esm2_t33_650M_UR50D
2
+ train_csv: data/train_addition.csv
3
+ test_csv: data/test.csv
4
+ batch_size: 32
5
+ epochs: 100
6
+ lr: 0.001
7
+ wd: 0
8
+ num_workers: 4
9
+ amp: true
10
+ seed: 42
11
+ nfolds: 5
12
+ ckpt_dir: checkpoints
13
+ patience: 20
src/configs/tune.yml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ program: train.py
2
+ method: grid
3
+ metric:
4
+ name: val_acc
5
+ goal: maximize
6
+ parameters:
7
+ plm:
8
+ values:
9
+ - esm2_t33_650M_UR50D
10
+ - esm2_t48_15B_UR50D
11
+ - esm2_t36_3B_UR50D
12
+ - esm2_t12_35M_UR50D
13
+ - esm2_t30_150M_UR50D
14
+ - esm2_t6_8M_UR50D
15
+ - esm1b_t33_650M_UR50S
16
+ - prot_t5_xl_half_uniref50-enc
17
+ - prot_t5_xl_bfd
src/data/polymer2tok.csv ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ polymer_id,polymer
2
+ 0,PE
3
+ 1,PET
4
+ 2,PCL
5
+ 3,PHB
6
+ 4,PEF
7
+ 5,PBS
8
+ 6,PBSA
9
+ 7,PLA
10
+ 8,PHV
11
+ 9,PU
12
+ 10,PES
13
+ 11,PHA
14
+ 12,PHO
15
+ 13,PVA
16
+ 14,PPL
17
+ 15,P3HP
18
+ 16,P4HB
19
+ 17,PEA
20
+ 18,O-PVA
21
+ 19,P(3HB-co-3MP)
22
+ 20,PEG
23
+ 21,PHBV
24
+ 22,PHPV
25
+ 23,Nylon
26
+ 24,PBS-Blend
27
+ 25,PBSA-Blend
28
+ 26,P3HV
29
+ 27,PBAT
30
+ 28,PMCL
31
+ 29,LDPE
32
+ 30,PS
33
+ 31,NR
34
+ 32,PBSeT
35
+ 33,Ecovio-FT
36
+ 34,PHBH
37
+ 35,PHBVH
38
+ 36,Impranil
src/models/dataset.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset, DataLoader
5
+ import lightning as L
6
+ from pathlib import Path
7
+ import pandas as pd
8
+ from models.plm import get_model
9
+ from models.polybert import PolyEncoder, polymer2psmiles
10
+ from argparse import Namespace as Args
11
+ from sklearn.model_selection import KFold
12
+ from tqdm import tqdm
13
+ from torch.utils.data import WeightedRandomSampler
14
+
15
+
16
+ class EnzymeDataset(Dataset):
17
+ def __init__(self, csv_file: str, plm: str):
18
+ self.data_list = []
19
+ for i, row in pd.read_csv(csv_file).iterrows():
20
+ self.data_list.append(
21
+ (row['category'], row['sequence'].upper(), row['degradation'], row['sequence_id'], row['polymer_id']))
22
+ (cache_dir := Path('cache')).mkdir(parents=True, exist_ok=True)
23
+ Path(cache_dir, 'protein').mkdir(parents=True, exist_ok=True)
24
+ Path(cache_dir, 'protein', plm).mkdir(parents=True, exist_ok=True)
25
+ Path(cache_dir, 'polymer').mkdir(parents=True, exist_ok=True)
26
+ if not all(Path(cache_dir, 'protein', plm, f"{seqid}.pt").exists() for _, _, _, seqid, _ in self.data_list):
27
+ plm_func = get_model(plm, 'cuda')
28
+ for _, seq, _, seqid, _ in tqdm(self.data_list, desc='Encoding enzyme sequences'):
29
+ seq_path = Path(cache_dir, 'protein', plm, f'{seqid}.pt')
30
+ if not seq_path.exists():
31
+ seq_tensor = plm_func([seq])
32
+ torch.save(seq_tensor, seq_path)
33
+
34
+ def __len__(self):
35
+ return len(self.data_list)
36
+
37
+ def __getitem__(self, idx):
38
+ return self.data_list[idx]
39
+
40
+
41
+ class EnzymeDataModule(L.LightningDataModule):
42
+ def __init__(self, args: Args):
43
+ super().__init__()
44
+ self.args = args
45
+ self.train_csv = args.train_csv
46
+ self.test_csv = args.test_csv
47
+ self.batch_size = args.batch_size
48
+ self.num_workers = args.num_workers
49
+ self.plm = args.plm
50
+
51
+ self.train_val_set = EnzymeDataset(self.train_csv, self.plm)
52
+ self.test_set = EnzymeDataset(self.test_csv, self.plm)
53
+
54
+ self.kfold = KFold(
55
+ n_splits=args.nfolds, shuffle=True,
56
+ random_state=self.args.seed)
57
+ self.indices = list(range(len(self.train_val_set)))
58
+ self.splits = list(self.kfold.split(self.indices))
59
+
60
+ def setup_k_fold(self, fold_idx):
61
+ train_idx, val_idx = self.splits[fold_idx]
62
+
63
+ self.train_set = torch.utils.data.Subset(
64
+ self.train_val_set, train_idx)
65
+ self.val_set = torch.utils.data.Subset(
66
+ self.train_val_set, val_idx)
67
+ self.sampler = self.data_sampler()
68
+
69
+ def data_sampler(self):
70
+ # Get labels for train_set
71
+ if hasattr(self, 'train_set'):
72
+ # train_set is a Subset, get indices
73
+ indices = self.train_set.indices if hasattr(
74
+ self.train_set, 'indices') else range(len(self.train_set))
75
+ labels = [self.train_val_set[i][2] for i in indices]
76
+ # Compute class weights
77
+ label_counts = pd.Series(labels).value_counts()
78
+ weights = [1.0 / label_counts[label] for label in labels]
79
+ sampler = WeightedRandomSampler(
80
+ weights, num_samples=len(weights), replacement=True)
81
+ return sampler
82
+ else:
83
+ raise AttributeError(
84
+ 'train_set not initialized. Call setup_k_fold first.')
85
+
86
+ def train_dataloader(self):
87
+ return DataLoader(
88
+ self.train_set, batch_size=self.batch_size,
89
+ # shuffle=True,
90
+ num_workers=self.num_workers,
91
+ sampler=self.sampler,
92
+ )
93
+
94
+ def val_dataloader(self):
95
+ return DataLoader(
96
+ self.val_set, batch_size=self.batch_size,
97
+ shuffle=False, num_workers=self.num_workers,)
98
+
99
+ def test_dataloader(self):
100
+ return DataLoader(
101
+ self.test_set, batch_size=self.batch_size,
102
+ shuffle=False, num_workers=self.num_workers)
src/models/plm.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
4
+ import re
5
+
6
+
7
+ def EsmModelInfo(name: str):
8
+ """Get model info by name:
9
+ Args:
10
+ name: str, model name
11
+ Returns:
12
+ dict, model info: dim, layers, model
13
+ """
14
+ return {
15
+ "esm2_t48_15B_UR50D": {
16
+ "dim": 5120,
17
+ "layers": 48,
18
+ "model": "facebook/esm2_t48_15B_UR50D",
19
+ },
20
+ "esm2_t36_3B_UR50D": {
21
+ "dim": 2560,
22
+ "layers": 36,
23
+ "model": "facebook/esm2_t36_3B_UR50D",
24
+ },
25
+ "esm2_t33_650M_UR50D": {
26
+ "dim": 1280,
27
+ "layers": 33,
28
+ "model": "facebook/esm2_t33_650M_UR50D",
29
+ },
30
+ "esm2_t30_150M_UR50D": {
31
+ "dim": 640,
32
+ "layers": 30,
33
+ "model": "facebook/esm2_t30_150M_UR50D",
34
+ },
35
+ "esm2_t12_35M_UR50D": {
36
+ "dim": 480,
37
+ "layers": 12,
38
+ "model": "facebook/esm2_t12_35M_UR50D",
39
+ },
40
+ "esm2_t6_8M_UR50D": {
41
+ "dim": 320,
42
+ "layers": 6,
43
+ "model": "facebook/esm2_t6_8M_UR50D",
44
+ },
45
+ "esm1b_t33_650M_UR50S": {
46
+ "dim": 1280,
47
+ "layers": 33,
48
+ "model": "facebook/esm1b_t33_650M_UR50S",
49
+ },
50
+ "prot_t5_xl_half_uniref50-enc": {
51
+ "dim": 1024,
52
+ "layers": 24,
53
+ "model": "Rostlab/prot_t5_xl_uniref50",
54
+ },
55
+ "prot_t5_xl_bfd": {
56
+ "dim": 1024,
57
+ "layers": 24,
58
+ "model": "Rostlab/prot_t5_xl_bfd",
59
+ },
60
+ "esmc-6b-2024-12": {
61
+ "dim": 2560,
62
+ "layers": -1,
63
+ "model": "esmc-6b-2024-12",
64
+ },
65
+ "esmc_300m": {
66
+ "dim": 768,
67
+ "layers": -1,
68
+ "model": "esmc_300m",
69
+ },
70
+ "esmc_600m": {
71
+ "dim": 1152,
72
+ "layers": -1,
73
+ "model": "esmc_600m",
74
+ },
75
+ }[name]
76
+
77
+
78
+ plm2abbr = {
79
+ 'esm2_t48_15B_UR50D': 'ESM2_T48',
80
+ 'esm2_t36_3B_UR50D': 'ESM2_T36',
81
+ 'esm2_t33_650M_UR50D': 'ESM2_T33',
82
+ 'esm2_t30_150M_UR50D': 'ESM2_T30',
83
+ 'esm2_t12_35M_UR50D': 'ESM2_T12',
84
+ 'esm2_t6_8M_UR50D': 'ESM2_T6',
85
+ 'esm1b_t33_650M_UR50S': 'ESM1B_T33',
86
+ 'prot_t5_xl_half_uniref50-enc': 'PT_UR',
87
+ 'prot_t5_xl_bfd': 'PT_BFD',
88
+ }
89
+
90
+
91
+ class EsmEncoder(nn.Module):
92
+ def __init__(self, model_name, dev):
93
+ super().__init__()
94
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
95
+ self.model = AutoModel.from_pretrained(
96
+ # auto, balanced_low_0
97
+ model_name,
98
+ device_map="balanced",
99
+ # torch_dtype=torch.float16,
100
+ torch_dtype=torch.float32,
101
+ offload_folder=".cache/offload",
102
+ offload_state_dict=True,
103
+ )
104
+ if model_name == "facebook/esm2_t48_15B_UR50D":
105
+ self.max_len = 512
106
+ else:
107
+ self.max_len = 960
108
+ self.overlap = 31
109
+ self.model.eval()
110
+ # self.model.half()
111
+
112
+ def forward(self, _seqs):
113
+ with torch.no_grad():
114
+ assert len(_seqs) == 1, "currently only support batch size 1"
115
+ seqs = _seqs[0]
116
+ # left overlappping, right overlappping
117
+ seqs = [
118
+ seqs[max(0, i - self.overlap): (i + self.max_len + self.overlap)]
119
+ for i in range(0, len(seqs), self.max_len)
120
+ ]
121
+ segs = []
122
+ for seq in seqs:
123
+ inputs = self.tokenizer(
124
+ [seq],
125
+ return_tensors="pt",
126
+ ).to(self.model.device)
127
+ outputs = (
128
+ self.model(
129
+ **inputs).last_hidden_state.squeeze(0).detach().cpu()
130
+ )
131
+ outputs0 = self.model.embeddings(
132
+ **inputs).squeeze(0).detach().cpu()
133
+ segs.append(torch.stack([outputs0, outputs], dim=-1))
134
+ t = []
135
+ for i in range(len(seqs)):
136
+ if i == 0:
137
+ t.append(segs[i][1: (1 + self.max_len)])
138
+ elif i == len(seqs) - 1:
139
+ t.append(segs[i][1 + self.overlap:])
140
+ else:
141
+ t.append(
142
+ segs[i][1 + self.overlap: 1 +
143
+ self.max_len + self.overlap]
144
+ )
145
+ outputs = torch.cat(t, dim=0)[: len(_seqs[0])]
146
+ assert outputs.shape[0] == len(_seqs[0])
147
+ return outputs
148
+
149
+
150
+ class T5Encoder(nn.Module):
151
+ def __init__(self, name: str, dev) -> None:
152
+ super().__init__()
153
+ self.dev = dev
154
+ if name == "Rostlab/prot_t5_xl_uniref50":
155
+ # Load the tokenizer
156
+ self.tokenizer = T5Tokenizer.from_pretrained(
157
+ "Rostlab/prot_t5_xl_half_uniref50-enc",
158
+ do_lower_case=False,
159
+ legacy=False,
160
+ )
161
+ # Load the model
162
+ self.model = T5EncoderModel.from_pretrained(
163
+ "Rostlab/prot_t5_xl_half_uniref50-enc"
164
+ ).to(dev)
165
+ elif name == "Rostlab/prot_t5_xl_bfd":
166
+ # Load the tokenizer
167
+ self.tokenizer = T5Tokenizer.from_pretrained(
168
+ "Rostlab/prot_t5_xl_bfd",
169
+ do_lower_case=False,
170
+ legacy=False,
171
+ )
172
+ # Load the model
173
+ self.model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_bfd").to(
174
+ dev
175
+ )
176
+
177
+ self.max_len = 960 # start_token, end_token occupy 2 positions
178
+ self.overlap = 31
179
+ self.model.eval()
180
+ self.model.half()
181
+
182
+ def forward(self, _seqs):
183
+ with torch.no_grad():
184
+ assert len(_seqs) == 1, "currently only support batch size 1"
185
+ seqs = _seqs[0]
186
+ # replace non-amino acids with X
187
+ seqs = re.sub(r"[^A-Z]", "X", seqs)
188
+ # left overlappping, right overlappping
189
+ seqs = [
190
+ seqs[max(0, i - self.overlap)
191
+ : (i + self.max_len + self.overlap)]
192
+ for i in range(0, len(seqs), self.max_len)
193
+ ]
194
+ input_ids = self.tokenizer.batch_encode_plus(
195
+ [" ".join(list(s)) for s in seqs],
196
+ add_special_tokens=True,
197
+ padding="longest",
198
+ )["input_ids"]
199
+ input_ids = torch.tensor(input_ids).to(self.dev)
200
+ outputs = self.model(input_ids=input_ids)
201
+ outputs0 = self.model.get_input_embeddings()(input_ids)
202
+ outputs = outputs.last_hidden_state
203
+ outputs = torch.stack([outputs0, outputs], dim=-1)
204
+ t = []
205
+ for i in range(len(seqs)):
206
+ if i == 0:
207
+ t.append(outputs[i, 1: (1 + self.max_len)])
208
+ elif i == len(seqs) - 1:
209
+ t.append(outputs[i, 1 + self.overlap:])
210
+ else:
211
+ t.append(
212
+ outputs[i, 1 + self.overlap: 1 +
213
+ self.max_len + self.overlap]
214
+ )
215
+
216
+ outputs = torch.cat(t, dim=0)[: len(_seqs[0])]
217
+ assert outputs.shape[0] == len(_seqs[0]), \
218
+ f"outputs shape {outputs.shape} does not match input seqs length {len(_seqs[0])}: {seqs}"
219
+ return outputs
220
+
221
+
222
+ def get_model(name: str, dev):
223
+ "Get model by name"
224
+ if name in [
225
+ "esm2_t48_15B_UR50D",
226
+ "esm2_t36_3B_UR50D",
227
+ "esm2_t33_650M_UR50D",
228
+ "esm2_t30_150M_UR50D",
229
+ "esm2_t12_35M_UR50D",
230
+ "esm2_t6_8M_UR50D",
231
+ "esm1b_t33_650M_UR50S",
232
+ ]:
233
+ d = EsmModelInfo(name)
234
+ return EsmEncoder(d["model"], dev)
235
+ elif name in ["prot_t5_xl_half_uniref50-enc", "prot_t5_xl_bfd"]:
236
+ d = EsmModelInfo(name)
237
+ return T5Encoder(d["model"], dev)
238
+ else:
239
+ raise ValueError(f"Unknown model name: {name}")
src/models/polybert.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModel
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ polymer2psmiles = {
6
+ 'PHB': '[*]OC(C)CC(=O)[*]',
7
+ 'PCL': '[*]OCCCCCC(=O)[*]',
8
+ 'PVA': '[*]C(O)C[*]',
9
+ 'PPL': '[*]CCC(=O)O[*]',
10
+ 'P3HP': '[*]OCCC(=O)[*]',
11
+ 'P4HB': '[*]C(=O)CCCO[*]',
12
+ 'PEA': '[*]OCCOC(=O)CCCCC(=O)[*]',
13
+ 'PES': '[*]OCCOC(=O)CCC(=O)[*]',
14
+ 'O-PVA': '[*]C(=O)C[*]',
15
+ 'PBS': '[*]C(=O)CCC(=O)OCCCCO[*]',
16
+ 'PLA': '[*]C(C)C(=O)O[*]',
17
+ 'PEG': '[*]CCO[*]',
18
+ 'PBSA': '[*]C(=O)CCC(=O)OCCCCOC(=O)CCC(=O)[*]',
19
+ 'PET': '[*]CCOC(=O)c1ccc(C(=O)O[*])cc1',
20
+ 'PE': '[*]CC[*]',
21
+ 'PMCL': '[*]C(=O)CCC(C)CCO[*]',
22
+ 'PEF': '[*]OC(=O)c1oc(C(=O)OCC[*])cc1',
23
+ 'PS': '[*]C(c1ccccc1)C[*]',
24
+ 'NR': '[*]CC(C)=CC[*]',
25
+ 'PHV': '[*]OC(CC)CC(=O)[*]',
26
+ }
27
+
28
+
29
+ class PolyEncoder(nn.Module):
30
+ def __init__(self):
31
+ super().__init__()
32
+ self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
33
+ self.polyBERT = AutoModel.from_pretrained('kuelumbus/polyBERT')
34
+
35
+ def forward(self, psmiles_strings):
36
+ assert len(psmiles_strings) == 1, "Batch size must be 1 for PolyEncoder"
37
+ encoded_input = self.tokenizer(
38
+ psmiles_strings, padding=True, truncation=True, return_tensors='pt')
39
+ with torch.no_grad():
40
+ model_output = self.polyBERT(**encoded_input)
41
+ return model_output[0]
src/models/training.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from models.dataset import EnzymeDataModule
6
+ import lightning as L
7
+ from timm.scheduler.cosine_lr import CosineLRScheduler
8
+ from argparse import Namespace as Args
9
+ import wandb
10
+ import time
11
+ from models.dataset import polymer2psmiles
12
+ from models.plm import EsmModelInfo
13
+ from models.polybert import PolyEncoder
14
+ from models.utils import is_wandb_running
15
+ from pathlib import Path
16
+ from einops import rearrange
17
+ import numpy as np
18
+ from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, precision_score, recall_score
19
+
20
+
21
+ class CrossAttnLayer(nn.Module):
22
+ def __init__(self, protein_dim, smiles_dim, nheads=8):
23
+ super().__init__()
24
+ self.fc_smiles = nn.Linear(smiles_dim, protein_dim)
25
+ self.fc_protein = nn.Linear(protein_dim, smiles_dim)
26
+ self.smiles2protein = nn.MultiheadAttention(
27
+ smiles_dim, nheads, batch_first=True)
28
+ self.protein2smiles = nn.MultiheadAttention(
29
+ protein_dim, nheads, batch_first=True)
30
+
31
+ def forward(self, protein, smiles):
32
+ down_protein = self.fc_protein(protein)
33
+ up_smiles = self.fc_smiles(smiles)
34
+ l_attn, l_weights = self.smiles2protein(
35
+ smiles, down_protein, down_protein)
36
+ p_attn, p_weights = self.protein2smiles(protein, up_smiles, up_smiles)
37
+ return l_attn, p_attn, l_weights, p_weights
38
+
39
+
40
+ class BaseModel(nn.Module):
41
+ def __init__(self, in_dim1, in_dim2, n_classes):
42
+ super().__init__()
43
+ self.attn = CrossAttnLayer(in_dim1, in_dim2)
44
+ self.fc = nn.Linear(in_dim1 + in_dim2, n_classes)
45
+
46
+ def forward(self, x):
47
+ protein, smiles = x
48
+ P, L, P_weights, L_weights = self.attn(protein, smiles)
49
+ x = torch.cat((P.mean(dim=1), L.mean(dim=1)), dim=-1)
50
+ x = self.fc(x)
51
+ return x, P_weights, L_weights
52
+
53
+
54
+ class PlasticPredictor(L.LightningModule):
55
+ def __init__(self, args: L.LightningModule):
56
+ super().__init__()
57
+ self.args = args
58
+ info = EsmModelInfo(args.plm)
59
+ plm_dim = info['dim']*2 # the first and last layers are concatenated
60
+ pbert_dim = 600
61
+ self.model = BaseModel(
62
+ in_dim1=plm_dim, in_dim2=pbert_dim, n_classes=2)
63
+
64
+ self.cached_proteins = {}
65
+ self.cached_smiles = {}
66
+
67
+ self.encoder = {} # trick: use dictionary to exclude modules
68
+ self.encoder['polybert'] = PolyEncoder()
69
+
70
+ self.automatic_optimization = False
71
+
72
+ def forward(self, x):
73
+ pass
74
+
75
+ def get_protein_embedding(self, seq_id):
76
+ if seq_id not in self.cached_proteins:
77
+ seq_path = f'cache/protein/{self.args.plm}/{seq_id}.pt'
78
+ if not Path(seq_path).exists():
79
+ raise FileNotFoundError(
80
+ f"Protein embedding for {seq_id} not found.")
81
+ emb = torch.load(seq_path)
82
+ emb = rearrange(emb, 'b l d -> b (l d)')
83
+ self.cached_proteins[seq_id] = emb
84
+ return self.cached_proteins[seq_id]
85
+
86
+ def get_smiles_embedding(self, polymer):
87
+ smi = polymer2psmiles[polymer]
88
+ # mol = Chem.MolFromSmiles(smi)
89
+ # smi = Chem.MolToSmiles(mol, doRandom=True)
90
+ # # replace * with [*]
91
+ # smi = smi.replace('*', '[*]')
92
+
93
+ if smi not in self.cached_smiles:
94
+ # first dimension is 1
95
+ with torch.no_grad(), torch.inference_mode():
96
+ emb = self.encoder['polybert']([smi])[0, 2:-1, :]
97
+ self.cached_smiles[smi] = emb
98
+ return self.cached_smiles[smi]
99
+
100
+ def step(self, batch):
101
+ polymer, seq, deg, seq_id, poly_id = zip(batch)
102
+ seqs = [self.get_protein_embedding(s.item()) for s in seq_id[0]]
103
+ polys = [self.get_smiles_embedding(p) for p in polymer[0]]
104
+ protein_lengths = [len(s) for s in seqs]
105
+ smiles_lengths = [len(p) for p in polys]
106
+
107
+ seqs = nn.utils.rnn.pad_sequence(
108
+ seqs, batch_first=True).to(self.device)
109
+ polys = nn.utils.rnn.pad_sequence(
110
+ polys, batch_first=True).to(self.device)
111
+ protein_lengths = torch.tensor(
112
+ protein_lengths, dtype=torch.long).to(self.device)
113
+ smiles_lengths = torch.tensor(
114
+ smiles_lengths, dtype=torch.long).to(self.device)
115
+
116
+ logits, P_weights, L_weights = self.model((seqs, polys))
117
+ # Flatten the output for cross-entropy loss
118
+ logits = logits.view(-1, 2)
119
+ deg = deg[0].to(self.device)
120
+
121
+ loss = F.cross_entropy(logits, deg, reduction='mean')
122
+
123
+ y_prob = torch.softmax(logits, dim=-1)[:, 1]
124
+
125
+ return deg, y_prob, loss
126
+
127
+ def training_step(self, batch, batch_idx):
128
+ y, y_prob, loss = self.step(batch)
129
+ self.log_dict({"train/loss": loss, }, prog_bar=True)
130
+
131
+ self.manual_backward(loss)
132
+
133
+ self.optimizers().step()
134
+ self.lr_scheduler_step()
135
+ self.optimizers().zero_grad()
136
+
137
+ def validation_step(self, batch, batch_idx):
138
+ y, y_prob, loss = self.step(batch)
139
+ self.y.append(y.detach().cpu().numpy())
140
+ self.y_prob.append(y_prob.detach().cpu().numpy())
141
+
142
+ def on_validation_epoch_start(self):
143
+ self.y, self.y_prob = [], []
144
+
145
+ def on_validation_epoch_end(self):
146
+ y_prob = np.concatenate(self.y_prob, axis=0)
147
+ y = np.concatenate(self.y, axis=0)
148
+ auc = roc_auc_score(y, y_prob)
149
+ f1 = f1_score(y, y_prob > 0.5)
150
+ mcc = matthews_corrcoef(y, y_prob > 0.5)
151
+ precision = precision_score(y, y_prob > 0.5)
152
+ recall = recall_score(y, y_prob > 0.5)
153
+ self.log_dict({
154
+ "val_auc": auc,
155
+ "val_f1": f1,
156
+ "val_mcc": mcc,
157
+ "val_pre": precision,
158
+ "val_rec": recall,
159
+ }, prog_bar=True)
160
+
161
+ def configure_optimizers(self):
162
+ optimizer = torch.optim.AdamW(
163
+ self.model.parameters(), lr=self.args.lr, weight_decay=self.args.wd
164
+ )
165
+ warmup_steps = round(self.args.t_initial * 0.1)
166
+ lr_scheduler = CosineLRScheduler(
167
+ optimizer,
168
+ t_initial=self.args.t_initial,
169
+ lr_min=1e-5,
170
+ warmup_t=warmup_steps,
171
+ warmup_lr_init=1e-5,
172
+ warmup_prefix=True,
173
+ )
174
+ self.lr_scheduler = lr_scheduler
175
+ return [optimizer]
176
+
177
+ def lr_scheduler_step(self, *args, **kwargs):
178
+ if self.trainer.global_step < self.trainer.max_steps:
179
+ self.lr_scheduler.step_update(self.trainer.global_step)
180
+
181
+
182
+ def train_plastic(args: Args):
183
+ L.seed_everything(args.seed)
184
+ if is_wandb_running():
185
+ wandb.init(project="plastic-predictor",)
186
+ args.__dict__.update(dict(wandb.config))
187
+
188
+ dm = EnzymeDataModule(args)
189
+
190
+ for kfold in range(args.nfolds):
191
+ print(f"Training fold {kfold + 1}/{args.nfolds}")
192
+
193
+ model = PlasticPredictor(args)
194
+ dm.setup_k_fold(kfold)
195
+
196
+ print(
197
+ f'Data loaded: {len(dm.train_dataloader())} train, {len(dm.val_dataloader())} val, {len(dm.test_dataloader())} test')
198
+
199
+ devices = 1
200
+ logger = None
201
+ # devices = torch.cuda.device_count()
202
+ # logger = L.pytorch.loggers.WandbLogger(project="plastic-predictor",)
203
+
204
+ strategy = "ddp" if devices > 1 else "auto"
205
+ steps_per_epoch = len(dm.train_dataloader())
206
+ args.__dict__.update(
207
+ {
208
+ "batch_size": args.batch_size // devices,
209
+ "dev_count": devices,
210
+ "t_initial": args.epochs * steps_per_epoch,
211
+ "steps_per_epoch": steps_per_epoch,
212
+ }
213
+ )
214
+ print(f"Total steps: {args.t_initial}")
215
+ checkpoint = L.pytorch.callbacks.ModelCheckpoint(
216
+ dirpath=args.ckpt_dir,
217
+ filename=f"plastic-{{epoch:02d}}-{{val_auc:.4f}}",
218
+ # monitor="val_auc",
219
+ # mode="max",
220
+ )
221
+
222
+ early_stopping = L.pytorch.callbacks.EarlyStopping(
223
+ monitor="val_auc",
224
+ patience=args.patience,
225
+ mode="max",
226
+ verbose=True,
227
+ )
228
+ precision = "16-mixed" if args.amp else "32-true"
229
+ trainer = L.Trainer(
230
+ max_epochs=args.epochs,
231
+ accelerator="gpu",
232
+ devices=devices,
233
+ strategy=strategy,
234
+ precision=precision,
235
+ log_every_n_steps=1,
236
+ callbacks=[checkpoint, early_stopping],
237
+ # callbacks=[checkpoint],
238
+ # enable_checkpointing=False,
239
+ logger=logger,
240
+ )
241
+
242
+ trainer.fit(model, dm)
243
+
244
+ trainer.validate(model, dm.val_dataloader(),
245
+ ckpt_path="best", verbose=True)
246
+ time.sleep(1)
247
+ val_test_metrics = trainer.callback_metrics.copy()
248
+ trainer.validate(model, dm.test_dataloader(),
249
+ ckpt_path="best", verbose=True)
250
+
251
+ time.sleep(1)
252
+ val_test_metrics.update(
253
+ [(k.replace("val_", "test_"), v)
254
+ for k, v in trainer.callback_metrics.items()]
255
+ )
256
+
257
+ # val_test_metrics['fold'] = kfold + 1
258
+ # add _fold suffix to each key
259
+ val_test_metrics = {
260
+ k + f"_fold{kfold + 1}": v for k, v in val_test_metrics.items()
261
+ }
262
+ if is_wandb_running():
263
+ wandb.log(val_test_metrics)
src/models/utils.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import signal
3
+
4
+ import psutil
5
+ import torch
6
+ import yaml
7
+ from functools import wraps
8
+ import errno
9
+ import signal
10
+ import numpy as np
11
+ from scipy.spatial import KDTree
12
+ from math import ceil
13
+ from tqdm import tqdm
14
+ import line_profiler
15
+ import os
16
+ import base64
17
+ import pickle
18
+ from cryptography.hazmat.primitives.asymmetric import padding
19
+ from cryptography.hazmat.primitives import serialization, hashes
20
+ from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
21
+ from cryptography.hazmat.backends import default_backend
22
+
23
+ import io
24
+
25
+
26
+ def num_parameters(model: torch.nn.Module) -> int:
27
+ """Return the number of parameters in the model"""
28
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
29
+
30
+
31
+ class Config:
32
+ """Read configuration from a YAML file and store as attributes"""
33
+
34
+ def __init__(self, yaml_file: str):
35
+ with open(yaml_file, "r") as f:
36
+ config = yaml.safe_load(f)
37
+
38
+ for k, v in config.items():
39
+ setattr(self, k, v)
40
+
41
+ def update(self, new_yaml_file: str):
42
+ with open(new_yaml_file, "r") as f:
43
+ config = yaml.safe_load(f)
44
+
45
+ for k, v in config.items():
46
+ setattr(self, k, v)
47
+
48
+ def save(self, yaml_file: str):
49
+ with open(yaml_file, "w") as f:
50
+ yaml.dump(self.__dict__, f)
51
+
52
+
53
+ def memory_usage_psutil():
54
+ """Return the memory usage in percentage like top"""
55
+ process = psutil.Process(os.getpid())
56
+ mem = process.memory_percent()
57
+ return mem
58
+
59
+
60
+ def is_wandb_running():
61
+ """Check if wandb is running"""
62
+ return "WANDB_SWEEP_ID" in os.environ
63
+
64
+
65
+ class TimeoutError(Exception):
66
+ pass
67
+
68
+
69
+ def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
70
+ def decorator(func):
71
+ def _handle_timeout(signum, frame):
72
+ raise TimeoutError(error_message)
73
+
74
+ def wrapper(*args, **kwargs):
75
+ signal.signal(signal.SIGALRM, _handle_timeout)
76
+ signal.alarm(seconds)
77
+ try:
78
+ result = func(*args, **kwargs)
79
+ finally:
80
+ signal.alarm(0)
81
+ return result
82
+
83
+ return wraps(func)(wrapper)
84
+
85
+ return decorator
86
+
87
+
88
+ def shorten_path(path: str, max_len: int = 30) -> str:
89
+ """Shorten the path to max_len characters"""
90
+ if len(path) > max_len:
91
+ return path[:max_len // 2] + "..." + path[-max_len // 2:]
92
+ return path
93
+
94
+
95
+ def cluster_points(data: torch.Tensor, d: float) -> torch.Tensor:
96
+ """
97
+ Cluster points based on the Euclidean distance.
98
+
99
+ :param data: Input data, shape (n_points, n_features), type torch.Tensor.
100
+ :param d: Distance threshold for clustering.
101
+ :return: Cluster indices, shape (n_points,), type torch.Tensor.
102
+ """
103
+ dist = torch.cdist(data, data)
104
+ indices = torch.full((data.shape[0],), -1, dtype=torch.long)
105
+ cluster_id = 0
106
+ for i in range(data.shape[0]):
107
+ if indices[i] == -1:
108
+ indices[dist[i] < d] = cluster_id
109
+ cluster_id += 1
110
+ return indices
111
+
112
+
113
+ def bron_kerbosch(R, P, X, graph):
114
+ if not P and not X:
115
+ yield R
116
+ while P:
117
+ v = P.pop()
118
+ yield from bron_kerbosch(
119
+ R | {v},
120
+ P & set(graph[v]),
121
+ X & set(graph[v]),
122
+ graph
123
+ )
124
+ X.add(v)
125
+
126
+
127
+ def find_cliques(graph):
128
+ """
129
+ Find all maximal cliques in an undirected graph with the Bron–Kerbosch algorithm.
130
+
131
+ :param graph: Input graph as a NetworkX graph
132
+ :return: List of maximal cliques
133
+ """
134
+ return list(bron_kerbosch(set(), set(graph.nodes()), set(), graph))
135
+
136
+
137
+ def segment_cmd(cmd_str: str, max_len: int = 1000):
138
+ cmds = ['']
139
+ prev = 0
140
+ for i, c in enumerate(cmd_str):
141
+ if c == ';':
142
+ if len(cmds[-1]) + len(cmd_str[prev:i]) > max_len:
143
+ cmds.append('')
144
+ cmds[-1] += cmd_str[prev:i + 1]
145
+ prev = i + 1
146
+ return cmds
147
+
148
+
149
+ def get_color(v):
150
+ assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
151
+ # green to brown
152
+ color1 = np.array([0, 128, 0])
153
+ color2 = np.array([165, 42, 42])
154
+ v = v * (color2 - color1) + color1
155
+ v /= 255
156
+ return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'
157
+
158
+
159
+ def generate_pymol_script(possible_sites):
160
+ cmd = ''
161
+ for i, pos in enumerate(possible_sites):
162
+ cmd += f"pseudoatom s{i},pos=[{pos[0]:.1f},{pos[1]:.1f},{pos[2]:.1f}];color blue,s{i};"
163
+ return cmd
164
+
165
+
166
+ def remove_close_points_kdtree(points, min_distance):
167
+ tree = KDTree(points)
168
+ keep = np.ones(len(points), dtype=bool)
169
+ for i, point in enumerate(points):
170
+ if not keep[i]:
171
+ continue
172
+ neighbors = tree.query_ball_point(
173
+ point, min_distance)
174
+ keep[neighbors] = False
175
+ keep[i] = True # Keep the current point
176
+ return points[keep]
177
+
178
+
179
+ @line_profiler.profile
180
+ def pack_bit(x: torch.Tensor):
181
+ """ Pack the bit tensor to a sequence of bytes.
182
+ Args:
183
+ x (torch.Tensor): The input tensor to be packed.
184
+ Returns:
185
+ torch.Tensor: The packed tensor.
186
+ """
187
+ batch_size, num_bits = x.shape
188
+ num_bytes = (num_bits + 7) // 8
189
+ output = torch.zeros(batch_size, num_bytes,
190
+ dtype=torch.uint8, device=x.device)
191
+ for i in range(num_bits):
192
+ byte_index = i // 8
193
+ bit_index = i % 8
194
+ output[:, byte_index] |= (x[:, i] << bit_index).to(torch.uint8)
195
+ return output
196
+
197
+
198
+ @line_profiler.profile
199
+ def unpack_bit(x: torch.Tensor, num_bits: int):
200
+ """ Unpack the bit tensor from a sequence of bytes.
201
+ Args:
202
+ x (torch.Tensor): The input tensor to be unpacked.
203
+ num_bits (int): The number of bits to unpack.
204
+ Returns:
205
+ torch.Tensor: The unpacked tensor.
206
+ """
207
+ batch_size, num_bytes = x.shape
208
+ output = torch.zeros(batch_size, num_bits,
209
+ dtype=torch.uint8, device=x.device)
210
+ for i in range(num_bits):
211
+ byte_index = i // 8
212
+ bit_index = i % 8
213
+ output[:, i] = (x[:, byte_index] >> bit_index) & 1
214
+ return output
215
+
216
+
217
+ def safe_dist(vec1: torch.Tensor, vec2: torch.Tensor, max_size: int = 100_000_000, p: int = 2):
218
+ """ compute the minimum distance between two vectors:
219
+
220
+ vec1: (N, 3), N could be very very large, i.e., all atoms' coordinates in a large protein
221
+
222
+ vec2: (M, 3), M are not very large, usually the coordinates of the binding sites
223
+
224
+ max_size: the maximum size of the distance matrix to compute at once
225
+
226
+ p: the p-norm to use for distance calculation
227
+
228
+ return: (M, ) the minimum distance of each binding site to the protein
229
+ """
230
+ size1 = vec1.shape
231
+ size2 = vec2.shape
232
+ batch_size = ceil(max_size / size1[0])
233
+ dists = []
234
+ for i in range(0, size2[0], batch_size):
235
+ dist = torch.cdist(vec1, vec2[i:i + batch_size], p=p)
236
+ dists.append(dist.min(dim=0).values)
237
+ return torch.cat(dists)
238
+
239
+
240
+ @line_profiler.profile
241
+ def safe_filter(nos: torch.Tensor, pos: torch.Tensor, thr: torch.Tensor, all: torch.Tensor, lb: float, max_size: int = 100_000_000):
242
+ """ filter the binding sites based on the distance matrix
243
+ nos: (N, 3), N are the coordinates of the binding sites
244
+ *pos: (M, 3), M are the coordinates of the protein, could be very very large
245
+ thr: (N, 2), the distance threshold for each binding site
246
+ all: (P, 3), P are the coordinates of all atoms in the protein
247
+ lb: the lower bound of the distance
248
+
249
+ return: (N, M) available binding sites
250
+ """
251
+ N, M, P = nos.shape[0], pos.shape[0], all.shape[0]
252
+ batch_size = ceil(max_size / N)
253
+ output = []
254
+ interests = []
255
+ for i in tqdm(range(0, M, batch_size), leave=False, desc=f'Filtering (batch_size: {batch_size})'):
256
+ dist = torch.cdist(pos[i:i + batch_size], nos)
257
+ dist = (dist <= thr[:, 1].unsqueeze(0)) & \
258
+ (dist >= thr[:, 0].unsqueeze(0))
259
+ dist_all = safe_dist(all, pos[i:i + batch_size]) > lb
260
+ dist = dist & dist_all.unsqueeze(-1)
261
+
262
+ mask = dist.any(dim=1)
263
+ output.append(pack_bit(dist[mask]).T)
264
+ interests.append(mask)
265
+ return torch.cat(output, dim=1), torch.cat(interests)
266
+
267
+
268
+ def backbone(atoms, chain_id):
269
+ """ return the atoms of the backbone of a chain """
270
+ return atoms[
271
+ (atoms.chain_id == chain_id) &
272
+ (atoms.atom_name == "CA") &
273
+ (atoms.element == "C")]
274
+
275
+
276
+ def get_color(v):
277
+ assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
278
+ # green to brown
279
+ color1 = np.array([0, 128, 0])
280
+ color2 = np.array([165, 42, 42])
281
+ v = v * (color2 - color1) + color1
282
+ v /= 255
283
+ return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'
284
+
285
+
286
+ def load_private_key_from_file(private_key_file=None):
287
+ if private_key_file is None:
288
+ private_key_b64 = os.environ.get('ModelCheckpointPrivateKey')
289
+ else:
290
+ with open(private_key_file, 'r') as f:
291
+ private_key_b64 = f.read().strip()
292
+
293
+ private_pem = base64.b64decode(private_key_b64)
294
+ private_key = serialization.load_pem_private_key(
295
+ private_pem,
296
+ password=None,
297
+ backend=default_backend()
298
+ )
299
+ return private_key
300
+
301
+
302
+ def decrypt_checkpoint(encrypted_path, private_key):
303
+ backend = default_backend()
304
+
305
+ with open(encrypted_path, 'rb') as f:
306
+
307
+ key_length = int.from_bytes(f.read(4), 'big')
308
+
309
+ encrypted_aes_key = f.read(key_length)
310
+ iv = f.read(16)
311
+ original_size = int.from_bytes(f.read(8), 'big')
312
+ encrypted_data = f.read()
313
+
314
+ try:
315
+ aes_key = private_key.decrypt(
316
+ encrypted_aes_key,
317
+ padding.OAEP(
318
+ mgf=padding.MGF1(algorithm=hashes.SHA256()),
319
+ algorithm=hashes.SHA256(),
320
+ label=None
321
+ )
322
+ )
323
+
324
+ cipher = Cipher(algorithms.AES(aes_key),
325
+ modes.CBC(iv), backend=backend)
326
+ decryptor = cipher.decryptor()
327
+ decrypted_padded = decryptor.update(
328
+ encrypted_data) + decryptor.finalize()
329
+
330
+ decrypted_data = decrypted_padded[:original_size]
331
+
332
+ try:
333
+ buffer = io.BytesIO(decrypted_data)
334
+ checkpoint_dict = torch.load(buffer, map_location='cpu')
335
+ return checkpoint_dict
336
+ except:
337
+ checkpoint_dict = pickle.loads(decrypted_data)
338
+ return checkpoint_dict
339
+
340
+ except Exception as e:
341
+ print(f"Error: {e}")
342
+ raise
src/predict.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from einops import rearrange
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import time
7
+ from models.polybert import PolyEncoder
8
+ from models.training import BaseModel
9
+ from models.utils import decrypt_checkpoint, load_private_key_from_file
10
+ import argparse
11
+ from tqdm import tqdm
12
+
13
+ from models.utils import Config
14
+ from models.plm import EsmModelInfo, get_model
15
+ import pandas as pd
16
+
17
+
18
+ if __name__ == "__main__":
19
+ # fmt: off
20
+ parser = argparse.ArgumentParser(description="Predict plastic degradation")
21
+ parser.add_argument("--ckpt", type=str, help="Path to the model checkpoint")
22
+ parser.add_argument("--plm", type=str, help="Protein language model to use", default='esm2_t33_650M_UR50D')
23
+ parser.add_argument("--csv", type=str, help="Path to the CSV file with test data", default=None)
24
+ parser.add_argument("--output",'-o', type=str, help="Path to the output file", default='predictions.csv')
25
+ parser.add_argument("--attn", action='store_true', help="Save attention weights to files")
26
+ # fmt: on
27
+ args = parser.parse_args()
28
+
29
+ info = EsmModelInfo(args.plm)
30
+ plm_dim = info['dim']*2
31
+ pbert_dim = 600
32
+
33
+ dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
34
+ model = BaseModel(plm_dim, pbert_dim, n_classes=2).to(dev)
35
+
36
+ # load weights
37
+ private_key = load_private_key_from_file()
38
+ state_dict = decrypt_checkpoint(args.ckpt, private_key)
39
+ state_dict = {
40
+ k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')}
41
+ model.load_state_dict(state_dict)
42
+ model.eval()
43
+ print(f'Load predictor from {args.ckpt}')
44
+
45
+ plm_func = get_model(args.plm, 'cuda')
46
+ print(f'Loaded PLM model {args.plm}')
47
+
48
+ polybert_func = PolyEncoder()
49
+ print('Loaded PolyEncoder model')
50
+
51
+ outfile = Path(
52
+ 'predictions.csv' if args.output is None else args.output)
53
+ # get protein embedding
54
+ with torch.no_grad(), torch.inference_mode():
55
+ df = pd.read_csv(args.csv)
56
+ probs = []
57
+ running_time = []
58
+ for i, row in tqdm(df.iterrows()):
59
+ start_time = time.time()
60
+
61
+ seq = row['sequence'].upper()
62
+ poly = row['polymer']
63
+ seq_emb = plm_func([seq]).to(dev)
64
+ seq_emb = rearrange(seq_emb, 'b l d -> b (l d)').unsqueeze(0)
65
+ poly_emb = polybert_func([poly]).to(dev)
66
+ logits, p_weights, l_weights = model((seq_emb, poly_emb))
67
+ prob = F.softmax(logits, dim=-1)[:, 1].item()
68
+ if args.attn:
69
+ outfile.with_suffix('.attn').mkdir(
70
+ parents=True, exist_ok=True)
71
+ torch.save(
72
+ (p_weights, l_weights),
73
+ outfile.with_suffix('.attn') / f'{i}.pt')
74
+ probs.append(prob)
75
+ running_time.append(time.time() - start_time)
76
+
77
+ df['prob'] = probs
78
+ df['pred'] = df['prob'].apply(lambda x: 'Yes' if x >= 0.5 else 'No')
79
+ df['time'] = running_time
80
+
81
+ # move pred and prob to the front
82
+ df = df[['pred', 'prob'] +
83
+ [col for col in df.columns if col not in ['pred', 'prob']]]
84
+
85
+ df.to_csv(outfile, index=False)
86
+ print(f'Predictions saved to {outfile}')
87
+ print(f'Attention weights saved to current directory as <index>.pt')
src/streamlit_app.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #fmt: off
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import os
5
+ import tempfile
6
+ import subprocess
7
+ import requests
8
+ import csv
9
+ from models.polybert import polymer2psmiles
10
+ import py3Dmol
11
+
12
+ # Fix for permission error - disable usage stats
13
+ if 'STREAMLIT_CONFIG_DIR' not in os.environ:
14
+ os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit'
15
+
16
+ # Create streamlit config directory if it doesn't exist
17
+ streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit')
18
+ os.makedirs(streamlit_dir, exist_ok=True)
19
+
20
+ # Create config.toml to disable usage stats
21
+ config_path = os.path.join(streamlit_dir, 'config.toml')
22
+ if not os.path.exists(config_path):
23
+ with open(config_path, 'w') as f:
24
+ f.write("""[browser]
25
+ gatherUsageStats = false
26
+
27
+ [server]
28
+ headless = true
29
+ enableCORS = false
30
+ enableXsrfProtection = false
31
+ """)
32
+ # fmt: on
33
+
34
+ aa2resn = {
35
+ 'A': 'ALA',
36
+ 'C': 'CYS',
37
+ 'D': 'ASP',
38
+ 'E': 'GLU',
39
+ 'F': 'PHE',
40
+ 'G': 'GLY',
41
+ 'H': 'HIS',
42
+ 'I': 'ILE',
43
+ 'K': 'LYS',
44
+ 'L': 'LEU',
45
+ 'M': 'MET',
46
+ 'N': 'ASN',
47
+ 'P': 'PRO',
48
+ 'Q': 'GLN',
49
+ 'R': 'ARG',
50
+ 'S': 'SER',
51
+ 'T': 'THR',
52
+ 'V': 'VAL',
53
+ 'W': 'TRP',
54
+ 'Y': 'TYR'
55
+ }
56
+
57
+ # Fancy header
58
+ st.markdown("""
59
+ <div style='text-align: center;'>
60
+ <h1 style='color:#377EB9;font-size:2.5em;'>🧬 Plastic Degradation Predictor</h1>
61
+ <h3 style='color:#4DAE48;'>Predict the degradability of plastics using protein sequences and polymer SMILES</h3>
62
+ </div>
63
+ <hr style='border:1px solid #974F9F;'>
64
+ """, unsafe_allow_html=True)
65
+
66
+ st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.")
67
+
68
+
69
+ # Load polymer names and SMILES
70
+
71
+ # Only show polymers with SMILES in the dropdown
72
+ polymer_csv = os.path.join(os.path.dirname(
73
+ __file__), 'data/polymer2tok.csv')
74
+ polymer_options = []
75
+ with open(polymer_csv, newline='') as f:
76
+ reader = csv.DictReader(f)
77
+ for row in reader:
78
+ name = row['polymer']
79
+ smiles = polymer2psmiles.get(name, '')
80
+ if smiles: # Only include polymers with SMILES
81
+ polymer_options.append(f"{name} | {smiles}")
82
+
83
+ input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"])
84
+
85
+ if input_type == "UniProt ID":
86
+ uniprot_id = st.text_input("Enter UniProt ID", "P69905")
87
+ sequence = ""
88
+ if uniprot_id:
89
+ # Fetch sequence from UniProt
90
+ url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
91
+ resp = requests.get(url)
92
+ if resp.status_code == 200:
93
+ fasta = resp.text
94
+ sequence = "".join(fasta.split("\n")[1:])
95
+ st.success(f"Fetched sequence for {uniprot_id}")
96
+ st.code(sequence)
97
+ else:
98
+ st.error("Failed to fetch sequence from UniProt.")
99
+ else:
100
+ sequence = st.text_area("Paste protein sequence",
101
+ "MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG")
102
+
103
+ polymer = st.selectbox("Select polymer", polymer_options)
104
+ selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer
105
+
106
+
107
+ ckpt = "src/checkpoints/weights.ckpt"
108
+ plm = "esm2_t33_650M_UR50D"
109
+
110
+ if st.button("Predict degradation", type="primary"):
111
+ if not sequence or not selected_polymer:
112
+ st.error("Please provide both sequence and polymer.")
113
+ else:
114
+ # Create temp CSV
115
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp:
116
+ tmp.write("sequence,polymer\n")
117
+ tmp.write(f"{sequence},{selected_polymer}\n")
118
+ tmp_path = tmp.name
119
+ output_path = os.path.join(tempfile.gettempdir(), "predictions.csv")
120
+ st.write("Running prediction...")
121
+ result = subprocess.run([
122
+ "python", "src/predict.py",
123
+ "--ckpt", ckpt,
124
+ "--plm", plm,
125
+ "--csv", tmp_path,
126
+ "--output", output_path,
127
+ "--attn"
128
+ ], capture_output=True, text=True)
129
+ if result.returncode == 0 and os.path.exists(output_path):
130
+ df = pd.read_csv(output_path)
131
+ if 'time' in df.columns:
132
+ df = df.rename(columns={'time': 'running time'})
133
+ st.markdown(f"""
134
+ <div style='background: linear-gradient(90deg, #377EB9 0%, #4DAE48 100%); padding: 1.5em; border-radius: 12px; color: white; margin-bottom: 1em;'>
135
+ <h2 style='margin:0;'><span style='font-size:18pt'>✅</span> Prediction Complete!</h2>
136
+ <p style='font-size:12pt;'>Your input has been processed. See the results below:</p>
137
+ <p style='font-size:12pt;'>Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})</p>
138
+ </div>
139
+ """, unsafe_allow_html=True)
140
+ st.download_button("⬇️ Download Results", data=df.to_csv(
141
+ index=False), file_name="predictions.csv", type="primary")
142
+
143
+ # Show top-N attention residues if attention file exists
144
+ attn_dir = os.path.join(os.path.dirname(
145
+ output_path), "predictions.attn")
146
+ attn_path = os.path.join(attn_dir, "0.pt")
147
+ if os.path.exists(attn_path):
148
+ import torch
149
+ attn = torch.load(attn_path)
150
+ # attn[0][0]: shape (num_heads, seq_len, seq_len) or (1, seq_len, seq_len)
151
+ attn_matrix = attn[0][0] if isinstance(
152
+ attn[0], (list, tuple)) else attn[0]
153
+ # Average over heads if needed
154
+ if attn_matrix.ndim == 3:
155
+ attn_matrix = attn_matrix.mean(0)
156
+ # For each residue, sum attention weights
157
+ residue_scores = attn_matrix.sum(0).cpu().numpy()
158
+ topN = min(10, len(residue_scores))
159
+ top_idx = residue_scores.argsort()[::-1][:topN]
160
+ st.markdown(f"**Top {topN} high-attention residues:**")
161
+ st.write(pd.DataFrame({
162
+ "Amino Acid": [sequence[i] for i in top_idx],
163
+ "Residue Index": top_idx+1,
164
+ "Attention Score": residue_scores[top_idx]
165
+ }))
166
+ else:
167
+ st.info("No attention file found for visualization.")
168
+ else:
169
+ st.error("Prediction failed. See details below:")
170
+ st.text(result.stderr)
171
+
172
+ # If UniProt ID, try to download AlphaFold structure
173
+ structure_path = None
174
+
175
+ if input_type == "UniProt ID" and uniprot_id:
176
+ af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.cif"
177
+ # If attention available, highlight top residues
178
+ highlight_residues = None
179
+ attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn")
180
+ attn_path = os.path.join(attn_dir, "0.pt")
181
+ if os.path.exists(attn_path):
182
+ import torch
183
+ attn = torch.load(attn_path)
184
+ attn_matrix = attn[0][0] if isinstance(
185
+ attn[0], (list, tuple)) else attn[0]
186
+ if attn_matrix.ndim == 3:
187
+ attn_matrix = attn_matrix.mean(0)
188
+ residue_scores = attn_matrix.sum(0).cpu().numpy()
189
+ topN = min(10, len(residue_scores))
190
+ top_idx = residue_scores.argsort()[::-1][:topN]
191
+ # Molstar selection: list of residue numbers (1-based)
192
+ highlight_residues = [int(i+1) for i in top_idx]
193
+
194
+ structure_path = os.path.join(
195
+ tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif")
196
+ try:
197
+ r = requests.get(af_url)
198
+ if r.status_code == 200:
199
+ with open(structure_path, "wb") as f:
200
+ f.write(r.content)
201
+ st.success(
202
+ f"AlphaFold structure downloaded: {structure_path}")
203
+ else:
204
+ st.warning(
205
+ "AlphaFoldDB structure not found for this UniProt ID.")
206
+ except Exception as e:
207
+ st.warning(f"AlphaFoldDB download error: {e}")
208
+
209
+ if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path):
210
+ st.markdown("### 3D Structure Visualization (stmol)")
211
+ import torch
212
+ from stmol import showmol
213
+ attn = torch.load(attn_path)
214
+ attn_matrix = attn[0][0] if isinstance(
215
+ attn[0], (list, tuple)) else attn[0]
216
+ if attn_matrix.ndim == 3:
217
+ attn_matrix = attn_matrix.mean(0)
218
+ residue_scores = attn_matrix.sum(0).cpu().numpy()
219
+ topN = min(10, len(residue_scores))
220
+ top_idx = residue_scores.argsort()[::-1][:topN]
221
+ labels = [
222
+ f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx]
223
+ with open(structure_path, "r") as cif_file:
224
+ cif_data = cif_file.read()
225
+ view = py3Dmol.view(width=600, height=400)
226
+ view.addModel(cif_data, "cif")
227
+ view.setStyle({"cartoon": {"color": "lightgray"}})
228
+ for i, idx in enumerate(top_idx):
229
+ resi_num = int(idx+1)
230
+ view.setStyle(
231
+ {"resi": resi_num}, {
232
+ "cartoon": {"color": "red"}})
233
+ view.addResLabels(
234
+ {"resi": resi_num},
235
+ {
236
+ "font": 'Arial', "fontColor": 'black',
237
+ "showBackground": False, "screenOffset": {"x": 0, "y": 0}})
238
+ view.zoomTo()
239
+ showmol(view, height=600, width='100%')
240
+
241
+
242
+ # --- Footer: License and References ---
243
+ st.markdown("""
244
+ ---
245
+ <h4>License</h4>
246
+ Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)<br>
247
+ <a href='https://creativecommons.org/licenses/by-nc-sa/4.0/' target='_blank'>View full license details</a><br>
248
+ """, unsafe_allow_html=True)