AlexSychovUN commited on
Commit
da7c0f0
·
0 Parent(s):

Added files

Browse files
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .idea
2
+ .ipynb_checkpoints
GNN_classification/Dataset_Preparation.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import pandas as pd
3
+
4
+ from rdkit import Chem
5
+ from torch_geometric.data import Data
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class SmilesDataset(Dataset):
10
+ def __init__(self, dataframe):
11
+ self.data = dataframe
12
+
13
+ def __len__(self):
14
+ return len(self.data)
15
+
16
+ def __getitem__(self, idx):
17
+ row = self.data.iloc[idx]
18
+ smiles = row["smiles"]
19
+ label = row["label"]
20
+
21
+ mol = Chem.MolFromSmiles(smiles)
22
+ if mol is None: return None
23
+
24
+ # Nodes
25
+ atom_features = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
26
+ x = torch.tensor(atom_features, dtype=torch.float)
27
+
28
+ # Edges
29
+ edge_indexes = []
30
+ for bond in mol.GetBonds():
31
+ i = bond.GetBeginAtomIdx()
32
+ j = bond.GetEndAtomIdx()
33
+ edge_indexes.append((i, j))
34
+ edge_indexes.append((j, i))
35
+
36
+ # t - transpose, [num_of_edges, 2] -> [2, num_of_edges]
37
+ # contiguous - take the virtually transposed tensor and make its physical copy and lay bytes sequentially
38
+ if not edge_indexes:
39
+ edge_index = torch.empty((2, 0), dtype=torch.long)
40
+ else:
41
+ edge_index = torch.tensor(edge_indexes, dtype=torch.long).t().contiguous()
42
+
43
+
44
+ # Label
45
+ y = torch.tensor([label], dtype=torch.long)
46
+ return Data(x=x, edge_index=edge_index, y=y)
47
+
48
+
49
+ if __name__ == "__main__":
50
+ columns = ["smiles", "label"]
51
+ train_dataset = pd.read_csv(
52
+ "dataset/classification/data_train.txt", sep=" ", header=None, names=columns
53
+ )
54
+ test_dataset = pd.read_csv(
55
+ "dataset/classification/data_test.txt", sep=" ", header=None, names=columns
56
+ )
57
+
58
+ train_dataset = SmilesDataset(train_dataset)
59
+ test_dataset = SmilesDataset(test_dataset)
60
+
61
+ print(len(train_dataset))
62
+ print(len(test_dataset))
63
+
64
+
GNN_classification/dataset/classification/data_test.txt ADDED
The diff for this file is too large to render. See raw diff
 
GNN_classification/dataset/classification/data_train.txt ADDED
The diff for this file is too large to render. See raw diff
 
