AlexSychovUN commited on
Commit
13188b8
Β·
1 Parent(s): 543ad41

Prepared for deploy

Browse files
Files changed (36) hide show
  1. .dockerignore +6 -0
  2. .gitignore +3 -1
  3. Dockerfile +16 -0
  4. docker-compose.yml +8 -0
  5. main.py +37 -34
  6. model_attention.py +1 -1
  7. models/model_ep028_weighted_loss6.7715.pth +3 -0
  8. requirements.txt +17 -7
  9. EDA.ipynb β†’ research/EDA.ipynb +0 -0
  10. {GNN_classification β†’ research/GNN_classification}/Dataset_Preparation.py +0 -0
  11. {GNN_classification β†’ research/GNN_classification}/dataset/classification/EDA.ipynb +0 -0
  12. {GNN_classification β†’ research/GNN_classification}/dataset/classification/data_test.csv +0 -0
  13. {GNN_classification β†’ research/GNN_classification}/dataset/classification/data_test.txt +0 -0
  14. {GNN_classification β†’ research/GNN_classification}/dataset/classification/data_train.csv +0 -0
  15. {GNN_classification β†’ research/GNN_classification}/dataset/classification/data_train.txt +0 -0
  16. {GNN_classification β†’ research/GNN_classification}/model.py +0 -0
  17. {GNN_classification β†’ research/GNN_classification}/training.py +0 -0
  18. GNNs__practice.ipynb β†’ research/GNNs__practice.ipynb +0 -0
  19. all_inferences.py β†’ research/all_inferences.py +82 -47
  20. research/dataset.py +209 -0
  21. dataset_preparation.py β†’ research/dataset_preparation.py +0 -0
  22. inference.py β†’ research/inference.py +0 -0
  23. inference_attention.py β†’ research/inference_attention.py +0 -0
  24. research/loss.py +18 -0
  25. research/model.py +158 -0
  26. research/model_attention.py +144 -0
  27. model_pl.py β†’ research/model_pl.py +0 -0
  28. optuna_train.py β†’ research/optuna_train.py +0 -0
  29. optuna_train_attention.py β†’ research/optuna_train_attention.py +5 -3
  30. pdbbind_refined_dataset.csv β†’ research/pdbbind_refined_dataset.csv +0 -0
  31. research/requirements_dev.txt +18 -0
  32. train.py β†’ research/train.py +0 -0
  33. train_attention.py β†’ research/train_attention.py +23 -10
  34. train_pl.py β†’ research/train_pl.py +0 -0
  35. visualization.ipynb β†’ research/visualization.ipynb +0 -0
  36. utils.py +52 -36
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .git
2
+ .idea
3
+ .venv
4
+ __pycache__
5
+ research
6
+ .env
.gitignore CHANGED
@@ -1,7 +1,9 @@
1
  .idea
2
  .venv
3
  .ipynb_checkpoints
 
4
 
5
  /refined-set/
6
  /data
7
- /lightning_logs
 
 
1
  .idea
2
  .venv
3
  .ipynb_checkpoints
4
+ __pycache__
5
 
6
  /refined-set/
7
  /data
