xushijie commited on
Commit ·
21f308b
0
Parent(s):
add app
Browse files- .gitattributes +36 -0
- .gitignore +3 -0
- .streamlit/config.toml +7 -0
- Dockerfile +29 -0
- README.md +15 -0
- requirements.txt +25 -0
- src/checkpoints/weights.ckpt +3 -0
- src/configs/train.yml +13 -0
- src/configs/tune.yml +17 -0
- src/data/polymer2tok.csv +38 -0
- src/models/dataset.py +102 -0
- src/models/plm.py +239 -0
- src/models/polybert.py +41 -0
- src/models/training.py +263 -0
- src/models/utils.py +342 -0
- src/predict.py +87 -0
- src/streamlit_app.py +248 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
trash/
|
| 2 |
+
__pycache__/
|
| 3 |
+
scripts/
|
.streamlit/config.toml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[browser]
|
| 2 |
+
gatherUsageStats = false
|
| 3 |
+
|
| 4 |
+
[server]
|
| 5 |
+
headless = true
|
| 6 |
+
enableCORS = false
|
| 7 |
+
enableXsrfProtection = false
|
Dockerfile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11
|
| 2 |
+
|
| 3 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 4 |
+
SHELL ["/bin/bash", "-c"]
|
| 5 |
+
RUN apt-get update -y \
|
| 6 |
+
&& apt-get install -y build-essential python3-dev r-base make apt-utils unzip gpg doxygen git curl aria2 vim screen rsync wget locales gfortran mafft libglew-dev libeigen3-dev \
|
| 7 |
+
libpng-dev libfreetype6-dev libxml2-dev \
|
| 8 |
+
libmsgpack-dev python3-pyqt5.qtopengl libglm-dev libnetcdf-dev \
|
| 9 |
+
&& locale-gen en_US.UTF-8 \
|
| 10 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
+
|
| 12 |
+
WORKDIR /app
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
|
| 16 |
+
RUN pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cpu
|
| 17 |
+
RUN pip3 install -r requirements.txt
|
| 18 |
+
|
| 19 |
+
# Ensure Hugging Face cache directory exists and is writable
|
| 20 |
+
RUN mkdir -p /app/.cache && chmod 777 /app/.cache
|
| 21 |
+
RUN mkdir -p /app/.cache/offload && chmod 777 /app/.cache/offload
|
| 22 |
+
|
| 23 |
+
EXPOSE 8501
|
| 24 |
+
|
| 25 |
+
HEALTHCHECK CMD curl --fail http://localhost:8501/_stcore/health
|
| 26 |
+
|
| 27 |
+
ENV HF_HOME=/app/.cache
|
| 28 |
+
|
| 29 |
+
ENTRYPOINT ["streamlit", "run", "src/streamlit_app.py", "--server.port=8501", "--server.address=0.0.0.0"]
|
README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: DELM
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: red
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: docker
|
| 7 |
+
app_port: 8501
|
| 8 |
+
tags:
|
| 9 |
+
- streamlit
|
| 10 |
+
pinned: false
|
| 11 |
+
short_description: Prediction of enzymatic degradation
|
| 12 |
+
license: cc-by-nc-sa-4.0
|
| 13 |
+
---
|
| 14 |
+
|
| 15 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tqdm
|
| 2 |
+
transformers
|
| 3 |
+
fair-esm
|
| 4 |
+
lightning
|
| 5 |
+
cupy-cuda11x
|
| 6 |
+
scikit-learn
|
| 7 |
+
line_profiler
|
| 8 |
+
sentence-transformers
|
| 9 |
+
pandas
|
| 10 |
+
openpyxl
|
| 11 |
+
timm
|
| 12 |
+
wandb
|
| 13 |
+
accelerate
|
| 14 |
+
ipykernel
|
| 15 |
+
einops
|
| 16 |
+
SentencePiece
|
| 17 |
+
seaborn
|
| 18 |
+
streamlit
|
| 19 |
+
biotite
|
| 20 |
+
matplotlib
|
| 21 |
+
git+https://github.com/Ramprasad-Group/psmiles.git
|
| 22 |
+
py3Dmol
|
| 23 |
+
stmol
|
| 24 |
+
ipython_genutils
|
| 25 |
+
cryptography
|
src/checkpoints/weights.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:42319d679b0fac05d5e21979771295e9c97b333bc3e13a1ef0f4a8189e47dbc7
|
| 3 |
+
size 122999708
|
src/configs/train.yml
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
plm: esm2_t33_650M_UR50D
|
| 2 |
+
train_csv: data/train_addition.csv
|
| 3 |
+
test_csv: data/test.csv
|
| 4 |
+
batch_size: 32
|
| 5 |
+
epochs: 100
|
| 6 |
+
lr: 0.001
|
| 7 |
+
wd: 0
|
| 8 |
+
num_workers: 4
|
| 9 |
+
amp: true
|
| 10 |
+
seed: 42
|
| 11 |
+
nfolds: 5
|
| 12 |
+
ckpt_dir: checkpoints
|
| 13 |
+
patience: 20
|
src/configs/tune.yml
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
program: train.py
|
| 2 |
+
method: grid
|
| 3 |
+
metric:
|
| 4 |
+
name: val_acc
|
| 5 |
+
goal: maximize
|
| 6 |
+
parameters:
|
| 7 |
+
plm:
|
| 8 |
+
values:
|
| 9 |
+
- esm2_t33_650M_UR50D
|
| 10 |
+
- esm2_t48_15B_UR50D
|
| 11 |
+
- esm2_t36_3B_UR50D
|
| 12 |
+
- esm2_t12_35M_UR50D
|
| 13 |
+
- esm2_t30_150M_UR50D
|
| 14 |
+
- esm2_t6_8M_UR50D
|
| 15 |
+
- esm1b_t33_650M_UR50S
|
| 16 |
+
- prot_t5_xl_half_uniref50-enc
|
| 17 |
+
- prot_t5_xl_bfd
|
src/data/polymer2tok.csv
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
polymer_id,polymer
|
| 2 |
+
0,PE
|
| 3 |
+
1,PET
|
| 4 |
+
2,PCL
|
| 5 |
+
3,PHB
|
| 6 |
+
4,PEF
|
| 7 |
+
5,PBS
|
| 8 |
+
6,PBSA
|
| 9 |
+
7,PLA
|
| 10 |
+
8,PHV
|
| 11 |
+
9,PU
|
| 12 |
+
10,PES
|
| 13 |
+
11,PHA
|
| 14 |
+
12,PHO
|
| 15 |
+
13,PVA
|
| 16 |
+
14,PPL
|
| 17 |
+
15,P3HP
|
| 18 |
+
16,P4HB
|
| 19 |
+
17,PEA
|
| 20 |
+
18,O-PVA
|
| 21 |
+
19,P(3HB-co-3MP)
|
| 22 |
+
20,PEG
|
| 23 |
+
21,PHBV
|
| 24 |
+
22,PHPV
|
| 25 |
+
23,Nylon
|
| 26 |
+
24,PBS-Blend
|
| 27 |
+
25,PBSA-Blend
|
| 28 |
+
26,P3HV
|
| 29 |
+
27,PBAT
|
| 30 |
+
28,PMCL
|
| 31 |
+
29,LDPE
|
| 32 |
+
30,PS
|
| 33 |
+
31,NR
|
| 34 |
+
32,PBSeT
|
| 35 |
+
33,Ecovio-FT
|
| 36 |
+
34,PHBH
|
| 37 |
+
35,PHBVH
|
| 38 |
+
36,Impranil
|
src/models/dataset.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torch.utils.data import Dataset, DataLoader
|
| 5 |
+
import lightning as L
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from models.plm import get_model
|
| 9 |
+
from models.polybert import PolyEncoder, polymer2psmiles
|
| 10 |
+
from argparse import Namespace as Args
|
| 11 |
+
from sklearn.model_selection import KFold
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from torch.utils.data import WeightedRandomSampler
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class EnzymeDataset(Dataset):
|
| 17 |
+
def __init__(self, csv_file: str, plm: str):
|
| 18 |
+
self.data_list = []
|
| 19 |
+
for i, row in pd.read_csv(csv_file).iterrows():
|
| 20 |
+
self.data_list.append(
|
| 21 |
+
(row['category'], row['sequence'].upper(), row['degradation'], row['sequence_id'], row['polymer_id']))
|
| 22 |
+
(cache_dir := Path('cache')).mkdir(parents=True, exist_ok=True)
|
| 23 |
+
Path(cache_dir, 'protein').mkdir(parents=True, exist_ok=True)
|
| 24 |
+
Path(cache_dir, 'protein', plm).mkdir(parents=True, exist_ok=True)
|
| 25 |
+
Path(cache_dir, 'polymer').mkdir(parents=True, exist_ok=True)
|
| 26 |
+
if not all(Path(cache_dir, 'protein', plm, f"{seqid}.pt").exists() for _, _, _, seqid, _ in self.data_list):
|
| 27 |
+
plm_func = get_model(plm, 'cuda')
|
| 28 |
+
for _, seq, _, seqid, _ in tqdm(self.data_list, desc='Encoding enzyme sequences'):
|
| 29 |
+
seq_path = Path(cache_dir, 'protein', plm, f'{seqid}.pt')
|
| 30 |
+
if not seq_path.exists():
|
| 31 |
+
seq_tensor = plm_func([seq])
|
| 32 |
+
torch.save(seq_tensor, seq_path)
|
| 33 |
+
|
| 34 |
+
def __len__(self):
|
| 35 |
+
return len(self.data_list)
|
| 36 |
+
|
| 37 |
+
def __getitem__(self, idx):
|
| 38 |
+
return self.data_list[idx]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class EnzymeDataModule(L.LightningDataModule):
|
| 42 |
+
def __init__(self, args: Args):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.args = args
|
| 45 |
+
self.train_csv = args.train_csv
|
| 46 |
+
self.test_csv = args.test_csv
|
| 47 |
+
self.batch_size = args.batch_size
|
| 48 |
+
self.num_workers = args.num_workers
|
| 49 |
+
self.plm = args.plm
|
| 50 |
+
|
| 51 |
+
self.train_val_set = EnzymeDataset(self.train_csv, self.plm)
|
| 52 |
+
self.test_set = EnzymeDataset(self.test_csv, self.plm)
|
| 53 |
+
|
| 54 |
+
self.kfold = KFold(
|
| 55 |
+
n_splits=args.nfolds, shuffle=True,
|
| 56 |
+
random_state=self.args.seed)
|
| 57 |
+
self.indices = list(range(len(self.train_val_set)))
|
| 58 |
+
self.splits = list(self.kfold.split(self.indices))
|
| 59 |
+
|
| 60 |
+
def setup_k_fold(self, fold_idx):
|
| 61 |
+
train_idx, val_idx = self.splits[fold_idx]
|
| 62 |
+
|
| 63 |
+
self.train_set = torch.utils.data.Subset(
|
| 64 |
+
self.train_val_set, train_idx)
|
| 65 |
+
self.val_set = torch.utils.data.Subset(
|
| 66 |
+
self.train_val_set, val_idx)
|
| 67 |
+
self.sampler = self.data_sampler()
|
| 68 |
+
|
| 69 |
+
def data_sampler(self):
|
| 70 |
+
# Get labels for train_set
|
| 71 |
+
if hasattr(self, 'train_set'):
|
| 72 |
+
# train_set is a Subset, get indices
|
| 73 |
+
indices = self.train_set.indices if hasattr(
|
| 74 |
+
self.train_set, 'indices') else range(len(self.train_set))
|
| 75 |
+
labels = [self.train_val_set[i][2] for i in indices]
|
| 76 |
+
# Compute class weights
|
| 77 |
+
label_counts = pd.Series(labels).value_counts()
|
| 78 |
+
weights = [1.0 / label_counts[label] for label in labels]
|
| 79 |
+
sampler = WeightedRandomSampler(
|
| 80 |
+
weights, num_samples=len(weights), replacement=True)
|
| 81 |
+
return sampler
|
| 82 |
+
else:
|
| 83 |
+
raise AttributeError(
|
| 84 |
+
'train_set not initialized. Call setup_k_fold first.')
|
| 85 |
+
|
| 86 |
+
def train_dataloader(self):
|
| 87 |
+
return DataLoader(
|
| 88 |
+
self.train_set, batch_size=self.batch_size,
|
| 89 |
+
# shuffle=True,
|
| 90 |
+
num_workers=self.num_workers,
|
| 91 |
+
sampler=self.sampler,
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
def val_dataloader(self):
|
| 95 |
+
return DataLoader(
|
| 96 |
+
self.val_set, batch_size=self.batch_size,
|
| 97 |
+
shuffle=False, num_workers=self.num_workers,)
|
| 98 |
+
|
| 99 |
+
def test_dataloader(self):
|
| 100 |
+
return DataLoader(
|
| 101 |
+
self.test_set, batch_size=self.batch_size,
|
| 102 |
+
shuffle=False, num_workers=self.num_workers)
|
src/models/plm.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers import AutoModel, AutoTokenizer, T5EncoderModel, T5Tokenizer
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def EsmModelInfo(name: str):
|
| 8 |
+
"""Get model info by name:
|
| 9 |
+
Args:
|
| 10 |
+
name: str, model name
|
| 11 |
+
Returns:
|
| 12 |
+
dict, model info: dim, layers, model
|
| 13 |
+
"""
|
| 14 |
+
return {
|
| 15 |
+
"esm2_t48_15B_UR50D": {
|
| 16 |
+
"dim": 5120,
|
| 17 |
+
"layers": 48,
|
| 18 |
+
"model": "facebook/esm2_t48_15B_UR50D",
|
| 19 |
+
},
|
| 20 |
+
"esm2_t36_3B_UR50D": {
|
| 21 |
+
"dim": 2560,
|
| 22 |
+
"layers": 36,
|
| 23 |
+
"model": "facebook/esm2_t36_3B_UR50D",
|
| 24 |
+
},
|
| 25 |
+
"esm2_t33_650M_UR50D": {
|
| 26 |
+
"dim": 1280,
|
| 27 |
+
"layers": 33,
|
| 28 |
+
"model": "facebook/esm2_t33_650M_UR50D",
|
| 29 |
+
},
|
| 30 |
+
"esm2_t30_150M_UR50D": {
|
| 31 |
+
"dim": 640,
|
| 32 |
+
"layers": 30,
|
| 33 |
+
"model": "facebook/esm2_t30_150M_UR50D",
|
| 34 |
+
},
|
| 35 |
+
"esm2_t12_35M_UR50D": {
|
| 36 |
+
"dim": 480,
|
| 37 |
+
"layers": 12,
|
| 38 |
+
"model": "facebook/esm2_t12_35M_UR50D",
|
| 39 |
+
},
|
| 40 |
+
"esm2_t6_8M_UR50D": {
|
| 41 |
+
"dim": 320,
|
| 42 |
+
"layers": 6,
|
| 43 |
+
"model": "facebook/esm2_t6_8M_UR50D",
|
| 44 |
+
},
|
| 45 |
+
"esm1b_t33_650M_UR50S": {
|
| 46 |
+
"dim": 1280,
|
| 47 |
+
"layers": 33,
|
| 48 |
+
"model": "facebook/esm1b_t33_650M_UR50S",
|
| 49 |
+
},
|
| 50 |
+
"prot_t5_xl_half_uniref50-enc": {
|
| 51 |
+
"dim": 1024,
|
| 52 |
+
"layers": 24,
|
| 53 |
+
"model": "Rostlab/prot_t5_xl_uniref50",
|
| 54 |
+
},
|
| 55 |
+
"prot_t5_xl_bfd": {
|
| 56 |
+
"dim": 1024,
|
| 57 |
+
"layers": 24,
|
| 58 |
+
"model": "Rostlab/prot_t5_xl_bfd",
|
| 59 |
+
},
|
| 60 |
+
"esmc-6b-2024-12": {
|
| 61 |
+
"dim": 2560,
|
| 62 |
+
"layers": -1,
|
| 63 |
+
"model": "esmc-6b-2024-12",
|
| 64 |
+
},
|
| 65 |
+
"esmc_300m": {
|
| 66 |
+
"dim": 768,
|
| 67 |
+
"layers": -1,
|
| 68 |
+
"model": "esmc_300m",
|
| 69 |
+
},
|
| 70 |
+
"esmc_600m": {
|
| 71 |
+
"dim": 1152,
|
| 72 |
+
"layers": -1,
|
| 73 |
+
"model": "esmc_600m",
|
| 74 |
+
},
|
| 75 |
+
}[name]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
plm2abbr = {
|
| 79 |
+
'esm2_t48_15B_UR50D': 'ESM2_T48',
|
| 80 |
+
'esm2_t36_3B_UR50D': 'ESM2_T36',
|
| 81 |
+
'esm2_t33_650M_UR50D': 'ESM2_T33',
|
| 82 |
+
'esm2_t30_150M_UR50D': 'ESM2_T30',
|
| 83 |
+
'esm2_t12_35M_UR50D': 'ESM2_T12',
|
| 84 |
+
'esm2_t6_8M_UR50D': 'ESM2_T6',
|
| 85 |
+
'esm1b_t33_650M_UR50S': 'ESM1B_T33',
|
| 86 |
+
'prot_t5_xl_half_uniref50-enc': 'PT_UR',
|
| 87 |
+
'prot_t5_xl_bfd': 'PT_BFD',
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class EsmEncoder(nn.Module):
|
| 92 |
+
def __init__(self, model_name, dev):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 95 |
+
self.model = AutoModel.from_pretrained(
|
| 96 |
+
# auto, balanced_low_0
|
| 97 |
+
model_name,
|
| 98 |
+
device_map="balanced",
|
| 99 |
+
# torch_dtype=torch.float16,
|
| 100 |
+
torch_dtype=torch.float32,
|
| 101 |
+
offload_folder=".cache/offload",
|
| 102 |
+
offload_state_dict=True,
|
| 103 |
+
)
|
| 104 |
+
if model_name == "facebook/esm2_t48_15B_UR50D":
|
| 105 |
+
self.max_len = 512
|
| 106 |
+
else:
|
| 107 |
+
self.max_len = 960
|
| 108 |
+
self.overlap = 31
|
| 109 |
+
self.model.eval()
|
| 110 |
+
# self.model.half()
|
| 111 |
+
|
| 112 |
+
def forward(self, _seqs):
|
| 113 |
+
with torch.no_grad():
|
| 114 |
+
assert len(_seqs) == 1, "currently only support batch size 1"
|
| 115 |
+
seqs = _seqs[0]
|
| 116 |
+
# left overlappping, right overlappping
|
| 117 |
+
seqs = [
|
| 118 |
+
seqs[max(0, i - self.overlap): (i + self.max_len + self.overlap)]
|
| 119 |
+
for i in range(0, len(seqs), self.max_len)
|
| 120 |
+
]
|
| 121 |
+
segs = []
|
| 122 |
+
for seq in seqs:
|
| 123 |
+
inputs = self.tokenizer(
|
| 124 |
+
[seq],
|
| 125 |
+
return_tensors="pt",
|
| 126 |
+
).to(self.model.device)
|
| 127 |
+
outputs = (
|
| 128 |
+
self.model(
|
| 129 |
+
**inputs).last_hidden_state.squeeze(0).detach().cpu()
|
| 130 |
+
)
|
| 131 |
+
outputs0 = self.model.embeddings(
|
| 132 |
+
**inputs).squeeze(0).detach().cpu()
|
| 133 |
+
segs.append(torch.stack([outputs0, outputs], dim=-1))
|
| 134 |
+
t = []
|
| 135 |
+
for i in range(len(seqs)):
|
| 136 |
+
if i == 0:
|
| 137 |
+
t.append(segs[i][1: (1 + self.max_len)])
|
| 138 |
+
elif i == len(seqs) - 1:
|
| 139 |
+
t.append(segs[i][1 + self.overlap:])
|
| 140 |
+
else:
|
| 141 |
+
t.append(
|
| 142 |
+
segs[i][1 + self.overlap: 1 +
|
| 143 |
+
self.max_len + self.overlap]
|
| 144 |
+
)
|
| 145 |
+
outputs = torch.cat(t, dim=0)[: len(_seqs[0])]
|
| 146 |
+
assert outputs.shape[0] == len(_seqs[0])
|
| 147 |
+
return outputs
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
class T5Encoder(nn.Module):
|
| 151 |
+
def __init__(self, name: str, dev) -> None:
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.dev = dev
|
| 154 |
+
if name == "Rostlab/prot_t5_xl_uniref50":
|
| 155 |
+
# Load the tokenizer
|
| 156 |
+
self.tokenizer = T5Tokenizer.from_pretrained(
|
| 157 |
+
"Rostlab/prot_t5_xl_half_uniref50-enc",
|
| 158 |
+
do_lower_case=False,
|
| 159 |
+
legacy=False,
|
| 160 |
+
)
|
| 161 |
+
# Load the model
|
| 162 |
+
self.model = T5EncoderModel.from_pretrained(
|
| 163 |
+
"Rostlab/prot_t5_xl_half_uniref50-enc"
|
| 164 |
+
).to(dev)
|
| 165 |
+
elif name == "Rostlab/prot_t5_xl_bfd":
|
| 166 |
+
# Load the tokenizer
|
| 167 |
+
self.tokenizer = T5Tokenizer.from_pretrained(
|
| 168 |
+
"Rostlab/prot_t5_xl_bfd",
|
| 169 |
+
do_lower_case=False,
|
| 170 |
+
legacy=False,
|
| 171 |
+
)
|
| 172 |
+
# Load the model
|
| 173 |
+
self.model = T5EncoderModel.from_pretrained("Rostlab/prot_t5_xl_bfd").to(
|
| 174 |
+
dev
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
self.max_len = 960 # start_token, end_token occupy 2 positions
|
| 178 |
+
self.overlap = 31
|
| 179 |
+
self.model.eval()
|
| 180 |
+
self.model.half()
|
| 181 |
+
|
| 182 |
+
def forward(self, _seqs):
|
| 183 |
+
with torch.no_grad():
|
| 184 |
+
assert len(_seqs) == 1, "currently only support batch size 1"
|
| 185 |
+
seqs = _seqs[0]
|
| 186 |
+
# replace non-amino acids with X
|
| 187 |
+
seqs = re.sub(r"[^A-Z]", "X", seqs)
|
| 188 |
+
# left overlappping, right overlappping
|
| 189 |
+
seqs = [
|
| 190 |
+
seqs[max(0, i - self.overlap)
|
| 191 |
+
: (i + self.max_len + self.overlap)]
|
| 192 |
+
for i in range(0, len(seqs), self.max_len)
|
| 193 |
+
]
|
| 194 |
+
input_ids = self.tokenizer.batch_encode_plus(
|
| 195 |
+
[" ".join(list(s)) for s in seqs],
|
| 196 |
+
add_special_tokens=True,
|
| 197 |
+
padding="longest",
|
| 198 |
+
)["input_ids"]
|
| 199 |
+
input_ids = torch.tensor(input_ids).to(self.dev)
|
| 200 |
+
outputs = self.model(input_ids=input_ids)
|
| 201 |
+
outputs0 = self.model.get_input_embeddings()(input_ids)
|
| 202 |
+
outputs = outputs.last_hidden_state
|
| 203 |
+
outputs = torch.stack([outputs0, outputs], dim=-1)
|
| 204 |
+
t = []
|
| 205 |
+
for i in range(len(seqs)):
|
| 206 |
+
if i == 0:
|
| 207 |
+
t.append(outputs[i, 1: (1 + self.max_len)])
|
| 208 |
+
elif i == len(seqs) - 1:
|
| 209 |
+
t.append(outputs[i, 1 + self.overlap:])
|
| 210 |
+
else:
|
| 211 |
+
t.append(
|
| 212 |
+
outputs[i, 1 + self.overlap: 1 +
|
| 213 |
+
self.max_len + self.overlap]
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
outputs = torch.cat(t, dim=0)[: len(_seqs[0])]
|
| 217 |
+
assert outputs.shape[0] == len(_seqs[0]), \
|
| 218 |
+
f"outputs shape {outputs.shape} does not match input seqs length {len(_seqs[0])}: {seqs}"
|
| 219 |
+
return outputs
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
def get_model(name: str, dev):
|
| 223 |
+
"Get model by name"
|
| 224 |
+
if name in [
|
| 225 |
+
"esm2_t48_15B_UR50D",
|
| 226 |
+
"esm2_t36_3B_UR50D",
|
| 227 |
+
"esm2_t33_650M_UR50D",
|
| 228 |
+
"esm2_t30_150M_UR50D",
|
| 229 |
+
"esm2_t12_35M_UR50D",
|
| 230 |
+
"esm2_t6_8M_UR50D",
|
| 231 |
+
"esm1b_t33_650M_UR50S",
|
| 232 |
+
]:
|
| 233 |
+
d = EsmModelInfo(name)
|
| 234 |
+
return EsmEncoder(d["model"], dev)
|
| 235 |
+
elif name in ["prot_t5_xl_half_uniref50-enc", "prot_t5_xl_bfd"]:
|
| 236 |
+
d = EsmModelInfo(name)
|
| 237 |
+
return T5Encoder(d["model"], dev)
|
| 238 |
+
else:
|
| 239 |
+
raise ValueError(f"Unknown model name: {name}")
|
src/models/polybert.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from transformers import AutoTokenizer, AutoModel
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
|
| 5 |
+
polymer2psmiles = {
|
| 6 |
+
'PHB': '[*]OC(C)CC(=O)[*]',
|
| 7 |
+
'PCL': '[*]OCCCCCC(=O)[*]',
|
| 8 |
+
'PVA': '[*]C(O)C[*]',
|
| 9 |
+
'PPL': '[*]CCC(=O)O[*]',
|
| 10 |
+
'P3HP': '[*]OCCC(=O)[*]',
|
| 11 |
+
'P4HB': '[*]C(=O)CCCO[*]',
|
| 12 |
+
'PEA': '[*]OCCOC(=O)CCCCC(=O)[*]',
|
| 13 |
+
'PES': '[*]OCCOC(=O)CCC(=O)[*]',
|
| 14 |
+
'O-PVA': '[*]C(=O)C[*]',
|
| 15 |
+
'PBS': '[*]C(=O)CCC(=O)OCCCCO[*]',
|
| 16 |
+
'PLA': '[*]C(C)C(=O)O[*]',
|
| 17 |
+
'PEG': '[*]CCO[*]',
|
| 18 |
+
'PBSA': '[*]C(=O)CCC(=O)OCCCCOC(=O)CCC(=O)[*]',
|
| 19 |
+
'PET': '[*]CCOC(=O)c1ccc(C(=O)O[*])cc1',
|
| 20 |
+
'PE': '[*]CC[*]',
|
| 21 |
+
'PMCL': '[*]C(=O)CCC(C)CCO[*]',
|
| 22 |
+
'PEF': '[*]OC(=O)c1oc(C(=O)OCC[*])cc1',
|
| 23 |
+
'PS': '[*]C(c1ccccc1)C[*]',
|
| 24 |
+
'NR': '[*]CC(C)=CC[*]',
|
| 25 |
+
'PHV': '[*]OC(CC)CC(=O)[*]',
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class PolyEncoder(nn.Module):
|
| 30 |
+
def __init__(self):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.tokenizer = AutoTokenizer.from_pretrained('kuelumbus/polyBERT')
|
| 33 |
+
self.polyBERT = AutoModel.from_pretrained('kuelumbus/polyBERT')
|
| 34 |
+
|
| 35 |
+
def forward(self, psmiles_strings):
|
| 36 |
+
assert len(psmiles_strings) == 1, "Batch size must be 1 for PolyEncoder"
|
| 37 |
+
encoded_input = self.tokenizer(
|
| 38 |
+
psmiles_strings, padding=True, truncation=True, return_tensors='pt')
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
model_output = self.polyBERT(**encoded_input)
|
| 41 |
+
return model_output[0]
|
src/models/training.py
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
from models.dataset import EnzymeDataModule
|
| 6 |
+
import lightning as L
|
| 7 |
+
from timm.scheduler.cosine_lr import CosineLRScheduler
|
| 8 |
+
from argparse import Namespace as Args
|
| 9 |
+
import wandb
|
| 10 |
+
import time
|
| 11 |
+
from models.dataset import polymer2psmiles
|
| 12 |
+
from models.plm import EsmModelInfo
|
| 13 |
+
from models.polybert import PolyEncoder
|
| 14 |
+
from models.utils import is_wandb_running
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
import numpy as np
|
| 18 |
+
from sklearn.metrics import roc_auc_score, f1_score, matthews_corrcoef, precision_score, recall_score
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class CrossAttnLayer(nn.Module):
|
| 22 |
+
def __init__(self, protein_dim, smiles_dim, nheads=8):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.fc_smiles = nn.Linear(smiles_dim, protein_dim)
|
| 25 |
+
self.fc_protein = nn.Linear(protein_dim, smiles_dim)
|
| 26 |
+
self.smiles2protein = nn.MultiheadAttention(
|
| 27 |
+
smiles_dim, nheads, batch_first=True)
|
| 28 |
+
self.protein2smiles = nn.MultiheadAttention(
|
| 29 |
+
protein_dim, nheads, batch_first=True)
|
| 30 |
+
|
| 31 |
+
def forward(self, protein, smiles):
|
| 32 |
+
down_protein = self.fc_protein(protein)
|
| 33 |
+
up_smiles = self.fc_smiles(smiles)
|
| 34 |
+
l_attn, l_weights = self.smiles2protein(
|
| 35 |
+
smiles, down_protein, down_protein)
|
| 36 |
+
p_attn, p_weights = self.protein2smiles(protein, up_smiles, up_smiles)
|
| 37 |
+
return l_attn, p_attn, l_weights, p_weights
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
class BaseModel(nn.Module):
|
| 41 |
+
def __init__(self, in_dim1, in_dim2, n_classes):
|
| 42 |
+
super().__init__()
|
| 43 |
+
self.attn = CrossAttnLayer(in_dim1, in_dim2)
|
| 44 |
+
self.fc = nn.Linear(in_dim1 + in_dim2, n_classes)
|
| 45 |
+
|
| 46 |
+
def forward(self, x):
|
| 47 |
+
protein, smiles = x
|
| 48 |
+
P, L, P_weights, L_weights = self.attn(protein, smiles)
|
| 49 |
+
x = torch.cat((P.mean(dim=1), L.mean(dim=1)), dim=-1)
|
| 50 |
+
x = self.fc(x)
|
| 51 |
+
return x, P_weights, L_weights
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class PlasticPredictor(L.LightningModule):
|
| 55 |
+
def __init__(self, args: L.LightningModule):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.args = args
|
| 58 |
+
info = EsmModelInfo(args.plm)
|
| 59 |
+
plm_dim = info['dim']*2 # the first and last layers are concatenated
|
| 60 |
+
pbert_dim = 600
|
| 61 |
+
self.model = BaseModel(
|
| 62 |
+
in_dim1=plm_dim, in_dim2=pbert_dim, n_classes=2)
|
| 63 |
+
|
| 64 |
+
self.cached_proteins = {}
|
| 65 |
+
self.cached_smiles = {}
|
| 66 |
+
|
| 67 |
+
self.encoder = {} # trick: use dictionary to exclude modules
|
| 68 |
+
self.encoder['polybert'] = PolyEncoder()
|
| 69 |
+
|
| 70 |
+
self.automatic_optimization = False
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
def get_protein_embedding(self, seq_id):
|
| 76 |
+
if seq_id not in self.cached_proteins:
|
| 77 |
+
seq_path = f'cache/protein/{self.args.plm}/{seq_id}.pt'
|
| 78 |
+
if not Path(seq_path).exists():
|
| 79 |
+
raise FileNotFoundError(
|
| 80 |
+
f"Protein embedding for {seq_id} not found.")
|
| 81 |
+
emb = torch.load(seq_path)
|
| 82 |
+
emb = rearrange(emb, 'b l d -> b (l d)')
|
| 83 |
+
self.cached_proteins[seq_id] = emb
|
| 84 |
+
return self.cached_proteins[seq_id]
|
| 85 |
+
|
| 86 |
+
def get_smiles_embedding(self, polymer):
|
| 87 |
+
smi = polymer2psmiles[polymer]
|
| 88 |
+
# mol = Chem.MolFromSmiles(smi)
|
| 89 |
+
# smi = Chem.MolToSmiles(mol, doRandom=True)
|
| 90 |
+
# # replace * with [*]
|
| 91 |
+
# smi = smi.replace('*', '[*]')
|
| 92 |
+
|
| 93 |
+
if smi not in self.cached_smiles:
|
| 94 |
+
# first dimension is 1
|
| 95 |
+
with torch.no_grad(), torch.inference_mode():
|
| 96 |
+
emb = self.encoder['polybert']([smi])[0, 2:-1, :]
|
| 97 |
+
self.cached_smiles[smi] = emb
|
| 98 |
+
return self.cached_smiles[smi]
|
| 99 |
+
|
| 100 |
+
def step(self, batch):
|
| 101 |
+
polymer, seq, deg, seq_id, poly_id = zip(batch)
|
| 102 |
+
seqs = [self.get_protein_embedding(s.item()) for s in seq_id[0]]
|
| 103 |
+
polys = [self.get_smiles_embedding(p) for p in polymer[0]]
|
| 104 |
+
protein_lengths = [len(s) for s in seqs]
|
| 105 |
+
smiles_lengths = [len(p) for p in polys]
|
| 106 |
+
|
| 107 |
+
seqs = nn.utils.rnn.pad_sequence(
|
| 108 |
+
seqs, batch_first=True).to(self.device)
|
| 109 |
+
polys = nn.utils.rnn.pad_sequence(
|
| 110 |
+
polys, batch_first=True).to(self.device)
|
| 111 |
+
protein_lengths = torch.tensor(
|
| 112 |
+
protein_lengths, dtype=torch.long).to(self.device)
|
| 113 |
+
smiles_lengths = torch.tensor(
|
| 114 |
+
smiles_lengths, dtype=torch.long).to(self.device)
|
| 115 |
+
|
| 116 |
+
logits, P_weights, L_weights = self.model((seqs, polys))
|
| 117 |
+
# Flatten the output for cross-entropy loss
|
| 118 |
+
logits = logits.view(-1, 2)
|
| 119 |
+
deg = deg[0].to(self.device)
|
| 120 |
+
|
| 121 |
+
loss = F.cross_entropy(logits, deg, reduction='mean')
|
| 122 |
+
|
| 123 |
+
y_prob = torch.softmax(logits, dim=-1)[:, 1]
|
| 124 |
+
|
| 125 |
+
return deg, y_prob, loss
|
| 126 |
+
|
| 127 |
+
def training_step(self, batch, batch_idx):
|
| 128 |
+
y, y_prob, loss = self.step(batch)
|
| 129 |
+
self.log_dict({"train/loss": loss, }, prog_bar=True)
|
| 130 |
+
|
| 131 |
+
self.manual_backward(loss)
|
| 132 |
+
|
| 133 |
+
self.optimizers().step()
|
| 134 |
+
self.lr_scheduler_step()
|
| 135 |
+
self.optimizers().zero_grad()
|
| 136 |
+
|
| 137 |
+
def validation_step(self, batch, batch_idx):
|
| 138 |
+
y, y_prob, loss = self.step(batch)
|
| 139 |
+
self.y.append(y.detach().cpu().numpy())
|
| 140 |
+
self.y_prob.append(y_prob.detach().cpu().numpy())
|
| 141 |
+
|
| 142 |
+
def on_validation_epoch_start(self):
|
| 143 |
+
self.y, self.y_prob = [], []
|
| 144 |
+
|
| 145 |
+
def on_validation_epoch_end(self):
|
| 146 |
+
y_prob = np.concatenate(self.y_prob, axis=0)
|
| 147 |
+
y = np.concatenate(self.y, axis=0)
|
| 148 |
+
auc = roc_auc_score(y, y_prob)
|
| 149 |
+
f1 = f1_score(y, y_prob > 0.5)
|
| 150 |
+
mcc = matthews_corrcoef(y, y_prob > 0.5)
|
| 151 |
+
precision = precision_score(y, y_prob > 0.5)
|
| 152 |
+
recall = recall_score(y, y_prob > 0.5)
|
| 153 |
+
self.log_dict({
|
| 154 |
+
"val_auc": auc,
|
| 155 |
+
"val_f1": f1,
|
| 156 |
+
"val_mcc": mcc,
|
| 157 |
+
"val_pre": precision,
|
| 158 |
+
"val_rec": recall,
|
| 159 |
+
}, prog_bar=True)
|
| 160 |
+
|
| 161 |
+
def configure_optimizers(self):
|
| 162 |
+
optimizer = torch.optim.AdamW(
|
| 163 |
+
self.model.parameters(), lr=self.args.lr, weight_decay=self.args.wd
|
| 164 |
+
)
|
| 165 |
+
warmup_steps = round(self.args.t_initial * 0.1)
|
| 166 |
+
lr_scheduler = CosineLRScheduler(
|
| 167 |
+
optimizer,
|
| 168 |
+
t_initial=self.args.t_initial,
|
| 169 |
+
lr_min=1e-5,
|
| 170 |
+
warmup_t=warmup_steps,
|
| 171 |
+
warmup_lr_init=1e-5,
|
| 172 |
+
warmup_prefix=True,
|
| 173 |
+
)
|
| 174 |
+
self.lr_scheduler = lr_scheduler
|
| 175 |
+
return [optimizer]
|
| 176 |
+
|
| 177 |
+
def lr_scheduler_step(self, *args, **kwargs):
|
| 178 |
+
if self.trainer.global_step < self.trainer.max_steps:
|
| 179 |
+
self.lr_scheduler.step_update(self.trainer.global_step)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def train_plastic(args: Args):
|
| 183 |
+
L.seed_everything(args.seed)
|
| 184 |
+
if is_wandb_running():
|
| 185 |
+
wandb.init(project="plastic-predictor",)
|
| 186 |
+
args.__dict__.update(dict(wandb.config))
|
| 187 |
+
|
| 188 |
+
dm = EnzymeDataModule(args)
|
| 189 |
+
|
| 190 |
+
for kfold in range(args.nfolds):
|
| 191 |
+
print(f"Training fold {kfold + 1}/{args.nfolds}")
|
| 192 |
+
|
| 193 |
+
model = PlasticPredictor(args)
|
| 194 |
+
dm.setup_k_fold(kfold)
|
| 195 |
+
|
| 196 |
+
print(
|
| 197 |
+
f'Data loaded: {len(dm.train_dataloader())} train, {len(dm.val_dataloader())} val, {len(dm.test_dataloader())} test')
|
| 198 |
+
|
| 199 |
+
devices = 1
|
| 200 |
+
logger = None
|
| 201 |
+
# devices = torch.cuda.device_count()
|
| 202 |
+
# logger = L.pytorch.loggers.WandbLogger(project="plastic-predictor",)
|
| 203 |
+
|
| 204 |
+
strategy = "ddp" if devices > 1 else "auto"
|
| 205 |
+
steps_per_epoch = len(dm.train_dataloader())
|
| 206 |
+
args.__dict__.update(
|
| 207 |
+
{
|
| 208 |
+
"batch_size": args.batch_size // devices,
|
| 209 |
+
"dev_count": devices,
|
| 210 |
+
"t_initial": args.epochs * steps_per_epoch,
|
| 211 |
+
"steps_per_epoch": steps_per_epoch,
|
| 212 |
+
}
|
| 213 |
+
)
|
| 214 |
+
print(f"Total steps: {args.t_initial}")
|
| 215 |
+
checkpoint = L.pytorch.callbacks.ModelCheckpoint(
|
| 216 |
+
dirpath=args.ckpt_dir,
|
| 217 |
+
filename=f"plastic-{{epoch:02d}}-{{val_auc:.4f}}",
|
| 218 |
+
# monitor="val_auc",
|
| 219 |
+
# mode="max",
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
early_stopping = L.pytorch.callbacks.EarlyStopping(
|
| 223 |
+
monitor="val_auc",
|
| 224 |
+
patience=args.patience,
|
| 225 |
+
mode="max",
|
| 226 |
+
verbose=True,
|
| 227 |
+
)
|
| 228 |
+
precision = "16-mixed" if args.amp else "32-true"
|
| 229 |
+
trainer = L.Trainer(
|
| 230 |
+
max_epochs=args.epochs,
|
| 231 |
+
accelerator="gpu",
|
| 232 |
+
devices=devices,
|
| 233 |
+
strategy=strategy,
|
| 234 |
+
precision=precision,
|
| 235 |
+
log_every_n_steps=1,
|
| 236 |
+
callbacks=[checkpoint, early_stopping],
|
| 237 |
+
# callbacks=[checkpoint],
|
| 238 |
+
# enable_checkpointing=False,
|
| 239 |
+
logger=logger,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
trainer.fit(model, dm)
|
| 243 |
+
|
| 244 |
+
trainer.validate(model, dm.val_dataloader(),
|
| 245 |
+
ckpt_path="best", verbose=True)
|
| 246 |
+
time.sleep(1)
|
| 247 |
+
val_test_metrics = trainer.callback_metrics.copy()
|
| 248 |
+
trainer.validate(model, dm.test_dataloader(),
|
| 249 |
+
ckpt_path="best", verbose=True)
|
| 250 |
+
|
| 251 |
+
time.sleep(1)
|
| 252 |
+
val_test_metrics.update(
|
| 253 |
+
[(k.replace("val_", "test_"), v)
|
| 254 |
+
for k, v in trainer.callback_metrics.items()]
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# val_test_metrics['fold'] = kfold + 1
|
| 258 |
+
# add _fold suffix to each key
|
| 259 |
+
val_test_metrics = {
|
| 260 |
+
k + f"_fold{kfold + 1}": v for k, v in val_test_metrics.items()
|
| 261 |
+
}
|
| 262 |
+
if is_wandb_running():
|
| 263 |
+
wandb.log(val_test_metrics)
|
src/models/utils.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import signal
|
| 3 |
+
|
| 4 |
+
import psutil
|
| 5 |
+
import torch
|
| 6 |
+
import yaml
|
| 7 |
+
from functools import wraps
|
| 8 |
+
import errno
|
| 9 |
+
import signal
|
| 10 |
+
import numpy as np
|
| 11 |
+
from scipy.spatial import KDTree
|
| 12 |
+
from math import ceil
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import line_profiler
|
| 15 |
+
import os
|
| 16 |
+
import base64
|
| 17 |
+
import pickle
|
| 18 |
+
from cryptography.hazmat.primitives.asymmetric import padding
|
| 19 |
+
from cryptography.hazmat.primitives import serialization, hashes
|
| 20 |
+
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
| 21 |
+
from cryptography.hazmat.backends import default_backend
|
| 22 |
+
|
| 23 |
+
import io
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def num_parameters(model: torch.nn.Module) -> int:
|
| 27 |
+
"""Return the number of parameters in the model"""
|
| 28 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class Config:
|
| 32 |
+
"""Read configuration from a YAML file and store as attributes"""
|
| 33 |
+
|
| 34 |
+
def __init__(self, yaml_file: str):
|
| 35 |
+
with open(yaml_file, "r") as f:
|
| 36 |
+
config = yaml.safe_load(f)
|
| 37 |
+
|
| 38 |
+
for k, v in config.items():
|
| 39 |
+
setattr(self, k, v)
|
| 40 |
+
|
| 41 |
+
def update(self, new_yaml_file: str):
|
| 42 |
+
with open(new_yaml_file, "r") as f:
|
| 43 |
+
config = yaml.safe_load(f)
|
| 44 |
+
|
| 45 |
+
for k, v in config.items():
|
| 46 |
+
setattr(self, k, v)
|
| 47 |
+
|
| 48 |
+
def save(self, yaml_file: str):
|
| 49 |
+
with open(yaml_file, "w") as f:
|
| 50 |
+
yaml.dump(self.__dict__, f)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def memory_usage_psutil():
|
| 54 |
+
"""Return the memory usage in percentage like top"""
|
| 55 |
+
process = psutil.Process(os.getpid())
|
| 56 |
+
mem = process.memory_percent()
|
| 57 |
+
return mem
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_wandb_running():
|
| 61 |
+
"""Check if wandb is running"""
|
| 62 |
+
return "WANDB_SWEEP_ID" in os.environ
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class TimeoutError(Exception):
|
| 66 |
+
pass
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def timeout(seconds=10, error_message=os.strerror(errno.ETIME)):
|
| 70 |
+
def decorator(func):
|
| 71 |
+
def _handle_timeout(signum, frame):
|
| 72 |
+
raise TimeoutError(error_message)
|
| 73 |
+
|
| 74 |
+
def wrapper(*args, **kwargs):
|
| 75 |
+
signal.signal(signal.SIGALRM, _handle_timeout)
|
| 76 |
+
signal.alarm(seconds)
|
| 77 |
+
try:
|
| 78 |
+
result = func(*args, **kwargs)
|
| 79 |
+
finally:
|
| 80 |
+
signal.alarm(0)
|
| 81 |
+
return result
|
| 82 |
+
|
| 83 |
+
return wraps(func)(wrapper)
|
| 84 |
+
|
| 85 |
+
return decorator
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def shorten_path(path: str, max_len: int = 30) -> str:
|
| 89 |
+
"""Shorten the path to max_len characters"""
|
| 90 |
+
if len(path) > max_len:
|
| 91 |
+
return path[:max_len // 2] + "..." + path[-max_len // 2:]
|
| 92 |
+
return path
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def cluster_points(data: torch.Tensor, d: float) -> torch.Tensor:
|
| 96 |
+
"""
|
| 97 |
+
Cluster points based on the Euclidean distance.
|
| 98 |
+
|
| 99 |
+
:param data: Input data, shape (n_points, n_features), type torch.Tensor.
|
| 100 |
+
:param d: Distance threshold for clustering.
|
| 101 |
+
:return: Cluster indices, shape (n_points,), type torch.Tensor.
|
| 102 |
+
"""
|
| 103 |
+
dist = torch.cdist(data, data)
|
| 104 |
+
indices = torch.full((data.shape[0],), -1, dtype=torch.long)
|
| 105 |
+
cluster_id = 0
|
| 106 |
+
for i in range(data.shape[0]):
|
| 107 |
+
if indices[i] == -1:
|
| 108 |
+
indices[dist[i] < d] = cluster_id
|
| 109 |
+
cluster_id += 1
|
| 110 |
+
return indices
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def bron_kerbosch(R, P, X, graph):
|
| 114 |
+
if not P and not X:
|
| 115 |
+
yield R
|
| 116 |
+
while P:
|
| 117 |
+
v = P.pop()
|
| 118 |
+
yield from bron_kerbosch(
|
| 119 |
+
R | {v},
|
| 120 |
+
P & set(graph[v]),
|
| 121 |
+
X & set(graph[v]),
|
| 122 |
+
graph
|
| 123 |
+
)
|
| 124 |
+
X.add(v)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def find_cliques(graph):
|
| 128 |
+
"""
|
| 129 |
+
Find all maximal cliques in an undirected graph with the Bron–Kerbosch algorithm.
|
| 130 |
+
|
| 131 |
+
:param graph: Input graph as a NetworkX graph
|
| 132 |
+
:return: List of maximal cliques
|
| 133 |
+
"""
|
| 134 |
+
return list(bron_kerbosch(set(), set(graph.nodes()), set(), graph))
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def segment_cmd(cmd_str: str, max_len: int = 1000):
|
| 138 |
+
cmds = ['']
|
| 139 |
+
prev = 0
|
| 140 |
+
for i, c in enumerate(cmd_str):
|
| 141 |
+
if c == ';':
|
| 142 |
+
if len(cmds[-1]) + len(cmd_str[prev:i]) > max_len:
|
| 143 |
+
cmds.append('')
|
| 144 |
+
cmds[-1] += cmd_str[prev:i + 1]
|
| 145 |
+
prev = i + 1
|
| 146 |
+
return cmds
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def get_color(v):
|
| 150 |
+
assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
|
| 151 |
+
# green to brown
|
| 152 |
+
color1 = np.array([0, 128, 0])
|
| 153 |
+
color2 = np.array([165, 42, 42])
|
| 154 |
+
v = v * (color2 - color1) + color1
|
| 155 |
+
v /= 255
|
| 156 |
+
return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def generate_pymol_script(possible_sites):
|
| 160 |
+
cmd = ''
|
| 161 |
+
for i, pos in enumerate(possible_sites):
|
| 162 |
+
cmd += f"pseudoatom s{i},pos=[{pos[0]:.1f},{pos[1]:.1f},{pos[2]:.1f}];color blue,s{i};"
|
| 163 |
+
return cmd
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def remove_close_points_kdtree(points, min_distance):
|
| 167 |
+
tree = KDTree(points)
|
| 168 |
+
keep = np.ones(len(points), dtype=bool)
|
| 169 |
+
for i, point in enumerate(points):
|
| 170 |
+
if not keep[i]:
|
| 171 |
+
continue
|
| 172 |
+
neighbors = tree.query_ball_point(
|
| 173 |
+
point, min_distance)
|
| 174 |
+
keep[neighbors] = False
|
| 175 |
+
keep[i] = True # Keep the current point
|
| 176 |
+
return points[keep]
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@line_profiler.profile
|
| 180 |
+
def pack_bit(x: torch.Tensor):
|
| 181 |
+
""" Pack the bit tensor to a sequence of bytes.
|
| 182 |
+
Args:
|
| 183 |
+
x (torch.Tensor): The input tensor to be packed.
|
| 184 |
+
Returns:
|
| 185 |
+
torch.Tensor: The packed tensor.
|
| 186 |
+
"""
|
| 187 |
+
batch_size, num_bits = x.shape
|
| 188 |
+
num_bytes = (num_bits + 7) // 8
|
| 189 |
+
output = torch.zeros(batch_size, num_bytes,
|
| 190 |
+
dtype=torch.uint8, device=x.device)
|
| 191 |
+
for i in range(num_bits):
|
| 192 |
+
byte_index = i // 8
|
| 193 |
+
bit_index = i % 8
|
| 194 |
+
output[:, byte_index] |= (x[:, i] << bit_index).to(torch.uint8)
|
| 195 |
+
return output
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
@line_profiler.profile
|
| 199 |
+
def unpack_bit(x: torch.Tensor, num_bits: int):
|
| 200 |
+
""" Unpack the bit tensor from a sequence of bytes.
|
| 201 |
+
Args:
|
| 202 |
+
x (torch.Tensor): The input tensor to be unpacked.
|
| 203 |
+
num_bits (int): The number of bits to unpack.
|
| 204 |
+
Returns:
|
| 205 |
+
torch.Tensor: The unpacked tensor.
|
| 206 |
+
"""
|
| 207 |
+
batch_size, num_bytes = x.shape
|
| 208 |
+
output = torch.zeros(batch_size, num_bits,
|
| 209 |
+
dtype=torch.uint8, device=x.device)
|
| 210 |
+
for i in range(num_bits):
|
| 211 |
+
byte_index = i // 8
|
| 212 |
+
bit_index = i % 8
|
| 213 |
+
output[:, i] = (x[:, byte_index] >> bit_index) & 1
|
| 214 |
+
return output
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def safe_dist(vec1: torch.Tensor, vec2: torch.Tensor, max_size: int = 100_000_000, p: int = 2):
|
| 218 |
+
""" compute the minimum distance between two vectors:
|
| 219 |
+
|
| 220 |
+
vec1: (N, 3), N could be very very large, i.e., all atoms' coordinates in a large protein
|
| 221 |
+
|
| 222 |
+
vec2: (M, 3), M are not very large, usually the coordinates of the binding sites
|
| 223 |
+
|
| 224 |
+
max_size: the maximum size of the distance matrix to compute at once
|
| 225 |
+
|
| 226 |
+
p: the p-norm to use for distance calculation
|
| 227 |
+
|
| 228 |
+
return: (M, ) the minimum distance of each binding site to the protein
|
| 229 |
+
"""
|
| 230 |
+
size1 = vec1.shape
|
| 231 |
+
size2 = vec2.shape
|
| 232 |
+
batch_size = ceil(max_size / size1[0])
|
| 233 |
+
dists = []
|
| 234 |
+
for i in range(0, size2[0], batch_size):
|
| 235 |
+
dist = torch.cdist(vec1, vec2[i:i + batch_size], p=p)
|
| 236 |
+
dists.append(dist.min(dim=0).values)
|
| 237 |
+
return torch.cat(dists)
|
| 238 |
+
|
| 239 |
+
|
| 240 |
+
@line_profiler.profile
|
| 241 |
+
def safe_filter(nos: torch.Tensor, pos: torch.Tensor, thr: torch.Tensor, all: torch.Tensor, lb: float, max_size: int = 100_000_000):
|
| 242 |
+
""" filter the binding sites based on the distance matrix
|
| 243 |
+
nos: (N, 3), N are the coordinates of the binding sites
|
| 244 |
+
*pos: (M, 3), M are the coordinates of the protein, could be very very large
|
| 245 |
+
thr: (N, 2), the distance threshold for each binding site
|
| 246 |
+
all: (P, 3), P are the coordinates of all atoms in the protein
|
| 247 |
+
lb: the lower bound of the distance
|
| 248 |
+
|
| 249 |
+
return: (N, M) available binding sites
|
| 250 |
+
"""
|
| 251 |
+
N, M, P = nos.shape[0], pos.shape[0], all.shape[0]
|
| 252 |
+
batch_size = ceil(max_size / N)
|
| 253 |
+
output = []
|
| 254 |
+
interests = []
|
| 255 |
+
for i in tqdm(range(0, M, batch_size), leave=False, desc=f'Filtering (batch_size: {batch_size})'):
|
| 256 |
+
dist = torch.cdist(pos[i:i + batch_size], nos)
|
| 257 |
+
dist = (dist <= thr[:, 1].unsqueeze(0)) & \
|
| 258 |
+
(dist >= thr[:, 0].unsqueeze(0))
|
| 259 |
+
dist_all = safe_dist(all, pos[i:i + batch_size]) > lb
|
| 260 |
+
dist = dist & dist_all.unsqueeze(-1)
|
| 261 |
+
|
| 262 |
+
mask = dist.any(dim=1)
|
| 263 |
+
output.append(pack_bit(dist[mask]).T)
|
| 264 |
+
interests.append(mask)
|
| 265 |
+
return torch.cat(output, dim=1), torch.cat(interests)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def backbone(atoms, chain_id):
|
| 269 |
+
""" return the atoms of the backbone of a chain """
|
| 270 |
+
return atoms[
|
| 271 |
+
(atoms.chain_id == chain_id) &
|
| 272 |
+
(atoms.atom_name == "CA") &
|
| 273 |
+
(atoms.element == "C")]
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_color(v):
|
| 277 |
+
assert 0 <= v <= 1, f'v should be in [0, 1], got {v}'
|
| 278 |
+
# green to brown
|
| 279 |
+
color1 = np.array([0, 128, 0])
|
| 280 |
+
color2 = np.array([165, 42, 42])
|
| 281 |
+
v = v * (color2 - color1) + color1
|
| 282 |
+
v /= 255
|
| 283 |
+
return f'[{v[0]:.2f},{v[1]:.2f},{v[2]:.2f}]'
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def load_private_key_from_file(private_key_file=None):
|
| 287 |
+
if private_key_file is None:
|
| 288 |
+
private_key_b64 = os.environ.get('ModelCheckpointPrivateKey')
|
| 289 |
+
else:
|
| 290 |
+
with open(private_key_file, 'r') as f:
|
| 291 |
+
private_key_b64 = f.read().strip()
|
| 292 |
+
|
| 293 |
+
private_pem = base64.b64decode(private_key_b64)
|
| 294 |
+
private_key = serialization.load_pem_private_key(
|
| 295 |
+
private_pem,
|
| 296 |
+
password=None,
|
| 297 |
+
backend=default_backend()
|
| 298 |
+
)
|
| 299 |
+
return private_key
|
| 300 |
+
|
| 301 |
+
|
| 302 |
+
def decrypt_checkpoint(encrypted_path, private_key):
|
| 303 |
+
backend = default_backend()
|
| 304 |
+
|
| 305 |
+
with open(encrypted_path, 'rb') as f:
|
| 306 |
+
|
| 307 |
+
key_length = int.from_bytes(f.read(4), 'big')
|
| 308 |
+
|
| 309 |
+
encrypted_aes_key = f.read(key_length)
|
| 310 |
+
iv = f.read(16)
|
| 311 |
+
original_size = int.from_bytes(f.read(8), 'big')
|
| 312 |
+
encrypted_data = f.read()
|
| 313 |
+
|
| 314 |
+
try:
|
| 315 |
+
aes_key = private_key.decrypt(
|
| 316 |
+
encrypted_aes_key,
|
| 317 |
+
padding.OAEP(
|
| 318 |
+
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
| 319 |
+
algorithm=hashes.SHA256(),
|
| 320 |
+
label=None
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
cipher = Cipher(algorithms.AES(aes_key),
|
| 325 |
+
modes.CBC(iv), backend=backend)
|
| 326 |
+
decryptor = cipher.decryptor()
|
| 327 |
+
decrypted_padded = decryptor.update(
|
| 328 |
+
encrypted_data) + decryptor.finalize()
|
| 329 |
+
|
| 330 |
+
decrypted_data = decrypted_padded[:original_size]
|
| 331 |
+
|
| 332 |
+
try:
|
| 333 |
+
buffer = io.BytesIO(decrypted_data)
|
| 334 |
+
checkpoint_dict = torch.load(buffer, map_location='cpu')
|
| 335 |
+
return checkpoint_dict
|
| 336 |
+
except:
|
| 337 |
+
checkpoint_dict = pickle.loads(decrypted_data)
|
| 338 |
+
return checkpoint_dict
|
| 339 |
+
|
| 340 |
+
except Exception as e:
|
| 341 |
+
print(f"Error: {e}")
|
| 342 |
+
raise
|
src/predict.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from einops import rearrange
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import time
|
| 7 |
+
from models.polybert import PolyEncoder
|
| 8 |
+
from models.training import BaseModel
|
| 9 |
+
from models.utils import decrypt_checkpoint, load_private_key_from_file
|
| 10 |
+
import argparse
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from models.utils import Config
|
| 14 |
+
from models.plm import EsmModelInfo, get_model
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if __name__ == "__main__":
|
| 19 |
+
# fmt: off
|
| 20 |
+
parser = argparse.ArgumentParser(description="Predict plastic degradation")
|
| 21 |
+
parser.add_argument("--ckpt", type=str, help="Path to the model checkpoint")
|
| 22 |
+
parser.add_argument("--plm", type=str, help="Protein language model to use", default='esm2_t33_650M_UR50D')
|
| 23 |
+
parser.add_argument("--csv", type=str, help="Path to the CSV file with test data", default=None)
|
| 24 |
+
parser.add_argument("--output",'-o', type=str, help="Path to the output file", default='predictions.csv')
|
| 25 |
+
parser.add_argument("--attn", action='store_true', help="Save attention weights to files")
|
| 26 |
+
# fmt: on
|
| 27 |
+
args = parser.parse_args()
|
| 28 |
+
|
| 29 |
+
info = EsmModelInfo(args.plm)
|
| 30 |
+
plm_dim = info['dim']*2
|
| 31 |
+
pbert_dim = 600
|
| 32 |
+
|
| 33 |
+
dev = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 34 |
+
model = BaseModel(plm_dim, pbert_dim, n_classes=2).to(dev)
|
| 35 |
+
|
| 36 |
+
# load weights
|
| 37 |
+
private_key = load_private_key_from_file()
|
| 38 |
+
state_dict = decrypt_checkpoint(args.ckpt, private_key)
|
| 39 |
+
state_dict = {
|
| 40 |
+
k.replace('model.', ''): v for k, v in state_dict['state_dict'].items() if k.startswith('model.')}
|
| 41 |
+
model.load_state_dict(state_dict)
|
| 42 |
+
model.eval()
|
| 43 |
+
print(f'Load predictor from {args.ckpt}')
|
| 44 |
+
|
| 45 |
+
plm_func = get_model(args.plm, 'cuda')
|
| 46 |
+
print(f'Loaded PLM model {args.plm}')
|
| 47 |
+
|
| 48 |
+
polybert_func = PolyEncoder()
|
| 49 |
+
print('Loaded PolyEncoder model')
|
| 50 |
+
|
| 51 |
+
outfile = Path(
|
| 52 |
+
'predictions.csv' if args.output is None else args.output)
|
| 53 |
+
# get protein embedding
|
| 54 |
+
with torch.no_grad(), torch.inference_mode():
|
| 55 |
+
df = pd.read_csv(args.csv)
|
| 56 |
+
probs = []
|
| 57 |
+
running_time = []
|
| 58 |
+
for i, row in tqdm(df.iterrows()):
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
|
| 61 |
+
seq = row['sequence'].upper()
|
| 62 |
+
poly = row['polymer']
|
| 63 |
+
seq_emb = plm_func([seq]).to(dev)
|
| 64 |
+
seq_emb = rearrange(seq_emb, 'b l d -> b (l d)').unsqueeze(0)
|
| 65 |
+
poly_emb = polybert_func([poly]).to(dev)
|
| 66 |
+
logits, p_weights, l_weights = model((seq_emb, poly_emb))
|
| 67 |
+
prob = F.softmax(logits, dim=-1)[:, 1].item()
|
| 68 |
+
if args.attn:
|
| 69 |
+
outfile.with_suffix('.attn').mkdir(
|
| 70 |
+
parents=True, exist_ok=True)
|
| 71 |
+
torch.save(
|
| 72 |
+
(p_weights, l_weights),
|
| 73 |
+
outfile.with_suffix('.attn') / f'{i}.pt')
|
| 74 |
+
probs.append(prob)
|
| 75 |
+
running_time.append(time.time() - start_time)
|
| 76 |
+
|
| 77 |
+
df['prob'] = probs
|
| 78 |
+
df['pred'] = df['prob'].apply(lambda x: 'Yes' if x >= 0.5 else 'No')
|
| 79 |
+
df['time'] = running_time
|
| 80 |
+
|
| 81 |
+
# move pred and prob to the front
|
| 82 |
+
df = df[['pred', 'prob'] +
|
| 83 |
+
[col for col in df.columns if col not in ['pred', 'prob']]]
|
| 84 |
+
|
| 85 |
+
df.to_csv(outfile, index=False)
|
| 86 |
+
print(f'Predictions saved to {outfile}')
|
| 87 |
+
print(f'Attention weights saved to current directory as <index>.pt')
|
src/streamlit_app.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#fmt: off
|
| 2 |
+
import streamlit as st
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import os
|
| 5 |
+
import tempfile
|
| 6 |
+
import subprocess
|
| 7 |
+
import requests
|
| 8 |
+
import csv
|
| 9 |
+
from models.polybert import polymer2psmiles
|
| 10 |
+
import py3Dmol
|
| 11 |
+
|
| 12 |
+
# Fix for permission error - disable usage stats
|
| 13 |
+
if 'STREAMLIT_CONFIG_DIR' not in os.environ:
|
| 14 |
+
os.environ['STREAMLIT_CONFIG_DIR'] = '/tmp/.streamlit'
|
| 15 |
+
|
| 16 |
+
# Create streamlit config directory if it doesn't exist
|
| 17 |
+
streamlit_dir = os.environ.get('STREAMLIT_CONFIG_DIR', '/tmp/.streamlit')
|
| 18 |
+
os.makedirs(streamlit_dir, exist_ok=True)
|
| 19 |
+
|
| 20 |
+
# Create config.toml to disable usage stats
|
| 21 |
+
config_path = os.path.join(streamlit_dir, 'config.toml')
|
| 22 |
+
if not os.path.exists(config_path):
|
| 23 |
+
with open(config_path, 'w') as f:
|
| 24 |
+
f.write("""[browser]
|
| 25 |
+
gatherUsageStats = false
|
| 26 |
+
|
| 27 |
+
[server]
|
| 28 |
+
headless = true
|
| 29 |
+
enableCORS = false
|
| 30 |
+
enableXsrfProtection = false
|
| 31 |
+
""")
|
| 32 |
+
# fmt: on
|
| 33 |
+
|
| 34 |
+
aa2resn = {
|
| 35 |
+
'A': 'ALA',
|
| 36 |
+
'C': 'CYS',
|
| 37 |
+
'D': 'ASP',
|
| 38 |
+
'E': 'GLU',
|
| 39 |
+
'F': 'PHE',
|
| 40 |
+
'G': 'GLY',
|
| 41 |
+
'H': 'HIS',
|
| 42 |
+
'I': 'ILE',
|
| 43 |
+
'K': 'LYS',
|
| 44 |
+
'L': 'LEU',
|
| 45 |
+
'M': 'MET',
|
| 46 |
+
'N': 'ASN',
|
| 47 |
+
'P': 'PRO',
|
| 48 |
+
'Q': 'GLN',
|
| 49 |
+
'R': 'ARG',
|
| 50 |
+
'S': 'SER',
|
| 51 |
+
'T': 'THR',
|
| 52 |
+
'V': 'VAL',
|
| 53 |
+
'W': 'TRP',
|
| 54 |
+
'Y': 'TYR'
|
| 55 |
+
}
|
| 56 |
+
|
| 57 |
+
# Fancy header
|
| 58 |
+
st.markdown("""
|
| 59 |
+
<div style='text-align: center;'>
|
| 60 |
+
<h1 style='color:#377EB9;font-size:2.5em;'>🧬 Plastic Degradation Predictor</h1>
|
| 61 |
+
<h3 style='color:#4DAE48;'>Predict the degradability of plastics using protein sequences and polymer SMILES</h3>
|
| 62 |
+
</div>
|
| 63 |
+
<hr style='border:1px solid #974F9F;'>
|
| 64 |
+
""", unsafe_allow_html=True)
|
| 65 |
+
|
| 66 |
+
st.write("Enter a UniProt ID or paste a protein sequence. Select a polymer from the list below.")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
# Load polymer names and SMILES
|
| 70 |
+
|
| 71 |
+
# Only show polymers with SMILES in the dropdown
|
| 72 |
+
polymer_csv = os.path.join(os.path.dirname(
|
| 73 |
+
__file__), 'data/polymer2tok.csv')
|
| 74 |
+
polymer_options = []
|
| 75 |
+
with open(polymer_csv, newline='') as f:
|
| 76 |
+
reader = csv.DictReader(f)
|
| 77 |
+
for row in reader:
|
| 78 |
+
name = row['polymer']
|
| 79 |
+
smiles = polymer2psmiles.get(name, '')
|
| 80 |
+
if smiles: # Only include polymers with SMILES
|
| 81 |
+
polymer_options.append(f"{name} | {smiles}")
|
| 82 |
+
|
| 83 |
+
input_type = st.radio("Input type", ["UniProt ID", "Protein Sequence"])
|
| 84 |
+
|
| 85 |
+
if input_type == "UniProt ID":
|
| 86 |
+
uniprot_id = st.text_input("Enter UniProt ID", "P69905")
|
| 87 |
+
sequence = ""
|
| 88 |
+
if uniprot_id:
|
| 89 |
+
# Fetch sequence from UniProt
|
| 90 |
+
url = f"https://rest.uniprot.org/uniprotkb/{uniprot_id}.fasta"
|
| 91 |
+
resp = requests.get(url)
|
| 92 |
+
if resp.status_code == 200:
|
| 93 |
+
fasta = resp.text
|
| 94 |
+
sequence = "".join(fasta.split("\n")[1:])
|
| 95 |
+
st.success(f"Fetched sequence for {uniprot_id}")
|
| 96 |
+
st.code(sequence)
|
| 97 |
+
else:
|
| 98 |
+
st.error("Failed to fetch sequence from UniProt.")
|
| 99 |
+
else:
|
| 100 |
+
sequence = st.text_area("Paste protein sequence",
|
| 101 |
+
"MVLSPADKTNVKAAWGKVGAHAGEYGAEALERMFLSFPTTKTYFPHFDLSHGSAQVKGHG")
|
| 102 |
+
|
| 103 |
+
polymer = st.selectbox("Select polymer", polymer_options)
|
| 104 |
+
selected_polymer = polymer.split('|')[0].strip() if '|' in polymer else polymer
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
ckpt = "src/checkpoints/weights.ckpt"
|
| 108 |
+
plm = "esm2_t33_650M_UR50D"
|
| 109 |
+
|
| 110 |
+
if st.button("Predict degradation", type="primary"):
|
| 111 |
+
if not sequence or not selected_polymer:
|
| 112 |
+
st.error("Please provide both sequence and polymer.")
|
| 113 |
+
else:
|
| 114 |
+
# Create temp CSV
|
| 115 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as tmp:
|
| 116 |
+
tmp.write("sequence,polymer\n")
|
| 117 |
+
tmp.write(f"{sequence},{selected_polymer}\n")
|
| 118 |
+
tmp_path = tmp.name
|
| 119 |
+
output_path = os.path.join(tempfile.gettempdir(), "predictions.csv")
|
| 120 |
+
st.write("Running prediction...")
|
| 121 |
+
result = subprocess.run([
|
| 122 |
+
"python", "src/predict.py",
|
| 123 |
+
"--ckpt", ckpt,
|
| 124 |
+
"--plm", plm,
|
| 125 |
+
"--csv", tmp_path,
|
| 126 |
+
"--output", output_path,
|
| 127 |
+
"--attn"
|
| 128 |
+
], capture_output=True, text=True)
|
| 129 |
+
if result.returncode == 0 and os.path.exists(output_path):
|
| 130 |
+
df = pd.read_csv(output_path)
|
| 131 |
+
if 'time' in df.columns:
|
| 132 |
+
df = df.rename(columns={'time': 'running time'})
|
| 133 |
+
st.markdown(f"""
|
| 134 |
+
<div style='background: linear-gradient(90deg, #377EB9 0%, #4DAE48 100%); padding: 1.5em; border-radius: 12px; color: white; margin-bottom: 1em;'>
|
| 135 |
+
<h2 style='margin:0;'><span style='font-size:18pt'>✅</span> Prediction Complete!</h2>
|
| 136 |
+
<p style='font-size:12pt;'>Your input has been processed. See the results below:</p>
|
| 137 |
+
<p style='font-size:12pt;'>Degradation: {df['pred'].values[0]} (Probability: {df['prob'].values[0]:.4f})</p>
|
| 138 |
+
</div>
|
| 139 |
+
""", unsafe_allow_html=True)
|
| 140 |
+
st.download_button("⬇️ Download Results", data=df.to_csv(
|
| 141 |
+
index=False), file_name="predictions.csv", type="primary")
|
| 142 |
+
|
| 143 |
+
# Show top-N attention residues if attention file exists
|
| 144 |
+
attn_dir = os.path.join(os.path.dirname(
|
| 145 |
+
output_path), "predictions.attn")
|
| 146 |
+
attn_path = os.path.join(attn_dir, "0.pt")
|
| 147 |
+
if os.path.exists(attn_path):
|
| 148 |
+
import torch
|
| 149 |
+
attn = torch.load(attn_path)
|
| 150 |
+
# attn[0][0]: shape (num_heads, seq_len, seq_len) or (1, seq_len, seq_len)
|
| 151 |
+
attn_matrix = attn[0][0] if isinstance(
|
| 152 |
+
attn[0], (list, tuple)) else attn[0]
|
| 153 |
+
# Average over heads if needed
|
| 154 |
+
if attn_matrix.ndim == 3:
|
| 155 |
+
attn_matrix = attn_matrix.mean(0)
|
| 156 |
+
# For each residue, sum attention weights
|
| 157 |
+
residue_scores = attn_matrix.sum(0).cpu().numpy()
|
| 158 |
+
topN = min(10, len(residue_scores))
|
| 159 |
+
top_idx = residue_scores.argsort()[::-1][:topN]
|
| 160 |
+
st.markdown(f"**Top {topN} high-attention residues:**")
|
| 161 |
+
st.write(pd.DataFrame({
|
| 162 |
+
"Amino Acid": [sequence[i] for i in top_idx],
|
| 163 |
+
"Residue Index": top_idx+1,
|
| 164 |
+
"Attention Score": residue_scores[top_idx]
|
| 165 |
+
}))
|
| 166 |
+
else:
|
| 167 |
+
st.info("No attention file found for visualization.")
|
| 168 |
+
else:
|
| 169 |
+
st.error("Prediction failed. See details below:")
|
| 170 |
+
st.text(result.stderr)
|
| 171 |
+
|
| 172 |
+
# If UniProt ID, try to download AlphaFold structure
|
| 173 |
+
structure_path = None
|
| 174 |
+
|
| 175 |
+
if input_type == "UniProt ID" and uniprot_id:
|
| 176 |
+
af_url = f"https://alphafold.ebi.ac.uk/files/AF-{uniprot_id}-F1-model_v4.cif"
|
| 177 |
+
# If attention available, highlight top residues
|
| 178 |
+
highlight_residues = None
|
| 179 |
+
attn_dir = os.path.join(tempfile.gettempdir(), "predictions.attn")
|
| 180 |
+
attn_path = os.path.join(attn_dir, "0.pt")
|
| 181 |
+
if os.path.exists(attn_path):
|
| 182 |
+
import torch
|
| 183 |
+
attn = torch.load(attn_path)
|
| 184 |
+
attn_matrix = attn[0][0] if isinstance(
|
| 185 |
+
attn[0], (list, tuple)) else attn[0]
|
| 186 |
+
if attn_matrix.ndim == 3:
|
| 187 |
+
attn_matrix = attn_matrix.mean(0)
|
| 188 |
+
residue_scores = attn_matrix.sum(0).cpu().numpy()
|
| 189 |
+
topN = min(10, len(residue_scores))
|
| 190 |
+
top_idx = residue_scores.argsort()[::-1][:topN]
|
| 191 |
+
# Molstar selection: list of residue numbers (1-based)
|
| 192 |
+
highlight_residues = [int(i+1) for i in top_idx]
|
| 193 |
+
|
| 194 |
+
structure_path = os.path.join(
|
| 195 |
+
tempfile.gettempdir(), f"AF-{uniprot_id}-F1-model_v4.cif")
|
| 196 |
+
try:
|
| 197 |
+
r = requests.get(af_url)
|
| 198 |
+
if r.status_code == 200:
|
| 199 |
+
with open(structure_path, "wb") as f:
|
| 200 |
+
f.write(r.content)
|
| 201 |
+
st.success(
|
| 202 |
+
f"AlphaFold structure downloaded: {structure_path}")
|
| 203 |
+
else:
|
| 204 |
+
st.warning(
|
| 205 |
+
"AlphaFoldDB structure not found for this UniProt ID.")
|
| 206 |
+
except Exception as e:
|
| 207 |
+
st.warning(f"AlphaFoldDB download error: {e}")
|
| 208 |
+
|
| 209 |
+
if input_type == "UniProt ID" and uniprot_id and os.path.exists(attn_path) and os.path.exists(structure_path):
|
| 210 |
+
st.markdown("### 3D Structure Visualization (stmol)")
|
| 211 |
+
import torch
|
| 212 |
+
from stmol import showmol
|
| 213 |
+
attn = torch.load(attn_path)
|
| 214 |
+
attn_matrix = attn[0][0] if isinstance(
|
| 215 |
+
attn[0], (list, tuple)) else attn[0]
|
| 216 |
+
if attn_matrix.ndim == 3:
|
| 217 |
+
attn_matrix = attn_matrix.mean(0)
|
| 218 |
+
residue_scores = attn_matrix.sum(0).cpu().numpy()
|
| 219 |
+
topN = min(10, len(residue_scores))
|
| 220 |
+
top_idx = residue_scores.argsort()[::-1][:topN]
|
| 221 |
+
labels = [
|
| 222 |
+
f"{sequence[i]}{i+1}: {residue_scores[i]:.4g}" for i in top_idx]
|
| 223 |
+
with open(structure_path, "r") as cif_file:
|
| 224 |
+
cif_data = cif_file.read()
|
| 225 |
+
view = py3Dmol.view(width=600, height=400)
|
| 226 |
+
view.addModel(cif_data, "cif")
|
| 227 |
+
view.setStyle({"cartoon": {"color": "lightgray"}})
|
| 228 |
+
for i, idx in enumerate(top_idx):
|
| 229 |
+
resi_num = int(idx+1)
|
| 230 |
+
view.setStyle(
|
| 231 |
+
{"resi": resi_num}, {
|
| 232 |
+
"cartoon": {"color": "red"}})
|
| 233 |
+
view.addResLabels(
|
| 234 |
+
{"resi": resi_num},
|
| 235 |
+
{
|
| 236 |
+
"font": 'Arial', "fontColor": 'black',
|
| 237 |
+
"showBackground": False, "screenOffset": {"x": 0, "y": 0}})
|
| 238 |
+
view.zoomTo()
|
| 239 |
+
showmol(view, height=600, width='100%')
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# --- Footer: License and References ---
|
| 243 |
+
st.markdown("""
|
| 244 |
+
---
|
| 245 |
+
<h4>License</h4>
|
| 246 |
+
Attribution-NonCommercial-ShareAlike 4.0 International (CC BY-NC-SA 4.0)<br>
|
| 247 |
+
<a href='https://creativecommons.org/licenses/by-nc-sa/4.0/' target='_blank'>View full license details</a><br>
|
| 248 |
+
""", unsafe_allow_html=True)
|