GNN_classification/model.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pandas as pd
5
+ from rdkit import Chem
6
+
7
+ from torch_geometric.nn import GCNConv, global_mean_pool
8
+ from torch_geometric.data import Data
9
+ from torch_geometric.loader import DataLoader
10
+ from torch.utils.data import Dataset
11
+
12
+ class GNNClassifier(nn.Module):
13
+ def __init__(self, input_dim, output_dim, hidden_channels):
14
+ super().__init__()
15
+ self.hidden_channels = hidden_channels
16
+
17
+ self.conv1 = GCNConv(input_dim, hidden_channels)
18
+ self.conv2 = GCNConv(hidden_channels, hidden_channels)
19
+ self.conv3 = GCNConv(hidden_channels, hidden_channels)
20
+
21
+ self.lin = nn.Linear(hidden_channels, output_dim) # classification task 0 or 1
22
+
23
+ def forward(self, x, edge_index, batch):
24
+ x = self.conv1(x, edge_index)
25
+ x = x.relu()
26
+ x = self.conv2(x, edge_index)
27
+ x = x.relu()
28
+ x = self.conv3(x, edge_index)
29
+
30
+ # Averaging nodes and got the molecula vector
31
+ x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
32
+
33
+ x = F.dropout(x, p=0.5, training=self.training)
34
+ x = self.lin(x)
35
+ return x
GNN_classification/training.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import pandas as pd
5
+ from rdkit import Chem
6
+
7
+ from torch_geometric.loader import DataLoader
8
+
9
+ from Dataset_Preparation import SmilesDataset
10
+ from model import GNNClassifier
11
+
12
+ DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+ print(DEVICE)
14
+
15
+ def train(model, loader, optimizer, criterion):
16
+ model.train()
17
+ total_loss = 0
18
+
19
+ for batch in loader:
20
+ batch = batch.to(DEVICE)
21
+
22
+ optimizer.zero_grad()
23
+
24
+ out = model(batch.x, batch.edge_index, batch.batch)
25
+
26
+ loss = criterion(out, batch.y)
27
+ loss.backward()
28
+ optimizer.step()
29
+
30
+ total_loss += loss.item()
31
+
32
+ return total_loss / len(loader)
33
+
34
+
35
+ def test(model, loader):
36
+ model.eval()
37
+ correct = 0
38
+
39
+ with torch.no_grad():
40
+ for batch in loader:
41
+ batch = batch.to(DEVICE)
42
+ out = model(batch.x, batch.edge_index, batch.batch)
43
+
44
+ pred = out.argmax(dim=1)
45
+
46
+ correct += (pred == batch.y).sum().item()
47
+
48
+ acc = correct / len(loader.dataset)
49
+ return acc
50
+
51
+
52
+ if __name__ == "__main__":
53
+ columns = ["smiles", "label"]
54
+ train_dataset = pd.read_csv(
55
+ "dataset/classification/data_train.txt", sep=" ", header=None, names=columns
56
+ )
57
+ test_dataset = pd.read_csv(
58
+ "dataset/classification/data_test.txt", sep=" ", header=None, names=columns
59
+ )
60
+
61
+ train_dataset = SmilesDataset(train_dataset)
62
+ test_dataset = SmilesDataset(test_dataset)
63
+
64
+ num_node_features = train_dataset[0].x.shape[1]
65
+ num_classes = 2
66
+
67
+ print(f"Train samples: {len(train_dataset)}")
68
+ print(f"Test samples: {len(test_dataset)}")
69
+ print(f"Node features: {num_node_features}")
70
+
71
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
72
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
73
+
74
+ model = GNNClassifier(input_dim=1, output_dim=2, hidden_channels=16).to(DEVICE)
75
+
76
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
77
+ criterion = torch.nn.CrossEntropyLoss()
78
+
79
+ EPOCHS = 20
80
+ print("Start Training")
81
+
82
+ for epoch in range(1, EPOCHS + 1):
83
+ train_loss = train_epoch(model, train_loader, optimizer, criterion)
84
+
85
+ train_acc = evaluate(model, train_loader)
86
+ print(f"Epoch: {epoch}, Loss: {train_loss}, Train Accuracy: {train_acc}")
GNNs__practice.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
dataset_preparation.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from Bio.PDB import PDBParser
4
+ from Bio.SeqUtils import seq1
5
+ from Bio.PDB.Polypeptide import is_aa
6
+ from rdkit import Chem
7
+ from tqdm import tqdm
8
+
9
+ PDBBIND_PATH = "refined-set"
10
+ INDEX_NAME = "INDEX_refined_data.2020"
11
+
12
+
13
+ def get_ligand_smiles(pdb_id, pdb_dir_path):
14
+ """
15
+ Get the SMILES representation of the ligand.
16
+ """
17
+
18
+ sdf_path = os.path.join(pdb_dir_path, f"{pdb_id}_ligand.sdf")
19
+ mol2_path = os.path.join(pdb_dir_path, f"{pdb_id}_ligand.mol2")
20
+ if os.path.exists(sdf_path):
21
+ try:
22
+ sfd_file = Chem.SDMolSupplier(sdf_path)
23
+ if sfd_file:
24
+ mol = sfd_file[0]
25
+ except Exception:
26
+ mol = None
27
+
28
+ if mol is None and os.path.exists(mol2_path):
29
+ try:
30
+ mol = Chem.MolFromMol2File(mol2_path)
31
+ except Exception:
32
+ mol = None
33
+ if mol is not None:
34
+ smiles = Chem.MolToSmiles(mol)
35
+ return smiles
36
+ else:
37
+ return None
38
+
39
+
40
+ def get_protein_sequence(pdb_id, pdb_dir_path):
41
+ """
42
+ Get the protein sequence of the protein.
43
+ """
44
+ protein_path = os.path.join(pdb_dir_path, f"{pdb_id}_protein.pdb")
45
+ pdbparser = PDBParser()
46
+ structure = pdbparser.get_structure(pdb_id, protein_path)
47
+ sequences = []
48
+
49
+ for model in structure:
50
+ for chain in model:
51
+ sequence = ""
52
+ for residue in chain:
53
+ if residue.get_id()[0] == " " and is_aa(
54
+ residue.get_resname(), standard=True
55
+ ):
56
+ sequence += seq1(residue.get_resname())
57
+
58
+ sequences.append(sequence)
59
+ longest_sequence = max(sequences, key=len)
60
+ return longest_sequence
61
+
62
+
63
+ def main():
64
+ final_data = []
65
+
66
+ index_data = {}
67
+
68
+ index_file_path = os.path.join(PDBBIND_PATH, "index", INDEX_NAME)
69
+ with open(index_file_path, "r") as f:
70
+ for line in f:
71
+ if line.startswith("#"):
72
+ continue
73
+ parts = line.split()
74
+ pdb_id = parts[0]
75
+ print(pdb_id)
76
+ affinity = parts[3]
77
+
78
+ index_data[pdb_id] = affinity
79
+ print(f"Loaded index data for {len(index_data)} entries")
80
+
81
+ for pdb_id, affinity in tqdm(index_data.items()):
82
+ pdb_id_path = os.path.join(PDBBIND_PATH, pdb_id)
83
+
84
+ smiles = get_ligand_smiles(pdb_id, pdb_id_path)
85
+ sequence = get_protein_sequence(pdb_id, pdb_id_path)
86
+ if smiles is not None or sequence is not None:
87
+ final_data.append(
88
+ {
89
+ "pdb_id": pdb_id,
90
+ "smiles": smiles,
91
+ "sequence": sequence,
92
+ "affinity": affinity,
93
+ }
94
+ )
95
+
96
+ df = pd.DataFrame(final_data)
97
+ df.to_csv("pdbbind_refined_dataset.csv", index=False)
98
+
99
+
100
+ # pdb_id = "1a1e"
101
+ # PDF_ID_PATH = os.path.join(PDBBIND_PATH, pdb_id)
102
+ #
103
+ # smiles = get_ligand_smiles(pdb_id, PDF_ID_PATH)
104
+ # print(smiles)
105
+ #
106
+ # sequence = get_protein_sequence(pdb_id, PDF_ID_PATH)
107
+ # print(sequence)
108
+
109
+ if __name__ == "__main__":
110
+ main()
pdbbind_refined_dataset.csv ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ pandas
2
+ rdkit
3
+ biopython
4
+ torch
transformer_from_scratch/model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class InputEmbeddings(nn.Module):
8
+ def __init__(self, d_model: int, vocab_size: int):
9
+ super().__init__()
10
+ self.d_model = d_model
11
+ self.vocab_size = vocab_size
12
+ self.embedding = nn.Embedding(vocab_size, d_model) # vocab_size -> 512
13
+
14
+ def forward(self, x):
15
+ return self.embedding(x) * math.sqrt(self.d_model)
16
+
17
+
18
+ class PositionalEncoding(nn.Module):
19
+ def __init__(self, d_model: int, seq_len: int, dropout: float):
20
+ super().__init__()
21
+ self.d_model = d_model
22
+ self.seq_len = seq_len
23
+ self.dropout = nn.Dropout(dropout)
24
+
25
+ # Create a matrix of shape (seq_len, d_model)
26
+ pe = torch.zeros(seq_len, d_model)
27
+
28
+ # Create a vector of shape (seq_len, 1)
29
+ position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
30
+ 1
31
+ ) # (Seq_len, 1)
32
+ # Compute the positional encodings once in log space.
33
+ div_term = torch.exp(
34
+ torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
35
+ )
36
+ # Apply the sin to even positions
37
+ pe[:, 0::2] = torch.sin(position * div_term)
38
+ # Apply the cos to odd positions
39
+ pe[:, 1::2] = torch.cos(position * div_term)
40
+
41
+ pe = pe.unsqueeze(0) # (1, Seq_len, d_model) batch dimension
42
+ self.register_buffer("pe", pe)
43
+
44
+ def forward(self, x):
45
+ x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
46
+ return self.dropout(x)
47
+
48
+
49
+ class LayerNormalization(nn.Module):
50
+ def __init__(self, eps: float = 10e-6) -> None:
51
+ super().__init__()
52
+ self.eps = eps # avoid division by zero and huge numbers
53
+ self.alpha = nn.Parameter(torch.ones(1)) # Multiplied
54
+ self.bias = nn.Parameter(torch.zeros(1)) # Added
55
+
56
+ def forward(self, x):
57
+ mean = x.mean(dim=-1, keepdim=True) # To every sample
58
+ std = x.std(dim=-1, keepdim=True)
59
+ return self.alpha * (x - mean) / (std + self.eps) + self.bias
60
+
61
+
62
+ class FeedForwardBlock(nn.Module):
63
+ def __init__(self, d_model: int, d_ff: int, dropout: float):
64
+ super().__init__()
65
+ self.linear1 = nn.Linear(d_model, d_ff)
visualization.ipynb ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "initial_id",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "data": {
11
+ "application/vnd.jupyter.widget-view+json": {
12
+ "model_id": "ccfa267dcd6945b6be10a9cbeffb4e5e",
13
+ "version_major": 2,
14
+ "version_minor": 0
15
+ },
16
+ "text/plain": []
17
+ },
18
+ "metadata": {},
19
+ "output_type": "display_data"
20
+ }
21
+ ],
22
+ "source": [
23
+ "import nglview as nv\n",
24
+ "import os"
25
+ ]
26
+ },
27
+ {
28
+ "cell_type": "code",
29
+ "execution_count": 2,
30
+ "id": "d8d7978e-980a-400c-8c6a-5365990c8855",
31
+ "metadata": {},
32
+ "outputs": [],
33
+ "source": [
34
+ "PDBBIND_PATH = \"refined-set\""
35
+ ]
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "execution_count": 3,
40
+ "id": "788a6b43-c515-45c7-bc52-341d446b1a65",
41
+ "metadata": {},
42
+ "outputs": [],
43
+ "source": [
44
+ "EXAMPLE_PDB_ID = \"1a1e\""
45
+ ]
46
+ },
47
+ {
48
+ "cell_type": "code",
49
+ "execution_count": 4,
50
+ "id": "e8f4bebc-845f-43e8-bc4d-ab7b649eb49c",
51
+ "metadata": {},
52
+ "outputs": [],
53
+ "source": [
54
+ "pdb_dir = os.path.join(PDBBIND_PATH, EXAMPLE_PDB_ID)"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 5,
60
+ "id": "24b5e435-4d8f-4505-b27c-dd6317376ed4",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "protein_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_protein.pdb\")"
65
+ ]
66
+ },
67
+ {
68
+ "cell_type": "code",
69
+ "execution_count": 6,
70
+ "id": "e7fc3539-00c0-48a2-b012-c80757fa12c4",
71
+ "metadata": {},
72
+ "outputs": [],
73
+ "source": [
74
+ "ligand_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_ligand.sdf\")"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": 7,
80
+ "id": "9a053b99-7c01-4881-b3f7-e9b39090af9d",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "view = nv.NGLWidget()"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": 8,
90
+ "id": "df8c8e00-3ce6-41dd-b457-d9f50e318dad",
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "protein_comp = view.add_component(protein_file)"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": 9,
100
+ "id": "c191fead-fef8-4077-b787-5bf9552307b1",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "protein_comp.clear_representations()"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 10,
110
+ "id": "4559033a-aeda-4659-8d91-9002b5a6ecda",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "protein_comp.add_representation('cartoon', color='blue')"
115
+ ]
116
+ },
117
+ {
118
+ "cell_type": "code",
119
+ "execution_count": 11,
120
+ "id": "73ea1a50-8463-40b8-a942-0c92d3e97a97",
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": [
124
+ "ligand_comp = view.add_component(ligand_file)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": 12,
130
+ "id": "16cdb710-1ed6-4b1d-9e6a-69b7ad61a600",
131
+ "metadata": {},
132
+ "outputs": [],
133
+ "source": [
134
+ "ligand_comp.clear_representations()"
135
+ ]
136
+ },
137
+ {
138
+ "cell_type": "code",
139
+ "execution_count": 13,
140
+ "id": "2193c497-f33c-4de0-86a9-6e535002fcb7",
141
+ "metadata": {},
142
+ "outputs": [],
143
+ "source": [
144
+ "ligand_comp.add_representation('ball+stick', radius=0.3)"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": 14,
150
+ "id": "b1cc7f44-a374-4400-b4ba-8f75101b21ce",
151
+ "metadata": {},
152
+ "outputs": [
153
+ {
154
+ "data": {
155
+ "application/vnd.jupyter.widget-view+json": {
156
+ "model_id": "6037e0edee3247a49cd586e52e64a61b",
157
+ "version_major": 2,
158
+ "version_minor": 0
159
+ },
160
+ "text/plain": [
161
+ "NGLWidget()"
162
+ ]
163
+ },
164
+ "metadata": {},
165
+ "output_type": "display_data"
166
+ }
167
+ ],
168
+ "source": [
169
+ "view"
170
+ ]
171
+ },
172
+ {
173
+ "cell_type": "code",
174
+ "execution_count": null,
175
+ "id": "5655e465-bb44-4218-a5e3-db2c5e62cd9c",
176
+ "metadata": {},
177
+ "outputs": [],
178
+ "source": []
179
+ }
180
+ ],
181
+ "metadata": {
182
+ "kernelspec": {
183
+ "display_name": "Python 3 (ipykernel)",
184
+ "language": "python",
185
+ "name": "python3"
186
+ },
187
+ "language_info": {
188
+ "codemirror_mode": {
189
+ "name": "ipython",
190
+ "version": 3
191
+ },
192
+ "file_extension": ".py",
193
+ "mimetype": "text/x-python",
194
+ "name": "python",
195
+ "nbconvert_exporter": "python",
196
+ "pygments_lexer": "ipython3",
197
+ "version": "3.12.4"
198
+ }
199
+ },
200
+ "nbformat": 4,
201
+ "nbformat_minor": 5
202
+ }