AlexSychovUN commited on
Commit
e33b6c9
·
1 Parent(s): 1390640

Updated all code

Browse files
.gitignore CHANGED
@@ -1 +1,7 @@
1
- .idea
 
 
 
 
 
 
 
1
+ .idea
2
+ .venv
3
+ .ipynb_checkpoints
4
+
5
+ /refined-set/
6
+ /data
7
+ /lightning_logs
GNN_classification/dataset/classification/EDA.ipynb CHANGED
@@ -126,7 +126,9 @@
126
  }
127
  },
128
  "cell_type": "code",
129
- "source": "train_dataset['label'].value_counts()",
 
 
130
  "id": "355c3ed8e5f76bbf",
131
  "outputs": [
132
  {
 
126
  }
127
  },
128
  "cell_type": "code",
129
+ "source": [
130
+ "train_dataset[\"label\"].value_counts()"
131
+ ],
132
  "id": "355c3ed8e5f76bbf",
133
  "outputs": [
134
  {
GNNs__practice.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
all_inferences.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import os.path
3
+ import torch
4
+ import numpy as np
5
+ from torch_geometric.data import Data, Batch
6
+ from rdkit import Chem
7
+ from rdkit.Chem import AllChem
8
+ import nglview as nv
9
+ import py3Dmol
10
+ from nglview import write_html
11
+
12
+
13
+ import matplotlib
14
+ import matplotlib.cm as cm
15
+ import matplotlib.colors as mcolors
16
+
17
+ from dataset import get_atom_features, get_protein_features
18
+ from model_attention import BindingAffinityModel
19
+
20
+
21
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
23
+
24
+ GAT_HEADS = 2
25
+ HIDDEN_CHANNELS = 256
26
+
27
+ def get_inference_data(ligand_smiles, protein_sequence, model_path):
28
+ """
29
+ Returns:
30
+ - mol: RDKit molecule object with 3D coordinates
31
+ - importance: list of importance scores for each atom
32
+ - predicted_affinity: predicted binding affinity value
33
+ """
34
+ # Prepare ligand molecule with geometry RDKit
35
+ mol = Chem.MolFromSmiles(ligand_smiles)
36
+ mol = Chem.AddHs(mol)
37
+ AllChem.EmbedMolecule(mol, randomSeed=42)
38
+
39
+ # Graph data PyTorch
40
+ atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
41
+ x = torch.tensor(np.array(atom_features), dtype=torch.float)
42
+ edge_index = []
43
+ for bond in mol.GetBonds():
44
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
45
+ edge_index.extend([(i, j), (j, i)])
46
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
47
+
48
+ tokens = [get_protein_features(c) for c in protein_sequence]
49
+ if len(tokens) > 1200: tokens = tokens[:1200]
50
+ else: tokens.extend([0] * (1200 - len(tokens)))
51
+ protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
52
+
53
+ data = Data(x=x, edge_index=edge_index)
54
+ batch = Batch.from_data_list([data]).to(DEVICE)
55
+ num_features = x.shape[1]
56
+
57
+ # Model loading
58
+ model = BindingAffinityModel(num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS).to(DEVICE)
59
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
60
+ model.eval()
61
+
62
+ # Prediction
63
+ with torch.no_grad():
64
+ pred = model(batch.x, batch.edge_index, batch.batch, protein_sequence)
65
+ attention_weights = model.cross_attention.last_attention_weights[0]
66
+
67
+ # Attention importance, Max + Normalize
68
+ real_prot_len = len([t for t in tokens if t != 0])
69
+ importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
70
+
71
+ # Normalize to [0, 1]
72
+ if importance.max() > 0:
73
+ importance = (importance - importance.min()) / (importance.max() - importance.min())
74
+
75
+ # Noise reduction
76
+ importance[importance < 0.01] = 0
77
+ return mol, importance, pred.item()
78
+
79
+
80
+ def print_atom_scores(mol, importance):
81
+ print("Atom importance scores:")
82
+
83
+ atom_data = []
84
+ for i, score in enumerate(importance):
85
+ if score > 0.1:
86
+ symbol = mol.GetAtomWithIdx(i).GetSymbol()
87
+ atom_data.append((i, symbol, score))
88
+
89
+ atom_data.sort(key=lambda x: x[2], reverse=True)
90
+
91
+ for idx, symbol, score in atom_data:
92
+ fire = "🔥" if score > 0.8 else ("✨" if score > 0.5 else "")
93
+ print(f"Atom {idx} ({symbol}): Importance = {score:.3f} {fire}")
94
+
95
+
96
+
97
+ def get_py3dmol(mol, importance, score):
98
+
99
+ view = py3Dmol.view(width=1000, height=800)
100
+ view.addModel(Chem.MolToMolBlock(mol), "sdf")
101
+ view.setBackgroundColor('white')
102
+
103
+ # 1. БАЗОВЫЙ СТИЛЬ (ГРУНТОВКА)
104
+ # Задаем единый размер для всей молекулы сразу
105
+ # scale: 0.25 — оптимальный средний размер
106
+ view.setStyle({}, {
107
+ 'stick': {'color': '#cccccc', 'radius': 0.1},
108
+ 'sphere': {'color': '#cccccc', 'scale': 0.25}
109
+ })
110
+
111
+ red_atoms = []
112
+ orange_atoms = []
113
+ blue_atoms = []
114
+
115
+ indices_sorted = np.argsort(importance)[::-1]
116
+ top_indices = set(indices_sorted[:15])
117
+ labels_to_add = []
118
+
119
+ conf = mol.GetConformer()
120
+
121
+ # 2. СОРТИРОВКА (ТОЛЬКО ЦВЕТА)
122
+ for i, val in enumerate(importance):
123
+ if val >= 0.70:
124
+ red_atoms.append(i)
125
+ elif val >= 0.55:
126
+ orange_atoms.append(i)
127
+ elif val >= 0.40:
128
+ blue_atoms.append(i)
129
+
130
+ if i in top_indices and val > 0.1:
131
+ pos = conf.GetAtomPosition(i)
132
+ symbol = mol.GetAtomWithIdx(i).GetSymbol()
133
+ labels_to_add.append({
134
+ 'text': f"{i}:{symbol}:{val:.2f}",
135
+ 'pos': {'x': pos.x, 'y': pos.y, 'z': pos.z}
136
+ })
137
+
138
+ # 3. ПРИМЕНЕНИЕ СТИЛЕЙ
139
+ # Обрати внимание: scale везде 0.25 (или 0.28, чтобы чуть выделить цветные)
140
+ # Мы меняем ТОЛЬКО ЦВЕТ.
141
+
142
+ if red_atoms:
143
+ view.addStyle({'serial': red_atoms}, {
144
+ 'sphere': {'color': '#FF0000', 'scale': 0.28},
145
+ 'stick': {'color': '#FF0000', 'radius': 0.12}
146
+ })
147
+
148
+ if orange_atoms:
149
+ view.addStyle({'serial': orange_atoms}, {
150
+ 'sphere': {'color': '#FF8C00', 'scale': 0.28},
151
+ 'stick': {'color': '#FF8C00', 'radius': 0.12}
152
+ })
153
+
154
+ if blue_atoms:
155
+ view.addStyle({'serial': blue_atoms}, {
156
+ 'sphere': {'color': '#7777FF', 'scale': 0.28}
157
+ })
158
+
159
+ # 4. МЕТКИ
160
+ for label in labels_to_add:
161
+ view.addLabel(label['text'], {
162
+ 'position': label['pos'],
163
+ 'fontSize': 14,
164
+ 'fontColor': 'white',
165
+ 'backgroundColor': 'black',
166
+ 'backgroundOpacity': 0.7,
167
+ 'borderThickness': 0,
168
+ 'inFront': True,
169
+ 'showBackground': True
170
+ })
171
+
172
+ view.zoomTo()
173
+ view.addLabel(f"Predicted pKd: {float(score):.2f}",
174
+ {'position': {'x': -5, 'y': 10, 'z': 0}, 'backgroundColor': 'black', 'fontColor': 'white'})
175
+
176
+ return view
177
+
178
+
179
+ def get_ngl(mol, importance):
180
+ pdb_temp = Chem.MolToPDBBlock(mol)
181
+ mol_pdb = Chem.MolFromPDBBlock(pdb_temp, removeHs=False)
182
+
183
+ for i, atom in enumerate(mol_pdb.GetAtoms()):
184
+ info = atom.GetPDBResidueInfo()
185
+ if info:
186
+ val = float(importance[i] * 100.0)
187
+ info.SetTempFactor(val)
188
+ final_pdb_block = Chem.MolToPDBBlock(mol_pdb)
189
+ structure = nv.TextStructure(final_pdb_block, ext="pdb")
190
+ view = nv.NGLWidget(structure)
191
+ view.clear_representations()
192
+
193
+ view.add_representation('ball+stick', colorScheme='bfactor', colorScale=['blue', 'white', 'red'], colorDomain=[10, 80], radiusScale=1.0)
194
+
195
+ indices_sorted = np.argsort(importance)[::-1]
196
+ top_indices = indices_sorted[:15]
197
+
198
+ selection_str = "@" + ",".join(map(str, top_indices))
199
+ view.add_representation('label',
200
+ selection=selection_str, # Подписываем только избранных
201
+ labelType='atomindex', # Показывать Индекс (0, 1, 2...)
202
+ color='black', # Черный текст
203
+ radius=2.0, # Размер шрифта (попробуйте 1.5 - 3.0)
204
+ zOffset=1.0) # Чуть сдвинуть к камере
205
+
206
+
207
+ view.center()
208
+ return view
209
+
210
+ if __name__ == "__main__":
211
+ smiles = "COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1"
212
+ protein = "PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF"
213
+ affinity = 11.92
214
+
215
+ file_name_py3dmol = "html_results/py3dmol_result.html"
216
+ file_name_ngl = "html_results/ngl_result.html"
217
+
218
+ mol, importance, score = get_inference_data(smiles, protein, MODEL_PATH)
219
+ print_atom_scores(mol, importance)
220
+ py3dmol_view = get_py3dmol(mol, importance, score)
221
+ py3dmol_view.write_html(file_name_py3dmol)
222
+
223
+ ngl_widget = get_ngl(mol, importance)
224
+ nv.write_html(file_name_ngl, ngl_widget)
225
+
dataset.py CHANGED
@@ -65,6 +65,14 @@ def get_atom_features(atom):
65
  degrees_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
66
  numhs_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
67
  implicit_valences_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
 
 
 
 
 
 
 
 
68
  return np.array(
69
  # Type of atom (Symbol)
70
  one_of_k_encoding(atom.GetSymbol(), symbols_list)
@@ -93,22 +101,55 @@ def get_atom_features(atom):
93
  +
94
  # Aromaticity (Boolean)
95
  [atom.GetIsAromatic()]
 
 
 
 
 
 
 
 
 
96
  )
97
 
 
98
  def get_protein_features(char):
99
- prot_vocab= {
100
- 'A': 1, 'R': 2, 'N': 3, 'D': 4, 'C': 5, 'Q': 6, 'E': 7, 'G': 8, 'H': 9,
101
- 'I': 10, 'L': 11, 'K': 12, 'M': 13, 'F': 14, 'P': 15, 'S': 16, 'T': 17,
102
- 'W': 18, 'Y': 19, 'V': 20, 'X': 21, 'Z': 21, 'B': 21,
103
- 'PAD': 0, 'UNK': 21
104
- }
105
- return prot_vocab.get(char, prot_vocab['UNK'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
  class BindingDataset(Dataset):
109
  def __init__(self, dataframe, max_seq_length=1000):
110
  self.data = dataframe
111
- self.max_seq_length = max_seq_length # Define a maximum sequence length for padding/truncation
 
 
112
 
113
  def __len__(self):
114
  return len(self.data)
@@ -144,9 +185,11 @@ class BindingDataset(Dataset):
144
  # Protein (Sequence, tensor of integers)
145
  tokens = [get_protein_features(char) for char in sequence]
146
  if len(tokens) > self.max_seq_length:
147
- tokens = tokens[:self.max_seq_length]
148
  else:
149
- tokens.extend([get_protein_features("PAD")] * (self.max_seq_length - len(tokens)))
 
 
150
  protein_tensor = torch.tensor(tokens, dtype=torch.long)
151
 
152
  # Affinity
@@ -164,4 +207,3 @@ if __name__ == "__main__":
164
 
165
  print(len(train_dataset))
166
  print(len(test_dataset))
167
-
 
65
  degrees_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
66
  numhs_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
67
  implicit_valences_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
68
+
69
+ formal_charge_list = [-2, -1, 0, 1, 2]
70
+ chirality_list = [
71
+ Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
72
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
73
+ Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
74
+ Chem.rdchem.ChiralType.CHI_OTHER,
75
+ ]
76
  return np.array(
77
  # Type of atom (Symbol)
78
  one_of_k_encoding(atom.GetSymbol(), symbols_list)
 
101
  +
102
  # Aromaticity (Boolean)
103
  [atom.GetIsAromatic()]
104
+ +
105
+ # Formal Charge
106
+ one_of_k_encoding(atom.GetFormalCharge(), formal_charge_list)
107
+ +
108
+ # Chirality (Geometry)
109
+ one_of_k_encoding(atom.GetChiralTag(), chirality_list)
110
+ +
111
+ # Is in ring (Boolean)
112
+ [atom.IsInRing()]
113
  )
114
 
115
+
116
  def get_protein_features(char):
117
+ prot_vocab = {
118
+ "A": 1,
119
+ "R": 2,
120
+ "N": 3,
121
+ "D": 4,
122
+ "C": 5,
123
+ "Q": 6,
124
+ "E": 7,
125
+ "G": 8,
126
+ "H": 9,
127
+ "I": 10,
128
+ "L": 11,
129
+ "K": 12,
130
+ "M": 13,
131
+ "F": 14,
132
+ "P": 15,
133
+ "S": 16,
134
+ "T": 17,
135
+ "W": 18,
136
+ "Y": 19,
137
+ "V": 20,
138
+ "X": 21,
139
+ "Z": 21,
140
+ "B": 21,
141
+ "PAD": 0,
142
+ "UNK": 21,
143
+ }
144
+ return prot_vocab.get(char, prot_vocab["UNK"])
145
 
146
 
147
  class BindingDataset(Dataset):
148
  def __init__(self, dataframe, max_seq_length=1000):
149
  self.data = dataframe
150
+ self.max_seq_length = (
151
+ max_seq_length # Define a maximum sequence length for padding/truncation
152
+ )
153
 
154
  def __len__(self):
155
  return len(self.data)
 
185
  # Protein (Sequence, tensor of integers)
186
  tokens = [get_protein_features(char) for char in sequence]
187
  if len(tokens) > self.max_seq_length:
188
+ tokens = tokens[: self.max_seq_length]
189
  else:
190
+ tokens.extend(
191
+ [get_protein_features("PAD")] * (self.max_seq_length - len(tokens))
192
+ )
193
  protein_tensor = torch.tensor(tokens, dtype=torch.long)
194
 
195
  # Affinity
 
207
 
208
  print(len(train_dataset))
209
  print(len(test_dataset))
 
inference.py CHANGED
@@ -10,9 +10,11 @@ from model import BindingAffinityModel
10
  from tqdm import tqdm
11
  from scipy.stats import pearsonr
12
  from torch.utils.data import random_split
 
 
 
 
13
 
14
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
- MODEL_PATH = "best_model_gat.pth"
16
 
17
  def set_seed(seed=42):
18
  random.seed(seed)
@@ -21,11 +23,12 @@ def set_seed(seed=42):
21
  np.random.seed(seed)
22
  return torch.Generator().manual_seed(seed)
23
 
 
24
  def predict_and_plot():
25
  gen = set_seed(42)
26
  print("Loading data...")
27
 
28
- dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
29
  dataframe.dropna(inplace=True)
30
  dataset = BindingDataset(dataframe)
31
  if len(dataset) == 0:
@@ -40,7 +43,12 @@ def predict_and_plot():
40
  num_features = test_dataset[0].x.shape[1]
41
 
42
  print("Loading model...")
43
- model = BindingAffinityModel(num_node_features=num_features, hidden_channels_gnn=128).to(DEVICE)
 
 
 
 
 
44
  model.load_state_dict(torch.load(MODEL_PATH))
45
  model.eval()
46
 
@@ -67,19 +75,28 @@ def predict_and_plot():
67
  print(f"Pearson Correlation: {pearson_corr:.4f}")
68
 
69
  plt.figure(figsize=(9, 9))
70
- plt.scatter(y_true, y_pred, alpha=0.4, s=15, c='blue', label='Predictions')
71
- plt.plot([min(y_true), max(y_true)], [min(y_true), max(y_true)], color='red', linestyle='--', linewidth=2,
72
- label='Ideal')
 
 
 
 
 
 
73
 
74
- plt.xlabel('Experimental Affinity (pK)')
75
- plt.ylabel('Predicted Affinity (pK)')
76
- plt.title(f'Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}')
 
 
77
  plt.legend()
78
  plt.grid(True, alpha=0.3)
79
- plot_file = 'final_results_gat.png'
80
  plt.savefig(plot_file)
81
  print(f"График сохранен в {plot_file}")
82
  plt.show()
83
 
 
84
  if __name__ == "__main__":
85
- predict_and_plot()
 
10
  from tqdm import tqdm
11
  from scipy.stats import pearsonr
12
  from torch.utils.data import random_split
13
+ from train import GAT_HEADS, DROPOUT, HIDDEN_CHANNELS
14
+
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ MODEL_PATH = "runs/experiment_20260122_230138_GAT_without_formal_charge_chirality_ring_scheduler/models/model_ep092_mse2.3805.pth"
17
 
 
 
18
 
19
  def set_seed(seed=42):
20
  random.seed(seed)
 
23
  np.random.seed(seed)
24
  return torch.Generator().manual_seed(seed)
25
 
26
+
27
  def predict_and_plot():
28
  gen = set_seed(42)
29
  print("Loading data...")
30
 
31
+ dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
32
  dataframe.dropna(inplace=True)
33
  dataset = BindingDataset(dataframe)
34
  if len(dataset) == 0:
 
43
  num_features = test_dataset[0].x.shape[1]
44
 
45
  print("Loading model...")
46
+ model = BindingAffinityModel(
47
+ num_node_features=num_features,
48
+ hidden_channels=HIDDEN_CHANNELS,
49
+ gat_heads=GAT_HEADS,
50
+ dropout=DROPOUT,
51
+ ).to(DEVICE)
52
  model.load_state_dict(torch.load(MODEL_PATH))
53
  model.eval()
54
 
 
75
  print(f"Pearson Correlation: {pearson_corr:.4f}")
76
 
77
  plt.figure(figsize=(9, 9))
78
+ plt.scatter(y_true, y_pred, alpha=0.4, s=15, c="blue", label="Predictions")
79
+ plt.plot(
80
+ [min(y_true), max(y_true)],
81
+ [min(y_true), max(y_true)],
82
+ color="red",
83
+ linestyle="--",
84
+ linewidth=2,
85
+ label="Ideal",
86
+ )
87
 
88
+ plt.xlabel("Experimental Affinity (pK)")
89
+ plt.ylabel("Predicted Affinity (pK)")
90
+ plt.title(
91
+ f"Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}"
92
+ )
93
  plt.legend()
94
  plt.grid(True, alpha=0.3)
95
+ plot_file = "final_results_gat.png"
96
  plt.savefig(plot_file)
97
  print(f"График сохранен в {plot_file}")
98
  plt.show()
99
 
100
+
101
  if __name__ == "__main__":
102
+ predict_and_plot()
inference_attention.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import pandas as pd
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ from torch_geometric.loader import DataLoader
8
+ from dataset import BindingDataset
9
+ from model_attention import BindingAffinityModel
10
+ from tqdm import tqdm
11
+ from scipy.stats import pearsonr
12
+ from torch.utils.data import random_split
13
+ from train_attention import GAT_HEADS, DROPOUT, HIDDEN_CHANNELS
14
+
15
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ MODEL_PATH = "runs/experiment_attention20260123_103840_with_additional_data_scheduler/models/model_ep032_mse2.0264.pth"
17
+
18
+
19
+ def set_seed(seed=42):
20
+ random.seed(seed)
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed(seed)
23
+ np.random.seed(seed)
24
+ return torch.Generator().manual_seed(seed)
25
+
26
+
27
+ def predict_and_plot():
28
+ gen = set_seed(42)
29
+ print("Loading data...")
30
+
31
+ dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
32
+ dataframe.dropna(inplace=True)
33
+ dataset = BindingDataset(dataframe)
34
+ if len(dataset) == 0:
35
+ print("Dataset is empty")
36
+ return
37
+
38
+ train_size = int(0.8 * len(dataset))
39
+ test_size = len(dataset) - train_size
40
+ _, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
41
+
42
+ loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
43
+ num_features = test_dataset[0].x.shape[1]
44
+
45
+ print("Loading model...")
46
+ model = BindingAffinityModel(
47
+ num_node_features=num_features,
48
+ hidden_channels=HIDDEN_CHANNELS,
49
+ gat_heads=GAT_HEADS,
50
+ dropout=DROPOUT,
51
+ ).to(DEVICE)
52
+ model.load_state_dict(torch.load(MODEL_PATH))
53
+ model.eval()
54
+
55
+ y_true = []
56
+ y_pred = []
57
+ print("Predicting...")
58
+ with torch.no_grad():
59
+ for batch in tqdm(loader):
60
+ batch = batch.to(DEVICE)
61
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
62
+
63
+ y_true.extend(batch.y.cpu().numpy())
64
+ y_pred.extend(out.squeeze().cpu().numpy())
65
+ y_true = np.array(y_true)
66
+ y_pred = np.array(y_pred)
67
+
68
+ rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
69
+ mae = np.mean(np.abs(y_true - y_pred))
70
+ pearson_corr, _ = pearsonr(y_true, y_pred) # Pearson correlation
71
+
72
+ print("Results:")
73
+ print(f"RMSE: {rmse:.4f}")
74
+ print(f"MAE: {mae:.4f}")
75
+ print(f"Pearson Correlation: {pearson_corr:.4f}")
76
+
77
+ plt.figure(figsize=(9, 9))
78
+ plt.scatter(y_true, y_pred, alpha=0.4, s=15, c="blue", label="Predictions")
79
+ plt.plot(
80
+ [min(y_true), max(y_true)],
81
+ [min(y_true), max(y_true)],
82
+ color="red",
83
+ linestyle="--",
84
+ linewidth=2,
85
+ label="Ideal",
86
+ )
87
+
88
+ plt.xlabel("Experimental Affinity (pK)")
89
+ plt.ylabel("Predicted Affinity (pK)")
90
+ plt.title(
91
+ f"Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}"
92
+ )
93
+ plt.legend()
94
+ plt.grid(True, alpha=0.3)
95
+ plot_file = "final_results_gat.png"
96
+ plt.savefig(plot_file)
97
+ print(f"График сохранен в {plot_file}")
98
+ plt.show()
99
+
100
+
101
+ if __name__ == "__main__":
102
+ predict_and_plot()
main.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import uuid
3
+
4
+ from fastapi import FastAPI, Request, Form
5
+ from fastapi.templating import Jinja2Templates
6
+ from fastapi.staticfiles import StaticFiles
7
+ from fastapi.responses import HTMLResponse
8
+ from utils import get_inference_data, get_py3dmol_view,save_standalone_ngl_html
9
+ import nglview as nv
10
+
11
+
12
+ app = FastAPI()
13
+
14
+ os.makedirs("html_results", exist_ok=True)
15
+ app.mount("/results", StaticFiles(directory="html_results"), name="results")
16
+ templates = Jinja2Templates(directory="templates")
17
+
18
+ @app.get("/", response_class=HTMLResponse)
19
+ async def read_root(request: Request):
20
+ return templates.TemplateResponse("index.html", {"request": request})
21
+
22
+
23
+ @app.post("/predict", response_class=HTMLResponse)
24
+ async def predict(
25
+ request: Request,
26
+ smiles_ligand: str = Form(...),
27
+ sequence_protein: str = Form(...)
28
+ ):
29
+ mol, importance, affinity = get_inference_data(smiles_ligand, sequence_protein)
30
+
31
+ atom_list = []
32
+ sorted_indices = sorted(range(len(importance)), key=lambda k: importance[k], reverse=True)
33
+
34
+ for idx in sorted_indices[:15]:
35
+ val = importance[idx]
36
+ symbol = mol.GetAtomWithIdx(idx).GetSymbol()
37
+
38
+ icon = ""
39
+ if val >= 0.9: icon = "🔥"
40
+ elif val >= 0.7: icon = "✨"
41
+ elif val >= 0.5: icon = "⭐"
42
+ atom_list.append({
43
+ "id": idx,
44
+ "symbol": symbol,
45
+ "score": f"{val:.3f}",
46
+ "icon": icon
47
+ })
48
+
49
+ unique_id = str(uuid.uuid4())
50
+
51
+
52
+ filename_ngl = f"ngl_{unique_id}.html"
53
+ filepath_ngl = os.path.join("html_results", filename_ngl)
54
+
55
+ py3dmol_view = get_py3dmol_view(mol, importance)
56
+ py3dmol_content = py3dmol_view._make_html()
57
+
58
+ # ngl_view = get_ngl_view(mol, importance)
59
+ # nv.write_html(filepath_ngl, ngl_view)
60
+
61
+ save_standalone_ngl_html(mol, importance, filepath_ngl)
62
+
63
+ ngl_url_link = f"/results/{filename_ngl}"
64
+
65
+ return templates.TemplateResponse("index.html", {
66
+ "request": request,
67
+ "result_ready": True,
68
+ "smiles": smiles_ligand,
69
+ "protein": sequence_protein,
70
+ "affinity": f"{affinity:.2f}",
71
+ "atom_list": atom_list,
72
+ "html_py3dmol": py3dmol_content,
73
+ "url_ngl": ngl_url_link
74
+ })
75
+
76
+
77
+
78
+
79
+
model.py CHANGED
@@ -1,11 +1,9 @@
1
  import math
2
-
3
  import torch
4
  import torch.nn as nn
5
-
6
-
7
  from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
8
 
 
9
  class PositionalEncoding(nn.Module):
10
  def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
11
  super().__init__()
@@ -34,11 +32,10 @@ class PositionalEncoding(nn.Module):
34
 
35
  def forward(self, x):
36
  # x: [batch_size, seq_len, d_model]
37
- x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
38
  return self.dropout(x)
39
 
40
 
41
-
42
  # class LigandGNN(nn.Module): # GCN CONV
43
  # def __init__(self, input_dim, hidden_channels):
44
  # super().__init__()
@@ -70,8 +67,12 @@ class LigandGNN(nn.Module):
70
  # Heads=4 means we use 4 attention heads
71
  # Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
72
  self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
73
- self.conv2 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
74
- self.conv3 = GATConv(hidden_channels, hidden_channels, heads=heads, concat=False)
 
 
 
 
75
  self.dropout = nn.Dropout(dropout)
76
 
77
  def forward(self, x, edge_index, batch):
@@ -89,20 +90,23 @@ class LigandGNN(nn.Module):
89
  x = global_mean_pool(x, batch)
90
  return x
91
 
 
92
  class ProteinTransformer(nn.Module):
93
  def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
94
  super().__init__()
95
  self.d_model = d_model
96
  self.embedding = nn.Embedding(vocab_size, d_model)
97
  self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
98
- encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=h, batch_first=True)
 
 
99
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
100
 
101
  self.fc = nn.Linear(d_model, output_dim)
102
 
103
  def forward(self, x):
104
  # x: [batch_size, seq_len]
105
- padding_mask = (x == 0) # mask for PAD tokens
106
  x = self.embedding(x) * math.sqrt(self.d_model)
107
  x = self.pos_encoder(x)
108
  x = self.transformer(x, src_key_padding_mask=padding_mask)
@@ -116,20 +120,34 @@ class ProteinTransformer(nn.Module):
116
  x = self.fc(x)
117
  return x
118
 
 
119
  class BindingAffinityModel(nn.Module):
120
- def __init__(self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2):
 
 
121
  super().__init__()
122
  # Tower 1 - Ligand GNN
123
- self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=hidden_channels, heads=gat_heads, dropout=dropout)
 
 
 
 
 
124
  # Tower 2 - Protein Transformer
125
- self.protein_transformer = ProteinTransformer(vocab_size=26, d_model=hidden_channels, output_dim=hidden_channels, dropout=dropout)
 
 
 
 
 
126
 
127
  self.head = nn.Sequential(
128
- nn.Linear(hidden_channels*2, hidden_channels),
129
  nn.ReLU(),
130
  nn.Dropout(dropout),
131
  nn.Linear(hidden_channels, 1),
132
  )
 
133
  def forward(self, x, edge_index, batch, protein_seq):
134
  ligand_vec = self.ligand_gnn(x, edge_index, batch)
135
  batch_size = batch.max().item() + 1
@@ -138,5 +156,3 @@ class BindingAffinityModel(nn.Module):
138
  protein_vec = self.protein_transformer(protein_seq)
139
  combined = torch.cat([ligand_vec, protein_vec], dim=1)
140
  return self.head(combined)
141
-
142
-
 
1
  import math
 
2
  import torch
3
  import torch.nn as nn
 
 
4
  from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
5
 
6
+
7
  class PositionalEncoding(nn.Module):
8
  def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
9
  super().__init__()
 
32
 
33
  def forward(self, x):
34
  # x: [batch_size, seq_len, d_model]
35
+ x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
36
  return self.dropout(x)
37
 
38
 
 
39
  # class LigandGNN(nn.Module): # GCN CONV
40
  # def __init__(self, input_dim, hidden_channels):
41
  # super().__init__()
 
67
  # Heads=4 means we use 4 attention heads
68
  # Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
69
  self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
70
+ self.conv2 = GATConv(
71
+ hidden_channels, hidden_channels, heads=heads, concat=False
72
+ )
73
+ self.conv3 = GATConv(
74
+ hidden_channels, hidden_channels, heads=heads, concat=False
75
+ )
76
  self.dropout = nn.Dropout(dropout)
77
 
78
  def forward(self, x, edge_index, batch):
 
90
  x = global_mean_pool(x, batch)
91
  return x
92
 
93
+
94
  class ProteinTransformer(nn.Module):
95
  def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
96
  super().__init__()
97
  self.d_model = d_model
98
  self.embedding = nn.Embedding(vocab_size, d_model)
99
  self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
100
+ encoder_layer = nn.TransformerEncoderLayer(
101
+ d_model=d_model, nhead=h, batch_first=True
102
+ )
103
  self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
104
 
105
  self.fc = nn.Linear(d_model, output_dim)
106
 
107
  def forward(self, x):
108
  # x: [batch_size, seq_len]
109
+ padding_mask = x == 0 # mask for PAD tokens
110
  x = self.embedding(x) * math.sqrt(self.d_model)
111
  x = self.pos_encoder(x)
112
  x = self.transformer(x, src_key_padding_mask=padding_mask)
 
120
  x = self.fc(x)
121
  return x
122
 
123
+
124
  class BindingAffinityModel(nn.Module):
125
+ def __init__(
126
+ self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
127
+ ):
128
  super().__init__()
129
  # Tower 1 - Ligand GNN
130
+ self.ligand_gnn = LigandGNN(
131
+ input_dim=num_node_features,
132
+ hidden_channels=hidden_channels,
133
+ heads=gat_heads,
134
+ dropout=dropout,
135
+ )
136
  # Tower 2 - Protein Transformer
137
+ self.protein_transformer = ProteinTransformer(
138
+ vocab_size=26,
139
+ d_model=hidden_channels,
140
+ output_dim=hidden_channels,
141
+ dropout=dropout,
142
+ )
143
 
144
  self.head = nn.Sequential(
145
+ nn.Linear(hidden_channels * 2, hidden_channels),
146
  nn.ReLU(),
147
  nn.Dropout(dropout),
148
  nn.Linear(hidden_channels, 1),
149
  )
150
+
151
  def forward(self, x, edge_index, batch, protein_seq):
152
  ligand_vec = self.ligand_gnn(x, edge_index, batch)
153
  batch_size = batch.max().item() + 1
 
156
  protein_vec = self.protein_transformer(protein_seq)
157
  combined = torch.cat([ligand_vec, protein_vec], dim=1)
158
  return self.head(combined)
 
 
model_attention.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch_geometric.nn import GATConv
4
+ from torch_geometric.utils import to_dense_batch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class CrossAttentionLayer(nn.Module):
9
+ def __init__(self, feature_dim, num_heads=4, dropout=0.1):
10
+ super().__init__()
11
+ # Main attention layer
12
+ # Feature dim is the dimension of the hidden features
13
+ self.attention = nn.MultiheadAttention(
14
+ feature_dim, num_heads, dropout=dropout, batch_first=True
15
+ )
16
+
17
+ # Normalization layer for stabilizing training
18
+ self.norm = nn.LayerNorm(feature_dim)
19
+
20
+ # Feedforward network for further processing, classical transformer style
21
+ self.ff = nn.Sequential(
22
+ nn.Linear(feature_dim, feature_dim * 4),
23
+ nn.GELU(),
24
+ nn.Dropout(dropout),
25
+ nn.Linear(feature_dim * 4, feature_dim),
26
+ )
27
+ self.norm_ff = nn.LayerNorm(feature_dim)
28
+ self.last_attention_weights = None
29
+
30
+ def forward(self, ligand_features, protein_features, key_padding_mask=None):
31
+ # ligand_features: [Batch, Atoms, Dim] - atoms
32
+ # protein_features: [Batch, Residues, Dim] - amino acids
33
+ # Cross attention:
34
+ # Query = Ligand (What we want to find out)
35
+ # Key, Value = Protein (Where we look for information)
36
+ # Result: "Ligand enriched with knowledge about proteins"
37
+ attention_output, attn_weights = self.attention(
38
+ query=ligand_features,
39
+ key=protein_features,
40
+ value=protein_features,
41
+ key_padding_mask=key_padding_mask,
42
+ need_weights=True,
43
+ average_attn_weights=True,
44
+ )
45
+ self.last_attention_weights = attn_weights.detach().cpu()
46
+
47
+ # Residual connection (x + attention(x)) and normalization
48
+ ligand_features = self.norm(ligand_features + attention_output)
49
+
50
+ # Feedforward network with residual connection and normalization
51
+ ff_output = self.ff(ligand_features)
52
+ ligand_features = self.norm_ff(ligand_features + ff_output)
53
+
54
+ return ligand_features
55
+
56
+
57
+ class BindingAffinityModel(nn.Module):
58
+ def __init__(
59
+ self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3
60
+ ):
61
+ super().__init__()
62
+ self.dropout = dropout
63
+ self.hidden_channels = hidden_channels
64
+
65
+ # Tower 1 - Ligand GNN with GAT layers, using 3 GAT layers, so that every atom can "see" up to 3 bonds away
66
+ self.gat1 = GATConv(
67
+ num_node_features, hidden_channels, heads=gat_heads, concat=False
68
+ )
69
+ self.gat2 = GATConv(
70
+ hidden_channels, hidden_channels, heads=gat_heads, concat=False
71
+ )
72
+ self.gat3 = GATConv(
73
+ hidden_channels, hidden_channels, heads=gat_heads, concat=False
74
+ )
75
+
76
+ # Tower 2 - Protein Transformer, 22 = 21 amino acids + 1 padding token PAD
77
+ self.protein_embedding = nn.Embedding(22, hidden_channels)
78
+ # Additional positional encoding (simple linear) to give the model information about the order
79
+ self.prot_conv = nn.Conv1d(
80
+ hidden_channels, hidden_channels, kernel_size=3, padding=1
81
+ )
82
+
83
+ # Cross-Attention Layer, atoms attending to amino acids
84
+ self.cross_attention = CrossAttentionLayer(
85
+ feature_dim=hidden_channels, num_heads=4, dropout=dropout
86
+ )
87
+
88
+ self.fc1 = nn.Linear(hidden_channels, hidden_channels)
89
+ self.fc2 = nn.Linear(hidden_channels, 1) # Final output for regression, pKd
90
+
91
+ def forward(self, x, edge_index, batch, protein_seq):
92
+ # Ligand GNN forward pass (Graph -> Node Embeddings)
93
+ x = F.elu(self.gat1(x, edge_index))
94
+ x = F.dropout(x, p=self.dropout, training=self.training)
95
+
96
+ x = F.elu(self.gat2(x, edge_index))
97
+ x = F.dropout(x, p=self.dropout, training=self.training)
98
+
99
+ x = F.elu(self.gat3(x, edge_index)) # [Total_Atoms, Hidden_Channels]
100
+
101
+ # Convert graph into tensor [Batch, Max_Atoms, Hidden_Channels]
102
+ # to_dense_batch adds zeros paddings where necessary
103
+ ligand_dense, ligand_mask = to_dense_batch(x, batch)
104
+ # ligand_dense: [Batch, Max_Atoms, Hidden_Channels]
105
+ # ligand_mask: [Batch, Max_Atoms] True where there is real atom, False where there is padding
106
+
107
+ batch_size = ligand_dense.size(0)
108
+ protein_seq = protein_seq.view(batch_size, -1) # [Batch, Seq_Len]
109
+
110
+ # Protein forward pass protein_seq: [Batch, Seq_Len]
111
+ p = self.protein_embedding(protein_seq) # [Batch, Seq_Len, Hidden_Channels]
112
+
113
+ # A simple convolution to understand local context in amino acids
114
+ p = p.permute(0, 2, 1) # Change to [Batch, Hidden_Channels, Seq_Len] for Conv1d
115
+ p = F.relu(self.prot_conv(p))
116
+ p = p.permute(0, 2, 1) # [Batch, Seq, Hidden_Channels]
117
+
118
+ # Mask for protein (where PAD=0, True, but MHA needs True where IGNOREME)
119
+ # In Pytorch MHA, the key_padding_mask should be True where we want to ignore
120
+ protein_pad_mask = protein_seq == 0
121
+
122
+ # Cross-Attention
123
+ x_cross = self.cross_attention(
124
+ ligand_dense, p, key_padding_mask=protein_pad_mask
125
+ )
126
+
127
+ # Pooling over atoms to get a single vector per molecule, considering only real atoms, ignoring paddings
128
+ # ligand mask True where real atom, False where padding
129
+ mask_expanded = ligand_mask.unsqueeze(-1) # [Batch, Max_Atoms, 1]
130
+
131
+ # Zero out the padded atom features
132
+ x_cross = x_cross * mask_expanded
133
+
134
+ # Sum the features of real atoms / number of real atoms to get the mean
135
+ sum_features = torch.sum(x_cross, dim=1) # [Batch, Hidden_Channels]
136
+ num_atoms = torch.sum(mask_expanded, dim=1) # [Batch, 1]
137
+ pooled_x = sum_features / (num_atoms + 1e-6) # Avoid division by zero
138
+
139
+ # MLP Head
140
+ out = F.relu(self.fc1(pooled_x))
141
+ out = F.dropout(out, p=self.dropout, training=self.training)
142
+ out = self.fc2(out)
143
+ return out
model_pl.py CHANGED
@@ -8,19 +8,19 @@ from torch.optim import Adam
8
 
9
  from model import LigandGNN, ProteinTransformer
10
 
 
11
  class BindingAffinityModelPL(pl.LightningModule):
12
  def __init__(self, num_node_features, hidden_channels_gnn, lr):
13
  super().__init__()
14
- self.save_hyperparameters() # Save hyperparameters for easy access
15
  self.lr = lr
16
 
17
- self.ligand_gnn = LigandGNN(input_dim=num_node_features, hidden_channels=hidden_channels_gnn)
 
 
18
  self.protein_transformer = ProteinTransformer(vocab_size=26)
19
  self.head = nn.Sequential(
20
- nn.Linear(128 + 128, 256),
21
- nn.ReLU(),
22
- nn.Dropout(0.2),
23
- nn.Linear(256, 1)
24
  )
25
  self.criterion = nn.MSELoss()
26
 
 
8
 
9
  from model import LigandGNN, ProteinTransformer
10
 
11
+
12
  class BindingAffinityModelPL(pl.LightningModule):
13
  def __init__(self, num_node_features, hidden_channels_gnn, lr):
14
  super().__init__()
15
+ self.save_hyperparameters() # Save hyperparameters for easy access
16
  self.lr = lr
17
 
18
+ self.ligand_gnn = LigandGNN(
19
+ input_dim=num_node_features, hidden_channels=hidden_channels_gnn
20
+ )
21
  self.protein_transformer = ProteinTransformer(vocab_size=26)
22
  self.head = nn.Sequential(
23
+ nn.Linear(128 + 128, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 1)
 
 
 
24
  )
25
  self.criterion = nn.MSELoss()
26
 
optuna_train.py CHANGED
@@ -9,10 +9,11 @@ from torch.utils.data import random_split
9
  from dataset import BindingDataset
10
  from model import BindingAffinityModel
11
 
12
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
  N_TRIALS = 20
14
  EPOCHS_PER_TRIAL = 15
15
 
 
16
  def set_seed(seed=42):
17
  random.seed(seed)
18
  np.random.seed(seed)
@@ -20,7 +21,8 @@ def set_seed(seed=42):
20
  torch.cuda.manual_seed(seed)
21
  return torch.Generator().manual_seed(seed)
22
 
23
- dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
 
24
  dataframe.dropna(inplace=True)
25
  dataset = BindingDataset(dataframe)
26
 
@@ -28,9 +30,12 @@ gen = set_seed(42)
28
 
29
  train_size = int(0.8 * len(dataset))
30
  test_size = len(dataset) - train_size
31
- train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
 
 
32
  num_features = train_dataset[0].x.shape[1]
33
 
 
34
  def train(model, loader, optimizer, criterion):
35
  model.train()
36
  for batch in loader:
@@ -62,28 +67,50 @@ def objective(trial):
62
 
63
  # Learning
64
 
65
- lr = trial.suggest_float("lr", 1e-5, 1e-2, log=True) # Learning rate from 0.00001 to 0.01
66
- weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True) # Weight decay from 0.000001 to 0.001
 
 
 
 
67
  batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
68
 
69
- model = BindingAffinityModel(num_node_features=num_features, hidden_channels=hidden_dim, gat_heads=gat_heads, dropout=dropout).to(DEVICE)
 
 
 
 
 
70
 
71
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
 
 
 
 
72
  criterion = nn.MSELoss()
73
 
74
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
75
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
76
 
 
 
77
  for epoch in range(EPOCHS_PER_TRIAL):
78
  train(model, train_loader, optimizer, criterion)
79
  val_loss = test(model, test_loader, criterion)
80
 
81
- print(f"Trial {trial.number} | Epoch {epoch + 1}/{EPOCHS_PER_TRIAL} | Val Loss: {val_loss:.4f}")
 
 
 
 
 
 
 
82
 
83
  trial.report(val_loss, epoch)
84
  if trial.should_prune():
85
  raise optuna.exceptions.TrialPruned()
86
- return val_loss
87
 
88
 
89
  if __name__ == "__main__":
@@ -93,7 +120,7 @@ if __name__ == "__main__":
93
  pruner=optuna.pruners.MedianPruner(),
94
  storage=storage_name,
95
  study_name="binding_prediction_optimization",
96
- load_if_exists=True
97
  )
