Enhance DCM-Net functionality in app.py by integrating support for multiple DCM layers, updating the rendering of 3D models, and improving the user interface with sliders and buttons. Extend dcm_app.py to create and load models for DCM-3 and DCM-4, and adjust weight loading to accommodate additional models. Update requirements.txt to include new dependencies for molecule visualization.
d8bf97a | import ase | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import json | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| from rdkit.Chem import Draw | |
| from scipy.spatial.distance import cdist | |
| from dcmnet.data import prepare_batches | |
| from dcmnet.modules import MessagePassingModel | |
| from dcmnet.loss import ( | |
| esp_loss_eval, | |
| esp_loss_pots, | |
| esp_mono_loss_pots, | |
| get_predictions, | |
| ) | |
| from dcmnet.plotting import evaluate_dc, create_plots2 | |
| from dcmnet.multimodel import get_atoms_dcmol | |
| from dcmnet.multipoles import plot_3d | |
| from dcmnet.utils import apply_model, clip_colors, reshape_dipole | |
| RANDOM_NUMBER = 0 | |
| data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2) | |
| # Model hyperparameters. | |
| features = 16 | |
| max_degree = 2 | |
| num_iterations = 2 | |
| num_basis_functions = 8 | |
| cutoff = 4.0 | |
| def create_models(): | |
| dcm1 = MessagePassingModel( | |
| features=features, | |
| max_degree=max_degree, | |
| num_iterations=num_iterations, | |
| num_basis_functions=num_basis_functions, | |
| cutoff=cutoff, | |
| n_dcm=1, | |
| ) | |
| dcm2 = MessagePassingModel( | |
| features=features, | |
| max_degree=max_degree, | |
| num_iterations=num_iterations, | |
| num_basis_functions=num_basis_functions, | |
| cutoff=cutoff, | |
| n_dcm=2, | |
| ) | |
| dcm3 = MessagePassingModel( | |
| features=features, | |
| max_degree=max_degree, | |
| num_iterations=num_iterations, | |
| num_basis_functions=num_basis_functions, | |
| cutoff=cutoff, | |
| n_dcm=3, | |
| ) | |
| dcm4 = MessagePassingModel( | |
| features=features, | |
| max_degree=max_degree, | |
| num_iterations=num_iterations, | |
| num_basis_functions=num_basis_functions, | |
| cutoff=cutoff, | |
| n_dcm=4, | |
| ) | |
| return dcm1, dcm2, dcm3, dcm4 | |
| def get_grid_points(coordinates): | |
| """ | |
| create a uniform grid of points around the molecule, | |
| starting from minimum and maximum coordinates of the molecule (plus minus some padding) | |
| :param coordinates: | |
| :return: | |
| """ | |
| bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)]) | |
| padding = 3.0 | |
| bounds = bounds + np.array([-1, 1])[:, None] * padding | |
| grid_points = np.meshgrid( | |
| *[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])] | |
| ) | |
| grid_points = np.stack(grid_points, axis=0) | |
| grid_points = np.reshape(grid_points.T, [-1, 3]) | |
| # exclude points that are too close to the molecule | |
| grid_points = grid_points[ | |
| np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0] | |
| ] | |
| return grid_points | |
| def restore_arrays(obj): | |
| if isinstance(obj, dict): | |
| return {key: restore_arrays(value) for key, value in obj.items()} | |
| if isinstance(obj, list): | |
| restored = [restore_arrays(value) for value in obj] | |
| if any(isinstance(value, dict) for value in restored): | |
| return restored | |
| try: | |
| return jnp.asarray(restored) | |
| except Exception: | |
| return restored | |
| return obj | |
| def load_json_dict(path): | |
| with open(path, "r", encoding="utf-8") as handle: | |
| payload = json.load(handle) | |
| return restore_arrays(payload) | |
| def load_weights(): | |
| dcm1_weights = load_json_dict("wbs/best_0.0_params_dict.json") | |
| dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params_dict.json") | |
| dcm3_weights = load_json_dict("wbs/dcm3-best_1000.0_params_dict.json") | |
| dcm4_weights = load_json_dict("wbs/dcm4-best_1000.0_params_dict.json") | |
| return dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights | |
| def prepare_inputs(smiles): | |
| smiles_mol = Chem.MolFromSmiles(smiles) | |
| rdkit_mol = Chem.AddHs(smiles_mol) | |
| elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()] | |
| AllChem.EmbedMolecule(rdkit_mol) | |
| coordinates = rdkit_mol.GetConformer(0).GetPositions() | |
| surface = get_grid_points(coordinates) | |
| for atom in smiles_mol.GetAtoms(): | |
| atom.SetProp("atomNote", str(atom.GetIdx())) | |
| smiles_image = Draw.MolToImage(smiles_mol) | |
| vdw_surface = surface | |
| max_n_atoms = 60 | |
| max_grid_points = 3143 | |
| try: | |
| z_values = [np.array([int(_) for _ in elements])] | |
| except Exception: | |
| z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])] | |
| pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))]) | |
| pad_coords = np.array( | |
| [ | |
| np.pad( | |
| coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0)) | |
| ) | |
| ] | |
| ) | |
| pad_vdw_surface = [] | |
| padded_surface = np.pad( | |
| vdw_surface, | |
| ((0, max_grid_points - len(vdw_surface)), (0, 0)), | |
| "constant", | |
| constant_values=(0, 10000), | |
| ) | |
| pad_vdw_surface.append(padded_surface) | |
| pad_vdw_surface = np.array(pad_vdw_surface) | |
| n_atoms = np.sum(pad_z != 0) | |
| data_batch = dict( | |
| atomic_numbers=jnp.asarray(pad_z), | |
| Z=jnp.asarray(pad_z), | |
| positions=jnp.asarray(pad_coords), | |
| R=jnp.asarray(pad_coords), | |
| # N is the number of atoms | |
| N=jnp.asarray([n_atoms]), | |
| mono=jnp.asarray(pad_z), | |
| ngrid=jnp.array([len(vdw_surface)]), | |
| n_grid=jnp.array([len(vdw_surface)]), | |
| esp=jnp.asarray([np.zeros(max_grid_points)]), | |
| vdw_surface=jnp.asarray(pad_vdw_surface), | |
| espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_), | |
| ) | |
| return data_batch, smiles_image | |
| def do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm): | |
| esp_errors, mono_pred, _, _ = evaluate_dc( | |
| batch, | |
| dipo_dc1, | |
| mono_dc1, | |
| batch_size, | |
| n_dcm, | |
| plot=False, | |
| ) | |
| n_atoms = int(batch.get("N", jnp.array([jnp.count_nonzero(batch["Z"])]))[0]) | |
| n_dcm = mono_dc1.shape[-1] | |
| atoms = ase.Atoms( | |
| numbers=np.array(batch["Z"][:n_atoms]), | |
| positions=np.array(batch["R"][:n_atoms]), | |
| ) | |
| dcm_positions = np.array(dipo_dc1).reshape(-1, 3)[: n_atoms * n_dcm] | |
| dcm_charges = np.array(mono_dc1).reshape(-1)[: n_atoms * n_dcm] | |
| dcmol = ase.Atoms( | |
| ["X" if _ > 0 else "He" for _ in dcm_charges], | |
| dcm_positions, | |
| ) | |
| outDict = { | |
| "mono": mono_dc1, | |
| "dipo": dipo_dc1, | |
| "esp_errors": esp_errors, | |
| "atoms": atoms, | |
| "dcmol": dcmol, | |
| "grid": None, | |
| "esp": None, | |
| "esp_dc_pred": None, | |
| "esp_mono_pred": mono_pred, | |
| "idx_cut": None, | |
| } | |
| return outDict | |
| def normalize_batch(batch): | |
| vdw_surface = batch.get("vdw_surface") | |
| if vdw_surface is not None and vdw_surface.ndim == 4 and vdw_surface.shape[1] == 1: | |
| batch["vdw_surface"] = vdw_surface.squeeze(axis=1) | |
| esp = batch.get("esp") | |
| if esp is not None and esp.ndim == 3 and esp.shape[1] == 1: | |
| batch["esp"] = esp.squeeze(axis=1) | |
| esp_mask = batch.get("espMask") | |
| if esp_mask is not None and esp_mask.ndim == 3 and esp_mask.shape[1] == 1: | |
| batch["espMask"] = esp_mask.squeeze(axis=1) | |
| return batch | |
| def run_dcm(smiles="C1NCCCC1", n_dcm=1): | |
| dcm1, dcm2, dcm3, dcm4 = create_models() | |
| dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights = load_weights() | |
| data_batch, smiles_image = prepare_inputs(smiles) | |
| batch_size = 1 | |
| psi4_test_batches = prepare_batches(data_key, data_batch, batch_size) | |
| batch = normalize_batch(psi4_test_batches[0]) | |
| if n_dcm >= 1: | |
| mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size) | |
| dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1) | |
| results = { | |
| "atoms": dcm1_results["atoms"], | |
| "dcmol": dcm1_results["dcmol"], | |
| "smiles_image": smiles_image, | |
| "mono_dc1": mono_dc1, | |
| } | |
| if n_dcm >= 2: | |
| mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size) | |
| dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2) | |
| results["dcmol2"] = dcm2_results["dcmol"] | |
| results["mono_dc2"] = mono_dc2 | |
| if n_dcm >= 3: | |
| mono_dc3, dipo_dc3 = apply_model(dcm3, dcm3_weights, batch, batch_size) | |
| dcm3_results = do_eval(batch, dipo_dc3, mono_dc3, batch_size, n_dcm=3) | |
| results["dcmol3"] = dcm3_results["dcmol"] | |
| results["mono_dc3"] = mono_dc3 | |
| if n_dcm >= 4: | |
| mono_dc4, dipo_dc4 = apply_model(dcm4, dcm4_weights, batch, batch_size) | |
| dcm4_results = do_eval(batch, dipo_dc4, mono_dc4, batch_size, n_dcm=4) | |
| results["dcmol4"] = dcm4_results["dcmol"] | |
| results["mono_dc4"] = mono_dc4 | |
| return results |