DCMNet / dcm_app.py
EricBoi's picture
Enhance DCM-Net functionality in app.py by integrating support for multiple DCM layers, updating the rendering of 3D models, and improving the user interface with sliders and buttons. Extend dcm_app.py to create and load models for DCM-3 and DCM-4, and adjust weight loading to accommodate additional models. Update requirements.txt to include new dependencies for molecule visualization.
d8bf97a
import ase
import jax
import jax.numpy as jnp
import numpy as np
import json
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit.Chem import Draw
from scipy.spatial.distance import cdist
from dcmnet.data import prepare_batches
from dcmnet.modules import MessagePassingModel
from dcmnet.loss import (
esp_loss_eval,
esp_loss_pots,
esp_mono_loss_pots,
get_predictions,
)
from dcmnet.plotting import evaluate_dc, create_plots2
from dcmnet.multimodel import get_atoms_dcmol
from dcmnet.multipoles import plot_3d
from dcmnet.utils import apply_model, clip_colors, reshape_dipole
RANDOM_NUMBER = 0
data_key, _ = jax.random.split(jax.random.PRNGKey(RANDOM_NUMBER), 2)
# Model hyperparameters.
features = 16
max_degree = 2
num_iterations = 2
num_basis_functions = 8
cutoff = 4.0
def create_models():
dcm1 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=1,
)
dcm2 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=2,
)
dcm3 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=3,
)
dcm4 = MessagePassingModel(
features=features,
max_degree=max_degree,
num_iterations=num_iterations,
num_basis_functions=num_basis_functions,
cutoff=cutoff,
n_dcm=4,
)
return dcm1, dcm2, dcm3, dcm4
def get_grid_points(coordinates):
"""
create a uniform grid of points around the molecule,
starting from minimum and maximum coordinates of the molecule (plus minus some padding)
:param coordinates:
:return:
"""
bounds = np.array([np.min(coordinates, axis=0), np.max(coordinates, axis=0)])
padding = 3.0
bounds = bounds + np.array([-1, 1])[:, None] * padding
grid_points = np.meshgrid(
*[np.linspace(a, b, 15) for a, b in zip(bounds[0], bounds[1])]
)
grid_points = np.stack(grid_points, axis=0)
grid_points = np.reshape(grid_points.T, [-1, 3])
# exclude points that are too close to the molecule
grid_points = grid_points[
np.where(np.all(cdist(grid_points, coordinates) >= (2.5 - 1e-1), axis=-1))[0]
]
return grid_points
def restore_arrays(obj):
if isinstance(obj, dict):
return {key: restore_arrays(value) for key, value in obj.items()}
if isinstance(obj, list):
restored = [restore_arrays(value) for value in obj]
if any(isinstance(value, dict) for value in restored):
return restored
try:
return jnp.asarray(restored)
except Exception:
return restored
return obj
def load_json_dict(path):
with open(path, "r", encoding="utf-8") as handle:
payload = json.load(handle)
return restore_arrays(payload)
def load_weights():
dcm1_weights = load_json_dict("wbs/best_0.0_params_dict.json")
dcm2_weights = load_json_dict("wbs/dcm2-best_1000.0_params_dict.json")
dcm3_weights = load_json_dict("wbs/dcm3-best_1000.0_params_dict.json")
dcm4_weights = load_json_dict("wbs/dcm4-best_1000.0_params_dict.json")
return dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights
def prepare_inputs(smiles):
smiles_mol = Chem.MolFromSmiles(smiles)
rdkit_mol = Chem.AddHs(smiles_mol)
elements = [a.GetSymbol() for a in rdkit_mol.GetAtoms()]
AllChem.EmbedMolecule(rdkit_mol)
coordinates = rdkit_mol.GetConformer(0).GetPositions()
surface = get_grid_points(coordinates)
for atom in smiles_mol.GetAtoms():
atom.SetProp("atomNote", str(atom.GetIdx()))
smiles_image = Draw.MolToImage(smiles_mol)
vdw_surface = surface
max_n_atoms = 60
max_grid_points = 3143
try:
z_values = [np.array([int(_) for _ in elements])]
except Exception:
z_values = [np.array([ase.data.atomic_numbers[_.capitalize()] for _ in elements])]
pad_z = np.array([np.pad(z_values[0], ((0, max_n_atoms - len(z_values[0]))))])
pad_coords = np.array(
[
np.pad(
coordinates, ((0, max_n_atoms - len(coordinates)), (0, 0))
)
]
)
pad_vdw_surface = []
padded_surface = np.pad(
vdw_surface,
((0, max_grid_points - len(vdw_surface)), (0, 0)),
"constant",
constant_values=(0, 10000),
)
pad_vdw_surface.append(padded_surface)
pad_vdw_surface = np.array(pad_vdw_surface)
n_atoms = np.sum(pad_z != 0)
data_batch = dict(
atomic_numbers=jnp.asarray(pad_z),
Z=jnp.asarray(pad_z),
positions=jnp.asarray(pad_coords),
R=jnp.asarray(pad_coords),
# N is the number of atoms
N=jnp.asarray([n_atoms]),
mono=jnp.asarray(pad_z),
ngrid=jnp.array([len(vdw_surface)]),
n_grid=jnp.array([len(vdw_surface)]),
esp=jnp.asarray([np.zeros(max_grid_points)]),
vdw_surface=jnp.asarray(pad_vdw_surface),
espMask=jnp.asarray([np.ones(max_grid_points)], dtype=jnp.bool_),
)
return data_batch, smiles_image
def do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm):
esp_errors, mono_pred, _, _ = evaluate_dc(
batch,
dipo_dc1,
mono_dc1,
batch_size,
n_dcm,
plot=False,
)
n_atoms = int(batch.get("N", jnp.array([jnp.count_nonzero(batch["Z"])]))[0])
n_dcm = mono_dc1.shape[-1]
atoms = ase.Atoms(
numbers=np.array(batch["Z"][:n_atoms]),
positions=np.array(batch["R"][:n_atoms]),
)
dcm_positions = np.array(dipo_dc1).reshape(-1, 3)[: n_atoms * n_dcm]
dcm_charges = np.array(mono_dc1).reshape(-1)[: n_atoms * n_dcm]
dcmol = ase.Atoms(
["X" if _ > 0 else "He" for _ in dcm_charges],
dcm_positions,
)
outDict = {
"mono": mono_dc1,
"dipo": dipo_dc1,
"esp_errors": esp_errors,
"atoms": atoms,
"dcmol": dcmol,
"grid": None,
"esp": None,
"esp_dc_pred": None,
"esp_mono_pred": mono_pred,
"idx_cut": None,
}
return outDict
def normalize_batch(batch):
vdw_surface = batch.get("vdw_surface")
if vdw_surface is not None and vdw_surface.ndim == 4 and vdw_surface.shape[1] == 1:
batch["vdw_surface"] = vdw_surface.squeeze(axis=1)
esp = batch.get("esp")
if esp is not None and esp.ndim == 3 and esp.shape[1] == 1:
batch["esp"] = esp.squeeze(axis=1)
esp_mask = batch.get("espMask")
if esp_mask is not None and esp_mask.ndim == 3 and esp_mask.shape[1] == 1:
batch["espMask"] = esp_mask.squeeze(axis=1)
return batch
def run_dcm(smiles="C1NCCCC1", n_dcm=1):
dcm1, dcm2, dcm3, dcm4 = create_models()
dcm1_weights, dcm2_weights, dcm3_weights, dcm4_weights = load_weights()
data_batch, smiles_image = prepare_inputs(smiles)
batch_size = 1
psi4_test_batches = prepare_batches(data_key, data_batch, batch_size)
batch = normalize_batch(psi4_test_batches[0])
if n_dcm >= 1:
mono_dc1, dipo_dc1 = apply_model(dcm1, dcm1_weights, batch, batch_size)
dcm1_results = do_eval(batch, dipo_dc1, mono_dc1, batch_size, n_dcm=1)
results = {
"atoms": dcm1_results["atoms"],
"dcmol": dcm1_results["dcmol"],
"smiles_image": smiles_image,
"mono_dc1": mono_dc1,
}
if n_dcm >= 2:
mono_dc2, dipo_dc2 = apply_model(dcm2, dcm2_weights, batch, batch_size)
dcm2_results = do_eval(batch, dipo_dc2, mono_dc2, batch_size, n_dcm=2)
results["dcmol2"] = dcm2_results["dcmol"]
results["mono_dc2"] = mono_dc2
if n_dcm >= 3:
mono_dc3, dipo_dc3 = apply_model(dcm3, dcm3_weights, batch, batch_size)
dcm3_results = do_eval(batch, dipo_dc3, mono_dc3, batch_size, n_dcm=3)
results["dcmol3"] = dcm3_results["dcmol"]
results["mono_dc3"] = mono_dc3
if n_dcm >= 4:
mono_dc4, dipo_dc4 = apply_model(dcm4, dcm4_weights, batch, batch_size)
dcm4_results = do_eval(batch, dipo_dc4, mono_dc4, batch_size, n_dcm=4)
results["dcmol4"] = dcm4_results["dcmol"]
results["mono_dc4"] = mono_dc4
return results