98
  print("Start hyperparameter optimization...")
99
 
 
9
  from dataset import BindingDataset
10
  from model import BindingAffinityModel
11
 
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  N_TRIALS = 20
14
  EPOCHS_PER_TRIAL = 15
15
 
16
+
17
  def set_seed(seed=42):
18
  random.seed(seed)
19
  np.random.seed(seed)
 
21
  torch.cuda.manual_seed(seed)
22
  return torch.Generator().manual_seed(seed)
23
 
24
+
25
+ dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
26
  dataframe.dropna(inplace=True)
27
  dataset = BindingDataset(dataframe)
28
 
 
30
 
31
  train_size = int(0.8 * len(dataset))
32
  test_size = len(dataset) - train_size
33
+ train_dataset, test_dataset = random_split(
34
+ dataset, [train_size, test_size], generator=gen
35
+ )
36
  num_features = train_dataset[0].x.shape[1]
37
 
38
+
39
  def train(model, loader, optimizer, criterion):
40
  model.train()
41
  for batch in loader:
 
67
 
68
  # Learning
69
 
70
+ lr = trial.suggest_float(
71
+ "lr", 1e-5, 1e-2, log=True
72
+ ) # Learning rate from 0.00001 to 0.01
73
+ weight_decay = trial.suggest_float(
74
+ "weight_decay", 1e-6, 1e-3, log=True
75
+ ) # Weight decay from 0.000001 to 0.001
76
  batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
