Spaces:
Configuration error
Configuration error
File size: 5,316 Bytes
c53d10d | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | from collections import defaultdict
import pickle
import numpy as np
from rdkit import Chem
import tqdm
import torch
def create_atoms(mol, atom_dict):
"""Transform the atom types in a molecule (e.g., H, C, and O)
into the indices (e.g., H=0, C=1, and O=2).
Note that each atom index considers the aromaticity.
"""
atoms = [a.GetSymbol() for a in mol.GetAtoms()]
for a in mol.GetAromaticAtoms():
i = a.GetIdx()
atoms[i] = (atoms[i], 'aromatic')
atoms = [atom_dict[a] for a in atoms]
return np.array(atoms)
def create_ijbonddict(mol, bond_dict):
"""Create a dictionary, in which each key is a node ID
and each value is the tuples of its neighboring node
and chemical bond (e.g., single and double) IDs.
"""
i_jbond_dict = defaultdict(lambda: [])
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
bond = bond_dict[str(b.GetBondType())]
i_jbond_dict[i].append((j, bond))
i_jbond_dict[j].append((i, bond))
return i_jbond_dict
def extract_fingerprints(radius, atoms, i_jbond_dict,
fingerprint_dict, edge_dict):
"""Extract the fingerprints from a molecular graph
based on Weisfeiler-Lehman algorithm.
"""
if (len(atoms) == 1) or (radius == 0):
nodes = [fingerprint_dict[a] for a in atoms]
else:
nodes = atoms
i_jedge_dict = i_jbond_dict
for _ in range(radius):
"""Update each node ID considering its neighboring nodes and edges.
The updated node IDs are the fingerprint IDs.
"""
nodes_ = []
for i, j_edge in i_jedge_dict.items():
neighbors = [(nodes[j], edge) for j, edge in j_edge]
fingerprint = (nodes[i], tuple(sorted(neighbors)))
nodes_.append(fingerprint_dict[fingerprint])
"""Also update each edge ID considering
its two nodes on both sides.
"""
i_jedge_dict_ = defaultdict(lambda: [])
for i, j_edge in i_jedge_dict.items():
for j, edge in j_edge:
both_side = tuple(sorted((nodes[i], nodes[j])))
edge = edge_dict[(both_side, edge)]
i_jedge_dict_[i].append((j, edge))
nodes = nodes_
i_jedge_dict = i_jedge_dict_
return np.array(nodes)
def split_dataset(dataset, ratio):
"""Shuffle and split a dataset."""
np.random.seed(1234) # fix the seed for shuffle.
np.random.shuffle(dataset)
n = int(ratio * len(dataset))
return dataset[:n], dataset[n:]
def create_datasets(task, dataset, radius, device):
dir_dataset = './NIPS_GNN/dataset/' + task + '/' + dataset + '/'
"""Initialize x_dict, in which each key is a symbol type
(e.g., atom and chemical bond) and each value is its index.
"""
atom_dict = defaultdict(lambda: len(atom_dict))
bond_dict = defaultdict(lambda: len(bond_dict))
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
edge_dict = defaultdict(lambda: len(edge_dict))
def create_dataset(filename):
"""Load a dataset."""
with open(dir_dataset + filename, 'r') as f:
smiles_property = f.readline().strip().split()
data_original = f.read().strip().split('\n')
"""Exclude the data contains '.' in its smiles."""
data_original = [data for data in data_original
if '.' not in data.split()[0]]
dataset = []
for data in tqdm.tqdm(data_original, total=len(data_original)):
smiles, property = data.strip().split()
"""Create each data with the above defined functions."""
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
atoms = create_atoms(mol, atom_dict)
molecular_size = len(atoms)
i_jbond_dict = create_ijbonddict(mol, bond_dict)
fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict,
fingerprint_dict, edge_dict)
adjacency = Chem.GetAdjacencyMatrix(mol)
"""Transform the above each data of numpy
to pytorch tensor on a device (i.e., CPU or GPU).
"""
fingerprints = torch.LongTensor(fingerprints).to(device)
adjacency = torch.FloatTensor(adjacency).to(device)
if task == 'classification':
property = torch.LongTensor([int(property)]).to(device)
if task == 'regression':
property = torch.FloatTensor([[float(property)]]).to(device)
dataset.append((fingerprints, adjacency, molecular_size, property))
return dataset
dataset_train = create_dataset('data_train.txt')
dataset_train, dataset_dev = split_dataset(dataset_train, 0.9)
dataset_test = create_dataset('data_test.txt')
N_fingerprints = len(fingerprint_dict)
dict_path = f'./NIPS_GNN/model/{dataset.lower()}_dictionaries.pkl'
with open(dict_path, 'wb') as f:
pickle.dump((dict(atom_dict), dict(bond_dict), dict(fingerprint_dict), dict(edge_dict)), f)
print('Dictionaries saved at:', dict_path)
return dataset_train, dataset_dev, dataset_test, N_fingerprints
|