Commit ·
e33b6c9
1
Parent(s): 1390640
Updated all code
Browse files- .gitignore +7 -1
- GNN_classification/dataset/classification/EDA.ipynb +3 -1
- GNNs__practice.ipynb +0 -0
- all_inferences.py +225 -0
- dataset.py +53 -11
- inference.py +29 -12
- inference_attention.py +102 -0
- main.py +79 -0
- model.py +31 -15
- model_attention.py +143 -0
- model_pl.py +6 -6
- optuna_train.py +36 -9
- optuna_train_attention.py +132 -0
- templates/index.html +103 -0
- train.py +43 -21
- train_attention.py +180 -0
- train_pl.py +10 -8
- transformer_from_scratch/attention_visual.ipynb +62 -14
- transformer_from_scratch/config.py +4 -3
- transformer_from_scratch/dataset.py +45 -19
- transformer_from_scratch/inference.ipynb +23 -8
- transformer_from_scratch/train.py +137 -69
- transformer_from_scratch/translate.py +83 -26
- utils.py +308 -0
- visualization.ipynb +91 -97
.gitignore
CHANGED
|
@@ -1 +1,7 @@
|
|
| 1 |
-
.idea
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea
|
| 2 |
+
.venv
|
| 3 |
+
.ipynb_checkpoints
|
| 4 |
+
|
| 5 |
+
/refined-set/
|
| 6 |
+
/data
|
| 7 |
+
/lightning_logs
|
GNN_classification/dataset/classification/EDA.ipynb
CHANGED
|
@@ -126,7 +126,9 @@
|
|
| 126 |
}
|
| 127 |
},
|
| 128 |
"cell_type": "code",
|
| 129 |
-
"source":
|
|
|
|
|
|
|
| 130 |
"id": "355c3ed8e5f76bbf",
|
| 131 |
"outputs": [
|
| 132 |
{
|
|
|
|
| 126 |
}
|
| 127 |
},
|
| 128 |
"cell_type": "code",
|
| 129 |
+
"source": [
|
| 130 |
+
"train_dataset[\"label\"].value_counts()"
|
| 131 |
+
],
|
| 132 |
"id": "355c3ed8e5f76bbf",
|
| 133 |
"outputs": [
|
| 134 |
{
|
GNNs__practice.ipynb
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|
all_inferences.py
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import os.path
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
from torch_geometric.data import Data, Batch
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from rdkit.Chem import AllChem
|
| 8 |
+
import nglview as nv
|
| 9 |
+
import py3Dmol
|
| 10 |
+
from nglview import write_html
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import matplotlib
|
| 14 |
+
import matplotlib.cm as cm
|
| 15 |
+
import matplotlib.colors as mcolors
|
| 16 |
+
|
| 17 |
+
from dataset import get_atom_features, get_protein_features
|
| 18 |
+
from model_attention import BindingAffinityModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 22 |
+
MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
|
| 23 |
+
|
| 24 |
+
GAT_HEADS = 2
|
| 25 |
+
HIDDEN_CHANNELS = 256
|
| 26 |
+
|
| 27 |
+
def get_inference_data(ligand_smiles, protein_sequence, model_path):
|
| 28 |
+
"""
|
| 29 |
+
Returns:
|
| 30 |
+
- mol: RDKit molecule object with 3D coordinates
|
| 31 |
+
- importance: list of importance scores for each atom
|
| 32 |
+
- predicted_affinity: predicted binding affinity value
|
| 33 |
+
"""
|
| 34 |
+
# Prepare ligand molecule with geometry RDKit
|
| 35 |
+
mol = Chem.MolFromSmiles(ligand_smiles)
|
| 36 |
+
mol = Chem.AddHs(mol)
|
| 37 |
+
AllChem.EmbedMolecule(mol, randomSeed=42)
|
| 38 |
+
|
| 39 |
+
# Graph data PyTorch
|
| 40 |
+
atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
|
| 41 |
+
x = torch.tensor(np.array(atom_features), dtype=torch.float)
|
| 42 |
+
edge_index = []
|
| 43 |
+
for bond in mol.GetBonds():
|
| 44 |
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
| 45 |
+
edge_index.extend([(i, j), (j, i)])
|
| 46 |
+
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
| 47 |
+
|
| 48 |
+
tokens = [get_protein_features(c) for c in protein_sequence]
|
| 49 |
+
if len(tokens) > 1200: tokens = tokens[:1200]
|
| 50 |
+
else: tokens.extend([0] * (1200 - len(tokens)))
|
| 51 |
+
protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
|
| 52 |
+
|
| 53 |
+
data = Data(x=x, edge_index=edge_index)
|
| 54 |
+
batch = Batch.from_data_list([data]).to(DEVICE)
|
| 55 |
+
num_features = x.shape[1]
|
| 56 |
+
|
| 57 |
+
# Model loading
|
| 58 |
+
model = BindingAffinityModel(num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS).to(DEVICE)
|
| 59 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
# Prediction
|
| 63 |
+
with torch.no_grad():
|
| 64 |
+
pred = model(batch.x, batch.edge_index, batch.batch, protein_sequence)
|
| 65 |
+
attention_weights = model.cross_attention.last_attention_weights[0]
|
| 66 |
+
|
| 67 |
+
# Attention importance, Max + Normalize
|
| 68 |
+
real_prot_len = len([t for t in tokens if t != 0])
|
| 69 |
+
importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
|
| 70 |
+
|
| 71 |
+
# Normalize to [0, 1]
|
| 72 |
+
if importance.max() > 0:
|
| 73 |
+
importance = (importance - importance.min()) / (importance.max() - importance.min())
|
| 74 |
+
|
| 75 |
+
# Noise reduction
|
| 76 |
+
importance[importance < 0.01] = 0
|
| 77 |
+
return mol, importance, pred.item()
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def print_atom_scores(mol, importance):
|
| 81 |
+
print("Atom importance scores:")
|
| 82 |
+
|
| 83 |
+
atom_data = []
|
| 84 |
+
for i, score in enumerate(importance):
|
| 85 |
+
if score > 0.1:
|
| 86 |
+
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 87 |
+
atom_data.append((i, symbol, score))
|
| 88 |
+
|
| 89 |
+
atom_data.sort(key=lambda x: x[2], reverse=True)
|
| 90 |
+
|
| 91 |
+
for idx, symbol, score in atom_data:
|
| 92 |
+
fire = "🔥" if score > 0.8 else ("✨" if score > 0.5 else "")
|
| 93 |
+
print(f"Atom {idx} ({symbol}): Importance = {score:.3f} {fire}")
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def get_py3dmol(mol, importance, score):
|
| 98 |
+
|
| 99 |
+
view = py3Dmol.view(width=1000, height=800)
|
| 100 |
+
view.addModel(Chem.MolToMolBlock(mol), "sdf")
|
| 101 |
+
view.setBackgroundColor('white')
|
| 102 |
+
|
| 103 |
+
# 1. БАЗОВЫЙ СТИЛЬ (ГРУНТОВКА)
|
| 104 |
+
# Задаем единый размер для всей молекулы сразу
|
| 105 |
+
# scale: 0.25 — оптимальный средний размер
|
| 106 |
+
view.setStyle({}, {
|
| 107 |
+
'stick': {'color': '#cccccc', 'radius': 0.1},
|
| 108 |
+
'sphere': {'color': '#cccccc', 'scale': 0.25}
|
| 109 |
+
})
|
| 110 |
+
|
| 111 |
+
red_atoms = []
|
| 112 |
+
orange_atoms = []
|
| 113 |
+
blue_atoms = []
|
| 114 |
+
|
| 115 |
+
indices_sorted = np.argsort(importance)[::-1]
|
| 116 |
+
top_indices = set(indices_sorted[:15])
|
| 117 |
+
labels_to_add = []
|
| 118 |
+
|
| 119 |
+
conf = mol.GetConformer()
|
| 120 |
+
|
| 121 |
+
# 2. СОРТИРОВКА (ТОЛЬКО ЦВЕТА)
|
| 122 |
+
for i, val in enumerate(importance):
|
| 123 |
+
if val >= 0.70:
|
| 124 |
+
red_atoms.append(i)
|
| 125 |
+
elif val >= 0.55:
|
| 126 |
+
orange_atoms.append(i)
|
| 127 |
+
elif val >= 0.40:
|
| 128 |
+
blue_atoms.append(i)
|
| 129 |
+
|
| 130 |
+
if i in top_indices and val > 0.1:
|
| 131 |
+
pos = conf.GetAtomPosition(i)
|
| 132 |
+
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 133 |
+
labels_to_add.append({
|
| 134 |
+
'text': f"{i}:{symbol}:{val:.2f}",
|
| 135 |
+
'pos': {'x': pos.x, 'y': pos.y, 'z': pos.z}
|
| 136 |
+
})
|
| 137 |
+
|
| 138 |
+
# 3. ПРИМЕНЕНИЕ СТИЛЕЙ
|
| 139 |
+
# Обрати внимание: scale везде 0.25 (или 0.28, чтобы чуть выделить цветные)
|
| 140 |
+
# Мы меняем ТОЛЬКО ЦВЕТ.
|
| 141 |
+
|
| 142 |
+
if red_atoms:
|
| 143 |
+
view.addStyle({'serial': red_atoms}, {
|
| 144 |
+
'sphere': {'color': '#FF0000', 'scale': 0.28},
|
| 145 |
+
'stick': {'color': '#FF0000', 'radius': 0.12}
|
| 146 |
+
})
|
| 147 |
+
|
| 148 |
+
if orange_atoms:
|
| 149 |
+
view.addStyle({'serial': orange_atoms}, {
|
| 150 |
+
'sphere': {'color': '#FF8C00', 'scale': 0.28},
|
| 151 |
+
'stick': {'color': '#FF8C00', 'radius': 0.12}
|
| 152 |
+
})
|
| 153 |
+
|
| 154 |
+
if blue_atoms:
|
| 155 |
+
view.addStyle({'serial': blue_atoms}, {
|
| 156 |
+
'sphere': {'color': '#7777FF', 'scale': 0.28}
|
| 157 |
+
})
|
| 158 |
+
|
| 159 |
+
# 4. МЕТКИ
|
| 160 |
+
for label in labels_to_add:
|
| 161 |
+
view.addLabel(label['text'], {
|
| 162 |
+
'position': label['pos'],
|
| 163 |
+
'fontSize': 14,
|
| 164 |
+
'fontColor': 'white',
|
| 165 |
+
'backgroundColor': 'black',
|
| 166 |
+
'backgroundOpacity': 0.7,
|
| 167 |
+
'borderThickness': 0,
|
| 168 |
+
'inFront': True,
|
| 169 |
+
'showBackground': True
|
| 170 |
+
})
|
| 171 |
+
|
| 172 |
+
view.zoomTo()
|
| 173 |
+
view.addLabel(f"Predicted pKd: {float(score):.2f}",
|
| 174 |
+
{'position': {'x': -5, 'y': 10, 'z': 0}, 'backgroundColor': 'black', 'fontColor': 'white'})
|
| 175 |
+
|
| 176 |
+
return view
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def get_ngl(mol, importance):
|
| 180 |
+
pdb_temp = Chem.MolToPDBBlock(mol)
|
| 181 |
+
mol_pdb = Chem.MolFromPDBBlock(pdb_temp, removeHs=False)
|
| 182 |
+
|
| 183 |
+
for i, atom in enumerate(mol_pdb.GetAtoms()):
|
| 184 |
+
info = atom.GetPDBResidueInfo()
|
| 185 |
+
if info:
|
| 186 |
+
val = float(importance[i] * 100.0)
|
| 187 |
+
info.SetTempFactor(val)
|
| 188 |
+
final_pdb_block = Chem.MolToPDBBlock(mol_pdb)
|
| 189 |
+
structure = nv.TextStructure(final_pdb_block, ext="pdb")
|
| 190 |
+
view = nv.NGLWidget(structure)
|
| 191 |
+
view.clear_representations()
|
| 192 |
+
|
| 193 |
+
view.add_representation('ball+stick', colorScheme='bfactor', colorScale=['blue', 'white', 'red'], colorDomain=[10, 80], radiusScale=1.0)
|
| 194 |
+
|
| 195 |
+
indices_sorted = np.argsort(importance)[::-1]
|
| 196 |
+
top_indices = indices_sorted[:15]
|
| 197 |
+
|
| 198 |
+
selection_str = "@" + ",".join(map(str, top_indices))
|
| 199 |
+
view.add_representation('label',
|
| 200 |
+
selection=selection_str, # Подписываем только избранных
|
| 201 |
+
labelType='atomindex', # Показывать Индекс (0, 1, 2...)
|
| 202 |
+
color='black', # Черный текст
|
| 203 |
+
radius=2.0, # Размер шрифта (попробуйте 1.5 - 3.0)
|
| 204 |
+
zOffset=1.0) # Чуть сдвинуть к камере
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
view.center()
|
| 208 |
+
return view
|
| 209 |
+
|
| 210 |
+
if __name__ == "__main__":
|
| 211 |
+
smiles = "COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1"
|
| 212 |
+
protein = "PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF"
|
| 213 |
+
affinity = 11.92
|
| 214 |
+
|
| 215 |
+
file_name_py3dmol = "html_results/py3dmol_result.html"
|
| 216 |
+
file_name_ngl = "html_results/ngl_result.html"
|
| 217 |
+
|
| 218 |
+
mol, importance, score = get_inference_data(smiles, protein, MODEL_PATH)
|
| 219 |
+
print_atom_scores(mol, importance)
|
| 220 |
+
py3dmol_view = get_py3dmol(mol, importance, score)
|
| 221 |
+
py3dmol_view.write_html(file_name_py3dmol)
|
| 222 |
+
|
| 223 |
+
ngl_widget = get_ngl(mol, importance)
|
| 224 |
+
nv.write_html(file_name_ngl, ngl_widget)
|
| 225 |
+
|
dataset.py
CHANGED
|
@@ -65,6 +65,14 @@ def get_atom_features(atom):
|
|
| 65 |
degrees_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 66 |
numhs_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 67 |
implicit_valences_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
return np.array(
|
| 69 |
# Type of atom (Symbol)
|
| 70 |
one_of_k_encoding(atom.GetSymbol(), symbols_list)
|
|
@@ -93,22 +101,55 @@ def get_atom_features(atom):
|
|
| 93 |
+
|
| 94 |
# Aromaticity (Boolean)
|
| 95 |
[atom.GetIsAromatic()]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
)
|
| 97 |
|
|
|
|
| 98 |
def get_protein_features(char):
|
| 99 |
-
prot_vocab= {
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
|
| 108 |
class BindingDataset(Dataset):
|
| 109 |
def __init__(self, dataframe, max_seq_length=1000):
|
| 110 |
self.data = dataframe
|
| 111 |
-
self.max_seq_length =
|
|
|
|
|
|
|
| 112 |
|
| 113 |
def __len__(self):
|
| 114 |
return len(self.data)
|
|
@@ -144,9 +185,11 @@ class BindingDataset(Dataset):
|
|
| 144 |
# Protein (Sequence, tensor of integers)
|
| 145 |
tokens = [get_protein_features(char) for char in sequence]
|
| 146 |
if len(tokens) > self.max_seq_length:
|
| 147 |
-
tokens = tokens[:self.max_seq_length]
|
| 148 |
else:
|
| 149 |
-
tokens.extend(
|
|
|
|
|
|
|
| 150 |
protein_tensor = torch.tensor(tokens, dtype=torch.long)
|
| 151 |
|
| 152 |
# Affinity
|
|
@@ -164,4 +207,3 @@ if __name__ == "__main__":
|
|
| 164 |
|
| 165 |
print(len(train_dataset))
|
| 166 |
print(len(test_dataset))
|
| 167 |
-
|
|
|
|
| 65 |
degrees_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 66 |
numhs_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 67 |
implicit_valences_list = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
|
| 68 |
+
|
| 69 |
+
formal_charge_list = [-2, -1, 0, 1, 2]
|
| 70 |
+
chirality_list = [
|
| 71 |
+
Chem.rdchem.ChiralType.CHI_UNSPECIFIED,
|
| 72 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CW,
|
| 73 |
+
Chem.rdchem.ChiralType.CHI_TETRAHEDRAL_CCW,
|
| 74 |
+
Chem.rdchem.ChiralType.CHI_OTHER,
|
| 75 |
+
]
|
| 76 |
return np.array(
|
| 77 |
# Type of atom (Symbol)
|
| 78 |
one_of_k_encoding(atom.GetSymbol(), symbols_list)
|
|
|
|
| 101 |
+
|
| 102 |
# Aromaticity (Boolean)
|
| 103 |
[atom.GetIsAromatic()]
|
| 104 |
+
+
|
| 105 |
+
# Formal Charge
|
| 106 |
+
one_of_k_encoding(atom.GetFormalCharge(), formal_charge_list)
|
| 107 |
+
+
|
| 108 |
+
# Chirality (Geometry)
|
| 109 |
+
one_of_k_encoding(atom.GetChiralTag(), chirality_list)
|
| 110 |
+
+
|
| 111 |
+
# Is in ring (Boolean)
|
| 112 |
+
[atom.IsInRing()]
|
| 113 |
)
|
| 114 |
|
| 115 |
+
|
| 116 |
def get_protein_features(char):
|
| 117 |
+
prot_vocab = {
|
| 118 |
+
"A": 1,
|
| 119 |
+
"R": 2,
|
| 120 |
+
"N": 3,
|
| 121 |
+
"D": 4,
|
| 122 |
+
"C": 5,
|
| 123 |
+
"Q": 6,
|
| 124 |
+
"E": 7,
|
| 125 |
+
"G": 8,
|
| 126 |
+
"H": 9,
|
| 127 |
+
"I": 10,
|
| 128 |
+
"L": 11,
|
| 129 |
+
"K": 12,
|
| 130 |
+
"M": 13,
|
| 131 |
+
"F": 14,
|
| 132 |
+
"P": 15,
|
| 133 |
+
"S": 16,
|
| 134 |
+
"T": 17,
|
| 135 |
+
"W": 18,
|
| 136 |
+
"Y": 19,
|
| 137 |
+
"V": 20,
|
| 138 |
+
"X": 21,
|
| 139 |
+
"Z": 21,
|
| 140 |
+
"B": 21,
|
| 141 |
+
"PAD": 0,
|
| 142 |
+
"UNK": 21,
|
| 143 |
+
}
|
| 144 |
+
return prot_vocab.get(char, prot_vocab["UNK"])
|
| 145 |
|
| 146 |
|
| 147 |
class BindingDataset(Dataset):
|
| 148 |
def __init__(self, dataframe, max_seq_length=1000):
|
| 149 |
self.data = dataframe
|
| 150 |
+
self.max_seq_length = (
|
| 151 |
+
max_seq_length # Define a maximum sequence length for padding/truncation
|
| 152 |
+
)
|
| 153 |
|
| 154 |
def __len__(self):
|
| 155 |
return len(self.data)
|
|
|
|
| 185 |
# Protein (Sequence, tensor of integers)
|
| 186 |
tokens = [get_protein_features(char) for char in sequence]
|
| 187 |
if len(tokens) > self.max_seq_length:
|
| 188 |
+
tokens = tokens[: self.max_seq_length]
|
| 189 |
else:
|
| 190 |
+
tokens.extend(
|
| 191 |
+
[get_protein_features("PAD")] * (self.max_seq_length - len(tokens))
|
| 192 |
+
)
|
| 193 |
protein_tensor = torch.tensor(tokens, dtype=torch.long)
|
| 194 |
|
| 195 |
# Affinity
|
|
|
|
| 207 |
|
| 208 |
print(len(train_dataset))
|
| 209 |
print(len(test_dataset))
|
|
|
inference.py
CHANGED
|
@@ -10,9 +10,11 @@ from model import BindingAffinityModel
|
|
| 10 |
from tqdm import tqdm
|
| 11 |
from scipy.stats import pearsonr
|
| 12 |
from torch.utils.data import random_split
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
|
| 14 |
-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
-
MODEL_PATH = "best_model_gat.pth"
|
| 16 |
|
| 17 |
def set_seed(seed=42):
|
| 18 |
random.seed(seed)
|
|
@@ -21,11 +23,12 @@ def set_seed(seed=42):
|
|
| 21 |
np.random.seed(seed)
|
| 22 |
return torch.Generator().manual_seed(seed)
|
| 23 |
|
|
|
|
| 24 |
def predict_and_plot():
|
| 25 |
gen = set_seed(42)
|
| 26 |
print("Loading data...")
|
| 27 |
|
| 28 |
-
dataframe = pd.read_csv(
|
| 29 |
dataframe.dropna(inplace=True)
|
| 30 |
dataset = BindingDataset(dataframe)
|
| 31 |
if len(dataset) == 0:
|
|
@@ -40,7 +43,12 @@ def predict_and_plot():
|
|
| 40 |
num_features = test_dataset[0].x.shape[1]
|
| 41 |
|
| 42 |
print("Loading model...")
|
| 43 |
-
model = BindingAffinityModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
model.load_state_dict(torch.load(MODEL_PATH))
|
| 45 |
model.eval()
|
| 46 |
|
|
@@ -67,19 +75,28 @@ def predict_and_plot():
|
|
| 67 |
print(f"Pearson Correlation: {pearson_corr:.4f}")
|
| 68 |
|
| 69 |
plt.figure(figsize=(9, 9))
|
| 70 |
-
plt.scatter(y_true, y_pred, alpha=0.4, s=15, c=
|
| 71 |
-
plt.plot(
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
-
plt.xlabel(
|
| 75 |
-
plt.ylabel(
|
| 76 |
-
plt.title(
|
|
|
|
|
|
|
| 77 |
plt.legend()
|
| 78 |
plt.grid(True, alpha=0.3)
|
| 79 |
-
plot_file =
|
| 80 |
plt.savefig(plot_file)
|
| 81 |
print(f"График сохранен в {plot_file}")
|
| 82 |
plt.show()
|
| 83 |
|
|
|
|
| 84 |
if __name__ == "__main__":
|
| 85 |
-
predict_and_plot()
|
|
|
|
| 10 |
from tqdm import tqdm
|
| 11 |
from scipy.stats import pearsonr
|
| 12 |
from torch.utils.data import random_split
|
| 13 |
+
from train import GAT_HEADS, DROPOUT, HIDDEN_CHANNELS
|
| 14 |
+
|
| 15 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
MODEL_PATH = "runs/experiment_20260122_230138_GAT_without_formal_charge_chirality_ring_scheduler/models/model_ep092_mse2.3805.pth"
|
| 17 |
|
|
|
|
|
|
|
| 18 |
|
| 19 |
def set_seed(seed=42):
|
| 20 |
random.seed(seed)
|
|
|
|
| 23 |
np.random.seed(seed)
|
| 24 |
return torch.Generator().manual_seed(seed)
|
| 25 |
|
| 26 |
+
|
| 27 |
def predict_and_plot():
|
| 28 |
gen = set_seed(42)
|
| 29 |
print("Loading data...")
|
| 30 |
|
| 31 |
+
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 32 |
dataframe.dropna(inplace=True)
|
| 33 |
dataset = BindingDataset(dataframe)
|
| 34 |
if len(dataset) == 0:
|
|
|
|
| 43 |
num_features = test_dataset[0].x.shape[1]
|
| 44 |
|
| 45 |
print("Loading model...")
|
| 46 |
+
model = BindingAffinityModel(
|
| 47 |
+
num_node_features=num_features,
|
| 48 |
+
hidden_channels=HIDDEN_CHANNELS,
|
| 49 |
+
gat_heads=GAT_HEADS,
|
| 50 |
+
dropout=DROPOUT,
|
| 51 |
+
).to(DEVICE)
|
| 52 |
model.load_state_dict(torch.load(MODEL_PATH))
|
| 53 |
model.eval()
|
| 54 |
|
|
|
|
| 75 |
print(f"Pearson Correlation: {pearson_corr:.4f}")
|
| 76 |
|
| 77 |
plt.figure(figsize=(9, 9))
|
| 78 |
+
plt.scatter(y_true, y_pred, alpha=0.4, s=15, c="blue", label="Predictions")
|
| 79 |
+
plt.plot(
|
| 80 |
+
[min(y_true), max(y_true)],
|
| 81 |
+
[min(y_true), max(y_true)],
|
| 82 |
+
color="red",
|
| 83 |
+
linestyle="--",
|
| 84 |
+
linewidth=2,
|
| 85 |
+
label="Ideal",
|
| 86 |
+
)
|
| 87 |
|
| 88 |
+
plt.xlabel("Experimental Affinity (pK)")
|
| 89 |
+
plt.ylabel("Predicted Affinity (pK)")
|
| 90 |
+
plt.title(
|
| 91 |
+
f"Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}"
|
| 92 |
+
)
|
| 93 |
plt.legend()
|
| 94 |
plt.grid(True, alpha=0.3)
|
| 95 |
+
plot_file = "final_results_gat.png"
|
| 96 |
plt.savefig(plot_file)
|
| 97 |
print(f"График сохранен в {plot_file}")
|
| 98 |
plt.show()
|
| 99 |
|
| 100 |
+
|
| 101 |
if __name__ == "__main__":
|
| 102 |
+
predict_and_plot()
|
inference_attention.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch_geometric.loader import DataLoader
|
| 8 |
+
from dataset import BindingDataset
|
| 9 |
+
from model_attention import BindingAffinityModel
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from scipy.stats import pearsonr
|
| 12 |
+
from torch.utils.data import random_split
|
| 13 |
+
from train_attention import GAT_HEADS, DROPOUT, HIDDEN_CHANNELS
|
| 14 |
+
|
| 15 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 16 |
+
MODEL_PATH = "runs/experiment_attention20260123_103840_with_additional_data_scheduler/models/model_ep032_mse2.0264.pth"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def set_seed(seed=42):
|
| 20 |
+
random.seed(seed)
|
| 21 |
+
torch.manual_seed(seed)
|
| 22 |
+
torch.cuda.manual_seed(seed)
|
| 23 |
+
np.random.seed(seed)
|
| 24 |
+
return torch.Generator().manual_seed(seed)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def predict_and_plot():
|
| 28 |
+
gen = set_seed(42)
|
| 29 |
+
print("Loading data...")
|
| 30 |
+
|
| 31 |
+
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 32 |
+
dataframe.dropna(inplace=True)
|
| 33 |
+
dataset = BindingDataset(dataframe)
|
| 34 |
+
if len(dataset) == 0:
|
| 35 |
+
print("Dataset is empty")
|
| 36 |
+
return
|
| 37 |
+
|
| 38 |
+
train_size = int(0.8 * len(dataset))
|
| 39 |
+
test_size = len(dataset) - train_size
|
| 40 |
+
_, test_dataset = random_split(dataset, [train_size, test_size], generator=gen)
|
| 41 |
+
|
| 42 |
+
loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 43 |
+
num_features = test_dataset[0].x.shape[1]
|
| 44 |
+
|
| 45 |
+
print("Loading model...")
|
| 46 |
+
model = BindingAffinityModel(
|
| 47 |
+
num_node_features=num_features,
|
| 48 |
+
hidden_channels=HIDDEN_CHANNELS,
|
| 49 |
+
gat_heads=GAT_HEADS,
|
| 50 |
+
dropout=DROPOUT,
|
| 51 |
+
).to(DEVICE)
|
| 52 |
+
model.load_state_dict(torch.load(MODEL_PATH))
|
| 53 |
+
model.eval()
|
| 54 |
+
|
| 55 |
+
y_true = []
|
| 56 |
+
y_pred = []
|
| 57 |
+
print("Predicting...")
|
| 58 |
+
with torch.no_grad():
|
| 59 |
+
for batch in tqdm(loader):
|
| 60 |
+
batch = batch.to(DEVICE)
|
| 61 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 62 |
+
|
| 63 |
+
y_true.extend(batch.y.cpu().numpy())
|
| 64 |
+
y_pred.extend(out.squeeze().cpu().numpy())
|
| 65 |
+
y_true = np.array(y_true)
|
| 66 |
+
y_pred = np.array(y_pred)
|
| 67 |
+
|
| 68 |
+
rmse = np.sqrt(np.mean((y_true - y_pred) ** 2))
|
| 69 |
+
mae = np.mean(np.abs(y_true - y_pred))
|
| 70 |
+
pearson_corr, _ = pearsonr(y_true, y_pred) # Pearson correlation
|
| 71 |
+
|
| 72 |
+
print("Results:")
|
| 73 |
+
print(f"RMSE: {rmse:.4f}")
|
| 74 |
+
print(f"MAE: {mae:.4f}")
|
| 75 |
+
print(f"Pearson Correlation: {pearson_corr:.4f}")
|
| 76 |
+
|
| 77 |
+
plt.figure(figsize=(9, 9))
|
| 78 |
+
plt.scatter(y_true, y_pred, alpha=0.4, s=15, c="blue", label="Predictions")
|
| 79 |
+
plt.plot(
|
| 80 |
+
[min(y_true), max(y_true)],
|
| 81 |
+
[min(y_true), max(y_true)],
|
| 82 |
+
color="red",
|
| 83 |
+
linestyle="--",
|
| 84 |
+
linewidth=2,
|
| 85 |
+
label="Ideal",
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
plt.xlabel("Experimental Affinity (pK)")
|
| 89 |
+
plt.ylabel("Predicted Affinity (pK)")
|
| 90 |
+
plt.title(
|
| 91 |
+
f"Binding affinity Results\nRMSE={rmse:.3f}, Pearson R={pearson_corr:.3f}"
|
| 92 |
+
)
|
| 93 |
+
plt.legend()
|
| 94 |
+
plt.grid(True, alpha=0.3)
|
| 95 |
+
plot_file = "final_results_gat.png"
|
| 96 |
+
plt.savefig(plot_file)
|
| 97 |
+
print(f"График сохранен в {plot_file}")
|
| 98 |
+
plt.show()
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
predict_and_plot()
|
main.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import uuid
|
| 3 |
+
|
| 4 |
+
from fastapi import FastAPI, Request, Form
|
| 5 |
+
from fastapi.templating import Jinja2Templates
|
| 6 |
+
from fastapi.staticfiles import StaticFiles
|
| 7 |
+
from fastapi.responses import HTMLResponse
|
| 8 |
+
from utils import get_inference_data, get_py3dmol_view,save_standalone_ngl_html
|
| 9 |
+
import nglview as nv
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
app = FastAPI()
|
| 13 |
+
|
| 14 |
+
os.makedirs("html_results", exist_ok=True)
|
| 15 |
+
app.mount("/results", StaticFiles(directory="html_results"), name="results")
|
| 16 |
+
templates = Jinja2Templates(directory="templates")
|
| 17 |
+
|
| 18 |
+
@app.get("/", response_class=HTMLResponse)
|
| 19 |
+
async def read_root(request: Request):
|
| 20 |
+
return templates.TemplateResponse("index.html", {"request": request})
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@app.post("/predict", response_class=HTMLResponse)
|
| 24 |
+
async def predict(
|
| 25 |
+
request: Request,
|
| 26 |
+
smiles_ligand: str = Form(...),
|
| 27 |
+
sequence_protein: str = Form(...)
|
| 28 |
+
):
|
| 29 |
+
mol, importance, affinity = get_inference_data(smiles_ligand, sequence_protein)
|
| 30 |
+
|
| 31 |
+
atom_list = []
|
| 32 |
+
sorted_indices = sorted(range(len(importance)), key=lambda k: importance[k], reverse=True)
|
| 33 |
+
|
| 34 |
+
for idx in sorted_indices[:15]:
|
| 35 |
+
val = importance[idx]
|
| 36 |
+
symbol = mol.GetAtomWithIdx(idx).GetSymbol()
|
| 37 |
+
|
| 38 |
+
icon = ""
|
| 39 |
+
if val >= 0.9: icon = "🔥"
|
| 40 |
+
elif val >= 0.7: icon = "✨"
|
| 41 |
+
elif val >= 0.5: icon = "⭐"
|
| 42 |
+
atom_list.append({
|
| 43 |
+
"id": idx,
|
| 44 |
+
"symbol": symbol,
|
| 45 |
+
"score": f"{val:.3f}",
|
| 46 |
+
"icon": icon
|
| 47 |
+
})
|
| 48 |
+
|
| 49 |
+
unique_id = str(uuid.uuid4())
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
filename_ngl = f"ngl_{unique_id}.html"
|
| 53 |
+
filepath_ngl = os.path.join("html_results", filename_ngl)
|
| 54 |
+
|
| 55 |
+
py3dmol_view = get_py3dmol_view(mol, importance)
|
| 56 |
+
py3dmol_content = py3dmol_view._make_html()
|
| 57 |
+
|
| 58 |
+
# ngl_view = get_ngl_view(mol, importance)
|
| 59 |
+
# nv.write_html(filepath_ngl, ngl_view)
|
| 60 |
+
|
| 61 |
+
save_standalone_ngl_html(mol, importance, filepath_ngl)
|
| 62 |
+
|
| 63 |
+
ngl_url_link = f"/results/{filename_ngl}"
|
| 64 |
+
|
| 65 |
+
return templates.TemplateResponse("index.html", {
|
| 66 |
+
"request": request,
|
| 67 |
+
"result_ready": True,
|
| 68 |
+
"smiles": smiles_ligand,
|
| 69 |
+
"protein": sequence_protein,
|
| 70 |
+
"affinity": f"{affinity:.2f}",
|
| 71 |
+
"atom_list": atom_list,
|
| 72 |
+
"html_py3dmol": py3dmol_content,
|
| 73 |
+
"url_ngl": ngl_url_link
|
| 74 |
+
})
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
|
model.py
CHANGED
|
@@ -1,11 +1,9 @@
|
|
| 1 |
import math
|
| 2 |
-
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
-
|
| 6 |
-
|
| 7 |
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
|
| 8 |
|
|
|
|
| 9 |
class PositionalEncoding(nn.Module):
|
| 10 |
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
|
| 11 |
super().__init__()
|
|
@@ -34,11 +32,10 @@ class PositionalEncoding(nn.Module):
|
|
| 34 |
|
| 35 |
def forward(self, x):
|
| 36 |
# x: [batch_size, seq_len, d_model]
|
| 37 |
-
x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False)
|
| 38 |
return self.dropout(x)
|
| 39 |
|
| 40 |
|
| 41 |
-
|
| 42 |
# class LigandGNN(nn.Module): # GCN CONV
|
| 43 |
# def __init__(self, input_dim, hidden_channels):
|
| 44 |
# super().__init__()
|
|
@@ -70,8 +67,12 @@ class LigandGNN(nn.Module):
|
|
| 70 |
# Heads=4 means we use 4 attention heads
|
| 71 |
# Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
|
| 72 |
self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
|
| 73 |
-
self.conv2 = GATConv(
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
self.dropout = nn.Dropout(dropout)
|
| 76 |
|
| 77 |
def forward(self, x, edge_index, batch):
|
|
@@ -89,20 +90,23 @@ class LigandGNN(nn.Module):
|
|
| 89 |
x = global_mean_pool(x, batch)
|
| 90 |
return x
|
| 91 |
|
|
|
|
| 92 |
class ProteinTransformer(nn.Module):
|
| 93 |
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
|
| 94 |
super().__init__()
|
| 95 |
self.d_model = d_model
|
| 96 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 97 |
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
|
| 98 |
-
encoder_layer = nn.TransformerEncoderLayer(
|
|
|
|
|
|
|
| 99 |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
|
| 100 |
|
| 101 |
self.fc = nn.Linear(d_model, output_dim)
|
| 102 |
|
| 103 |
def forward(self, x):
|
| 104 |
# x: [batch_size, seq_len]
|
| 105 |
-
padding_mask =
|
| 106 |
x = self.embedding(x) * math.sqrt(self.d_model)
|
| 107 |
x = self.pos_encoder(x)
|
| 108 |
x = self.transformer(x, src_key_padding_mask=padding_mask)
|
|
@@ -116,20 +120,34 @@ class ProteinTransformer(nn.Module):
|
|
| 116 |
x = self.fc(x)
|
| 117 |
return x
|
| 118 |
|
|
|
|
| 119 |
class BindingAffinityModel(nn.Module):
|
| 120 |
-
def __init__(
|
|
|
|
|
|
|
| 121 |
super().__init__()
|
| 122 |
# Tower 1 - Ligand GNN
|
| 123 |
-
self.ligand_gnn = LigandGNN(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
# Tower 2 - Protein Transformer
|
| 125 |
-
self.protein_transformer = ProteinTransformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
self.head = nn.Sequential(
|
| 128 |
-
nn.Linear(hidden_channels*2, hidden_channels),
|
| 129 |
nn.ReLU(),
|
| 130 |
nn.Dropout(dropout),
|
| 131 |
nn.Linear(hidden_channels, 1),
|
| 132 |
)
|
|
|
|
| 133 |
def forward(self, x, edge_index, batch, protein_seq):
|
| 134 |
ligand_vec = self.ligand_gnn(x, edge_index, batch)
|
| 135 |
batch_size = batch.max().item() + 1
|
|
@@ -138,5 +156,3 @@ class BindingAffinityModel(nn.Module):
|
|
| 138 |
protein_vec = self.protein_transformer(protein_seq)
|
| 139 |
combined = torch.cat([ligand_vec, protein_vec], dim=1)
|
| 140 |
return self.head(combined)
|
| 141 |
-
|
| 142 |
-
|
|
|
|
| 1 |
import math
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
|
|
|
|
|
|
| 4 |
from torch_geometric.nn import GCNConv, GATConv, global_mean_pool
|
| 5 |
|
| 6 |
+
|
| 7 |
class PositionalEncoding(nn.Module):
|
| 8 |
def __init__(self, d_model: int, seq_len: int = 5000, dropout: float = 0.1):
|
| 9 |
super().__init__()
|
|
|
|
| 32 |
|
| 33 |
def forward(self, x):
|
| 34 |
# x: [batch_size, seq_len, d_model]
|
| 35 |
+
x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
|
| 36 |
return self.dropout(x)
|
| 37 |
|
| 38 |
|
|
|
|
| 39 |
# class LigandGNN(nn.Module): # GCN CONV
|
| 40 |
# def __init__(self, input_dim, hidden_channels):
|
| 41 |
# super().__init__()
|
|
|
|
| 67 |
# Heads=4 means we use 4 attention heads
|
| 68 |
# Concat=False, we average the heads instead of concatenating them, to keep the output dimension same as hidden_channels
|
| 69 |
self.conv1 = GATConv(input_dim, hidden_channels, heads=heads, concat=False)
|
| 70 |
+
self.conv2 = GATConv(
|
| 71 |
+
hidden_channels, hidden_channels, heads=heads, concat=False
|
| 72 |
+
)
|
| 73 |
+
self.conv3 = GATConv(
|
| 74 |
+
hidden_channels, hidden_channels, heads=heads, concat=False
|
| 75 |
+
)
|
| 76 |
self.dropout = nn.Dropout(dropout)
|
| 77 |
|
| 78 |
def forward(self, x, edge_index, batch):
|
|
|
|
| 90 |
x = global_mean_pool(x, batch)
|
| 91 |
return x
|
| 92 |
|
| 93 |
+
|
| 94 |
class ProteinTransformer(nn.Module):
|
| 95 |
def __init__(self, vocab_size, d_model=128, N=2, h=4, output_dim=128, dropout=0.2):
|
| 96 |
super().__init__()
|
| 97 |
self.d_model = d_model
|
| 98 |
self.embedding = nn.Embedding(vocab_size, d_model)
|
| 99 |
self.pos_encoder = PositionalEncoding(d_model, dropout=dropout)
|
| 100 |
+
encoder_layer = nn.TransformerEncoderLayer(
|
| 101 |
+
d_model=d_model, nhead=h, batch_first=True
|
| 102 |
+
)
|
| 103 |
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=N)
|
| 104 |
|
| 105 |
self.fc = nn.Linear(d_model, output_dim)
|
| 106 |
|
| 107 |
def forward(self, x):
|
| 108 |
# x: [batch_size, seq_len]
|
| 109 |
+
padding_mask = x == 0 # mask for PAD tokens
|
| 110 |
x = self.embedding(x) * math.sqrt(self.d_model)
|
| 111 |
x = self.pos_encoder(x)
|
| 112 |
x = self.transformer(x, src_key_padding_mask=padding_mask)
|
|
|
|
| 120 |
x = self.fc(x)
|
| 121 |
return x
|
| 122 |
|
| 123 |
+
|
| 124 |
class BindingAffinityModel(nn.Module):
|
| 125 |
+
def __init__(
|
| 126 |
+
self, num_node_features, hidden_channels=128, gat_heads=4, dropout=0.2
|
| 127 |
+
):
|
| 128 |
super().__init__()
|
| 129 |
# Tower 1 - Ligand GNN
|
| 130 |
+
self.ligand_gnn = LigandGNN(
|
| 131 |
+
input_dim=num_node_features,
|
| 132 |
+
hidden_channels=hidden_channels,
|
| 133 |
+
heads=gat_heads,
|
| 134 |
+
dropout=dropout,
|
| 135 |
+
)
|
| 136 |
# Tower 2 - Protein Transformer
|
| 137 |
+
self.protein_transformer = ProteinTransformer(
|
| 138 |
+
vocab_size=26,
|
| 139 |
+
d_model=hidden_channels,
|
| 140 |
+
output_dim=hidden_channels,
|
| 141 |
+
dropout=dropout,
|
| 142 |
+
)
|
| 143 |
|
| 144 |
self.head = nn.Sequential(
|
| 145 |
+
nn.Linear(hidden_channels * 2, hidden_channels),
|
| 146 |
nn.ReLU(),
|
| 147 |
nn.Dropout(dropout),
|
| 148 |
nn.Linear(hidden_channels, 1),
|
| 149 |
)
|
| 150 |
+
|
| 151 |
def forward(self, x, edge_index, batch, protein_seq):
|
| 152 |
ligand_vec = self.ligand_gnn(x, edge_index, batch)
|
| 153 |
batch_size = batch.max().item() + 1
|
|
|
|
| 156 |
protein_vec = self.protein_transformer(protein_seq)
|
| 157 |
combined = torch.cat([ligand_vec, protein_vec], dim=1)
|
| 158 |
return self.head(combined)
|
|
|
|
|
|
model_attention.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch_geometric.nn import GATConv
|
| 4 |
+
from torch_geometric.utils import to_dense_batch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CrossAttentionLayer(nn.Module):
|
| 9 |
+
def __init__(self, feature_dim, num_heads=4, dropout=0.1):
|
| 10 |
+
super().__init__()
|
| 11 |
+
# Main attention layer
|
| 12 |
+
# Feature dim is the dimension of the hidden features
|
| 13 |
+
self.attention = nn.MultiheadAttention(
|
| 14 |
+
feature_dim, num_heads, dropout=dropout, batch_first=True
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Normalization layer for stabilizing training
|
| 18 |
+
self.norm = nn.LayerNorm(feature_dim)
|
| 19 |
+
|
| 20 |
+
# Feedforward network for further processing, classical transformer style
|
| 21 |
+
self.ff = nn.Sequential(
|
| 22 |
+
nn.Linear(feature_dim, feature_dim * 4),
|
| 23 |
+
nn.GELU(),
|
| 24 |
+
nn.Dropout(dropout),
|
| 25 |
+
nn.Linear(feature_dim * 4, feature_dim),
|
| 26 |
+
)
|
| 27 |
+
self.norm_ff = nn.LayerNorm(feature_dim)
|
| 28 |
+
self.last_attention_weights = None
|
| 29 |
+
|
| 30 |
+
def forward(self, ligand_features, protein_features, key_padding_mask=None):
|
| 31 |
+
# ligand_features: [Batch, Atoms, Dim] - atoms
|
| 32 |
+
# protein_features: [Batch, Residues, Dim] - amino acids
|
| 33 |
+
# Cross attention:
|
| 34 |
+
# Query = Ligand (What we want to find out)
|
| 35 |
+
# Key, Value = Protein (Where we look for information)
|
| 36 |
+
# Result: "Ligand enriched with knowledge about proteins"
|
| 37 |
+
attention_output, attn_weights = self.attention(
|
| 38 |
+
query=ligand_features,
|
| 39 |
+
key=protein_features,
|
| 40 |
+
value=protein_features,
|
| 41 |
+
key_padding_mask=key_padding_mask,
|
| 42 |
+
need_weights=True,
|
| 43 |
+
average_attn_weights=True,
|
| 44 |
+
)
|
| 45 |
+
self.last_attention_weights = attn_weights.detach().cpu()
|
| 46 |
+
|
| 47 |
+
# Residual connection (x + attention(x)) and normalization
|
| 48 |
+
ligand_features = self.norm(ligand_features + attention_output)
|
| 49 |
+
|
| 50 |
+
# Feedforward network with residual connection and normalization
|
| 51 |
+
ff_output = self.ff(ligand_features)
|
| 52 |
+
ligand_features = self.norm_ff(ligand_features + ff_output)
|
| 53 |
+
|
| 54 |
+
return ligand_features
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class BindingAffinityModel(nn.Module):
|
| 58 |
+
def __init__(
|
| 59 |
+
self, num_node_features, hidden_channels=256, gat_heads=2, dropout=0.3
|
| 60 |
+
):
|
| 61 |
+
super().__init__()
|
| 62 |
+
self.dropout = dropout
|
| 63 |
+
self.hidden_channels = hidden_channels
|
| 64 |
+
|
| 65 |
+
# Tower 1 - Ligand GNN with GAT layers, using 3 GAT layers, so that every atom can "see" up to 3 bonds away
|
| 66 |
+
self.gat1 = GATConv(
|
| 67 |
+
num_node_features, hidden_channels, heads=gat_heads, concat=False
|
| 68 |
+
)
|
| 69 |
+
self.gat2 = GATConv(
|
| 70 |
+
hidden_channels, hidden_channels, heads=gat_heads, concat=False
|
| 71 |
+
)
|
| 72 |
+
self.gat3 = GATConv(
|
| 73 |
+
hidden_channels, hidden_channels, heads=gat_heads, concat=False
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Tower 2 - Protein Transformer, 22 = 21 amino acids + 1 padding token PAD
|
| 77 |
+
self.protein_embedding = nn.Embedding(22, hidden_channels)
|
| 78 |
+
# Additional positional encoding (simple linear) to give the model information about the order
|
| 79 |
+
self.prot_conv = nn.Conv1d(
|
| 80 |
+
hidden_channels, hidden_channels, kernel_size=3, padding=1
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Cross-Attention Layer, atoms attending to amino acids
|
| 84 |
+
self.cross_attention = CrossAttentionLayer(
|
| 85 |
+
feature_dim=hidden_channels, num_heads=4, dropout=dropout
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
self.fc1 = nn.Linear(hidden_channels, hidden_channels)
|
| 89 |
+
self.fc2 = nn.Linear(hidden_channels, 1) # Final output for regression, pKd
|
| 90 |
+
|
| 91 |
+
def forward(self, x, edge_index, batch, protein_seq):
|
| 92 |
+
# Ligand GNN forward pass (Graph -> Node Embeddings)
|
| 93 |
+
x = F.elu(self.gat1(x, edge_index))
|
| 94 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 95 |
+
|
| 96 |
+
x = F.elu(self.gat2(x, edge_index))
|
| 97 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 98 |
+
|
| 99 |
+
x = F.elu(self.gat3(x, edge_index)) # [Total_Atoms, Hidden_Channels]
|
| 100 |
+
|
| 101 |
+
# Convert graph into tensor [Batch, Max_Atoms, Hidden_Channels]
|
| 102 |
+
# to_dense_batch adds zeros paddings where necessary
|
| 103 |
+
ligand_dense, ligand_mask = to_dense_batch(x, batch)
|
| 104 |
+
# ligand_dense: [Batch, Max_Atoms, Hidden_Channels]
|
| 105 |
+
# ligand_mask: [Batch, Max_Atoms] True where there is real atom, False where there is padding
|
| 106 |
+
|
| 107 |
+
batch_size = ligand_dense.size(0)
|
| 108 |
+
protein_seq = protein_seq.view(batch_size, -1) # [Batch, Seq_Len]
|
| 109 |
+
|
| 110 |
+
# Protein forward pass protein_seq: [Batch, Seq_Len]
|
| 111 |
+
p = self.protein_embedding(protein_seq) # [Batch, Seq_Len, Hidden_Channels]
|
| 112 |
+
|
| 113 |
+
# A simple convolution to understand local context in amino acids
|
| 114 |
+
p = p.permute(0, 2, 1) # Change to [Batch, Hidden_Channels, Seq_Len] for Conv1d
|
| 115 |
+
p = F.relu(self.prot_conv(p))
|
| 116 |
+
p = p.permute(0, 2, 1) # [Batch, Seq, Hidden_Channels]
|
| 117 |
+
|
| 118 |
+
# Mask for protein (where PAD=0, True, but MHA needs True where IGNOREME)
|
| 119 |
+
# In Pytorch MHA, the key_padding_mask should be True where we want to ignore
|
| 120 |
+
protein_pad_mask = protein_seq == 0
|
| 121 |
+
|
| 122 |
+
# Cross-Attention
|
| 123 |
+
x_cross = self.cross_attention(
|
| 124 |
+
ligand_dense, p, key_padding_mask=protein_pad_mask
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Pooling over atoms to get a single vector per molecule, considering only real atoms, ignoring paddings
|
| 128 |
+
# ligand mask True where real atom, False where padding
|
| 129 |
+
mask_expanded = ligand_mask.unsqueeze(-1) # [Batch, Max_Atoms, 1]
|
| 130 |
+
|
| 131 |
+
# Zero out the padded atom features
|
| 132 |
+
x_cross = x_cross * mask_expanded
|
| 133 |
+
|
| 134 |
+
# Sum the features of real atoms / number of real atoms to get the mean
|
| 135 |
+
sum_features = torch.sum(x_cross, dim=1) # [Batch, Hidden_Channels]
|
| 136 |
+
num_atoms = torch.sum(mask_expanded, dim=1) # [Batch, 1]
|
| 137 |
+
pooled_x = sum_features / (num_atoms + 1e-6) # Avoid division by zero
|
| 138 |
+
|
| 139 |
+
# MLP Head
|
| 140 |
+
out = F.relu(self.fc1(pooled_x))
|
| 141 |
+
out = F.dropout(out, p=self.dropout, training=self.training)
|
| 142 |
+
out = self.fc2(out)
|
| 143 |
+
return out
|
model_pl.py
CHANGED
|
@@ -8,19 +8,19 @@ from torch.optim import Adam
|
|
| 8 |
|
| 9 |
from model import LigandGNN, ProteinTransformer
|
| 10 |
|
|
|
|
| 11 |
class BindingAffinityModelPL(pl.LightningModule):
|
| 12 |
def __init__(self, num_node_features, hidden_channels_gnn, lr):
|
| 13 |
super().__init__()
|
| 14 |
-
self.save_hyperparameters()
|
| 15 |
self.lr = lr
|
| 16 |
|
| 17 |
-
self.ligand_gnn = LigandGNN(
|
|
|
|
|
|
|
| 18 |
self.protein_transformer = ProteinTransformer(vocab_size=26)
|
| 19 |
self.head = nn.Sequential(
|
| 20 |
-
nn.Linear(128 + 128, 256),
|
| 21 |
-
nn.ReLU(),
|
| 22 |
-
nn.Dropout(0.2),
|
| 23 |
-
nn.Linear(256, 1)
|
| 24 |
)
|
| 25 |
self.criterion = nn.MSELoss()
|
| 26 |
|
|
|
|
| 8 |
|
| 9 |
from model import LigandGNN, ProteinTransformer
|
| 10 |
|
| 11 |
+
|
| 12 |
class BindingAffinityModelPL(pl.LightningModule):
|
| 13 |
def __init__(self, num_node_features, hidden_channels_gnn, lr):
|
| 14 |
super().__init__()
|
| 15 |
+
self.save_hyperparameters() # Save hyperparameters for easy access
|
| 16 |
self.lr = lr
|
| 17 |
|
| 18 |
+
self.ligand_gnn = LigandGNN(
|
| 19 |
+
input_dim=num_node_features, hidden_channels=hidden_channels_gnn
|
| 20 |
+
)
|
| 21 |
self.protein_transformer = ProteinTransformer(vocab_size=26)
|
| 22 |
self.head = nn.Sequential(
|
| 23 |
+
nn.Linear(128 + 128, 256), nn.ReLU(), nn.Dropout(0.2), nn.Linear(256, 1)
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
self.criterion = nn.MSELoss()
|
| 26 |
|
optuna_train.py
CHANGED
|
@@ -9,10 +9,11 @@ from torch.utils.data import random_split
|
|
| 9 |
from dataset import BindingDataset
|
| 10 |
from model import BindingAffinityModel
|
| 11 |
|
| 12 |
-
DEVICE = torch.device(
|
| 13 |
N_TRIALS = 20
|
| 14 |
EPOCHS_PER_TRIAL = 15
|
| 15 |
|
|
|
|
| 16 |
def set_seed(seed=42):
|
| 17 |
random.seed(seed)
|
| 18 |
np.random.seed(seed)
|
|
@@ -20,7 +21,8 @@ def set_seed(seed=42):
|
|
| 20 |
torch.cuda.manual_seed(seed)
|
| 21 |
return torch.Generator().manual_seed(seed)
|
| 22 |
|
| 23 |
-
|
|
|
|
| 24 |
dataframe.dropna(inplace=True)
|
| 25 |
dataset = BindingDataset(dataframe)
|
| 26 |
|
|
@@ -28,9 +30,12 @@ gen = set_seed(42)
|
|
| 28 |
|
| 29 |
train_size = int(0.8 * len(dataset))
|
| 30 |
test_size = len(dataset) - train_size
|
| 31 |
-
train_dataset, test_dataset = random_split(
|
|
|
|
|
|
|
| 32 |
num_features = train_dataset[0].x.shape[1]
|
| 33 |
|
|
|
|
| 34 |
def train(model, loader, optimizer, criterion):
|
| 35 |
model.train()
|
| 36 |
for batch in loader:
|
|
@@ -62,28 +67,50 @@ def objective(trial):
|
|
| 62 |
|
| 63 |
# Learning
|
| 64 |
|
| 65 |
-
lr = trial.suggest_float(
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 68 |
|
| 69 |
-
model = BindingAffinityModel(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
criterion = nn.MSELoss()
|
| 73 |
|
| 74 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 75 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 76 |
|
|
|
|
|
|
|
| 77 |
for epoch in range(EPOCHS_PER_TRIAL):
|
| 78 |
train(model, train_loader, optimizer, criterion)
|
| 79 |
val_loss = test(model, test_loader, criterion)
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
trial.report(val_loss, epoch)
|
| 84 |
if trial.should_prune():
|
| 85 |
raise optuna.exceptions.TrialPruned()
|
| 86 |
-
return
|
| 87 |
|
| 88 |
|
| 89 |
if __name__ == "__main__":
|
|
@@ -93,7 +120,7 @@ if __name__ == "__main__":
|
|
| 93 |
pruner=optuna.pruners.MedianPruner(),
|
| 94 |
storage=storage_name,
|
| 95 |
study_name="binding_prediction_optimization",
|
| 96 |
-
load_if_exists=True
|
| 97 |
)
|
| 98 |
print("Start hyperparameter optimization...")
|
| 99 |
|
|
|
|
| 9 |
from dataset import BindingDataset
|
| 10 |
from model import BindingAffinityModel
|
| 11 |
|
| 12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
N_TRIALS = 20
|
| 14 |
EPOCHS_PER_TRIAL = 15
|
| 15 |
|
| 16 |
+
|
| 17 |
def set_seed(seed=42):
|
| 18 |
random.seed(seed)
|
| 19 |
np.random.seed(seed)
|
|
|
|
| 21 |
torch.cuda.manual_seed(seed)
|
| 22 |
return torch.Generator().manual_seed(seed)
|
| 23 |
|
| 24 |
+
|
| 25 |
+
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 26 |
dataframe.dropna(inplace=True)
|
| 27 |
dataset = BindingDataset(dataframe)
|
| 28 |
|
|
|
|
| 30 |
|
| 31 |
train_size = int(0.8 * len(dataset))
|
| 32 |
test_size = len(dataset) - train_size
|
| 33 |
+
train_dataset, test_dataset = random_split(
|
| 34 |
+
dataset, [train_size, test_size], generator=gen
|
| 35 |
+
)
|
| 36 |
num_features = train_dataset[0].x.shape[1]
|
| 37 |
|
| 38 |
+
|
| 39 |
def train(model, loader, optimizer, criterion):
|
| 40 |
model.train()
|
| 41 |
for batch in loader:
|
|
|
|
| 67 |
|
| 68 |
# Learning
|
| 69 |
|
| 70 |
+
lr = trial.suggest_float(
|
| 71 |
+
"lr", 1e-5, 1e-2, log=True
|
| 72 |
+
) # Learning rate from 0.00001 to 0.01
|
| 73 |
+
weight_decay = trial.suggest_float(
|
| 74 |
+
"weight_decay", 1e-6, 1e-3, log=True
|
| 75 |
+
) # Weight decay from 0.000001 to 0.001
|
| 76 |
batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
|
| 77 |
|
| 78 |
+
model = BindingAffinityModel(
|
| 79 |
+
num_node_features=num_features,
|
| 80 |
+
hidden_channels=hidden_dim,
|
| 81 |
+
gat_heads=gat_heads,
|
| 82 |
+
dropout=dropout,
|
| 83 |
+
).to(DEVICE)
|
| 84 |
|
| 85 |
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 86 |
+
|
| 87 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 88 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 89 |
+
)
|
| 90 |
criterion = nn.MSELoss()
|
| 91 |
|
| 92 |
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 93 |
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 94 |
|
| 95 |
+
best_val_loss = float("inf")
|
| 96 |
+
|
| 97 |
for epoch in range(EPOCHS_PER_TRIAL):
|
| 98 |
train(model, train_loader, optimizer, criterion)
|
| 99 |
val_loss = test(model, test_loader, criterion)
|
| 100 |
|
| 101 |
+
if val_loss < best_val_loss:
|
| 102 |
+
best_val_loss = val_loss
|
| 103 |
+
|
| 104 |
+
scheduler.step(val_loss)
|
| 105 |
+
|
| 106 |
+
print(
|
| 107 |
+
f"Trial {trial.number} | Epoch {epoch + 1}/{EPOCHS_PER_TRIAL} | Val Loss: {val_loss:.4f}"
|
| 108 |
+
)
|
| 109 |
|
| 110 |
trial.report(val_loss, epoch)
|
| 111 |
if trial.should_prune():
|
| 112 |
raise optuna.exceptions.TrialPruned()
|
| 113 |
+
return best_val_loss
|
| 114 |
|
| 115 |
|
| 116 |
if __name__ == "__main__":
|
|
|
|
| 120 |
pruner=optuna.pruners.MedianPruner(),
|
| 121 |
storage=storage_name,
|
| 122 |
study_name="binding_prediction_optimization",
|
| 123 |
+
load_if_exists=True,
|
| 124 |
)
|
| 125 |
print("Start hyperparameter optimization...")
|
| 126 |
|
optuna_train_attention.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import optuna
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
from torch_geometric.loader import DataLoader
|
| 8 |
+
from torch.utils.data import random_split
|
| 9 |
+
from dataset import BindingDataset
|
| 10 |
+
from model_attention import BindingAffinityModel
|
| 11 |
+
|
| 12 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
+
N_TRIALS = 50
|
| 14 |
+
MAX_EPOCHS_PER_TRIAL = 60
|
| 15 |
+
LOG_DIR = "runs"
|
| 16 |
+
DATA_CSV = "pdbbind_refined_dataset.csv"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def set_seed(seed=42):
|
| 20 |
+
random.seed(seed)
|
| 21 |
+
np.random.seed(seed)
|
| 22 |
+
torch.manual_seed(seed)
|
| 23 |
+
torch.cuda.manual_seed(seed)
|
| 24 |
+
return torch.Generator().manual_seed(seed)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
dataframe = pd.read_csv(DATA_CSV)
|
| 28 |
+
dataframe.dropna(inplace=True)
|
| 29 |
+
dataset = BindingDataset(dataframe, max_seq_length=1200)
|
| 30 |
+
|
| 31 |
+
gen = set_seed(42)
|
| 32 |
+
|
| 33 |
+
train_size = int(0.8 * len(dataset))
|
| 34 |
+
test_size = len(dataset) - train_size
|
| 35 |
+
train_dataset, test_dataset = random_split(
|
| 36 |
+
dataset, [train_size, test_size], generator=gen
|
| 37 |
+
)
|
| 38 |
+
num_features = train_dataset[0].x.shape[1]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def train(model, loader, optimizer, criterion):
|
| 42 |
+
model.train()
|
| 43 |
+
for batch in loader:
|
| 44 |
+
batch = batch.to(DEVICE)
|
| 45 |
+
optimizer.zero_grad()
|
| 46 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 47 |
+
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 48 |
+
loss.backward()
|
| 49 |
+
optimizer.step()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def test(model, loader, criterion):
|
| 53 |
+
model.eval()
|
| 54 |
+
total_loss = 0
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
for batch in loader:
|
| 57 |
+
batch = batch.to(DEVICE)
|
| 58 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 59 |
+
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 60 |
+
total_loss += loss.item()
|
| 61 |
+
return total_loss / len(loader)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def objective(trial):
|
| 65 |
+
# Architecture
|
| 66 |
+
hidden_dim = trial.suggest_categorical("hidden_dim", [128, 256])
|
| 67 |
+
gat_heads = trial.suggest_categorical("gat_heads", [2, 4])
|
| 68 |
+
dropout = trial.suggest_float("dropout", 0.2, 0.5)
|
| 69 |
+
|
| 70 |
+
# Learning
|
| 71 |
+
|
| 72 |
+
lr = trial.suggest_float(
|
| 73 |
+
"lr", 1e-5, 1e-3, log=True
|
| 74 |
+
) # Learning rate from 0.00001 to 0.001
|
| 75 |
+
weight_decay = trial.suggest_float(
|
| 76 |
+
"weight_decay", 1e-6, 1e-3, log=True
|
| 77 |
+
) # Weight decay from 0.000001 to 0.001
|
| 78 |
+
batch_size = 16
|
| 79 |
+
|
| 80 |
+
model = BindingAffinityModel(
|
| 81 |
+
num_node_features=num_features,
|
| 82 |
+
hidden_channels=hidden_dim,
|
| 83 |
+
gat_heads=gat_heads,
|
| 84 |
+
dropout=dropout,
|
| 85 |
+
).to(DEVICE)
|
| 86 |
+
|
| 87 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 88 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 89 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 90 |
+
)
|
| 91 |
+
criterion = nn.MSELoss()
|
| 92 |
+
|
| 93 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 94 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 95 |
+
|
| 96 |
+
best_val_loss = float("inf")
|
| 97 |
+
|
| 98 |
+
for epoch in range(MAX_EPOCHS_PER_TRIAL):
|
| 99 |
+
train(model, train_loader, optimizer, criterion)
|
| 100 |
+
val_loss = test(model, test_loader, criterion)
|
| 101 |
+
scheduler.step(val_loss)
|
| 102 |
+
|
| 103 |
+
if val_loss < best_val_loss:
|
| 104 |
+
best_val_loss = val_loss
|
| 105 |
+
print(
|
| 106 |
+
f"Trial {trial.number} | Epoch {epoch + 1}/{MAX_EPOCHS_PER_TRIAL} | Val Loss: {val_loss:.4f}"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
trial.report(val_loss, epoch)
|
| 110 |
+
if trial.should_prune():
|
| 111 |
+
raise optuna.exceptions.TrialPruned()
|
| 112 |
+
return best_val_loss
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
storage_name = "sqlite:///db.sqlite3"
|
| 117 |
+
study = optuna.create_study(
|
| 118 |
+
direction="minimize",
|
| 119 |
+
pruner=optuna.pruners.MedianPruner(n_min_trials=5, n_warmup_steps=10),
|
| 120 |
+
storage=storage_name,
|
| 121 |
+
study_name="binding_prediction_optimization_attentionV2",
|
| 122 |
+
load_if_exists=True,
|
| 123 |
+
)
|
| 124 |
+
print("Start hyperparameter optimization...")
|
| 125 |
+
|
| 126 |
+
study.optimize(objective, n_trials=N_TRIALS)
|
| 127 |
+
print("\n--- Optimization Finished ---")
|
| 128 |
+
print("Best parameters found: ", study.best_params)
|
| 129 |
+
print("Best Test MSE: ", study.best_value)
|
| 130 |
+
|
| 131 |
+
df_results = study.trials_dataframe()
|
| 132 |
+
df_results.to_csv("optuna_results_attention.csv")
|
templates/index.html
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="ru">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>BioBinding AI Vis</title>
|
| 7 |
+
<link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/css/bootstrap.min.css" rel="stylesheet">
|
| 8 |
+
<style>
|
| 9 |
+
body { background-color: #f4f6f9; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; }
|
| 10 |
+
.sidebar { background: white; border-right: 1px solid #dee2e6; height: 100vh; overflow-y: auto; }
|
| 11 |
+
.result-card { border: none; box-shadow: 0 4px 6px rgba(0,0,0,0.05); border-radius: 12px; margin-bottom: 20px; }
|
| 12 |
+
.affinity-score { font-size: 3rem; font-weight: bold; color: #4e73df; }
|
| 13 |
+
.atom-badge { font-size: 0.9rem; padding: 8px 12px; }
|
| 14 |
+
.mol-container { width: 100%; height: 600px; border-radius: 12px; overflow: hidden; border: 1px solid #ddd; }
|
| 15 |
+
iframe { width: 100%; height: 100%; border: none; }
|
| 16 |
+
</style>
|
| 17 |
+
</head>
|
| 18 |
+
<body>
|
| 19 |
+
|
| 20 |
+
<div class="container-fluid">
|
| 21 |
+
<div class="row">
|
| 22 |
+
<div class="col-md-3 sidebar p-4">
|
| 23 |
+
<h3 class="mb-4 text-primary">🧪 BioBind AI</h3>
|
| 24 |
+
|
| 25 |
+
<form action="/predict" method="post">
|
| 26 |
+
<div class="mb-3">
|
| 27 |
+
<label class="form-label fw-bold">Ligand (SMILES)</label>
|
| 28 |
+
<textarea class="form-control" name="smiles_ligand" rows="3" required>{{ smiles_ligand if smiles_ligand else 'COc1ccc(S(=O)(=O)N(CC(C)C)C[C@@H](O)[C@H](Cc2ccccc2)NC(=O)O[C@@H]2C[C@@H]3NC(=O)O[C@@H]3C2)cc1' }}</textarea>
|
| 29 |
+
</div>
|
| 30 |
+
|
| 31 |
+
<div class="mb-3">
|
| 32 |
+
<label class="form-label fw-bold">Protein Sequence</label>
|
| 33 |
+
<textarea class="form-control" name="sequence_protein" rows="3" required>{{ sequence_protein if sequence_protein else 'PQITLWKRPLVTIKIGGQLKEALLDTGADDTVIEEMSLPGRWKPKMIGGIGGFIKVRQYDQIIIEIAGHKAIGTVLVGPTPVNIIGRNLLTQIGATLNF' }}</textarea>
|
| 34 |
+
</div>
|
| 35 |
+
|
| 36 |
+
<button type="submit" class="btn btn-primary w-100 py-2">🔮 Calculate Binding</button>
|
| 37 |
+
</form>
|
| 38 |
+
|
| 39 |
+
{% if result_ready %}
|
| 40 |
+
<hr class="my-4">
|
| 41 |
+
<h5 class="mb-3">Top Important Atoms</h5>
|
| 42 |
+
<div class="list-group">
|
| 43 |
+
{% for atom in atom_list %}
|
| 44 |
+
<div class="list-group-item d-flex justify-content-between align-items-center">
|
| 45 |
+
<span>
|
| 46 |
+
<span class="fw-bold">#{{ atom.id }}</span> {{ atom.symbol }}
|
| 47 |
+
</span>
|
| 48 |
+
<span>
|
| 49 |
+
<span class="badge bg-light text-dark border">{{ atom.score }}</span>
|
| 50 |
+
<span>{{ atom.icon }}</span>
|
| 51 |
+
</span>
|
| 52 |
+
</div>
|
| 53 |
+
{% endfor %}
|
| 54 |
+
</div>
|
| 55 |
+
{% endif %}
|
| 56 |
+
</div>
|
| 57 |
+
|
| 58 |
+
<div class="col-md-9 p-4">
|
| 59 |
+
{% if result_ready %}
|
| 60 |
+
<div class="card result-card p-4 text-center">
|
| 61 |
+
<h2 class="text-muted">Predicted Binding Affinity (pKd)</h2>
|
| 62 |
+
<div class="affinity-score">{{ affinity }}</div>
|
| 63 |
+
</div>
|
| 64 |
+
|
| 65 |
+
<div class="card result-card p-3">
|
| 66 |
+
<ul class="nav nav-pills mb-3" id="pills-tab" role="tablist">
|
| 67 |
+
<li class="nav-item" role="presentation">
|
| 68 |
+
<button class="nav-link active" id="pills-py3dmol-tab" data-bs-toggle="pill" data-bs-target="#pills-py3dmol" type="button">🧬 Py3Dmol (High Contrast)</button>
|
| 69 |
+
</li>
|
| 70 |
+
<li class="nav-item" role="presentation">
|
| 71 |
+
<button class="nav-link" id="pills-ngl-tab" data-bs-toggle="pill" data-bs-target="#pills-ngl" type="button">🔬 NGLView</button>
|
| 72 |
+
</li>
|
| 73 |
+
</ul>
|
| 74 |
+
|
| 75 |
+
<div class="tab-content" id="pills-tabContent">
|
| 76 |
+
<div class="tab-pane fade show active" id="pills-py3dmol" role="tabpanel">
|
| 77 |
+
<div class="mol-container">
|
| 78 |
+
<iframe srcdoc="{{ html_py3dmol }}" style="width: 100%; height: 100%; border: none;"></iframe>
|
| 79 |
+
</div>
|
| 80 |
+
</div>
|
| 81 |
+
|
| 82 |
+
<div class="tab-pane fade" id="pills-ngl" role="tabpanel">
|
| 83 |
+
<div class="mol-container">
|
| 84 |
+
<iframe src="{{ url_ngl }}" style="width: 100%; height: 100%; border: none;"></iframe>
|
| 85 |
+
</div>
|
| 86 |
+
</div>
|
| 87 |
+
</div>
|
| 88 |
+
</div>
|
| 89 |
+
{% else %}
|
| 90 |
+
<div class="d-flex align-items-center justify-content-center h-100 text-muted">
|
| 91 |
+
<div class="text-center">
|
| 92 |
+
<h1>🧬 Ready to Analyze</h1>
|
| 93 |
+
<p>Enter SMILES and Protein sequence on the left to start.</p>
|
| 94 |
+
</div>
|
| 95 |
+
</div>
|
| 96 |
+
{% endif %}
|
| 97 |
+
</div>
|
| 98 |
+
</div>
|
| 99 |
+
</div>
|
| 100 |
+
|
| 101 |
+
<script src="https://cdn.jsdelivr.net/npm/bootstrap@5.3.0/dist/js/bootstrap.bundle.min.js"></script>
|
| 102 |
+
</body>
|
| 103 |
+
</html>
|
train.py
CHANGED
|
@@ -20,12 +20,14 @@ WEIGHT_DECAY = 7.06e-6
|
|
| 20 |
EPOCS = 100
|
| 21 |
DROPOUT = 0.325
|
| 22 |
GAT_HEADS = 2
|
|
|
|
| 23 |
|
| 24 |
-
DEVICE = torch.device(
|
| 25 |
-
LOG_DIR = f"runs/
|
| 26 |
TOP_K = 3
|
| 27 |
SAVES_DIR = LOG_DIR + "/models"
|
| 28 |
|
|
|
|
| 29 |
def set_seed(seed=42):
|
| 30 |
random.seed(seed)
|
| 31 |
torch.manual_seed(seed)
|
|
@@ -52,13 +54,14 @@ def train_epoch(epoch, model, loader, optimizer, criterion, writer):
|
|
| 52 |
total_loss += current_loss
|
| 53 |
|
| 54 |
global_step = (epoch - 1) * len(loader) + i
|
| 55 |
-
writer.add_scalar(
|
| 56 |
|
| 57 |
-
loop.set_postfix(loss
|
| 58 |
|
| 59 |
avg_loss = total_loss / len(loader)
|
| 60 |
return avg_loss
|
| 61 |
|
|
|
|
| 62 |
def evaluate(epoch, model, loader, criterion, writer):
|
| 63 |
model.eval()
|
| 64 |
total_loss = 0
|
|
@@ -70,9 +73,10 @@ def evaluate(epoch, model, loader, criterion, writer):
|
|
| 70 |
total_loss += loss.item()
|
| 71 |
|
| 72 |
avg_loss = total_loss / len(loader)
|
| 73 |
-
writer.add_scalar(
|
| 74 |
return avg_loss
|
| 75 |
|
|
|
|
| 76 |
def main():
|
| 77 |
gen = set_seed(42)
|
| 78 |
writer = SummaryWriter(LOG_DIR)
|
|
@@ -82,20 +86,21 @@ def main():
|
|
| 82 |
print(f"Logging to {LOG_DIR}...")
|
| 83 |
print(f"Model saves to {SAVES_DIR}...")
|
| 84 |
# Load dataset
|
| 85 |
-
dataframe = pd.read_csv(
|
| 86 |
dataframe.dropna(inplace=True)
|
| 87 |
print("Dataset loaded with {} samples".format(len(dataframe)))
|
| 88 |
-
dataset = BindingDataset(dataframe)
|
| 89 |
print("Dataset transformed with {} samples".format(len(dataset)))
|
| 90 |
|
| 91 |
if len(dataset) == 0:
|
| 92 |
print("Dataset is empty")
|
| 93 |
return
|
| 94 |
|
| 95 |
-
|
| 96 |
train_size = int(0.8 * len(dataset))
|
| 97 |
test_size = len(dataset) - train_size
|
| 98 |
-
train_dataset, test_dataset = random_split(
|
|
|
|
|
|
|
| 99 |
|
| 100 |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 101 |
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
|
@@ -104,40 +109,57 @@ def main():
|
|
| 104 |
|
| 105 |
model = BindingAffinityModel(
|
| 106 |
num_node_features=num_features,
|
| 107 |
-
hidden_channels=
|
| 108 |
gat_heads=GAT_HEADS,
|
| 109 |
-
dropout=DROPOUT
|
| 110 |
).to(DEVICE)
|
| 111 |
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
criterion = nn.MSELoss()
|
| 113 |
|
| 114 |
top_models = []
|
| 115 |
|
| 116 |
print(f"Starting training on {DEVICE}")
|
| 117 |
for epoch in range(1, EPOCS + 1):
|
| 118 |
-
train_loss = train_epoch(
|
|
|
|
|
|
|
| 119 |
test_loss = evaluate(epoch, model, test_loader, criterion, writer)
|
| 120 |
|
| 121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 122 |
|
| 123 |
filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
|
| 124 |
|
| 125 |
torch.save(model.state_dict(), filename)
|
| 126 |
-
top_models.append({
|
| 127 |
|
| 128 |
-
top_models.sort(key=lambda x: x[
|
| 129 |
|
| 130 |
if len(top_models) > TOP_K:
|
| 131 |
worst_model = top_models.pop()
|
| 132 |
-
os.remove(worst_model[
|
| 133 |
|
| 134 |
-
if any(m[
|
| 135 |
-
rank = [m[
|
| 136 |
-
print(f
|
| 137 |
else:
|
| 138 |
print("")
|
| 139 |
|
| 140 |
-
|
| 141 |
writer.close()
|
| 142 |
print("Training finished.")
|
| 143 |
print("Top models saved:")
|
|
@@ -146,4 +168,4 @@ def main():
|
|
| 146 |
|
| 147 |
|
| 148 |
if __name__ == "__main__":
|
| 149 |
-
main()
|
|
|
|
| 20 |
EPOCS = 100
|
| 21 |
DROPOUT = 0.325
|
| 22 |
GAT_HEADS = 2
|
| 23 |
+
HIDDEN_CHANNELS = 256
|
| 24 |
|
| 25 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 26 |
+
LOG_DIR = f"runs/experiment_scheduler{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 27 |
TOP_K = 3
|
| 28 |
SAVES_DIR = LOG_DIR + "/models"
|
| 29 |
|
| 30 |
+
|
| 31 |
def set_seed(seed=42):
|
| 32 |
random.seed(seed)
|
| 33 |
torch.manual_seed(seed)
|
|
|
|
| 54 |
total_loss += current_loss
|
| 55 |
|
| 56 |
global_step = (epoch - 1) * len(loader) + i
|
| 57 |
+
writer.add_scalar("Loss/Train_Step", current_loss, global_step)
|
| 58 |
|
| 59 |
+
loop.set_postfix(loss=loss.item())
|
| 60 |
|
| 61 |
avg_loss = total_loss / len(loader)
|
| 62 |
return avg_loss
|
| 63 |
|
| 64 |
+
|
| 65 |
def evaluate(epoch, model, loader, criterion, writer):
|
| 66 |
model.eval()
|
| 67 |
total_loss = 0
|
|
|
|
| 73 |
total_loss += loss.item()
|
| 74 |
|
| 75 |
avg_loss = total_loss / len(loader)
|
| 76 |
+
writer.add_scalar("Loss/Test", avg_loss, epoch)
|
| 77 |
return avg_loss
|
| 78 |
|
| 79 |
+
|
| 80 |
def main():
|
| 81 |
gen = set_seed(42)
|
| 82 |
writer = SummaryWriter(LOG_DIR)
|
|
|
|
| 86 |
print(f"Logging to {LOG_DIR}...")
|
| 87 |
print(f"Model saves to {SAVES_DIR}...")
|
| 88 |
# Load dataset
|
| 89 |
+
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 90 |
dataframe.dropna(inplace=True)
|
| 91 |
print("Dataset loaded with {} samples".format(len(dataframe)))
|
| 92 |
+
dataset = BindingDataset(dataframe, max_seq_length=1200)
|
| 93 |
print("Dataset transformed with {} samples".format(len(dataset)))
|
| 94 |
|
| 95 |
if len(dataset) == 0:
|
| 96 |
print("Dataset is empty")
|
| 97 |
return
|
| 98 |
|
|
|
|
| 99 |
train_size = int(0.8 * len(dataset))
|
| 100 |
test_size = len(dataset) - train_size
|
| 101 |
+
train_dataset, test_dataset = random_split(
|
| 102 |
+
dataset, [train_size, test_size], generator=gen
|
| 103 |
+
)
|
| 104 |
|
| 105 |
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 106 |
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
|
|
|
| 109 |
|
| 110 |
model = BindingAffinityModel(
|
| 111 |
num_node_features=num_features,
|
| 112 |
+
hidden_channels=HIDDEN_CHANNELS,
|
| 113 |
gat_heads=GAT_HEADS,
|
| 114 |
+
dropout=DROPOUT,
|
| 115 |
).to(DEVICE)
|
| 116 |
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
|
| 117 |
+
# factor of 0.5 means reducing lr to half when triggered
|
| 118 |
+
# patience of 5 means wait for 5 epochs before reducing lr
|
| 119 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 120 |
+
optimizer, mode="min", factor=0.5, patience=5
|
| 121 |
+
)
|
| 122 |
criterion = nn.MSELoss()
|
| 123 |
|
| 124 |
top_models = []
|
| 125 |
|
| 126 |
print(f"Starting training on {DEVICE}")
|
| 127 |
for epoch in range(1, EPOCS + 1):
|
| 128 |
+
train_loss = train_epoch(
|
| 129 |
+
epoch, model, train_loader, optimizer, criterion, writer
|
| 130 |
+
)
|
| 131 |
test_loss = evaluate(epoch, model, test_loader, criterion, writer)
|
| 132 |
|
| 133 |
+
old_lr = optimizer.param_groups[0]["lr"]
|
| 134 |
+
scheduler.step(test_loss)
|
| 135 |
+
new_lr = optimizer.param_groups[0]["lr"]
|
| 136 |
+
|
| 137 |
+
if new_lr != old_lr:
|
| 138 |
+
print(
|
| 139 |
+
f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
|
| 140 |
+
)
|
| 141 |
+
print(
|
| 142 |
+
f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
|
| 143 |
+
end="",
|
| 144 |
+
)
|
| 145 |
|
| 146 |
filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
|
| 147 |
|
| 148 |
torch.save(model.state_dict(), filename)
|
| 149 |
+
top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
|
| 150 |
|
| 151 |
+
top_models.sort(key=lambda x: x["loss"])
|
| 152 |
|
| 153 |
if len(top_models) > TOP_K:
|
| 154 |
worst_model = top_models.pop()
|
| 155 |
+
os.remove(worst_model["path"])
|
| 156 |
|
| 157 |
+
if any(m["epoch"] == epoch for m in top_models):
|
| 158 |
+
rank = [m["epoch"] for m in top_models].index(epoch) + 1
|
| 159 |
+
print(f"-- Model saved (Rank: {rank})")
|
| 160 |
else:
|
| 161 |
print("")
|
| 162 |
|
|
|
|
| 163 |
writer.close()
|
| 164 |
print("Training finished.")
|
| 165 |
print("Top models saved:")
|
|
|
|
| 168 |
|
| 169 |
|
| 170 |
if __name__ == "__main__":
|
| 171 |
+
main()
|
train_attention.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from torch.utils.data import random_split
|
| 7 |
+
from torch_geometric.loader import DataLoader
|
| 8 |
+
from dataset import BindingDataset
|
| 9 |
+
from model_attention import BindingAffinityModel
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 12 |
+
import numpy as np
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
import os
|
| 15 |
+
|
| 16 |
+
# 2.02
|
| 17 |
+
# BATCH_SIZE = 16
|
| 18 |
+
# LR = 0.00035 # Reduced learning rate
|
| 19 |
+
# WEIGHT_DECAY = 1e-5 # Slightly increased weight decay (regularization)
|
| 20 |
+
# EPOCHS = 100
|
| 21 |
+
# DROPOUT = 0.3 # Slightly reduced dropout
|
| 22 |
+
# GAT_HEADS = 2
|
| 23 |
+
# HIDDEN_CHANNELS = 256
|
| 24 |
+
|
| 25 |
+
# 1.90 from Optuna
|
| 26 |
+
BATCH_SIZE = 16
|
| 27 |
+
LR = 0.000034
|
| 28 |
+
WEIGHT_DECAY = 1e-6
|
| 29 |
+
DROPOUT = 0.26
|
| 30 |
+
EPOCHS = 100
|
| 31 |
+
HIDDEN_CHANNELS = 256
|
| 32 |
+
GAT_HEADS = 2
|
| 33 |
+
|
| 34 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 35 |
+
LOG_DIR = f"runs/experiment_attention{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
| 36 |
+
TOP_K = 3
|
| 37 |
+
SAVES_DIR = LOG_DIR + "/models"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def set_seed(seed=42):
|
| 41 |
+
random.seed(seed)
|
| 42 |
+
torch.manual_seed(seed)
|
| 43 |
+
torch.cuda.manual_seed(seed)
|
| 44 |
+
np.random.seed(seed)
|
| 45 |
+
return torch.Generator().manual_seed(seed)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def train_epoch(epoch, model, loader, optimizer, criterion, writer):
|
| 49 |
+
model.train()
|
| 50 |
+
total_loss = 0
|
| 51 |
+
|
| 52 |
+
loop = tqdm(loader, desc=f"Training epoch: {epoch}", leave=False)
|
| 53 |
+
for i, batch in enumerate(loop):
|
| 54 |
+
batch = batch.to(DEVICE)
|
| 55 |
+
optimizer.zero_grad()
|
| 56 |
+
|
| 57 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 58 |
+
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 59 |
+
|
| 60 |
+
loss.backward()
|
| 61 |
+
optimizer.step()
|
| 62 |
+
current_loss = loss.item()
|
| 63 |
+
total_loss += current_loss
|
| 64 |
+
|
| 65 |
+
global_step = (epoch - 1) * len(loader) + i
|
| 66 |
+
writer.add_scalar("Loss/Train_Step", current_loss, global_step)
|
| 67 |
+
|
| 68 |
+
loop.set_postfix(loss=loss.item())
|
| 69 |
+
|
| 70 |
+
avg_loss = total_loss / len(loader)
|
| 71 |
+
return avg_loss
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def evaluate(epoch, model, loader, criterion, writer):
|
| 75 |
+
model.eval()
|
| 76 |
+
total_loss = 0
|
| 77 |
+
with torch.no_grad():
|
| 78 |
+
for batch in tqdm(loader, desc=f"Evaluating epoch: {epoch}", leave=False):
|
| 79 |
+
batch = batch.to(DEVICE)
|
| 80 |
+
out = model(batch.x, batch.edge_index, batch.batch, batch.protein_seq)
|
| 81 |
+
loss = criterion(out.squeeze(), batch.y.squeeze())
|
| 82 |
+
total_loss += loss.item()
|
| 83 |
+
|
| 84 |
+
avg_loss = total_loss / len(loader)
|
| 85 |
+
writer.add_scalar("Loss/Test", avg_loss, epoch)
|
| 86 |
+
return avg_loss
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def main():
|
| 90 |
+
gen = set_seed(42)
|
| 91 |
+
writer = SummaryWriter(LOG_DIR)
|
| 92 |
+
|
| 93 |
+
if not os.path.exists(SAVES_DIR):
|
| 94 |
+
os.makedirs(SAVES_DIR)
|
| 95 |
+
print(f"Logging to {LOG_DIR}...")
|
| 96 |
+
print(f"Model saves to {SAVES_DIR}...")
|
| 97 |
+
# Load dataset
|
| 98 |
+
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 99 |
+
dataframe.dropna(inplace=True)
|
| 100 |
+
print("Dataset loaded with {} samples".format(len(dataframe)))
|
| 101 |
+
dataset = BindingDataset(dataframe, max_seq_length=1200)
|
| 102 |
+
print("Dataset transformed with {} samples".format(len(dataset)))
|
| 103 |
+
|
| 104 |
+
if len(dataset) == 0:
|
| 105 |
+
print("Dataset is empty")
|
| 106 |
+
return
|
| 107 |
+
|
| 108 |
+
train_size = int(0.8 * len(dataset))
|
| 109 |
+
test_size = len(dataset) - train_size
|
| 110 |
+
train_dataset, test_dataset = random_split(
|
| 111 |
+
dataset, [train_size, test_size], generator=gen
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
|
| 115 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
|
| 116 |
+
num_features = train_dataset[0].x.shape[1]
|
| 117 |
+
print("Number of node features:", num_features)
|
| 118 |
+
|
| 119 |
+
model = BindingAffinityModel(
|
| 120 |
+
num_node_features=num_features,
|
| 121 |
+
hidden_channels=HIDDEN_CHANNELS,
|
| 122 |
+
gat_heads=GAT_HEADS,
|
| 123 |
+
dropout=DROPOUT,
|
| 124 |
+
).to(DEVICE)
|
| 125 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
|
| 126 |
+
# factor of 0.5 means reducing lr to half when triggered
|
| 127 |
+
# patience of 8 means wait for 8 epochs before reducing lr
|
| 128 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
| 129 |
+
optimizer, mode="min", factor=0.5, patience=8
|
| 130 |
+
)
|
| 131 |
+
criterion = nn.MSELoss()
|
| 132 |
+
|
| 133 |
+
top_models = []
|
| 134 |
+
|
| 135 |
+
print(f"Starting training on {DEVICE}")
|
| 136 |
+
for epoch in range(1, EPOCHS + 1):
|
| 137 |
+
train_loss = train_epoch(
|
| 138 |
+
epoch, model, train_loader, optimizer, criterion, writer
|
| 139 |
+
)
|
| 140 |
+
test_loss = evaluate(epoch, model, test_loader, criterion, writer)
|
| 141 |
+
|
| 142 |
+
old_lr = optimizer.param_groups[0]["lr"]
|
| 143 |
+
scheduler.step(test_loss)
|
| 144 |
+
new_lr = optimizer.param_groups[0]["lr"]
|
| 145 |
+
|
| 146 |
+
if new_lr != old_lr:
|
| 147 |
+
print(
|
| 148 |
+
f"\nEpoch {epoch}: Scheduler reduced LR from {old_lr:.6f} to {new_lr:.6f}!"
|
| 149 |
+
)
|
| 150 |
+
print(
|
| 151 |
+
f"Epoch {epoch:02d} | LR: {new_lr:.6f} | Train: {train_loss:.4f} | Test: {test_loss:.4f}",
|
| 152 |
+
end="",
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
filename = f"{SAVES_DIR}/model_ep{epoch:03d}_mse{test_loss:.4f}.pth"
|
| 156 |
+
|
| 157 |
+
torch.save(model.state_dict(), filename)
|
| 158 |
+
top_models.append({"loss": test_loss, "path": filename, "epoch": epoch})
|
| 159 |
+
|
| 160 |
+
top_models.sort(key=lambda x: x["loss"])
|
| 161 |
+
|
| 162 |
+
if len(top_models) > TOP_K:
|
| 163 |
+
worst_model = top_models.pop()
|
| 164 |
+
os.remove(worst_model["path"])
|
| 165 |
+
|
| 166 |
+
if any(m["epoch"] == epoch for m in top_models):
|
| 167 |
+
rank = [m["epoch"] for m in top_models].index(epoch) + 1
|
| 168 |
+
print(f"-- Model saved (Rank: {rank})")
|
| 169 |
+
else:
|
| 170 |
+
print("")
|
| 171 |
+
|
| 172 |
+
writer.close()
|
| 173 |
+
print("Training finished.")
|
| 174 |
+
print("Top models saved:")
|
| 175 |
+
for i, m in enumerate(top_models):
|
| 176 |
+
print(f"{i + 1}. {m['path']} (MSE: {m['loss']:.4f})")
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
if __name__ == "__main__":
|
| 180 |
+
main()
|
train_pl.py
CHANGED
|
@@ -6,10 +6,11 @@ from torch.utils.data import random_split
|
|
| 6 |
from model_pl import BindingAffinityModelPL
|
| 7 |
import pandas as pd
|
| 8 |
|
|
|
|
| 9 |
def main():
|
| 10 |
lr = 0.0005
|
| 11 |
# Load dataset
|
| 12 |
-
dataframe = pd.read_csv(
|
| 13 |
dataframe.dropna(inplace=True)
|
| 14 |
print("Dataset loaded with {} samples".format(len(dataframe)))
|
| 15 |
dataset = BindingDataset(dataframe)
|
|
@@ -30,21 +31,22 @@ def main():
|
|
| 30 |
|
| 31 |
model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
|
| 32 |
checkpoint_callback = ModelCheckpoint(
|
| 33 |
-
monitor=
|
| 34 |
-
dirpath=
|
| 35 |
-
filename=
|
| 36 |
save_top_k=3,
|
| 37 |
-
mode=
|
| 38 |
)
|
| 39 |
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)
|
| 40 |
|
| 41 |
trainer = pl.Trainer(
|
| 42 |
max_epochs=20,
|
| 43 |
-
accelerator="auto",
|
| 44 |
devices=1,
|
| 45 |
-
callbacks=[checkpoint_callback, early_stop_callback]
|
| 46 |
)
|
| 47 |
trainer.fit(model, train_loader, val_loader)
|
| 48 |
|
|
|
|
| 49 |
if __name__ == "__main__":
|
| 50 |
-
main()
|
|
|
|
| 6 |
from model_pl import BindingAffinityModelPL
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
+
|
| 10 |
def main():
|
| 11 |
lr = 0.0005
|
| 12 |
# Load dataset
|
| 13 |
+
dataframe = pd.read_csv("pdbbind_refined_dataset.csv")
|
| 14 |
dataframe.dropna(inplace=True)
|
| 15 |
print("Dataset loaded with {} samples".format(len(dataframe)))
|
| 16 |
dataset = BindingDataset(dataframe)
|
|
|
|
| 31 |
|
| 32 |
model = BindingAffinityModelPL(num_node_features=84, hidden_channels_gnn=128, lr=lr)
|
| 33 |
checkpoint_callback = ModelCheckpoint(
|
| 34 |
+
monitor="val_loss",
|
| 35 |
+
dirpath="checkpoints/",
|
| 36 |
+
filename="best-checkpoint",
|
| 37 |
save_top_k=3,
|
| 38 |
+
mode="min",
|
| 39 |
)
|
| 40 |
early_stop_callback = EarlyStopping(monitor="val_loss", patience=5)
|
| 41 |
|
| 42 |
trainer = pl.Trainer(
|
| 43 |
max_epochs=20,
|
| 44 |
+
accelerator="auto", # Use GPU if available
|
| 45 |
devices=1,
|
| 46 |
+
callbacks=[checkpoint_callback, early_stop_callback],
|
| 47 |
)
|
| 48 |
trainer.fit(model, train_loader, val_loader)
|
| 49 |
|
| 50 |
+
|
| 51 |
if __name__ == "__main__":
|
| 52 |
+
main()
|
transformer_from_scratch/attention_visual.ipynb
CHANGED
|
@@ -21,6 +21,7 @@
|
|
| 21 |
"import pandas as pd\n",
|
| 22 |
"import numpy as np\n",
|
| 23 |
"import warnings\n",
|
|
|
|
| 24 |
"warnings.filterwarnings(\"ignore\")"
|
| 25 |
]
|
| 26 |
},
|
|
@@ -72,12 +73,14 @@
|
|
| 72 |
"source": [
|
| 73 |
"config = get_config()\n",
|
| 74 |
"train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n",
|
| 75 |
-
"model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(
|
|
|
|
|
|
|
| 76 |
"\n",
|
| 77 |
"# Load the pretrained weights\n",
|
| 78 |
"model_filename = get_weights_file_path(config, f\"34\")\n",
|
| 79 |
"state = torch.load(model_filename)\n",
|
| 80 |
-
"model.load_state_dict(state[
|
| 81 |
]
|
| 82 |
},
|
| 83 |
{
|
|
@@ -95,16 +98,26 @@
|
|
| 95 |
" decoder_input = batch[\"decoder_input\"].to(device)\n",
|
| 96 |
" decoder_mask = batch[\"decoder_mask\"].to(device)\n",
|
| 97 |
"\n",
|
| 98 |
-
" encoder_input_tokens = [
|
| 99 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
"\n",
|
| 101 |
" # check that the batch size is 1\n",
|
| 102 |
-
" assert encoder_input.size(\n",
|
| 103 |
-
" 0) == 1, \"Batch size must be 1 for validation\"\n",
|
| 104 |
"\n",
|
| 105 |
" model_out = greedy_decode(\n",
|
| 106 |
-
" model
|
| 107 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
" return batch, encoder_input_tokens, decoder_input_tokens"
|
| 109 |
]
|
| 110 |
},
|
|
@@ -132,6 +145,7 @@
|
|
| 132 |
" columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
|
| 133 |
" )\n",
|
| 134 |
"\n",
|
|
|
|
| 135 |
"def get_attn_map(attn_type: str, layer: int, head: int):\n",
|
| 136 |
" if attn_type == \"encoder\":\n",
|
| 137 |
" attn = model.encoder.layers[layer].self_attention_block.attention_scores\n",
|
|
@@ -141,6 +155,7 @@
|
|
| 141 |
" attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n",
|
| 142 |
" return attn[0, head].data\n",
|
| 143 |
"\n",
|
|
|
|
| 144 |
"def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n",
|
| 145 |
" df = mtx2df(\n",
|
| 146 |
" get_attn_map(attn_type, layer, head),\n",
|
|
@@ -158,17 +173,29 @@
|
|
| 158 |
" color=\"value\",\n",
|
| 159 |
" tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
|
| 160 |
" )\n",
|
| 161 |
-
"
|
| 162 |
" .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n",
|
| 163 |
" .interactive()\n",
|
| 164 |
" )\n",
|
| 165 |
"\n",
|
| 166 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 167 |
" charts = []\n",
|
| 168 |
" for layer in layers:\n",
|
| 169 |
" rowCharts = []\n",
|
| 170 |
" for head in heads:\n",
|
| 171 |
-
" rowCharts.append(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
" charts.append(alt.hconcat(*rowCharts))\n",
|
| 173 |
" return alt.vconcat(*charts)"
|
| 174 |
]
|
|
@@ -287,7 +314,14 @@
|
|
| 287 |
"heads = [0, 1, 2, 3, 4, 5, 6, 7]\n",
|
| 288 |
"\n",
|
| 289 |
"# Encoder Self-Attention\n",
|
| 290 |
-
"get_all_attention_maps(\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 291 |
]
|
| 292 |
},
|
| 293 |
{
|
|
@@ -379,7 +413,14 @@
|
|
| 379 |
],
|
| 380 |
"source": [
|
| 381 |
"# Encoder Self-Attention\n",
|
| 382 |
-
"get_all_attention_maps(\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
]
|
| 384 |
},
|
| 385 |
{
|
|
@@ -471,7 +512,14 @@
|
|
| 471 |
],
|
| 472 |
"source": [
|
| 473 |
"# Encoder Self-Attention\n",
|
| 474 |
-
"get_all_attention_maps(\"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
]
|
| 476 |
},
|
| 477 |
{
|
|
|
|
| 21 |
"import pandas as pd\n",
|
| 22 |
"import numpy as np\n",
|
| 23 |
"import warnings\n",
|
| 24 |
+
"\n",
|
| 25 |
"warnings.filterwarnings(\"ignore\")"
|
| 26 |
]
|
| 27 |
},
|
|
|
|
| 73 |
"source": [
|
| 74 |
"config = get_config()\n",
|
| 75 |
"train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)\n",
|
| 76 |
+
"model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(\n",
|
| 77 |
+
" device\n",
|
| 78 |
+
")\n",
|
| 79 |
"\n",
|
| 80 |
"# Load the pretrained weights\n",
|
| 81 |
"model_filename = get_weights_file_path(config, f\"34\")\n",
|
| 82 |
"state = torch.load(model_filename)\n",
|
| 83 |
+
"model.load_state_dict(state[\"model_state_dict\"])"
|
| 84 |
]
|
| 85 |
},
|
| 86 |
{
|
|
|
|
| 98 |
" decoder_input = batch[\"decoder_input\"].to(device)\n",
|
| 99 |
" decoder_mask = batch[\"decoder_mask\"].to(device)\n",
|
| 100 |
"\n",
|
| 101 |
+
" encoder_input_tokens = [\n",
|
| 102 |
+
" vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()\n",
|
| 103 |
+
" ]\n",
|
| 104 |
+
" decoder_input_tokens = [\n",
|
| 105 |
+
" vocab_tgt.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()\n",
|
| 106 |
+
" ]\n",
|
| 107 |
"\n",
|
| 108 |
" # check that the batch size is 1\n",
|
| 109 |
+
" assert encoder_input.size(0) == 1, \"Batch size must be 1 for validation\"\n",
|
|
|
|
| 110 |
"\n",
|
| 111 |
" model_out = greedy_decode(\n",
|
| 112 |
+
" model,\n",
|
| 113 |
+
" encoder_input,\n",
|
| 114 |
+
" encoder_mask,\n",
|
| 115 |
+
" vocab_src,\n",
|
| 116 |
+
" vocab_tgt,\n",
|
| 117 |
+
" config[\"seq_len\"],\n",
|
| 118 |
+
" device,\n",
|
| 119 |
+
" )\n",
|
| 120 |
+
"\n",
|
| 121 |
" return batch, encoder_input_tokens, decoder_input_tokens"
|
| 122 |
]
|
| 123 |
},
|
|
|
|
| 145 |
" columns=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
|
| 146 |
" )\n",
|
| 147 |
"\n",
|
| 148 |
+
"\n",
|
| 149 |
"def get_attn_map(attn_type: str, layer: int, head: int):\n",
|
| 150 |
" if attn_type == \"encoder\":\n",
|
| 151 |
" attn = model.encoder.layers[layer].self_attention_block.attention_scores\n",
|
|
|
|
| 155 |
" attn = model.decoder.layers[layer].cross_attention_block.attention_scores\n",
|
| 156 |
" return attn[0, head].data\n",
|
| 157 |
"\n",
|
| 158 |
+
"\n",
|
| 159 |
"def attn_map(attn_type, layer, head, row_tokens, col_tokens, max_sentence_len):\n",
|
| 160 |
" df = mtx2df(\n",
|
| 161 |
" get_attn_map(attn_type, layer, head),\n",
|
|
|
|
| 173 |
" color=\"value\",\n",
|
| 174 |
" tooltip=[\"row\", \"column\", \"value\", \"row_token\", \"col_token\"],\n",
|
| 175 |
" )\n",
|
| 176 |
+
" # .title(f\"Layer {layer} Head {head}\")\n",
|
| 177 |
" .properties(height=400, width=400, title=f\"Layer {layer} Head {head}\")\n",
|
| 178 |
" .interactive()\n",
|
| 179 |
" )\n",
|
| 180 |
"\n",
|
| 181 |
+
"\n",
|
| 182 |
+
"def get_all_attention_maps(\n",
|
| 183 |
+
" attn_type: str,\n",
|
| 184 |
+
" layers: list[int],\n",
|
| 185 |
+
" heads: list[int],\n",
|
| 186 |
+
" row_tokens: list,\n",
|
| 187 |
+
" col_tokens,\n",
|
| 188 |
+
" max_sentence_len: int,\n",
|
| 189 |
+
"):\n",
|
| 190 |
" charts = []\n",
|
| 191 |
" for layer in layers:\n",
|
| 192 |
" rowCharts = []\n",
|
| 193 |
" for head in heads:\n",
|
| 194 |
+
" rowCharts.append(\n",
|
| 195 |
+
" attn_map(\n",
|
| 196 |
+
" attn_type, layer, head, row_tokens, col_tokens, max_sentence_len\n",
|
| 197 |
+
" )\n",
|
| 198 |
+
" )\n",
|
| 199 |
" charts.append(alt.hconcat(*rowCharts))\n",
|
| 200 |
" return alt.vconcat(*charts)"
|
| 201 |
]
|
|
|
|
| 314 |
"heads = [0, 1, 2, 3, 4, 5, 6, 7]\n",
|
| 315 |
"\n",
|
| 316 |
"# Encoder Self-Attention\n",
|
| 317 |
+
"get_all_attention_maps(\n",
|
| 318 |
+
" \"encoder\",\n",
|
| 319 |
+
" layers,\n",
|
| 320 |
+
" heads,\n",
|
| 321 |
+
" encoder_input_tokens,\n",
|
| 322 |
+
" encoder_input_tokens,\n",
|
| 323 |
+
" min(20, sentence_len),\n",
|
| 324 |
+
")"
|
| 325 |
]
|
| 326 |
},
|
| 327 |
{
|
|
|
|
| 413 |
],
|
| 414 |
"source": [
|
| 415 |
"# Encoder Self-Attention\n",
|
| 416 |
+
"get_all_attention_maps(\n",
|
| 417 |
+
" \"decoder\",\n",
|
| 418 |
+
" layers,\n",
|
| 419 |
+
" heads,\n",
|
| 420 |
+
" decoder_input_tokens,\n",
|
| 421 |
+
" decoder_input_tokens,\n",
|
| 422 |
+
" min(20, sentence_len),\n",
|
| 423 |
+
")"
|
| 424 |
]
|
| 425 |
},
|
| 426 |
{
|
|
|
|
| 512 |
],
|
| 513 |
"source": [
|
| 514 |
"# Encoder Self-Attention\n",
|
| 515 |
+
"get_all_attention_maps(\n",
|
| 516 |
+
" \"encoder-decoder\",\n",
|
| 517 |
+
" layers,\n",
|
| 518 |
+
" heads,\n",
|
| 519 |
+
" encoder_input_tokens,\n",
|
| 520 |
+
" decoder_input_tokens,\n",
|
| 521 |
+
" min(20, sentence_len),\n",
|
| 522 |
+
")"
|
| 523 |
]
|
| 524 |
},
|
| 525 |
{
|
transformer_from_scratch/config.py
CHANGED
|
@@ -14,14 +14,15 @@ def get_config():
|
|
| 14 |
"model_basename": "tmodel_",
|
| 15 |
"preload": None,
|
| 16 |
"tokenizer_file": "tokenizer_{0}.json",
|
| 17 |
-
"experiment_name": "runs/tmodel"
|
| 18 |
}
|
| 19 |
|
|
|
|
| 20 |
def get_weights_file_path(config, epoch):
|
| 21 |
model_folder = config["model_folder"]
|
| 22 |
model_basename = config["model_basename"]
|
| 23 |
model_filename = f"{model_basename}{epoch}.pt"
|
| 24 |
-
return str(Path(
|
| 25 |
|
| 26 |
|
| 27 |
def latest_weights_file_path(config):
|
|
@@ -31,4 +32,4 @@ def latest_weights_file_path(config):
|
|
| 31 |
if len(weights_files) == 0:
|
| 32 |
return None
|
| 33 |
weights_files.sort()
|
| 34 |
-
return str(weights_files[-1])
|
|
|
|
| 14 |
"model_basename": "tmodel_",
|
| 15 |
"preload": None,
|
| 16 |
"tokenizer_file": "tokenizer_{0}.json",
|
| 17 |
+
"experiment_name": "runs/tmodel",
|
| 18 |
}
|
| 19 |
|
| 20 |
+
|
| 21 |
def get_weights_file_path(config, epoch):
|
| 22 |
model_folder = config["model_folder"]
|
| 23 |
model_basename = config["model_basename"]
|
| 24 |
model_filename = f"{model_basename}{epoch}.pt"
|
| 25 |
+
return str(Path(".") / model_folder / model_filename)
|
| 26 |
|
| 27 |
|
| 28 |
def latest_weights_file_path(config):
|
|
|
|
| 32 |
if len(weights_files) == 0:
|
| 33 |
return None
|
| 34 |
weights_files.sort()
|
| 35 |
+
return str(weights_files[-1])
|
transformer_from_scratch/dataset.py
CHANGED
|
@@ -13,26 +13,34 @@ class BilingualDataset(Dataset):
|
|
| 13 |
self.src_lang = src_lang
|
| 14 |
self.tgt_lang = tgt_lang
|
| 15 |
|
| 16 |
-
self.sos_token = torch.tensor(
|
| 17 |
-
|
| 18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
def __len__(self):
|
| 21 |
return len(self.ds)
|
| 22 |
|
| 23 |
def __getitem__(self, index):
|
| 24 |
src_target_pair = self.ds[index]
|
| 25 |
-
src_text = src_target_pair[
|
| 26 |
-
tgt_text = src_target_pair[
|
| 27 |
|
| 28 |
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
|
| 29 |
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
|
| 30 |
|
| 31 |
-
enc_num_padding_tokens =
|
| 32 |
-
|
|
|
|
|
|
|
| 33 |
|
| 34 |
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
|
| 35 |
-
raise ValueError(
|
| 36 |
|
| 37 |
# Add SOS and EOS tokens to source text
|
| 38 |
encoder_input = torch.cat(
|
|
@@ -40,7 +48,9 @@ class BilingualDataset(Dataset):
|
|
| 40 |
self.sos_token,
|
| 41 |
torch.tensor(enc_input_tokens, dtype=torch.int64),
|
| 42 |
self.eos_token,
|
| 43 |
-
torch.tensor(
|
|
|
|
|
|
|
| 44 |
]
|
| 45 |
)
|
| 46 |
# Add SOS token to the decoder input
|
|
@@ -48,7 +58,9 @@ class BilingualDataset(Dataset):
|
|
| 48 |
[
|
| 49 |
self.sos_token,
|
| 50 |
torch.tensor(dec_input_tokens, dtype=torch.int64),
|
| 51 |
-
torch.tensor(
|
|
|
|
|
|
|
| 52 |
]
|
| 53 |
)
|
| 54 |
# Add EOS token to the label (what we want )
|
|
@@ -56,7 +68,9 @@ class BilingualDataset(Dataset):
|
|
| 56 |
[
|
| 57 |
torch.tensor(dec_input_tokens, dtype=torch.int64),
|
| 58 |
self.eos_token,
|
| 59 |
-
torch.tensor(
|
|
|
|
|
|
|
| 60 |
]
|
| 61 |
)
|
| 62 |
|
|
@@ -65,15 +79,27 @@ class BilingualDataset(Dataset):
|
|
| 65 |
assert label.size(0) == self.seq_len
|
| 66 |
|
| 67 |
return {
|
| 68 |
-
"encoder_input": encoder_input,
|
| 69 |
-
"decoder_input": decoder_input,
|
| 70 |
-
"encoder_mask": (encoder_input != self.pad_token)
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
"src_text": src_text,
|
| 74 |
-
"tgt_text": tgt_text
|
| 75 |
}
|
| 76 |
|
|
|
|
| 77 |
def casual_mask(size):
|
| 78 |
-
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(
|
| 79 |
-
|
|
|
|
|
|
|
|
|
| 13 |
self.src_lang = src_lang
|
| 14 |
self.tgt_lang = tgt_lang
|
| 15 |
|
| 16 |
+
self.sos_token = torch.tensor(
|
| 17 |
+
[tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64
|
| 18 |
+
)
|
| 19 |
+
self.eos_token = torch.tensor(
|
| 20 |
+
[tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64
|
| 21 |
+
)
|
| 22 |
+
self.pad_token = torch.tensor(
|
| 23 |
+
[tokenizer_src.token_to_id("[PAD]")], dtype=torch.int64
|
| 24 |
+
)
|
| 25 |
|
| 26 |
def __len__(self):
|
| 27 |
return len(self.ds)
|
| 28 |
|
| 29 |
def __getitem__(self, index):
|
| 30 |
src_target_pair = self.ds[index]
|
| 31 |
+
src_text = src_target_pair["translation"][self.src_lang]
|
| 32 |
+
tgt_text = src_target_pair["translation"][self.tgt_lang]
|
| 33 |
|
| 34 |
enc_input_tokens = self.tokenizer_src.encode(src_text).ids
|
| 35 |
dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids
|
| 36 |
|
| 37 |
+
enc_num_padding_tokens = (
|
| 38 |
+
self.seq_len - len(enc_input_tokens) - 2
|
| 39 |
+
) # for SOS and EOS
|
| 40 |
+
dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1 # for SOS
|
| 41 |
|
| 42 |
if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
|
| 43 |
+
raise ValueError("Sentence is too long")
|
| 44 |
|
| 45 |
# Add SOS and EOS tokens to source text
|
| 46 |
encoder_input = torch.cat(
|
|
|
|
| 48 |
self.sos_token,
|
| 49 |
torch.tensor(enc_input_tokens, dtype=torch.int64),
|
| 50 |
self.eos_token,
|
| 51 |
+
torch.tensor(
|
| 52 |
+
[self.pad_token] * enc_num_padding_tokens, dtype=torch.int64
|
| 53 |
+
),
|
| 54 |
]
|
| 55 |
)
|
| 56 |
# Add SOS token to the decoder input
|
|
|
|
| 58 |
[
|
| 59 |
self.sos_token,
|
| 60 |
torch.tensor(dec_input_tokens, dtype=torch.int64),
|
| 61 |
+
torch.tensor(
|
| 62 |
+
[self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
|
| 63 |
+
),
|
| 64 |
]
|
| 65 |
)
|
| 66 |
# Add EOS token to the label (what we want )
|
|
|
|
| 68 |
[
|
| 69 |
torch.tensor(dec_input_tokens, dtype=torch.int64),
|
| 70 |
self.eos_token,
|
| 71 |
+
torch.tensor(
|
| 72 |
+
[self.pad_token] * dec_num_padding_tokens, dtype=torch.int64
|
| 73 |
+
),
|
| 74 |
]
|
| 75 |
)
|
| 76 |
|
|
|
|
| 79 |
assert label.size(0) == self.seq_len
|
| 80 |
|
| 81 |
return {
|
| 82 |
+
"encoder_input": encoder_input, # (Seq_len)
|
| 83 |
+
"decoder_input": decoder_input, # (Seq_len)
|
| 84 |
+
"encoder_mask": (encoder_input != self.pad_token)
|
| 85 |
+
.unsqueeze(0)
|
| 86 |
+
.unsqueeze(0)
|
| 87 |
+
.int(), # (1, 1, Seq_len)
|
| 88 |
+
"decoder_mask": (decoder_input != self.pad_token)
|
| 89 |
+
.unsqueeze(0)
|
| 90 |
+
.unsqueeze(0)
|
| 91 |
+
.int()
|
| 92 |
+
& casual_mask(
|
| 93 |
+
decoder_input.size(0)
|
| 94 |
+
), # (1, Seq_len) & (1, Seq_len, Seq_len)
|
| 95 |
+
"label": label, # (Seq_len)
|
| 96 |
"src_text": src_text,
|
| 97 |
+
"tgt_text": tgt_text,
|
| 98 |
}
|
| 99 |
|
| 100 |
+
|
| 101 |
def casual_mask(size):
|
| 102 |
+
mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(
|
| 103 |
+
torch.int
|
| 104 |
+
) # Upper triangular matrix, above the main diagonal
|
| 105 |
+
return mask == 0
|
transformer_from_scratch/inference.ipynb
CHANGED
|
@@ -12,7 +12,7 @@
|
|
| 12 |
},
|
| 13 |
"source": [
|
| 14 |
"import torch\n",
|
| 15 |
-
"from config import get_config,latest_weights_file_path\n",
|
| 16 |
"from train import get_model, get_ds, run_validation\n",
|
| 17 |
"from translate import translate"
|
| 18 |
],
|
|
@@ -22,10 +22,10 @@
|
|
| 22 |
"evalue": "cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)",
|
| 23 |
"output_type": "error",
|
| 24 |
"traceback": [
|
| 25 |
-
"\
|
| 26 |
-
"\
|
| 27 |
-
"\
|
| 28 |
-
"\
|
| 29 |
]
|
| 30 |
}
|
| 31 |
],
|
|
@@ -42,12 +42,14 @@
|
|
| 42 |
"print(\"Using device:\", device)\n",
|
| 43 |
"config = get_config()\n",
|
| 44 |
"train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
|
| 45 |
-
"model = get_model(
|
|
|
|
|
|
|
| 46 |
"\n",
|
| 47 |
"# Load the pretrained weights\n",
|
| 48 |
"model_filename = latest_weights_file_path(config)\n",
|
| 49 |
"state = torch.load(model_filename)\n",
|
| 50 |
-
"model.load_state_dict(state[
|
| 51 |
],
|
| 52 |
"id": "e6b0b6022c4d1c15"
|
| 53 |
},
|
|
@@ -56,7 +58,20 @@
|
|
| 56 |
"cell_type": "code",
|
| 57 |
"outputs": [],
|
| 58 |
"execution_count": null,
|
| 59 |
-
"source":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
"id": "be2c2169c183a445"
|
| 61 |
},
|
| 62 |
{
|
|
|
|
| 12 |
},
|
| 13 |
"source": [
|
| 14 |
"import torch\n",
|
| 15 |
+
"from config import get_config, latest_weights_file_path\n",
|
| 16 |
"from train import get_model, get_ds, run_validation\n",
|
| 17 |
"from translate import translate"
|
| 18 |
],
|
|
|
|
| 22 |
"evalue": "cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)",
|
| 23 |
"output_type": "error",
|
| 24 |
"traceback": [
|
| 25 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 26 |
+
"\u001b[31mImportError\u001b[39m Traceback (most recent call last)",
|
| 27 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[2]\u001b[39m\u001b[32m, line 3\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;28;01mimport\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtorch\u001b[39;00m\n\u001b[32m 2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mconfig\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_config,latest_weights_file_path\n\u001b[32m----> \u001b[39m\u001b[32m3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtrain\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m get_model, get_ds, run_validation\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mtranslate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m translate\n",
|
| 28 |
+
"\u001b[31mImportError\u001b[39m: cannot import name 'get_model' from 'train' (C:\\Users\\Alex\\Desktop\\binding_affinity\\train.py)"
|
| 29 |
]
|
| 30 |
}
|
| 31 |
],
|
|
|
|
| 42 |
"print(\"Using device:\", device)\n",
|
| 43 |
"config = get_config()\n",
|
| 44 |
"train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)\n",
|
| 45 |
+
"model = get_model(\n",
|
| 46 |
+
" config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()\n",
|
| 47 |
+
").to(device)\n",
|
| 48 |
"\n",
|
| 49 |
"# Load the pretrained weights\n",
|
| 50 |
"model_filename = latest_weights_file_path(config)\n",
|
| 51 |
"state = torch.load(model_filename)\n",
|
| 52 |
+
"model.load_state_dict(state[\"model_state_dict\"])"
|
| 53 |
],
|
| 54 |
"id": "e6b0b6022c4d1c15"
|
| 55 |
},
|
|
|
|
| 58 |
"cell_type": "code",
|
| 59 |
"outputs": [],
|
| 60 |
"execution_count": null,
|
| 61 |
+
"source": [
|
| 62 |
+
"run_validation(\n",
|
| 63 |
+
" model,\n",
|
| 64 |
+
" val_dataloader,\n",
|
| 65 |
+
" tokenizer_src,\n",
|
| 66 |
+
" tokenizer_tgt,\n",
|
| 67 |
+
" config[\"seq_len\"],\n",
|
| 68 |
+
" device,\n",
|
| 69 |
+
" lambda msg: print(msg),\n",
|
| 70 |
+
" 0,\n",
|
| 71 |
+
" None,\n",
|
| 72 |
+
" num_examples=10,\n",
|
| 73 |
+
")"
|
| 74 |
+
],
|
| 75 |
"id": "be2c2169c183a445"
|
| 76 |
},
|
| 77 |
{
|
transformer_from_scratch/train.py
CHANGED
|
@@ -19,19 +19,24 @@ from model import build_transformer
|
|
| 19 |
from config import get_weights_file_path, get_config
|
| 20 |
import warnings
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
# Precompute the encoder output and reuse it for every token we get from the decoder
|
| 27 |
encoder_output = model.encode(source, source_mask)
|
| 28 |
# Initialize the decoder input with the sos token
|
| 29 |
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
|
| 30 |
while True:
|
| 31 |
-
if decoder_input
|
| 32 |
break
|
| 33 |
# Build mask for the target (decoder input)
|
| 34 |
-
decoder_mask =
|
|
|
|
|
|
|
| 35 |
|
| 36 |
# Calculate the output of the decoder
|
| 37 |
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
|
@@ -40,15 +45,31 @@ def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_
|
|
| 40 |
prob = model.project(out[:, -1])
|
| 41 |
# Select the token with the highest probability (because it's a greedy search)
|
| 42 |
_, next_word = torch.max(prob, dim=1)
|
| 43 |
-
decoder_input = torch.cat(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
if next_word == eos_idx:
|
| 46 |
break
|
| 47 |
return decoder_input.squeeze(0)
|
| 48 |
|
| 49 |
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
model.eval()
|
| 53 |
count = 0
|
| 54 |
|
|
@@ -61,25 +82,33 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
|
|
| 61 |
with torch.no_grad():
|
| 62 |
for batch in validation_ds:
|
| 63 |
count += 1
|
| 64 |
-
encoder_input = batch[
|
| 65 |
-
encoder_mask = batch[
|
| 66 |
|
| 67 |
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
|
| 68 |
|
| 69 |
-
model_out = greedy_decode(
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
|
| 74 |
|
| 75 |
source_texts.append(source_text)
|
| 76 |
expected.append(target_text)
|
| 77 |
predicted.append(model_out_text)
|
| 78 |
|
| 79 |
-
print_msg(
|
| 80 |
-
print_msg(f
|
| 81 |
-
print_msg(f
|
| 82 |
-
print_msg(f
|
| 83 |
|
| 84 |
if count == num_examples:
|
| 85 |
break
|
|
@@ -91,25 +120,22 @@ def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len,
|
|
| 91 |
# Compute the char error rate
|
| 92 |
metric = CharErrorRate()
|
| 93 |
cer = metric(predicted, expected)
|
| 94 |
-
writer.add_scalar(
|
| 95 |
writer.flush()
|
| 96 |
|
| 97 |
# Compute the word error rate
|
| 98 |
metric = WordErrorRate()
|
| 99 |
wer = metric(predicted, expected)
|
| 100 |
-
writer.add_scalar(
|
| 101 |
writer.flush()
|
| 102 |
|
| 103 |
# Compute the BLEU metric
|
| 104 |
metric = BLEUScore()
|
| 105 |
bleu = metric(predicted, expected)
|
| 106 |
-
writer.add_scalar(
|
| 107 |
writer.flush()
|
| 108 |
|
| 109 |
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
def get_all_sentences(ds, lang):
|
| 114 |
for item in ds:
|
| 115 |
yield item["translation"][lang]
|
|
@@ -145,84 +171,117 @@ def get_ds(config):
|
|
| 145 |
val_ds_size = len(ds_raw) - train_ds_size
|
| 146 |
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
|
| 147 |
|
| 148 |
-
train_ds = BilingualDataset(
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
max_len_src = 0
|
| 152 |
max_len_tgt = 0
|
| 153 |
|
| 154 |
for item in ds_raw:
|
| 155 |
-
src_ids = tokenizer_src.encode(item[
|
| 156 |
-
tgt_ids = tokenizer_tgt.encode(item[
|
| 157 |
|
| 158 |
max_len_src = max(len(src_ids), max_len_src)
|
| 159 |
max_len_tgt = max(len(tgt_ids), max_len_tgt)
|
| 160 |
|
| 161 |
-
print(f
|
| 162 |
-
print(f
|
| 163 |
|
| 164 |
-
train_dataloader = DataLoader(
|
|
|
|
|
|
|
| 165 |
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
|
| 166 |
|
| 167 |
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
|
| 168 |
|
| 169 |
|
| 170 |
def get_model(config, vocab_src_len, vocab_tgt_len):
|
| 171 |
-
model = build_transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
return model
|
| 173 |
|
|
|
|
| 174 |
def train_model(config):
|
| 175 |
# Define the device
|
| 176 |
-
device = torch.device(
|
| 177 |
-
print(f
|
| 178 |
|
| 179 |
-
Path(config[
|
| 180 |
|
| 181 |
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
|
| 182 |
-
model = get_model(
|
|
|
|
|
|
|
| 183 |
|
| 184 |
# Tensorboard
|
| 185 |
-
writer = SummaryWriter(config[
|
| 186 |
|
| 187 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=config[
|
| 188 |
|
| 189 |
initial_epoch = 0
|
| 190 |
global_step = 0
|
| 191 |
|
| 192 |
-
if config[
|
| 193 |
-
model_filename = get_weights_file_path(config, config[
|
| 194 |
-
print(f
|
| 195 |
state = torch.load(model_filename)
|
| 196 |
-
initial_epoch = state[
|
| 197 |
-
optimizer.load_state_dict(state[
|
| 198 |
-
global_step = state[
|
| 199 |
|
| 200 |
-
loss_fn = torch.nn.CrossEntropyLoss(
|
|
|
|
|
|
|
| 201 |
|
| 202 |
-
for epoch in range(initial_epoch, config[
|
| 203 |
-
batch_iterator = tqdm(train_dataloader,desc=f
|
| 204 |
for batch in batch_iterator:
|
| 205 |
model.train()
|
| 206 |
-
encoder_input = batch[
|
| 207 |
-
decoder_input = batch[
|
| 208 |
|
| 209 |
-
encoder_mask = batch[
|
| 210 |
-
decoder_mask = batch[
|
| 211 |
|
| 212 |
# Run the tensors through the transformer model
|
| 213 |
-
encoder_output = model.encode(
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
|
| 216 |
-
proj_output = model.project(decoder_output)
|
| 217 |
-
label = batch[
|
| 218 |
|
| 219 |
# (B, Seq_len, tgt_vocab_size) --> (B * Seq_len, tgt_vocab_size)
|
| 220 |
-
loss = loss_fn(
|
|
|
|
|
|
|
| 221 |
|
| 222 |
-
batch_iterator.set_postfix({f
|
| 223 |
|
| 224 |
# Log the loss
|
| 225 |
-
writer.add_scalar(
|
| 226 |
writer.flush()
|
| 227 |
|
| 228 |
# Backpropagate the loss
|
|
@@ -234,23 +293,32 @@ def train_model(config):
|
|
| 234 |
|
| 235 |
global_step += 1
|
| 236 |
|
| 237 |
-
run_validation(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
# Save the model at the end of each epoch
|
| 240 |
-
model_filename = get_weights_file_path(config, f
|
| 241 |
torch.save(
|
| 242 |
{
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
},
|
|
|
|
|
|
|
| 248 |
|
| 249 |
|
| 250 |
if __name__ == "__main__":
|
| 251 |
-
warnings.filterwarnings(
|
| 252 |
config = get_config()
|
| 253 |
train_model(config)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
|
|
|
| 19 |
from config import get_weights_file_path, get_config
|
| 20 |
import warnings
|
| 21 |
|
| 22 |
+
|
| 23 |
+
def greedy_decode(
|
| 24 |
+
model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device
|
| 25 |
+
):
|
| 26 |
+
sos_idx = tokenizer_tgt.token_to_id("[SOS]")
|
| 27 |
+
eos_idx = tokenizer_tgt.token_to_id("[EOS]")
|
| 28 |
|
| 29 |
# Precompute the encoder output and reuse it for every token we get from the decoder
|
| 30 |
encoder_output = model.encode(source, source_mask)
|
| 31 |
# Initialize the decoder input with the sos token
|
| 32 |
decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
|
| 33 |
while True:
|
| 34 |
+
if decoder_input.size(1) == max_len:
|
| 35 |
break
|
| 36 |
# Build mask for the target (decoder input)
|
| 37 |
+
decoder_mask = (
|
| 38 |
+
casual_mask(decoder_input.size(1)).type_as(source_mask).to(device)
|
| 39 |
+
)
|
| 40 |
|
| 41 |
# Calculate the output of the decoder
|
| 42 |
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
|
|
|
| 45 |
prob = model.project(out[:, -1])
|
| 46 |
# Select the token with the highest probability (because it's a greedy search)
|
| 47 |
_, next_word = torch.max(prob, dim=1)
|
| 48 |
+
decoder_input = torch.cat(
|
| 49 |
+
[
|
| 50 |
+
decoder_input,
|
| 51 |
+
torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device),
|
| 52 |
+
],
|
| 53 |
+
dim=1,
|
| 54 |
+
)
|
| 55 |
|
| 56 |
if next_word == eos_idx:
|
| 57 |
break
|
| 58 |
return decoder_input.squeeze(0)
|
| 59 |
|
| 60 |
|
| 61 |
+
def run_validation(
|
| 62 |
+
model,
|
| 63 |
+
validation_ds,
|
| 64 |
+
tokenizer_src,
|
| 65 |
+
tokenizer_tgt,
|
| 66 |
+
max_len,
|
| 67 |
+
device,
|
| 68 |
+
print_msg,
|
| 69 |
+
global_step,
|
| 70 |
+
writer,
|
| 71 |
+
num_examples=2,
|
| 72 |
+
):
|
| 73 |
model.eval()
|
| 74 |
count = 0
|
| 75 |
|
|
|
|
| 82 |
with torch.no_grad():
|
| 83 |
for batch in validation_ds:
|
| 84 |
count += 1
|
| 85 |
+
encoder_input = batch["encoder_input"].to(device)
|
| 86 |
+
encoder_mask = batch["encoder_mask"].to(device)
|
| 87 |
|
| 88 |
assert encoder_input.size(0) == 1, "Batch size must be 1 for validation"
|
| 89 |
|
| 90 |
+
model_out = greedy_decode(
|
| 91 |
+
model,
|
| 92 |
+
encoder_input,
|
| 93 |
+
encoder_mask,
|
| 94 |
+
tokenizer_src,
|
| 95 |
+
tokenizer_tgt,
|
| 96 |
+
max_len,
|
| 97 |
+
device,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
source_text = batch["src_text"][0]
|
| 101 |
+
target_text = batch["tgt_text"][0]
|
| 102 |
model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())
|
| 103 |
|
| 104 |
source_texts.append(source_text)
|
| 105 |
expected.append(target_text)
|
| 106 |
predicted.append(model_out_text)
|
| 107 |
|
| 108 |
+
print_msg("-" * console_width)
|
| 109 |
+
print_msg(f"Source: {source_text}")
|
| 110 |
+
print_msg(f"Expected: {target_text}")
|
| 111 |
+
print_msg(f"Predicted: {model_out_text}")
|
| 112 |
|
| 113 |
if count == num_examples:
|
| 114 |
break
|
|
|
|
| 120 |
# Compute the char error rate
|
| 121 |
metric = CharErrorRate()
|
| 122 |
cer = metric(predicted, expected)
|
| 123 |
+
writer.add_scalar("validation cer", cer, global_step)
|
| 124 |
writer.flush()
|
| 125 |
|
| 126 |
# Compute the word error rate
|
| 127 |
metric = WordErrorRate()
|
| 128 |
wer = metric(predicted, expected)
|
| 129 |
+
writer.add_scalar("validation wer", wer, global_step)
|
| 130 |
writer.flush()
|
| 131 |
|
| 132 |
# Compute the BLEU metric
|
| 133 |
metric = BLEUScore()
|
| 134 |
bleu = metric(predicted, expected)
|
| 135 |
+
writer.add_scalar("validation BLEU", bleu, global_step)
|
| 136 |
writer.flush()
|
| 137 |
|
| 138 |
|
|
|
|
|
|
|
|
|
|
| 139 |
def get_all_sentences(ds, lang):
|
| 140 |
for item in ds:
|
| 141 |
yield item["translation"][lang]
|
|
|
|
| 171 |
val_ds_size = len(ds_raw) - train_ds_size
|
| 172 |
train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])
|
| 173 |
|
| 174 |
+
train_ds = BilingualDataset(
|
| 175 |
+
train_ds_raw,
|
| 176 |
+
tokenizer_src,
|
| 177 |
+
tokenizer_tgt,
|
| 178 |
+
config["lang_src"],
|
| 179 |
+
config["lang_tgt"],
|
| 180 |
+
config["seq_len"],
|
| 181 |
+
)
|
| 182 |
+
val_ds = BilingualDataset(
|
| 183 |
+
val_ds_raw,
|
| 184 |
+
tokenizer_src,
|
| 185 |
+
tokenizer_tgt,
|
| 186 |
+
config["lang_src"],
|
| 187 |
+
config["lang_tgt"],
|
| 188 |
+
config["seq_len"],
|
| 189 |
+
)
|
| 190 |
|
| 191 |
max_len_src = 0
|
| 192 |
max_len_tgt = 0
|
| 193 |
|
| 194 |
for item in ds_raw:
|
| 195 |
+
src_ids = tokenizer_src.encode(item["translation"][config["lang_src"]]).ids
|
| 196 |
+
tgt_ids = tokenizer_tgt.encode(item["translation"][config["lang_tgt"]]).ids
|
| 197 |
|
| 198 |
max_len_src = max(len(src_ids), max_len_src)
|
| 199 |
max_len_tgt = max(len(tgt_ids), max_len_tgt)
|
| 200 |
|
| 201 |
+
print(f"Max length of the source sentence: {max_len_src}")
|
| 202 |
+
print(f"Max length of the target sentence: {max_len_tgt}")
|
| 203 |
|
| 204 |
+
train_dataloader = DataLoader(
|
| 205 |
+
train_ds, batch_size=config["batch_size"], shuffle=True
|
| 206 |
+
)
|
| 207 |
val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)
|
| 208 |
|
| 209 |
return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt
|
| 210 |
|
| 211 |
|
| 212 |
def get_model(config, vocab_src_len, vocab_tgt_len):
|
| 213 |
+
model = build_transformer(
|
| 214 |
+
vocab_src_len,
|
| 215 |
+
vocab_tgt_len,
|
| 216 |
+
config["seq_len"],
|
| 217 |
+
config["seq_len"],
|
| 218 |
+
config["d_model"],
|
| 219 |
+
)
|
| 220 |
return model
|
| 221 |
|
| 222 |
+
|
| 223 |
def train_model(config):
|
| 224 |
# Define the device
|
| 225 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 226 |
+
print(f"using device: {device}")
|
| 227 |
|
| 228 |
+
Path(config["model_folder"]).mkdir(parents=True, exist_ok=True)
|
| 229 |
|
| 230 |
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
|
| 231 |
+
model = get_model(
|
| 232 |
+
config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()
|
| 233 |
+
).to(device)
|
| 234 |
|
| 235 |
# Tensorboard
|
| 236 |
+
writer = SummaryWriter(config["experiment_name"])
|
| 237 |
|
| 238 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], eps=1e-9)
|
| 239 |
|
| 240 |
initial_epoch = 0
|
| 241 |
global_step = 0
|
| 242 |
|
| 243 |
+
if config["preload"]:
|
| 244 |
+
model_filename = get_weights_file_path(config, config["preload"])
|
| 245 |
+
print(f"Preloading model {model_filename}")
|
| 246 |
state = torch.load(model_filename)
|
| 247 |
+
initial_epoch = state["epoch"] + 1
|
| 248 |
+
optimizer.load_state_dict(state["optimizer_state_dict"])
|
| 249 |
+
global_step = state["global_step"]
|
| 250 |
|
| 251 |
+
loss_fn = torch.nn.CrossEntropyLoss(
|
| 252 |
+
ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1
|
| 253 |
+
)
|
| 254 |
|
| 255 |
+
for epoch in range(initial_epoch, config["num_epochs"]):
|
| 256 |
+
batch_iterator = tqdm(train_dataloader, desc=f"Processing epoch {epoch:02d}")
|
| 257 |
for batch in batch_iterator:
|
| 258 |
model.train()
|
| 259 |
+
encoder_input = batch["encoder_input"].to(device) # (B, Seq_len)
|
| 260 |
+
decoder_input = batch["decoder_input"].to(device) # (B, Seq_len)
|
| 261 |
|
| 262 |
+
encoder_mask = batch["encoder_mask"].to(device) # (B, 1, 1, Seq_len)
|
| 263 |
+
decoder_mask = batch["decoder_mask"].to(device) # (B, 1, Seq_len, Seq_len)
|
| 264 |
|
| 265 |
# Run the tensors through the transformer model
|
| 266 |
+
encoder_output = model.encode(
|
| 267 |
+
encoder_input, encoder_mask
|
| 268 |
+
) # (B, Seq_len, d_model)
|
| 269 |
+
decoder_output = model.decode(
|
| 270 |
+
encoder_output, encoder_mask, decoder_input, decoder_mask
|
| 271 |
+
) # (B, Seq_len, d_model)
|
| 272 |
|
| 273 |
+
proj_output = model.project(decoder_output) # (B, Seq_len, tgt_vocab_size)
|
| 274 |
+
label = batch["label"].to(device) # (B, Seq_len)
|
| 275 |
|
| 276 |
# (B, Seq_len, tgt_vocab_size) --> (B * Seq_len, tgt_vocab_size)
|
| 277 |
+
loss = loss_fn(
|
| 278 |
+
proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1)
|
| 279 |
+
)
|
| 280 |
|
| 281 |
+
batch_iterator.set_postfix({f"loss": f"{loss.item(): 6.3f}"})
|
| 282 |
|
| 283 |
# Log the loss
|
| 284 |
+
writer.add_scalar("train loss", loss.item(), global_step)
|
| 285 |
writer.flush()
|
| 286 |
|
| 287 |
# Backpropagate the loss
|
|
|
|
| 293 |
|
| 294 |
global_step += 1
|
| 295 |
|
| 296 |
+
run_validation(
|
| 297 |
+
model,
|
| 298 |
+
val_dataloader,
|
| 299 |
+
tokenizer_src,
|
| 300 |
+
tokenizer_tgt,
|
| 301 |
+
config["seq_len"],
|
| 302 |
+
device,
|
| 303 |
+
lambda msg: batch_iterator.write(msg),
|
| 304 |
+
global_step,
|
| 305 |
+
writer,
|
| 306 |
+
)
|
| 307 |
|
| 308 |
# Save the model at the end of each epoch
|
| 309 |
+
model_filename = get_weights_file_path(config, f"{epoch:02d}")
|
| 310 |
torch.save(
|
| 311 |
{
|
| 312 |
+
"epoch": epoch,
|
| 313 |
+
"model_state_dict": model.state_dict(),
|
| 314 |
+
"optimizer_state_dict": optimizer.state_dict(),
|
| 315 |
+
"global_step": global_step,
|
| 316 |
+
},
|
| 317 |
+
model_filename,
|
| 318 |
+
)
|
| 319 |
|
| 320 |
|
| 321 |
if __name__ == "__main__":
|
| 322 |
+
warnings.filterwarnings("ignore")
|
| 323 |
config = get_config()
|
| 324 |
train_model(config)
|
|
|
|
|
|
|
|
|
transformer_from_scratch/translate.py
CHANGED
|
@@ -7,32 +7,53 @@ from dataset import BilingualDataset
|
|
| 7 |
import torch
|
| 8 |
import sys
|
| 9 |
|
|
|
|
| 10 |
def translate(sentence: str):
|
| 11 |
# Define the device, tokenizers, and model
|
| 12 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 13 |
print("Using device:", device)
|
| 14 |
config = get_config()
|
| 15 |
|
| 16 |
-
tokenizer_src = Tokenizer.from_file(
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
model = build_transformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# Load the pretrained weights
|
| 22 |
model_filename = latest_weights_file_path(config)
|
| 23 |
state = torch.load(model_filename)
|
| 24 |
-
model.load_state_dict(state[
|
| 25 |
|
| 26 |
# if the sentence is a number use it as an index to the test set
|
| 27 |
label = ""
|
| 28 |
if type(sentence) == int or sentence.isdigit():
|
| 29 |
id = int(sentence)
|
| 30 |
-
ds = load_dataset(
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 34 |
label = ds[id]["tgt_text"]
|
| 35 |
-
seq_len = config[
|
| 36 |
|
| 37 |
# translate the sentence
|
| 38 |
|
|
@@ -40,46 +61,82 @@ def translate(sentence: str):
|
|
| 40 |
with torch.no_grad():
|
| 41 |
# Precompute the encoder output and reuse it for every generation step
|
| 42 |
source = tokenizer_src.encode(sentence)
|
| 43 |
-
source = torch.cat(
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
encoder_output = model.encode(source, source_mask)
|
| 51 |
|
| 52 |
# Initialize the decoder input with the sos token
|
| 53 |
-
decoder_input =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
# Print the source sentence and target start prompt
|
| 56 |
-
if label != "":
|
|
|
|
| 57 |
print(f"{f'SOURCE: ':>12}{sentence}")
|
| 58 |
-
if label != "":
|
| 59 |
-
|
|
|
|
| 60 |
|
| 61 |
# Generate the translation word by word
|
| 62 |
while decoder_input.size(1) < seq_len:
|
| 63 |
# build mask for target and calculate output
|
| 64 |
-
decoder_mask =
|
| 65 |
-
torch.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
| 67 |
|
| 68 |
# project next token
|
| 69 |
prob = model.project(out[:, -1])
|
| 70 |
_, next_word = torch.max(prob, dim=1)
|
| 71 |
decoder_input = torch.cat(
|
| 72 |
-
[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
# print the translated word
|
| 75 |
-
print(f"{tokenizer_tgt.decode([next_word.item()])}", end=
|
| 76 |
|
| 77 |
# break if we predict the end of sentence token
|
| 78 |
-
if next_word == tokenizer_tgt.token_to_id(
|
| 79 |
break
|
| 80 |
|
| 81 |
# convert ids to tokens
|
| 82 |
return tokenizer_tgt.decode(decoder_input[0].tolist())
|
| 83 |
|
|
|
|
| 84 |
# read sentence from argument
|
| 85 |
-
translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good student.")
|
|
|
|
| 7 |
import torch
|
| 8 |
import sys
|
| 9 |
|
| 10 |
+
|
| 11 |
def translate(sentence: str):
|
| 12 |
# Define the device, tokenizers, and model
|
| 13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
print("Using device:", device)
|
| 15 |
config = get_config()
|
| 16 |
|
| 17 |
+
tokenizer_src = Tokenizer.from_file(
|
| 18 |
+
str(Path(config["tokenizer_file"].format(config["lang_src"])))
|
| 19 |
+
)
|
| 20 |
+
tokenizer_tgt = Tokenizer.from_file(
|
| 21 |
+
str(Path(config["tokenizer_file"].format(config["lang_tgt"])))
|
| 22 |
+
)
|
| 23 |
|
| 24 |
+
model = build_transformer(
|
| 25 |
+
tokenizer_src.get_vocab_size(),
|
| 26 |
+
tokenizer_tgt.get_vocab_size(),
|
| 27 |
+
config["seq_len"],
|
| 28 |
+
config["seq_len"],
|
| 29 |
+
d_model=config["d_model"],
|
| 30 |
+
).to(device)
|
| 31 |
|
| 32 |
# Load the pretrained weights
|
| 33 |
model_filename = latest_weights_file_path(config)
|
| 34 |
state = torch.load(model_filename)
|
| 35 |
+
model.load_state_dict(state["model_state_dict"])
|
| 36 |
|
| 37 |
# if the sentence is a number use it as an index to the test set
|
| 38 |
label = ""
|
| 39 |
if type(sentence) == int or sentence.isdigit():
|
| 40 |
id = int(sentence)
|
| 41 |
+
ds = load_dataset(
|
| 42 |
+
f"{config['datasource']}",
|
| 43 |
+
f"{config['lang_src']}-{config['lang_tgt']}",
|
| 44 |
+
split="all",
|
| 45 |
+
)
|
| 46 |
+
ds = BilingualDataset(
|
| 47 |
+
ds,
|
| 48 |
+
tokenizer_src,
|
| 49 |
+
tokenizer_tgt,
|
| 50 |
+
config["lang_src"],
|
| 51 |
+
config["lang_tgt"],
|
| 52 |
+
config["seq_len"],
|
| 53 |
+
)
|
| 54 |
+
sentence = ds[id]["src_text"]
|
| 55 |
label = ds[id]["tgt_text"]
|
| 56 |
+
seq_len = config["seq_len"]
|
| 57 |
|
| 58 |
# translate the sentence
|
| 59 |
|
|
|
|
| 61 |
with torch.no_grad():
|
| 62 |
# Precompute the encoder output and reuse it for every generation step
|
| 63 |
source = tokenizer_src.encode(sentence)
|
| 64 |
+
source = torch.cat(
|
| 65 |
+
[
|
| 66 |
+
torch.tensor([tokenizer_src.token_to_id("[SOS]")], dtype=torch.int64),
|
| 67 |
+
torch.tensor(source.ids, dtype=torch.int64),
|
| 68 |
+
torch.tensor([tokenizer_src.token_to_id("[EOS]")], dtype=torch.int64),
|
| 69 |
+
torch.tensor(
|
| 70 |
+
[tokenizer_src.token_to_id("[PAD]")]
|
| 71 |
+
* (seq_len - len(source.ids) - 2),
|
| 72 |
+
dtype=torch.int64,
|
| 73 |
+
),
|
| 74 |
+
],
|
| 75 |
+
dim=0,
|
| 76 |
+
).to(device)
|
| 77 |
+
source_mask = (
|
| 78 |
+
(source != tokenizer_src.token_to_id("[PAD]"))
|
| 79 |
+
.unsqueeze(0)
|
| 80 |
+
.unsqueeze(0)
|
| 81 |
+
.int()
|
| 82 |
+
.to(device)
|
| 83 |
+
)
|
| 84 |
encoder_output = model.encode(source, source_mask)
|
| 85 |
|
| 86 |
# Initialize the decoder input with the sos token
|
| 87 |
+
decoder_input = (
|
| 88 |
+
torch.empty(1, 1)
|
| 89 |
+
.fill_(tokenizer_tgt.token_to_id("[SOS]"))
|
| 90 |
+
.type_as(source)
|
| 91 |
+
.to(device)
|
| 92 |
+
)
|
| 93 |
|
| 94 |
# Print the source sentence and target start prompt
|
| 95 |
+
if label != "":
|
| 96 |
+
print(f"{f'ID: ':>12}{id}")
|
| 97 |
print(f"{f'SOURCE: ':>12}{sentence}")
|
| 98 |
+
if label != "":
|
| 99 |
+
print(f"{f'TARGET: ':>12}{label}")
|
| 100 |
+
print(f"{f'PREDICTED: ':>12}", end="")
|
| 101 |
|
| 102 |
# Generate the translation word by word
|
| 103 |
while decoder_input.size(1) < seq_len:
|
| 104 |
# build mask for target and calculate output
|
| 105 |
+
decoder_mask = (
|
| 106 |
+
torch.triu(
|
| 107 |
+
torch.ones((1, decoder_input.size(1), decoder_input.size(1))),
|
| 108 |
+
diagonal=1,
|
| 109 |
+
)
|
| 110 |
+
.type(torch.int)
|
| 111 |
+
.type_as(source_mask)
|
| 112 |
+
.to(device)
|
| 113 |
+
)
|
| 114 |
out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)
|
| 115 |
|
| 116 |
# project next token
|
| 117 |
prob = model.project(out[:, -1])
|
| 118 |
_, next_word = torch.max(prob, dim=1)
|
| 119 |
decoder_input = torch.cat(
|
| 120 |
+
[
|
| 121 |
+
decoder_input,
|
| 122 |
+
torch.empty(1, 1)
|
| 123 |
+
.type_as(source)
|
| 124 |
+
.fill_(next_word.item())
|
| 125 |
+
.to(device),
|
| 126 |
+
],
|
| 127 |
+
dim=1,
|
| 128 |
+
)
|
| 129 |
|
| 130 |
# print the translated word
|
| 131 |
+
print(f"{tokenizer_tgt.decode([next_word.item()])}", end=" ")
|
| 132 |
|
| 133 |
# break if we predict the end of sentence token
|
| 134 |
+
if next_word == tokenizer_tgt.token_to_id("[EOS]"):
|
| 135 |
break
|
| 136 |
|
| 137 |
# convert ids to tokens
|
| 138 |
return tokenizer_tgt.decode(decoder_input[0].tolist())
|
| 139 |
|
| 140 |
+
|
| 141 |
# read sentence from argument
|
| 142 |
+
translate(sys.argv[1] if len(sys.argv) > 1 else "I am not a very good student.")
|
utils.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from torch_geometric.data import Data, Batch
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
from rdkit.Chem import AllChem
|
| 6 |
+
import nglview as nv
|
| 7 |
+
import py3Dmol
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
from dataset import get_atom_features, get_protein_features
|
| 11 |
+
from model_attention import BindingAffinityModel
|
| 12 |
+
|
| 13 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
MODEL_PATH = "runs/experiment_attention20260124_104439_optuna/models/model_ep041_mse1.9153.pth"
|
| 15 |
+
|
| 16 |
+
GAT_HEADS = 2
|
| 17 |
+
HIDDEN_CHANNELS = 256
|
| 18 |
+
|
| 19 |
+
def get_inference_data(ligand_smiles, protein_sequence, model_path=MODEL_PATH):
|
| 20 |
+
"""
|
| 21 |
+
Returns:
|
| 22 |
+
- mol: RDKit molecule object with 3D coordinates
|
| 23 |
+
- importance: list of importance scores for each atom
|
| 24 |
+
- predicted_affinity: predicted binding affinity value
|
| 25 |
+
"""
|
| 26 |
+
# Prepare ligand molecule with geometry RDKit
|
| 27 |
+
mol = Chem.MolFromSmiles(ligand_smiles)
|
| 28 |
+
mol = Chem.AddHs(mol)
|
| 29 |
+
AllChem.EmbedMolecule(mol, randomSeed=42)
|
| 30 |
+
|
| 31 |
+
# Graph data PyTorch
|
| 32 |
+
atom_features = [get_atom_features(atom) for atom in mol.GetAtoms()]
|
| 33 |
+
x = torch.tensor(np.array(atom_features), dtype=torch.float)
|
| 34 |
+
edge_index = []
|
| 35 |
+
for bond in mol.GetBonds():
|
| 36 |
+
i, j = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
| 37 |
+
edge_index.extend([(i, j), (j, i)])
|
| 38 |
+
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
| 39 |
+
|
| 40 |
+
tokens = [get_protein_features(c) for c in protein_sequence]
|
| 41 |
+
if len(tokens) > 1200: tokens = tokens[:1200]
|
| 42 |
+
else: tokens.extend([0] * (1200 - len(tokens)))
|
| 43 |
+
protein_sequence = torch.tensor(tokens, dtype=torch.long).unsqueeze(0).to(DEVICE)
|
| 44 |
+
|
| 45 |
+
data = Data(x=x, edge_index=edge_index)
|
| 46 |
+
batch = Batch.from_data_list([data]).to(DEVICE)
|
| 47 |
+
num_features = x.shape[1]
|
| 48 |
+
|
| 49 |
+
# Model loading
|
| 50 |
+
model = BindingAffinityModel(num_features, hidden_channels=HIDDEN_CHANNELS, gat_heads=GAT_HEADS).to(DEVICE)
|
| 51 |
+
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
|
| 52 |
+
model.eval()
|
| 53 |
+
|
| 54 |
+
# Prediction
|
| 55 |
+
with torch.no_grad():
|
| 56 |
+
pred = model(batch.x, batch.edge_index, batch.batch, protein_sequence)
|
| 57 |
+
attention_weights = model.cross_attention.last_attention_weights[0]
|
| 58 |
+
|
| 59 |
+
# Attention importance, Max + Normalize
|
| 60 |
+
real_prot_len = len([t for t in tokens if t != 0])
|
| 61 |
+
importance = attention_weights[:, :real_prot_len].max(dim=1).values.cpu().numpy()
|
| 62 |
+
|
| 63 |
+
# Normalize to [0, 1]
|
| 64 |
+
if importance.max() > 0:
|
| 65 |
+
importance = (importance - importance.min()) / (importance.max() - importance.min())
|
| 66 |
+
|
| 67 |
+
# Noise reduction
|
| 68 |
+
importance[importance < 0.01] = 0
|
| 69 |
+
return mol, importance, pred.item()
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def get_py3dmol_view(mol, importance):
|
| 74 |
+
view = py3Dmol.view(width="100%", height="600px")
|
| 75 |
+
view.addModel(Chem.MolToMolBlock(mol), "sdf")
|
| 76 |
+
view.setBackgroundColor('white')
|
| 77 |
+
|
| 78 |
+
view.setStyle({}, {
|
| 79 |
+
'stick': {'radius': 0.15},
|
| 80 |
+
'sphere': {'scale': 0.25}
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
indices_sorted = np.argsort(importance)[::-1]
|
| 84 |
+
top_indices = set(indices_sorted[:15])
|
| 85 |
+
|
| 86 |
+
conf = mol.GetConformer()
|
| 87 |
+
|
| 88 |
+
for i, val in enumerate(importance):
|
| 89 |
+
if i in top_indices:
|
| 90 |
+
pos = conf.GetAtomPosition(i)
|
| 91 |
+
symbol = mol.GetAtomWithIdx(i).GetSymbol()
|
| 92 |
+
|
| 93 |
+
label_text = f"{i}:{symbol}:{val:.2f}"
|
| 94 |
+
|
| 95 |
+
view.addLabel(label_text, {
|
| 96 |
+
'position': {'x': pos.x, 'y': pos.y, 'z': pos.z},
|
| 97 |
+
'fontSize': 14,
|
| 98 |
+
'fontColor': 'white',
|
| 99 |
+
'backgroundColor': 'black',
|
| 100 |
+
'backgroundOpacity': 0.7,
|
| 101 |
+
'borderThickness': 0,
|
| 102 |
+
'inFront': True,
|
| 103 |
+
'showBackground': True
|
| 104 |
+
})
|
| 105 |
+
view.zoomTo()
|
| 106 |
+
return view
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def save_standalone_ngl_html(mol, importance, filepath):
|
| 111 |
+
pdb_block = Chem.MolToPDBBlock(mol)
|
| 112 |
+
mol_pdb = Chem.MolFromPDBBlock(pdb_block, removeHs=False)
|
| 113 |
+
|
| 114 |
+
for i, atom in enumerate(mol_pdb.GetAtoms()):
|
| 115 |
+
info = atom.GetPDBResidueInfo()
|
| 116 |
+
if info:
|
| 117 |
+
|
| 118 |
+
info.SetTempFactor(float(importance[i]) * 100)
|
| 119 |
+
|
| 120 |
+
final_pdb_block = Chem.MolToPDBBlock(mol_pdb)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
final_pdb_block = final_pdb_block.replace("`", "\\`")
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
indices_sorted = np.argsort(importance)[::-1]
|
| 127 |
+
top_indices = indices_sorted[:15]
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
selection_list = [str(i) for i in top_indices]
|
| 131 |
+
selection_str = "@" + ",".join(selection_list)
|
| 132 |
+
|
| 133 |
+
# Защита от пустой выборки
|
| 134 |
+
if not selection_list:
|
| 135 |
+
selection_str = "@-1"
|
| 136 |
+
|
| 137 |
+
html_content = f"""<!DOCTYPE html>
|
| 138 |
+
<html lang="en">
|
| 139 |
+
<head>
|
| 140 |
+
<meta charset="UTF-8">
|
| 141 |
+
<title>NGL Visualization</title>
|
| 142 |
+
<script src="https://unpkg.com/ngl@2.0.0-dev.37/dist/ngl.js"></script>
|
| 143 |
+
<style>
|
| 144 |
+
html, body {{ width: 100%; height: 100%; margin: 0; padding: 0; overflow: hidden; font-family: sans-serif; }}
|
| 145 |
+
#viewport {{ width: 100%; height: 100%; }}
|
| 146 |
+
|
| 147 |
+
/* Стиль подсказки */
|
| 148 |
+
#tooltip {{
|
| 149 |
+
display: none;
|
| 150 |
+
position: absolute;
|
| 151 |
+
z-index: 100;
|
| 152 |
+
pointer-events: none; /* Чтобы мышь не 'застревала' на подсказке */
|
| 153 |
+
background-color: rgba(20, 20, 20, 0.9);
|
| 154 |
+
color: white;
|
| 155 |
+
padding: 8px 12px;
|
| 156 |
+
border-radius: 6px;
|
| 157 |
+
font-size: 14px;
|
| 158 |
+
box-shadow: 0 4px 6px rgba(0,0,0,0.3);
|
| 159 |
+
white-space: nowrap;
|
| 160 |
+
border: 1px solid rgba(255,255,255,0.2);
|
| 161 |
+
transition: opacity 0.1s ease;
|
| 162 |
+
}}
|
| 163 |
+
|
| 164 |
+
/* Панель управления */
|
| 165 |
+
#controls {{
|
| 166 |
+
position: absolute;
|
| 167 |
+
top: 20px;
|
| 168 |
+
right: 20px;
|
| 169 |
+
z-index: 50;
|
| 170 |
+
background: rgba(255, 255, 255, 0.95);
|
| 171 |
+
padding: 15px;
|
| 172 |
+
border-radius: 8px;
|
| 173 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 174 |
+
display: flex;
|
| 175 |
+
align-items: center;
|
| 176 |
+
}}
|
| 177 |
+
|
| 178 |
+
/* Стили переключателя */
|
| 179 |
+
.switch-container {{
|
| 180 |
+
display: flex;
|
| 181 |
+
align-items: center;
|
| 182 |
+
gap: 10px;
|
| 183 |
+
cursor: pointer;
|
| 184 |
+
font-weight: bold;
|
| 185 |
+
color: #333;
|
| 186 |
+
}}
|
| 187 |
+
|
| 188 |
+
input[type=checkbox] {{
|
| 189 |
+
transform: scale(1.5);
|
| 190 |
+
cursor: pointer;
|
| 191 |
+
}}
|
| 192 |
+
</style>
|
| 193 |
+
</head>
|
| 194 |
+
<body>
|
| 195 |
+
<div id="controls">
|
| 196 |
+
<label class="switch-container">
|
| 197 |
+
<input type="checkbox" id="heatmapToggle" checked>
|
| 198 |
+
<span>Show Heatmap</span>
|
| 199 |
+
</label>
|
| 200 |
+
</div>
|
| 201 |
+
|
| 202 |
+
<div id="tooltip"></div>
|
| 203 |
+
|
| 204 |
+
<div id="viewport"></div>
|
| 205 |
+
|
| 206 |
+
<script>
|
| 207 |
+
var pdbData = `{final_pdb_block}`;
|
| 208 |
+
var selectionString = "{selection_str}";
|
| 209 |
+
var component; // Глобальная переменная для доступа к модели
|
| 210 |
+
|
| 211 |
+
document.addEventListener("DOMContentLoaded", function () {{
|
| 212 |
+
var stage = new NGL.Stage("viewport", {{ backgroundColor: "white" }});
|
| 213 |
+
var tooltip = document.getElementById("tooltip");
|
| 214 |
+
var toggle = document.getElementById("heatmapToggle");
|
| 215 |
+
|
| 216 |
+
// Загружаем данные
|
| 217 |
+
var stringBlob = new Blob([pdbData], {{type: 'text/plain'}});
|
| 218 |
+
|
| 219 |
+
stage.loadFile(stringBlob, {{ ext: 'pdb' }}).then(function (o) {{
|
| 220 |
+
component = o; // Сохраняем ссылку
|
| 221 |
+
|
| 222 |
+
// Рисуем начальное состояние
|
| 223 |
+
updateVisualization();
|
| 224 |
+
o.autoView();
|
| 225 |
+
}});
|
| 226 |
+
|
| 227 |
+
// --- ФУНКЦИЯ ОБНОВЛЕНИЯ ВИДА ---
|
| 228 |
+
function updateVisualization() {{
|
| 229 |
+
if (!component) return;
|
| 230 |
+
|
| 231 |
+
// Очищаем старые представления (чтобы не накладывались)
|
| 232 |
+
component.removeAllRepresentations();
|
| 233 |
+
|
| 234 |
+
var useHeatmap = toggle.checked;
|
| 235 |
+
|
| 236 |
+
if (useHeatmap) {{
|
| 237 |
+
// 1. РЕЖИМ HEATMAP
|
| 238 |
+
component.addRepresentation("ball+stick", {{
|
| 239 |
+
colorScheme: "bfactor",
|
| 240 |
+
colorDomain: [20, 80],
|
| 241 |
+
colorScale: ["blue", "white", "red"],
|
| 242 |
+
radiusScale: 1.0
|
| 243 |
+
}});
|
| 244 |
+
}} else {{
|
| 245 |
+
// 2. ОБЫЧНЫЙ РЕЖИМ (По элементам)
|
| 246 |
+
component.addRepresentation("ball+stick", {{
|
| 247 |
+
colorScheme: "element",
|
| 248 |
+
radiusScale: 1.0
|
| 249 |
+
}});
|
| 250 |
+
}}
|
| 251 |
+
|
| 252 |
+
// Добавляем метки (они нужны всегда)
|
| 253 |
+
if (selectionString.length > 1 && selectionString !== "@-1") {{
|
| 254 |
+
component.addRepresentation("label", {{
|
| 255 |
+
sele: selectionString,
|
| 256 |
+
labelType: "atomindex",
|
| 257 |
+
color: "black",
|
| 258 |
+
radius: 1.1,
|
| 259 |
+
yOffset: 0.0,
|
| 260 |
+
zOffset: 2.0,
|
| 261 |
+
attachment: "middle_center",
|
| 262 |
+
pickable: true // ВАЖНО: Делаем текст интерактивным
|
| 263 |
+
}});
|
| 264 |
+
}}
|
| 265 |
+
}}
|
| 266 |
+
|
| 267 |
+
// Слушаем переключатель
|
| 268 |
+
toggle.addEventListener("change", updateVisualization);
|
| 269 |
+
|
| 270 |
+
// --- УМНЫЙ TOOLTIP ---
|
| 271 |
+
stage.mouseControls.remove("hoverPick"); // Убираем стандартное поведение
|
| 272 |
+
|
| 273 |
+
stage.signals.hovered.add(function (pickingProxy) {{
|
| 274 |
+
// Проверяем, навели ли мы на атом ИЛИ на метку (текст)
|
| 275 |
+
// NGL возвращает pickingProxy.atom даже если мы навели на label этого атома
|
| 276 |
+
if (pickingProxy && (pickingProxy.atom || pickingProxy.closestBondAtom)) {{
|
| 277 |
+
var atom = pickingProxy.atom || pickingProxy.closestBondAtom;
|
| 278 |
+
var score = atom.bfactor.toFixed(2);
|
| 279 |
+
|
| 280 |
+
tooltip.innerHTML = `
|
| 281 |
+
<div style="margin-bottom:2px;"><b>Atom ID:</b> ${{atom.index}} (${{atom.element}}: ${{atom.atomname}})</div>
|
| 282 |
+
<div style="color: #ffcccc;"><b>Importance:</b> ${{(score/100).toFixed(3)}}</div>
|
| 283 |
+
`;
|
| 284 |
+
tooltip.style.display = "block";
|
| 285 |
+
tooltip.style.opacity = "1";
|
| 286 |
+
|
| 287 |
+
// Позиционирование: сдвиг вправо и вниз, чтобы не мешать
|
| 288 |
+
var cp = pickingProxy.canvasPosition;
|
| 289 |
+
tooltip.style.left = (cp.x + 20) + "px";
|
| 290 |
+
tooltip.style.top = (cp.y + 20) + "px";
|
| 291 |
+
|
| 292 |
+
}} else {{
|
| 293 |
+
// Скрываем, если увели мышь
|
| 294 |
+
tooltip.style.display = "none";
|
| 295 |
+
tooltip.style.opacity = "0";
|
| 296 |
+
}}
|
| 297 |
+
}});
|
| 298 |
+
|
| 299 |
+
// Ресайз окна
|
| 300 |
+
window.addEventListener("resize", function(event){{
|
| 301 |
+
stage.handleResize();
|
| 302 |
+
}}, false);
|
| 303 |
+
}});
|
| 304 |
+
</script>
|
| 305 |
+
</body>
|
| 306 |
+
</html>"""
|
| 307 |
+
with open(filepath, "w", encoding="utf-8") as f:
|
| 308 |
+
f.write(html_content)
|
visualization.ipynb
CHANGED
|
@@ -2,261 +2,255 @@
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
|
|
|
| 5 |
"id": "initial_id",
|
| 6 |
"metadata": {
|
| 7 |
"ExecuteTime": {
|
| 8 |
-
"end_time": "
|
| 9 |
-
"start_time": "
|
| 10 |
}
|
| 11 |
},
|
| 12 |
-
"source": [
|
| 13 |
-
"import nglview as nv\n",
|
| 14 |
-
"import os"
|
| 15 |
-
],
|
| 16 |
"outputs": [
|
| 17 |
{
|
| 18 |
"data": {
|
| 19 |
-
"text/plain": [],
|
| 20 |
"application/vnd.jupyter.widget-view+json": {
|
|
|
|
| 21 |
"version_major": 2,
|
| 22 |
-
"version_minor": 0
|
| 23 |
-
|
| 24 |
-
|
| 25 |
},
|
| 26 |
"metadata": {},
|
| 27 |
-
"output_type": "display_data"
|
| 28 |
-
"jetTransient": {
|
| 29 |
-
"display_id": null
|
| 30 |
-
}
|
| 31 |
}
|
| 32 |
],
|
| 33 |
-
"
|
|
|
|
|
|
|
|
|
|
| 34 |
},
|
| 35 |
{
|
| 36 |
"cell_type": "code",
|
|
|
|
| 37 |
"id": "d8d7978e-980a-400c-8c6a-5365990c8855",
|
| 38 |
"metadata": {
|
| 39 |
"ExecuteTime": {
|
| 40 |
-
"end_time": "
|
| 41 |
-
"start_time": "
|
| 42 |
}
|
| 43 |
},
|
|
|
|
| 44 |
"source": [
|
| 45 |
"PDBBIND_PATH = \"refined-set\""
|
| 46 |
-
]
|
| 47 |
-
"outputs": [],
|
| 48 |
-
"execution_count": 2
|
| 49 |
},
|
| 50 |
{
|
| 51 |
"cell_type": "code",
|
|
|
|
| 52 |
"id": "788a6b43-c515-45c7-bc52-341d446b1a65",
|
| 53 |
"metadata": {
|
| 54 |
"ExecuteTime": {
|
| 55 |
-
"end_time": "
|
| 56 |
-
"start_time": "
|
| 57 |
}
|
| 58 |
},
|
|
|
|
| 59 |
"source": [
|
| 60 |
"EXAMPLE_PDB_ID = \"1a1e\""
|
| 61 |
-
]
|
| 62 |
-
"outputs": [],
|
| 63 |
-
"execution_count": 3
|
| 64 |
},
|
| 65 |
{
|
| 66 |
"cell_type": "code",
|
|
|
|
| 67 |
"id": "e8f4bebc-845f-43e8-bc4d-ab7b649eb49c",
|
| 68 |
"metadata": {
|
| 69 |
"ExecuteTime": {
|
| 70 |
-
"end_time": "
|
| 71 |
-
"start_time": "
|
| 72 |
}
|
| 73 |
},
|
|
|
|
| 74 |
"source": [
|
| 75 |
"pdb_dir = os.path.join(PDBBIND_PATH, EXAMPLE_PDB_ID)"
|
| 76 |
-
]
|
| 77 |
-
"outputs": [],
|
| 78 |
-
"execution_count": 4
|
| 79 |
},
|
| 80 |
{
|
| 81 |
"cell_type": "code",
|
|
|
|
| 82 |
"id": "24b5e435-4d8f-4505-b27c-dd6317376ed4",
|
| 83 |
"metadata": {
|
| 84 |
"ExecuteTime": {
|
| 85 |
-
"end_time": "
|
| 86 |
-
"start_time": "
|
| 87 |
}
|
| 88 |
},
|
|
|
|
| 89 |
"source": [
|
| 90 |
"protein_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_protein.pdb\")"
|
| 91 |
-
]
|
| 92 |
-
"outputs": [],
|
| 93 |
-
"execution_count": 5
|
| 94 |
},
|
| 95 |
{
|
| 96 |
"cell_type": "code",
|
|
|
|
| 97 |
"id": "e7fc3539-00c0-48a2-b012-c80757fa12c4",
|
| 98 |
"metadata": {
|
| 99 |
"ExecuteTime": {
|
| 100 |
-
"end_time": "
|
| 101 |
-
"start_time": "
|
| 102 |
}
|
| 103 |
},
|
|
|
|
| 104 |
"source": [
|
| 105 |
"ligand_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_ligand.sdf\")"
|
| 106 |
-
]
|
| 107 |
-
"outputs": [],
|
| 108 |
-
"execution_count": 6
|
| 109 |
},
|
| 110 |
{
|
| 111 |
"cell_type": "code",
|
|
|
|
| 112 |
"id": "9a053b99-7c01-4881-b3f7-e9b39090af9d",
|
| 113 |
"metadata": {
|
| 114 |
"ExecuteTime": {
|
| 115 |
-
"end_time": "
|
| 116 |
-
"start_time": "
|
| 117 |
}
|
| 118 |
},
|
|
|
|
| 119 |
"source": [
|
| 120 |
"view = nv.NGLWidget()"
|
| 121 |
-
]
|
| 122 |
-
"outputs": [],
|
| 123 |
-
"execution_count": 7
|
| 124 |
},
|
| 125 |
{
|
| 126 |
"cell_type": "code",
|
|
|
|
| 127 |
"id": "df8c8e00-3ce6-41dd-b457-d9f50e318dad",
|
| 128 |
"metadata": {
|
| 129 |
"ExecuteTime": {
|
| 130 |
-
"end_time": "
|
| 131 |
-
"start_time": "
|
| 132 |
}
|
| 133 |
},
|
|
|
|
| 134 |
"source": [
|
| 135 |
"protein_comp = view.add_component(protein_file)"
|
| 136 |
-
]
|
| 137 |
-
"outputs": [],
|
| 138 |
-
"execution_count": 8
|
| 139 |
},
|
| 140 |
{
|
| 141 |
"cell_type": "code",
|
|
|
|
| 142 |
"id": "c191fead-fef8-4077-b787-5bf9552307b1",
|
| 143 |
"metadata": {
|
| 144 |
"ExecuteTime": {
|
| 145 |
-
"end_time": "
|
| 146 |
-
"start_time": "
|
| 147 |
}
|
| 148 |
},
|
|
|
|
| 149 |
"source": [
|
| 150 |
"protein_comp.clear_representations()"
|
| 151 |
-
]
|
| 152 |
-
"outputs": [],
|
| 153 |
-
"execution_count": 9
|
| 154 |
},
|
| 155 |
{
|
| 156 |
"cell_type": "code",
|
|
|
|
| 157 |
"id": "4559033a-aeda-4659-8d91-9002b5a6ecda",
|
| 158 |
"metadata": {
|
| 159 |
"ExecuteTime": {
|
| 160 |
-
"end_time": "
|
| 161 |
-
"start_time": "
|
| 162 |
}
|
| 163 |
},
|
| 164 |
-
"source": [
|
| 165 |
-
"protein_comp.add_representation('cartoon', color='blue')"
|
| 166 |
-
],
|
| 167 |
"outputs": [],
|
| 168 |
-
"
|
|
|
|
|
|
|
| 169 |
},
|
| 170 |
{
|
| 171 |
"cell_type": "code",
|
|
|
|
| 172 |
"id": "73ea1a50-8463-40b8-a942-0c92d3e97a97",
|
| 173 |
"metadata": {
|
| 174 |
"ExecuteTime": {
|
| 175 |
-
"end_time": "
|
| 176 |
-
"start_time": "
|
| 177 |
}
|
| 178 |
},
|
|
|
|
| 179 |
"source": [
|
| 180 |
"ligand_comp = view.add_component(ligand_file)"
|
| 181 |
-
]
|
| 182 |
-
"outputs": [],
|
| 183 |
-
"execution_count": 11
|
| 184 |
},
|
| 185 |
{
|
| 186 |
"cell_type": "code",
|
|
|
|
| 187 |
"id": "16cdb710-1ed6-4b1d-9e6a-69b7ad61a600",
|
| 188 |
"metadata": {
|
| 189 |
"ExecuteTime": {
|
| 190 |
-
"end_time": "
|
| 191 |
-
"start_time": "
|
| 192 |
}
|
| 193 |
},
|
|
|
|
| 194 |
"source": [
|
| 195 |
"ligand_comp.clear_representations()"
|
| 196 |
-
]
|
| 197 |
-
"outputs": [],
|
| 198 |
-
"execution_count": 12
|
| 199 |
},
|
| 200 |
{
|
| 201 |
"cell_type": "code",
|
|
|
|
| 202 |
"id": "2193c497-f33c-4de0-86a9-6e535002fcb7",
|
| 203 |
"metadata": {
|
| 204 |
"ExecuteTime": {
|
| 205 |
-
"end_time": "
|
| 206 |
-
"start_time": "
|
| 207 |
}
|
| 208 |
},
|
| 209 |
-
"source": [
|
| 210 |
-
"ligand_comp.add_representation('ball+stick', radius=0.3)"
|
| 211 |
-
],
|
| 212 |
"outputs": [],
|
| 213 |
-
"
|
|
|
|
|
|
|
| 214 |
},
|
| 215 |
{
|
| 216 |
"cell_type": "code",
|
|
|
|
| 217 |
"id": "b1cc7f44-a374-4400-b4ba-8f75101b21ce",
|
| 218 |
"metadata": {
|
| 219 |
"ExecuteTime": {
|
| 220 |
-
"end_time": "
|
| 221 |
-
"start_time": "
|
| 222 |
}
|
| 223 |
},
|
| 224 |
-
"source": [
|
| 225 |
-
"view"
|
| 226 |
-
],
|
| 227 |
"outputs": [
|
| 228 |
{
|
| 229 |
"data": {
|
| 230 |
-
"text/plain": [
|
| 231 |
-
"NGLWidget()"
|
| 232 |
-
],
|
| 233 |
"application/vnd.jupyter.widget-view+json": {
|
|
|
|
| 234 |
"version_major": 2,
|
| 235 |
-
"version_minor": 0
|
| 236 |
-
|
| 237 |
-
|
|
|
|
|
|
|
| 238 |
},
|
| 239 |
"metadata": {},
|
| 240 |
-
"output_type": "display_data"
|
| 241 |
-
"jetTransient": {
|
| 242 |
-
"display_id": null
|
| 243 |
-
}
|
| 244 |
}
|
| 245 |
],
|
| 246 |
-
"
|
|
|
|
|
|
|
| 247 |
},
|
| 248 |
{
|
| 249 |
"cell_type": "code",
|
|
|
|
| 250 |
"id": "5655e465-bb44-4218-a5e3-db2c5e62cd9c",
|
| 251 |
"metadata": {
|
| 252 |
"ExecuteTime": {
|
| 253 |
-
"end_time": "
|
| 254 |
-
"start_time": "
|
| 255 |
}
|
| 256 |
},
|
| 257 |
-
"source": [],
|
| 258 |
"outputs": [],
|
| 259 |
-
"
|
| 260 |
}
|
| 261 |
],
|
| 262 |
"metadata": {
|
|
|
|
| 2 |
"cells": [
|
| 3 |
{
|
| 4 |
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
"id": "initial_id",
|
| 7 |
"metadata": {
|
| 8 |
"ExecuteTime": {
|
| 9 |
+
"end_time": "2026-01-24T09:06:36.981469Z",
|
| 10 |
+
"start_time": "2026-01-24T09:06:36.975634Z"
|
| 11 |
}
|
| 12 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
"outputs": [
|
| 14 |
{
|
| 15 |
"data": {
|
|
|
|
| 16 |
"application/vnd.jupyter.widget-view+json": {
|
| 17 |
+
"model_id": "5077355be9e64d4f814509a151b6c8b6",
|
| 18 |
"version_major": 2,
|
| 19 |
+
"version_minor": 0
|
| 20 |
+
},
|
| 21 |
+
"text/plain": []
|
| 22 |
},
|
| 23 |
"metadata": {},
|
| 24 |
+
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
| 25 |
}
|
| 26 |
],
|
| 27 |
+
"source": [
|
| 28 |
+
"import nglview as nv\n",
|
| 29 |
+
"import os"
|
| 30 |
+
]
|
| 31 |
},
|
| 32 |
{
|
| 33 |
"cell_type": "code",
|
| 34 |
+
"execution_count": 2,
|
| 35 |
"id": "d8d7978e-980a-400c-8c6a-5365990c8855",
|
| 36 |
"metadata": {
|
| 37 |
"ExecuteTime": {
|
| 38 |
+
"end_time": "2026-01-24T09:06:37.011231Z",
|
| 39 |
+
"start_time": "2026-01-24T09:06:37.005099Z"
|
| 40 |
}
|
| 41 |
},
|
| 42 |
+
"outputs": [],
|
| 43 |
"source": [
|
| 44 |
"PDBBIND_PATH = \"refined-set\""
|
| 45 |
+
]
|
|
|
|
|
|
|
| 46 |
},
|
| 47 |
{
|
| 48 |
"cell_type": "code",
|
| 49 |
+
"execution_count": 3,
|
| 50 |
"id": "788a6b43-c515-45c7-bc52-341d446b1a65",
|
| 51 |
"metadata": {
|
| 52 |
"ExecuteTime": {
|
| 53 |
+
"end_time": "2026-01-24T09:06:37.022991Z",
|
| 54 |
+
"start_time": "2026-01-24T09:06:37.016849Z"
|
| 55 |
}
|
| 56 |
},
|
| 57 |
+
"outputs": [],
|
| 58 |
"source": [
|
| 59 |
"EXAMPLE_PDB_ID = \"1a1e\""
|
| 60 |
+
]
|
|
|
|
|
|
|
| 61 |
},
|
| 62 |
{
|
| 63 |
"cell_type": "code",
|
| 64 |
+
"execution_count": 4,
|
| 65 |
"id": "e8f4bebc-845f-43e8-bc4d-ab7b649eb49c",
|
| 66 |
"metadata": {
|
| 67 |
"ExecuteTime": {
|
| 68 |
+
"end_time": "2026-01-24T09:06:37.041322Z",
|
| 69 |
+
"start_time": "2026-01-24T09:06:37.035944Z"
|
| 70 |
}
|
| 71 |
},
|
| 72 |
+
"outputs": [],
|
| 73 |
"source": [
|
| 74 |
"pdb_dir = os.path.join(PDBBIND_PATH, EXAMPLE_PDB_ID)"
|
| 75 |
+
]
|
|
|
|
|
|
|
| 76 |
},
|
| 77 |
{
|
| 78 |
"cell_type": "code",
|
| 79 |
+
"execution_count": 5,
|
| 80 |
"id": "24b5e435-4d8f-4505-b27c-dd6317376ed4",
|
| 81 |
"metadata": {
|
| 82 |
"ExecuteTime": {
|
| 83 |
+
"end_time": "2026-01-24T09:06:37.064924Z",
|
| 84 |
+
"start_time": "2026-01-24T09:06:37.059278Z"
|
| 85 |
}
|
| 86 |
},
|
| 87 |
+
"outputs": [],
|
| 88 |
"source": [
|
| 89 |
"protein_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_protein.pdb\")"
|
| 90 |
+
]
|
|
|
|
|
|
|
| 91 |
},
|
| 92 |
{
|
| 93 |
"cell_type": "code",
|
| 94 |
+
"execution_count": 6,
|
| 95 |
"id": "e7fc3539-00c0-48a2-b012-c80757fa12c4",
|
| 96 |
"metadata": {
|
| 97 |
"ExecuteTime": {
|
| 98 |
+
"end_time": "2026-01-24T09:06:37.080165Z",
|
| 99 |
+
"start_time": "2026-01-24T09:06:37.074657Z"
|
| 100 |
}
|
| 101 |
},
|
| 102 |
+
"outputs": [],
|
| 103 |
"source": [
|
| 104 |
"ligand_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_ligand.sdf\")"
|
| 105 |
+
]
|
|
|
|
|
|
|
| 106 |
},
|
| 107 |
{
|
| 108 |
"cell_type": "code",
|
| 109 |
+
"execution_count": 7,
|
| 110 |
"id": "9a053b99-7c01-4881-b3f7-e9b39090af9d",
|
| 111 |
"metadata": {
|
| 112 |
"ExecuteTime": {
|
| 113 |
+
"end_time": "2026-01-24T09:06:37.126934Z",
|
| 114 |
+
"start_time": "2026-01-24T09:06:37.107047Z"
|
| 115 |
}
|
| 116 |
},
|
| 117 |
+
"outputs": [],
|
| 118 |
"source": [
|
| 119 |
"view = nv.NGLWidget()"
|
| 120 |
+
]
|
|
|
|
|
|
|
| 121 |
},
|
| 122 |
{
|
| 123 |
"cell_type": "code",
|
| 124 |
+
"execution_count": 8,
|
| 125 |
"id": "df8c8e00-3ce6-41dd-b457-d9f50e318dad",
|
| 126 |
"metadata": {
|
| 127 |
"ExecuteTime": {
|
| 128 |
+
"end_time": "2026-01-24T09:06:37.209871Z",
|
| 129 |
+
"start_time": "2026-01-24T09:06:37.140785Z"
|
| 130 |
}
|
| 131 |
},
|
| 132 |
+
"outputs": [],
|
| 133 |
"source": [
|
| 134 |
"protein_comp = view.add_component(protein_file)"
|
| 135 |
+
]
|
|
|
|
|
|
|
| 136 |
},
|
| 137 |
{
|
| 138 |
"cell_type": "code",
|
| 139 |
+
"execution_count": 9,
|
| 140 |
"id": "c191fead-fef8-4077-b787-5bf9552307b1",
|
| 141 |
"metadata": {
|
| 142 |
"ExecuteTime": {
|
| 143 |
+
"end_time": "2026-01-24T09:06:37.243271Z",
|
| 144 |
+
"start_time": "2026-01-24T09:06:37.235380Z"
|
| 145 |
}
|
| 146 |
},
|
| 147 |
+
"outputs": [],
|
| 148 |
"source": [
|
| 149 |
"protein_comp.clear_representations()"
|
| 150 |
+
]
|
|
|
|
|
|
|
| 151 |
},
|
| 152 |
{
|
| 153 |
"cell_type": "code",
|
| 154 |
+
"execution_count": 10,
|
| 155 |
"id": "4559033a-aeda-4659-8d91-9002b5a6ecda",
|
| 156 |
"metadata": {
|
| 157 |
"ExecuteTime": {
|
| 158 |
+
"end_time": "2026-01-24T09:06:37.276519Z",
|
| 159 |
+
"start_time": "2026-01-24T09:06:37.270030Z"
|
| 160 |
}
|
| 161 |
},
|
|
|
|
|
|
|
|
|
|
| 162 |
"outputs": [],
|
| 163 |
+
"source": [
|
| 164 |
+
"protein_comp.add_representation(\"cartoon\", color=\"blue\")"
|
| 165 |
+
]
|
| 166 |
},
|
| 167 |
{
|
| 168 |
"cell_type": "code",
|
| 169 |
+
"execution_count": 11,
|
| 170 |
"id": "73ea1a50-8463-40b8-a942-0c92d3e97a97",
|
| 171 |
"metadata": {
|
| 172 |
"ExecuteTime": {
|
| 173 |
+
"end_time": "2026-01-24T09:06:37.309460Z",
|
| 174 |
+
"start_time": "2026-01-24T09:06:37.299153Z"
|
| 175 |
}
|
| 176 |
},
|
| 177 |
+
"outputs": [],
|
| 178 |
"source": [
|
| 179 |
"ligand_comp = view.add_component(ligand_file)"
|
| 180 |
+
]
|
|
|
|
|
|
|
| 181 |
},
|
| 182 |
{
|
| 183 |
"cell_type": "code",
|
| 184 |
+
"execution_count": 12,
|
| 185 |
"id": "16cdb710-1ed6-4b1d-9e6a-69b7ad61a600",
|
| 186 |
"metadata": {
|
| 187 |
"ExecuteTime": {
|
| 188 |
+
"end_time": "2026-01-24T09:06:37.340286Z",
|
| 189 |
+
"start_time": "2026-01-24T09:06:37.333802Z"
|
| 190 |
}
|
| 191 |
},
|
| 192 |
+
"outputs": [],
|
| 193 |
"source": [
|
| 194 |
"ligand_comp.clear_representations()"
|
| 195 |
+
]
|
|
|
|
|
|
|
| 196 |
},
|
| 197 |
{
|
| 198 |
"cell_type": "code",
|
| 199 |
+
"execution_count": 13,
|
| 200 |
"id": "2193c497-f33c-4de0-86a9-6e535002fcb7",
|
| 201 |
"metadata": {
|
| 202 |
"ExecuteTime": {
|
| 203 |
+
"end_time": "2026-01-24T09:06:37.372239Z",
|
| 204 |
+
"start_time": "2026-01-24T09:06:37.365156Z"
|
| 205 |
}
|
| 206 |
},
|
|
|
|
|
|
|
|
|
|
| 207 |
"outputs": [],
|
| 208 |
+
"source": [
|
| 209 |
+
"ligand_comp.add_representation(\"ball+stick\", radius=0.3)"
|
| 210 |
+
]
|
| 211 |
},
|
| 212 |
{
|
| 213 |
"cell_type": "code",
|
| 214 |
+
"execution_count": 14,
|
| 215 |
"id": "b1cc7f44-a374-4400-b4ba-8f75101b21ce",
|
| 216 |
"metadata": {
|
| 217 |
"ExecuteTime": {
|
| 218 |
+
"end_time": "2026-01-24T09:06:37.406445Z",
|
| 219 |
+
"start_time": "2026-01-24T09:06:37.398945Z"
|
| 220 |
}
|
| 221 |
},
|
|
|
|
|
|
|
|
|
|
| 222 |
"outputs": [
|
| 223 |
{
|
| 224 |
"data": {
|
|
|
|
|
|
|
|
|
|
| 225 |
"application/vnd.jupyter.widget-view+json": {
|
| 226 |
+
"model_id": "11e403e6733946b9b6942f47bff2464e",
|
| 227 |
"version_major": 2,
|
| 228 |
+
"version_minor": 0
|
| 229 |
+
},
|
| 230 |
+
"text/plain": [
|
| 231 |
+
"NGLWidget()"
|
| 232 |
+
]
|
| 233 |
},
|
| 234 |
"metadata": {},
|
| 235 |
+
"output_type": "display_data"
|
|
|
|
|
|
|
|
|
|
| 236 |
}
|
| 237 |
],
|
| 238 |
+
"source": [
|
| 239 |
+
"view"
|
| 240 |
+
]
|
| 241 |
},
|
| 242 |
{
|
| 243 |
"cell_type": "code",
|
| 244 |
+
"execution_count": null,
|
| 245 |
"id": "5655e465-bb44-4218-a5e3-db2c5e62cd9c",
|
| 246 |
"metadata": {
|
| 247 |
"ExecuteTime": {
|
| 248 |
+
"end_time": "2026-01-24T09:06:37.420258Z",
|
| 249 |
+
"start_time": "2026-01-24T09:06:37.416018Z"
|
| 250 |
}
|
| 251 |
},
|
|
|
|
| 252 |
"outputs": [],
|
| 253 |
+
"source": []
|
| 254 |
}
|
| 255 |
],
|
| 256 |
"metadata": {
|