77
 
78
+ model = BindingAffinityModel(
79
+ num_node_features=num_features,
80
+ hidden_channels=hidden_dim,
81
+ gat_heads=gat_heads,
82
+ dropout=dropout,
83
+ ).to(DEVICE)
84
 
85
  optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
86
+
87
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
88
+ optimizer, mode="min", factor=0.5, patience=5
89
+ )
90
  criterion = nn.MSELoss()
91
 
92
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
93
  test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
94
 
95
+ best_val_loss = float("inf")
96
+
97
  for epoch in range(EPOCHS_PER_TRIAL):
98
  train(model, train_loader, optimizer, criterion)
99
  val_loss = test(model, test_loader, criterion)
100
 
101
+ if val_loss < best_val_loss:
102
+ best_val_loss = val_loss
103
+
104
+ scheduler.step(val_loss)
105
+
106
+ print(
107
+ f"Trial {trial.number} | Epoch {epoch + 1}/{EPOCHS_PER_TRIAL} | Val Loss: {val_loss:.4f}"
108
+ )
109
 
110
  trial.report(val_loss, epoch)
111
  if trial.should_prune():
112
  raise optuna.exceptions.TrialPruned()
113
+ return best_val_loss
114
 
115
 
116
  if __name__ == "__main__":
 
120
  pruner=optuna.pruners.MedianPruner(),
121
  storage=storage_name,
122
  study_name="binding_prediction_optimization",
123
+ load_if_exists=True,
124
  )
125
  print("Start hyperparameter optimization...")
126
 
