Commit Β·
13188b8
1
Parent(s): 543ad41
Prepared for deploy
Browse files- .dockerignore +6 -0
- .gitignore +3 -1
- Dockerfile +16 -0
- docker-compose.yml +8 -0
- main.py +37 -34
- model_attention.py +1 -1
- models/model_ep028_weighted_loss6.7715.pth +3 -0
- requirements.txt +17 -7
- EDA.ipynb β research/EDA.ipynb +0 -0
- {GNN_classification β research/GNN_classification}/Dataset_Preparation.py +0 -0
- {GNN_classification β research/GNN_classification}/dataset/classification/EDA.ipynb +0 -0
- {GNN_classification β research/GNN_classification}/dataset/classification/data_test.csv +0 -0
- {GNN_classification β research/GNN_classification}/dataset/classification/data_test.txt +0 -0
- {GNN_classification β research/GNN_classification}/dataset/classification/data_train.csv +0 -0
- {GNN_classification β research/GNN_classification}/dataset/classification/data_train.txt +0 -0
- {GNN_classification β research/GNN_classification}/model.py +0 -0
- {GNN_classification β research/GNN_classification}/training.py +0 -0
- GNNs__practice.ipynb β research/GNNs__practice.ipynb +0 -0
- all_inferences.py β research/all_inferences.py +82 -47
- research/dataset.py +209 -0
- dataset_preparation.py β research/dataset_preparation.py +0 -0
- inference.py β research/inference.py +0 -0
- inference_attention.py β research/inference_attention.py +0 -0
- research/loss.py +18 -0
- research/model.py +158 -0
- research/model_attention.py +144 -0
- model_pl.py β research/model_pl.py +0 -0
- optuna_train.py β research/optuna_train.py +0 -0
- optuna_train_attention.py β research/optuna_train_attention.py +5 -3
- pdbbind_refined_dataset.csv β research/pdbbind_refined_dataset.csv +0 -0
- research/requirements_dev.txt +18 -0
- train.py β research/train.py +0 -0
- train_attention.py β research/train_attention.py +23 -10
- train_pl.py β research/train_pl.py +0 -0
- visualization.ipynb β research/visualization.ipynb +0 -0
- utils.py +52 -36
.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.idea
|
| 3 |
+
.venv
|
| 4 |
+
__pycache__
|
| 5 |
+
research
|
| 6 |
+
.env
|
.gitignore
CHANGED
|
@@ -1,7 +1,9 @@
|
|
| 1 |
.idea
|
| 2 |
.venv
|
| 3 |
.ipynb_checkpoints
|
|
|
|
| 4 |
|
| 5 |
/refined-set/
|
| 6 |
/data
|
| 7 |
-
/lightning_logs
|
|
|
|
|
|
| 1 |
.idea
|
| 2 |
.venv
|
| 3 |
.ipynb_checkpoints
|
| 4 |
+
__pycache__
|
| 5 |
|
| 6 |
/refined-set/
|
| 7 |
/data
|
| 8 |
+
/lightning_logs
|
| 9 |
+
.env
|
Dockerfile
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
build-essential \
|
| 7 |
+
libxrender1 \
|
| 8 |
+
libxext6 \
|
| 9 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
+
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
COPY . .
|
| 15 |
+
EXPOSE 7860
|
| 16 |
+
CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
|
docker-compose.yml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
binding-predictor:
|
| 3 |
+
build: .
|
| 4 |
+
container_name: binding-app
|
| 5 |
+
ports:
|
| 6 |
+
- "8000:7860"
|
| 7 |
+
environment:
|
| 8 |
+
- GEMINI_API_KEY=${GEMINI_API_KEY}
|
main.py
CHANGED
|
@@ -5,8 +5,13 @@ from fastapi import FastAPI, Request, Form
|
|
| 5 |
from fastapi.templating import Jinja2Templates
|
| 6 |
from fastapi.staticfiles import StaticFiles
|
| 7 |
from fastapi.responses import HTMLResponse
|
| 8 |
-
from utils import
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
app = FastAPI()
|
| 12 |
|
|
@@ -17,6 +22,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
| 17 |
|
| 18 |
templates = Jinja2Templates(directory="templates")
|
| 19 |
|
|
|
|
| 20 |
@app.get("/", response_class=HTMLResponse)
|
| 21 |
async def read_root(request: Request):
|
| 22 |
return templates.TemplateResponse("index.html", {"request": request})
|
|
@@ -24,33 +30,32 @@ async def read_root(request: Request):
|
|
| 24 |
|
| 25 |
@app.post("/predict", response_class=HTMLResponse)
|
| 26 |
async def predict(
|
| 27 |
-
request: Request,
|
| 28 |
-
smiles_ligand: str = Form(...),
|
| 29 |
-
sequence_protein: str = Form(...)
|
| 30 |
):
|
| 31 |
mol, importance, affinity = get_inference_data(smiles_ligand, sequence_protein)
|
| 32 |
|
| 33 |
atom_list = []
|
| 34 |
-
sorted_indices = sorted(
|
|
|
|
|
|
|
| 35 |
|
| 36 |
for idx in sorted_indices[:15]:
|
| 37 |
val = importance[idx]
|
| 38 |
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
|
| 39 |
|
| 40 |
icon = ""
|
| 41 |
-
if val >= 0.9:
|
| 42 |
-
|
| 43 |
-
elif val >= 0.
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
"
|
| 47 |
-
|
| 48 |
-
"icon": icon
|
| 49 |
-
|
| 50 |
|
| 51 |
unique_id = str(uuid.uuid4())
|
| 52 |
|
| 53 |
-
|
| 54 |
filename_ngl = f"ngl_{unique_id}.html"
|
| 55 |
filepath_ngl = os.path.join("html_results", filename_ngl)
|
| 56 |
|
|
@@ -71,23 +76,21 @@ async def predict(
|
|
| 71 |
sequence_protein,
|
| 72 |
f"{affinity:.2f}",
|
| 73 |
atom_list,
|
| 74 |
-
lipinski_properties
|
| 75 |
)
|
| 76 |
|
| 77 |
-
return templates.TemplateResponse(
|
| 78 |
-
"
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
|
|
|
| 5 |
from fastapi.templating import Jinja2Templates
|
| 6 |
from fastapi.staticfiles import StaticFiles
|
| 7 |
from fastapi.responses import HTMLResponse
|
| 8 |
+
from utils import (
|
| 9 |
+
get_inference_data,
|
| 10 |
+
get_py3dmol_view,
|
| 11 |
+
save_standalone_ngl_html,
|
| 12 |
+
get_lipinski_properties,
|
| 13 |
+
get_gemini_explanation,
|
| 14 |
+
)
|
| 15 |
|
| 16 |
app = FastAPI()
|
| 17 |
|
|
|
|
| 22 |
|
| 23 |
templates = Jinja2Templates(directory="templates")
|
| 24 |
|
| 25 |
+
|
| 26 |
@app.get("/", response_class=HTMLResponse)
|
| 27 |
async def read_root(request: Request):
|
| 28 |
return templates.TemplateResponse("index.html", {"request": request})
|
|
|
|
| 30 |
|
| 31 |
@app.post("/predict", response_class=HTMLResponse)
|
| 32 |
async def predict(
|
| 33 |
+
request: Request, smiles_ligand: str = Form(...), sequence_protein: str = Form(...)
|
|
|
|
|
|
|
| 34 |
):
|
| 35 |
mol, importance, affinity = get_inference_data(smiles_ligand, sequence_protein)
|
| 36 |
|
| 37 |
atom_list = []
|
| 38 |
+
sorted_indices = sorted(
|
| 39 |
+
range(len(importance)), key=lambda k: importance[k], reverse=True
|
| 40 |
+
)
|
| 41 |
|
| 42 |
for idx in sorted_indices[:15]:
|
| 43 |
val = importance[idx]
|
| 44 |
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
|
| 45 |
|
| 46 |
icon = ""
|
| 47 |
+
if val >= 0.9:
|
| 48 |
+
icon = "π₯"
|
| 49 |
+
elif val >= 0.7:
|
| 50 |
+
icon = "β¨"
|
| 51 |
+
elif val >= 0.5:
|
| 52 |
+
icon = "β"
|
| 53 |
+
atom_list.append(
|
| 54 |
+
{"id": idx, "symbol": symbol, "score": f"{val:.3f}", "icon": icon}
|
| 55 |
+
)
|
| 56 |
|
| 57 |
unique_id = str(uuid.uuid4())
|
| 58 |
|
|
|
|
| 59 |
filename_ngl = f"ngl_{unique_id}.html"
|
| 60 |
filepath_ngl = os.path.join("html_results", filename_ngl)
|
| 61 |
|
|
|
|
| 76 |
sequence_protein,
|
| 77 |
f"{affinity:.2f}",
|
| 78 |
atom_list,
|
| 79 |
+
lipinski_properties,
|
| 80 |
)
|
| 81 |
|
| 82 |
+
return templates.TemplateResponse(
|
| 83 |
+
"index.html",
|
| 84 |
+
{
|
| 85 |
+
"request": request,
|
| 86 |
+
"result_ready": True,
|
| 87 |
+
"smiles": smiles_ligand,
|
| 88 |
+
"protein": sequence_protein,
|
| 89 |
+
"affinity": f"{affinity:.2f}",
|
| 90 |
+
"atom_list": atom_list,
|
| 91 |
+
"html_py3dmol": py3dmol_content,
|
| 92 |
+
"url_ngl": ngl_url_link,
|
| 93 |
+
"lipinski": lipinski_properties,
|
| 94 |
+
"ai_explanation": ai_explanation,
|
| 95 |
+
},
|
| 96 |
+
)
|
|
|
|
|
|
model_attention.py
CHANGED
|
@@ -20,7 +20,7 @@ class CrossAttentionLayer(nn.Module):
|
|
| 20 |
# Feedforward network for further processing, classical transformer style
|
| 21 |
self.ff = nn.Sequential(
|
| 22 |
nn.Linear(feature_dim, feature_dim * 4),
|
| 23 |
-
nn.GELU(),
|
| 24 |
nn.Dropout(dropout),
|
| 25 |
nn.Linear(feature_dim * 4, feature_dim),
|
| 26 |
)
|
|
|
|
| 20 |
# Feedforward network for further processing, classical transformer style
|
| 21 |
self.ff = nn.Sequential(
|
| 22 |
nn.Linear(feature_dim, feature_dim * 4),
|
| 23 |
+
nn.GELU(), # GELU works better with transformers
|
| 24 |
nn.Dropout(dropout),
|
| 25 |
nn.Linear(feature_dim * 4, feature_dim),
|
| 26 |
)
|
models/model_ep028_weighted_loss6.7715.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d64adee176c82baa3489ffbae33567c31f5cea7c7b254e782651507754e3dfc2
|
| 3 |
+
size 1810686
|
requirements.txt
CHANGED
|
@@ -1,10 +1,20 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
|
|
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
|
|
|
| 1 |
+
fastapi==0.128.0
|
| 2 |
+
uvicorn[standard]==0.38.0
|
| 3 |
+
python-multipart==0.0.21
|
| 4 |
+
requests==2.32.5
|
| 5 |
|
| 6 |
+
python-decouple==3.8
|
| 7 |
+
py3Dmol==2.5.4
|
| 8 |
|
| 9 |
+
numpy==2.2.6
|
| 10 |
+
pandas==2.3.3
|
| 11 |
+
|
| 12 |
+
rdkit==2025.9.1
|
| 13 |
+
|
| 14 |
+
google-genai==1.53.0
|
| 15 |
+
|
| 16 |
+
--extra-index-url https://download.pytorch.org/whl/cpu
|
| 17 |
+
|
| 18 |
+
torch>=2.5.0
|
| 19 |
+
torch-geometric>=2.7.0
|
| 20 |
|
EDA.ipynb β research/EDA.ipynb
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/Dataset_Preparation.py
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/dataset/classification/EDA.ipynb
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/dataset/classification/data_test.csv
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/dataset/classification/data_test.txt
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/dataset/classification/data_train.csv
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/dataset/classification/data_train.txt
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/model.py
RENAMED
|
File without changes
|
{GNN_classification β research/GNN_classification}/training.py
RENAMED
|
File without changes
|
GNNs__practice.ipynb β research/GNNs__practice.ipynb
RENAMED
|
File without changes
|
all_inferences.py β research/all_inferences.py
RENAMED
|
@@ -19,11 +19,14 @@ from model_attention import BindingAffinityModel
|
|
| 19 |
|
| 20 |
|
| 21 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
-
MODEL_PATH =
|
|
|
|
|
|
|
| 23 |
|
| 24 |
GAT_HEADS = 2
|
| 25 |
HIDDEN_CHANNELS = 256
|
| 26 |
|
|
|
|
| 27 |
def get_inference_data(ligand_smiles, protein_sequence, model_path):
|
| 28 |
"""
|
| 29 |
Returns:
|
|
@@ -46,8 +49,10 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path):
|
|
| 46 |
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
| 47 |
|
| 48 |
tokens = [get_protein_features(c) for c in protein_sequence]
|
| 49 |
-
if len(tokens) > 1200:
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
|
| 52 |
|
| 53 |
data = Data(x=x, edge_index=edge_index)
|
|
@@ -55,7 +60,9 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path):
|
|
| 55 |
num_features = x.shape[1]
|
| 56 |
|
| 57 |
# Model loading
|
| 58 |
-
model = BindingAffinityModel(
|
|
|
|
|
|
|
| 59 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 60 |
model.eval()
|
| 61 |
|
|
@@ -70,7 +77,9 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path):
|
|
| 70 |
|
| 71 |
# Normalize to [0, 1]
|
| 72 |
if importance.max() > 0:
|
| 73 |
-
importance = (importance - importance.min()) / (
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Noise reduction
|
| 76 |
importance[importance < 0.01] = 0
|
|
@@ -93,20 +102,22 @@ def print_atom_scores(mol, importance):
|
|
| 93 |
print(f"Atom {idx} ({symbol}): Importance = {score:.3f} {fire}")
|
| 94 |
|
| 95 |
|
| 96 |
-
|
| 97 |
def get_py3dmol(mol, importance, score):
|
| 98 |
|
| 99 |
view = py3Dmol.view(width=1000, height=800)
|
| 100 |
view.addModel(Chem.MolToMolBlock(mol), "sdf")
|
| 101 |
-
view.setBackgroundColor(
|
| 102 |
|
| 103 |
# 1. ΠΠΠΠΠΠ«Π Π‘Π’ΠΠΠ¬ (ΠΠ Π£ΠΠ’ΠΠΠΠ)
|
| 104 |
# ΠΠ°Π΄Π°Π΅ΠΌ Π΅Π΄ΠΈΠ½ΡΠΉ ΡΠ°Π·ΠΌΠ΅Ρ Π΄Π»Ρ Π²ΡΠ΅ΠΉ ΠΌΠΎΠ»Π΅ΠΊΡΠ»Ρ ΡΡΠ°Π·Ρ
|
| 105 |
# scale: 0.25 β ΠΎΠΏΡΠΈΠΌΠ°Π»ΡΠ½ΡΠΉ ΡΡΠ΅Π΄Π½ΠΈΠΉ ΡΠ°Π·ΠΌΠ΅Ρ
|
| 106 |
-
view.setStyle(
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
red_atoms = []
|
| 112 |
orange_atoms = []
|
|
@@ -130,48 +141,65 @@ def get_py3dmol(mol, importance, score):
|
|
| 130 |
if i in top_indices and val > 0.1:
|
| 131 |
pos = conf.GetAtomPosition(i)
|
| 132 |
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 133 |
-
labels_to_add.append(
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
|
|
|
|
|
|
| 137 |
|
| 138 |
# 3. ΠΠ ΠΠΠΠΠΠΠΠ Π‘Π’ΠΠΠΠ
|
| 139 |
# ΠΠ±ΡΠ°ΡΠΈ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅: scale Π²Π΅Π·Π΄Π΅ 0.25 (ΠΈΠ»ΠΈ 0.28, ΡΡΠΎΠ±Ρ ΡΡΡΡ Π²ΡΠ΄Π΅Π»ΠΈΡΡ ΡΠ²Π΅ΡΠ½ΡΠ΅)
|
| 140 |
# ΠΡ ΠΌΠ΅Π½ΡΠ΅ΠΌ Π’ΠΠΠ¬ΠΠ Π¦ΠΠΠ’.
|
| 141 |
|
| 142 |
if red_atoms:
|
| 143 |
-
view.addStyle(
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
if orange_atoms:
|
| 149 |
-
view.addStyle(
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
|
|
|
|
|
|
|
|
|
| 153 |
|
| 154 |
if blue_atoms:
|
| 155 |
-
view.addStyle(
|
| 156 |
-
|
| 157 |
-
|
| 158 |
|
| 159 |
# 4. ΠΠΠ’ΠΠ
|
| 160 |
for label in labels_to_add:
|
| 161 |
-
view.addLabel(
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
view.zoomTo()
|
| 173 |
-
view.addLabel(
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
return view
|
| 177 |
|
|
@@ -190,23 +218,31 @@ def get_ngl(mol, importance):
|
|
| 190 |
view = nv.NGLWidget(structure)
|
| 191 |
view.clear_representations()
|
| 192 |
|
| 193 |
-
view.add_representation(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 194 |
|
| 195 |
indices_sorted = np.argsort(importance)[::-1]
|
| 196 |
top_indices = indices_sorted[:15]
|
| 197 |
|
| 198 |
selection_str = "@" + ",".join(map(str, top_indices))
|
| 199 |
-
view.add_representation(
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
|
|
|
| 206 |
|
| 207 |
view.center()
|
| 208 |
return view
|
| 209 |
|
|
|
|
| 210 |
if __name__ == "__main__":
|
| 211 |
smiles = "COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1"
|
| 212 |
protein = "PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF"
|
|
@@ -222,4 +258,3 @@ if __name__ == "__main__":
|
|
| 222 |
|
| 223 |
ngl_widget = get_ngl(mol, importance)
|
| 224 |
nv.write_html(file_name_ngl, ngl_widget)
|
| 225 |
-
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
MODEL_PATH = (
|
| 23 |
+
"runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
|
| 24 |
+
)
|
| 25 |
|
| 26 |
GAT_HEADS = 2
|
| 27 |
HIDDEN_CHANNELS = 256
|
| 28 |
|
| 29 |
+
|
| 30 |
def get_inference_data(ligand_smiles, protein_sequence, model_path):
|
| 31 |
"""
|
| 32 |
Returns:
|
|
|
|
| 49 |
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
| 50 |
|
| 51 |
tokens = [get_protein_features(c) for c in protein_sequence]
|
| 52 |
+
if len(tokens) > 1200:
|
| 53 |
+
tokens = tokens[:1200]
|
| 54 |
+
else:
|
| 55 |
+
tokens.extend([0] * (1200 - len(tokens)))
|
| 56 |
protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
|
| 57 |
|
| 58 |
data = Data(x=x, edge_index=edge_index)
|
|
|
|
| 60 |
num_features = x.shape[1]
|
| 61 |
|
| 62 |
# Model loading
|
| 63 |
+
model = BindingAffinityModel(
|
| 64 |
+
num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS
|
| 65 |
+
).to(DEVICE)
|
| 66 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 67 |
model.eval()
|
| 68 |
|
|
|
|
| 77 |
|
| 78 |
# Normalize to [0, 1]
|
| 79 |
if importance.max() > 0:
|
| 80 |
+
importance = (importance - importance.min()) / (
|
| 81 |
+
importance.max() - importance.min()
|
| 82 |
+
)
|
| 83 |
|
| 84 |
# Noise reduction
|
| 85 |
importance[importance < 0.01] = 0
|
|
|
|
| 102 |
print(f"Atom {idx} ({symbol}): Importance = {score:.3f} {fire}")
|
| 103 |
|
| 104 |
|
|
|
|
| 105 |
def get_py3dmol(mol, importance, score):
|
| 106 |
|
| 107 |
view = py3Dmol.view(width=1000, height=800)
|
| 108 |
view.addModel(Chem.MolToMolBlock(mol), "sdf")
|
| 109 |
+
view.setBackgroundColor("white")
|
| 110 |
|
| 111 |
# 1. ΠΠΠΠΠΠ«Π Π‘Π’ΠΠΠ¬ (ΠΠ Π£ΠΠ’ΠΠΠΠ)
|
| 112 |
# ΠΠ°Π΄Π°Π΅ΠΌ Π΅Π΄ΠΈΠ½ΡΠΉ ΡΠ°Π·ΠΌΠ΅Ρ Π΄Π»Ρ Π²ΡΠ΅ΠΉ ΠΌΠΎΠ»Π΅ΠΊΡΠ»Ρ ΡΡΠ°Π·Ρ
|
| 113 |
# scale: 0.25 β ΠΎΠΏΡΠΈΠΌΠ°Π»ΡΠ½ΡΠΉ ΡΡΠ΅Π΄Π½ΠΈΠΉ ΡΠ°Π·ΠΌΠ΅Ρ
|
| 114 |
+
view.setStyle(
|
| 115 |
+
{},
|
| 116 |
+
{
|
| 117 |
+
"stick": {"color": "#cccccc", "radius": 0.1},
|
| 118 |
+
"sphere": {"color": "#cccccc", "scale": 0.25},
|
| 119 |
+
},
|
| 120 |
+
)
|
| 121 |
|
| 122 |
red_atoms = []
|
| 123 |
orange_atoms = []
|
|
|
|
| 141 |
if i in top_indices and val > 0.1:
|
| 142 |
pos = conf.GetAtomPosition(i)
|
| 143 |
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 144 |
+
labels_to_add.append(
|
| 145 |
+
{
|
| 146 |
+
"text": f"{i}:{symbol}:{val:.2f}",
|
| 147 |
+
"pos": {"x": pos.x, "y": pos.y, "z": pos.z},
|
| 148 |
+
}
|
| 149 |
+
)
|
| 150 |
|
| 151 |
# 3. ΠΠ ΠΠΠΠΠΠΠΠ Π‘Π’ΠΠΠΠ
|
| 152 |
# ΠΠ±ΡΠ°ΡΠΈ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅: scale Π²Π΅Π·Π΄Π΅ 0.25 (ΠΈΠ»ΠΈ 0.28, ΡΡΠΎΠ±Ρ ΡΡΡΡ Π²ΡΠ΄Π΅Π»ΠΈΡΡ ΡΠ²Π΅ΡΠ½ΡΠ΅)
|
| 153 |
# ΠΡ ΠΌΠ΅Π½ΡΠ΅ΠΌ Π’ΠΠΠ¬ΠΠ Π¦ΠΠΠ’.
|
| 154 |
|
| 155 |
if red_atoms:
|
| 156 |
+
view.addStyle(
|
| 157 |
+
{"serial": red_atoms},
|
| 158 |
+
{
|
| 159 |
+
"sphere": {"color": "#FF0000", "scale": 0.28},
|
| 160 |
+
"stick": {"color": "#FF0000", "radius": 0.12},
|
| 161 |
+
},
|
| 162 |
+
)
|
| 163 |
|
| 164 |
if orange_atoms:
|
| 165 |
+
view.addStyle(
|
| 166 |
+
{"serial": orange_atoms},
|
| 167 |
+
{
|
| 168 |
+
"sphere": {"color": "#FF8C00", "scale": 0.28},
|
| 169 |
+
"stick": {"color": "#FF8C00", "radius": 0.12},
|
| 170 |
+
},
|
| 171 |
+
)
|
| 172 |
|
| 173 |
if blue_atoms:
|
| 174 |
+
view.addStyle(
|
| 175 |
+
{"serial": blue_atoms}, {"sphere": {"color": "#7777FF", "scale": 0.28}}
|
| 176 |
+
)
|
| 177 |
|
| 178 |
# 4. ΠΠΠ’ΠΠ
|
| 179 |
for label in labels_to_add:
|
| 180 |
+
view.addLabel(
|
| 181 |
+
label["text"],
|
| 182 |
+
{
|
| 183 |
+
"position": label["pos"],
|
| 184 |
+
"fontSize": 14,
|
| 185 |
+
"fontColor": "white",
|
| 186 |
+
"backgroundColor": "black",
|
| 187 |
+
"backgroundOpacity": 0.7,
|
| 188 |
+
"borderThickness": 0,
|
| 189 |
+
"inFront": True,
|
| 190 |
+
"showBackground": True,
|
| 191 |
+
},
|
| 192 |
+
)
|
| 193 |
|
| 194 |
view.zoomTo()
|
| 195 |
+
view.addLabel(
|
| 196 |
+
f"Predicted pKd: {float(score):.2f}",
|
| 197 |
+
{
|
| 198 |
+
"position": {"x": -5, "y": 10, "z": 0},
|
| 199 |
+
"backgroundColor": "black",
|
| 200 |
+
"fontColor": "white",
|
| 201 |
+
},
|
| 202 |
+
)
|
| 203 |
|
| 204 |
return view
|
| 205 |
|
|
|
|
| 218 |
view = nv.NGLWidget(structure)
|
| 219 |
view.clear_representations()
|
| 220 |
|
| 221 |
+
view.add_representation(
|
| 222 |
+
"ball+stick",
|
| 223 |
+
colorScheme="bfactor",
|
| 224 |
+
colorScale=["blue", "white", "red"],
|
| 225 |
+
colorDomain=[10, 80],
|
| 226 |
+
radiusScale=1.0,
|
| 227 |
+
)
|
| 228 |
|
| 229 |
indices_sorted = np.argsort(importance)[::-1]
|
| 230 |
top_indices = indices_sorted[:15]
|
| 231 |
|
| 232 |
selection_str = "@" + ",".join(map(str, top_indices))
|
| 233 |
+
view.add_representation(
|
| 234 |
+
"label",
|
| 235 |
+
selection=selection_str, # ΠΠΎΠ΄ΠΏΠΈΡΡΠ²Π°Π΅ΠΌ ΡΠΎΠ»ΡΠΊΠΎ ΠΈΠ·Π±ΡΠ°Π½Π½ΡΡ
|
| 236 |
+
labelType="atomindex", # ΠΠΎΠΊΠ°Π·ΡΠ²Π°ΡΡ ΠΠ½Π΄Π΅ΠΊΡ (0, 1, 2...)
|
| 237 |
+
color="black", # Π§Π΅ΡΠ½ΡΠΉ ΡΠ΅ΠΊΡΡ
|
| 238 |
+
radius=2.0, # Π Π°Π·ΠΌΠ΅Ρ ΡΡΠΈΡΡΠ° (ΠΏΠΎΠΏΡΠΎΠ±ΡΠΉΡΠ΅ 1.5 - 3.0)
|
| 239 |
+
zOffset=1.0,
|
| 240 |
+
) # Π§ΡΡΡ ΡΠ΄Π²ΠΈΠ½ΡΡΡ ΠΊ ΠΊΠ°ΠΌΠ΅ΡΠ΅
|
| 241 |
|
| 242 |
view.center()
|
| 243 |
return view
|
| 244 |
|
| 245 |
+
|
| 246 |
if __name__ == "__main__":
|
| 247 |
smiles = "COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1"
|
| 248 |
protein = "PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF"
|
|
|
|
| 258 |
|
| 259 |
ngl_widget = get_ngl(mol, importance)
|
| 260 |
nv.write_html(file_name_ngl, ngl_widget)
|
|
|
research/dataset.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from rdkit import Chem, rdBase
|
| 5 |
+
from torch_geometric.data import Data
|
| 6 |
+
from torch.utils.data import Dataset, random_split
|
| 7 |
+
|
| 8 |
+
rdBase.DisableLog("rdApp.*")
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def one_of_k_encoding(x, allowable_set):
|
| 12 |
+
# last position - unknown
|
| 13 |
+
if x not in allowable_set:
|
| 14 |
+
x = allowable_set[-1]
|
| 15 |
+
return list(map(lambda s: x == s, allowable_set))
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def get_atom_features(atom):
|
| 19 |
+
symbols_list = [
|
| 20 |
+
"C",
|
| 21 |
+
"N",
|
| 22 |
+
"O",
|
| 23 |
+
"S",
|
| 24 |
+
"F",
|
| 25 |
+
"Si",
|
| 26 |
+
"P",
|
| 27 |
+
"Cl",
|
| 28 |
+
"Br",
|
| 29 |
+
"Mg",
|
| 30 |
+
"Na",
|
| 31 |
+
"Ca",
|
| 32 |
+
"Fe",
|
| 33 |
+
"As",
|
| 34 |
+
"Al",
|
| 35 |
+
"I",
|
| 36 |
+
"B",
|
| 37 |
+
"V",
|
| 38 |
+
"K",
|
| 39 |
+
"Tl",
|
| 40 |
+
"Yb",
|
| 41 |
+
"Sb",
|
| 42 |
+
"Sn",
|
| 43 |
+
"Ag",
|
| 44 |
+
"Pd",
|
| 45 |
+
"Co",
|
| 46 |
+
"Se",
|
| 47 |
+
"Ti",
|
| 48 |
+
"Zn",
|
| 49 |
+
"H",
|
| 50 |
+
"Li",
|
| 51 |
+
"Ge",
|
| 52 |
+
"Cu",
|
| 53 |
+
"Au",
|
| 54 |
+
"Ni",
|
| 55 |
+
"Cd",
|
| 56 |
+
"In",
|
| 57 |
+
"Mn",
|
| 58 |
+
"Zr",
|
| 59 |
+
"Cr",
|
| 60 |
+
"Pt",
|
| 61 |
+
"Hg",
|
| 62 |
+
"Pb",
|
| 63 |
+
"Unknown",
|
| 64 |
+
]
|
| 65 |
+
degrees_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 66 |
+
numhs_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 67 |
+
implicit_valences_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 68 |
+
|
| 69 |
+
formal_charge_list = [-2, -1, 0, 1, 2]
|
| 70 |
+
chirality_list = [
|
| 71 |
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
| 72 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
| 73 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
| 74 |
+
Chem.rdchem.ChiralType.CHI_OTHER,
|
| 75 |
+
]
|
| 76 |
+
return np.array(
|
| 77 |
+
# Type of atom (Symbol)
|
| 78 |
+
one_of_k_encoding(atom.GetSymbol(), symbols_list)
|
| 79 |
+
+
|
| 80 |
+
# Number of neighbours (Degree)
|
| 81 |
+
one_of_k_encoding(atom.GetDegree(), degrees_list)
|
| 82 |
+
+
|
| 83 |
+
# Number of hydrogen atoms (Implicit Hs) - bond donors
|
| 84 |
+
one_of_k_encoding(atom.GetTotalNumHs(), numhs_list)
|
| 85 |
+
+
|
| 86 |
+
# Valence - chemical potential
|
| 87 |
+
one_of_k_encoding(atom.GetImplicitValence(), implicit_valences_list)
|
| 88 |
+
+
|
| 89 |
+
# Hybridization - so important for 3d structure, sp2 - Trigonal planar, sp3 - Tetrahedral
|
| 90 |
+
one_of_k_encoding(
|
| 91 |
+
atom.GetHybridization(),
|
| 92 |
+
[
|
| 93 |
+
Chem.rdchem.HybridizationType.SP,
|
| 94 |
+
Chem.rdchem.HybridizationType.SP2,
|
| 95 |
+
Chem.rdchem.HybridizationType.SP3,
|
| 96 |
+
Chem.rdchem.HybridizationType.SP3D,
|
| 97 |
+
Chem.rdchem.HybridizationType.SP3D2,
|
| 98 |
+
"other",
|
| 99 |
+
],
|
| 100 |
+
)
|
| 101 |
+
+
|
| 102 |
+
# Aromaticity (Boolean)
|
| 103 |
+
[atom.GetIsAromatic()]
|
| 104 |
+
+
|
| 105 |
+
# Formal Charge
|
| 106 |
+
one_of_k_encoding(atom.GetFormalCharge(), formal_charge_list)
|
| 107 |
+
+
|
| 108 |
+
# Chirality (Geometry)
|
| 109 |
+
one_of_k_encoding(atom.GetChiralTag(), chirality_list)
|
| 110 |
+
+
|
| 111 |
+
# Is in ring (Boolean)
|
| 112 |
+
[atom.IsInRing()]
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_protein_features(char):
|
| 117 |
+
prot_vocab = {
|
| 118 |
+
"A": 1,
|
| 119 |
+
"R": 2,
|
| 120 |
+
"N": 3,
|
| 121 |
+
"D": 4,
|
| 122 |
+
"C": 5,
|
| 123 |
+
"Q": 6,
|
| 124 |
+
"E": 7,
|
| 125 |
+
"G": 8,
|
| 126 |
+
"H": 9,
|
| 127 |
+
"I": 10,
|
| 128 |
+
"L": 11,
|
| 129 |
+
"K": 12,
|
| 130 |
+
"M": 13,
|
| 131 |
+
"F": 14,
|
| 132 |
+
"P": 15,
|
| 133 |
+
"S": 16,
|
| 134 |
+
"T": 17,
|
| 135 |
+
"W": 18,
|
| 136 |
+
"Y": 19,
|
| 137 |
+
"V": 20,
|
| 138 |
+
"X": 21,
|
| 139 |
+
"Z": 21,
|
| 140 |
+
"B": 21,
|
| 141 |
+
"PAD": 0,
|
| 142 |
+
"UNK": 21,
|
| 143 |
+
}
|
| 144 |
+
return prot_vocab.get(char, prot_vocab["UNK"])
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class BindingDataset(Dataset):
|
| 148 |
+
def __init__(self, dataframe, max_seq_length=1000):
|
| 149 |
+
self.data = dataframe
|
| 150 |
+
self.max_seq_length = (
|
| 151 |
+
max_seq_length # Define a maximum sequence length for padding/truncation
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
def __len__(self):
|
| 155 |
+
return len(self.data)
|
| 156 |
+
|
| 157 |
+
def __getitem__(self, idx):
|
| 158 |
+
row = self.data.iloc[idx]
|
| 159 |
+
smiles = row["smiles"]
|
| 160 |
+
sequence = row["sequence"]
|
| 161 |
+
affinity = row["affinity"]
|
| 162 |
+
|
| 163 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 164 |
+
if mol is None:
|
| 165 |
+
return None
|
| 166 |
+
|
| 167 |
+
# Ligand (Graph)
|
| 168 |
+
# Nodes
|
| 169 |
+
atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
|
| 170 |
+
x = torch.tensor(np.array(atom_features), dtype=torch.float)
|
| 171 |
+
|
| 172 |
+
# Edges
|
| 173 |
+
edge_indexes = []
|
| 174 |
+
for bond in mol.GetBonds():
|
| 175 |
+
i = bond.GetBeginAtomIdx()
|
| 176 |
+
j = bond.GetEndAtomIdx()
|
| 177 |
+
edge_indexes.append((i, j))
|
| 178 |
+
edge_indexes.append((j, i))
|
| 179 |
+
|
| 180 |
+
# t - transpose, [num_of_edges, 2] -> [2, num_of_edges]
|
| 181 |
+
# contiguous - take the virtually transposed tensor and make its physical copy and lay bytes sequentially
|
| 182 |
+
|
| 183 |
+
edge_index = torch.tensor(edge_indexes, dtype=torch.long).t().contiguous()
|
| 184 |
+
|
| 185 |
+
# Protein (Sequence, tensor of integers)
|
| 186 |
+
tokens = [get_protein_features(char) for char in sequence]
|
| 187 |
+
if len(tokens) > self.max_seq_length:
|
| 188 |
+
tokens = tokens[: self.max_seq_length]
|
| 189 |
+
else:
|
| 190 |
+
tokens.extend(
|
| 191 |
+
[get_protein_features("PAD")] * (self.max_seq_length - len(tokens))
|
| 192 |
+
)
|
| 193 |
+
protein_tensor = torch.tensor(tokens, dtype=torch.long)
|
| 194 |
+
|
| 195 |
+
# Affinity
|
| 196 |
+
y = torch.tensor([affinity], dtype=torch.float)
|
| 197 |
+
return Data(x=x, edge_index=edge_index, protein_seq=protein_tensor, y=y)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
if __name__ == "__main__":
|
| 201 |
+
dataset = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 202 |
+
dataset = BindingDataset(dataset)
|
| 203 |
+
|
| 204 |
+
train_size = int(0.8 * len(dataset))
|
| 205 |
+
test_size = len(dataset) - train_size
|
| 206 |
+
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
|
| 207 |
+
|
| 208 |
+
print(len(train_dataset))
|
| 209 |
+
print(len(test_dataset))
|
dataset_preparation.py β research/dataset_preparation.py
RENAMED
|
File without changes
|
inference.py β research/inference.py
RENAMED
|
File without changes
|
inference_attention.py β research/inference_attention.py
RENAMED
|
File without changes
|
research/loss.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class WeightedMSELoss(nn.Module):
|
| 6 |
+
def __init__(self):
|
| 7 |
+
super().__init__()
|
| 8 |
+
|
| 9 |
+
def forward(self, prediction, target):
|
| 10 |
+
squared_errors = (prediction - target) ** 2
|
| 11 |
+
weights = torch.ones_like(target)
|
| 12 |
+
|
| 13 |
+
weights[target >= 6.0] = 2.0 # Fine x2 pKd > 6 good binding
|
| 14 |
+
weights[target >= 7.0] = 5.0 # Fine x5 pKd > 7 great binding
|
| 15 |
+
weights[target >= 8.0] = 10.0 # Fine x10 pKd > 8 super binding
|
| 16 |
+
|
| 17 |
+
weighted_loss = squared_errors * weights
|
| 18 |
+
return torch.mean(weighted_loss)
|
research/model.py
ADDED
|
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PositionalEncoding(nn.Module):
|
| 8 |
+
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.d_model = d_model
|
| 11 |
+
self.seq_len = seq_len
|
| 12 |
+
self.dropout = nn.Dropout(dropout)
|
| 13 |
+
|
| 14 |
+
# Create a matrix of shape (seq_len, d_model)
|
| 15 |
+
pe = torch.zeros(seq_len, d_model)
|
| 16 |
+
|
| 17 |
+
# Create a vector of shape (seq_len, 1)
|
| 18 |
+
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
|
| 19 |
+
1
|
| 20 |
+
) # (Seq_len, 1)
|
| 21 |
+
# Compute the positional encodings once in log space.
|
| 22 |
+
div_term = torch.exp(
|
| 23 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 24 |
+
)
|
| 25 |
+
# Apply the sin to even positions
|
| 26 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 27 |
+
# Apply the cos to odd positions
|
| 28 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 29 |
+
|
| 30 |
+
pe = pe.unsqueeze(0) # (1, Seq_len, d_model) batch dimension
|
| 31 |
+
self.register_buffer("pe", pe)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
# x: [batch_size, seq_len, d_model]
|
| 35 |
+
x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
|
| 36 |
+
return self.dropout(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# class LigandGNN(nn.Module): # GCN CONV
|
| 40 |
+
# def __init__(self, input_dim, hidden_channels):
|
| 41 |
+
# super().__init__()
|
| 42 |
+
# self.hidden_channels = hidden_channels
|
| 43 |
+
#
|
| 44 |
+
# self.conv1 = GCNConv(input_dim, hidden_channels)
|
| 45 |
+
# self.conv2 = GCNConv(hidden_channels, hidden_channels)
|
| 46 |
+
# self.conv3 = GCNConv(hidden_channels, hidden_channels)
|
| 47 |
+
# self.dropout = nn.Dropout(0.2)
|
| 48 |
+
#
|
| 49 |
+
# def forward(self, x, edge_index, batch):
|
| 50 |
+
# x = self.conv1(x, edge_index)
|
| 51 |
+
# x = x.relu()
|
| 52 |
+
# x = self.dropout(x)
|
| 53 |
+
#
|
| 54 |
+
# x = self.conv2(x, edge_index)
|
| 55 |
+
# x = x.relu()
|
| 56 |
+
# x = self.conv3(x, edge_index)
|
| 57 |
+
# x = self.dropout(x)
|
| 58 |
+
#
|
| 59 |
+
# # Averaging nodes and got the molecula vector
|
| 60 |
+
# x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
|
| 61 |
+
# return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class LigandGNN(nn.Module):
|
| 65 |
+
def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
|
| 66 |
+
super().__init__()
|
| 67 |
+
# Heads=4 means we use 4 attention heads
|
| 68 |
+
# Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
|
| 69 |
+
self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
|
| 70 |
+
self.conv2 = GATConv(
|
| 71 |
+
hidden_channels, hidden_channels, heads=heads, concat=False
|
| 72 |
+
)
|
| 73 |
+
self.conv3 = GATConv(
|
| 74 |
+
hidden_channels, hidden_channels, heads=heads, concat=False
|
| 75 |
+
)
|
| 76 |
+
self.dropout = nn.Dropout(dropout)
|
| 77 |
+
|
| 78 |
+
def forward(self, x, edge_index, batch):
|
| 79 |
+
x = self.conv1(x, edge_index)
|
| 80 |
+
x = x.relu()
|
| 81 |
+
x = self.dropout(x)
|
| 82 |
+
|
| 83 |
+
x = self.conv2(x, edge_index)
|
| 84 |
+
x = x.relu()
|
| 85 |
+
x = self.dropout(x)
|
| 86 |
+
|
| 87 |
+
x = self.conv3(x, edge_index)
|
| 88 |
+
|
| 89 |
+
# Global Mean Pooling
|
| 90 |
+
x = global_mean_pool(x, batch)
|
| 91 |
+
return x
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ProteinTransformer(nn.Module):
|
| 95 |
+
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
|
| 96 |
+
super().__init__()
|
| 97 |
+
self.d_model = d_model
|
| 98 |
+
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 99 |
+
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
|
| 100 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 101 |
+
d_model=d_model, nhead=h, batch_first=True
|
| 102 |
+
)
|
| 103 |
+
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
|
| 104 |
+
|
| 105 |
+
self.fc = nn.Linear(d_model, output_dim)
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
# x: [batch_size, seq_len]
|
| 109 |
+
padding_mask = x == 0 # mask for PAD tokens
|
| 110 |
+
x = self.embedding(x) * math.sqrt(self.d_model)
|
| 111 |
+
x = self.pos_encoder(x)
|
| 112 |
+
x = self.transformer(x, src_key_padding_mask=padding_mask)
|
| 113 |
+
|
| 114 |
+
mask = (~padding_mask).float().unsqueeze(-1)
|
| 115 |
+
x = x * mask
|
| 116 |
+
|
| 117 |
+
sum_x = x.sum(dim=1) # Global average pooling
|
| 118 |
+
token_counts = mask.sum(dim=1).clamp(min=1e-9)
|
| 119 |
+
x = sum_x / token_counts
|
| 120 |
+
x = self.fc(x)
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class BindingAffinityModel(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
|
| 127 |
+
):
|
| 128 |
+
super().__init__()
|
| 129 |
+
# Tower 1 - Ligand GNN
|
| 130 |
+
self.ligand_gnn = LigandGNN(
|
| 131 |
+
input_dim=num_node_features,
|
| 132 |
+
hidden_channels=hidden_channels,
|
| 133 |
+
heads=gat_heads,
|
| 134 |
+
dropout=dropout,
|
| 135 |
+
)
|
| 136 |
+
# Tower 2 - Protein Transformer
|
| 137 |
+
self.protein_transformer = ProteinTransformer(
|
| 138 |
+
vocab_size=26,
|
| 139 |
+
d_model=hidden_channels,
|
| 140 |
+
output_dim=hidden_channels,
|
| 141 |
+
dropout=dropout,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.head = nn.Sequential(
|
| 145 |
+
nn.Linear(hidden_channels * 2, hidden_channels),
|
| 146 |
+
nn.ReLU(),
|
| 147 |
+
nn.Dropout(dropout),
|
| 148 |
+
nn.Linear(hidden_channels, 1),
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
def forward(self, x, edge_index, batch, protein_seq):
|
| 152 |
+
ligand_vec = self.ligand_gnn(x, edge_index, batch)
|
| 153 |
+
batch_size = batch.max().item() + 1
|
| 154 |
+
protein_seq = protein_seq.view(batch_size, -1)
|
| 155 |
+
|
| 156 |
+
protein_vec = self.protein_transformer(protein_seq)
|
| 157 |
+
combined = torch.cat([ligand_vec, protein_vec], dim=1)
|
| 158 |
+
return self.head(combined)
|
research/model_attention.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch_geometric.nn import GATConv
|
| 4 |
+
from torch_geometric.utils import to_dense_batch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CrossAttentionLayer(nn.Module):
|
| 9 |
+
def __init__(self, feature_dim, num_heads=4, dropout=0.1):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# Main attention layer
|
| 12 |
+
# Feature dim is the dimension of the hidden features
|
| 13 |
+
self.attention = nn.MultiheadAttention(
|
| 14 |
+
feature_dim, num_heads, dropout=dropout, batch_first=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Normalization layer for stabilizing training
|
| 18 |
+
self.norm = nn.LayerNorm(feature_dim)
|
| 19 |
+
|
| 20 |
+
# Feedforward network for further processing, classical transformer style
|
| 21 |
+
self.ff = nn.Sequential(
|
| 22 |
+
nn.Linear(feature_dim, feature_dim * 4),
|
| 23 |
+
nn.GELU(), # GELU works better with transformers
|
| 24 |
+
nn.Dropout(dropout),
|
| 25 |
+
nn.Linear(feature_dim * 4, feature_dim),
|
| 26 |
+
)
|
| 27 |
+
self.norm_ff = nn.LayerNorm(feature_dim)
|
| 28 |
+
self.last_attention_weights = None
|
| 29 |
+
|
| 30 |
+
def forward(self, ligand_features, protein_features, key_padding_mask=None):
|
| 31 |
+
# ligand_features: [Batch, Atoms, Dim] - atoms
|
| 32 |
+
# protein_features: [Batch, Residues, Dim] - amino acids
|
| 33 |
+
# Cross attention:
|
| 34 |
+
# Query = Ligand (What we want to find out)
|
| 35 |
+
# Key, Value = Protein (Where we look for information)
|
| 36 |
+
# Result: "Ligand enriched with knowledge about proteins"
|
| 37 |
+
attention_output, attn_weights = self.attention(
|
| 38 |
+
query=ligand_features,
|
| 39 |
+
key=protein_features,
|
| 40 |
+
value=protein_features,
|
| 41 |
+
key_padding_mask=key_padding_mask,
|
| 42 |
+
need_weights=True,
|
| 43 |
+
average_attn_weights=True,
|
| 44 |
+
)
|
| 45 |
+
self.last_attention_weights = attn_weights.detach().cpu()
|
| 46 |
+
|
| 47 |
+
# Residual connection (x + attention(x)) and normalization
|
| 48 |
+
ligand_features = self.norm(ligand_features + attention_output)
|
| 49 |
+
|
| 50 |
+
# Feedforward network with residual connection and normalization
|
| 51 |
+
ff_output = self.ff(ligand_features)
|
| 52 |
+
ligand_features = self.norm_ff(ligand_features + ff_output)
|
| 53 |
+
|
| 54 |
+
return ligand_features
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BindingAffinityModel(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.dropout = dropout
|
| 63 |
+
self.hidden_channels = hidden_channels
|
| 64 |
+
|
| 65 |
+
# Tower 1 - Ligand GNN with GAT layers, using 3 GAT layers, so that every atom can "see" up to 3 bonds away,
|
| 66 |
+
# Attention allows to measure the importance of the neighbours
|
| 67 |
+
self.gat1 = GATConv(
|
| 68 |
+
num_node_features, hidden_channels, heads=gat_heads, concat=False
|
| 69 |
+
)
|
| 70 |
+
self.gat2 = GATConv(
|
| 71 |
+
hidden_channels, hidden_channels, heads=gat_heads, concat=False
|
| 72 |
+
)
|
| 73 |
+
self.gat3 = GATConv(
|
| 74 |
+
hidden_channels, hidden_channels, heads=gat_heads, concat=False
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Tower 2 - Protein Transformer, 22 = 21 amino acids + 1 padding token PAD
|
| 78 |
+
self.protein_embedding = nn.Embedding(22, hidden_channels)
|
| 79 |
+
# Additional positional encoding (simple linear) to give the model information about the order
|
| 80 |
+
self.prot_conv = nn.Conv1d(
|
| 81 |
+
hidden_channels, hidden_channels, kernel_size=3, padding=1
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# Cross-Attention Layer, atoms attending to amino acids
|
| 85 |
+
self.cross_attention = CrossAttentionLayer(
|
| 86 |
+
feature_dim=hidden_channels, num_heads=4, dropout=dropout
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.fc1 = nn.Linear(hidden_channels, hidden_channels)
|
| 90 |
+
self.fc2 = nn.Linear(hidden_channels, 1) # Final output for regression, pKd
|
| 91 |
+
|
| 92 |
+
def forward(self, x, edge_index, batch, protein_seq):
|
| 93 |
+
# Ligand GNN forward pass (Graph -> Node Embeddings)
|
| 94 |
+
x = F.elu(self.gat1(x, edge_index))
|
| 95 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 96 |
+
|
| 97 |
+
x = F.elu(self.gat2(x, edge_index))
|
| 98 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 99 |
+
|
| 100 |
+
x = F.elu(self.gat3(x, edge_index)) # [Total_Atoms, Hidden_Channels]
|
| 101 |
+
|
| 102 |
+
# Convert graph into tensor [Batch, Max_Atoms, Hidden_Channels]
|
| 103 |
+
# to_dense_batch adds zeros paddings where necessary to the size of the largest graph in the batch
|
| 104 |
+
ligand_dense, ligand_mask = to_dense_batch(x, batch)
|
| 105 |
+
# ligand_dense: [Batch, Max_Atoms, Hidden_Channels]
|
| 106 |
+
# ligand_mask: [Batch, Max_Atoms] True where there is real atom, False where there is padding
|
| 107 |
+
|
| 108 |
+
batch_size = ligand_dense.size(0)
|
| 109 |
+
protein_seq = protein_seq.view(batch_size, -1) # [Batch, Seq_Len]
|
| 110 |
+
|
| 111 |
+
# Protein forward pass protein_seq: [Batch, Seq_Len]
|
| 112 |
+
p = self.protein_embedding(protein_seq) # [Batch, Seq_Len, Hidden_Channels]
|
| 113 |
+
|
| 114 |
+
# A simple convolution to understand local context in amino acids
|
| 115 |
+
p = p.permute(0, 2, 1) # Change to [Batch, Hidden_Channels, Seq_Len] for Conv1d
|
| 116 |
+
p = F.relu(self.prot_conv(p))
|
| 117 |
+
p = p.permute(0, 2, 1) # [Batch, Seq, Hidden_Channels]
|
| 118 |
+
|
| 119 |
+
# Mask for protein (where PAD=0, True, but MHA needs True where IGNOREME)
|
| 120 |
+
# In Pytorch MHA, the key_padding_mask should be True where we want to ignore
|
| 121 |
+
protein_pad_mask = protein_seq == 0
|
| 122 |
+
|
| 123 |
+
# Cross-Attention
|
| 124 |
+
x_cross = self.cross_attention(
|
| 125 |
+
ligand_dense, p, key_padding_mask=protein_pad_mask
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
# Pooling over atoms to get a single vector per molecule, considering only real atoms, ignoring paddings
|
| 129 |
+
# ligand mask True where real atom, False where padding
|
| 130 |
+
mask_expanded = ligand_mask.unsqueeze(-1) # [Batch, Max_Atoms, 1]
|
| 131 |
+
|
| 132 |
+
# Zero out the padded atom features
|
| 133 |
+
x_cross = x_cross * mask_expanded
|
| 134 |
+
|
| 135 |
+
# Sum the features of real atoms / number of real atoms to get the mean
|
| 136 |
+
sum_features = torch.sum(x_cross, dim=1) # [Batch, Hidden_Channels]
|
| 137 |
+
num_atoms = torch.sum(mask_expanded, dim=1) # [Batch, 1]
|
| 138 |
+
pooled_x = sum_features / (num_atoms + 1e-6) # Avoid division by zero
|
| 139 |
+
|
| 140 |
+
# MLP Head
|
| 141 |
+
out = F.relu(self.fc1(pooled_x))
|
| 142 |
+
out = F.dropout(out, p=self.dropout, training=self.training)
|
| 143 |
+
out = self.fc2(out)
|
| 144 |
+
return out
|
model_pl.py β research/model_pl.py
RENAMED
|
File without changes
|
optuna_train.py β research/optuna_train.py
RENAMED
|
File without changes
|
optuna_train_attention.py β research/optuna_train_attention.py
RENAMED
|
@@ -7,11 +7,12 @@ import numpy as np
|
|
| 7 |
from torch_geometric.loader import DataLoader
|
| 8 |
from torch.utils.data import random_split
|
| 9 |
from dataset import BindingDataset
|
|
|
|
| 10 |
from model_attention import BindingAffinityModel
|
| 11 |
|
| 12 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
N_TRIALS = 50
|
| 14 |
-
MAX_EPOCHS_PER_TRIAL =
|
| 15 |
LOG_DIR = "runs"
|
| 16 |
DATA_CSV = "pdbbind_refined_dataset.csv"
|
| 17 |
|
|
@@ -88,7 +89,8 @@ def objective(trial):
|
|
| 88 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 89 |
optimizer, mode="min", factor=0.5, patience=5
|
| 90 |
)
|
| 91 |
-
criterion = nn.MSELoss()
|
|
|
|
| 92 |
|
| 93 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 94 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
|
@@ -118,7 +120,7 @@ if __name__ == "__main__":
|
|
| 118 |
direction="minimize",
|
| 119 |
pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=10),
|
| 120 |
storage=storage_name,
|
| 121 |
-
study_name="
|
| 122 |
load_if_exists=True,
|
| 123 |
)
|
| 124 |
print("Start hyperparameter optimization...")
|
|
|
|
| 7 |
from torch_geometric.loader import DataLoader
|
| 8 |
from torch.utils.data import random_split
|
| 9 |
from dataset import BindingDataset
|
| 10 |
+
from loss import WeightedMSELoss
|
| 11 |
from model_attention import BindingAffinityModel
|
| 12 |
|
| 13 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
N_TRIALS = 50
|
| 15 |
+
MAX_EPOCHS_PER_TRIAL = 50
|
| 16 |
LOG_DIR = "runs"
|
| 17 |
DATA_CSV = "pdbbind_refined_dataset.csv"
|
| 18 |
|
|
|
|
| 89 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 90 |
optimizer, mode="min", factor=0.5, patience=5
|
| 91 |
)
|
| 92 |
+
# criterion = nn.MSELoss()
|
| 93 |
+
criterion = WeightedMSELoss()
|
| 94 |
|
| 95 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 96 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
|
|
|
| 120 |
direction="minimize",
|
| 121 |
pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=10),
|
| 122 |
storage=storage_name,
|
| 123 |
+
study_name="binding_prediction_optimization_attentionWeightedLoss",
|
| 124 |
load_if_exists=True,
|
| 125 |
)
|
| 126 |
print("Start hyperparameter optimization...")
|
pdbbind_refined_dataset.csv β research/pdbbind_refined_dataset.csv
RENAMED
|
File without changes
|
research/requirements_dev.txt
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
pytorch-lightning
|
| 3 |
+
optuna
|
| 4 |
+
optuna-dashboard
|
| 5 |
+
tensorboard
|
| 6 |
+
|
| 7 |
+
numpy
|
| 8 |
+
pandas
|
| 9 |
+
matplotlib
|
| 10 |
+
seaborn
|
| 11 |
+
|
| 12 |
+
rdkit
|
| 13 |
+
biopython
|
| 14 |
+
|
| 15 |
+
jupyter
|
| 16 |
+
tqdm
|
| 17 |
+
black
|
| 18 |
+
|
train.py β research/train.py
RENAMED
|
File without changes
|
train_attention.py β research/train_attention.py
RENAMED
|
@@ -6,6 +6,7 @@ import pandas as pd
|
|
| 6 |
from torch.utils.data import random_split
|
| 7 |
from torch_geometric.loader import DataLoader
|
| 8 |
from dataset import BindingDataset
|
|
|
|
| 9 |
from model_attention import BindingAffinityModel
|
| 10 |
from tqdm import tqdm
|
| 11 |
from torch.utils.tensorboard import SummaryWriter
|
|
@@ -13,7 +14,7 @@ import numpy as np
|
|
| 13 |
from datetime import datetime
|
| 14 |
import os
|
| 15 |
|
| 16 |
-
# 2.02
|
| 17 |
# BATCH_SIZE = 16
|
| 18 |
# LR = 0.00035 # Reduced learning rate
|
| 19 |
# WEIGHT_DECAY = 1e-5 # Slightly increased weight decay (regularization)
|
|
@@ -23,16 +24,27 @@ import os
|
|
| 23 |
# HIDDEN_CHANNELS = 256
|
| 24 |
|
| 25 |
# 1.90 from Optuna
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
BATCH_SIZE = 16
|
| 27 |
-
LR = 0.
|
| 28 |
-
WEIGHT_DECAY = 1e-
|
| 29 |
-
DROPOUT = 0.
|
| 30 |
EPOCHS = 100
|
| 31 |
-
HIDDEN_CHANNELS =
|
| 32 |
-
GAT_HEADS =
|
| 33 |
|
| 34 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
-
LOG_DIR =
|
|
|
|
|
|
|
| 36 |
TOP_K = 3
|
| 37 |
SAVES_DIR = LOG_DIR + "/models"
|
| 38 |
|
|
@@ -128,7 +140,8 @@ def main():
|
|
| 128 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 129 |
optimizer, mode="min", factor=0.5, patience=8
|
| 130 |
)
|
| 131 |
-
criterion = nn.MSELoss()
|
|
|
|
| 132 |
|
| 133 |
top_models = []
|
| 134 |
|
|
@@ -152,7 +165,7 @@ def main():
|
|
| 152 |
end="",
|
| 153 |
)
|
| 154 |
|
| 155 |
-
filename = f"{SAVES_DIR}/model_ep{epoch:03d}
|
| 156 |
|
| 157 |
torch.save(model.state_dict(), filename)
|
| 158 |
top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
|
|
@@ -173,7 +186,7 @@ def main():
|
|
| 173 |
print("Training finished.")
|
| 174 |
print("Top models saved:")
|
| 175 |
for i, m in enumerate(top_models):
|
| 176 |
-
print(f"{i + 1}. {m['path']} (
|
| 177 |
|
| 178 |
|
| 179 |
if __name__ == "__main__":
|
|
|
|
| 6 |
from torch.utils.data import random_split
|
| 7 |
from torch_geometric.loader import DataLoader
|
| 8 |
from dataset import BindingDataset
|
| 9 |
+
from loss import WeightedMSELoss
|
| 10 |
from model_attention import BindingAffinityModel
|
| 11 |
from tqdm import tqdm
|
| 12 |
from torch.utils.tensorboard import SummaryWriter
|
|
|
|
| 14 |
from datetime import datetime
|
| 15 |
import os
|
| 16 |
|
| 17 |
+
# 2.02 default parameters
|
| 18 |
# BATCH_SIZE = 16
|
| 19 |
# LR = 0.00035 # Reduced learning rate
|
| 20 |
# WEIGHT_DECAY = 1e-5 # Slightly increased weight decay (regularization)
|
|
|
|
| 24 |
# HIDDEN_CHANNELS = 256
|
| 25 |
|
| 26 |
# 1.90 from Optuna
|
| 27 |
+
# BATCH_SIZE = 16
|
| 28 |
+
# LR = 0.000034
|
| 29 |
+
# WEIGHT_DECAY = 1e-6
|
| 30 |
+
# DROPOUT = 0.26
|
| 31 |
+
# EPOCHS = 100
|
| 32 |
+
# HIDDEN_CHANNELS = 256
|
| 33 |
+
# GAT_HEADS = 2
|
| 34 |
+
|
| 35 |
+
# Weighted Loss
|
| 36 |
BATCH_SIZE = 16
|
| 37 |
+
LR = 0.00022
|
| 38 |
+
WEIGHT_DECAY = 1e-5
|
| 39 |
+
DROPOUT = 0.25
|
| 40 |
EPOCHS = 100
|
| 41 |
+
HIDDEN_CHANNELS = 128
|
| 42 |
+
GAT_HEADS = 4
|
| 43 |
|
| 44 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 45 |
+
LOG_DIR = (
|
| 46 |
+
f"runs/experiment_attention{datetime.now().strftime('%Y%m%d_%H%M%S')}_weighted_loss"
|
| 47 |
+
)
|
| 48 |
TOP_K = 3
|
| 49 |
SAVES_DIR = LOG_DIR + "/models"
|
| 50 |
|
|
|
|
| 140 |
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 141 |
optimizer, mode="min", factor=0.5, patience=8
|
| 142 |
)
|
| 143 |
+
# criterion = nn.MSELoss()
|
| 144 |
+
criterion = WeightedMSELoss()
|
| 145 |
|
| 146 |
top_models = []
|
| 147 |
|
|
|
|
| 165 |
end="",
|
| 166 |
)
|
| 167 |
|
| 168 |
+
filename = f"{SAVES_DIR}/model_ep{epoch:03d}_weighted_loss{test_loss:.4f}.pth"
|
| 169 |
|
| 170 |
torch.save(model.state_dict(), filename)
|
| 171 |
top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
|
|
|
|
| 186 |
print("Training finished.")
|
| 187 |
print("Top models saved:")
|
| 188 |
for i, m in enumerate(top_models):
|
| 189 |
+
print(f"{i + 1}. {m['path']} (Weighted Loss: {m['loss']:.4f})")
|
| 190 |
|
| 191 |
|
| 192 |
if __name__ == "__main__":
|
train_pl.py β research/train_pl.py
RENAMED
|
File without changes
|
visualization.ipynb β research/visualization.ipynb
RENAMED
|
File without changes
|
utils.py
CHANGED
|
@@ -7,21 +7,24 @@ from rdkit.Chem import Descriptors
|
|
| 7 |
import py3Dmol
|
| 8 |
from jinja2 import Environment, FileSystemLoader
|
| 9 |
from google import genai
|
| 10 |
-
from google.genai import types
|
| 11 |
from decouple import config
|
| 12 |
|
| 13 |
GEMINI_API_KEY = config("GEMINI_API_KEY")
|
| 14 |
|
| 15 |
-
|
| 16 |
from dataset import get_atom_features, get_protein_features
|
| 17 |
from model_attention import BindingAffinityModel
|
| 18 |
|
| 19 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
-
# ΠΠ±Π½ΠΎΠ²ΠΈΡΠ΅ ΠΏΡΡΡ, Π΅ΡΠ»ΠΈ Π½ΡΠΆΠ½ΠΎ
|
| 21 |
-
MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
|
| 27 |
def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
|
|
@@ -45,14 +48,18 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
|
|
| 45 |
tokens = tokens[:1200]
|
| 46 |
else:
|
| 47 |
tokens.extend([0] * (1200 - len(tokens)))
|
| 48 |
-
protein_sequence_tensor =
|
|
|
|
|
|
|
| 49 |
|
| 50 |
data = Data(x=x, edge_index=edge_index)
|
| 51 |
batch = Batch.from_data_list([data]).to(DEVICE)
|
| 52 |
num_features = x.shape[1]
|
| 53 |
|
| 54 |
# Model
|
| 55 |
-
model = BindingAffinityModel(
|
|
|
|
|
|
|
| 56 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 57 |
model.eval()
|
| 58 |
|
|
@@ -65,7 +72,9 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
|
|
| 65 |
importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
|
| 66 |
|
| 67 |
if importance.max() > 0:
|
| 68 |
-
importance = (importance - importance.min()) / (
|
|
|
|
|
|
|
| 69 |
|
| 70 |
importance[importance < 0.01] = 0
|
| 71 |
return mol, importance, pred.item()
|
|
@@ -112,19 +121,16 @@ def get_lipinski_properties(mol):
|
|
| 112 |
"violations": violations,
|
| 113 |
"status_text": status,
|
| 114 |
"css_class": css_class,
|
| 115 |
-
"bad_params": ", ".join(bad_params) if bad_params else "None"
|
| 116 |
}
|
| 117 |
|
| 118 |
|
| 119 |
def get_py3dmol_view(mol, importance):
|
| 120 |
view = py3Dmol.view(width="100%", height="600px")
|
| 121 |
view.addModel(Chem.MolToMolBlock(mol), "sdf")
|
| 122 |
-
view.setBackgroundColor(
|
| 123 |
|
| 124 |
-
view.setStyle({}, {
|
| 125 |
-
'stick': {'radius': 0.15},
|
| 126 |
-
'sphere': {'scale': 0.25}
|
| 127 |
-
})
|
| 128 |
|
| 129 |
indices_sorted = np.argsort(importance)[::-1]
|
| 130 |
top_indices = set(indices_sorted[:15])
|
|
@@ -137,16 +143,19 @@ def get_py3dmol_view(mol, importance):
|
|
| 137 |
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 138 |
label_text = f"{i}:{symbol}:{val:.2f}"
|
| 139 |
|
| 140 |
-
view.addLabel(
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
| 150 |
view.zoomTo()
|
| 151 |
return view
|
| 152 |
|
|
@@ -166,32 +175,40 @@ def save_standalone_ngl_html(mol, importance, filepath):
|
|
| 166 |
indices_sorted = np.argsort(importance)[::-1]
|
| 167 |
top_indices = indices_sorted[:15]
|
| 168 |
|
| 169 |
-
|
| 170 |
selection_list = [str(i) for i in top_indices]
|
| 171 |
selection_str = "@" + ",".join(selection_list)
|
| 172 |
|
| 173 |
if not selection_list:
|
| 174 |
selection_str = "@-1"
|
| 175 |
|
|
|
|
|
|
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
rendered_html = template.render(pdb_block=final_pdb_block, selection_str=selection_str)
|
| 181 |
|
| 182 |
with open(filepath, "w", encoding="utf-8") as f:
|
| 183 |
f.write(rendered_html)
|
| 184 |
|
| 185 |
|
| 186 |
-
def get_gemini_explanation(
|
|
|
|
|
|
|
| 187 |
if not GEMINI_API_KEY:
|
| 188 |
return "<p class='text-warning'>API Key for Gemini not found. Please set GOOGLE_API_KEY environment variable.</p>"
|
| 189 |
|
| 190 |
# Forming a list of top important atoms for a prompt
|
| 191 |
-
atoms_desc = ", ".join(
|
|
|
|
|
|
|
| 192 |
|
| 193 |
# Cut a protein to not spend too many tokens
|
| 194 |
-
prot_short =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
prompt = f"""
|
| 197 |
You are an expert Computational Chemist and Drug Discovery Scientist.
|
|
@@ -216,9 +233,8 @@ def get_gemini_explanation(ligand_smiles, protein_sequence, affinity, top_atoms,
|
|
| 216 |
try:
|
| 217 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
| 218 |
response = client.models.generate_content(
|
| 219 |
-
model="gemini-2.5-flash",
|
| 220 |
-
contents=prompt
|
| 221 |
)
|
| 222 |
return response.text
|
| 223 |
except Exception as e:
|
| 224 |
-
return f"<p class='text-danger'>Error generating explanation: {str(e)}</p>"
|
|
|
|
| 7 |
import py3Dmol
|
| 8 |
from jinja2 import Environment, FileSystemLoader
|
| 9 |
from google import genai
|
|
|
|
| 10 |
from decouple import config
|
| 11 |
|
| 12 |
GEMINI_API_KEY = config("GEMINI_API_KEY")
|
| 13 |
|
| 14 |
+
|
| 15 |
from dataset import get_atom_features, get_protein_features
|
| 16 |
from model_attention import BindingAffinityModel
|
| 17 |
|
| 18 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
# MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
|
| 21 |
+
#
|
| 22 |
+
# GAT_HEADS = 2
|
| 23 |
+
# HIDDEN_CHANNELS = 256
|
| 24 |
+
|
| 25 |
+
MODEL_PATH = "runs/experiment_attention20260127_055340_weighted_loss/models/model_ep028_weighted_loss6.7715.pth"
|
| 26 |
+
GAT_HEADS = 4
|
| 27 |
+
HIDDEN_CHANNELS = 128
|
| 28 |
|
| 29 |
|
| 30 |
def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
|
|
|
|
| 48 |
tokens = tokens[:1200]
|
| 49 |
else:
|
| 50 |
tokens.extend([0] * (1200 - len(tokens)))
|
| 51 |
+
protein_sequence_tensor = (
|
| 52 |
+
torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
|
| 53 |
+
)
|
| 54 |
|
| 55 |
data = Data(x=x, edge_index=edge_index)
|
| 56 |
batch = Batch.from_data_list([data]).to(DEVICE)
|
| 57 |
num_features = x.shape[1]
|
| 58 |
|
| 59 |
# Model
|
| 60 |
+
model = BindingAffinityModel(
|
| 61 |
+
num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS
|
| 62 |
+
).to(DEVICE)
|
| 63 |
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 64 |
model.eval()
|
| 65 |
|
|
|
|
| 72 |
importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
|
| 73 |
|
| 74 |
if importance.max() > 0:
|
| 75 |
+
importance = (importance - importance.min()) / (
|
| 76 |
+
importance.max() - importance.min()
|
| 77 |
+
)
|
| 78 |
|
| 79 |
importance[importance < 0.01] = 0
|
| 80 |
return mol, importance, pred.item()
|
|
|
|
| 121 |
"violations": violations,
|
| 122 |
"status_text": status,
|
| 123 |
"css_class": css_class,
|
| 124 |
+
"bad_params": ", ".join(bad_params) if bad_params else "None",
|
| 125 |
}
|
| 126 |
|
| 127 |
|
| 128 |
def get_py3dmol_view(mol, importance):
|
| 129 |
view = py3Dmol.view(width="100%", height="600px")
|
| 130 |
view.addModel(Chem.MolToMolBlock(mol), "sdf")
|
| 131 |
+
view.setBackgroundColor("white")
|
| 132 |
|
| 133 |
+
view.setStyle({}, {"stick": {"radius": 0.15}, "sphere": {"scale": 0.25}})
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
indices_sorted = np.argsort(importance)[::-1]
|
| 136 |
top_indices = set(indices_sorted[:15])
|
|
|
|
| 143 |
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 144 |
label_text = f"{i}:{symbol}:{val:.2f}"
|
| 145 |
|
| 146 |
+
view.addLabel(
|
| 147 |
+
label_text,
|
| 148 |
+
{
|
| 149 |
+
"position": {"x": pos.x, "y": pos.y, "z": pos.z},
|
| 150 |
+
"fontSize": 14,
|
| 151 |
+
"fontColor": "white",
|
| 152 |
+
"backgroundColor": "black",
|
| 153 |
+
"backgroundOpacity": 0.7,
|
| 154 |
+
"borderThickness": 0,
|
| 155 |
+
"inFront": True,
|
| 156 |
+
"showBackground": True,
|
| 157 |
+
},
|
| 158 |
+
)
|
| 159 |
view.zoomTo()
|
| 160 |
return view
|
| 161 |
|
|
|
|
| 175 |
indices_sorted = np.argsort(importance)[::-1]
|
| 176 |
top_indices = indices_sorted[:15]
|
| 177 |
|
|
|
|
| 178 |
selection_list = [str(i) for i in top_indices]
|
| 179 |
selection_str = "@" + ",".join(selection_list)
|
| 180 |
|
| 181 |
if not selection_list:
|
| 182 |
selection_str = "@-1"
|
| 183 |
|
| 184 |
+
env = Environment(loader=FileSystemLoader("templates"))
|
| 185 |
+
template = env.get_template("ngl_view.html")
|
| 186 |
|
| 187 |
+
rendered_html = template.render(
|
| 188 |
+
pdb_block=final_pdb_block, selection_str=selection_str
|
| 189 |
+
)
|
|
|
|
| 190 |
|
| 191 |
with open(filepath, "w", encoding="utf-8") as f:
|
| 192 |
f.write(rendered_html)
|
| 193 |
|
| 194 |
|
| 195 |
+
def get_gemini_explanation(
|
| 196 |
+
ligand_smiles, protein_sequence, affinity, top_atoms, lipinski
|
| 197 |
+
):
|
| 198 |
if not GEMINI_API_KEY:
|
| 199 |
return "<p class='text-warning'>API Key for Gemini not found. Please set GOOGLE_API_KEY environment variable.</p>"
|
| 200 |
|
| 201 |
# Forming a list of top important atoms for a prompt
|
| 202 |
+
atoms_desc = ", ".join(
|
| 203 |
+
[f"{a['symbol']}(idx {a['id']}, score {a['score']})" for a in top_atoms[:10]]
|
| 204 |
+
)
|
| 205 |
|
| 206 |
# Cut a protein to not spend too many tokens
|
| 207 |
+
prot_short = (
|
| 208 |
+
protein_sequence[:100] + "..."
|
| 209 |
+
if len(protein_sequence) > 100
|
| 210 |
+
else protein_sequence
|
| 211 |
+
)
|
| 212 |
|
| 213 |
prompt = f"""
|
| 214 |
You are an expert Computational Chemist and Drug Discovery Scientist.
|
|
|
|
| 233 |
try:
|
| 234 |
client = genai.Client(api_key=GEMINI_API_KEY)
|
| 235 |
response = client.models.generate_content(
|
| 236 |
+
model="gemini-2.5-flash", contents=prompt
|
|
|
|
| 237 |
)
|
| 238 |
return response.text
|
| 239 |
except Exception as e:
|
| 240 |
+
return f"<p class='text-danger'>Error generating explanation: {str(e)}</p>"
|