File size: 8,495 Bytes
88cd151 0b9d81c 88cd151 d8bf97a 88cd151 0b9d81c d8bf97a 0b9d81c 543d14b 88cd151 d8bf97a 88cd151 100bebf 88cd151 100bebf 88cd151 100bebf 88cd151 100bebf 88cd151 100bebf 88cd151 100bebf d8bf97a 88cd151 100bebf 88cd151 d8bf97a 88cd151 d8bf97a 88cd151 d8bf97a 88cd151 d8bf97a 88cd151 d8bf97a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 | import ase
import jax
import jax.numpy as jnp
import numpy as np
import json
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from scipy.spatial.distance import cdist
from dcmnet.data import prepare_batches
from dcmnet.modules import MessagePassingModel
from dcmnet.loss import (
esp_loss_eval,
esp_loss_pots,
esp_mono_loss_pots,
get_predictions,
)
from dcmnet.plotting import evaluate_dc, create_plots2
from dcmnet.multimodel import get_atoms_dcmol
from dcmnet.multipoles import plot_3d
from dcmnet.utils import apply_model, clip_colors, reshape_dipole
RANDOM_NUMBER = 0
data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
# Model hyperparameters.
features = 16
max_degree = 2
num_iterations = 2
num_basis_functions = 8
cutoff = 4.0
def create_models():
dcm1 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=1,
)
dcm2 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=2,
)
dcm3 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=3,
)
dcm4 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=4,
)
return dcm1, dcm2, dcm3, dcm4
def get_grid_points(coordinates):
"""
create a uniform grid of points around the molecule,
starting from minimum and maximum coordinates of the molecule (plus minus some padding)
:param coordinates:
:return:
"""
bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)])
padding = 3.0
bounds = bounds + np.array([-1, 1])[:, None] * padding
grid_points = np.meshgrid(
*[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])]
)
grid_points = np.stack(grid_points, axis=0)
grid_points = np.reshape(grid_points.T, [-1, 3])
# exclude points that are too close to the molecule
grid_points = grid_points[
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]
]
return grid_points
def restore_arrays(obj):
if isinstance(obj, dict):
return {key: restore_arrays(value) for key, value in obj.items()}
if isinstance(obj, list):
restored = [restore_arrays(value) for value in obj]
if any(isinstance(value, dict) for value in restored):
return restored
try:
return jnp.asarray(restored)
except Exception:
return restored
return obj
def load_json_dict(path):
with open(path, "r", encoding="utf-8") as handle:
payload = json.load(handle)
return restore_arrays(payload)
def load_weights():
dcm1_weights = load_json_dict("wbs/best_0.0_params_dict.json")
dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params_dict.json")
dcm3_weights = load_json_dict("wbs/dcm3-best_1000.0_params_dict.json")
dcm4_weights = load_json_dict("wbs/dcm4-best_1000.0_params_dict.json")
return dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights
def prepare_inputs(smiles):
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for atom in smiles_mol.GetAtoms():
atom.SetProp("atomNote", str(atom.GetIdx()))
smiles_image = Draw.MolToImage(smiles_mol)
vdw_surface = surface
max_n_atoms = 60
max_grid_points = 3143
try:
z_values = [np.array([int(_) for _ in elements])]
except Exception:
z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))])
pad_coords = np.array(
[
np.pad(
coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0))
)
]
)
pad_vdw_surface = []
padded_surface = np.pad(
vdw_surface,
((0, max_grid_points - len(vdw_surface)), (0, 0)),
"constant",
constant_values=(0, 10000),
)
pad_vdw_surface.append(padded_surface)
pad_vdw_surface = np.array(pad_vdw_surface)
n_atoms = np.sum(pad_z != 0)
data_batch = dict(
atomic_numbers=jnp.asarray(pad_z),
Z=jnp.asarray(pad_z),
positions=jnp.asarray(pad_coords),
R=jnp.asarray(pad_coords),
# N is the number of atoms
N=jnp.asarray([n_atoms]),
mono=jnp.asarray(pad_z),
ngrid=jnp.array([len(vdw_surface)]),
n_grid=jnp.array([len(vdw_surface)]),
esp=jnp.asarray([np.zeros(max_grid_points)]),
vdw_surface=jnp.asarray(pad_vdw_surface),
espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_),
)
return data_batch, smiles_image
def do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm):
esp_errors, mono_pred, _, _ = evaluate_dc(
batch,
dipo_dc1,
mono_dc1,
batch_size,
n_dcm,
plot=False,
)
n_atoms = int(batch.get("N", jnp.array([jnp.count_nonzero(batch["Z"])]))[0])
n_dcm = mono_dc1.shape[-1]
atoms = ase.Atoms(
numbers=np.array(batch["Z"][:n_atoms]),
positions=np.array(batch["R"][:n_atoms]),
)
dcm_positions = np.array(dipo_dc1).reshape(-1, 3)[: n_atoms * n_dcm]
dcm_charges = np.array(mono_dc1).reshape(-1)[: n_atoms * n_dcm]
dcmol = ase.Atoms(
["X" if _ > 0 else "He" for _ in dcm_charges],
dcm_positions,
)
outDict = {
"mono": mono_dc1,
"dipo": dipo_dc1,
"esp_errors": esp_errors,
"atoms": atoms,
"dcmol": dcmol,
"grid": None,
"esp": None,
"esp_dc_pred": None,
"esp_mono_pred": mono_pred,
"idx_cut": None,
}
return outDict
def normalize_batch(batch):
vdw_surface = batch.get("vdw_surface")
if vdw_surface is not None and vdw_surface.ndim == 4 and vdw_surface.shape[1] == 1:
batch["vdw_surface"] = vdw_surface.squeeze(axis=1)
esp = batch.get("esp")
if esp is not None and esp.ndim == 3 and esp.shape[1] == 1:
batch["esp"] = esp.squeeze(axis=1)
esp_mask = batch.get("espMask")
if esp_mask is not None and esp_mask.ndim == 3 and esp_mask.shape[1] == 1:
batch["espMask"] = esp_mask.squeeze(axis=1)
return batch
def run_dcm(smiles="C1NCCCC1", n_dcm=1):
dcm1, dcm2, dcm3, dcm4 = create_models()
dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights = load_weights()
data_batch, smiles_image = prepare_inputs(smiles)
batch_size = 1
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
batch = normalize_batch(psi4_test_batches[0])
if n_dcm >= 1:
mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
results = {
"atoms": dcm1_results["atoms"],
"dcmol": dcm1_results["dcmol"],
"smiles_image": smiles_image,
"mono_dc1": mono_dc1,
}
if n_dcm >= 2:
mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
results["dcmol2"] = dcm2_results["dcmol"]
results["mono_dc2"] = mono_dc2
if n_dcm >= 3:
mono_dc3, dipo_dc3 = apply_model(dcm3, dcm3_weights, batch, batch_size)
dcm3_results = do_eval(batch, dipo_dc3, mono_dc3, batch_size, n_dcm=3)
results["dcmol3"] = dcm3_results["dcmol"]
results["mono_dc3"] = mono_dc3
if n_dcm >= 4:
mono_dc4, dipo_dc4 = apply_model(dcm4, dcm4_weights, batch, batch_size)
dcm4_results = do_eval(batch, dipo_dc4, mono_dc4, batch_size, n_dcm=4)
results["dcmol4"] = dcm4_results["dcmol"]
results["mono_dc4"] = mono_dc4
return results |