optuna_train_attention.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import optuna
2
+ import torch
3
+ import torch.nn as nn
4
+ import pandas as pd
5
+ import random
6
+ import numpy as np
7
+ from torch_geometric.loader import DataLoader
8
+ from torch.utils.data import random_split
9
+ from dataset import BindingDataset
10
+ from model_attention import BindingAffinityModel
11
+
12
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+ N_TRIALS = 50
14
+ MAX_EPOCHS_PER_TRIAL = 60
15
+ LOG_DIR = "runs"
16
+ DATA_CSV = "pdbbind_refined_dataset.csv"
17
+
18
+
19
+ def set_seed(seed=42):
20
+ random.seed(seed)
21
+ np.random.seed(seed)
22
+ torch.manual_seed(seed)
23
+ torch.cuda.manual_seed(seed)
24
+ return torch.Generator().manual_seed(seed)
25
+
26
+
27
+ dataframe = pd.read_csv(DATA_CSV)
28
+ dataframe.dropna(inplace=True)
29
+ dataset = BindingDataset(dataframe, max_seq_length=1200)
30
+
31
+ gen = set_seed(42)
32
+
33
+ train_size = int(0.8 * len(dataset))
34
+ test_size = len(dataset) - train_size
35
+ train_dataset, test_dataset = random_split(
36
+ dataset, [train_size, test_size], generator=gen
37
+ )
38
+ num_features = train_dataset[0].x.shape[1]
39
+
40
+
41
+ def train(model, loader, optimizer, criterion):
42
+ model.train()
43
+ for batch in loader:
44
+ batch = batch.to(DEVICE)
45
+ optimizer.zero_grad()
46
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
47
+ loss = criterion(out.squeeze(), batch.y.squeeze())
48
+ loss.backward()
49
+ optimizer.step()
50
+
51
+
52
+ def test(model, loader, criterion):
53
+ model.eval()
54
+ total_loss = 0
55
+ with torch.no_grad():
56
+ for batch in loader:
57
+ batch = batch.to(DEVICE)
58
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
59
+ loss = criterion(out.squeeze(), batch.y.squeeze())
60
+ total_loss += loss.item()
61
+ return total_loss / len(loader)
62
+
63
+
64
+ def objective(trial):
65
+ # Architecture
66
+ hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256])
67
+ gat_heads = trial.suggest_categorical("gat_heads", [2, 4])
68
+ dropout = trial.suggest_float("dropout", 0.2, 0.5)
69
+
70
+ # Learning
71
+
72
+ lr = trial.suggest_float(
73
+ "lr", 1e-5, 1e-3, log=True
74
+ ) # Learning rate from 0.00001 to 0.001
75
+ weight_decay = trial.suggest_float(
76
+ "weight_decay", 1e-6, 1e-3, log=True
77
+ ) # Weight decay from 0.000001 to 0.001
78
+ batch_size = 16
79
+
80
+ model = BindingAffinityModel(
81
+ num_node_features=num_features,
82
+ hidden_channels=hidden_dim,
83
+ gat_heads=gat_heads,
84
+ dropout=dropout,
85
+ ).to(DEVICE)
86
+
87
+ optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
88
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
89
+ optimizer, mode="min", factor=0.5, patience=5
90
+ )
91
+ criterion = nn.MSELoss()
92
+
93
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
94
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
95
+
96
+ best_val_loss = float("inf")
97
+
98
+ for epoch in range(MAX_EPOCHS_PER_TRIAL):
99
+ train(model, train_loader, optimizer, criterion)
100
+ val_loss = test(model, test_loader, criterion)
101
+ scheduler.step(val_loss)
102
+
103
+ if val_loss < best_val_loss:
104
+ best_val_loss = val_loss
105
+ print(
106
+ f"Trial {trial.number} | Epoch {epoch + 1}/{MAX_EPOCHS_PER_TRIAL} | Val Loss: {val_loss:.4f}"
107
+ )
108
+
109
+ trial.report(val_loss, epoch)
110
+ if trial.should_prune():
111
+ raise optuna.exceptions.TrialPruned()
112
+ return best_val_loss
113
+
114
+
115
+ if __name__ == "__main__":
116
+ storage_name = "sqlite:///db.sqlite3"
117
+ study = optuna.create_study(
118
+ direction="minimize",
119
+ pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=10),
120
+ storage=storage_name,
121
+ study_name="binding_prediction_optimization_attentionV2",
122
+ load_if_exists=True,
123
+ )
124
+ print("Start hyperparameter optimization...")
125
+
126
+ study.optimize(objective, n_trials=N_TRIALS)
127
+ print("\n--- Optimization Finished ---")
128
+ print("Best parameters found: ", study.best_params)
129
+ print("Best Test MSE: ", study.best_value)
130
+
131
+ df_results = study.trials_dataframe()
132
+ df_results.to_csv("optuna_results_attention.csv")
templates/index.html ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="ru">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>BioBinding AI Vis</title>
7
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
8
+ <style>
9
+ body { background-color: #f4f6f9; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
10
+ .sidebar { background: white; border-right: 1px solid #dee2e6; height: 100vh; overflow-y: auto; }
11
+ .result-card { border: none; box-shadow: 0 4px 6px rgba(0,0,0,0.05); border-radius: 12px; margin-bottom: 20px; }
12
+ .affinity-score { font-size: 3rem; font-weight: bold; color: #4e73df; }
13
+ .atom-badge { font-size: 0.9rem; padding: 8px 12px; }
14
+ .mol-container { width: 100%; height: 600px; border-radius: 12px; overflow: hidden; border: 1px solid #ddd; }
15
+ iframe { width: 100%; height: 100%; border: none; }
16
+ </style>
17
+ </head>
18
+ <body>
19
+
20
+ <div class="container-fluid">
21
+ <div class="row">
22
+ <div class="col-md-3 sidebar p-4">
23
+ <h3 class="mb-4 text-primary">🧪 BioBind AI</h3>
24
+
25
+ <form action="/predict" method="post">
26
+ <div class="mb-3">
27
+ <label class="form-label fw-bold">Ligand (SMILES)</label>
28
+ <textarea class="form-control" name="smiles_ligand" rows="3" required>{{ smiles_ligand if smiles_ligand else 'COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1' }}</textarea>
29
+ </div>
30
+
31
+ <div class="mb-3">
32
+ <label class="form-label fw-bold">Protein Sequence</label>
33
+ <textarea class="form-control" name="sequence_protein" rows="3" required>{{ sequence_protein if sequence_protein else 'PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF' }}</textarea>
34
+ </div>
35
+
36
+ <button type="submit" class="btn btn-primary w-100 py-2">🔮 Calculate Binding</button>
37
+ </form>
38
+
39
+ {% if result_ready %}
40
+ <hr class="my-4">
41
+ <h5 class="mb-3">Top Important Atoms</h5>
42
+ <div class="list-group">
43
+ {% for atom in atom_list %}
44
+ <div class="list-group-item d-flex justify-content-between align-items-center">
45
+ <span>
46
+ <span class="fw-bold">#{{ atom.id }}</span> {{ atom.symbol }}
47
+ </span>
48
+ <span>
49
+ <span class="badge bg-light text-dark border">{{ atom.score }}</span>
50
+ <span>{{ atom.icon }}</span>
51
+ </span>
52
+ </div>
53
+ {% endfor %}
54
+ </div>
55
+ {% endif %}
56
+ </div>
57
+
58
+ <div class="col-md-9 p-4">
59
+ {% if result_ready %}
60
+ <div class="card result-card p-4 text-center">
61
+ <h2 class="text-muted">Predicted Binding Affinity (pKd)</h2>
62
+ <div class="affinity-score">{{ affinity }}</div>
63
+ </div>
64
+
65
+ <div class="card result-card p-3">
66
+ <ul class="nav nav-pills mb-3" id="pills-tab" role="tablist">
67
+ <li class="nav-item" role="presentation">
68
+ <button class="nav-link active" id="pills-py3dmol-tab" data-bs-toggle="pill" data-bs-target="#pills-py3dmol" type="button">🧬 Py3Dmol (High Contrast)</button>
69
+ </li>
70
+ <li class="nav-item" role="presentation">
71
+ <button class="nav-link" id="pills-ngl-tab" data-bs-toggle="pill" data-bs-target="#pills-ngl" type="button">🔬 NGLView</button>
72
+ </li>
73
+ </ul>
74
+
75
+ <div class="tab-content" id="pills-tabContent">
76
+ <div class="tab-pane fade show active" id="pills-py3dmol" role="tabpanel">
77
+ <div class="mol-container">
78
+ <iframe srcdoc="{{ html_py3dmol }}" style="width: 100%; height: 100%; border: none;"></iframe>
79
+ </div>
80
+ </div>
81
+
82
+ <div class="tab-pane fade" id="pills-ngl" role="tabpanel">
83
+ <div class="mol-container">
84
+ <iframe src="{{ url_ngl }}" style="width: 100%; height: 100%; border: none;"></iframe>
85
+ </div>
86
+ </div>
87
+ </div>
88
+ </div>
89
+ {% else %}
90
+ <div class="d-flex align-items-center justify-content-center h-100 text-muted">
91
+ <div class="text-center">
92
+ <h1>🧬 Ready to Analyze</h1>
93
+ <p>Enter SMILES and Protein sequence on the left to start.</p>
94
+ </div>
95
+ </div>
96
+ {% endif %}
97
+ </div>
98
+ </div>
99
+ </div>
100
+
101
+ <script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
102
+ </body>
103
+ </html>
train.py CHANGED
@@ -20,12 +20,14 @@ WEIGHT_DECAY = 7.06e-6
20
  EPOCS = 100
21
  DROPOUT = 0.325
22
  GAT_HEADS = 2
 
23
 
24
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
25
- LOG_DIR = f"runs/experiment_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
26
  TOP_K = 3
27
  SAVES_DIR = LOG_DIR + "/models"
28
 
 
29
  def set_seed(seed=42):
30
  random.seed(seed)
31
  torch.manual_seed(seed)
@@ -52,13 +54,14 @@ def train_epoch(epoch, model, loader, optimizer, criterion, writer):
52
  total_loss += current_loss
53
 
54
  global_step = (epoch - 1) * len(loader) + i
55
- writer.add_scalar('Loss/Train_Step', current_loss, global_step)
56
 
57
- loop.set_postfix(loss = loss.item())
58
 
59
  avg_loss = total_loss / len(loader)
60
  return avg_loss
61
 
 
62
  def evaluate(epoch, model, loader, criterion, writer):
63
  model.eval()
64
  total_loss = 0
@@ -70,9 +73,10 @@ def evaluate(epoch, model, loader, criterion, writer):
70
  total_loss += loss.item()
71
 
72
  avg_loss = total_loss / len(loader)
73
- writer.add_scalar('Loss/Test', avg_loss, epoch)
74
  return avg_loss
75
 
 
76
  def main():
77
  gen = set_seed(42)
78
  writer = SummaryWriter(LOG_DIR)
@@ -82,20 +86,21 @@ def main():
82
  print(f"Logging to {LOG_DIR}...")
83
  print(f"Model saves to {SAVES_DIR}...")
84
  # Load dataset
85
- dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
86
  dataframe.dropna(inplace=True)
87
  print("Dataset loaded with {} samples".format(len(dataframe)))
88
- dataset = BindingDataset(dataframe)
89
  print("Dataset transformed with {} samples".format(len(dataset)))
90
 
91
  if len(dataset) == 0:
92
  print("Dataset is empty")
93
  return
94
 
95
-
96
  train_size = int(0.8 * len(dataset))
97
  test_size = len(dataset) - train_size
98
- train_dataset, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
 
 
99
 
100
  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
101
  test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
@@ -104,40 +109,57 @@ def main():
104
 
105
  model = BindingAffinityModel(
106
  num_node_features=num_features,
107
- hidden_channels=256,
108
  gat_heads=GAT_HEADS,
109
- dropout=DROPOUT
110
  ).to(DEVICE)
111
  optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
 
 
 
 
 
112
  criterion = nn.MSELoss()
113
 
114
  top_models = []
115
 
116
  print(f"Starting training on {DEVICE}")
117
  for epoch in range(1, EPOCS + 1):
118
- train_loss = train_epoch(epoch, model, train_loader, optimizer, criterion, writer)
 
 
119
  test_loss = evaluate(epoch, model, test_loader, criterion, writer)
120
 
121
- print(f'Epoch {epoch:02d}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
124
 
125
  torch.save(model.state_dict(), filename)
126
- top_models.append({'loss': test_loss, 'path': filename, 'epoch': epoch})
127
 
128
- top_models.sort(key=lambda x: x['loss'])
129
 
130
  if len(top_models) > TOP_K:
131
  worst_model = top_models.pop()
132
- os.remove(worst_model['path'])
133
 
134
- if any(m['epoch'] == epoch for m in top_models):
135
- rank = [m['epoch'] for m in top_models].index(epoch) + 1
136
- print(f'-- Model saved (Rank: {rank})')
137
  else:
138
  print("")
139
 
140
-
141
  writer.close()
142
  print("Training finished.")
143
  print("Top models saved:")
@@ -146,4 +168,4 @@ def main():
146
 
147
 
148
  if __name__ == "__main__":
149
- main()
 
20
  EPOCS = 100
21
  DROPOUT = 0.325
22
  GAT_HEADS = 2
23
+ HIDDEN_CHANNELS = 256
24
 
25
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+ LOG_DIR = f"runs/experiment_scheduler{datetime.now().strftime('%Y%m%d_%H%M%S')}"
27
  TOP_K = 3
28
  SAVES_DIR = LOG_DIR + "/models"
29
 
30
+
31
  def set_seed(seed=42):
32
  random.seed(seed)
33
  torch.manual_seed(seed)
 
54
  total_loss += current_loss
55
 
56
  global_step = (epoch - 1) * len(loader) + i
57
+ writer.add_scalar("Loss/Train_Step", current_loss, global_step)
58
 
59
+ loop.set_postfix(loss=loss.item())
60
 
61
  avg_loss = total_loss / len(loader)
62
  return avg_loss
63
 
64
+
65
  def evaluate(epoch, model, loader, criterion, writer):
66
  model.eval()
67
  total_loss = 0
 
73
  total_loss += loss.item()
74
 
75
  avg_loss = total_loss / len(loader)
76
+ writer.add_scalar("Loss/Test", avg_loss, epoch)
77
  return avg_loss
78
 
79
+
80
  def main():
81
  gen = set_seed(42)
82
  writer = SummaryWriter(LOG_DIR)
 
86
  print(f"Logging to {LOG_DIR}...")
87
  print(f"Model saves to {SAVES_DIR}...")
88
  # Load dataset
89
+ dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
90
  dataframe.dropna(inplace=True)
91
  print("Dataset loaded with {} samples".format(len(dataframe)))
92
+ dataset = BindingDataset(dataframe, max_seq_length=1200)
93
  print("Dataset transformed with {} samples".format(len(dataset)))
94
 
95
  if len(dataset) == 0:
96
  print("Dataset is empty")
97
  return
98
 
 
99
  train_size = int(0.8 * len(dataset))
100
  test_size = len(dataset) - train_size
101
+ train_dataset, test_dataset = random_split(
102
+ dataset, [train_size, test_size], generator=gen
103
+ )
104
 
105
  train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
106
  test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
 
109
 
110
  model = BindingAffinityModel(
111
  num_node_features=num_features,
112
+ hidden_channels=HIDDEN_CHANNELS,
113
  gat_heads=GAT_HEADS,
114
+ dropout=DROPOUT,
115
  ).to(DEVICE)
116
  optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
117
+ # factor of 0.5 means reducing lr to half when triggered
118
+ # patience of 5 means wait for 5 epochs before reducing lr
119
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
120
+ optimizer, mode="min", factor=0.5, patience=5
121
+ )
122
  criterion = nn.MSELoss()
123
 
124
  top_models = []
125
 
126
  print(f"Starting training on {DEVICE}")
127
  for epoch in range(1, EPOCS + 1):
128
+ train_loss = train_epoch(
129
+ epoch, model, train_loader, optimizer, criterion, writer
130
+ )
131
  test_loss = evaluate(epoch, model, test_loader, criterion, writer)
132
 
133
+ old_lr = optimizer.param_groups[0]["lr"]
134
+ scheduler.step(test_loss)
135
+ new_lr = optimizer.param_groups[0]["lr"]
136
+
137
+ if new_lr != old_lr:
138
+ print(
139
+ f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
140
+ )
141
+ print(
142
+ f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
143
+ end="",
144
+ )
145
 
146
  filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
147
 
148
  torch.save(model.state_dict(), filename)
149
+ top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
150
 
151
+ top_models.sort(key=lambda x: x["loss"])
152
 
153
  if len(top_models) > TOP_K:
154
  worst_model = top_models.pop()
155
+ os.remove(worst_model["path"])
156
 
157
+ if any(m["epoch"] == epoch for m in top_models):
158
+ rank = [m["epoch"] for m in top_models].index(epoch) + 1
159
+ print(f"-- Model saved (Rank: {rank})")
160
  else:
161
  print("")
162
 
 
163
  writer.close()
164
  print("Training finished.")
165
  print("Top models saved:")
 
168
 
169
 
170
  if __name__ == "__main__":
171
+ main()
train_attention.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import pandas as pd
6
+ from torch.utils.data import random_split
7
+ from torch_geometric.loader import DataLoader
8
+ from dataset import BindingDataset
9
+ from model_attention import BindingAffinityModel
10
+ from tqdm import tqdm
11
+ from torch.utils.tensorboard import SummaryWriter
12
+ import numpy as np
13
+ from datetime import datetime
14
+ import os
15
+
16
+ # 2.02
17
+ # BATCH_SIZE = 16
18
+ # LR = 0.00035 # Reduced learning rate
19
+ # WEIGHT_DECAY = 1e-5 # Slightly increased weight decay (regularization)
20
+ # EPOCHS = 100
21
+ # DROPOUT = 0.3 # Slightly reduced dropout
22
+ # GAT_HEADS = 2
23
+ # HIDDEN_CHANNELS = 256
24
+
25
+ # 1.90 from Optuna
26
+ BATCH_SIZE = 16
27
+ LR = 0.000034
28
+ WEIGHT_DECAY = 1e-6
29
+ DROPOUT = 0.26
30
+ EPOCHS = 100
31
+ HIDDEN_CHANNELS = 256
32
+ GAT_HEADS = 2
33
+
34
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ LOG_DIR = f"runs/experiment_attention{datetime.now().strftime('%Y%m%d_%H%M%S')}"
36
+ TOP_K = 3
37
+ SAVES_DIR = LOG_DIR + "/models"
38
+
39
+
40
+ def set_seed(seed=42):
41
+ random.seed(seed)
42
+ torch.manual_seed(seed)
43
+ torch.cuda.manual_seed(seed)
44
+ np.random.seed(seed)
45
+ return torch.Generator().manual_seed(seed)
46
+
47
+
48
+ def train_epoch(epoch, model, loader, optimizer, criterion, writer):
49
+ model.train()
50
+ total_loss = 0
51
+
52
+ loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
53
+ for i, batch in enumerate(loop):
54
+ batch = batch.to(DEVICE)
55
+ optimizer.zero_grad()
56
+
57
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
58
+ loss = criterion(out.squeeze(), batch.y.squeeze())
59
+
60
+ loss.backward()
61
+ optimizer.step()
62
+ current_loss = loss.item()
63
+ total_loss += current_loss
64
+
65
+ global_step = (epoch - 1) * len(loader) + i
66
+ writer.add_scalar("Loss/Train_Step", current_loss, global_step)
67
+
68
+ loop.set_postfix(loss=loss.item())
69
+
70
+ avg_loss = total_loss / len(loader)
71
+ return avg_loss
72
+
73
+
74
+ def evaluate(epoch, model, loader, criterion, writer):
75
+ model.eval()
76
+ total_loss = 0
77
+ with torch.no_grad():
78
+ for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
79
+ batch = batch.to(DEVICE)
80
+ out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
81
+ loss = criterion(out.squeeze(), batch.y.squeeze())
82
+ total_loss += loss.item()
83
+
84
+ avg_loss = total_loss / len(loader)
85
+ writer.add_scalar("Loss/Test", avg_loss, epoch)
86
+ return avg_loss
87
+
88
+
89
+ def main():
90
+ gen = set_seed(42)
91
+ writer = SummaryWriter(LOG_DIR)
92
+
93
+ if not os.path.exists(SAVES_DIR):
94
+ os.makedirs(SAVES_DIR)
95
+ print(f"Logging to {LOG_DIR}...")
96
+ print(f"Model saves to {SAVES_DIR}...")
97
+ # Load dataset
98
+ dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
99
+ dataframe.dropna(inplace=True)
100
+ print("Dataset loaded with {} samples".format(len(dataframe)))
101
+ dataset = BindingDataset(dataframe, max_seq_length=1200)
102
+ print("Dataset transformed with {} samples".format(len(dataset)))
103
+
104
+ if len(dataset) == 0:
105
+ print("Dataset is empty")
106
+ return
107
+
108
+ train_size = int(0.8 * len(dataset))
109
+ test_size = len(dataset) - train_size
110
+ train_dataset, test_dataset = random_split(
111
+ dataset, [train_size, test_size], generator=gen
112
+ )
113
+
114
+ train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
115
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
116
+ num_features = train_dataset[0].x.shape[1]
117
+ print("Number of node features:", num_features)
118
+
119
+ model = BindingAffinityModel(
120
+ num_node_features=num_features,
121
+ hidden_channels=HIDDEN_CHANNELS,
122
+ gat_heads=GAT_HEADS,
123
+ dropout=DROPOUT,
124
+ ).to(DEVICE)
125
+ optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
126
+ # factor of 0.5 means reducing lr to half when triggered
127
+ # patience of 8 means wait for 8 epochs before reducing lr
128
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
129
+ optimizer, mode="min", factor=0.5, patience=8
130
+ )
131
+ criterion = nn.MSELoss()
132
+
133
+ top_models = []
134
+
135
+ print(f"Starting training on {DEVICE}")
136
+ for epoch in range(1, EPOCHS + 1):
137
+ train_loss = train_epoch(
138
+ epoch, model, train_loader, optimizer, criterion, writer
139
+ )
140
+ test_loss = evaluate(epoch, model, test_loader, criterion, writer)
141
+
142
+ old_lr = optimizer.param_groups[0]["lr"]
143
+ scheduler.step(test_loss)
144
+ new_lr = optimizer.param_groups[0]["lr"]
145
+
146
+ if new_lr != old_lr:
147
+ print(
148
+ f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
149
+ )
150
+ print(
151
+ f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
152
+ end="",
153
+ )
154
+
155
+ filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
156
+
157
+ torch.save(model.state_dict(), filename)
158
+ top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
159
+
160
+ top_models.sort(key=lambda x: x["loss"])
161
+
162
+ if len(top_models) > TOP_K:
163
+ worst_model = top_models.pop()
164
+ os.remove(worst_model["path"])
165
+
166
+ if any(m["epoch"] == epoch for m in top_models):
167
+ rank = [m["epoch"] for m in top_models].index(epoch) + 1
168
+ print(f"-- Model saved (Rank: {rank})")
169
+ else:
170
+ print("")
171
+
172
+ writer.close()
173
+ print("Training finished.")
174
+ print("Top models saved:")
175
+ for i, m in enumerate(top_models):
176
+ print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")
177
+
178
+
179
+ if __name__ == "__main__":
180
+ main()
train_pl.py CHANGED
@@ -6,10 +6,11 @@ from torch.utils.data import random_split
6
  from model_pl import BindingAffinityModelPL
7
  import pandas as pd
8
 
 
9
  def main():
10
  lr = 0.0005
11
  # Load dataset
12
- dataframe = pd.read_csv('pdbbind_refined_dataset.csv')
13
  dataframe.dropna(inplace=True)
14
  print("Dataset loaded with {} samples".format(len(dataframe)))
15
  dataset = BindingDataset(dataframe)
@@ -30,21 +31,22 @@ def main():
30
 
31
  model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
32
  checkpoint_callback = ModelCheckpoint(
33
- monitor='val_loss',
34
- dirpath='checkpoints/',
35
- filename='best-checkpoint',
36
  save_top_k=3,
37
- mode='min'
38
  )
39
  early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)
40
 
41
  trainer = pl.Trainer(
42
  max_epochs=20,
43
- accelerator="auto", # Use GPU if available
44
  devices=1,
45
- callbacks=[checkpoint_callback, early_stop_callback]
46
  )
47
  trainer.fit(model, train_loader, val_loader)
48
 
 
49
  if __name__ == "__main__":
50
- main()
 
6
  from model_pl import BindingAffinityModelPL
7
  import pandas as pd
8
 
9
+
10
  def main():
11
  lr = 0.0005
12
  # Load dataset
13
+ dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
14
  dataframe.dropna(inplace=True)
15
  print("Dataset loaded with {} samples".format(len(dataframe)))
16
  dataset = BindingDataset(dataframe)
 
31
 
32
  model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
33
  checkpoint_callback = ModelCheckpoint(
34
+ monitor="val_loss",
35
+ dirpath="checkpoints/",
36
+ filename="best-checkpoint",
37
  save_top_k=3,
38
+ mode="min",
39
  )
40
  early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)
41
 
42
  trainer = pl.Trainer(
43
  max_epochs=20,
44
+ accelerator="auto", # Use GPU if available
45
  devices=1,
46
+ callbacks=[checkpoint_callback, early_stop_callback],
47
  )
48
  trainer.fit(model, train_loader, val_loader)
49
 
50
+
51
  if __name__ == "__main__":
52
+ main()
transformer_from_scratch/attention_visual.ipynb CHANGED
@@ -21,6 +21,7 @@
21
  "import pandas as pd\n",
22
  "import numpy as np\n",
23
  "import warnings\n",
 
24
  "warnings.filterwarnings(\"ignore\")"
25
  ]
