Commit ·
da7c0f0
0
Parent(s):
Added files
Browse files- .gitignore +2 -0
- GNN_classification/Dataset_Preparation.py +64 -0
- GNN_classification/dataset/classification/data_test.txt +0 -0
- GNN_classification/dataset/classification/data_train.txt +0 -0
- GNN_classification/model.py +35 -0
- GNN_classification/training.py +86 -0
- GNNs__practice.ipynb +0 -0
- dataset_preparation.py +110 -0
- pdbbind_refined_dataset.csv +0 -0
- requirements.txt +4 -0
- transformer_from_scratch/model.py +65 -0
- visualization.ipynb +202 -0
.gitignore
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.idea
|
| 2 |
+
.ipynb_checkpoints
|
GNN_classification/Dataset_Preparation.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import pandas as pd
|
| 3 |
+
|
| 4 |
+
from rdkit import Chem
|
| 5 |
+
from torch_geometric.data import Data
|
| 6 |
+
from torch.utils.data import Dataset
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class SmilesDataset(Dataset):
|
| 10 |
+
def __init__(self, dataframe):
|
| 11 |
+
self.data = dataframe
|
| 12 |
+
|
| 13 |
+
def __len__(self):
|
| 14 |
+
return len(self.data)
|
| 15 |
+
|
| 16 |
+
def __getitem__(self, idx):
|
| 17 |
+
row = self.data.iloc[idx]
|
| 18 |
+
smiles = row["smiles"]
|
| 19 |
+
label = row["label"]
|
| 20 |
+
|
| 21 |
+
mol = Chem.MolFromSmiles(smiles)
|
| 22 |
+
if mol is None: return None
|
| 23 |
+
|
| 24 |
+
# Nodes
|
| 25 |
+
atom_features = [[atom.GetAtomicNum()] for atom in mol.GetAtoms()]
|
| 26 |
+
x = torch.tensor(atom_features, dtype=torch.float)
|
| 27 |
+
|
| 28 |
+
# Edges
|
| 29 |
+
edge_indexes = []
|
| 30 |
+
for bond in mol.GetBonds():
|
| 31 |
+
i = bond.GetBeginAtomIdx()
|
| 32 |
+
j = bond.GetEndAtomIdx()
|
| 33 |
+
edge_indexes.append((i, j))
|
| 34 |
+
edge_indexes.append((j, i))
|
| 35 |
+
|
| 36 |
+
# t - transpose, [num_of_edges, 2] -> [2, num_of_edges]
|
| 37 |
+
# contiguous - take the virtually transposed tensor and make its physical copy and lay bytes sequentially
|
| 38 |
+
if not edge_indexes:
|
| 39 |
+
edge_index = torch.empty((2, 0), dtype=torch.long)
|
| 40 |
+
else:
|
| 41 |
+
edge_index = torch.tensor(edge_indexes, dtype=torch.long).t().contiguous()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Label
|
| 45 |
+
y = torch.tensor([label], dtype=torch.long)
|
| 46 |
+
return Data(x=x, edge_index=edge_index, y=y)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
columns = ["smiles", "label"]
|
| 51 |
+
train_dataset = pd.read_csv(
|
| 52 |
+
"dataset/classification/data_train.txt", sep=" ", header=None, names=columns
|
| 53 |
+
)
|
| 54 |
+
test_dataset = pd.read_csv(
|
| 55 |
+
"dataset/classification/data_test.txt", sep=" ", header=None, names=columns
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
train_dataset = SmilesDataset(train_dataset)
|
| 59 |
+
test_dataset = SmilesDataset(test_dataset)
|
| 60 |
+
|
| 61 |
+
print(len(train_dataset))
|
| 62 |
+
print(len(test_dataset))
|
| 63 |
+
|
| 64 |
+
|
GNN_classification/dataset/classification/data_test.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GNN_classification/dataset/classification/data_train.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
GNN_classification/model.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
|
| 7 |
+
from torch_geometric.nn import GCNConv, global_mean_pool
|
| 8 |
+
from torch_geometric.data import Data
|
| 9 |
+
from torch_geometric.loader import DataLoader
|
| 10 |
+
from torch.utils.data import Dataset
|
| 11 |
+
|
| 12 |
+
class GNNClassifier(nn.Module):
|
| 13 |
+
def __init__(self, input_dim, output_dim, hidden_channels):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.hidden_channels = hidden_channels
|
| 16 |
+
|
| 17 |
+
self.conv1 = GCNConv(input_dim, hidden_channels)
|
| 18 |
+
self.conv2 = GCNConv(hidden_channels, hidden_channels)
|
| 19 |
+
self.conv3 = GCNConv(hidden_channels, hidden_channels)
|
| 20 |
+
|
| 21 |
+
self.lin = nn.Linear(hidden_channels, output_dim) # classification task 0 or 1
|
| 22 |
+
|
| 23 |
+
def forward(self, x, edge_index, batch):
|
| 24 |
+
x = self.conv1(x, edge_index)
|
| 25 |
+
x = x.relu()
|
| 26 |
+
x = self.conv2(x, edge_index)
|
| 27 |
+
x = x.relu()
|
| 28 |
+
x = self.conv3(x, edge_index)
|
| 29 |
+
|
| 30 |
+
# Averaging nodes and got the molecula vector
|
| 31 |
+
x = global_mean_pool(x, batch) # [batch_size, hidden_channels]
|
| 32 |
+
|
| 33 |
+
x = F.dropout(x, p=0.5, training=self.training)
|
| 34 |
+
x = self.lin(x)
|
| 35 |
+
return x
|
GNN_classification/training.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from rdkit import Chem
|
| 6 |
+
|
| 7 |
+
from torch_geometric.loader import DataLoader
|
| 8 |
+
|
| 9 |
+
from Dataset_Preparation import SmilesDataset
|
| 10 |
+
from model import GNNClassifier
|
| 11 |
+
|
| 12 |
+
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 13 |
+
print(DEVICE)
|
| 14 |
+
|
| 15 |
+
def train(model, loader, optimizer, criterion):
|
| 16 |
+
model.train()
|
| 17 |
+
total_loss = 0
|
| 18 |
+
|
| 19 |
+
for batch in loader:
|
| 20 |
+
batch = batch.to(DEVICE)
|
| 21 |
+
|
| 22 |
+
optimizer.zero_grad()
|
| 23 |
+
|
| 24 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 25 |
+
|
| 26 |
+
loss = criterion(out, batch.y)
|
| 27 |
+
loss.backward()
|
| 28 |
+
optimizer.step()
|
| 29 |
+
|
| 30 |
+
total_loss += loss.item()
|
| 31 |
+
|
| 32 |
+
return total_loss / len(loader)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def test(model, loader):
|
| 36 |
+
model.eval()
|
| 37 |
+
correct = 0
|
| 38 |
+
|
| 39 |
+
with torch.no_grad():
|
| 40 |
+
for batch in loader:
|
| 41 |
+
batch = batch.to(DEVICE)
|
| 42 |
+
out = model(batch.x, batch.edge_index, batch.batch)
|
| 43 |
+
|
| 44 |
+
pred = out.argmax(dim=1)
|
| 45 |
+
|
| 46 |
+
correct += (pred == batch.y).sum().item()
|
| 47 |
+
|
| 48 |
+
acc = correct / len(loader.dataset)
|
| 49 |
+
return acc
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
columns = ["smiles", "label"]
|
| 54 |
+
train_dataset = pd.read_csv(
|
| 55 |
+
"dataset/classification/data_train.txt", sep=" ", header=None, names=columns
|
| 56 |
+
)
|
| 57 |
+
test_dataset = pd.read_csv(
|
| 58 |
+
"dataset/classification/data_test.txt", sep=" ", header=None, names=columns
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
train_dataset = SmilesDataset(train_dataset)
|
| 62 |
+
test_dataset = SmilesDataset(test_dataset)
|
| 63 |
+
|
| 64 |
+
num_node_features = train_dataset[0].x.shape[1]
|
| 65 |
+
num_classes = 2
|
| 66 |
+
|
| 67 |
+
print(f"Train samples: {len(train_dataset)}")
|
| 68 |
+
print(f"Test samples: {len(test_dataset)}")
|
| 69 |
+
print(f"Node features: {num_node_features}")
|
| 70 |
+
|
| 71 |
+
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
|
| 72 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
|
| 73 |
+
|
| 74 |
+
model = GNNClassifier(input_dim=1, output_dim=2, hidden_channels=16).to(DEVICE)
|
| 75 |
+
|
| 76 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 77 |
+
criterion = torch.nn.CrossEntropyLoss()
|
| 78 |
+
|
| 79 |
+
EPOCHS = 20
|
| 80 |
+
print("Start Training")
|
| 81 |
+
|
| 82 |
+
for epoch in range(1, EPOCHS + 1):
|
| 83 |
+
train_loss = train_epoch(model, train_loader, optimizer, criterion)
|
| 84 |
+
|
| 85 |
+
train_acc = evaluate(model, train_loader)
|
| 86 |
+
print(f"Epoch: {epoch}, Loss: {train_loss}, Train Accuracy: {train_acc}")
|
GNNs__practice.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
dataset_preparation.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from Bio.PDB import PDBParser
|
| 4 |
+
from Bio.SeqUtils import seq1
|
| 5 |
+
from Bio.PDB.Polypeptide import is_aa
|
| 6 |
+
from rdkit import Chem
|
| 7 |
+
from tqdm import tqdm
|
| 8 |
+
|
| 9 |
+
PDBBIND_PATH = "refined-set"
|
| 10 |
+
INDEX_NAME = "INDEX_refined_data.2020"
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_ligand_smiles(pdb_id, pdb_dir_path):
|
| 14 |
+
"""
|
| 15 |
+
Get the SMILES representation of the ligand.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
sdf_path = os.path.join(pdb_dir_path, f"{pdb_id}_ligand.sdf")
|
| 19 |
+
mol2_path = os.path.join(pdb_dir_path, f"{pdb_id}_ligand.mol2")
|
| 20 |
+
if os.path.exists(sdf_path):
|
| 21 |
+
try:
|
| 22 |
+
sfd_file = Chem.SDMolSupplier(sdf_path)
|
| 23 |
+
if sfd_file:
|
| 24 |
+
mol = sfd_file[0]
|
| 25 |
+
except Exception:
|
| 26 |
+
mol = None
|
| 27 |
+
|
| 28 |
+
if mol is None and os.path.exists(mol2_path):
|
| 29 |
+
try:
|
| 30 |
+
mol = Chem.MolFromMol2File(mol2_path)
|
| 31 |
+
except Exception:
|
| 32 |
+
mol = None
|
| 33 |
+
if mol is not None:
|
| 34 |
+
smiles = Chem.MolToSmiles(mol)
|
| 35 |
+
return smiles
|
| 36 |
+
else:
|
| 37 |
+
return None
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_protein_sequence(pdb_id, pdb_dir_path):
|
| 41 |
+
"""
|
| 42 |
+
Get the protein sequence of the protein.
|
| 43 |
+
"""
|
| 44 |
+
protein_path = os.path.join(pdb_dir_path, f"{pdb_id}_protein.pdb")
|
| 45 |
+
pdbparser = PDBParser()
|
| 46 |
+
structure = pdbparser.get_structure(pdb_id, protein_path)
|
| 47 |
+
sequences = []
|
| 48 |
+
|
| 49 |
+
for model in structure:
|
| 50 |
+
for chain in model:
|
| 51 |
+
sequence = ""
|
| 52 |
+
for residue in chain:
|
| 53 |
+
if residue.get_id()[0] == " " and is_aa(
|
| 54 |
+
residue.get_resname(), standard=True
|
| 55 |
+
):
|
| 56 |
+
sequence += seq1(residue.get_resname())
|
| 57 |
+
|
| 58 |
+
sequences.append(sequence)
|
| 59 |
+
longest_sequence = max(sequences, key=len)
|
| 60 |
+
return longest_sequence
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def main():
|
| 64 |
+
final_data = []
|
| 65 |
+
|
| 66 |
+
index_data = {}
|
| 67 |
+
|
| 68 |
+
index_file_path = os.path.join(PDBBIND_PATH, "index", INDEX_NAME)
|
| 69 |
+
with open(index_file_path, "r") as f:
|
| 70 |
+
for line in f:
|
| 71 |
+
if line.startswith("#"):
|
| 72 |
+
continue
|
| 73 |
+
parts = line.split()
|
| 74 |
+
pdb_id = parts[0]
|
| 75 |
+
print(pdb_id)
|
| 76 |
+
affinity = parts[3]
|
| 77 |
+
|
| 78 |
+
index_data[pdb_id] = affinity
|
| 79 |
+
print(f"Loaded index data for {len(index_data)} entries")
|
| 80 |
+
|
| 81 |
+
for pdb_id, affinity in tqdm(index_data.items()):
|
| 82 |
+
pdb_id_path = os.path.join(PDBBIND_PATH, pdb_id)
|
| 83 |
+
|
| 84 |
+
smiles = get_ligand_smiles(pdb_id, pdb_id_path)
|
| 85 |
+
sequence = get_protein_sequence(pdb_id, pdb_id_path)
|
| 86 |
+
if smiles is not None or sequence is not None:
|
| 87 |
+
final_data.append(
|
| 88 |
+
{
|
| 89 |
+
"pdb_id": pdb_id,
|
| 90 |
+
"smiles": smiles,
|
| 91 |
+
"sequence": sequence,
|
| 92 |
+
"affinity": affinity,
|
| 93 |
+
}
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
df = pd.DataFrame(final_data)
|
| 97 |
+
df.to_csv("pdbbind_refined_dataset.csv", index=False)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
# pdb_id = "1a1e"
|
| 101 |
+
# PDF_ID_PATH = os.path.join(PDBBIND_PATH, pdb_id)
|
| 102 |
+
#
|
| 103 |
+
# smiles = get_ligand_smiles(pdb_id, PDF_ID_PATH)
|
| 104 |
+
# print(smiles)
|
| 105 |
+
#
|
| 106 |
+
# sequence = get_protein_sequence(pdb_id, PDF_ID_PATH)
|
| 107 |
+
# print(sequence)
|
| 108 |
+
|
| 109 |
+
if __name__ == "__main__":
|
| 110 |
+
main()
|
pdbbind_refined_dataset.csv
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pandas
|
| 2 |
+
rdkit
|
| 3 |
+
biopython
|
| 4 |
+
torch
|
transformer_from_scratch/model.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class InputEmbeddings(nn.Module):
|
| 8 |
+
def __init__(self, d_model: int, vocab_size: int):
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.d_model = d_model
|
| 11 |
+
self.vocab_size = vocab_size
|
| 12 |
+
self.embedding = nn.Embedding(vocab_size, d_model) # vocab_size -> 512
|
| 13 |
+
|
| 14 |
+
def forward(self, x):
|
| 15 |
+
return self.embedding(x) * math.sqrt(self.d_model)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class PositionalEncoding(nn.Module):
|
| 19 |
+
def __init__(self, d_model: int, seq_len: int, dropout: float):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.d_model = d_model
|
| 22 |
+
self.seq_len = seq_len
|
| 23 |
+
self.dropout = nn.Dropout(dropout)
|
| 24 |
+
|
| 25 |
+
# Create a matrix of shape (seq_len, d_model)
|
| 26 |
+
pe = torch.zeros(seq_len, d_model)
|
| 27 |
+
|
| 28 |
+
# Create a vector of shape (seq_len, 1)
|
| 29 |
+
position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(
|
| 30 |
+
1
|
| 31 |
+
) # (Seq_len, 1)
|
| 32 |
+
# Compute the positional encodings once in log space.
|
| 33 |
+
div_term = torch.exp(
|
| 34 |
+
torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
|
| 35 |
+
)
|
| 36 |
+
# Apply the sin to even positions
|
| 37 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 38 |
+
# Apply the cos to odd positions
|
| 39 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 40 |
+
|
| 41 |
+
pe = pe.unsqueeze(0) # (1, Seq_len, d_model) batch dimension
|
| 42 |
+
self.register_buffer("pe", pe)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = x + (self.pe[:, : x.shape[1], :]).requires_grad_(False)
|
| 46 |
+
return self.dropout(x)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class LayerNormalization(nn.Module):
|
| 50 |
+
def __init__(self, eps: float = 10e-6) -> None:
|
| 51 |
+
super().__init__()
|
| 52 |
+
self.eps = eps # avoid division by zero and huge numbers
|
| 53 |
+
self.alpha = nn.Parameter(torch.ones(1)) # Multiplied
|
| 54 |
+
self.bias = nn.Parameter(torch.zeros(1)) # Added
|
| 55 |
+
|
| 56 |
+
def forward(self, x):
|
| 57 |
+
mean = x.mean(dim=-1, keepdim=True) # To every sample
|
| 58 |
+
std = x.std(dim=-1, keepdim=True)
|
| 59 |
+
return self.alpha * (x - mean) / (std + self.eps) + self.bias
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class FeedForwardBlock(nn.Module):
|
| 63 |
+
def __init__(self, d_model: int, d_ff: int, dropout: float):
|
| 64 |
+
super().__init__()
|
| 65 |
+
self.linear1 = nn.Linear(d_model, d_ff)
|
visualization.ipynb
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "initial_id",
|
| 7 |
+
"metadata": {},
|
| 8 |
+
"outputs": [
|
| 9 |
+
{
|
| 10 |
+
"data": {
|
| 11 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 12 |
+
"model_id": "ccfa267dcd6945b6be10a9cbeffb4e5e",
|
| 13 |
+
"version_major": 2,
|
| 14 |
+
"version_minor": 0
|
| 15 |
+
},
|
| 16 |
+
"text/plain": []
|
| 17 |
+
},
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"output_type": "display_data"
|
| 20 |
+
}
|
| 21 |
+
],
|
| 22 |
+
"source": [
|
| 23 |
+
"import nglview as nv\n",
|
| 24 |
+
"import os"
|
| 25 |
+
]
|
| 26 |
+
},
|
| 27 |
+
{
|
| 28 |
+
"cell_type": "code",
|
| 29 |
+
"execution_count": 2,
|
| 30 |
+
"id": "d8d7978e-980a-400c-8c6a-5365990c8855",
|
| 31 |
+
"metadata": {},
|
| 32 |
+
"outputs": [],
|
| 33 |
+
"source": [
|
| 34 |
+
"PDBBIND_PATH = \"refined-set\""
|
| 35 |
+
]
|
| 36 |
+
},
|
| 37 |
+
{
|
| 38 |
+
"cell_type": "code",
|
| 39 |
+
"execution_count": 3,
|
| 40 |
+
"id": "788a6b43-c515-45c7-bc52-341d446b1a65",
|
| 41 |
+
"metadata": {},
|
| 42 |
+
"outputs": [],
|
| 43 |
+
"source": [
|
| 44 |
+
"EXAMPLE_PDB_ID = \"1a1e\""
|
| 45 |
+
]
|
| 46 |
+
},
|
| 47 |
+
{
|
| 48 |
+
"cell_type": "code",
|
| 49 |
+
"execution_count": 4,
|
| 50 |
+
"id": "e8f4bebc-845f-43e8-bc4d-ab7b649eb49c",
|
| 51 |
+
"metadata": {},
|
| 52 |
+
"outputs": [],
|
| 53 |
+
"source": [
|
| 54 |
+
"pdb_dir = os.path.join(PDBBIND_PATH, EXAMPLE_PDB_ID)"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": 5,
|
| 60 |
+
"id": "24b5e435-4d8f-4505-b27c-dd6317376ed4",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"protein_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_protein.pdb\")"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": 6,
|
| 70 |
+
"id": "e7fc3539-00c0-48a2-b012-c80757fa12c4",
|
| 71 |
+
"metadata": {},
|
| 72 |
+
"outputs": [],
|
| 73 |
+
"source": [
|
| 74 |
+
"ligand_file = os.path.join(pdb_dir, f\"{EXAMPLE_PDB_ID}_ligand.sdf\")"
|
| 75 |
+
]
|
| 76 |
+
},
|
| 77 |
+
{
|
| 78 |
+
"cell_type": "code",
|
| 79 |
+
"execution_count": 7,
|
| 80 |
+
"id": "9a053b99-7c01-4881-b3f7-e9b39090af9d",
|
| 81 |
+
"metadata": {},
|
| 82 |
+
"outputs": [],
|
| 83 |
+
"source": [
|
| 84 |
+
"view = nv.NGLWidget()"
|
| 85 |
+
]
|
| 86 |
+
},
|
| 87 |
+
{
|
| 88 |
+
"cell_type": "code",
|
| 89 |
+
"execution_count": 8,
|
| 90 |
+
"id": "df8c8e00-3ce6-41dd-b457-d9f50e318dad",
|
| 91 |
+
"metadata": {},
|
| 92 |
+
"outputs": [],
|
| 93 |
+
"source": [
|
| 94 |
+
"protein_comp = view.add_component(protein_file)"
|
| 95 |
+
]
|
| 96 |
+
},
|
| 97 |
+
{
|
| 98 |
+
"cell_type": "code",
|
| 99 |
+
"execution_count": 9,
|
| 100 |
+
"id": "c191fead-fef8-4077-b787-5bf9552307b1",
|
| 101 |
+
"metadata": {},
|
| 102 |
+
"outputs": [],
|
| 103 |
+
"source": [
|
| 104 |
+
"protein_comp.clear_representations()"
|
| 105 |
+
]
|
| 106 |
+
},
|
| 107 |
+
{
|
| 108 |
+
"cell_type": "code",
|
| 109 |
+
"execution_count": 10,
|
| 110 |
+
"id": "4559033a-aeda-4659-8d91-9002b5a6ecda",
|
| 111 |
+
"metadata": {},
|
| 112 |
+
"outputs": [],
|
| 113 |
+
"source": [
|
| 114 |
+
"protein_comp.add_representation('cartoon', color='blue')"
|
| 115 |
+
]
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"cell_type": "code",
|
| 119 |
+
"execution_count": 11,
|
| 120 |
+
"id": "73ea1a50-8463-40b8-a942-0c92d3e97a97",
|
| 121 |
+
"metadata": {},
|
| 122 |
+
"outputs": [],
|
| 123 |
+
"source": [
|
| 124 |
+
"ligand_comp = view.add_component(ligand_file)"
|
| 125 |
+
]
|
| 126 |
+
},
|
| 127 |
+
{
|
| 128 |
+
"cell_type": "code",
|
| 129 |
+
"execution_count": 12,
|
| 130 |
+
"id": "16cdb710-1ed6-4b1d-9e6a-69b7ad61a600",
|
| 131 |
+
"metadata": {},
|
| 132 |
+
"outputs": [],
|
| 133 |
+
"source": [
|
| 134 |
+
"ligand_comp.clear_representations()"
|
| 135 |
+
]
|
| 136 |
+
},
|
| 137 |
+
{
|
| 138 |
+
"cell_type": "code",
|
| 139 |
+
"execution_count": 13,
|
| 140 |
+
"id": "2193c497-f33c-4de0-86a9-6e535002fcb7",
|
| 141 |
+
"metadata": {},
|
| 142 |
+
"outputs": [],
|
| 143 |
+
"source": [
|
| 144 |
+
"ligand_comp.add_representation('ball+stick', radius=0.3)"
|
| 145 |
+
]
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"cell_type": "code",
|
| 149 |
+
"execution_count": 14,
|
| 150 |
+
"id": "b1cc7f44-a374-4400-b4ba-8f75101b21ce",
|
| 151 |
+
"metadata": {},
|
| 152 |
+
"outputs": [
|
| 153 |
+
{
|
| 154 |
+
"data": {
|
| 155 |
+
"application/vnd.jupyter.widget-view+json": {
|
| 156 |
+
"model_id": "6037e0edee3247a49cd586e52e64a61b",
|
| 157 |
+
"version_major": 2,
|
| 158 |
+
"version_minor": 0
|
| 159 |
+
},
|
| 160 |
+
"text/plain": [
|
| 161 |
+
"NGLWidget()"
|
| 162 |
+
]
|
| 163 |
+
},
|
| 164 |
+
"metadata": {},
|
| 165 |
+
"output_type": "display_data"
|
| 166 |
+
}
|
| 167 |
+
],
|
| 168 |
+
"source": [
|
| 169 |
+
"view"
|
| 170 |
+
]
|
| 171 |
+
},
|
| 172 |
+
{
|
| 173 |
+
"cell_type": "code",
|
| 174 |
+
"execution_count": null,
|
| 175 |
+
"id": "5655e465-bb44-4218-a5e3-db2c5e62cd9c",
|
| 176 |
+
"metadata": {},
|
| 177 |
+
"outputs": [],
|
| 178 |
+
"source": []
|
| 179 |
+
}
|
| 180 |
+
],
|
| 181 |
+
"metadata": {
|
| 182 |
+
"kernelspec": {
|
| 183 |
+
"display_name": "Python 3 (ipykernel)",
|
| 184 |
+
"language": "python",
|
| 185 |
+
"name": "python3"
|
| 186 |
+
},
|
| 187 |
+
"language_info": {
|
| 188 |
+
"codemirror_mode": {
|
| 189 |
+
"name": "ipython",
|
| 190 |
+
"version": 3
|
| 191 |
+
},
|
| 192 |
+
"file_extension": ".py",
|
| 193 |
+
"mimetype": "text/x-python",
|
| 194 |
+
"name": "python",
|
| 195 |
+
"nbconvert_exporter": "python",
|
| 196 |
+
"pygments_lexer": "ipython3",
|
| 197 |
+
"version": "3.12.4"
|
| 198 |
+
}
|
| 199 |
+
},
|
| 200 |
+
"nbformat": 4,
|
| 201 |
+
"nbformat_minor": 5
|
| 202 |
+
}
|