| import ase |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import pandas as pd |
| from rdkit import Chem |
| from rdkit.Chem import AllChem |
| from rdkit.Chem import Draw |
| from scipy.spatial.distance import cdist |
|
|
| from dcmnet.data import prepare_batches |
| from dcmnet.modules import MessagePassingModel |
| from dcmnet.loss import ( |
| esp_loss_eval, |
| esp_loss_pots, |
| esp_mono_loss_pots, |
| get_predictions, |
| ) |
| from dcmnet.plotting import evaluate_dc, create_plots2 |
| from dcmnet.multimodel import get_atoms_dcmol |
| from dcmnet.multipoles import plot_3d |
| from dcmnet.utils import apply_model, clip_colors, reshape_dipole |
| RANDOM_NUMBER = 0 |
| data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2) |
|
|
| |
| features = 16 |
| max_degree = 2 |
| num_iterations = 2 |
| num_basis_functions = 8 |
| cutoff = 4.0 |
|
|
|
|
| def create_models(): |
| dcm1 = MessagePassingModel( |
| features=features, |
| max_degree=max_degree, |
| num_iterations=num_iterations, |
| num_basis_functions=num_basis_functions, |
| cutoff=cutoff, |
| n_dcm=1, |
| ) |
| dcm2 = MessagePassingModel( |
| features=features, |
| max_degree=max_degree, |
| num_iterations=num_iterations, |
| num_basis_functions=num_basis_functions, |
| cutoff=cutoff, |
| n_dcm=2, |
| ) |
| return dcm1, dcm2 |
|
|
|
|
| def get_grid_points(coordinates): |
| """ |
| create a uniform grid of points around the molecule, |
| starting from minimum and maximum coordinates of the molecule (plus minus some padding) |
| :param coordinates: |
| :return: |
| """ |
| bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)]) |
| padding = 3.0 |
| bounds = bounds + np.array([-1, 1])[:, None] * padding |
| grid_points = np.meshgrid( |
| *[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])] |
| ) |
|
|
| grid_points = np.stack(grid_points, axis=0) |
| grid_points = np.reshape(grid_points.T, [-1, 3]) |
| |
| grid_points = grid_points[ |
| np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0] |
| ] |
|
|
| return grid_points |
|
|
|
|
| def load_weights(): |
| dcm1_weights = pd.read_pickle("wbs/best_0.0_params.pkl") |
| dcm2_weights = pd.read_pickle("wbs/dcm2-best_1000.0_params.pkl") |
| return dcm1_weights, dcm2_weights |
|
|
|
|
| def prepare_inputs(smiles): |
| smiles_mol = Chem.MolFromSmiles(smiles) |
| rdkit_mol = Chem.AddHs(smiles_mol) |
| elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] |
| AllChem.EmbedMolecule(rdkit_mol) |
| coordinates = rdkit_mol.GetConformer(0).GetPositions() |
| surface = get_grid_points(coordinates) |
|
|
| for atom in smiles_mol.GetAtoms(): |
| atom.SetProp("atomNote", str(atom.GetIdx())) |
|
|
| smiles_image = Draw.MolToImage(smiles_mol) |
|
|
| vdw_surface = surface |
| max_n_atoms = 60 |
| max_grid_points = 3143 |
| try: |
| z_values = [np.array([int(_) for _ in elements])] |
| except Exception: |
| z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])] |
|
|
| pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))]) |
| pad_coords = np.array( |
| [ |
| np.pad( |
| coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0)) |
| ) |
| ] |
| ) |
|
|
| pad_vdw_surface = [] |
| padded_surface = np.pad( |
| vdw_surface, |
| ((0, max_grid_points - len(vdw_surface)), (0, 0)), |
| "constant", |
| constant_values=(0, 10000), |
| ) |
| pad_vdw_surface.append(padded_surface) |
| pad_vdw_surface = np.array(pad_vdw_surface) |
| n_atoms = np.sum(pad_z != 0) |
| data_batch = dict( |
| atomic_numbers=jnp.asarray(pad_z), |
| Z=jnp.asarray(pad_z), |
| positions=jnp.asarray(pad_coords), |
| R=jnp.asarray(pad_coords), |
| |
| N=jnp.asarray([n_atoms]), |
| mono=jnp.asarray(pad_z), |
| ngrid=jnp.array([len(vdw_surface)]), |
| n_grid=jnp.array([len(vdw_surface)]), |
| esp=jnp.asarray([np.zeros(max_grid_points)]), |
| vdw_surface=jnp.asarray(pad_vdw_surface), |
| espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_), |
| ) |
|
|
| return data_batch, smiles_image |
|
|
| def do_eval(batch, dipo_dc1, mono_dc1, batch_size): |
| esp_errors, mono_pred, _, _ = evaluate_dc( |
| batch, |
| dipo_dc1, |
| mono_dc1, |
| batch_size, |
| 1, |
| plot=False, |
|
|
| ) |
|
|
| atoms, dcmol, grid, esp, esp_dc_pred, idx_cut = create_plots2( |
| mono_dc1, dipo_dc1, batch, batch_size, 1 |
| ) |
| outDict = { |
| "mono": mono_dc1, |
| "dipo": dipo_dc1, |
| "esp_errors": esp_errors, |
| "atoms": atoms, |
| "dcmol": dcmol, |
| "grid": grid, |
| "esp": esp, |
| "esp_dc_pred": esp_dc_pred, |
| "esp_mono_pred": mono_pred, |
| "idx_cut": idx_cut, |
| } |
| |
| return outDict |
|
|
| def run_dcm(smiles="C1NCCCC1"): |
| dcm1, dcm2 = create_models() |
| dcm1_weights, dcm2_weights = load_weights() |
| data_batch, smiles_image = prepare_inputs(smiles) |
|
|
| batch_size = 1 |
| psi4_test_batches = prepare_batches(data_key, data_batch, batch_size) |
| batch = psi4_test_batches[0] |
|
|
| mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size) |
| mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size) |
|
|
| dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size) |
| dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size) |
|
|
| return { |
| "smiles_image": smiles_image, |
| "atoms": dcm1_results["atoms"], |
| "dcmol": dcm1_results["dcmol"], |
| "dcmol2": dcm2_results["dcmol"], |
| } |
|
|
|
|
| if __name__ == "__main__": |
| smiles = "C1NCCCC1" |
| results = run_dcm(smiles) |
| print(results) |