26
  },
@@ -72,12 +73,14 @@
72
  "source": [
73
  "config = get_config()\n",
74
  "train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n",
75
- "model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)\n",
 
 
76
  "\n",
77
  "# Load the pretrained weights\n",
78
  "model_filename = get_weights_file_path(config, f\"34\")\n",
79
  "state = torch.load(model_filename)\n",
80
- "model.load_state_dict(state['model_state_dict'])"
81
  ]
82
  },
83
  {
@@ -95,16 +98,26 @@
95
  " decoder_input = batch[\"decoder_input\"].to(device)\n",
96
  " decoder_mask = batch[\"decoder_mask\"].to(device)\n",
97
  "\n",
98
- " encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]\n",
99
- " decoder_input_tokens = [vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]\n",
 
 
 
 
100
  "\n",
101
  " # check that the batch size is 1\n",
102
- " assert encoder_input.size(\n",
103
- " 0) == 1, \"Batch size must be 1 for validation\"\n",
104
  "\n",
105
  " model_out = greedy_decode(\n",
106
- " model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config['seq_len'], device)\n",
107
- " \n",
 
 
 
 
 
 
 
108
  " return batch, encoder_input_tokens, decoder_input_tokens"
109
  ]
110
  },
@@ -132,6 +145,7 @@
132
  " columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
133
  " )\n",
134
  "\n",
 
135
  "def get_attn_map(attn_type: str, layer: int, head: int):\n",
136
  " if attn_type == \"encoder\":\n",
137
  " attn = model.encoder.layers[layer].self_attention_block.attention_scores\n",
@@ -141,6 +155,7 @@
141
  " attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n",
142
  " return attn[0, head].data\n",
143
  "\n",
 
144
  "def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n",
145
  " df = mtx2df(\n",
146
  " get_attn_map(attn_type, layer, head),\n",
@@ -158,17 +173,29 @@
158
  " color=\"value\",\n",
159
  " tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
160
  " )\n",
161
- " #.title(f\"Layer {layer} Head {head}\")\n",
162
  " .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n",
163
  " .interactive()\n",
164
  " )\n",
165
  "\n",
166
- "def get_all_attention_maps(attn_type: str, layers: list[int], heads: list[int], row_tokens: list, col_tokens, max_sentence_len: int):\n",
 
 
 
 
 
 
 
 
167
  " charts = []\n",
168
  " for layer in layers:\n",
169
  " rowCharts = []\n",
170
  " for head in heads:\n",
171
- " rowCharts.append(attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len))\n",
 
 
 
 
172
  " charts.append(alt.hconcat(*rowCharts))\n",
173
  " return alt.vconcat(*charts)"
174
  ]
@@ -287,7 +314,14 @@
287
  "heads = [0, 1, 2, 3, 4, 5, 6, 7]\n",
288
  "\n",
289
  "# Encoder Self-Attention\n",
290
- "get_all_attention_maps(\"encoder\", layers, heads, encoder_input_tokens, encoder_input_tokens, min(20, sentence_len))"
 
 
 
 
 
 
 
291
  ]
292
  },
293
  {
@@ -379,7 +413,14 @@
379
  ],
380
  "source": [
381
  "# Encoder Self-Attention\n",
382
- "get_all_attention_maps(\"decoder\", layers, heads, decoder_input_tokens, decoder_input_tokens, min(20, sentence_len))"
 
 
 
 
 
 
 
383
  ]
384
  },
385
  {
@@ -471,7 +512,14 @@
471
  ],
472
  "source": [
473
  "# Encoder Self-Attention\n",
474
- "get_all_attention_maps(\"encoder-decoder\", layers, heads, encoder_input_tokens, decoder_input_tokens, min(20, sentence_len))"
 
 
 
 
 
 
 
475
  ]
476
  },
477
  {
 
21
  "import pandas as pd\n",
22
  "import numpy as np\n",
23
  "import warnings\n",
24
+ "\n",
25
  "warnings.filterwarnings(\"ignore\")"
26
  ]
27
  },
 
73
  "source": [
74
  "config = get_config()\n",
75
  "train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n",
76
+ "model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(\n",
77
+ " device\n",
78
+ ")\n",
79
  "\n",
80
  "# Load the pretrained weights\n",
81
  "model_filename = get_weights_file_path(config, f\"34\")\n",
82
  "state = torch.load(model_filename)\n",
83
+ "model.load_state_dict(state[\"model_state_dict\"])"
84
  ]
85
  },
86
  {
 
98
  " decoder_input = batch[\"decoder_input\"].to(device)\n",
99
  " decoder_mask = batch[\"decoder_mask\"].to(device)\n",
100
  "\n",
101
+ " encoder_input_tokens = [\n",
102
+ " vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()\n",
103
+ " ]\n",
104
+ " decoder_input_tokens = [\n",
105
+ " vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()\n",
106
+ " ]\n",
107
  "\n",
108
  " # check that the batch size is 1\n",
109
+ " assert encoder_input.size(0) == 1, \"Batch size must be 1 for validation\"\n",
 
110
  "\n",
111
  " model_out = greedy_decode(\n",
112
+ " model,\n",
113
+ " encoder_input,\n",
114
+ " encoder_mask,\n",
115
+ " vocab_src,\n",
116
+ " vocab_tgt,\n",
117
+ " config[\"seq_len\"],\n",
118
+ " device,\n",
119
+ " )\n",
120
+ "\n",
121
  " return batch, encoder_input_tokens, decoder_input_tokens"
122
  ]
123
  },
 
145
  " columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
146
  " )\n",
147
  "\n",
148
+ "\n",
149
  "def get_attn_map(attn_type: str, layer: int, head: int):\n",
150
  " if attn_type == \"encoder\":\n",
151
  " attn = model.encoder.layers[layer].self_attention_block.attention_scores\n",
 
155
  " attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n",
156
  " return attn[0, head].data\n",
157
  "\n",
158
+ "\n",
159
  "def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n",
160
  " df = mtx2df(\n",
161
  " get_attn_map(attn_type, layer, head),\n",
 
173
  " color=\"value\",\n",
174
  " tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
175
  " )\n",
176
+ " # .title(f\"Layer {layer} Head {head}\")\n",
177
  " .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n",
178
  " .interactive()\n",
179
  " )\n",
180
  "\n",
181
+ "\n",
182
+ "def get_all_attention_maps(\n",
183
+ " attn_type: str,\n",
184
+ " layers: list[int],\n",
185
+ " heads: list[int],\n",
186
+ " row_tokens: list,\n",
187
+ " col_tokens,\n",
188
+ " max_sentence_len: int,\n",
189
+ "):\n",
190
  " charts = []\n",
191
  " for layer in layers:\n",
192
  " rowCharts = []\n",
193
  " for head in heads:\n",
194
+ " rowCharts.append(\n",
195
+ " attn_map(\n",
196
+ " attn_type, layer, head, row_tokens, col_tokens, max_sentence_len\n",
197
+ " )\n",
198
+ " )\n",
199
  " charts.append(alt.hconcat(*rowCharts))\n",
200
  " return alt.vconcat(*charts)"
201
  ]
 
314
  "heads = [0, 1, 2, 3, 4, 5, 6, 7]\n",
315
  "\n",
316
  "# Encoder Self-Attention\n",
317
+ "get_all_attention_maps(\n",
318
+ " \"encoder\",\n",
319
+ " layers,\n",
320
+ " heads,\n",
321
+ " encoder_input_tokens,\n",
322
+ " encoder_input_tokens,\n",
323
+ " min(20, sentence_len),\n",
324
+ ")"
325
  ]
326
  },
327
  {
 
413
  ],
414
  "source": [
415
  "# Encoder Self-Attention\n",
416
+ "get_all_attention_maps(\n",
417
+ " \"decoder\",\n",
418
+ " layers,\n",
419
+ " heads,\n",
420
+ " decoder_input_tokens,\n",
421
+ " decoder_input_tokens,\n",
422
+ " min(20, sentence_len),\n",
423
+ ")"
424
  ]
425
  },
426
  {
 
512
  ],
513
  "source": [
514
  "# Encoder Self-Attention\n",
515
+ "get_all_attention_maps(\n",
516
+ " \"encoder-decoder\",\n",
517
+ " layers,\n",
518
+ " heads,\n",
519
+ " encoder_input_tokens,\n",
520
+ " decoder_input_tokens,\n",
521
+ " min(20, sentence_len),\n",
522
+ ")"
523
  ]
524
  },
525
  {
transformer_from_scratch/config.py CHANGED
@@ -14,14 +14,15 @@ def get_config():
14
  "model_basename": "tmodel_",
15
  "preload": None,
16
  "tokenizer_file": "tokenizer_{0}.json",
17
- "experiment_name": "runs/tmodel"
18
  }
19
 
 
20
  def get_weights_file_path(config, epoch):
21
  model_folder = config["model_folder"]
22
  model_basename = config["model_basename"]
23
  model_filename = f"{model_basename}{epoch}.pt"
24
- return str(Path('.') / model_folder / model_filename)
25
 
26
 
27
  def latest_weights_file_path(config):
@@ -31,4 +32,4 @@ def latest_weights_file_path(config):
31
  if len(weights_files) == 0:
32
  return None
33
  weights_files.sort()
34
- return str(weights_files[-1])
 
14
  "model_basename": "tmodel_",
15
  "preload": None,
16
  "tokenizer_file": "tokenizer_{0}.json",
17
+ "experiment_name": "runs/tmodel",
18
  }
19
 
20
+
21
  def get_weights_file_path(config, epoch):
22
  model_folder = config["model_folder"]
23
  model_basename = config["model_basename"]
24
  model_filename = f"{model_basename}{epoch}.pt"
25
+ return str(Path(".") / model_folder / model_filename)
26
 
27
 
28
  def latest_weights_file_path(config):
 
32
  if len(weights_files) == 0:
33
  return None
34
  weights_files.sort()
35
+ return str(weights_files[-1])
transformer_from_scratch/dataset.py CHANGED
@@ -13,26 +13,34 @@ class BilingualDataset(Dataset):
13
  self.src_lang = src_lang
14
  self.tgt_lang = tgt_lang
15
 
16
- self.sos_token = torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64)
17
- self.eos_token = torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64)
18
- self.pad_token = torch.tensor([tokenizer_src.token_to_id('[PAD]')], dtype=torch.int64)
 
 
 
 
 
 
19
 
20
  def __len__(self):
21
  return len(self.ds)
22
 
23
  def __getitem__(self, index):
24
  src_target_pair = self.ds[index]
25
- src_text = src_target_pair['translation'][self.src_lang]
26
- tgt_text = src_target_pair['translation'][self.tgt_lang]
27
 
28
  enc_input_tokens = self.tokenizer_src.encode(src_text).ids
29
  dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
30
 
31
- enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2 # for SOS and EOS
32
- dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 # for SOS
 
 
33
 
34
  if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
35
- raise ValueError('Sentence is too long')
36
 
37
  # Add SOS and EOS tokens to source text
38
  encoder_input = torch.cat(
@@ -40,7 +48,9 @@ class BilingualDataset(Dataset):
40
  self.sos_token,
41
  torch.tensor(enc_input_tokens, dtype=torch.int64),
42
  self.eos_token,
43
- torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64)
 
 
44
  ]
45
  )
46
  # Add SOS token to the decoder input
@@ -48,7 +58,9 @@ class BilingualDataset(Dataset):
48
  [
49
  self.sos_token,
50
  torch.tensor(dec_input_tokens, dtype=torch.int64),
51
- torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
 
 
52
  ]
53
  )
54
  # Add EOS token to the label (what we want )
@@ -56,7 +68,9 @@ class BilingualDataset(Dataset):
56
  [
57
  torch.tensor(dec_input_tokens, dtype=torch.int64),
58
  self.eos_token,
59
- torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64)
 
 
60
  ]
61
  )
62
 
@@ -65,15 +79,27 @@ class BilingualDataset(Dataset):
65
  assert label.size(0) == self.seq_len
66
 
67
  return {
68
- "encoder_input": encoder_input, # (Seq_len)
69
- "decoder_input": decoder_input, # (Seq_len)
70
- "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, Seq_len)
71
- "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int() & casual_mask(decoder_input.size(0)), # (1, Seq_len) & (1, Seq_len, Seq_len)
72
- "label": label, # (Seq_len)
 
 
 
 
 
 
 
 
 
73
  "src_text": src_text,
74
- "tgt_text": tgt_text
75
  }
76
 
 
77
  def casual_mask(size):
78
- mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int) # Upper triangular matrix, above the main diagonal
79
- return mask==0
 
 
 
13
  self.src_lang = src_lang
14
  self.tgt_lang = tgt_lang
15
 
16
+ self.sos_token = torch.tensor(
17
+ [tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64
18
+ )
19
+ self.eos_token = torch.tensor(
20
+ [tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64
21
+ )
22
+ self.pad_token = torch.tensor(
23
+ [tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64
24
+ )
25
 
26
  def __len__(self):
27
  return len(self.ds)
28
 
29
  def __getitem__(self, index):
30
  src_target_pair = self.ds[index]
31
+ src_text = src_target_pair["translation"][self.src_lang]
32
+ tgt_text = src_target_pair["translation"][self.tgt_lang]
33
 
34
  enc_input_tokens = self.tokenizer_src.encode(src_text).ids
35
  dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
36
 
37
+ enc_num_padding_tokens = (
38
+ self.seq_len - len(enc_input_tokens) - 2
39
+ ) # for SOS and EOS
40
+ dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 # for SOS
41
 
42
  if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
43
+ raise ValueError("Sentence is too long")
44
 
45
  # Add SOS and EOS tokens to source text
46
  encoder_input = torch.cat(
 
48
  self.sos_token,
49
  torch.tensor(enc_input_tokens, dtype=torch.int64),
50
  self.eos_token,
51
+ torch.tensor(
52
+ [self.pad_token] * enc_num_padding_tokens, dtype=torch.int64
53
+ ),
54
  ]
55
  )
56
  # Add SOS token to the decoder input
 
58
  [
59
  self.sos_token,
60
  torch.tensor(dec_input_tokens, dtype=torch.int64),
61
+ torch.tensor(
62
+ [self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
63
+ ),
64
  ]
65
  )
66
  # Add EOS token to the label (what we want )
 
68
  [
69
  torch.tensor(dec_input_tokens, dtype=torch.int64),
70
  self.eos_token,
71
+ torch.tensor(
72
+ [self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
73
+ ),
74
  ]
75
  )
76
 
 
79
  assert label.size(0) == self.seq_len
80
 
81
  return {
82
+ "encoder_input": encoder_input, # (Seq_len)
83
+ "decoder_input": decoder_input, # (Seq_len)
84
+ "encoder_mask": (encoder_input != self.pad_token)
85
+ .unsqueeze(0)
86
+ .unsqueeze(0)
87
+ .int(), # (1, 1, Seq_len)
88
+ "decoder_mask": (decoder_input != self.pad_token)
89
+ .unsqueeze(0)
90
+ .unsqueeze(0)
91
+ .int()
92
+ & casual_mask(
93
+ decoder_input.size(0)
94
+ ), # (1, Seq_len) & (1, Seq_len, Seq_len)
95
+ "label": label, # (Seq_len)
96
  "src_text": src_text,
97
+ "tgt_text": tgt_text,
98
  }
99
 
100
+
101
  def casual_mask(size):
102
+ mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(
103
+ torch.int
104
+ ) # Upper triangular matrix, above the main diagonal
105
+ return mask == 0
transformer_from_scratch/inference.ipynb CHANGED
@@ -12,7 +12,7 @@
12
  },
13
  "source": [
14
  "import torch\n",
15
- "from config import get_config,latest_weights_file_path\n",
16
  "from train import get_model, get_ds, run_validation\n",
17
  "from translate import translate"
18
  ],
@@ -22,10 +22,10 @@
22
  "evalue": "cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)",
23
  "output_type": "error",
24
  "traceback": [
25
- "\u001B[31m---------------------------------------------------------------------------\u001B[39m",
26
- "\u001B[31mImportError\u001B[39m Traceback (most recent call last)",
27
- "\u001B[36mCell\u001B[39m\u001B[36m \u001B[39m\u001B[32mIn[2]\u001B[39m\u001B[32m, line 3\u001B[39m\n\u001B[32m 1\u001B[39m \u001B[38;5;28;01mimport\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtorch\u001B[39;00m\n\u001B[32m 2\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mconfig\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m get_config,latest_weights_file_path\n\u001B[32m----> \u001B[39m\u001B[32m3\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtrain\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m get_model, get_ds, run_validation\n\u001B[32m 4\u001B[39m \u001B[38;5;28;01mfrom\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[34;01mtranslate\u001B[39;00m\u001B[38;5;250m \u001B[39m\u001B[38;5;28;01mimport\u001B[39;00m translate\n",
28
- "\u001B[31mImportError\u001B[39m: cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)"
29
  ]
30
  }
31
  ],
@@ -42,12 +42,14 @@
42
  "print(\"Using device:\", device)\n",
43
  "config = get_config()\n",
44
  "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
45
- "model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)\n",
 
 
46
  "\n",
47
  "# Load the pretrained weights\n",
48
  "model_filename = latest_weights_file_path(config)\n",
49
  "state = torch.load(model_filename)\n",
50
- "model.load_state_dict(state['model_state_dict'])"
51
  ],
52
  "id": "e6b0b6022c4d1c15"
53
  },
@@ -56,7 +58,20 @@
56
  "cell_type": "code",
57
  "outputs": [],
58
  "execution_count": null,
59
- "source": "run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: print(msg), 0, None, num_examples=10)",
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  "id": "be2c2169c183a445"
61
  },
62
  {
 
12
  },
13
  "source": [
14
  "import torch\n",
15
+ "from config import get_config, latest_weights_file_path\n",
16
  "from train import get_model, get_ds, run_validation\n",
17
  "from translate import translate"
18
  ],
 
22
  "evalue": "cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)",
23
  "output_type": "error",
