import ase import jax import jax.numpy as jnp import numpy as np import json from rdkit import Chem from rdkit.Chem import AllChem from rdkit.Chem import Draw from scipy.spatial.distance import cdist from dcmnet.data import prepare_batches from dcmnet.modules import MessagePassingModel from dcmnet.loss import ( esp_loss_eval, esp_loss_pots, esp_mono_loss_pots, get_predictions, ) from dcmnet.plotting import evaluate_dc, create_plots2 from dcmnet.multimodel import get_atoms_dcmol from dcmnet.multipoles import plot_3d from dcmnet.utils import apply_model, clip_colors, reshape_dipole RANDOM_NUMBER = 0 data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2) # Model hyperparameters. features = 16 max_degree = 2 num_iterations = 2 num_basis_functions = 8 cutoff = 4.0 def create_models(): dcm1 = MessagePassingModel( features=features, max_degree=max_degree, num_iterations=num_iterations, num_basis_functions=num_basis_functions, cutoff=cutoff, n_dcm=1, ) dcm2 = MessagePassingModel( features=features, max_degree=max_degree, num_iterations=num_iterations, num_basis_functions=num_basis_functions, cutoff=cutoff, n_dcm=2, ) dcm3 = MessagePassingModel( features=features, max_degree=max_degree, num_iterations=num_iterations, num_basis_functions=num_basis_functions, cutoff=cutoff, n_dcm=3, ) dcm4 = MessagePassingModel( features=features, max_degree=max_degree, num_iterations=num_iterations, num_basis_functions=num_basis_functions, cutoff=cutoff, n_dcm=4, ) return dcm1, dcm2, dcm3, dcm4 def get_grid_points(coordinates): """ create a uniform grid of points around the molecule, starting from minimum and maximum coordinates of the molecule (plus minus some padding) :param coordinates: :return: """ bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)]) padding = 3.0 bounds = bounds + np.array([-1, 1])[:, None] * padding grid_points = np.meshgrid( *[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])] ) grid_points = np.stack(grid_points, axis=0) grid_points = np.reshape(grid_points.T, [-1, 3]) # exclude points that are too close to the molecule grid_points = grid_points[ np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0] ] return grid_points def restore_arrays(obj): if isinstance(obj, dict): return {key: restore_arrays(value) for key, value in obj.items()} if isinstance(obj, list): restored = [restore_arrays(value) for value in obj] if any(isinstance(value, dict) for value in restored): return restored try: return jnp.asarray(restored) except Exception: return restored return obj def load_json_dict(path): with open(path, "r", encoding="utf-8") as handle: payload = json.load(handle) return restore_arrays(payload) def load_weights(): dcm1_weights = load_json_dict("wbs/best_0.0_params_dict.json") dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params_dict.json") dcm3_weights = load_json_dict("wbs/dcm3-best_1000.0_params_dict.json") dcm4_weights = load_json_dict("wbs/dcm4-best_1000.0_params_dict.json") return dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights def prepare_inputs(smiles): smiles_mol = Chem.MolFromSmiles(smiles) rdkit_mol = Chem.AddHs(smiles_mol) elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] AllChem.EmbedMolecule(rdkit_mol) coordinates = rdkit_mol.GetConformer(0).GetPositions() surface = get_grid_points(coordinates) for atom in smiles_mol.GetAtoms(): atom.SetProp("atomNote", str(atom.GetIdx())) smiles_image = Draw.MolToImage(smiles_mol) vdw_surface = surface max_n_atoms = 60 max_grid_points = 3143 try: z_values = [np.array([int(_) for _ in elements])] except Exception: z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])] pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))]) pad_coords = np.array( [ np.pad( coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0)) ) ] ) pad_vdw_surface = [] padded_surface = np.pad( vdw_surface, ((0, max_grid_points - len(vdw_surface)), (0, 0)), "constant", constant_values=(0, 10000), ) pad_vdw_surface.append(padded_surface) pad_vdw_surface = np.array(pad_vdw_surface) n_atoms = np.sum(pad_z != 0) data_batch = dict( atomic_numbers=jnp.asarray(pad_z), Z=jnp.asarray(pad_z), positions=jnp.asarray(pad_coords), R=jnp.asarray(pad_coords), # N is the number of atoms N=jnp.asarray([n_atoms]), mono=jnp.asarray(pad_z), ngrid=jnp.array([len(vdw_surface)]), n_grid=jnp.array([len(vdw_surface)]), esp=jnp.asarray([np.zeros(max_grid_points)]), vdw_surface=jnp.asarray(pad_vdw_surface), espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_), ) return data_batch, smiles_image def do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm): esp_errors, mono_pred, _, _ = evaluate_dc( batch, dipo_dc1, mono_dc1, batch_size, n_dcm, plot=False, ) n_atoms = int(batch.get("N", jnp.array([jnp.count_nonzero(batch["Z"])]))[0]) n_dcm = mono_dc1.shape[-1] atoms = ase.Atoms( numbers=np.array(batch["Z"][:n_atoms]), positions=np.array(batch["R"][:n_atoms]), ) dcm_positions = np.array(dipo_dc1).reshape(-1, 3)[: n_atoms * n_dcm] dcm_charges = np.array(mono_dc1).reshape(-1)[: n_atoms * n_dcm] dcmol = ase.Atoms( ["X" if _ > 0 else "He" for _ in dcm_charges], dcm_positions, ) outDict = { "mono": mono_dc1, "dipo": dipo_dc1, "esp_errors": esp_errors, "atoms": atoms, "dcmol": dcmol, "grid": None, "esp": None, "esp_dc_pred": None, "esp_mono_pred": mono_pred, "idx_cut": None, } return outDict def normalize_batch(batch): vdw_surface = batch.get("vdw_surface") if vdw_surface is not None and vdw_surface.ndim == 4 and vdw_surface.shape[1] == 1: batch["vdw_surface"] = vdw_surface.squeeze(axis=1) esp = batch.get("esp") if esp is not None and esp.ndim == 3 and esp.shape[1] == 1: batch["esp"] = esp.squeeze(axis=1) esp_mask = batch.get("espMask") if esp_mask is not None and esp_mask.ndim == 3 and esp_mask.shape[1] == 1: batch["espMask"] = esp_mask.squeeze(axis=1) return batch def run_dcm(smiles="C1NCCCC1", n_dcm=1): dcm1, dcm2, dcm3, dcm4 = create_models() dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights = load_weights() data_batch, smiles_image = prepare_inputs(smiles) batch_size = 1 psi4_test_batches = prepare_batches(data_key, data_batch, batch_size) batch = normalize_batch(psi4_test_batches[0]) if n_dcm >= 1: mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size) dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1) results = { "atoms": dcm1_results["atoms"], "dcmol": dcm1_results["dcmol"], "smiles_image": smiles_image, "mono_dc1": mono_dc1, } if n_dcm >= 2: mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size) dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2) results["dcmol2"] = dcm2_results["dcmol"] results["mono_dc2"] = mono_dc2 if n_dcm >= 3: mono_dc3, dipo_dc3 = apply_model(dcm3, dcm3_weights, batch, batch_size) dcm3_results = do_eval(batch, dipo_dc3, mono_dc3, batch_size, n_dcm=3) results["dcmol3"] = dcm3_results["dcmol"] results["mono_dc3"] = mono_dc3 if n_dcm >= 4: mono_dc4, dipo_dc4 = apply_model(dcm4, dcm4_weights, batch, batch_size) dcm4_results = do_eval(batch, dipo_dc4, mono_dc4, batch_size, n_dcm=4) results["dcmol4"] = dcm4_results["dcmol"] results["mono_dc4"] = mono_dc4 return results