FridayCode's picture
Deploy polymer property prediction model with LFS
c53d10d
from collections import defaultdict
import pickle
import numpy as np
from rdkit import Chem
import tqdm
import torch
def create_atoms(mol, atom_dict):
"""Transform the atom types in a molecule (e.g., H, C, and O)
into the indices (e.g., H=0, C=1, and O=2).
Note that each atom index considers the aromaticity.
"""
atoms = [a.GetSymbol() for a in mol.GetAtoms()]
for a in mol.GetAromaticAtoms():
i = a.GetIdx()
atoms[i] = (atoms[i], 'aromatic')
atoms = [atom_dict[a] for a in atoms]
return np.array(atoms)
def create_ijbonddict(mol, bond_dict):
"""Create a dictionary, in which each key is a node ID
and each value is the tuples of its neighboring node
and chemical bond (e.g., single and double) IDs.
"""
i_jbond_dict = defaultdict(lambda: [])
for b in mol.GetBonds():
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
bond = bond_dict[str(b.GetBondType())]
i_jbond_dict[i].append((j, bond))
i_jbond_dict[j].append((i, bond))
return i_jbond_dict
def extract_fingerprints(radius, atoms, i_jbond_dict,
fingerprint_dict, edge_dict):
"""Extract the fingerprints from a molecular graph
based on Weisfeiler-Lehman algorithm.
"""
if (len(atoms) == 1) or (radius == 0):
nodes = [fingerprint_dict[a] for a in atoms]
else:
nodes = atoms
i_jedge_dict = i_jbond_dict
for _ in range(radius):
"""Update each node ID considering its neighboring nodes and edges.
The updated node IDs are the fingerprint IDs.
"""
nodes_ = []
for i, j_edge in i_jedge_dict.items():
neighbors = [(nodes[j], edge) for j, edge in j_edge]
fingerprint = (nodes[i], tuple(sorted(neighbors)))
nodes_.append(fingerprint_dict[fingerprint])
"""Also update each edge ID considering
its two nodes on both sides.
"""
i_jedge_dict_ = defaultdict(lambda: [])
for i, j_edge in i_jedge_dict.items():
for j, edge in j_edge:
both_side = tuple(sorted((nodes[i], nodes[j])))
edge = edge_dict[(both_side, edge)]
i_jedge_dict_[i].append((j, edge))
nodes = nodes_
i_jedge_dict = i_jedge_dict_
return np.array(nodes)
def split_dataset(dataset, ratio):
"""Shuffle and split a dataset."""
np.random.seed(1234) # fix the seed for shuffle.
np.random.shuffle(dataset)
n = int(ratio * len(dataset))
return dataset[:n], dataset[n:]
def create_datasets(task, dataset, radius, device):
dir_dataset = './NIPS_GNN/dataset/' + task + '/' + dataset + '/'
"""Initialize x_dict, in which each key is a symbol type
(e.g., atom and chemical bond) and each value is its index.
"""
atom_dict = defaultdict(lambda: len(atom_dict))
bond_dict = defaultdict(lambda: len(bond_dict))
fingerprint_dict = defaultdict(lambda: len(fingerprint_dict))
edge_dict = defaultdict(lambda: len(edge_dict))
def create_dataset(filename):
"""Load a dataset."""
with open(dir_dataset + filename, 'r') as f:
smiles_property = f.readline().strip().split()
data_original = f.read().strip().split('\n')
"""Exclude the data contains '.' in its smiles."""
data_original = [data for data in data_original
if '.' not in data.split()[0]]
dataset = []
for data in tqdm.tqdm(data_original, total=len(data_original)):
smiles, property = data.strip().split()
"""Create each data with the above defined functions."""
mol = Chem.AddHs(Chem.MolFromSmiles(smiles))
atoms = create_atoms(mol, atom_dict)
molecular_size = len(atoms)
i_jbond_dict = create_ijbonddict(mol, bond_dict)
fingerprints = extract_fingerprints(radius, atoms, i_jbond_dict,
fingerprint_dict, edge_dict)
adjacency = Chem.GetAdjacencyMatrix(mol)
"""Transform the above each data of numpy
to pytorch tensor on a device (i.e., CPU or GPU).
"""
fingerprints = torch.LongTensor(fingerprints).to(device)
adjacency = torch.FloatTensor(adjacency).to(device)
if task == 'classification':
property = torch.LongTensor([int(property)]).to(device)
if task == 'regression':
property = torch.FloatTensor([[float(property)]]).to(device)
dataset.append((fingerprints, adjacency, molecular_size, property))
return dataset
dataset_train = create_dataset('data_train.txt')
dataset_train, dataset_dev = split_dataset(dataset_train, 0.9)
dataset_test = create_dataset('data_test.txt')
N_fingerprints = len(fingerprint_dict)
dict_path = f'./NIPS_GNN/model/{dataset.lower()}_dictionaries.pkl'
with open(dict_path, 'wb') as f:
pickle.dump((dict(atom_dict), dict(bond_dict), dict(fingerprint_dict), dict(edge_dict)), f)
print('Dictionaries saved at:', dict_path)
return dataset_train, dataset_dev, dataset_test, N_fingerprints