24
  "traceback": [
25
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
26
+ "\u001b[31mImportError\u001b[39m Traceback (most recent call last)",
27
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mconfig\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_config,latest_weights_file_path\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtrain\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_model, get_ds, run_validation\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtranslate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m translate\n",
28
+ "\u001b[31mImportError\u001b[39m: cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)"
29
  ]
30
  }
31
  ],
 
42
  "print(\"Using device:\", device)\n",
43
  "config = get_config()\n",
44
  "train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
45
+ "model = get_model(\n",
46
+ " config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()\n",
47
+ ").to(device)\n",
48
  "\n",
49
  "# Load the pretrained weights\n",
50
  "model_filename = latest_weights_file_path(config)\n",
51
  "state = torch.load(model_filename)\n",
52
+ "model.load_state_dict(state[\"model_state_dict\"])"
53
  ],
54
  "id": "e6b0b6022c4d1c15"
55
  },
 
58
  "cell_type": "code",
59
  "outputs": [],
60
  "execution_count": null,
61
+ "source": [
62
+ "run_validation(\n",
63
+ " model,\n",
64
+ " val_dataloader,\n",
65
+ " tokenizer_src,\n",
66
+ " tokenizer_tgt,\n",
67
+ " config[\"seq_len\"],\n",
68
+ " device,\n",
69
+ " lambda msg: print(msg),\n",
70
+ " 0,\n",
71
+ " None,\n",
72
+ " num_examples=10,\n",
73
+ ")"
74
+ ],
75
  "id": "be2c2169c183a445"
76
  },
77
  {
transformer_from_scratch/train.py CHANGED
@@ -19,19 +19,24 @@ from model import build_transformer
19
  from config import get_weights_file_path, get_config
20
  import warnings
21
 
22
- def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
23
- sos_idx = tokenizer_tgt.token_to_id('[SOS]')
24
- eos_idx = tokenizer_tgt.token_to_id('[EOS]')
 
 
 
25
 
26
  # Precompute the encoder output and reuse it for every token we get from the decoder
27
  encoder_output = model.encode(source, source_mask)
28
  # Initialize the decoder input with the sos token
29
  decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
30
  while True:
31
- if decoder_input .size(1) == max_len:
32
  break
33
  # Build mask for the target (decoder input)
34
- decoder_mask = casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)
 
 
35
 
36
  # Calculate the output of the decoder
37
  out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
@@ -40,15 +45,31 @@ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_
40
  prob = model.project(out[:, -1])
41
  # Select the token with the highest probability (because it's a greedy search)
42
  _, next_word = torch.max(prob, dim=1)
43
- decoder_input = torch.cat([decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
 
 
 
 
 
 
44
 
45
  if next_word == eos_idx:
46
  break
47
  return decoder_input.squeeze(0)
48
 
49
 
50
-
51
- def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer, num_examples=2):
 
 
 
 
 
 
 
 
 
 
52
  model.eval()
53
  count = 0
54
 
@@ -61,25 +82,33 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
61
  with torch.no_grad():
62
  for batch in validation_ds:
63
  count += 1
64
- encoder_input = batch['encoder_input'].to(device)
65
- encoder_mask = batch['encoder_mask'].to(device)
66
 
67
  assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
68
 
69
- model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
70
-
71
- source_text = batch['src_text'][0]
72
- target_text = batch['tgt_text'][0]
 
 
 
 
 
 
 
 
73
  model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
74
 
75
  source_texts.append(source_text)
76
  expected.append(target_text)
77
  predicted.append(model_out_text)
78
 
79
- print_msg('-' * console_width)
80
- print_msg(f'Source: {source_text}')
81
- print_msg(f'Expected: {target_text}')
82
- print_msg(f'Predicted: {model_out_text}')
83
 
84
  if count == num_examples:
85
  break
@@ -91,25 +120,22 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
91
  # Compute the char error rate
92
  metric = CharErrorRate()
93
  cer = metric(predicted, expected)
94
- writer.add_scalar('validation cer', cer, global_step)
95
  writer.flush()
96
 
97
  # Compute the word error rate
98
  metric = WordErrorRate()
99
  wer = metric(predicted, expected)
100
- writer.add_scalar('validation wer', wer, global_step)
101
  writer.flush()
102
 
103
  # Compute the BLEU metric
104
  metric = BLEUScore()
105
  bleu = metric(predicted, expected)
106
- writer.add_scalar('validation BLEU', bleu, global_step)
107
  writer.flush()
108
 
109
 
110
-
111
-
112
-
113
  def get_all_sentences(ds, lang):
114
  for item in ds:
115
  yield item["translation"][lang]
@@ -145,84 +171,117 @@ def get_ds(config):
145
  val_ds_size = len(ds_raw) - train_ds_size
146
  train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
147
 
148
- train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
149
- val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'], config['seq_len'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
  max_len_src = 0
152
  max_len_tgt = 0
153
 
154
  for item in ds_raw:
155
- src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
156
- tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
157
 
158
  max_len_src = max(len(src_ids), max_len_src)
159
  max_len_tgt = max(len(tgt_ids), max_len_tgt)
160
 
161
- print(f'Max length of the source sentence: {max_len_src}')
162
- print(f'Max length of the target sentence: {max_len_tgt}')
163
 
164
- train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
 
 
165
  val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
166
 
167
  return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
168
 
169
 
170
  def get_model(config, vocab_src_len, vocab_tgt_len):
171
- model = build_transformer(vocab_src_len, vocab_tgt_len, config['seq_len'], config['seq_len'], config['d_model'])
 
 
 
 
 
 
172
  return model
173
 
 
174
  def train_model(config):
175
  # Define the device
176
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
177
- print(f'using device: {device}')
178
 
179
- Path(config['model_folder']).mkdir(parents=True, exist_ok=True)
180
 
181
  train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
182
- model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
 
 
183
 
184
  # Tensorboard
185
- writer = SummaryWriter(config['experiment_name'])
186
 
187
- optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)
188
 
189
  initial_epoch = 0
190
  global_step = 0
191
 
192
- if config['preload']:
193
- model_filename = get_weights_file_path(config, config['preload'])
194
- print(f'Preloading model {model_filename}')
195
  state = torch.load(model_filename)
196
- initial_epoch = state['epoch'] + 1
197
- optimizer.load_state_dict(state['optimizer_state_dict'])
198
- global_step = state['global_step']
199
 
200
- loss_fn = torch.nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1)
 
 
201
 
202
- for epoch in range(initial_epoch, config['num_epochs']):
203
- batch_iterator = tqdm(train_dataloader,desc=f'Processing epoch {epoch:02d}')
204
  for batch in batch_iterator:
205
  model.train()
206
- encoder_input = batch['encoder_input'].to(device) # (B, Seq_len)
207
- decoder_input = batch['decoder_input'].to(device) # (B, Seq_len)
208
 
209
- encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, Seq_len)
210
- decoder_mask = batch['decoder_mask'].to(device) # (B, 1, Seq_len, Seq_len)
211
 
212
  # Run the tensors through the transformer model
213
- encoder_output = model.encode(encoder_input, encoder_mask) # (B, Seq_len, d_model)
214
- decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, Seq_len, d_model)
 
 
 
 
215
 
216
- proj_output = model.project(decoder_output) # (B, Seq_len, tgt_vocab_size)
217
- label = batch['label'].to(device) # (B, Seq_len)
218
 
219
  # (B, Seq_len, tgt_vocab_size) --> (B * Seq_len, tgt_vocab_size)
220
- loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
 
 
221
 
222
- batch_iterator.set_postfix({f'loss' : f'{loss.item(): 6.3f}'})
223
 
224
  # Log the loss
225
- writer.add_scalar('train loss', loss.item(), global_step)
226
  writer.flush()
227
 
228
  # Backpropagate the loss
@@ -234,23 +293,32 @@ def train_model(config):
234
 
235
  global_step += 1
236
 
237
- run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
 
 
 
 
 
 
 
 
 
 
238
 
239
  # Save the model at the end of each epoch
240
- model_filename = get_weights_file_path(config, f'{epoch:02d}')
241
  torch.save(
242
  {
243
- 'epoch': epoch,
244
- 'model_state_dict': model.state_dict(),
245
- 'optimizer_state_dict': optimizer.state_dict(),
246
- 'global_step': global_step
247
- }, model_filename)
 
 
248
 
249
 
250
  if __name__ == "__main__":
251
- warnings.filterwarnings('ignore')
252
  config = get_config()
253
  train_model(config)
254
-
255
-
256
-
 
19
  from config import get_weights_file_path, get_config
20
  import warnings
21
 
22
+
23
+ def greedy_decode(
24
+ model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device
25
+ ):
26
+ sos_idx = tokenizer_tgt.token_to_id("[SOS]")
27
+ eos_idx = tokenizer_tgt.token_to_id("[EOS]")
28
 
29
  # Precompute the encoder output and reuse it for every token we get from the decoder
30
  encoder_output = model.encode(source, source_mask)
31
  # Initialize the decoder input with the sos token
32
  decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
33
  while True:
34
+ if decoder_input.size(1) == max_len:
35
  break
36
  # Build mask for the target (decoder input)
37
+ decoder_mask = (
38
+ casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)
39
+ )
40
 
41
  # Calculate the output of the decoder
42
  out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
 
45
  prob = model.project(out[:, -1])
46
  # Select the token with the highest probability (because it's a greedy search)
47
  _, next_word = torch.max(prob, dim=1)
48
+ decoder_input = torch.cat(
49
+ [
50
+ decoder_input,
51
+ torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device),
52
+ ],
53
+ dim=1,
54
+ )
55
 
56
  if next_word == eos_idx:
57
  break
58
  return decoder_input.squeeze(0)
59
 
60
 
61
+ def run_validation(
62
+ model,
63
+ validation_ds,
64
+ tokenizer_src,
65
+ tokenizer_tgt,
66
+ max_len,
67
+ device,
68
+ print_msg,
69
+ global_step,
70
+ writer,
71
+ num_examples=2,
72
+ ):
73
  model.eval()
74
  count = 0
75
 
 
82
  with torch.no_grad():
83
  for batch in validation_ds:
84
  count += 1
85
+ encoder_input = batch["encoder_input"].to(device)
86
+ encoder_mask = batch["encoder_mask"].to(device)
87
 
88
  assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
89
 
90
+ model_out = greedy_decode(
91
+ model,
92
+ encoder_input,
93
+ encoder_mask,
94
+ tokenizer_src,
95
+ tokenizer_tgt,
96
+ max_len,
97
+ device,
98
+ )
99
+
100
+ source_text = batch["src_text"][0]
101
+ target_text = batch["tgt_text"][0]
102
  model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
103
 
104
  source_texts.append(source_text)
105
  expected.append(target_text)
106
  predicted.append(model_out_text)
107
 
108
+ print_msg("-" * console_width)
109
+ print_msg(f"Source: {source_text}")
110
+ print_msg(f"Expected: {target_text}")
111
+ print_msg(f"Predicted: {model_out_text}")
112
 
113
  if count == num_examples:
114
  break
 
120
  # Compute the char error rate
121
  metric = CharErrorRate()
122
  cer = metric(predicted, expected)
123
+ writer.add_scalar("validation cer", cer, global_step)
124
  writer.flush()
125
 
126
  # Compute the word error rate
127
  metric = WordErrorRate()
128
  wer = metric(predicted, expected)
129
+ writer.add_scalar("validation wer", wer, global_step)
130
  writer.flush()
131
 
132
  # Compute the BLEU metric
133
  metric = BLEUScore()
134
  bleu = metric(predicted, expected)
135
+ writer.add_scalar("validation BLEU", bleu, global_step)
136
  writer.flush()
137
 
138
 
 
 
 
139
  def get_all_sentences(ds, lang):
140
  for item in ds:
141
  yield item["translation"][lang]
 
171
  val_ds_size = len(ds_raw) - train_ds_size
172
  train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
173
 
174
+ train_ds = BilingualDataset(
175
+ train_ds_raw,
176
+ tokenizer_src,
177
+ tokenizer_tgt,
178
+ config["lang_src"],
179
+ config["lang_tgt"],
180
+ config["seq_len"],
181
+ )
182
+ val_ds = BilingualDataset(
183
+ val_ds_raw,
184
+ tokenizer_src,
185
+ tokenizer_tgt,
186
+ config["lang_src"],
187
+ config["lang_tgt"],
188
+ config["seq_len"],
189
+ )
190
 
191
  max_len_src = 0
192
  max_len_tgt = 0
193
 
194
  for item in ds_raw:
195
+ src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
196
+ tgt_ids = tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids
197
 
198
  max_len_src = max(len(src_ids), max_len_src)
199
  max_len_tgt = max(len(tgt_ids), max_len_tgt)
200
 
201
+ print(f"Max length of the source sentence: {max_len_src}")
202
+ print(f"Max length of the target sentence: {max_len_tgt}")
203
 
204
+ train_dataloader = DataLoader(
205
+ train_ds, batch_size=config["batch_size"], shuffle=True
206
+ )
207
  val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
208
 
209
  return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
210
 
211
 
212
  def get_model(config, vocab_src_len, vocab_tgt_len):
213
+ model = build_transformer(
214
+ vocab_src_len,
215
+ vocab_tgt_len,
216
+ config["seq_len"],
217
+ config["seq_len"],
218
+ config["d_model"],
219
+ )
220
  return model
221
 
222
+
223
  def train_model(config):
224
  # Define the device
225
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
226
+ print(f"using device: {device}")
227
 
228
+ Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
229
 
230
  train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
231
+ model = get_model(
232
+ config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()
233
+ ).to(device)
234
 
235
  # Tensorboard
236
+ writer = SummaryWriter(config["experiment_name"])
237
 
238
+ optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
239
 
240
  initial_epoch = 0
241
  global_step = 0
242
 
243
+ if config["preload"]:
244
+ model_filename = get_weights_file_path(config, config["preload"])
245
+ print(f"Preloading model {model_filename}")
246
  state = torch.load(model_filename)
247
+ initial_epoch = state["epoch"] + 1
248
+ optimizer.load_state_dict(state["optimizer_state_dict"])
249
+ global_step = state["global_step"]
250
 
251
+ loss_fn = torch.nn.CrossEntropyLoss(
252
+ ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
253
+ )
254
 
255
+ for epoch in range(initial_epoch, config["num_epochs"]):
256
+ batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch {epoch:02d}")
257
  for batch in batch_iterator:
258
  model.train()
259
+ encoder_input = batch["encoder_input"].to(device) # (B, Seq_len)
260
+ decoder_input = batch["decoder_input"].to(device) # (B, Seq_len)
261
 
262
+ encoder_mask = batch["encoder_mask"].to(device) # (B, 1, 1, Seq_len)
263
+ decoder_mask = batch["decoder_mask"].to(device) # (B, 1, Seq_len, Seq_len)
264
 
265
  # Run the tensors through the transformer model
266
+ encoder_output = model.encode(
267
+ encoder_input, encoder_mask
268
+ ) # (B, Seq_len, d_model)
269
+ decoder_output = model.decode(
270
+ encoder_output, encoder_mask, decoder_input, decoder_mask
271
+ ) # (B, Seq_len, d_model)
272
 
273
+ proj_output = model.project(decoder_output) # (B, Seq_len, tgt_vocab_size)
274
+ label = batch["label"].to(device) # (B, Seq_len)
275
 
276
  # (B, Seq_len, tgt_vocab_size) --> (B * Seq_len, tgt_vocab_size)
277
+ loss = loss_fn(
278
+ proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1)
279
+ )
280
 
281
+ batch_iterator.set_postfix({f"loss": f"{loss.item(): 6.3f}"})
282
 
283
  # Log the loss
284
+ writer.add_scalar("train loss", loss.item(), global_step)
285
  writer.flush()
286
 
287
  # Backpropagate the loss
 
293
 
294
  global_step += 1
295
 
296
+ run_validation(
297
+ model,
298
+ val_dataloader,
299
+ tokenizer_src,
300
+ tokenizer_tgt,
301
+ config["seq_len"],
302
+ device,
303
+ lambda msg: batch_iterator.write(msg),
304
+ global_step,
305
+ writer,
306
+ )
307
 
308
  # Save the model at the end of each epoch
309
+ model_filename = get_weights_file_path(config, f"{epoch:02d}")
310
  torch.save(
311
  {
312
+ "epoch": epoch,
313
+ "model_state_dict": model.state_dict(),
314
+ "optimizer_state_dict": optimizer.state_dict(),
315
+ "global_step": global_step,
316
+ },
317
+ model_filename,
318
+ )
319
 
320
 
321
  if __name__ == "__main__":
322
+ warnings.filterwarnings("ignore")
323
  config = get_config()
324
  train_model(config)
 
 
 
transformer_from_scratch/translate.py CHANGED
@@ -7,32 +7,53 @@ from dataset import BilingualDataset
7
  import torch
8
  import sys
9
 
 
10
  def translate(sentence: str):
11
  # Define the device, tokenizers, and model
12
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
  print("Using device:", device)
14
  config = get_config()
15
 
16
- tokenizer_src = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_src']))))
17
- tokenizer_tgt = Tokenizer.from_file(str(Path(config['tokenizer_file'].format(config['lang_tgt']))))
 
 
 
 
18
 
19
- model = build_transformer(tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size(), config["seq_len"], config['seq_len'], d_model=config['d_model']).to(device)
 
 
 
 
 
 
20
 
21
  # Load the pretrained weights
22
  model_filename = latest_weights_file_path(config)
23
  state = torch.load(model_filename)
24
- model.load_state_dict(state['model_state_dict'])
25
 
26
  # if the sentence is a number use it as an index to the test set
27
  label = ""
28
  if type(sentence) == int or sentence.isdigit():
29
  id = int(sentence)
30
- ds = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='all')
31
- ds = BilingualDataset(ds, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'],
32
- config['seq_len'])
33
- sentence = ds[id]['src_text']
 
 
 
 
 
 
 
 
 
 
34
  label = ds[id]["tgt_text"]
35
- seq_len = config['seq_len']
36
 
37
  # translate the sentence
38
 
@@ -40,46 +61,82 @@ def translate(sentence: str):
40
  with torch.no_grad():
