Zaixi's picture
1
dcacefd
import torch
import matplotlib
import numpy as np
import scipy.stats as sp_stats
atom_encoder = {'H': 1, 'C': 6, 'N': 7, 'O': 8, 'F': 9, 'P': 15, 'S': 16, 'Cl': 17}
atom_decoder = {v: k for k, v in atom_encoder.items()}
# Bond lengths from http://www.wiredchemist.com/chemistry/data/bond_energies_lengths.html
bonds1 = {'H': {'H': 74, 'C': 109, 'N': 101, 'O': 96, 'F': 92, 'P': 144, 'S': 134, 'Cl': 127},
'C': {'H': 109, 'C': 154, 'N': 147, 'O': 143, 'F': 135, 'P': 184, 'S': 182, 'Cl': 177},
'N': {'H': 101, 'C': 147, 'N': 145, 'O': 140, 'F': 136, 'P': 177, 'S': 168, 'Cl': 175},
'O': {'H': 96, 'C': 143, 'N': 140, 'O': 148, 'F': 142, 'P': 163, 'S': 151, 'Cl': 164},
'F': {'H': 92, 'C': 135, 'N': 136, 'O': 142, 'F': 142, 'P': 156, 'S': 158, 'Cl': 166},
'P': {'H': 144, 'C': 184, 'N': 177, 'O': 163, 'F': 156, 'P': 221, 'S': 210, 'Cl': 203},
'S': {'H': 134, 'C': 182, 'N': 168, 'O': 151, 'F': 158, 'P': 210, 'S': 204, 'Cl': 207},
'Cl': {'H': 127, 'C': 177, 'N': 175, 'O': 164, 'F': 166, 'P': 203, 'S': 207, 'Cl': 199}
}
bonds2 = {'H': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'C': {'H': -1, 'C': 134, 'N': 129, 'O': 120, 'F': -1, 'P': -1, 'S': 160, 'Cl': -1},
'N': {'H': -1, 'C': 129, 'N': 125, 'O': 121, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'O': {'H': -1, 'C': 120, 'N': 121, 'O': 121, 'F': -1, 'P': 150, 'S': -1, 'Cl': -1},
'F': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'P': {'H': -1, 'C': -1, 'N': -1, 'O': 150, 'F': -1, 'P': -1, 'S': 186, 'Cl': -1},
'S': {'H': -1, 'C': 160, 'N': -1, 'O': -1, 'F': -1, 'P': 186, 'S': -1, 'Cl': -1},
'Cl': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
}
bonds3 = {'H': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'C': {'H': -1, 'C': 120, 'N': 116, 'O': 113, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'N': {'H': -1, 'C': 116, 'N': 110, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'O': {'H': -1, 'C': 113, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'F': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'P': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'S': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
'Cl': {'H': -1, 'C': -1, 'N': -1, 'O': -1, 'F': -1, 'P': -1, 'S': -1, 'Cl': -1},
}
stdv = {'H': 5, 'C': 1, 'N': 1, 'O': 2, 'F': 3}
margin1, margin2, margin3 = 10, 5, 3
allowed_bonds = {'H': 1, 'C': 4, 'N': 3, 'O': 2, 'F': 1, 'P': 5, 'S': 4, 'Cl': 1}
def normalize_histogram(hist):
hist = np.array(hist)
prob = hist / np.sum(hist)
return prob
def coord2distances(x):
x = x.unsqueeze(2)
x_t = x.transpose(1, 2)
dist = (x - x_t) ** 2
dist = torch.sqrt(torch.sum(dist, 3))
dist = dist.flatten()
return dist
def earth_mover_distance(h1, h2):
p1 = normalize_histogram(h1)
p2 = normalize_histogram(h2)
distance = sp_stats.wasserstein_distance(p1, p2)
return distance
def kl_divergence(p1, p2):
return np.sum(p1 * np.log(p1 / p2))
def kl_divergence_sym(h1, h2):
p1 = normalize_histogram(h1) + 1e-10
p2 = normalize_histogram(h2) + 1e-10
kl = kl_divergence(p1, p2)
kl_flipped = kl_divergence(p2, p1)
return (kl + kl_flipped) / 2.
def js_divergence(h1, h2):
p1 = normalize_histogram(h1) + 1e-10
p2 = normalize_histogram(h2) + 1e-10
M = (p1 + p2) / 2
js = (kl_divergence(p1, M) + kl_divergence(p2, M)) / 2
return js
def get_bond_order(atom1, atom2, distance):
distance = 100 * distance # We change the metric
# margin1, margin2 and margin3 have been tuned to maximize the stability of the QM9 true samples
if distance < bonds1[atom1][atom2] + margin1:
thr_bond2 = bonds2[atom1][atom2] + margin2
if distance < thr_bond2:
thr_bond3 = bonds3[atom1][atom2] + margin3
if distance < thr_bond3:
return 3
return 2
return 1
return 0
def check_stability(positions, atom_type, debug=False, hs=False, return_nr_bonds=False):
assert len(positions.shape) == 2
assert positions.shape[1] == 3
x = positions[:, 0]
y = positions[:, 1]
z = positions[:, 2]
nr_bonds = np.zeros(len(x), dtype='int')
for i in range(len(x)):
for j in range(i + 1, len(x)):
p1 = np.array([x[i], y[i], z[i]])
p2 = np.array([x[j], y[j], z[j]])
dist = np.sqrt(np.sum((p1 - p2) ** 2))
atom1, atom2 = atom_decoder[atom_type[i]], atom_decoder[
atom_type[j]]
order = get_bond_order(atom1, atom2, dist)
# if i == 0:
# print(j, order)
nr_bonds[i] += order
nr_bonds[j] += order
nr_stable_bonds = 0
for atom_type_i, nr_bonds_i in zip(atom_type, nr_bonds):
if hs:
is_stable = allowed_bonds[atom_decoder[atom_type_i]] == nr_bonds_i
else:
is_stable = (allowed_bonds[atom_decoder[atom_type_i]] >= nr_bonds_i > 0)
if is_stable == False and debug:
print("Invalid bonds for molecule %s with %d bonds" % (atom_decoder[atom_type_i], nr_bonds_i))
nr_stable_bonds += int(is_stable)
molecule_stable = nr_stable_bonds == len(x)
if return_nr_bonds:
return molecule_stable, nr_stable_bonds, len(x), nr_bonds
else:
return molecule_stable, nr_stable_bonds, len(x)
def analyze_stability_for_molecules(molecule_list):
n_samples = len(molecule_list)
molecule_stable_list = []
molecule_stable = 0
nr_stable_bonds = 0
n_atoms = 0
for one_hot, x in molecule_list:
atom_type = one_hot.argmax(2).squeeze(0).cpu().detach().numpy()
x = x.squeeze(0).cpu().detach().numpy()
validity_results = check_stability(x, atom_type)
molecule_stable += int(validity_results[0])
nr_stable_bonds += int(validity_results[1])
n_atoms += int(validity_results[2])
if validity_results[0]:
molecule_stable_list.append((x, atom_type))
# Validity
fraction_mol_stable = molecule_stable / float(n_samples)
fraction_atm_stable = nr_stable_bonds / float(n_atoms)
validity_dict = {
'mol_stable': fraction_mol_stable,
'atm_stable': fraction_atm_stable,
}
# print('Validity:', validity_dict)
return validity_dict, molecule_stable_list