8
+ /lightning_logs
9
+ .env
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ WORKDIR /app
4
+
5
+ RUN apt-get update && apt-get install -y \
6
+ build-essential \
7
+ libxrender1 \
8
+ libxext6 \
9
+ && rm -rf /var/lib/apt/lists/*
10
+
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ COPY . .
15
+ EXPOSE 7860
16
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
docker-compose.yml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ binding-predictor:
3
+ build: .
4
+ container_name: binding-app
5
+ ports:
6
+ - "8000:7860"
7
+ environment:
8
+ - GEMINI_API_KEY=${GEMINI_API_KEY}
main.py CHANGED
@@ -5,8 +5,13 @@ from fastapi import FastAPI, Request, Form
5
  from fastapi.templating import Jinja2Templates
6
  from fastapi.staticfiles import StaticFiles
7
  from fastapi.responses import HTMLResponse
8
- from utils import get_inference_data, get_py3dmol_view, save_standalone_ngl_html, get_lipinski_properties, \
9
- get_gemini_explanation
 
 
 
 
 
10
 
11
  app = FastAPI()
12
 
@@ -17,6 +22,7 @@ app.mount("/static", StaticFiles(directory="static"), name="static")
17
 
18
  templates = Jinja2Templates(directory="templates")
19
 
 
20
  @app.get("/", response_class=HTMLResponse)
21
  async def read_root(request: Request):
22
  return templates.TemplateResponse("index.html", {"request": request})
@@ -24,33 +30,32 @@ async def read_root(request: Request):
24
 
25
  @app.post("/predict", response_class=HTMLResponse)
26
  async def predict(
27
- request: Request,
28
- smiles_ligand: str = Form(...),
29
- sequence_protein: str = Form(...)
30
  ):
31
  mol, importance, affinity = get_inference_data(smiles_ligand, sequence_protein)
32
 
33
  atom_list = []
34
- sorted_indices = sorted(range(len(importance)), key=lambda k: importance[k], reverse=True)
 
 
35
 
36
  for idx in sorted_indices[:15]:
37
  val = importance[idx]
38
  symbol = mol.GetAtomWithIdx(idx).GetSymbol()
39
 
40
  icon = ""
41
- if val >= 0.9: icon = "πŸ”₯"
42
- elif val >= 0.7: icon = "✨"
43
- elif val >= 0.5: icon = "⭐"
44
- atom_list.append({
45
- "id": idx,
46
- "symbol": symbol,
47
- "score": f"{val:.3f}",
48
- "icon": icon
49
- })
50
 
51
  unique_id = str(uuid.uuid4())
52
 
53
-
54
  filename_ngl = f"ngl_{unique_id}.html"
55
  filepath_ngl = os.path.join("html_results", filename_ngl)
56
 
@@ -71,23 +76,21 @@ async def predict(
71
  sequence_protein,
72
  f"{affinity:.2f}",
73
  atom_list,
74
- lipinski_properties
75
  )
76
 
77
- return templates.TemplateResponse("index.html", {
78
- "request": request,
79
- "result_ready": True,
80
- "smiles": smiles_ligand,
81
- "protein": sequence_protein,
82
- "affinity": f"{affinity:.2f}",
83
- "atom_list": atom_list,
84
- "html_py3dmol": py3dmol_content,
85
- "url_ngl": ngl_url_link,
86
- "lipinski": lipinski_properties,
87
- "ai_explanation": ai_explanation
88
- })
89
-
90
-
91
-
92
-
93
-
 
5
  from fastapi.templating import Jinja2Templates
6
  from fastapi.staticfiles import StaticFiles
7
  from fastapi.responses import HTMLResponse
8
+ from utils import (
9
+ get_inference_data,
10
+ get_py3dmol_view,
11
+ save_standalone_ngl_html,
12
+ get_lipinski_properties,
13
+ get_gemini_explanation,
14
+ )
15
 
16
  app = FastAPI()
17
 
 
22
 
23
  templates = Jinja2Templates(directory="templates")
24
 
25
+
26
  @app.get("/", response_class=HTMLResponse)
27
  async def read_root(request: Request):
28
  return templates.TemplateResponse("index.html", {"request": request})
 
30
 
31
  @app.post("/predict", response_class=HTMLResponse)
32
  async def predict(
33
+ request: Request, smiles_ligand: str = Form(...), sequence_protein: str = Form(...)
 
 
34
  ):
35
  mol, importance, affinity = get_inference_data(smiles_ligand, sequence_protein)
36
 
37
  atom_list = []
38
+ sorted_indices = sorted(
39
+ range(len(importance)), key=lambda k: importance[k], reverse=True
40
+ )
41
 
42
  for idx in sorted_indices[:15]:
43
  val = importance[idx]
44
  symbol = mol.GetAtomWithIdx(idx).GetSymbol()
45
 
46
  icon = ""
47
+ if val >= 0.9:
48
+ icon = "πŸ”₯"
49
+ elif val >= 0.7:
50
+ icon = "✨"
51
+ elif val >= 0.5:
52
+ icon = "⭐"
53
+ atom_list.append(
54
+ {"id": idx, "symbol": symbol, "score": f"{val:.3f}", "icon": icon}
55
+ )
56
 
57
  unique_id = str(uuid.uuid4())
58
 
 
59
  filename_ngl = f"ngl_{unique_id}.html"
60
  filepath_ngl = os.path.join("html_results", filename_ngl)
61
 
 
76
  sequence_protein,
77
  f"{affinity:.2f}",
78
  atom_list,
79
+ lipinski_properties,
80
  )
81
 
82
+ return templates.TemplateResponse(
83
+ "index.html",
84
+ {
85
+ "request": request,
86
+ "result_ready": True,
87
+ "smiles": smiles_ligand,
88
+ "protein": sequence_protein,
89
+ "affinity": f"{affinity:.2f}",
90
+ "atom_list": atom_list,
91
+ "html_py3dmol": py3dmol_content,
92
+ "url_ngl": ngl_url_link,
93
+ "lipinski": lipinski_properties,
94
+ "ai_explanation": ai_explanation,
95
+ },
96
+ )
 
 
model_attention.py CHANGED
@@ -20,7 +20,7 @@ class CrossAttentionLayer(nn.Module):
20
  # Feedforward network for further processing, classical transformer style
21
  self.ff = nn.Sequential(
22
  nn.Linear(feature_dim, feature_dim * 4),
23
- nn.GELU(), # GELU works better with transformers
24
  nn.Dropout(dropout),
25
  nn.Linear(feature_dim * 4, feature_dim),
26
  )
 
20
  # Feedforward network for further processing, classical transformer style
21
  self.ff = nn.Sequential(
22
  nn.Linear(feature_dim, feature_dim * 4),
23
+ nn.GELU(), # GELU works better with transformers
24
  nn.Dropout(dropout),
25
  nn.Linear(feature_dim * 4, feature_dim),
26
  )
models/model_ep028_weighted_loss6.7715.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d64adee176c82baa3489ffbae33567c31f5cea7c7b254e782651507754e3dfc2
3
+ size 1810686
requirements.txt CHANGED
@@ -1,10 +1,20 @@
1
- torch
2
- pytorch-lightning
3
- optuna
 
4
 
5
- numpy
6
- pandas
7
 
8
- rdkit
9
- biopython
 
 
 
 
 
 
 
 
 
10
 
 
1
+ fastapi==0.128.0
2
+ uvicorn[standard]==0.38.0
3
+ python-multipart==0.0.21
4
+ requests==2.32.5
5
 
6
+ python-decouple==3.8
7
+ py3Dmol==2.5.4
8
 
9
+ numpy==2.2.6
10
+ pandas==2.3.3
11
+
12
+ rdkit==2025.9.1
13
+
14
+ google-genai==1.53.0
15
+
16
+ --extra-index-url https://download.pytorch.org/whl/cpu
17
+
18
+ torch>=2.5.0
19
+ torch-geometric>=2.7.0
20
 
EDA.ipynb β†’ research/EDA.ipynb RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/Dataset_Preparation.py RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/dataset/classification/EDA.ipynb RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/dataset/classification/data_test.csv RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/dataset/classification/data_test.txt RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/dataset/classification/data_train.csv RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/dataset/classification/data_train.txt RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/model.py RENAMED
File without changes
{GNN_classification β†’ research/GNN_classification}/training.py RENAMED
File without changes
GNNs__practice.ipynb β†’ research/GNNs__practice.ipynb RENAMED
File without changes
all_inferences.py β†’ research/all_inferences.py RENAMED
@@ -19,11 +19,14 @@ from model_attention import BindingAffinityModel
19
 
20
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
- MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
 
 
23
 
24
  GAT_HEADS = 2
25
  HIDDEN_CHANNELS = 256
26
 
 
27
  def get_inference_data(ligand_smiles, protein_sequence, model_path):
28
  """
29
  Returns:
@@ -46,8 +49,10 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path):
46
  edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
47
 
48
  tokens = [get_protein_features(c) for c in protein_sequence]
49
- if len(tokens) > 1200: tokens = tokens[:1200]
50
- else: tokens.extend([0] * (1200 - len(tokens)))
 
 
51
  protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
52
 
53
  data = Data(x=x, edge_index=edge_index)
@@ -55,7 +60,9 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path):
55
  num_features = x.shape[1]
56
 
57
  # Model loading
58
- model = BindingAffinityModel(num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS).to(DEVICE)
 
 
59
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
60
  model.eval()
61
 
@@ -70,7 +77,9 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path):
70
 
71
  # Normalize to [0, 1]
72
  if importance.max() > 0:
73
- importance = (importance - importance.min()) / (importance.max() - importance.min())
 
 
74
 
75
  # Noise reduction
76
  importance[importance < 0.01] = 0
@@ -93,20 +102,22 @@ def print_atom_scores(mol, importance):
93
  print(f"Atom {idx} ({symbol}): Importance = {score:.3f} {fire}")
94
 
95
 
96
-
97
  def get_py3dmol(mol, importance, score):
98
 
99
  view = py3Dmol.view(width=1000, height=800)
100
  view.addModel(Chem.MolToMolBlock(mol), "sdf")
101
- view.setBackgroundColor('white')
102
 
103
  # 1. Π‘ΠΠ—ΠžΠ’Π«Π™ Π‘Π’Π˜Π›Π¬ (Π“Π Π£ΠΠ’ΠžΠ’ΠšΠ)
104
  # Π—Π°Π΄Π°Π΅ΠΌ Π΅Π΄ΠΈΠ½Ρ‹ΠΉ Ρ€Π°Π·ΠΌΠ΅Ρ€ для всСй ΠΌΠΎΠ»Π΅ΠΊΡƒΠ»Ρ‹ сразу
105
  # scale: 0.25 β€” ΠΎΠΏΡ‚ΠΈΠΌΠ°Π»ΡŒΠ½Ρ‹ΠΉ срСдний Ρ€Π°Π·ΠΌΠ΅Ρ€
106
- view.setStyle({}, {
107
- 'stick': {'color': '#cccccc', 'radius': 0.1},
108
- 'sphere': {'color': '#cccccc', 'scale': 0.25}
109
- })
 
 
 
110
 
111
  red_atoms = []
112
  orange_atoms = []
@@ -130,48 +141,65 @@ def get_py3dmol(mol, importance, score):
130
  if i in top_indices and val > 0.1:
131
  pos = conf.GetAtomPosition(i)
132
  symbol = mol.GetAtomWithIdx(i).GetSymbol()
133
- labels_to_add.append({
134
- 'text': f"{i}:{symbol}:{val:.2f}",
135
- 'pos': {'x': pos.x, 'y': pos.y, 'z': pos.z}
136
- })
 
 
137
 
138
  # 3. ΠŸΠ Π˜ΠœΠ•ΠΠ•ΠΠ˜Π• Π‘Π’Π˜Π›Π•Π™
139
  # ΠžΠ±Ρ€Π°Ρ‚ΠΈ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅: scale Π²Π΅Π·Π΄Π΅ 0.25 (ΠΈΠ»ΠΈ 0.28, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Ρ‡ΡƒΡ‚ΡŒ Π²Ρ‹Π΄Π΅Π»ΠΈΡ‚ΡŒ Ρ†Π²Π΅Ρ‚Π½Ρ‹Π΅)
140
  # ΠœΡ‹ мСняСм Π’ΠžΠ›Π¬ΠšΠž Π¦Π’Π•Π’.
141
 
142
  if red_atoms:
143
- view.addStyle({'serial': red_atoms}, {
144
- 'sphere': {'color': '#FF0000', 'scale': 0.28},
145
- 'stick': {'color': '#FF0000', 'radius': 0.12}
146
- })
 
 
 
147
 
148
  if orange_atoms:
149
- view.addStyle({'serial': orange_atoms}, {
150
- 'sphere': {'color': '#FF8C00', 'scale': 0.28},
151
- 'stick': {'color': '#FF8C00', 'radius': 0.12}
152
- })
 
 
 
153
 
154
  if blue_atoms:
155
- view.addStyle({'serial': blue_atoms}, {
156
- 'sphere': {'color': '#7777FF', 'scale': 0.28}
157
- })
158
 
159
  # 4. ΠœΠ•Π’ΠšΠ˜
160
  for label in labels_to_add:
161
- view.addLabel(label['text'], {
162
- 'position': label['pos'],
163
- 'fontSize': 14,
164
- 'fontColor': 'white',
165
- 'backgroundColor': 'black',
166
- 'backgroundOpacity': 0.7,
167
- 'borderThickness': 0,
168
- 'inFront': True,
169
- 'showBackground': True
170
- })
 
 
 
171
 
172
  view.zoomTo()
173
- view.addLabel(f"Predicted pKd: {float(score):.2f}",
174
- {'position': {'x': -5, 'y': 10, 'z': 0}, 'backgroundColor': 'black', 'fontColor': 'white'})
 
 
 
 
 
 
175
 
176
  return view
177
 
@@ -190,23 +218,31 @@ def get_ngl(mol, importance):
190
  view = nv.NGLWidget(structure)
191
  view.clear_representations()
192
 
193
- view.add_representation('ball+stick', colorScheme='bfactor', colorScale=['blue', 'white', 'red'], colorDomain=[10, 80], radiusScale=1.0)
 
 
 
 
 
 
194
 
195
  indices_sorted = np.argsort(importance)[::-1]
196
  top_indices = indices_sorted[:15]
197
 
198
  selection_str = "@" + ",".join(map(str, top_indices))
199
- view.add_representation('label',
200
- selection=selection_str, # ΠŸΠΎΠ΄ΠΏΠΈΡΡ‹Π²Π°Π΅ΠΌ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΈΠ·Π±Ρ€Π°Π½Π½Ρ‹Ρ…
201
- labelType='atomindex', # ΠŸΠΎΠΊΠ°Π·Ρ‹Π²Π°Ρ‚ΡŒ ИндСкс (0, 1, 2...)
202
- color='black', # Π§Π΅Ρ€Π½Ρ‹ΠΉ тСкст
203
- radius=2.0, # Π Π°Π·ΠΌΠ΅Ρ€ ΡˆΡ€ΠΈΡ„Ρ‚Π° (ΠΏΠΎΠΏΡ€ΠΎΠ±ΡƒΠΉΡ‚Π΅ 1.5 - 3.0)
204
- zOffset=1.0) # Π§ΡƒΡ‚ΡŒ ΡΠ΄Π²ΠΈΠ½ΡƒΡ‚ΡŒ ΠΊ ΠΊΠ°ΠΌΠ΅Ρ€Π΅
205
-
 
206
 
207
  view.center()
208
  return view
209
 
 
210
  if __name__ == "__main__":
211
  smiles = "COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1"
212
  protein = "PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF"
@@ -222,4 +258,3 @@ if __name__ == "__main__":
222
 
223
  ngl_widget = get_ngl(mol, importance)
224
  nv.write_html(file_name_ngl, ngl_widget)
225
-
 
19
 
20
 
21
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ MODEL_PATH = (
23
+ "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
24
+ )
25
 
26
  GAT_HEADS = 2
27
  HIDDEN_CHANNELS = 256
28
 
29
+
30
  def get_inference_data(ligand_smiles, protein_sequence, model_path):
31
  """
32
  Returns:
 
49
  edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
50
 
51
  tokens = [get_protein_features(c) for c in protein_sequence]
52
+ if len(tokens) > 1200:
53
+ tokens = tokens[:1200]
54
+ else:
55
+ tokens.extend([0] * (1200 - len(tokens)))
56
  protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
57
 
58
  data = Data(x=x, edge_index=edge_index)
 
60
  num_features = x.shape[1]
61
 
62
  # Model loading
63
+ model = BindingAffinityModel(
64
+ num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS
65
+ ).to(DEVICE)
66
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
67
  model.eval()
68
 
 
77
 
78
  # Normalize to [0, 1]
79
  if importance.max() > 0:
80
+ importance = (importance - importance.min()) / (
81
+ importance.max() - importance.min()
82
+ )
83
 
84
  # Noise reduction
85
  importance[importance < 0.01] = 0
 
102
  print(f"Atom {idx} ({symbol}): Importance = {score:.3f} {fire}")
103
 
104
 
 
105
  def get_py3dmol(mol, importance, score):
106
 
107
  view = py3Dmol.view(width=1000, height=800)
108
  view.addModel(Chem.MolToMolBlock(mol), "sdf")
109
+ view.setBackgroundColor("white")
110
 
111
  # 1. Π‘ΠΠ—ΠžΠ’Π«Π™ Π‘Π’Π˜Π›Π¬ (Π“Π Π£ΠΠ’ΠžΠ’ΠšΠ)
112
  # Π—Π°Π΄Π°Π΅ΠΌ Π΅Π΄ΠΈΠ½Ρ‹ΠΉ Ρ€Π°Π·ΠΌΠ΅Ρ€ для всСй ΠΌΠΎΠ»Π΅ΠΊΡƒΠ»Ρ‹ сразу
113
  # scale: 0.25 β€” ΠΎΠΏΡ‚ΠΈΠΌΠ°Π»ΡŒΠ½Ρ‹ΠΉ срСдний Ρ€Π°Π·ΠΌΠ΅Ρ€
114
+ view.setStyle(
115
+ {},
116
+ {
117
+ "stick": {"color": "#cccccc", "radius": 0.1},
118
+ "sphere": {"color": "#cccccc", "scale": 0.25},
119
+ },
120
+ )
121
 
122
  red_atoms = []
123
  orange_atoms = []
 
141
  if i in top_indices and val > 0.1:
142
  pos = conf.GetAtomPosition(i)
143
  symbol = mol.GetAtomWithIdx(i).GetSymbol()
144
+ labels_to_add.append(
145
+ {
146
+ "text": f"{i}:{symbol}:{val:.2f}",
147
+ "pos": {"x": pos.x, "y": pos.y, "z": pos.z},
148
+ }
149
+ )
150
 
151
  # 3. ΠŸΠ Π˜ΠœΠ•ΠΠ•ΠΠ˜Π• Π‘Π’Π˜Π›Π•Π™
152
  # ΠžΠ±Ρ€Π°Ρ‚ΠΈ Π²Π½ΠΈΠΌΠ°Π½ΠΈΠ΅: scale Π²Π΅Π·Π΄Π΅ 0.25 (ΠΈΠ»ΠΈ 0.28, Ρ‡Ρ‚ΠΎΠ±Ρ‹ Ρ‡ΡƒΡ‚ΡŒ Π²Ρ‹Π΄Π΅Π»ΠΈΡ‚ΡŒ Ρ†Π²Π΅Ρ‚Π½Ρ‹Π΅)
153
  # ΠœΡ‹ мСняСм Π’ΠžΠ›Π¬ΠšΠž Π¦Π’Π•Π’.
154
 
155
  if red_atoms:
156
+ view.addStyle(
157
+ {"serial": red_atoms},
158
+ {
159
+ "sphere": {"color": "#FF0000", "scale": 0.28},
160
+ "stick": {"color": "#FF0000", "radius": 0.12},
161
+ },
162
+ )
163
 
164
  if orange_atoms:
165
+ view.addStyle(
166
+ {"serial": orange_atoms},
167
+ {
168
+ "sphere": {"color": "#FF8C00", "scale": 0.28},
169
+ "stick": {"color": "#FF8C00", "radius": 0.12},
170
+ },
171
+ )
172
 
173
  if blue_atoms:
174
+ view.addStyle(
175
+ {"serial": blue_atoms}, {"sphere": {"color": "#7777FF", "scale": 0.28}}
176
+ )
177
 
178
  # 4. ΠœΠ•Π’ΠšΠ˜
179
  for label in labels_to_add:
180
+ view.addLabel(
181
+ label["text"],
182
+ {
183
+ "position": label["pos"],
184
+ "fontSize": 14,
185
+ "fontColor": "white",
186
+ "backgroundColor": "black",
187
+ "backgroundOpacity": 0.7,
188
+ "borderThickness": 0,
189
+ "inFront": True,
190
+ "showBackground": True,
191
+ },
192
+ )
193
 
194
  view.zoomTo()
195
+ view.addLabel(
196
+ f"Predicted pKd: {float(score):.2f}",
197
+ {
198
+ "position": {"x": -5, "y": 10, "z": 0},
199
+ "backgroundColor": "black",
200
+ "fontColor": "white",
201
+ },
202
+ )
203
 
204
  return view
205
 
 
218
  view = nv.NGLWidget(structure)
219
  view.clear_representations()
220
 
221
+ view.add_representation(
222
+ "ball+stick",
223
+ colorScheme="bfactor",
224
+ colorScale=["blue", "white", "red"],
225
+ colorDomain=[10, 80],
226
+ radiusScale=1.0,
227
+ )
228
 
229
  indices_sorted = np.argsort(importance)[::-1]
230
  top_indices = indices_sorted[:15]
231
 
232
  selection_str = "@" + ",".join(map(str, top_indices))
233
+ view.add_representation(
234
+ "label",
235
+ selection=selection_str, # ΠŸΠΎΠ΄ΠΏΠΈΡΡ‹Π²Π°Π΅ΠΌ Ρ‚ΠΎΠ»ΡŒΠΊΠΎ ΠΈΠ·Π±Ρ€Π°Π½Π½Ρ‹Ρ…
236
+ labelType="atomindex", # ΠŸΠΎΠΊΠ°Π·Ρ‹Π²Π°Ρ‚ΡŒ ИндСкс (0, 1, 2...)
237
+ color="black", # Π§Π΅Ρ€Π½Ρ‹ΠΉ тСкст
238
+ radius=2.0, # Π Π°Π·ΠΌΠ΅Ρ€ ΡˆΡ€ΠΈΡ„Ρ‚Π° (ΠΏΠΎΠΏΡ€ΠΎΠ±ΡƒΠΉΡ‚Π΅ 1.5 - 3.0)
239
+ zOffset=1.0,
240
+ ) # Π§ΡƒΡ‚ΡŒ ΡΠ΄Π²ΠΈΠ½ΡƒΡ‚ΡŒ ΠΊ ΠΊΠ°ΠΌΠ΅Ρ€Π΅
241
 
242
  view.center()
243
  return view
244
 
245
+
246
  if __name__ == "__main__":
247
  smiles = "COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1"
248
  protein = "PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF"
 
258
 
259
  ngl_widget = get_ngl(mol, importance)
260
  nv.write_html(file_name_ngl, ngl_widget)
 
research/dataset.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import pandas as pd
4
+ from rdkit import Chem, rdBase
5
+ from torch_geometric.data import Data
6
+ from torch.utils.data import Dataset, random_split
7
+
8
+ rdBase.DisableLog("rdApp.*")
9
+
10
+
11
+ def one_of_k_encoding(x, allowable_set):
12
+ # last position - unknown
13
+ if x not in allowable_set:
14
+ x = allowable_set[-1]
15
+ return list(map(lambda s: x == s, allowable_set))
16
+
17
+
18
+ def get_atom_features(atom):
19
+ symbols_list = [
20
+ "C",
21
+ "N",
22
+ "O",
23
+ "S",
24
+ "F",
25
+ "Si",
26
+ "P",
27
+ "Cl",
28
+ "Br",
29
+ "Mg",
30
+ "Na",
31
+ "Ca",
32
+ "Fe",
33
+ "As",
34
+ "Al",
35
+ "I",
36
+ "B",
37
+ "V",
38
+ "K",
39
+ "Tl",
40
+ "Yb",
41
+ "Sb",
42
+ "Sn",
43
+ "Ag",
44
+ "Pd",
45
+ "Co",
46
+ "Se",
47
+ "Ti",
48
+ "Zn",
49
+ "H",
50
+ "Li",
51
+ "Ge",
52
+ "Cu",
53
+ "Au",
54
+ "Ni",
55
+ "Cd",
56
+ "In",
57
+ "Mn",
58
+ "Zr",
59
+ "Cr",
60
+ "Pt",
61
+ "Hg",
62
+ "Pb",
63
+ "Unknown",
64
+ ]
65
+ degrees_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
66
+ numhs_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
67
+ implicit_valences_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
68
+
69
+ formal_charge_list = [-2, -1, 0, 1, 2]
70
+ chirality_list = [
71
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
72
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
73
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
74
+ Chem.rdchem.ChiralType.CHI_OTHER,
75
+ ]
76
+ return np.array(
77
+ # Type of atom (Symbol)
78
+ one_of_k_encoding(atom.GetSymbol(), symbols_list)
79
+ +
80
+ # Number of neighbours (Degree)
81
+ one_of_k_encoding(atom.GetDegree(), degrees_list)
82
+ +
83
+ # Number of hydrogen atoms (Implicit Hs) - bond donors
84
+ one_of_k_encoding(atom.GetTotalNumHs(), numhs_list)
85
+ +
86
+ # Valence - chemical potential
87
+ one_of_k_encoding(atom.GetImplicitValence(), implicit_valences_list)
88
+ +
89
+ # Hybridization - so important for 3d structure, sp2 - Trigonal planar, sp3 - Tetrahedral
90
+ one_of_k_encoding(
91
+ atom.GetHybridization(),
92
+ [
93
+ Chem.rdchem.HybridizationType.SP,
94
+ Chem.rdchem.HybridizationType.SP2,
95
+ Chem.rdchem.HybridizationType.SP3,
96
+ Chem.rdchem.HybridizationType.SP3D,
97
+ Chem.rdchem.HybridizationType.SP3D2,
98
+ "other",
99
+ ],
100
+ )
101
+ +
102
+ # Aromaticity (Boolean)
103
+ [atom.GetIsAromatic()]
104
+ +
105
+ # Formal Charge
106
+ one_of_k_encoding(atom.GetFormalCharge(), formal_charge_list)
107
+ +
108
+ # Chirality (Geometry)
109
+ one_of_k_encoding(atom.GetChiralTag(), chirality_list)
110
+ +
111
+ # Is in ring (Boolean)
112
+ [atom.IsInRing()]
113
+ )
114
+
115
+
116
+ def get_protein_features(char):
117
+ prot_vocab = {
118
+ "A": 1,
119
+ "R": 2,
120
+ "N": 3,
121
+ "D": 4,
122
+ "C": 5,
123
+ "Q": 6,
124
+ "E": 7,
125
+ "G": 8,
126
+ "H": 9,
127
+ "I": 10,
128
+ "L": 11,
129
+ "K": 12,
130
+ "M": 13,
131
+ "F": 14,
132
+ "P": 15,
133
+ "S": 16,
134
+ "T": 17,
135
+ "W": 18,
136
+ "Y": 19,
137
+ "V": 20,
138
+ "X": 21,
139
+ "Z": 21,
140
+ "B": 21,
141
+ "PAD": 0,
142
+ "UNK": 21,
143
+ }
144
+ return prot_vocab.get(char, prot_vocab["UNK"])
145
+
146
+
147
+ class BindingDataset(Dataset):
148
+ def __init__(self, dataframe, max_seq_length=1000):
149
+ self.data = dataframe
150
+ self.max_seq_length = (
151
+ max_seq_length # Define a maximum sequence length for padding/truncation
152
+ )
153
+
154
+ def __len__(self):
155
+ return len(self.data)
156
+
157
+ def __getitem__(self, idx):
158
+ row = self.data.iloc[idx]
159
+ smiles = row["smiles"]
160
+ sequence = row["sequence"]
161
+ affinity = row["affinity"]
162
+
163
+ mol = Chem.MolFromSmiles(smiles)
164
+ if mol is None:
165
+ return None
166
+
167
+ # Ligand (Graph)
168
+ # Nodes
169
+ atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
170
+ x = torch.tensor(np.array(atom_features), dtype=torch.float)
171
+
172
+ # Edges
173
+ edge_indexes = []
174
+ for bond in mol.GetBonds():
175
+ i = bond.GetBeginAtomIdx()
176
+ j = bond.GetEndAtomIdx()
177
+ edge_indexes.append((i, j))
178
+ edge_indexes.append((j, i))
179
+
180
+ # t - transpose, [num_of_edges, 2] -> [2, num_of_edges]
181
+ # contiguous - take the virtually transposed tensor and make its physical copy and lay bytes sequentially
182
+
183
+ edge_index = torch.tensor(edge_indexes, dtype=torch.long).t().contiguous()
184
+
185
+ # Protein (Sequence, tensor of integers)
186
+ tokens = [get_protein_features(char) for char in sequence]
187
+ if len(tokens) > self.max_seq_length:
188
+ tokens = tokens[: self.max_seq_length]
189
+ else:
190
+ tokens.extend(
191
+ [get_protein_features("PAD")] * (self.max_seq_length - len(tokens))
192
+ )
193
+ protein_tensor = torch.tensor(tokens, dtype=torch.long)
194
+
195
+ # Affinity
196
+ y = torch.tensor([affinity], dtype=torch.float)
197
+ return Data(x=x, edge_index=edge_index, protein_seq=protein_tensor, y=y)
198
+
199
+
200
+ if __name__ == "__main__":
201
+ dataset = pd.read_csv("pdbbind_refined_dataset.csv")
202
+ dataset = BindingDataset(dataset)
203
+
204
+ train_size = int(0.8 * len(dataset))
205
+ test_size = len(dataset) - train_size
206
+ train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
207
+
208
+ print(len(train_dataset))
209
+ print(len(test_dataset))
dataset_preparation.py β†’ research/dataset_preparation.py RENAMED
File without changes
inference.py β†’ research/inference.py RENAMED
File without changes
inference_attention.py β†’ research/inference_attention.py RENAMED
File without changes
research/loss.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class WeightedMSELoss(nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+
9
+ def forward(self, prediction, target):
10
+ squared_errors = (prediction - target) ** 2
11
+ weights = torch.ones_like(target)
12
+
13
+ weights[target >= 6.0] = 2.0 # Fine x2 pKd > 6 good binding
14
+ weights[target >= 7.0] = 5.0 # Fine x5 pKd > 7 great binding
15
+ weights[target >= 8.0] = 10.0 # Fine x10 pKd > 8 super binding
16
+
17
+ weighted_loss = squared_errors * weights
18
+ return torch.mean(weighted_loss)
research/model.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
5
+
6
+
7
+ class PositionalEncoding(nn.Module):
8
+ def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
9
+ super().__init__()
10
+ self.d_model = d_model
11
+ self.seq_len = seq_len
12
+ self.dropout = nn.Dropout(dropout)
13
+
14
+ # Create a matrix of shape (seq_len, d_model)
15
+ pe = torch.zeros(seq_len, d_model)
16
+
17
+ # Create a vector of shape (seq_len, 1)
18
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
19
+ 1
20
+ ) # (Seq_len, 1)
21
+ # Compute the positional encodings once in log space.
22
+ div_term = torch.exp(
23
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
24
+ )
25
+ # Apply the sin to even positions
26
+ pe[:, 0::2] = torch.sin(position * div_term)
27
+ # Apply the cos to odd positions
28
+ pe[:, 1::2] = torch.cos(position * div_term)
29
+
30
+ pe = pe.unsqueeze(0) # (1, Seq_len, d_model) batch dimension
31
+ self.register_buffer("pe", pe)
32
+
33
+ def forward(self, x):
34
+ # x: [batch_size, seq_len, d_model]
35
+ x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
36
+ return self.dropout(x)
37
+
38
+
39
+ # class LigandGNN(nn.Module): # GCN CONV
40
+ # def __init__(self, input_dim, hidden_channels):
41
+ # super().__init__()
42
+ # self.hidden_channels = hidden_channels
43
+ #
44
+ # self.conv1 = GCNConv(input_dim, hidden_channels)
45
+ # self.conv2 = GCNConv(hidden_channels, hidden_channels)
46
+ # self.conv3 = GCNConv(hidden_channels, hidden_channels)
47
+ # self.dropout = nn.Dropout(0.2)
48
+ #
49
+ # def forward(self, x, edge_index, batch):
50
+ # x = self.conv1(x, edge_index)
51
+ # x = x.relu()
52
+ # x = self.dropout(x)
53
+ #
54
+ # x = self.conv2(x, edge_index)
55
+ # x = x.relu()
56
+ # x = self.conv3(x, edge_index)
57
+ # x = self.dropout(x)
58
+ #
59
+ # # Averaging nodes and got the molecula vector
60
+ # x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
61
+ # return x
62
+
63
+
64
+ class LigandGNN(nn.Module):
65
+ def __init__(self, input_dim, hidden_channels, heads=4, dropout=0.2):
66
+ super().__init__()
67
+ # Heads=4 means we use 4 attention heads
68
+ # Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
69
+ self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
70
+ self.conv2 = GATConv(
71
+ hidden_channels, hidden_channels, heads=heads, concat=False
72
+ )
73
+ self.conv3 = GATConv(
74
+ hidden_channels, hidden_channels, heads=heads, concat=False
75
+ )
76
+ self.dropout = nn.Dropout(dropout)
77
+
78
+ def forward(self, x, edge_index, batch):
79
+ x = self.conv1(x, edge_index)
80
+ x = x.relu()
81
+ x = self.dropout(x)
82
+
83
+ x = self.conv2(x, edge_index)
84
+ x = x.relu()
85
+ x = self.dropout(x)
86
+
87
+ x = self.conv3(x, edge_index)
88
+
89
+ # Global Mean Pooling
90
+ x = global_mean_pool(x, batch)
91
+ return x
92
+
93
+
94
+ class ProteinTransformer(nn.Module):
95
+ def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
96
+ super().__init__()
97
+ self.d_model = d_model
98
+ self.embedding = nn.Embedding(vocab_size, d_model)
99
+ self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
100
+ encoder_layer = nn.TransformerEncoderLayer(
101
+ d_model=d_model, nhead=h, batch_first=True
102
+ )
103
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
104
+
105
+ self.fc = nn.Linear(d_model, output_dim)
106
+
107
+ def forward(self, x):
108
+ # x: [batch_size, seq_len]
109
+ padding_mask = x == 0 # mask for PAD tokens
110
+ x = self.embedding(x) * math.sqrt(self.d_model)
111
+ x = self.pos_encoder(x)
112
+ x = self.transformer(x, src_key_padding_mask=padding_mask)
113
+
114
+ mask = (~padding_mask).float().unsqueeze(-1)
115
+ x = x * mask
116
+
117
+ sum_x = x.sum(dim=1) # Global average pooling
118
+ token_counts = mask.sum(dim=1).clamp(min=1e-9)
119
+ x = sum_x / token_counts
120
+ x = self.fc(x)
121
+ return x
122
+
123
+
124
+ class BindingAffinityModel(nn.Module):
125
+ def __init__(
126
+ self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
127
+ ):
128
+ super().__init__()
129
+ # Tower 1 - Ligand GNN
130
+ self.ligand_gnn = LigandGNN(
131
+ input_dim=num_node_features,
132
+ hidden_channels=hidden_channels,
133
+ heads=gat_heads,
134
+ dropout=dropout,
135
+ )
136
+ # Tower 2 - Protein Transformer
137
+ self.protein_transformer = ProteinTransformer(
138
+ vocab_size=26,
139
+ d_model=hidden_channels,
140
+ output_dim=hidden_channels,
141
+ dropout=dropout,
142
+ )
143
+
144
+ self.head = nn.Sequential(
145
+ nn.Linear(hidden_channels * 2, hidden_channels),
146
+ nn.ReLU(),
147
+ nn.Dropout(dropout),
148
+ nn.Linear(hidden_channels, 1),
149
+ )
150
+
151
+ def forward(self, x, edge_index, batch, protein_seq):
152
+ ligand_vec = self.ligand_gnn(x, edge_index, batch)
153
+ batch_size = batch.max().item() + 1
154
+ protein_seq = protein_seq.view(batch_size, -1)
155
+
156
+ protein_vec = self.protein_transformer(protein_seq)
157
+ combined = torch.cat([ligand_vec, protein_vec], dim=1)
158
+ return self.head(combined)
research/model_attention.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import GATConv
4
+ from torch_geometric.utils import to_dense_batch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class CrossAttentionLayer(nn.Module):
9
+ def __init__(self, feature_dim, num_heads=4, dropout=0.1):
10
+ super().__init__()
11
+ # Main attention layer
12
+ # Feature dim is the dimension of the hidden features
13
+ self.attention = nn.MultiheadAttention(
14
+ feature_dim, num_heads, dropout=dropout, batch_first=True
15
+ )
16
+
17
+ # Normalization layer for stabilizing training
18
+ self.norm = nn.LayerNorm(feature_dim)
19
+
20
+ # Feedforward network for further processing, classical transformer style
21
+ self.ff = nn.Sequential(
22
+ nn.Linear(feature_dim, feature_dim * 4),
23
+ nn.GELU(), # GELU works better with transformers
24
+ nn.Dropout(dropout),
25
+ nn.Linear(feature_dim * 4, feature_dim),
26
+ )
27
+ self.norm_ff = nn.LayerNorm(feature_dim)
28
+ self.last_attention_weights = None
29
+
30
+ def forward(self, ligand_features, protein_features, key_padding_mask=None):
31
+ # ligand_features: [Batch, Atoms, Dim] - atoms
32
+ # protein_features: [Batch, Residues, Dim] - amino acids
33
+ # Cross attention:
34
+ # Query = Ligand (What we want to find out)
35
+ # Key, Value = Protein (Where we look for information)
36
+ # Result: "Ligand enriched with knowledge about proteins"
37
+ attention_output, attn_weights = self.attention(
38
+ query=ligand_features,
39
+ key=protein_features,
40
+ value=protein_features,
41
+ key_padding_mask=key_padding_mask,
42
+ need_weights=True,
43
+ average_attn_weights=True,
44
+ )
45
+ self.last_attention_weights = attn_weights.detach().cpu()
46
+
47
+ # Residual connection (x + attention(x)) and normalization
48
+ ligand_features = self.norm(ligand_features + attention_output)
49
+
50
+ # Feedforward network with residual connection and normalization
51
+ ff_output = self.ff(ligand_features)
52
+ ligand_features = self.norm_ff(ligand_features + ff_output)
53
+
54
+ return ligand_features
55
+
56
+
57
+ class BindingAffinityModel(nn.Module):
58
+ def __init__(
59
+ self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3
60
+ ):
61
+ super().__init__()
62
+ self.dropout = dropout
63
+ self.hidden_channels = hidden_channels
64
+
65
+ # Tower 1 - Ligand GNN with GAT layers, using 3 GAT layers, so that every atom can "see" up to 3 bonds away,
66
+ # Attention allows to measure the importance of the neighbours
67
+ self.gat1 = GATConv(
68
+ num_node_features, hidden_channels, heads=gat_heads, concat=False
69
+ )
70
+ self.gat2 = GATConv(
71
+ hidden_channels, hidden_channels, heads=gat_heads, concat=False
72
+ )
73
+ self.gat3 = GATConv(
74
+ hidden_channels, hidden_channels, heads=gat_heads, concat=False
75
+ )
76
+
77
+ # Tower 2 - Protein Transformer, 22 = 21 amino acids + 1 padding token PAD
78
+ self.protein_embedding = nn.Embedding(22, hidden_channels)
79
+ # Additional positional encoding (simple linear) to give the model information about the order
80
+ self.prot_conv = nn.Conv1d(
81
+ hidden_channels, hidden_channels, kernel_size=3, padding=1
82
+ )
83
+
84
+ # Cross-Attention Layer, atoms attending to amino acids
85
+ self.cross_attention = CrossAttentionLayer(
86
+ feature_dim=hidden_channels, num_heads=4, dropout=dropout
87
+ )
88
+
89
+ self.fc1 = nn.Linear(hidden_channels, hidden_channels)
90
+ self.fc2 = nn.Linear(hidden_channels, 1) # Final output for regression, pKd
91
+
92
+ def forward(self, x, edge_index, batch, protein_seq):
93
+ # Ligand GNN forward pass (Graph -> Node Embeddings)
94
+ x = F.elu(self.gat1(x, edge_index))
95
+ x = F.dropout(x, p=self.dropout, training=self.training)
96
+
97
+ x = F.elu(self.gat2(x, edge_index))
98
+ x = F.dropout(x, p=self.dropout, training=self.training)
99
+
100
+ x = F.elu(self.gat3(x, edge_index)) # [Total_Atoms, Hidden_Channels]
101
+
102
+ # Convert graph into tensor [Batch, Max_Atoms, Hidden_Channels]
103
+ # to_dense_batch adds zeros paddings where necessary to the size of the largest graph in the batch
104
+ ligand_dense, ligand_mask = to_dense_batch(x, batch)
105
+ # ligand_dense: [Batch, Max_Atoms, Hidden_Channels]
106
+ # ligand_mask: [Batch, Max_Atoms] True where there is real atom, False where there is padding
107
+
108
+ batch_size = ligand_dense.size(0)
109
+ protein_seq = protein_seq.view(batch_size, -1) # [Batch, Seq_Len]
110
+
111
+ # Protein forward pass protein_seq: [Batch, Seq_Len]
112
+ p = self.protein_embedding(protein_seq) # [Batch, Seq_Len, Hidden_Channels]
113
+
114
+ # A simple convolution to understand local context in amino acids
115
+ p = p.permute(0, 2, 1) # Change to [Batch, Hidden_Channels, Seq_Len] for Conv1d
116
+ p = F.relu(self.prot_conv(p))
117
+ p = p.permute(0, 2, 1) # [Batch, Seq, Hidden_Channels]
118
+
119
+ # Mask for protein (where PAD=0, True, but MHA needs True where IGNOREME)
120
+ # In Pytorch MHA, the key_padding_mask should be True where we want to ignore
121
+ protein_pad_mask = protein_seq == 0
122
+
123
+ # Cross-Attention
124
+ x_cross = self.cross_attention(
125
+ ligand_dense, p, key_padding_mask=protein_pad_mask
126
+ )
127
+
128
+ # Pooling over atoms to get a single vector per molecule, considering only real atoms, ignoring paddings
129
+ # ligand mask True where real atom, False where padding
130
+ mask_expanded = ligand_mask.unsqueeze(-1) # [Batch, Max_Atoms, 1]
131
+
132
+ # Zero out the padded atom features
133
+ x_cross = x_cross * mask_expanded
134
+
135
+ # Sum the features of real atoms / number of real atoms to get the mean
136
+ sum_features = torch.sum(x_cross, dim=1) # [Batch, Hidden_Channels]
137
+ num_atoms = torch.sum(mask_expanded, dim=1) # [Batch, 1]
138
+ pooled_x = sum_features / (num_atoms + 1e-6) # Avoid division by zero
139
+
140
+ # MLP Head
141
+ out = F.relu(self.fc1(pooled_x))
142
+ out = F.dropout(out, p=self.dropout, training=self.training)
143
+ out = self.fc2(out)
144
+ return out
model_pl.py β†’ research/model_pl.py RENAMED
File without changes
optuna_train.py β†’ research/optuna_train.py RENAMED
File without changes
optuna_train_attention.py β†’ research/optuna_train_attention.py RENAMED
@@ -7,11 +7,12 @@ import numpy as np
7
  from torch_geometric.loader import DataLoader
8
  from torch.utils.data import random_split
9
  from dataset import BindingDataset
 
10
  from model_attention import BindingAffinityModel
11
 
12
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  N_TRIALS = 50
14
- MAX_EPOCHS_PER_TRIAL = 60
15
  LOG_DIR = "runs"
16
  DATA_CSV = "pdbbind_refined_dataset.csv"
17
 
@@ -88,7 +89,8 @@ def objective(trial):
88
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
89
  optimizer, mode="min", factor=0.5, patience=5
90
  )
91
- criterion = nn.MSELoss()
 
92
 
93
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
94
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
@@ -118,7 +120,7 @@ if __name__ == "__main__":
118
  direction="minimize",
119
  pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=10),
120
  storage=storage_name,
121
- study_name="binding_prediction_optimization_attentionV2",
122
  load_if_exists=True,
123
  )
124
  print("Start hyperparameter optimization...")
 
7
  from torch_geometric.loader import DataLoader
8
  from torch.utils.data import random_split
9
  from dataset import BindingDataset
10
+ from loss import WeightedMSELoss
11
  from model_attention import BindingAffinityModel
12
 
13
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  N_TRIALS = 50
15
+ MAX_EPOCHS_PER_TRIAL = 50
16
  LOG_DIR = "runs"
17
  DATA_CSV = "pdbbind_refined_dataset.csv"
18
 
 
89
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
90
  optimizer, mode="min", factor=0.5, patience=5
91
  )
92
+ # criterion = nn.MSELoss()
93
+ criterion = WeightedMSELoss()
94
 
95
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
96
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
 
120
  direction="minimize",
121
  pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=10),
122
  storage=storage_name,
123
+ study_name="binding_prediction_optimization_attentionWeightedLoss",
124
  load_if_exists=True,
125
  )
126
  print("Start hyperparameter optimization...")
pdbbind_refined_dataset.csv β†’ research/pdbbind_refined_dataset.csv RENAMED
File without changes
research/requirements_dev.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ pytorch-lightning
3
+ optuna
4
+ optuna-dashboard
5
+ tensorboard
6
+
7
+ numpy
8
+ pandas
9
+ matplotlib
10
+ seaborn
11
+
12
+ rdkit
13
+ biopython
14
+
15
+ jupyter
16
+ tqdm
17
+ black
18
+
train.py β†’ research/train.py RENAMED
File without changes
train_attention.py β†’ research/train_attention.py RENAMED
@@ -6,6 +6,7 @@ import pandas as pd
6
  from torch.utils.data import random_split
7
  from torch_geometric.loader import DataLoader
8
  from dataset import BindingDataset
 
9
  from model_attention import BindingAffinityModel
10
  from tqdm import tqdm
11
  from torch.utils.tensorboard import SummaryWriter
@@ -13,7 +14,7 @@ import numpy as np
13
  from datetime import datetime
14
  import os
15
 
16
- # 2.02
17
  # BATCH_SIZE = 16
18
  # LR = 0.00035 # Reduced learning rate
19
  # WEIGHT_DECAY = 1e-5 # Slightly increased weight decay (regularization)
@@ -23,16 +24,27 @@ import os
23
  # HIDDEN_CHANNELS = 256
24
 
25
  # 1.90 from Optuna
 
 
 
 
 
 
 
 
 
26
  BATCH_SIZE = 16
27
- LR = 0.000034
28
- WEIGHT_DECAY = 1e-6
29
- DROPOUT = 0.26
30
  EPOCHS = 100
31
- HIDDEN_CHANNELS = 256
32
- GAT_HEADS = 2
33
 
34
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- LOG_DIR = f"runs/experiment_attention{datetime.now().strftime('%Y%m%d_%H%M%S')}"
 
 
36
  TOP_K = 3
37
  SAVES_DIR = LOG_DIR + "/models"
38
 
@@ -128,7 +140,8 @@ def main():
128
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
129
  optimizer, mode="min", factor=0.5, patience=8
130
  )
131
- criterion = nn.MSELoss()
 
132
 
133
  top_models = []
134
 
@@ -152,7 +165,7 @@ def main():
152
  end="",
153
  )
154
 
155
- filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
156
 
157
  torch.save(model.state_dict(), filename)
158
  top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
@@ -173,7 +186,7 @@ def main():
173
  print("Training finished.")
174
  print("Top models saved:")
175
  for i, m in enumerate(top_models):
176
- print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")
177
 
178
 
179
  if __name__ == "__main__":
 
6
  from torch.utils.data import random_split
7
  from torch_geometric.loader import DataLoader
8
  from dataset import BindingDataset
9
+ from loss import WeightedMSELoss
10
  from model_attention import BindingAffinityModel
11
  from tqdm import tqdm
12
  from torch.utils.tensorboard import SummaryWriter
 
14
  from datetime import datetime
15
  import os
16
 
17
+ # 2.02 default parameters
18
  # BATCH_SIZE = 16
19
  # LR = 0.00035 # Reduced learning rate
20
  # WEIGHT_DECAY = 1e-5 # Slightly increased weight decay (regularization)
 
24
  # HIDDEN_CHANNELS = 256
25
 
26
  # 1.90 from Optuna
27
+ # BATCH_SIZE = 16
28
+ # LR = 0.000034
29
+ # WEIGHT_DECAY = 1e-6
30
+ # DROPOUT = 0.26
31
+ # EPOCHS = 100
32
+ # HIDDEN_CHANNELS = 256
33
+ # GAT_HEADS = 2
34
+
35
+ # Weighted Loss
36
  BATCH_SIZE = 16
37
+ LR = 0.00022
38
+ WEIGHT_DECAY = 1e-5
39
+ DROPOUT = 0.25
40
  EPOCHS = 100
41
+ HIDDEN_CHANNELS = 128
42
+ GAT_HEADS = 4
43
 
44
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
45
+ LOG_DIR = (
46
+ f"runs/experiment_attention{datetime.now().strftime('%Y%m%d_%H%M%S')}_weighted_loss"
47
+ )
48
  TOP_K = 3
49
  SAVES_DIR = LOG_DIR + "/models"
50
 
 
140
  scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
141
  optimizer, mode="min", factor=0.5, patience=8
142
  )
143
+ # criterion = nn.MSELoss()
144
+ criterion = WeightedMSELoss()
145
 
146
  top_models = []
147
 
 
165
  end="",
166
  )
167
 
168
+ filename = f"{SAVES_DIR}/model_ep{epoch:03d}_weighted_loss{test_loss:.4f}.pth"
169
 
170
  torch.save(model.state_dict(), filename)
171
  top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
 
186
  print("Training finished.")
187
  print("Top models saved:")
188
  for i, m in enumerate(top_models):
189
+ print(f"{i + 1}. {m['path']} (Weighted Loss: {m['loss']:.4f})")
190
 
191
 
192
  if __name__ == "__main__":
train_pl.py β†’ research/train_pl.py RENAMED
File without changes
visualization.ipynb β†’ research/visualization.ipynb RENAMED
File without changes
utils.py CHANGED
@@ -7,21 +7,24 @@ from rdkit.Chem import Descriptors
7
  import py3Dmol
8
  from jinja2 import Environment, FileSystemLoader
9
  from google import genai
10
- from google.genai import types
11
  from decouple import config
12
 
13
  GEMINI_API_KEY = config("GEMINI_API_KEY")
14
 
15
- # Π£Π±ΠΈΡ€Π°Π΅ΠΌ лишниС ΠΈΠΌΠΏΠΎΡ€Ρ‚Ρ‹ (nglview Ρ‚ΡƒΡ‚ большС Π½Π΅ Π½ΡƒΠΆΠ΅Π½ для standalone HTML)
16
  from dataset import get_atom_features, get_protein_features
17
  from model_attention import BindingAffinityModel
18
 
19
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
- # ΠžΠ±Π½ΠΎΠ²ΠΈΡ‚Π΅ ΠΏΡƒΡ‚ΡŒ, Ссли Π½ΡƒΠΆΠ½ΠΎ
21
- MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
22
 
23
- GAT_HEADS = 2
24
- HIDDEN_CHANNELS = 256
 
 
 
 
 
 
25
 
26
 
27
  def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
@@ -45,14 +48,18 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
45
  tokens = tokens[:1200]
46
  else:
47
  tokens.extend([0] * (1200 - len(tokens)))
48
- protein_sequence_tensor = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
 
 
49
 
50
  data = Data(x=x, edge_index=edge_index)
51
  batch = Batch.from_data_list([data]).to(DEVICE)
52
  num_features = x.shape[1]
53
 
54
  # Model
55
- model = BindingAffinityModel(num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS).to(DEVICE)
 
 
56
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
57
  model.eval()
58
 
@@ -65,7 +72,9 @@ def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
65
  importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
66
 
67
  if importance.max() > 0:
68
- importance = (importance - importance.min()) / (importance.max() - importance.min())
 
 
69
 
70
  importance[importance < 0.01] = 0
71
  return mol, importance, pred.item()
@@ -112,19 +121,16 @@ def get_lipinski_properties(mol):
112
  "violations": violations,
113
  "status_text": status,
114
  "css_class": css_class,
115
- "bad_params": ", ".join(bad_params) if bad_params else "None"
116
  }
117
 
118
 
119
  def get_py3dmol_view(mol, importance):
120
  view = py3Dmol.view(width="100%", height="600px")
121
  view.addModel(Chem.MolToMolBlock(mol), "sdf")
122
- view.setBackgroundColor('white')
123
 
124
- view.setStyle({}, {
125
- 'stick': {'radius': 0.15},
126
- 'sphere': {'scale': 0.25}
127
- })
128
 
129
  indices_sorted = np.argsort(importance)[::-1]
130
  top_indices = set(indices_sorted[:15])
@@ -137,16 +143,19 @@ def get_py3dmol_view(mol, importance):
137
  symbol = mol.GetAtomWithIdx(i).GetSymbol()
138
  label_text = f"{i}:{symbol}:{val:.2f}"
139
 
140
- view.addLabel(label_text, {
141
- 'position': {'x': pos.x, 'y': pos.y, 'z': pos.z},
142
- 'fontSize': 14,
143
- 'fontColor': 'white',
144
- 'backgroundColor': 'black',
145
- 'backgroundOpacity': 0.7,
146
- 'borderThickness': 0,
147
- 'inFront': True,
148
- 'showBackground': True
149
- })
 
 
 
150
  view.zoomTo()
151
  return view
152
 
@@ -166,32 +175,40 @@ def save_standalone_ngl_html(mol, importance, filepath):
166
  indices_sorted = np.argsort(importance)[::-1]
167
  top_indices = indices_sorted[:15]
168
 
169
-
170
  selection_list = [str(i) for i in top_indices]
171
  selection_str = "@" + ",".join(selection_list)
172
 
173
  if not selection_list:
174
  selection_str = "@-1"
175
 
 
 
176
 
177
- env = Environment(loader=FileSystemLoader('templates'))
178
- template = env.get_template('ngl_view.html')
179
-
180
- rendered_html = template.render(pdb_block=final_pdb_block, selection_str=selection_str)
181
 
182
  with open(filepath, "w", encoding="utf-8") as f:
183
  f.write(rendered_html)
184
 
185
 
186
- def get_gemini_explanation(ligand_smiles, protein_sequence, affinity, top_atoms, lipinski):
 
 
187
  if not GEMINI_API_KEY:
188
  return "<p class='text-warning'>API Key for Gemini not found. Please set GOOGLE_API_KEY environment variable.</p>"
189
 
190
  # Forming a list of top important atoms for a prompt
191
- atoms_desc = ", ".join([f"{a['symbol']}(idx {a['id']}, score {a['score']})" for a in top_atoms[:10]])
 
 
192
 
193
  # Cut a protein to not spend too many tokens
194
- prot_short = protein_sequence[:100] + "..." if len(protein_sequence) > 100 else protein_sequence
 
 
 
 
195
 
196
  prompt = f"""
197
  You are an expert Computational Chemist and Drug Discovery Scientist.
@@ -216,9 +233,8 @@ def get_gemini_explanation(ligand_smiles, protein_sequence, affinity, top_atoms,
216
  try:
217
  client = genai.Client(api_key=GEMINI_API_KEY)
218
  response = client.models.generate_content(
219
- model="gemini-2.5-flash",
220
- contents=prompt
221
  )
222
  return response.text
223
  except Exception as e:
224
- return f"<p class='text-danger'>Error generating explanation: {str(e)}</p>"
 
7
  import py3Dmol
8
  from jinja2 import Environment, FileSystemLoader
9
  from google import genai
 
10
  from decouple import config
11
 
12
  GEMINI_API_KEY = config("GEMINI_API_KEY")
13
 
14
+
15
  from dataset import get_atom_features, get_protein_features
16
  from model_attention import BindingAffinityModel
17
 
18
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
19
 
20
+ # MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
21
+ #
22
+ # GAT_HEADS = 2
23
+ # HIDDEN_CHANNELS = 256
24
+
25
+ MODEL_PATH = "runs/experiment_attention20260127_055340_weighted_loss/models/model_ep028_weighted_loss6.7715.pth"
26
+ GAT_HEADS = 4
27
+ HIDDEN_CHANNELS = 128
28
 
29
 
30
  def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
 
48
  tokens = tokens[:1200]
49
  else:
50
  tokens.extend([0] * (1200 - len(tokens)))
51
+ protein_sequence_tensor = (
52
+ torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
53
+ )
54
 
55
  data = Data(x=x, edge_index=edge_index)
56
  batch = Batch.from_data_list([data]).to(DEVICE)
57
  num_features = x.shape[1]
58
 
59
  # Model
60
+ model = BindingAffinityModel(
61
+ num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS
62
+ ).to(DEVICE)
63
  model.load_state_dict(torch.load(model_path, map_location=DEVICE))
64
  model.eval()
65
 
 
72
  importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
73
 
74
  if importance.max() > 0:
75
+ importance = (importance - importance.min()) / (
76
+ importance.max() - importance.min()
77
+ )
78
 
79
  importance[importance < 0.01] = 0
80
  return mol, importance, pred.item()
 
121
  "violations": violations,
122
  "status_text": status,
123
  "css_class": css_class,
124
+ "bad_params": ", ".join(bad_params) if bad_params else "None",
125
  }
126
 
127
 
128
  def get_py3dmol_view(mol, importance):
129
  view = py3Dmol.view(width="100%", height="600px")
130
  view.addModel(Chem.MolToMolBlock(mol), "sdf")
131
+ view.setBackgroundColor("white")
132
 
133
+ view.setStyle({}, {"stick": {"radius": 0.15}, "sphere": {"scale": 0.25}})
 
 
 
134
 
135
  indices_sorted = np.argsort(importance)[::-1]
136
  top_indices = set(indices_sorted[:15])
 
143
  symbol = mol.GetAtomWithIdx(i).GetSymbol()
144
  label_text = f"{i}:{symbol}:{val:.2f}"
145
 
146
+ view.addLabel(
147
+ label_text,
148
+ {
149
+ "position": {"x": pos.x, "y": pos.y, "z": pos.z},
150
+ "fontSize": 14,
151
+ "fontColor": "white",
152
+ "backgroundColor": "black",
153
+ "backgroundOpacity": 0.7,
154
+ "borderThickness": 0,
155
+ "inFront": True,
156
+ "showBackground": True,
157
+ },
158
+ )
159
  view.zoomTo()
160
  return view
161
 
 
175
  indices_sorted = np.argsort(importance)[::-1]
176
  top_indices = indices_sorted[:15]
177
 
 
178
  selection_list = [str(i) for i in top_indices]
179
  selection_str = "@" + ",".join(selection_list)
180
 
181
  if not selection_list:
182
  selection_str = "@-1"
183
 
184
+ env = Environment(loader=FileSystemLoader("templates"))
185
+ template = env.get_template("ngl_view.html")
186
 
187
+ rendered_html = template.render(
188
+ pdb_block=final_pdb_block, selection_str=selection_str
189
+ )
 
190
 
191
  with open(filepath, "w", encoding="utf-8") as f:
192
  f.write(rendered_html)
193
 
194
 
195
+ def get_gemini_explanation(
196
+ ligand_smiles, protein_sequence, affinity, top_atoms, lipinski
197
+ ):
198
  if not GEMINI_API_KEY:
199
  return "<p class='text-warning'>API Key for Gemini not found. Please set GOOGLE_API_KEY environment variable.</p>"
200
 
201
  # Forming a list of top important atoms for a prompt
202
+ atoms_desc = ", ".join(
203
+ [f"{a['symbol']}(idx {a['id']}, score {a['score']})" for a in top_atoms[:10]]
204
+ )
205
 
206
  # Cut a protein to not spend too many tokens
207
+ prot_short = (
208
+ protein_sequence[:100] + "..."
209
+ if len(protein_sequence) > 100
210
+ else protein_sequence
211
+ )
212
 
213
  prompt = f"""
214
  You are an expert Computational Chemist and Drug Discovery Scientist.
 
233
  try:
234
  client = genai.Client(api_key=GEMINI_API_KEY)
235
  response = client.models.generate_content(
236
+ model="gemini-2.5-flash", contents=prompt
 
237
  )
238
  return response.text
239
  except Exception as e:
240
+ return f"<p class='text-danger'>Error generating explanation: {str(e)}</p>"