from collections import defaultdict import pickle import numpy as np from rdkit import Chem import tqdm import torch def create_atoms(mol, atom_dict): """Transform the atom types in a molecule (e.g., H, C, and O) into the indices (e.g., H=0, C=1, and O=2). Note that each atom index considers the aromaticity. """ atoms = [a.GetSymbol() for a in mol.GetAtoms()] for a in mol.GetAromaticAtoms(): i = a.GetIdx() atoms[i] = (atoms[i], 'aromatic') atoms = [atom_dict[a] for a in atoms] return np.array(atoms) def create_ijbonddict(mol, bond_dict): """Create a dictionary, in which each key is a node ID and each value is the tuples of its neighboring node and chemical bond (e.g., single and double) IDs. """ i_jbond_dict = defaultdict(lambda: []) for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() bond = bond_dict[str(b.GetBondType())] i_jbond_dict[i].append((j, bond)) i_jbond_dict[j].append((i, bond)) return i_jbond_dict def extract_fingerprints(radius, atoms, i_jbond_dict, fingerprint_dict, edge_dict): """Extract the fingerprints from a molecular graph based on Weisfeiler-Lehman algorithm. """ if (len(atoms) == 1) or (radius == 0): nodes = [fingerprint_dict[a] for a in atoms] else: nodes = atoms i_jedge_dict = i_jbond_dict for _ in range(radius): """Update each node ID considering its neighboring nodes and edges. The updated node IDs are the fingerprint IDs. """ nodes_ = [] for i, j_edge in i_jedge_dict.items(): neighbors = [(nodes[j], edge) for j, edge in j_edge] fingerprint = (nodes[i], tuple(sorted(neighbors))) nodes_.append(fingerprint_dict[fingerprint]) """Also update each edge ID considering its two nodes on both sides. """ i_jedge_dict_ = defaultdict(lambda: []) for i, j_edge in i_jedge_dict.items(): for j, edge in j_edge: both_side = tuple(sorted((nodes[i], nodes[j]))) edge = edge_dict[(both_side, edge)] i_jedge_dict_[i].append((j, edge)) nodes = nodes_ i_jedge_dict = i_jedge_dict_ return np.array(nodes) def split_dataset(dataset, ratio): """Shuffle and split a dataset.""" np.random.seed(1234) # fix the seed for shuffle. np.random.shuffle(dataset) n = int(ratio * len(dataset)) return dataset[:n], dataset[n:] def create_datasets(task, dataset, radius, device): dir_dataset = './NIPS_GNN/dataset/' + task + '/' + dataset + '/' """Initialize x_dict, in which each key is a symbol type (e.g., atom and chemical bond) and each value is its index. """ atom_dict = defaultdict(lambda: len(atom_dict)) bond_dict = defaultdict(lambda: len(bond_dict)) fingerprint_dict = defaultdict(lambda: len(fingerprint_dict)) edge_dict = defaultdict(lambda: len(edge_dict)) def create_dataset(filename): """Load a dataset.""" with open(dir_dataset + filename, 'r') as f: smiles_property = f.readline().strip().split() data_original = f.read().strip().split('\n') """Exclude the data contains '.' in its smiles.""" data_original = [data for data in data_original if '.' not in data.split()[0]] dataset = [] for data in tqdm.tqdm(data_original, total=len(data_original)): smiles, property = data.strip().split() """Create each data with the above defined functions.""" mol = Chem.AddHs(Chem.MolFromSmiles(smiles)) atoms = create_atoms(mol, atom_dict) molecular_size = len(atoms) i_jbond_dict = create_ijbonddict(mol, bond_dict) fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict, fingerprint_dict, edge_dict) adjacency = Chem.GetAdjacencyMatrix(mol) """Transform the above each data of numpy to pytorch tensor on a device (i.e., CPU or GPU). """ fingerprints = torch.LongTensor(fingerprints).to(device) adjacency = torch.FloatTensor(adjacency).to(device) if task == 'classification': property = torch.LongTensor([int(property)]).to(device) if task == 'regression': property = torch.FloatTensor([[float(property)]]).to(device) dataset.append((fingerprints, adjacency, molecular_size, property)) return dataset dataset_train = create_dataset('data_train.txt') dataset_train, dataset_dev = split_dataset(dataset_train, 0.9) dataset_test = create_dataset('data_test.txt') N_fingerprints = len(fingerprint_dict) dict_path = f'./NIPS_GNN/model/{dataset.lower()}_dictionaries.pkl' with open(dict_path, 'wb') as f: pickle.dump((dict(atom_dict), dict(bond_dict), dict(fingerprint_dict), dict(edge_dict)), f) print('Dictionaries saved at:', dict_path) return dataset_train, dataset_dev, dataset_test, N_fingerprints