41
  # Precompute the encoder output and reuse it for every generation step
42
  source = tokenizer_src.encode(sentence)
43
- source = torch.cat([
44
- torch.tensor([tokenizer_src.token_to_id('[SOS]')], dtype=torch.int64),
45
- torch.tensor(source.ids, dtype=torch.int64),
46
- torch.tensor([tokenizer_src.token_to_id('[EOS]')], dtype=torch.int64),
47
- torch.tensor([tokenizer_src.token_to_id('[PAD]')] * (seq_len - len(source.ids) - 2), dtype=torch.int64)
48
- ], dim=0).to(device)
49
- source_mask = (source != tokenizer_src.token_to_id('[PAD]')).unsqueeze(0).unsqueeze(0).int().to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  encoder_output = model.encode(source, source_mask)
51
 
52
  # Initialize the decoder input with the sos token
53
- decoder_input = torch.empty(1, 1).fill_(tokenizer_tgt.token_to_id('[SOS]')).type_as(source).to(device)
 
 
 
 
 
54
 
55
  # Print the source sentence and target start prompt
56
- if label != "": print(f"{f'ID: ':>12}{id}")
 
57
  print(f"{f'SOURCE: ':>12}{sentence}")
58
- if label != "": print(f"{f'TARGET: ':>12}{label}")
59
- print(f"{f'PREDICTED: ':>12}", end='')
 
60
 
61
  # Generate the translation word by word
62
  while decoder_input.size(1) < seq_len:
63
  # build mask for target and calculate output
64
- decoder_mask = torch.triu(torch.ones((1, decoder_input.size(1), decoder_input.size(1))), diagonal=1).type(
65
- torch.int).type_as(source_mask).to(device)
 
 
 
 
 
 
 
66
  out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
67
 
68
  # project next token
69
  prob = model.project(out[:, -1])
70
  _, next_word = torch.max(prob, dim=1)
71
  decoder_input = torch.cat(
72
- [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1)
 
 
 
 
 
 
 
 
73
 
74
  # print the translated word
75
- print(f"{tokenizer_tgt.decode([next_word.item()])}", end=' ')
76
 
77
  # break if we predict the end of sentence token
78
- if next_word == tokenizer_tgt.token_to_id('[EOS]'):
79
  break
80
 
81
  # convert ids to tokens
82
  return tokenizer_tgt.decode(decoder_input[0].tolist())
83
 
 
84
  # read sentence from argument
85
- translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good student.")
 
7
  import torch
8
  import sys
9
 
10
+
11
  def translate(sentence: str):
12
  # Define the device, tokenizers, and model
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
  print("Using device:", device)
15
  config = get_config()
16
 
17
+ tokenizer_src = Tokenizer.from_file(
18
+ str(Path(config["tokenizer_file"].format(config["lang_src"])))
19
+ )
20
+ tokenizer_tgt = Tokenizer.from_file(
21
+ str(Path(config["tokenizer_file"].format(config["lang_tgt"])))
22
+ )
23
 
24
+ model = build_transformer(
25
+ tokenizer_src.get_vocab_size(),
26
+ tokenizer_tgt.get_vocab_size(),
27
+ config["seq_len"],
28
+ config["seq_len"],
29
+ d_model=config["d_model"],
30
+ ).to(device)
31
 
32
  # Load the pretrained weights
33
  model_filename = latest_weights_file_path(config)
34
  state = torch.load(model_filename)
35
+ model.load_state_dict(state["model_state_dict"])
36
 
37
  # if the sentence is a number use it as an index to the test set
38
  label = ""
39
  if type(sentence) == int or sentence.isdigit():
40
  id = int(sentence)
41
+ ds = load_dataset(
42
+ f"{config['datasource']}",
43
+ f"{config['lang_src']}-{config['lang_tgt']}",
44
+ split="all",
45
+ )
46
+ ds = BilingualDataset(
47
+ ds,
48
+ tokenizer_src,
49
+ tokenizer_tgt,
50
+ config["lang_src"],
51
+ config["lang_tgt"],
52
+ config["seq_len"],
53
+ )
54
+ sentence = ds[id]["src_text"]
55
  label = ds[id]["tgt_text"]
56
+ seq_len = config["seq_len"]
57
 
58
  # translate the sentence
59
 
 
61
  with torch.no_grad():
62
  # Precompute the encoder output and reuse it for every generation step
63
  source = tokenizer_src.encode(sentence)
64
+ source = torch.cat(
65
+ [
66
+ torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64),
67
+ torch.tensor(source.ids, dtype=torch.int64),
68
+ torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64),
69
+ torch.tensor(
70
+ [tokenizer_src.token_to_id("[PAD]")]
71
+ * (seq_len - len(source.ids) - 2),
72
+ dtype=torch.int64,
73
+ ),
74
+ ],
75
+ dim=0,
76
+ ).to(device)
77
+ source_mask = (
78
+ (source != tokenizer_src.token_to_id("[PAD]"))
79
+ .unsqueeze(0)
80
+ .unsqueeze(0)
81
+ .int()
82
+ .to(device)
83
+ )
84
  encoder_output = model.encode(source, source_mask)
85
 
86
  # Initialize the decoder input with the sos token
87
+ decoder_input = (
88
+ torch.empty(1, 1)
89
+ .fill_(tokenizer_tgt.token_to_id("[SOS]"))
90
+ .type_as(source)
91
+ .to(device)
92
+ )
93
 
94
  # Print the source sentence and target start prompt
95
+ if label != "":
96
+ print(f"{f'ID: ':>12}{id}")
97
  print(f"{f'SOURCE: ':>12}{sentence}")
98
+ if label != "":
99
+ print(f"{f'TARGET: ':>12}{label}")
100
+ print(f"{f'PREDICTED: ':>12}", end="")
101
 
102
  # Generate the translation word by word
103
  while decoder_input.size(1) < seq_len:
104
  # build mask for target and calculate output
105
+ decoder_mask = (
106
+ torch.triu(
107
+ torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
108
+ diagonal=1,
109
+ )
110
+ .type(torch.int)
111
+ .type_as(source_mask)
112
+ .to(device)
113
+ )
114
  out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
115
 
116
  # project next token
117
  prob = model.project(out[:, -1])
118
  _, next_word = torch.max(prob, dim=1)
119
  decoder_input = torch.cat(
120
+ [
121
+ decoder_input,
122
+ torch.empty(1, 1)
123
+ .type_as(source)
124
+ .fill_(next_word.item())
125
+ .to(device),
126
+ ],
127
+ dim=1,
128
+ )
129
 
130
  # print the translated word
131
+ print(f"{tokenizer_tgt.decode([next_word.item()])}", end=" ")
132
 
133
  # break if we predict the end of sentence token
134
+ if next_word == tokenizer_tgt.token_to_id("[EOS]"):
135
  break
136
 
137
  # convert ids to tokens
138
  return tokenizer_tgt.decode(decoder_input[0].tolist())
139
 
140
+
141
  # read sentence from argument
142
+ translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good student.")
utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from torch_geometric.data import Data, Batch
4
+ from rdkit import Chem
5
+ from rdkit.Chem import AllChem
6
+ import nglview as nv
7
+ import py3Dmol
8
+
9
+
10
+ from dataset import get_atom_features, get_protein_features
11
+ from model_attention import BindingAffinityModel
12
+
13
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+ MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
15
+
16
+ GAT_HEADS = 2
17
+ HIDDEN_CHANNELS = 256
18
+
19
+ def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
20
+ """
21
+ Returns:
22
+ - mol: RDKit molecule object with 3D coordinates
23
+ - importance: list of importance scores for each atom
24
+ - predicted_affinity: predicted binding affinity value
25
+ """
26
+ # Prepare ligand molecule with geometry RDKit
27
+ mol = Chem.MolFromSmiles(ligand_smiles)
28
+ mol = Chem.AddHs(mol)
29
+ AllChem.EmbedMolecule(mol, randomSeed=42)
30
+
31
+ # Graph data PyTorch
32
+ atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
33
+ x = torch.tensor(np.array(atom_features), dtype=torch.float)
34
+ edge_index = []
35
+ for bond in mol.GetBonds():
36
+ i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
37
+ edge_index.extend([(i, j), (j, i)])
38
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
39
+
40
+ tokens = [get_protein_features(c) for c in protein_sequence]
41
+ if len(tokens) > 1200: tokens = tokens[:1200]
42
+ else: tokens.extend([0] * (1200 - len(tokens)))
43
+ protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
44
+
45
+ data = Data(x=x, edge_index=edge_index)
46
+ batch = Batch.from_data_list([data]).to(DEVICE)
47
+ num_features = x.shape[1]
48
+
49
+ # Model loading
50
+ model = BindingAffinityModel(num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS).to(DEVICE)
51
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
52
+ model.eval()
53
+
54
+ # Prediction
55
+ with torch.no_grad():
56
+ pred = model(batch.x, batch.edge_index, batch.batch, protein_sequence)
57
+ attention_weights = model.cross_attention.last_attention_weights[0]
58
+
59
+ # Attention importance, Max + Normalize
60
+ real_prot_len = len([t for t in tokens if t != 0])
61
+ importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
62
+
63
+ # Normalize to [0, 1]
64
+ if importance.max() > 0:
65
+ importance = (importance - importance.min()) / (importance.max() - importance.min())
66
+
67
+ # Noise reduction
68
+ importance[importance < 0.01] = 0
69
+ return mol, importance, pred.item()
70
+
71
+
72
+
73
+ def get_py3dmol_view(mol, importance):
74
+ view = py3Dmol.view(width="100%", height="600px")
75
+ view.addModel(Chem.MolToMolBlock(mol), "sdf")
76
+ view.setBackgroundColor('white')
77
+
78
+ view.setStyle({}, {
79
+ 'stick': {'radius': 0.15},
80
+ 'sphere': {'scale': 0.25}
81
+ })
82
+
83
+ indices_sorted = np.argsort(importance)[::-1]
84
+ top_indices = set(indices_sorted[:15])
85
+
86
+ conf = mol.GetConformer()
87
+
88
+ for i, val in enumerate(importance):
89
+ if i in top_indices:
90
+ pos = conf.GetAtomPosition(i)
91
+ symbol = mol.GetAtomWithIdx(i).GetSymbol()
92
+
93
+ label_text = f"{i}:{symbol}:{val:.2f}"
94
+
95
+ view.addLabel(label_text, {
96
+ 'position': {'x': pos.x, 'y': pos.y, 'z': pos.z},
97
+ 'fontSize': 14,
98
+ 'fontColor': 'white',
99
+ 'backgroundColor': 'black',
100
+ 'backgroundOpacity': 0.7,
101
+ 'borderThickness': 0,
102
+ 'inFront': True,
103
+ 'showBackground': True
104
+ })
105
+ view.zoomTo()
106
+ return view
107
+
108
+
109
+
110
+ def save_standalone_ngl_html(mol, importance, filepath):
111
+ pdb_block = Chem.MolToPDBBlock(mol)
112
+ mol_pdb = Chem.MolFromPDBBlock(pdb_block, removeHs=False)
113
+
114
+ for i, atom in enumerate(mol_pdb.GetAtoms()):
115
+ info = atom.GetPDBResidueInfo()
116
+ if info:
117
+
118
+ info.SetTempFactor(float(importance[i]) * 100)
119
+
120
+ final_pdb_block = Chem.MolToPDBBlock(mol_pdb)
121
+
122
+
123
+ final_pdb_block = final_pdb_block.replace("`", "\\`")
124
+
125
+
126
+ indices_sorted = np.argsort(importance)[::-1]
127
+ top_indices = indices_sorted[:15]
128
+
129
+
130
+ selection_list = [str(i) for i in top_indices]
131
+ selection_str = "@" + ",".join(selection_list)
132
+
133
+ # Защита от пустой выборки
134
+ if not selection_list:
135
+ selection_str = "@-1"
136
+
137
+ html_content = f"""<!DOCTYPE html>
138
+ <html lang="en">
139
+ <head>
140
+ <meta charset="UTF-8">
141
+ <title>NGL Visualization</title>
142
+ <script src="https://unpkg.com/ngl@2.0.0-dev.37/dist/ngl.js"></script>
143
+ <style>
144
+ html, body {{ width: 100%; height: 100%; margin: 0; padding: 0; overflow: hidden; font-family: sans-serif; }}
145
+ #viewport {{ width: 100%; height: 100%; }}
146
+
147
+ /* Стиль подсказки */
148
+ #tooltip {{
149
+ display: none;
150
+ position: absolute;
151
+ z-index: 100;
152
+ pointer-events: none; /* Чтобы мышь не 'застревала' на подсказке */
153
+ background-color: rgba(20, 20, 20, 0.9);
154
+ color: white;
155
+ padding: 8px 12px;
156
+ border-radius: 6px;
157
+ font-size: 14px;
158
+ box-shadow: 0 4px 6px rgba(0,0,0,0.3);
159
+ white-space: nowrap;
160
+ border: 1px solid rgba(255,255,255,0.2);
161
+ transition: opacity 0.1s ease;
162
+ }}
163
+
164
+ /* Панель управления */
165
+ #controls {{
166
+ position: absolute;
167
+ top: 20px;
168
+ right: 20px;
169
+ z-index: 50;
170
+ background: rgba(255, 255, 255, 0.95);
171
+ padding: 15px;
172
+ border-radius: 8px;
173
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
174
+ display: flex;
175
+ align-items: center;
176
+ }}
177
+
178
+ /* Стили переключателя */
179
+ .switch-container {{
180
+ display: flex;
181
+ align-items: center;
182
+ gap: 10px;
183
+ cursor: pointer;
184
+ font-weight: bold;
185
+ color: #333;
186
+ }}
187
+
188
+ input[type=checkbox] {{
189
+ transform: scale(1.5);
190
+ cursor: pointer;
191
+ }}
192
+ </style>
193
+ </head>
194
+ <body>
195
+ <div id="controls">
196
+ <label class="switch-container">
197
+ <input type="checkbox" id="heatmapToggle" checked>
198
+ <span>Show Heatmap</span>
199
+ </label>
200
+ </div>
201
+
202
+ <div id="tooltip"></div>
203
+
204
+ <div id="viewport"></div>
205
+
206
+ <script>
207
+ var pdbData = `{final_pdb_block}`;
208
+ var selectionString = "{selection_str}";
209
+ var component; // Глобальная переменная для доступа к модели
210
+
211
+ document.addEventListener("DOMContentLoaded", function () {{
212
+ var stage = new NGL.Stage("viewport", {{ backgroundColor: "white" }});
213
+ var tooltip = document.getElementById("tooltip");
214
+ var toggle = document.getElementById("heatmapToggle");
215
+
216
+ // Загружаем данные
217
+ var stringBlob = new Blob([pdbData], {{type: 'text/plain'}});
218
+
219
+ stage.loadFile(stringBlob, {{ ext: 'pdb' }}).then(function (o) {{
220
+ component = o; // Сохраняем ссылку
221
+
222
+ // Рисуем начальное состояние
223
+ updateVisualization();
224
+ o.autoView();
225
+ }});
226
+
227
+ // --- ФУНКЦИЯ ОБНОВЛЕНИЯ ВИДА ---
228
+ function updateVisualization() {{
229
+ if (!component) return;
230
+
231
+ // Очищаем старые представления (чтобы не накладывались)
232
+ component.removeAllRepresentations();
233
+
234
+ var useHeatmap = toggle.checked;
235
+
236
+ if (useHeatmap) {{
237
+ // 1. РЕЖИМ HEATMAP
238
+ component.addRepresentation("ball+stick", {{
239
+ colorScheme: "bfactor",
240
+ colorDomain: [20, 80],
241
+ colorScale: ["blue", "white", "red"],
242
+ radiusScale: 1.0
243
+ }});
244
+ }} else {{
245
+ // 2. ОБЫЧНЫЙ РЕЖИМ (По элементам)
246
+ component.addRepresentation("ball+stick", {{
247
+ colorScheme: "element",
248
+ radiusScale: 1.0
249
+ }});
250
+ }}
251
+
252
+ // Добавляем метки (они нужны всегда)
253
+ if (selectionString.length > 1 && selectionString !== "@-1") {{
254
+ component.addRepresentation("label", {{
255
+ sele: selectionString,
256
+ labelType: "atomindex",
257
+ color: "black",
258
+ radius: 1.1,
259
+ yOffset: 0.0,
260
+ zOffset: 2.0,
261
+ attachment: "middle_center",
262
+ pickable: true // ВАЖНО: Делаем текст интерактивным
263
+ }});
264
+ }}
265
+ }}
266
+
267
+ // Слушаем переключатель
268
+ toggle.addEventListener("change", updateVisualization);
269
+
270
+ // --- УМНЫЙ TOOLTIP ---
271
+ stage.mouseControls.remove("hoverPick"); // Убираем стандартное поведение
272
+
273
+ stage.signals.hovered.add(function (pickingProxy) {{
274
+ // Проверяем, навели ли мы на атом ИЛИ на метку (текст)
275
+ // NGL возвращает pickingProxy.atom даже если мы навели на label этого атома
276
+ if (pickingProxy && (pickingProxy.atom || pickingProxy.closestBondAtom)) {{
277
+ var atom = pickingProxy.atom || pickingProxy.closestBondAtom;
278
+ var score = atom.bfactor.toFixed(2);
279
+
280
+ tooltip.innerHTML = `
281
+ <div style="margin-bottom:2px;"><b>Atom ID:</b> ${{atom.index}} (${{atom.element}}: ${{atom.atomname}})</div>
282
+ <div style="color: #ffcccc;"><b>Importance:</b> ${{(score/100).toFixed(3)}}</div>
283
+ `;
284
+ tooltip.style.display = "block";
285
+ tooltip.style.opacity = "1";
286
+
287
+ // Позиционирование: сдвиг вправо и вниз, чтобы не мешать
288
+ var cp = pickingProxy.canvasPosition;
289
+ tooltip.style.left = (cp.x + 20) + "px";
290
+ tooltip.style.top = (cp.y + 20) + "px";
291
+
292
+ }} else {{
293
+ // Скрываем, если увели мышь
294
+ tooltip.style.display = "none";
295
+ tooltip.style.opacity = "0";
296
+ }}
297
+ }});
298
+
299
+ // Ресайз окна
300
+ window.addEventListener("resize", function(event){{
301
+ stage.handleResize();
302
+ }}, false);
303
+ }});
304
+ </script>
305
+ </body>
306
+ </html>"""
307
+ with open(filepath, "w", encoding="utf-8") as f:
308
+ f.write(html_content)
visualization.ipynb CHANGED
@@ -2,261 +2,255 @@
2
  "cells": [
3
  {
4
  "cell_type": "code",
 
5
  "id": "initial_id",
6
  "metadata": {
7
  "ExecuteTime": {
8
- "end_time": "2025-12-05T14:02:00.479196Z",
9
- "start_time": "2025-12-05T14:02:00.003864Z"
10
  }
11
  },
12
- "source": [
13
- "import nglview as nv\n",
14
- "import os"
15
- ],
16
  "outputs": [
17
  {
18
  "data": {
19
- "text/plain": [],
20
  "application/vnd.jupyter.widget-view+json": {
 
21
  "version_major": 2,
22
- "version_minor": 0,
23
- "model_id": "3016118bc02a458cbcb4491a27089a6a"
24
- }
25
  },
26
  "metadata": {},
27
- "output_type": "display_data",
28
- "jetTransient": {
29
- "display_id": null
30
- }
31
  }
32
  ],
33
- "execution_count": 1
 
 
 
34
  },
35
  {
36
  "cell_type": "code",
 
37
  "id": "d8d7978e-980a-400c-8c6a-5365990c8855",
38
  "metadata": {
39
  "ExecuteTime": {
40
- "end_time": "2025-12-05T14:02:00.497753Z",
41
- "start_time": "2025-12-05T14:02:00.493751Z"
42
  }
43
  },
 
44
  "source": [
45
  "PDBBIND_PATH = \"refined-set\""
46
- ],
47
- "outputs": [],
48
- "execution_count": 2
49
  },
50
  {
51
  "cell_type": "code",
 
52
  "id": "788a6b43-c515-45c7-bc52-341d446b1a65",
53
  "metadata": {
54
  "ExecuteTime": {
55
- "end_time": "2025-12-05T14:02:00.510747Z",
56
- "start_time": "2025-12-05T14:02:00.505672Z"
57
  }
58
  },
 
59
  "source": [
60
  "EXAMPLE_PDB_ID = \"1a1e\""
61
- ],
62
- "outputs": [],
63
- "execution_count": 3
64
  },
65
  {
66
  "cell_type": "code",
 
67
  "id": "e8f4bebc-845f-43e8-bc4d-ab7b649eb49c",
68
  "metadata": {
69
  "ExecuteTime": {
70
- "end_time": "2025-12-05T14:02:00.523669Z",
71
- "start_time": "2025-12-05T14:02:00.518519Z"
72
  }
73
  },
 
74
  "source": [
75
  "pdb_dir = os.path.join(PDBBIND_PATH, EXAMPLE_PDB_ID)"
76
- ],
77
- "outputs": [],
78
- "execution_count": 4
79
  },
80
  {
81
  "cell_type": "code",
 
82
  "id": "24b5e435-4d8f-4505-b27c-dd6317376ed4",
83
  "metadata": {
84
  "ExecuteTime": {
85
- "end_time": "2025-12-05T14:02:00.570497Z",
86
- "start_time": "2025-12-05T14:02:00.565454Z"
87
  }
88
  },
 
89
  "source": [
90
  "protein_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_protein.pdb\")"
91
- ],
92
- "outputs": [],
93
- "execution_count": 5
94
  },
95
  {
96
  "cell_type": "code",
 
97
  "id": "e7fc3539-00c0-48a2-b012-c80757fa12c4",
98
  "metadata": {
99
  "ExecuteTime": {
100
- "end_time": "2025-12-05T14:02:00.584673Z",
101
- "start_time": "2025-12-05T14:02:00.578982Z"
102
  }
103
  },
 
104
  "source": [
105
  "ligand_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_ligand.sdf\")"
106
- ],
107
- "outputs": [],
108
- "execution_count": 6
109
  },
110
  {
111
  "cell_type": "code",
 
112
  "id": "9a053b99-7c01-4881-b3f7-e9b39090af9d",
113
  "metadata": {
114
  "ExecuteTime": {
115
- "end_time": "2025-12-05T14:02:00.649631Z",
116
- "start_time": "2025-12-05T14:02:00.591897Z"
117
  }
118
  },
 
119
  "source": [
120
  "view = nv.NGLWidget()"
121
- ],
122
- "outputs": [],
123
- "execution_count": 7
124
  },
125
  {
126
  "cell_type": "code",
 
127
  "id": "df8c8e00-3ce6-41dd-b457-d9f50e318dad",
128
  "metadata": {
129
  "ExecuteTime": {
130
- "end_time": "2025-12-05T14:02:00.779528Z",
131
- "start_time": "2025-12-05T14:02:00.657448Z"
132
  }
133
  },
 
134
  "source": [
135
  "protein_comp = view.add_component(protein_file)"
136
- ],
137
- "outputs": [],
138
- "execution_count": 8
139
  },
140
  {
141
  "cell_type": "code",
 
142
  "id": "c191fead-fef8-4077-b787-5bf9552307b1",
143
  "metadata": {
144
  "ExecuteTime": {
145
- "end_time": "2025-12-05T14:02:00.802894Z",
146
- "start_time": "2025-12-05T14:02:00.795534Z"
147
  }
148
  },
 
149
  "source": [
150
  "protein_comp.clear_representations()"
151
- ],
152
- "outputs": [],
153
- "execution_count": 9
154
  },
155
  {
156
  "cell_type": "code",
 
157
  "id": "4559033a-aeda-4659-8d91-9002b5a6ecda",
158
  "metadata": {
159
  "ExecuteTime": {
160
- "end_time": "2025-12-05T14:02:00.824161Z",
161
- "start_time": "2025-12-05T14:02:00.817622Z"
162
  }
163
  },
164
- "source": [
165
- "protein_comp.add_representation('cartoon', color='blue')"
166
- ],
167
  "outputs": [],
168
- "execution_count": 10
 
 
169
  },
170
  {
171
  "cell_type": "code",
 
172
  "id": "73ea1a50-8463-40b8-a942-0c92d3e97a97",
173
  "metadata": {
174
  "ExecuteTime": {
175
- "end_time": "2025-12-05T14:02:00.850013Z",
176
- "start_time": "2025-12-05T14:02:00.840262Z"
177
  }
178
  },
 
179
  "source": [
180
  "ligand_comp = view.add_component(ligand_file)"
181
- ],
182
- "outputs": [],
183
- "execution_count": 11
184
  },
185
  {
186
  "cell_type": "code",
 
187
  "id": "16cdb710-1ed6-4b1d-9e6a-69b7ad61a600",
188
  "metadata": {
189
  "ExecuteTime": {
190
- "end_time": "2025-12-05T14:02:00.866184Z",
191
- "start_time": "2025-12-05T14:02:00.859732Z"
192
  }
193
  },
 
194
  "source": [
195
  "ligand_comp.clear_representations()"
196
- ],
197
- "outputs": [],
198
- "execution_count": 12
199
  },
200
  {
201
  "cell_type": "code",
 
202
  "id": "2193c497-f33c-4de0-86a9-6e535002fcb7",
203
  "metadata": {
204
  "ExecuteTime": {
205
- "end_time": "2025-12-05T14:02:00.882846Z",
206
- "start_time": "2025-12-05T14:02:00.876856Z"
207
  }
208
  },
209
- "source": [
210
- "ligand_comp.add_representation('ball+stick', radius=0.3)"
211
- ],
212
  "outputs": [],
213
- "execution_count": 13
 
 
214
  },
215
  {
216
  "cell_type": "code",
 
217
  "id": "b1cc7f44-a374-4400-b4ba-8f75101b21ce",
218
  "metadata": {
219
  "ExecuteTime": {
220
- "end_time": "2025-12-05T14:02:00.903573Z",
221
- "start_time": "2025-12-05T14:02:00.897038Z"
222
  }
223
  },
224
- "source": [
225
- "view"
226
- ],
227
  "outputs": [
228
  {
229
  "data": {
230
- "text/plain": [
231
- "NGLWidget()"
232
- ],
233
  "application/vnd.jupyter.widget-view+json": {
 
234
  "version_major": 2,
235
- "version_minor": 0,
236
- "model_id": "028b8398377e4869a80fba4c3d5e5921"
237
- }
 
 
238
  },
239
  "metadata": {},
240
- "output_type": "display_data",
241
- "jetTransient": {
242
- "display_id": null
243
- }
244
  }
245
  ],
246
- "execution_count": 14
 
 
247
  },
248
  {
249
  "cell_type": "code",
 
250
  "id": "5655e465-bb44-4218-a5e3-db2c5e62cd9c",
251
  "metadata": {
252
  "ExecuteTime": {
253
- "end_time": "2025-12-05T14:02:00.915090Z",
254
- "start_time": "2025-12-05T14:02:00.912563Z"
255
  }
256
  },
257
- "source": [],
258
  "outputs": [],
259
- "execution_count": null
260
  }
261
  ],
262
  "metadata": {
 
2
  "cells": [
3
  {
4
  "cell_type": "code",
5
+ "execution_count": 1,
6
  "id": "initial_id",
7
  "metadata": {
8
  "ExecuteTime": {
9
+ "end_time": "2026-01-24T09:06:36.981469Z",
10
+ "start_time": "2026-01-24T09:06:36.975634Z"
11
  }
12
  },
 
 
 
 
13
  "outputs": [
14
  {
15
  "data": {
 
16
  "application/vnd.jupyter.widget-view+json": {
17
+ "model_id": "5077355be9e64d4f814509a151b6c8b6",
18
  "version_major": 2,
19
+ "version_minor": 0
20
+ },
21
+ "text/plain": []
22
  },
23
  "metadata": {},
24
+ "output_type": "display_data"
 
 
 
25
  }
26
  ],
27
+ "source": [
28
+ "import nglview as nv\n",
29
+ "import os"
30
+ ]
31
  },
32
  {
33
  "cell_type": "code",
34
+ "execution_count": 2,
35
  "id": "d8d7978e-980a-400c-8c6a-5365990c8855",
36
  "metadata": {
37
  "ExecuteTime": {
38
+ "end_time": "2026-01-24T09:06:37.011231Z",
39
+ "start_time": "2026-01-24T09:06:37.005099Z"
40
  }
41
  },
42
+ "outputs": [],
43
  "source": [
44
  "PDBBIND_PATH = \"refined-set\""
45
+ ]
 
 
46
  },
47
  {
48
  "cell_type": "code",
49
+ "execution_count": 3,
50
  "id": "788a6b43-c515-45c7-bc52-341d446b1a65",
51
  "metadata": {
52
  "ExecuteTime": {
53
+ "end_time": "2026-01-24T09:06:37.022991Z",
54
+ "start_time": "2026-01-24T09:06:37.016849Z"
55
  }
56
  },
57
+ "outputs": [],
58
  "source": [
59
  "EXAMPLE_PDB_ID = \"1a1e\""
60
+ ]
 
 
61
  },
62
  {
63
  "cell_type": "code",
64
+ "execution_count": 4,
65
  "id": "e8f4bebc-845f-43e8-bc4d-ab7b649eb49c",
66
  "metadata": {
67
  "ExecuteTime": {
68
+ "end_time": "2026-01-24T09:06:37.041322Z",
69
+ "start_time": "2026-01-24T09:06:37.035944Z"
70
  }
71
  },
72
+ "outputs": [],
73
  "source": [
74
  "pdb_dir = os.path.join(PDBBIND_PATH, EXAMPLE_PDB_ID)"
75
+ ]
 
 
76
  },
77
  {
78
  "cell_type": "code",
79
+ "execution_count": 5,
80
  "id": "24b5e435-4d8f-4505-b27c-dd6317376ed4",
81
  "metadata": {
82
  "ExecuteTime": {
83
+ "end_time": "2026-01-24T09:06:37.064924Z",
84
+ "start_time": "2026-01-24T09:06:37.059278Z"
85
  }
86
  },
87
+ "outputs": [],
88
  "source": [
89
  "protein_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_protein.pdb\")"
90
+ ]
 
 
91
  },
92
  {
93
  "cell_type": "code",
94
+ "execution_count": 6,
95
  "id": "e7fc3539-00c0-48a2-b012-c80757fa12c4",
96
  "metadata": {
97
  "ExecuteTime": {
98
+ "end_time": "2026-01-24T09:06:37.080165Z",
99
+ "start_time": "2026-01-24T09:06:37.074657Z"
100
  }
101
  },
102
+ "outputs": [],
103
  "source": [
104
  "ligand_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_ligand.sdf\")"
105
+ ]
 
 
106
  },
107
  {
108
  "cell_type": "code",
109
+ "execution_count": 7,
110
  "id": "9a053b99-7c01-4881-b3f7-e9b39090af9d",
111
  "metadata": {
112
  "ExecuteTime": {
113
+ "end_time": "2026-01-24T09:06:37.126934Z",
114
+ "start_time": "2026-01-24T09:06:37.107047Z"
115
  }
116
  },
117
+ "outputs": [],
118
  "source": [
119
  "view = nv.NGLWidget()"
120
+ ]
 
 
121
  },
122
  {
123
  "cell_type": "code",
124
+ "execution_count": 8,
125
  "id": "df8c8e00-3ce6-41dd-b457-d9f50e318dad",
126
  "metadata": {
127
  "ExecuteTime": {
128
+ "end_time": "2026-01-24T09:06:37.209871Z",
129
+ "start_time": "2026-01-24T09:06:37.140785Z"
130
  }
131
  },
132
+ "outputs": [],
133
  "source": [
134
  "protein_comp = view.add_component(protein_file)"
135
+ ]
 
 
136
  },
137
  {
138
  "cell_type": "code",
139
+ "execution_count": 9,
140
  "id": "c191fead-fef8-4077-b787-5bf9552307b1",
141
  "metadata": {
142
  "ExecuteTime": {
143
+ "end_time": "2026-01-24T09:06:37.243271Z",
144
+ "start_time": "2026-01-24T09:06:37.235380Z"
145
  }
146
  },
147
+ "outputs": [],
148
  "source": [
149
  "protein_comp.clear_representations()"
150
+ ]
 
 
151
  },
152
  {
153
  "cell_type": "code",
154
+ "execution_count": 10,
155
  "id": "4559033a-aeda-4659-8d91-9002b5a6ecda",
156
  "metadata": {
157
  "ExecuteTime": {
158
+ "end_time": "2026-01-24T09:06:37.276519Z",
159
+ "start_time": "2026-01-24T09:06:37.270030Z"
160
  }
161
  },
 
 
 
162
  "outputs": [],
163
+ "source": [
164
+ "protein_comp.add_representation(\"cartoon\", color=\"blue\")"
165
+ ]
166
  },
167
  {
168
  "cell_type": "code",
169
+ "execution_count": 11,
170
  "id": "73ea1a50-8463-40b8-a942-0c92d3e97a97",
171
  "metadata": {
172
  "ExecuteTime": {
173
+ "end_time": "2026-01-24T09:06:37.309460Z",
174
+ "start_time": "2026-01-24T09:06:37.299153Z"
175
  }
176
  },
177
+ "outputs": [],
178
  "source": [
179
  "ligand_comp = view.add_component(ligand_file)"
180
+ ]
 
 
181
  },
182
  {
183
  "cell_type": "code",
184
+ "execution_count": 12,
185
  "id": "16cdb710-1ed6-4b1d-9e6a-69b7ad61a600",
186
  "metadata": {
187
  "ExecuteTime": {
188
+ "end_time": "2026-01-24T09:06:37.340286Z",
189
+ "start_time": "2026-01-24T09:06:37.333802Z"
190
  }
191
  },
192
+ "outputs": [],
193
  "source": [
194
  "ligand_comp.clear_representations()"
195
+ ]
 
 
196
  },
197
  {
198
  "cell_type": "code",
199
+ "execution_count": 13,
200
  "id": "2193c497-f33c-4de0-86a9-6e535002fcb7",
201
  "metadata": {
202
  "ExecuteTime": {
203
+ "end_time": "2026-01-24T09:06:37.372239Z",
204
+ "start_time": "2026-01-24T09:06:37.365156Z"
205
  }
206
  },
 
 
 
207
  "outputs": [],
208
+ "source": [
209
+ "ligand_comp.add_representation(\"ball+stick\", radius=0.3)"
210
+ ]
211
  },
212
  {
213
  "cell_type": "code",
214
+ "execution_count": 14,
215
  "id": "b1cc7f44-a374-4400-b4ba-8f75101b21ce",
216
  "metadata": {
217
  "ExecuteTime": {
218
+ "end_time": "2026-01-24T09:06:37.406445Z",
219
+ "start_time": "2026-01-24T09:06:37.398945Z"
220
  }
221
  },
 
 
 
222
  "outputs": [
223
  {
224
  "data": {
 
 
 
225
  "application/vnd.jupyter.widget-view+json": {
226
+ "model_id": "11e403e6733946b9b6942f47bff2464e",
227
  "version_major": 2,
228
+ "version_minor": 0
229
+ },
230
+ "text/plain": [
231
+ "NGLWidget()"
232
+ ]
233
  },
234
  "metadata": {},
235
+ "output_type": "display_data"
 
 
 
236
  }
237
  ],
238
+ "source": [
239
+ "view"
240
+ ]
241
  },
242
  {
243
  "cell_type": "code",
244
+ "execution_count": null,
245
  "id": "5655e465-bb44-4218-a5e3-db2c5e62cd9c",
246
  "metadata": {
247
  "ExecuteTime": {
248
+ "end_time": "2026-01-24T09:06:37.420258Z",
249
+ "start_time": "2026-01-24T09:06:37.416018Z"
250
  }
251
  },
 
252
  "outputs": [],
253
+ "source": []
254
  }
255
  ],
256
  "metadata": {